diff --git a/apps/rss_feeds/models.py b/apps/rss_feeds/models.py index cb0a34672..0e647ecbc 100755 --- a/apps/rss_feeds/models.py +++ b/apps/rss_feeds/models.py @@ -341,14 +341,14 @@ class Feed(models.Model): SearchFeed.create_elasticsearch_mapping(delete=True) last_pk = cls.objects.latest("pk").pk - for f in range(offset, last_pk, 1000): + for f in range(offset, last_pk, 10): print( " ---> {f} / {last_pk} ({pct}%)".format( f=f, last_pk=last_pk, pct=str(float(f) / last_pk * 100)[:2] ) ) feeds = Feed.objects.filter( - pk__in=range(f, f + 1000), active=True, active_subscribers__gte=subscribers + pk__in=range(f, f + 10), active=True, active_subscribers__gte=subscribers ).values_list("pk") for (feed_id,) in feeds: Feed.objects.get(pk=feed_id).index_feed_for_search() @@ -364,6 +364,7 @@ class Feed(models.Model): address=self.feed_address, link=self.feed_link, num_subscribers=self.num_subscribers, + content_vector=SearchFeed.generate_feed_content_vector(self.pk), ) def index_stories_for_search(self): diff --git a/apps/search/models.py b/apps/search/models.py index 949c67542..06aa54516 100644 --- a/apps/search/models.py +++ b/apps/search/models.py @@ -6,11 +6,13 @@ import time import celery import elasticsearch import mongoengine as mongo +import numpy as np import pymongo import redis import urllib3 from django.conf import settings from django.contrib.auth.models import User +from sentence_transformers import SentenceTransformer from apps.search.tasks import ( FinishIndexSubscriptionsForSearch, @@ -491,6 +493,7 @@ class SearchStory: class SearchFeed: _es_client = None name = "feeds" + model = None @classmethod def ES(cls): @@ -574,6 +577,10 @@ class SearchFeed: "term_vector": "with_positions_offsets", "type": "text", }, + "content_vector": { + "type": "dense_vector", + "dims": 384, # Numbers of dims from all-MiniLM-L6-v2 + }, } cls.ES().indices.put_mapping( body={ @@ -584,13 +591,14 @@ class SearchFeed: cls.ES().indices.flush(cls.index_name()) @classmethod - def index(cls, feed_id, title, address, link, num_subscribers): + def index(cls, feed_id, title, address, link, num_subscribers, content_vector): doc = { "feed_id": feed_id, "title": title, "feed_address": address, "link": link, "num_subscribers": num_subscribers, + "content_vector": content_vector, } try: cls.ES().create(index=cls.index_name(), id=feed_id, body=doc, doc_type=cls.doc_type()) @@ -681,6 +689,76 @@ class SearchFeed: return results["hits"]["hits"] + @classmethod + def vector_query(cls, query_vector, max_results=10): + try: + cls.ES().indices.flush(index=cls.index_name()) + except elasticsearch.exceptions.NotFoundError as e: + logging.debug(f" ***> ~FRNo search server available: {e}") + return [] + + body = { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0", + "params": {"query_vector": query_vector}, + }, + } + }, + "size": max_results, + } + try: + results = cls.ES().search(body=body, index=cls.index_name(), doc_type=cls.doc_type()) + except elasticsearch.exceptions.RequestError as e: + logging.debug(" ***> ~FRNo search server available for querying: %s" % e) + return [] + + logging.info( + f"~FGVector search ~FCfeeds~FG: ~SB{max_results}~SN requested, ~SB{len(results['hits']['hits'])}~SN results" + ) + + return results["hits"]["hits"] + + @classmethod + def generate_feed_content_vector(cls, feed_id, text=None): + from apps.rss_feeds.models import Feed + + if cls.model is None: + cls.model = SentenceTransformer("all-MiniLM-L6-v2") + + if text is None: + feed = Feed.objects.get(id=feed_id) + + # cross_encoder = CrossEncoder("BAAI/bge-large-zh-v2", device="cpu") + # cross_encoder.encode([feed.feed_title, feed.feed_content], convert_to_tensors="all") + + stories = feed.get_stories() + stories_text = "" + for story in stories: + stories_text += f"{story['story_title']} {story['story_authors']} {story['story_content']}" + text = f"{feed.feed_title} {stories_text}" + + # Remove URLs + text = re.sub(r"http\S+", "", text) + + # Remove special characters + text = re.sub(r"[^\w\s]", "", text) + + # Convert to lowercase + text = text.lower() + + # Remove extra whitespace + text = " ".join(text.split()) + + encoded_text = cls.model.encode(text) + normalized_embedding = encoded_text / np.linalg.norm(encoded_text) + + # logging.debug(f" ---> ~FGNormalized embedding for feed {feed_id}: {normalized_embedding}") + + return normalized_embedding + @classmethod def export_csv(cls): import djqscsv diff --git a/config/requirements.txt b/config/requirements.txt index 5c8338985..f56901953 100644 --- a/config/requirements.txt +++ b/config/requirements.txt @@ -126,3 +126,4 @@ webencodings==0.5.1 XlsxWriter==1.3.7 zope.event==4.5.0 zope.interface==5.4.0 +sentence_transformers==3.0.1 diff --git a/newsblur_web/settings.py b/newsblur_web/settings.py index 35b2e3c71..f69a17baa 100644 --- a/newsblur_web/settings.py +++ b/newsblur_web/settings.py @@ -110,6 +110,13 @@ MAX_EMAILS_SENT_PER_DAY_PER_USER = 20 # Most are story notifications # = Django-specific Modules = # =========================== +SHELL_PLUS_IMPORTS = [ + "from apps.search.models import SearchFeed, SearchStory", + "import redis", + "import datetime", + "from pprint import pprint", +] +# SHELL_PLUS_PRINT_SQL = True MIDDLEWARE = ( "django_prometheus.middleware.PrometheusBeforeMiddleware",