mirror of
https://github.com/samuelclay/NewsBlur.git
synced 2025-04-13 09:42:01 +00:00
Using sentence transformers and the MiniLM model to create embeddings for feeds.
This commit is contained in:
parent
27b1069302
commit
0f579cbc03
4 changed files with 90 additions and 3 deletions
|
@ -341,14 +341,14 @@ class Feed(models.Model):
|
|||
SearchFeed.create_elasticsearch_mapping(delete=True)
|
||||
|
||||
last_pk = cls.objects.latest("pk").pk
|
||||
for f in range(offset, last_pk, 1000):
|
||||
for f in range(offset, last_pk, 10):
|
||||
print(
|
||||
" ---> {f} / {last_pk} ({pct}%)".format(
|
||||
f=f, last_pk=last_pk, pct=str(float(f) / last_pk * 100)[:2]
|
||||
)
|
||||
)
|
||||
feeds = Feed.objects.filter(
|
||||
pk__in=range(f, f + 1000), active=True, active_subscribers__gte=subscribers
|
||||
pk__in=range(f, f + 10), active=True, active_subscribers__gte=subscribers
|
||||
).values_list("pk")
|
||||
for (feed_id,) in feeds:
|
||||
Feed.objects.get(pk=feed_id).index_feed_for_search()
|
||||
|
@ -364,6 +364,7 @@ class Feed(models.Model):
|
|||
address=self.feed_address,
|
||||
link=self.feed_link,
|
||||
num_subscribers=self.num_subscribers,
|
||||
content_vector=SearchFeed.generate_feed_content_vector(self.pk),
|
||||
)
|
||||
|
||||
def index_stories_for_search(self):
|
||||
|
|
|
@ -6,11 +6,13 @@ import time
|
|||
import celery
|
||||
import elasticsearch
|
||||
import mongoengine as mongo
|
||||
import numpy as np
|
||||
import pymongo
|
||||
import redis
|
||||
import urllib3
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from apps.search.tasks import (
|
||||
FinishIndexSubscriptionsForSearch,
|
||||
|
@ -491,6 +493,7 @@ class SearchStory:
|
|||
class SearchFeed:
|
||||
_es_client = None
|
||||
name = "feeds"
|
||||
model = None
|
||||
|
||||
@classmethod
|
||||
def ES(cls):
|
||||
|
@ -574,6 +577,10 @@ class SearchFeed:
|
|||
"term_vector": "with_positions_offsets",
|
||||
"type": "text",
|
||||
},
|
||||
"content_vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": 384, # Numbers of dims from all-MiniLM-L6-v2
|
||||
},
|
||||
}
|
||||
cls.ES().indices.put_mapping(
|
||||
body={
|
||||
|
@ -584,13 +591,14 @@ class SearchFeed:
|
|||
cls.ES().indices.flush(cls.index_name())
|
||||
|
||||
@classmethod
|
||||
def index(cls, feed_id, title, address, link, num_subscribers):
|
||||
def index(cls, feed_id, title, address, link, num_subscribers, content_vector):
|
||||
doc = {
|
||||
"feed_id": feed_id,
|
||||
"title": title,
|
||||
"feed_address": address,
|
||||
"link": link,
|
||||
"num_subscribers": num_subscribers,
|
||||
"content_vector": content_vector,
|
||||
}
|
||||
try:
|
||||
cls.ES().create(index=cls.index_name(), id=feed_id, body=doc, doc_type=cls.doc_type())
|
||||
|
@ -681,6 +689,76 @@ class SearchFeed:
|
|||
|
||||
return results["hits"]["hits"]
|
||||
|
||||
@classmethod
|
||||
def vector_query(cls, query_vector, max_results=10):
|
||||
try:
|
||||
cls.ES().indices.flush(index=cls.index_name())
|
||||
except elasticsearch.exceptions.NotFoundError as e:
|
||||
logging.debug(f" ***> ~FRNo search server available: {e}")
|
||||
return []
|
||||
|
||||
body = {
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
|
||||
"params": {"query_vector": query_vector},
|
||||
},
|
||||
}
|
||||
},
|
||||
"size": max_results,
|
||||
}
|
||||
try:
|
||||
results = cls.ES().search(body=body, index=cls.index_name(), doc_type=cls.doc_type())
|
||||
except elasticsearch.exceptions.RequestError as e:
|
||||
logging.debug(" ***> ~FRNo search server available for querying: %s" % e)
|
||||
return []
|
||||
|
||||
logging.info(
|
||||
f"~FGVector search ~FCfeeds~FG: ~SB{max_results}~SN requested, ~SB{len(results['hits']['hits'])}~SN results"
|
||||
)
|
||||
|
||||
return results["hits"]["hits"]
|
||||
|
||||
@classmethod
|
||||
def generate_feed_content_vector(cls, feed_id, text=None):
|
||||
from apps.rss_feeds.models import Feed
|
||||
|
||||
if cls.model is None:
|
||||
cls.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
if text is None:
|
||||
feed = Feed.objects.get(id=feed_id)
|
||||
|
||||
# cross_encoder = CrossEncoder("BAAI/bge-large-zh-v2", device="cpu")
|
||||
# cross_encoder.encode([feed.feed_title, feed.feed_content], convert_to_tensors="all")
|
||||
|
||||
stories = feed.get_stories()
|
||||
stories_text = ""
|
||||
for story in stories:
|
||||
stories_text += f"{story['story_title']} {story['story_authors']} {story['story_content']}"
|
||||
text = f"{feed.feed_title} {stories_text}"
|
||||
|
||||
# Remove URLs
|
||||
text = re.sub(r"http\S+", "", text)
|
||||
|
||||
# Remove special characters
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
|
||||
# Convert to lowercase
|
||||
text = text.lower()
|
||||
|
||||
# Remove extra whitespace
|
||||
text = " ".join(text.split())
|
||||
|
||||
encoded_text = cls.model.encode(text)
|
||||
normalized_embedding = encoded_text / np.linalg.norm(encoded_text)
|
||||
|
||||
# logging.debug(f" ---> ~FGNormalized embedding for feed {feed_id}: {normalized_embedding}")
|
||||
|
||||
return normalized_embedding
|
||||
|
||||
@classmethod
|
||||
def export_csv(cls):
|
||||
import djqscsv
|
||||
|
|
|
@ -126,3 +126,4 @@ webencodings==0.5.1
|
|||
XlsxWriter==1.3.7
|
||||
zope.event==4.5.0
|
||||
zope.interface==5.4.0
|
||||
sentence_transformers==3.0.1
|
||||
|
|
|
@ -110,6 +110,13 @@ MAX_EMAILS_SENT_PER_DAY_PER_USER = 20 # Most are story notifications
|
|||
# = Django-specific Modules =
|
||||
# ===========================
|
||||
|
||||
SHELL_PLUS_IMPORTS = [
|
||||
"from apps.search.models import SearchFeed, SearchStory",
|
||||
"import redis",
|
||||
"import datetime",
|
||||
"from pprint import pprint",
|
||||
]
|
||||
# SHELL_PLUS_PRINT_SQL = True
|
||||
|
||||
MIDDLEWARE = (
|
||||
"django_prometheus.middleware.PrometheusBeforeMiddleware",
|
||||
|
|
Loading…
Add table
Reference in a new issue