Spaces:
Runtime error
Runtime error
import uuid | |
from typing import List, Dict | |
from qdrant_client import QdrantClient, models as qmodels | |
from llama_index.llms.openai import OpenAI | |
from fastembed import TextEmbedding | |
from .models import FoodItem | |
from .utils import synthesize_food_item | |
likes = ["dosa", "fanta", "croissant", "waffles"] | |
dislikes = ["virgin mojito"] | |
menu = ["croissant", "mango", "jalebi"] | |
class RecommendationEngine: | |
def __init__( | |
self, category: str, qdrant: QdrantClient, fastembed_model: TextEmbedding | |
) -> None: | |
self.collection = f"{category}_preferences" | |
self.qdrant = qdrant | |
self.embedding_model = fastembed_model | |
if self.qdrant.collection_exists(self.collection): | |
self.counter = self.qdrant.count(self.collection, exact=True).count | |
else: | |
self.reset() | |
self.counter = 0 | |
def reset(self): | |
self.qdrant.recreate_collection( | |
self.collection, | |
vectors_config=qmodels.VectorParams( | |
size=384, distance=qmodels.Distance.COSINE | |
), | |
) | |
def _generate_vector(self, model_json: dict): | |
embedding_txt = "" | |
for k, v in model_json.items(): | |
embedding_txt += f"{k}: {v}" | |
return list(self.embedding_model.passage_embed([embedding_txt]))[0] | |
def _insert_preference(self, item: FoodItem, *args, **kwargs): | |
model_json: dict = item.model_dump() | |
embedding = self._generate_vector(model_json) | |
model_json.update(kwargs) | |
self.qdrant.upsert( | |
self.collection, | |
points=[ | |
qmodels.PointStruct( | |
id=self.counter, payload=model_json, vector=embedding | |
) | |
], | |
) | |
self.counter += 1 | |
def like(self, item: FoodItem): | |
self._insert_preference(item, liked=True) | |
def dislike(self, item: FoodItem): | |
self._insert_preference(item, liked=False) | |
def recommend_from_given( | |
self, items: List[FoodItem], limit: int = 3 | |
) -> Dict[str, float]: | |
liked_points, _offset = self.qdrant.scroll( | |
self.collection, | |
scroll_filter={"must": [{"key": "liked", "match": {"value": True}}]}, | |
) | |
disliked_points, _offset = self.qdrant.scroll( | |
self.collection, | |
scroll_filter={"must": [{"key": "liked", "match": {"value": False}}]}, | |
) | |
# Insert points in DB so they can be recommended: | |
# A bit ugly but this is the best possible thing at the moment. | |
query_id = str(uuid.uuid1()) | |
for item in items: | |
self._insert_preference(item, query_id=query_id) | |
scored_points = self.qdrant.recommend( | |
self.collection, | |
positive=[p.id for p in liked_points], | |
negative=[p.id for p in disliked_points], | |
query_filter={"must": [{"key": "query_id", "match": {"value": query_id}}]}, | |
with_payload=True, | |
strategy="best_score", | |
) | |
self.qdrant.delete(self.collection, [p.id for p in scored_points]) | |
return {point.payload["name"]: point.score for point in scored_points} | |
if __name__ == "__main__": | |
llm = OpenAI(model="gpt-3.5-turbo") | |
qdrant = QdrantClient() | |
fastembed_model = TextEmbedding() | |
rec_engine = RecommendationEngine("food", qdrant, fastembed_model) | |
if rec_engine.counter != len(likes) + len(dislikes): | |
rec_engine.reset() | |
print("Filling with starter data") | |
for food_name in likes: | |
food_item = synthesize_food_item(food_name, llm) | |
rec_engine.like(food_item) | |
for food_name in dislikes: | |
food_item = synthesize_food_item(food_name, llm) | |
rec_engine.dislike(food_item) | |
new_items = [synthesize_food_item(food_name, llm) for food_name in menu] | |
recommendations = rec_engine.recommend_from_given(items=new_items) | |
print(recommendations) | |