from qdrant_client import QdrantClient, models from sentence_transformers import SentenceTransformer from transformers import AutoModel, AutoImageProcessor import torch import os from datasets import load_dataset from dotenv import load_dotenv import numpy as np import uuid from PIL import Image, ImageFile from fastembed import SparseTextEmbedding import cohere load_dotenv() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encoder = SentenceTransformer("sentence-transformers/LaBSE").to(device) processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large') image_encoder = AutoModel.from_pretrained("facebook/dinov2-large").to(device) qdrant_client = QdrantClient(url=os.getenv("qdrant_url"), api_key=os.getenv("qdrant_api_key")) sparse_encoder = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1") co = cohere.ClientV2(os.getenv("cohere_api_key")) dataset = load_dataset("Karbo31881/Pokemon_images") ds = dataset["train"] labels = ds["text"] def get_sparse_embedding(text: str, model: SparseTextEmbedding): embeddings = list(model.embed(text)) vector = {f"sparse-text": models.SparseVector(indices=embeddings[0].indices, values=embeddings[0].values)} return vector def get_query_sparse_embedding(text: str, model: SparseTextEmbedding): embeddings = list(model.embed(text)) query_vector = models.NamedSparseVector( name="sparse-text", vector=models.SparseVector( indices=embeddings[0].indices, values=embeddings[0].values, ), ) return query_vector def upload_text_to_qdrant(client: QdrantClient, collection_name: str, encoder: SentenceTransformer, text: str, point_id_dense: int, point_id_sparse: int): try: docs = {"text": text} client.upsert( collection_name=collection_name, points=[ models.PointStruct( id=point_id_dense, vector={f"dense-text": encoder.encode(docs["text"]).tolist()}, payload=docs, ) ], ) client.upsert( collection_name=collection_name, points=[ models.PointStruct( id=point_id_sparse, vector=get_sparse_embedding(docs["text"], sparse_encoder), payload=docs, ) ], ) return True except Exception as e: return False def upload_images_to_qdrant(client: QdrantClient, collection_name: str, vectorsfile: str, labelslist: list): try: vectors = np.load(vectorsfile) docs = [] for label in labelslist: docs.append({"label": label}) client.upload_points( collection_name=collection_name, points=[ models.PointStruct( id=idx, vector=vectors[idx].tolist(), payload=doc, ) for idx, doc in enumerate(docs) ], ) return True except Exception as e: return False class SemanticCache: def __init__(self, client: QdrantClient, text_encoder: SentenceTransformer, collection_name: str, threshold: float = 0.75): self.client = client self.text_encoder = text_encoder self.collection_name = collection_name self.threshold = threshold def upload_to_cache(self, question: str, answer: str): docs = {"question": question, "answer": answer} point_id = str(uuid.uuid4()) self.client.upsert( collection_name=self.collection_name, points=[ models.PointStruct( id=point_id, vector=self.text_encoder.encode(docs["question"]).tolist(), payload=docs, ) ], ) def search_cache(self, question: str, limit: int = 5): vector = self.text_encoder.encode(question).tolist() search_result = self.client.search( collection_name=self.collection_name, query_vector=vector, query_filter=None, limit=limit, ) payloads = [hit.payload["answer"] for hit in search_result if hit.score > self.threshold] if len(payloads) > 0: return payloads[0] else: return "" class NeuralSearcher: def __init__(self, text_collection_name: str, image_collection_name: str, client: QdrantClient, text_encoder: SentenceTransformer , image_encoder: AutoModel, image_processor: AutoImageProcessor, sparse_encoder: SparseTextEmbedding): self.text_collection_name = text_collection_name self.image_collection_name = image_collection_name self.text_encoder = text_encoder self.image_encoder = image_encoder self.image_processor = image_processor self.qdrant_client = client self.sparse_encoder = sparse_encoder def search_text(self, text: str, limit: int = 5): vector = self.text_encoder.encode(text).tolist() search_result_dense = self.qdrant_client.search( collection_name=self.text_collection_name, query_vector=models.NamedVector(name="dense-text", vector=vector), query_filter=None, limit=limit, ) search_result_sparse = self.qdrant_client.search( collection_name=self.text_collection_name, query_vector=get_query_sparse_embedding(text, self.sparse_encoder), query_filter=None, limit=limit, ) payloads = [hit.payload["text"] for hit in search_result_dense] payloads += [hit.payload["text"] for hit in search_result_sparse] return payloads def reranking(self, text: str, search_result: list): results = co.rerank(model="rerank-v3.5", query=text, documents=search_result, top_n = 3) ranked_results = [search_result[results.results[i].index] for i in range(3)] return ranked_results def search_image(self, image: ImageFile, limit: int = 5): img = image inputs = self.image_processor(images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = self.image_encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() search_result = self.qdrant_client.search( collection_name=self.image_collection_name, query_vector=outputs[0].tolist(), query_filter=None, limit=limit, ) payloads = [f"- {hit.payload['label']} with score {hit.score}" for hit in search_result] return payloads