import logging from typing import List, Dict, Any, Optional, Union import numpy as np from functools import lru_cache from sentence_transformers import SentenceTransformer from app.config import ( EMBEDDING_MODEL, MAX_TOKEN_LIMIT, ENABLE_CACHE, CACHE_SIZE ) from app.services.preprocessor import TextPreprocessor logger = logging.getLogger(__name__) class EmbeddingService: """Service for generating embeddings for text using sentence-transformers.""" def __init__(self, model_name: str = EMBEDDING_MODEL, preprocessor=None): """ Initialize the embedding service. Args: model_name: Name of the sentence-transformers model to use preprocessor: Optional TextPreprocessor instance """ logger.info(f"Loading embedding model: {model_name}") self.model = SentenceTransformer(model_name) self.model_dim = self.model.get_sentence_embedding_dimension() logger.info(f"Model loaded. Embedding dimension: {self.model_dim}") # Use provided preprocessor or create one self.preprocessor = preprocessor or TextPreprocessor() # Set up caching if enabled if ENABLE_CACHE: self.get_embedding = lru_cache(maxsize=CACHE_SIZE)(self._get_embedding) else: self.get_embedding = self._get_embedding def _get_embedding(self, text: str) -> List[float]: """ Generate embedding for a text string. Args: text: Text to generate embedding for Returns: List of floats representing the embedding vector """ if not text or not isinstance(text, str): logger.warning("Empty or invalid text provided for embedding generation") return [0.0] * self.model_dim # Use preprocessor for token counting only token_count = self.preprocessor.count_tokens(text) # Check against token limit if token_count > MAX_TOKEN_LIMIT: logger.error( f"Text exceeds max token limit ({token_count} > {MAX_TOKEN_LIMIT}). " f"Please chunk your text before encoding." ) raise ValueError(f"Text exceeds max token limit ({token_count} > {MAX_TOKEN_LIMIT})") try: # Directly encode the text string embedding = self.model.encode(text).tolist() return embedding except Exception as e: logger.error(f"Error generating embedding: {str(e)}") return [0.0] * self.model_dim def get_embeddings_batch(self, texts: List[str]) -> List[List[float]]: """ Generate embeddings for a batch of texts. Args: texts: List of texts to generate embeddings for Returns: List of embedding vectors """ if not texts: return [] # Validate texts are within token limit for i, text in enumerate(texts): if not text or not isinstance(text, str): logger.warning(f"Empty or invalid text at index {i}") continue # Check token count token_count = self.preprocessor.count_tokens(text) if token_count > MAX_TOKEN_LIMIT: logger.error( f"Text at index {i} exceeds max token limit ({token_count} > {MAX_TOKEN_LIMIT}). " f"Please chunk your text before encoding." ) raise ValueError(f"Text at index {i} exceeds max token limit ({token_count} > {MAX_TOKEN_LIMIT})") try: # Let the model handle the batch encoding directly embeddings = self.model.encode(texts).tolist() return embeddings except Exception as e: logger.error(f"Error generating batch embeddings: {str(e)}") return [[0.0] * self.model_dim] * len(texts) def embed_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Generate embeddings for a list of text chunks. Args: chunks: List of chunk dictionaries with text and metadata Returns: List of chunk dictionaries with added embeddings """ if not chunks: return [] # Extract texts from chunks texts = [chunk["text"] for chunk in chunks] # Generate embeddings embeddings = self.get_embeddings_batch(texts) # Add embeddings to chunks result_chunks = [] for chunk, embedding in zip(chunks, embeddings): chunk_with_embedding = chunk.copy() chunk_with_embedding["embedding"] = embedding result_chunks.append(chunk_with_embedding) return result_chunks def similarity_search( self, query: str, embeddings: List[List[float]], texts: List[str], metadata: Optional[List[Dict[str, Any]]] = None, top_k: int = 5 ) -> List[Dict[str, Any]]: """ Find the most similar texts to a query. Args: query: Query text embeddings: List of embedding vectors to search texts: List of texts corresponding to the embeddings metadata: Optional list of metadata for each text top_k: Number of top matches to return Returns: List of matches with text, score, and metadata """ if not query or not embeddings or not texts: return [] if metadata is None: metadata = [{} for _ in range(len(texts))] # Generate query embedding query_embedding = self.get_embedding(query) # Convert to numpy arrays for efficient computation query_embedding_np = np.array(query_embedding) embeddings_np = np.array(embeddings) # Compute cosine similarity similarity_scores = np.dot(embeddings_np, query_embedding_np) / ( np.linalg.norm(embeddings_np, axis=1) * np.linalg.norm(query_embedding_np) ) # Get top-k indices if top_k > len(texts): top_k = len(texts) top_indices = np.argsort(similarity_scores)[-top_k:][::-1] # Prepare results results = [] for idx in top_indices: results.append({ "text": texts[idx], "score": float(similarity_scores[idx]), "metadata": metadata[idx] }) return results