Using sentence transformers and the MiniLM model to create embeddings for feeds.

This commit is contained in:
Samuel Clay 2024-06-30 09:13:25 -04:00
parent 27b1069302
commit 0f579cbc03
4 changed files with 90 additions and 3 deletions

View file

@ -341,14 +341,14 @@ class Feed(models.Model):
SearchFeed.create_elasticsearch_mapping(delete=True) SearchFeed.create_elasticsearch_mapping(delete=True)
last_pk = cls.objects.latest("pk").pk last_pk = cls.objects.latest("pk").pk
for f in range(offset, last_pk, 1000): for f in range(offset, last_pk, 10):
print( print(
" ---> {f} / {last_pk} ({pct}%)".format( " ---> {f} / {last_pk} ({pct}%)".format(
f=f, last_pk=last_pk, pct=str(float(f) / last_pk * 100)[:2] f=f, last_pk=last_pk, pct=str(float(f) / last_pk * 100)[:2]
) )
) )
feeds = Feed.objects.filter( feeds = Feed.objects.filter(
pk__in=range(f, f + 1000), active=True, active_subscribers__gte=subscribers pk__in=range(f, f + 10), active=True, active_subscribers__gte=subscribers
).values_list("pk") ).values_list("pk")
for (feed_id,) in feeds: for (feed_id,) in feeds:
Feed.objects.get(pk=feed_id).index_feed_for_search() Feed.objects.get(pk=feed_id).index_feed_for_search()
@ -364,6 +364,7 @@ class Feed(models.Model):
address=self.feed_address, address=self.feed_address,
link=self.feed_link, link=self.feed_link,
num_subscribers=self.num_subscribers, num_subscribers=self.num_subscribers,
content_vector=SearchFeed.generate_feed_content_vector(self.pk),
) )
def index_stories_for_search(self): def index_stories_for_search(self):

View file

@ -6,11 +6,13 @@ import time
import celery import celery
import elasticsearch import elasticsearch
import mongoengine as mongo import mongoengine as mongo
import numpy as np
import pymongo import pymongo
import redis import redis
import urllib3 import urllib3
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import User from django.contrib.auth.models import User
from sentence_transformers import SentenceTransformer
from apps.search.tasks import ( from apps.search.tasks import (
FinishIndexSubscriptionsForSearch, FinishIndexSubscriptionsForSearch,
@ -491,6 +493,7 @@ class SearchStory:
class SearchFeed: class SearchFeed:
_es_client = None _es_client = None
name = "feeds" name = "feeds"
model = None
@classmethod @classmethod
def ES(cls): def ES(cls):
@ -574,6 +577,10 @@ class SearchFeed:
"term_vector": "with_positions_offsets", "term_vector": "with_positions_offsets",
"type": "text", "type": "text",
}, },
"content_vector": {
"type": "dense_vector",
"dims": 384, # Numbers of dims from all-MiniLM-L6-v2
},
} }
cls.ES().indices.put_mapping( cls.ES().indices.put_mapping(
body={ body={
@ -584,13 +591,14 @@ class SearchFeed:
cls.ES().indices.flush(cls.index_name()) cls.ES().indices.flush(cls.index_name())
@classmethod @classmethod
def index(cls, feed_id, title, address, link, num_subscribers): def index(cls, feed_id, title, address, link, num_subscribers, content_vector):
doc = { doc = {
"feed_id": feed_id, "feed_id": feed_id,
"title": title, "title": title,
"feed_address": address, "feed_address": address,
"link": link, "link": link,
"num_subscribers": num_subscribers, "num_subscribers": num_subscribers,
"content_vector": content_vector,
} }
try: try:
cls.ES().create(index=cls.index_name(), id=feed_id, body=doc, doc_type=cls.doc_type()) cls.ES().create(index=cls.index_name(), id=feed_id, body=doc, doc_type=cls.doc_type())
@ -681,6 +689,76 @@ class SearchFeed:
return results["hits"]["hits"] return results["hits"]["hits"]
@classmethod
def vector_query(cls, query_vector, max_results=10):
try:
cls.ES().indices.flush(index=cls.index_name())
except elasticsearch.exceptions.NotFoundError as e:
logging.debug(f" ***> ~FRNo search server available: {e}")
return []
body = {
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
"params": {"query_vector": query_vector},
},
}
},
"size": max_results,
}
try:
results = cls.ES().search(body=body, index=cls.index_name(), doc_type=cls.doc_type())
except elasticsearch.exceptions.RequestError as e:
logging.debug(" ***> ~FRNo search server available for querying: %s" % e)
return []
logging.info(
f"~FGVector search ~FCfeeds~FG: ~SB{max_results}~SN requested, ~SB{len(results['hits']['hits'])}~SN results"
)
return results["hits"]["hits"]
@classmethod
def generate_feed_content_vector(cls, feed_id, text=None):
from apps.rss_feeds.models import Feed
if cls.model is None:
cls.model = SentenceTransformer("all-MiniLM-L6-v2")
if text is None:
feed = Feed.objects.get(id=feed_id)
# cross_encoder = CrossEncoder("BAAI/bge-large-zh-v2", device="cpu")
# cross_encoder.encode([feed.feed_title, feed.feed_content], convert_to_tensors="all")
stories = feed.get_stories()
stories_text = ""
for story in stories:
stories_text += f"{story['story_title']} {story['story_authors']} {story['story_content']}"
text = f"{feed.feed_title} {stories_text}"
# Remove URLs
text = re.sub(r"http\S+", "", text)
# Remove special characters
text = re.sub(r"[^\w\s]", "", text)
# Convert to lowercase
text = text.lower()
# Remove extra whitespace
text = " ".join(text.split())
encoded_text = cls.model.encode(text)
normalized_embedding = encoded_text / np.linalg.norm(encoded_text)
# logging.debug(f" ---> ~FGNormalized embedding for feed {feed_id}: {normalized_embedding}")
return normalized_embedding
@classmethod @classmethod
def export_csv(cls): def export_csv(cls):
import djqscsv import djqscsv

View file

@ -126,3 +126,4 @@ webencodings==0.5.1
XlsxWriter==1.3.7 XlsxWriter==1.3.7
zope.event==4.5.0 zope.event==4.5.0
zope.interface==5.4.0 zope.interface==5.4.0
sentence_transformers==3.0.1

View file

@ -110,6 +110,13 @@ MAX_EMAILS_SENT_PER_DAY_PER_USER = 20 # Most are story notifications
# = Django-specific Modules = # = Django-specific Modules =
# =========================== # ===========================
SHELL_PLUS_IMPORTS = [
"from apps.search.models import SearchFeed, SearchStory",
"import redis",
"import datetime",
"from pprint import pprint",
]
# SHELL_PLUS_PRINT_SQL = True
MIDDLEWARE = ( MIDDLEWARE = (
"django_prometheus.middleware.PrometheusBeforeMiddleware", "django_prometheus.middleware.PrometheusBeforeMiddleware",