Spaces:
Sleeping
Sleeping
import numpy as np | |
import pandas as pd | |
from typing import List | |
import logging | |
from sentence_transformers import SentenceTransformer | |
import os | |
from pathlib import Path | |
import pickle | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' # To avoid warnings | |
logger = logging.getLogger(__name__) | |
class EmbeddingGenerator: | |
def __init__(self, model_name: str = 'all-MiniLM-L6-v2', cache_dir: str = None): | |
try: | |
self.model_name = model_name | |
self.model = SentenceTransformer(model_name) | |
# Setup cache directory | |
self.cache_dir = Path(cache_dir) if cache_dir else Path('data/embedding_cache') | |
self.cache_dir.mkdir(parents=True, exist_ok=True) | |
# Cache file for embeddings | |
self.cache_file = self.cache_dir / f"embeddings_cache_{model_name.replace('/', '_')}.pkl" | |
# Load existing cache if available | |
self.embedding_cache = self._load_cache() | |
logger.info(f"Successfully loaded model: {model_name}") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise | |
def _load_cache(self) -> dict: | |
"""Load embedding cache from file if it exists""" | |
try: | |
if self.cache_file.exists(): | |
with open(self.cache_file, 'rb') as f: | |
cache = pickle.load(f) | |
logger.info(f"Loaded {len(cache)} cached embeddings") | |
return cache | |
return {} | |
except Exception as e: | |
logger.warning(f"Error loading cache, starting fresh: {str(e)}") | |
return {} | |
def _save_cache(self): | |
"""Save embedding cache to file""" | |
try: | |
with open(self.cache_file, 'wb') as f: | |
pickle.dump(self.embedding_cache, f) | |
logger.info(f"Saved {len(self.embedding_cache)} embeddings to cache") | |
except Exception as e: | |
logger.error(f"Error saving cache: {str(e)}") | |
def generate_embeddings(self, texts: pd.Series) -> np.ndarray: | |
try: | |
# Convert texts to list | |
text_list = texts.tolist() | |
# Initialize array to store embeddings | |
all_embeddings = [] | |
texts_to_embed = [] | |
indices_to_embed = [] | |
# Check cache for existing embeddings | |
for i, text in enumerate(text_list): | |
text_hash = hash(text) | |
if text_hash in self.embedding_cache: | |
all_embeddings.append(self.embedding_cache[text_hash]) | |
else: | |
texts_to_embed.append(text) | |
indices_to_embed.append(i) | |
# Generate embeddings only for new texts | |
if texts_to_embed: | |
logger.info(f"Generating embeddings for {len(texts_to_embed)} new texts") | |
new_embeddings = self.model.encode( | |
texts_to_embed, | |
show_progress_bar=True, | |
convert_to_numpy=True | |
) | |
# Cache new embeddings | |
for text, embedding in zip(texts_to_embed, new_embeddings): | |
text_hash = hash(text) | |
self.embedding_cache[text_hash] = embedding | |
# Save updated cache | |
self._save_cache() | |
# Insert new embeddings in correct positions | |
for idx, embedding in zip(indices_to_embed, new_embeddings): | |
all_embeddings.insert(idx, embedding) | |
else: | |
logger.info("All embeddings found in cache") | |
return np.array(all_embeddings) | |
except Exception as e: | |
logger.error(f"Error generating embeddings: {str(e)}") | |
raise | |
def add_embeddings_to_df(self, df: pd.DataFrame, text_column: str = 'description') -> pd.DataFrame: | |
try: | |
embeddings = self.generate_embeddings(df[text_column]) | |
df['embeddings'] = list(embeddings) | |
return df | |
except Exception as e: | |
logger.error(f"Error adding embeddings to DataFrame: {str(e)}") | |
raise | |
def clear_cache(self): | |
"""Clear the embedding cache""" | |
try: | |
self.embedding_cache = {} | |
if self.cache_file.exists(): | |
self.cache_file.unlink() | |
logger.info("Embedding cache cleared") | |
except Exception as e: | |
logger.error(f"Error clearing cache: {str(e)}") | |
raise |