Using sentence transformers and the MiniLM model to create embeddings for feeds.

This commit is contained in:
Samuel Clay 2024-06-30 09:13:25 -04:00
parent 27b1069302
commit 0f579cbc03
4 changed files with 90 additions and 3 deletions

View file

@ -341,14 +341,14 @@ class Feed(models.Model):
SearchFeed.create_elasticsearch_mapping(delete=True)
last_pk = cls.objects.latest("pk").pk
for f in range(offset, last_pk, 1000):
for f in range(offset, last_pk, 10):
print(
" ---> {f} / {last_pk} ({pct}%)".format(
f=f, last_pk=last_pk, pct=str(float(f) / last_pk * 100)[:2]
)
)
feeds = Feed.objects.filter(
pk__in=range(f, f + 1000), active=True, active_subscribers__gte=subscribers
pk__in=range(f, f + 10), active=True, active_subscribers__gte=subscribers
).values_list("pk")
for (feed_id,) in feeds:
Feed.objects.get(pk=feed_id).index_feed_for_search()
@ -364,6 +364,7 @@ class Feed(models.Model):
address=self.feed_address,
link=self.feed_link,
num_subscribers=self.num_subscribers,
content_vector=SearchFeed.generate_feed_content_vector(self.pk),
)
def index_stories_for_search(self):

View file

@ -6,11 +6,13 @@ import time
import celery
import elasticsearch
import mongoengine as mongo
import numpy as np
import pymongo
import redis
import urllib3
from django.conf import settings
from django.contrib.auth.models import User
from sentence_transformers import SentenceTransformer
from apps.search.tasks import (
FinishIndexSubscriptionsForSearch,
@ -491,6 +493,7 @@ class SearchStory:
class SearchFeed:
_es_client = None
name = "feeds"
model = None
@classmethod
def ES(cls):
@ -574,6 +577,10 @@ class SearchFeed:
"term_vector": "with_positions_offsets",
"type": "text",
},
"content_vector": {
"type": "dense_vector",
"dims": 384, # Numbers of dims from all-MiniLM-L6-v2
},
}
cls.ES().indices.put_mapping(
body={
@ -584,13 +591,14 @@ class SearchFeed:
cls.ES().indices.flush(cls.index_name())
@classmethod
def index(cls, feed_id, title, address, link, num_subscribers):
def index(cls, feed_id, title, address, link, num_subscribers, content_vector):
doc = {
"feed_id": feed_id,
"title": title,
"feed_address": address,
"link": link,
"num_subscribers": num_subscribers,
"content_vector": content_vector,
}
try:
cls.ES().create(index=cls.index_name(), id=feed_id, body=doc, doc_type=cls.doc_type())
@ -681,6 +689,76 @@ class SearchFeed:
return results["hits"]["hits"]
@classmethod
def vector_query(cls, query_vector, max_results=10):
try:
cls.ES().indices.flush(index=cls.index_name())
except elasticsearch.exceptions.NotFoundError as e:
logging.debug(f" ***> ~FRNo search server available: {e}")
return []
body = {
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'content_vector') + 1.0",
"params": {"query_vector": query_vector},
},
}
},
"size": max_results,
}
try:
results = cls.ES().search(body=body, index=cls.index_name(), doc_type=cls.doc_type())
except elasticsearch.exceptions.RequestError as e:
logging.debug(" ***> ~FRNo search server available for querying: %s" % e)
return []
logging.info(
f"~FGVector search ~FCfeeds~FG: ~SB{max_results}~SN requested, ~SB{len(results['hits']['hits'])}~SN results"
)
return results["hits"]["hits"]
@classmethod
def generate_feed_content_vector(cls, feed_id, text=None):
from apps.rss_feeds.models import Feed
if cls.model is None:
cls.model = SentenceTransformer("all-MiniLM-L6-v2")
if text is None:
feed = Feed.objects.get(id=feed_id)
# cross_encoder = CrossEncoder("BAAI/bge-large-zh-v2", device="cpu")
# cross_encoder.encode([feed.feed_title, feed.feed_content], convert_to_tensors="all")
stories = feed.get_stories()
stories_text = ""
for story in stories:
stories_text += f"{story['story_title']} {story['story_authors']} {story['story_content']}"
text = f"{feed.feed_title} {stories_text}"
# Remove URLs
text = re.sub(r"http\S+", "", text)
# Remove special characters
text = re.sub(r"[^\w\s]", "", text)
# Convert to lowercase
text = text.lower()
# Remove extra whitespace
text = " ".join(text.split())
encoded_text = cls.model.encode(text)
normalized_embedding = encoded_text / np.linalg.norm(encoded_text)
# logging.debug(f" ---> ~FGNormalized embedding for feed {feed_id}: {normalized_embedding}")
return normalized_embedding
@classmethod
def export_csv(cls):
import djqscsv

View file

@ -126,3 +126,4 @@ webencodings==0.5.1
XlsxWriter==1.3.7
zope.event==4.5.0
zope.interface==5.4.0
sentence_transformers==3.0.1

View file

@ -110,6 +110,13 @@ MAX_EMAILS_SENT_PER_DAY_PER_USER = 20 # Most are story notifications
# = Django-specific Modules =
# ===========================
SHELL_PLUS_IMPORTS = [
"from apps.search.models import SearchFeed, SearchStory",
"import redis",
"import datetime",
"from pprint import pprint",
]
# SHELL_PLUS_PRINT_SQL = True
MIDDLEWARE = (
"django_prometheus.middleware.PrometheusBeforeMiddleware",