File size: 4,600 Bytes
2ed2129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import pandas as pd
from typing import List
import logging
from sentence_transformers import SentenceTransformer
import os
from pathlib import Path
import pickle

os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # To avoid warnings

logger = logging.getLogger(__name__)

class EmbeddingGenerator:
    def __init__(self, model_name: str = 'all-MiniLM-L6-v2', cache_dir: str = None):
        try:
            self.model_name = model_name
            self.model = SentenceTransformer(model_name)
            
            # Setup cache directory
            self.cache_dir = Path(cache_dir) if cache_dir else Path('data/embedding_cache')
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            
            # Cache file for embeddings
            self.cache_file = self.cache_dir / f"embeddings_cache_{model_name.replace('/', '_')}.pkl"
            
            # Load existing cache if available
            self.embedding_cache = self._load_cache()
            
            logger.info(f"Successfully loaded model: {model_name}")
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            raise

    def _load_cache(self) -> dict:
        """Load embedding cache from file if it exists"""
        try:
            if self.cache_file.exists():
                with open(self.cache_file, 'rb') as f:
                    cache = pickle.load(f)
                logger.info(f"Loaded {len(cache)} cached embeddings")
                return cache
            return {}
        except Exception as e:
            logger.warning(f"Error loading cache, starting fresh: {str(e)}")
            return {}

    def _save_cache(self):
        """Save embedding cache to file"""
        try:
            with open(self.cache_file, 'wb') as f:
                pickle.dump(self.embedding_cache, f)
            logger.info(f"Saved {len(self.embedding_cache)} embeddings to cache")
        except Exception as e:
            logger.error(f"Error saving cache: {str(e)}")

    def generate_embeddings(self, texts: pd.Series) -> np.ndarray:
        try:
            # Convert texts to list
            text_list = texts.tolist()
            
            # Initialize array to store embeddings
            all_embeddings = []
            texts_to_embed = []
            indices_to_embed = []

            # Check cache for existing embeddings
            for i, text in enumerate(text_list):
                text_hash = hash(text)
                if text_hash in self.embedding_cache:
                    all_embeddings.append(self.embedding_cache[text_hash])
                else:
                    texts_to_embed.append(text)
                    indices_to_embed.append(i)

            # Generate embeddings only for new texts
            if texts_to_embed:
                logger.info(f"Generating embeddings for {len(texts_to_embed)} new texts")
                new_embeddings = self.model.encode(
                    texts_to_embed,
                    show_progress_bar=True,
                    convert_to_numpy=True
                )

                # Cache new embeddings
                for text, embedding in zip(texts_to_embed, new_embeddings):
                    text_hash = hash(text)
                    self.embedding_cache[text_hash] = embedding
                
                # Save updated cache
                self._save_cache()

                # Insert new embeddings in correct positions
                for idx, embedding in zip(indices_to_embed, new_embeddings):
                    all_embeddings.insert(idx, embedding)
            else:
                logger.info("All embeddings found in cache")

            return np.array(all_embeddings)
            
        except Exception as e:
            logger.error(f"Error generating embeddings: {str(e)}")
            raise

    def add_embeddings_to_df(self, df: pd.DataFrame, text_column: str = 'description') -> pd.DataFrame:
        try:
            embeddings = self.generate_embeddings(df[text_column])
            df['embeddings'] = list(embeddings)
            return df
        except Exception as e:
            logger.error(f"Error adding embeddings to DataFrame: {str(e)}")
            raise

    def clear_cache(self):
        """Clear the embedding cache"""
        try:
            self.embedding_cache = {}
            if self.cache_file.exists():
                self.cache_file.unlink()
            logger.info("Embedding cache cleared")
        except Exception as e:
            logger.error(f"Error clearing cache: {str(e)}")
            raise