mirror of
https://github.com/samuelclay/NewsBlur.git
synced 2025-08-05 16:58:59 +00:00
Using sentence transformers and the MiniLM model to create embeddings for feeds.
This commit is contained in:
parent
27b1069302
commit
0f579cbc03
4 changed files with 90 additions and 3 deletions
|
@ -341,14 +341,14 @@ class Feed(models.Model):
|
||||||
SearchFeed.create_elasticsearch_mapping(delete=True)
|
SearchFeed.create_elasticsearch_mapping(delete=True)
|
||||||
|
|
||||||
last_pk = cls.objects.latest("pk").pk
|
last_pk = cls.objects.latest("pk").pk
|
||||||
for f in range(offset, last_pk, 1000):
|
for f in range(offset, last_pk, 10):
|
||||||
print(
|
print(
|
||||||
" ---> {f} / {last_pk} ({pct}%)".format(
|
" ---> {f} / {last_pk} ({pct}%)".format(
|
||||||
f=f, last_pk=last_pk, pct=str(float(f) / last_pk * 100)[:2]
|
f=f, last_pk=last_pk, pct=str(float(f) / last_pk * 100)[:2]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
feeds = Feed.objects.filter(
|
feeds = Feed.objects.filter(
|
||||||
pk__in=range(f, f + 1000), active=True, active_subscribers__gte=subscribers
|
pk__in=range(f, f + 10), active=True, active_subscribers__gte=subscribers
|
||||||
).values_list("pk")
|
).values_list("pk")
|
||||||
for (feed_id,) in feeds:
|
for (feed_id,) in feeds:
|
||||||
Feed.objects.get(pk=feed_id).index_feed_for_search()
|
Feed.objects.get(pk=feed_id).index_feed_for_search()
|
||||||
|
@ -364,6 +364,7 @@ class Feed(models.Model):
|
||||||
address=self.feed_address,
|
address=self.feed_address,
|
||||||
link=self.feed_link,
|
link=self.feed_link,
|
||||||
num_subscribers=self.num_subscribers,
|
num_subscribers=self.num_subscribers,
|
||||||
|
content_vector=SearchFeed.generate_feed_content_vector(self.pk),
|
||||||
)
|
)
|
||||||
|
|
||||||
def index_stories_for_search(self):
|
def index_stories_for_search(self):
|
||||||
|
|
|
@ -6,11 +6,13 @@ import time
|
||||||
import celery
|
import celery
|
||||||
import elasticsearch
|
import elasticsearch
|
||||||
import mongoengine as mongo
|
import mongoengine as mongo
|
||||||
|
import numpy as np
|
||||||
import pymongo
|
import pymongo
|
||||||
import redis
|
import redis
|
||||||
import urllib3
|
import urllib3
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
from apps.search.tasks import (
|
from apps.search.tasks import (
|
||||||
FinishIndexSubscriptionsForSearch,
|
FinishIndexSubscriptionsForSearch,
|
||||||
|
@ -491,6 +493,7 @@ class SearchStory:
|
||||||
class SearchFeed:
|
class SearchFeed:
|
||||||
_es_client = None
|
_es_client = None
|
||||||
name = "feeds"
|
name = "feeds"
|
||||||
|
model = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ES(cls):
|
def ES(cls):
|
||||||
|
@ -574,6 +577,10 @@ class SearchFeed:
|
||||||
"term_vector": "with_positions_offsets",
|
"term_vector": "with_positions_offsets",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
},
|
},
|
||||||
|
"content_vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": 384, # Numbers of dims from all-MiniLM-L6-v2
|
||||||
|
},
|
||||||
}
|
}
|
||||||
cls.ES().indices.put_mapping(
|
cls.ES().indices.put_mapping(
|
||||||
body={
|
body={
|
||||||
|
@ -584,13 +591,14 @@ class SearchFeed:
|
||||||
cls.ES().indices.flush(cls.index_name())
|
cls.ES().indices.flush(cls.index_name())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def index(cls, feed_id, title, address, link, num_subscribers):
|
def index(cls, feed_id, title, address, link, num_subscribers, content_vector):
|
||||||
doc = {
|
doc = {
|
||||||
"feed_id": feed_id,
|
"feed_id": feed_id,
|
||||||
"title": title,
|
"title": title,
|
||||||
"feed_address": address,
|
"feed_address": address,
|
||||||
"link": link,
|
"link": link,
|
||||||
"num_subscribers": num_subscribers,
|
"num_subscribers": num_subscribers,
|
||||||
|
"content_vector": content_vector,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
cls.ES().create(index=cls.index_name(), id=feed_id, body=doc, doc_type=cls.doc_type())
|
cls.ES().create(index=cls.index_name(), id=feed_id, body=doc, doc_type=cls.doc_type())
|
||||||
|
@ -681,6 +689,76 @@ class SearchFeed:
|
||||||
|
|
||||||
return results["hits"]["hits"]
|
return results["hits"]["hits"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def vector_query(cls, query_vector, max_results=10):
|
||||||
|
try:
|
||||||
|
cls.ES().indices.flush(index=cls.index_name())
|
||||||
|
except elasticsearch.exceptions.NotFoundError as e:
|
||||||
|
logging.debug(f" ***> ~FRNo search server available: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query": {"match_all": {}},
|
||||||
|
"script": {
|
||||||
|
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
|
||||||
|
"params": {"query_vector": query_vector},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"size": max_results,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
results = cls.ES().search(body=body, index=cls.index_name(), doc_type=cls.doc_type())
|
||||||
|
except elasticsearch.exceptions.RequestError as e:
|
||||||
|
logging.debug(" ***> ~FRNo search server available for querying: %s" % e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"~FGVector search ~FCfeeds~FG: ~SB{max_results}~SN requested, ~SB{len(results['hits']['hits'])}~SN results"
|
||||||
|
)
|
||||||
|
|
||||||
|
return results["hits"]["hits"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_feed_content_vector(cls, feed_id, text=None):
|
||||||
|
from apps.rss_feeds.models import Feed
|
||||||
|
|
||||||
|
if cls.model is None:
|
||||||
|
cls.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
feed = Feed.objects.get(id=feed_id)
|
||||||
|
|
||||||
|
# cross_encoder = CrossEncoder("BAAI/bge-large-zh-v2", device="cpu")
|
||||||
|
# cross_encoder.encode([feed.feed_title, feed.feed_content], convert_to_tensors="all")
|
||||||
|
|
||||||
|
stories = feed.get_stories()
|
||||||
|
stories_text = ""
|
||||||
|
for story in stories:
|
||||||
|
stories_text += f"{story['story_title']} {story['story_authors']} {story['story_content']}"
|
||||||
|
text = f"{feed.feed_title} {stories_text}"
|
||||||
|
|
||||||
|
# Remove URLs
|
||||||
|
text = re.sub(r"http\S+", "", text)
|
||||||
|
|
||||||
|
# Remove special characters
|
||||||
|
text = re.sub(r"[^\w\s]", "", text)
|
||||||
|
|
||||||
|
# Convert to lowercase
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# Remove extra whitespace
|
||||||
|
text = " ".join(text.split())
|
||||||
|
|
||||||
|
encoded_text = cls.model.encode(text)
|
||||||
|
normalized_embedding = encoded_text / np.linalg.norm(encoded_text)
|
||||||
|
|
||||||
|
# logging.debug(f" ---> ~FGNormalized embedding for feed {feed_id}: {normalized_embedding}")
|
||||||
|
|
||||||
|
return normalized_embedding
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def export_csv(cls):
|
def export_csv(cls):
|
||||||
import djqscsv
|
import djqscsv
|
||||||
|
|
|
@ -126,3 +126,4 @@ webencodings==0.5.1
|
||||||
XlsxWriter==1.3.7
|
XlsxWriter==1.3.7
|
||||||
zope.event==4.5.0
|
zope.event==4.5.0
|
||||||
zope.interface==5.4.0
|
zope.interface==5.4.0
|
||||||
|
sentence_transformers==3.0.1
|
||||||
|
|
|
@ -110,6 +110,13 @@ MAX_EMAILS_SENT_PER_DAY_PER_USER = 20 # Most are story notifications
|
||||||
# = Django-specific Modules =
|
# = Django-specific Modules =
|
||||||
# ===========================
|
# ===========================
|
||||||
|
|
||||||
|
SHELL_PLUS_IMPORTS = [
|
||||||
|
"from apps.search.models import SearchFeed, SearchStory",
|
||||||
|
"import redis",
|
||||||
|
"import datetime",
|
||||||
|
"from pprint import pprint",
|
||||||
|
]
|
||||||
|
# SHELL_PLUS_PRINT_SQL = True
|
||||||
|
|
||||||
MIDDLEWARE = (
|
MIDDLEWARE = (
|
||||||
"django_prometheus.middleware.PrometheusBeforeMiddleware",
|
"django_prometheus.middleware.PrometheusBeforeMiddleware",
|
||||||
|
|
Loading…
Add table
Reference in a new issue