Rohil Bansal
commit
2ed2129
import numpy as np
import pandas as pd
from typing import List
import logging
from sentence_transformers import SentenceTransformer
import os
from pathlib import Path
import pickle
os.environ['TOKENIZERS_PARALLELISM'] = 'false' # To avoid warnings
logger = logging.getLogger(__name__)
class EmbeddingGenerator:
def __init__(self, model_name: str = 'all-MiniLM-L6-v2', cache_dir: str = None):
try:
self.model_name = model_name
self.model = SentenceTransformer(model_name)
# Setup cache directory
self.cache_dir = Path(cache_dir) if cache_dir else Path('data/embedding_cache')
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Cache file for embeddings
self.cache_file = self.cache_dir / f"embeddings_cache_{model_name.replace('/', '_')}.pkl"
# Load existing cache if available
self.embedding_cache = self._load_cache()
logger.info(f"Successfully loaded model: {model_name}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
def _load_cache(self) -> dict:
"""Load embedding cache from file if it exists"""
try:
if self.cache_file.exists():
with open(self.cache_file, 'rb') as f:
cache = pickle.load(f)
logger.info(f"Loaded {len(cache)} cached embeddings")
return cache
return {}
except Exception as e:
logger.warning(f"Error loading cache, starting fresh: {str(e)}")
return {}
def _save_cache(self):
"""Save embedding cache to file"""
try:
with open(self.cache_file, 'wb') as f:
pickle.dump(self.embedding_cache, f)
logger.info(f"Saved {len(self.embedding_cache)} embeddings to cache")
except Exception as e:
logger.error(f"Error saving cache: {str(e)}")
def generate_embeddings(self, texts: pd.Series) -> np.ndarray:
try:
# Convert texts to list
text_list = texts.tolist()
# Initialize array to store embeddings
all_embeddings = []
texts_to_embed = []
indices_to_embed = []
# Check cache for existing embeddings
for i, text in enumerate(text_list):
text_hash = hash(text)
if text_hash in self.embedding_cache:
all_embeddings.append(self.embedding_cache[text_hash])
else:
texts_to_embed.append(text)
indices_to_embed.append(i)
# Generate embeddings only for new texts
if texts_to_embed:
logger.info(f"Generating embeddings for {len(texts_to_embed)} new texts")
new_embeddings = self.model.encode(
texts_to_embed,
show_progress_bar=True,
convert_to_numpy=True
)
# Cache new embeddings
for text, embedding in zip(texts_to_embed, new_embeddings):
text_hash = hash(text)
self.embedding_cache[text_hash] = embedding
# Save updated cache
self._save_cache()
# Insert new embeddings in correct positions
for idx, embedding in zip(indices_to_embed, new_embeddings):
all_embeddings.insert(idx, embedding)
else:
logger.info("All embeddings found in cache")
return np.array(all_embeddings)
except Exception as e:
logger.error(f"Error generating embeddings: {str(e)}")
raise
def add_embeddings_to_df(self, df: pd.DataFrame, text_column: str = 'description') -> pd.DataFrame:
try:
embeddings = self.generate_embeddings(df[text_column])
df['embeddings'] = list(embeddings)
return df
except Exception as e:
logger.error(f"Error adding embeddings to DataFrame: {str(e)}")
raise
def clear_cache(self):
"""Clear the embedding cache"""
try:
self.embedding_cache = {}
if self.cache_file.exists():
self.cache_file.unlink()
logger.info("Embedding cache cleared")
except Exception as e:
logger.error(f"Error clearing cache: {str(e)}")
raise