from openai import AsyncOpenAI import logging from typing import List, Dict, Union import pandas as pd import asyncio from src.api.exceptions import OpenAIError # Set up structured logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class EmbeddingService: def __init__( self, openai_api_key: str, model: str = "text-embedding-3-small", batch_size: int = 10, max_concurrent_requests: int = 10, # Limit to 10 concurrent requests ): self.client = AsyncOpenAI(api_key=openai_api_key) self.model = model self.batch_size = batch_size self.semaphore = asyncio.Semaphore(max_concurrent_requests) # Rate limiter self.total_requests = 0 # Total number of requests to process self.completed_requests = 0 # Number of completed requests async def get_embedding(self, text: str) -> List[float]: """Generate embeddings for the given text using OpenAI.""" text = text.replace("\n", " ") try: async with self.semaphore: # Acquire a semaphore slot response = await self.client.embeddings.create( input=[text], model=self.model ) self.completed_requests += 1 # Increment completed requests self._log_progress() # Log progress return response.data[0].embedding except Exception as e: logger.error(f"Failed to generate embedding: {e}") raise OpenAIError(f"OpenAI API error: {e}") async def create_embeddings( self, data: Union[pd.DataFrame, List[str]], target_column: str = None, output_column: str = "embeddings", ) -> Union[pd.DataFrame, List[List[float]]]: """ Create embeddings for either a DataFrame or a list of strings. Args: data: Either a DataFrame or a list of strings. target_column: The column in the DataFrame to generate embeddings for (required if data is a DataFrame). output_column: The column to store embeddings in the DataFrame (default: "embeddings"). Returns: If data is a DataFrame, returns the DataFrame with the embeddings column. If data is a list of strings, returns a list of embeddings. """ if isinstance(data, pd.DataFrame): if not target_column: raise ValueError("target_column is required when data is a DataFrame.") return await self._create_embeddings_for_dataframe( data, target_column, output_column ) elif isinstance(data, list): return await self._create_embeddings_for_texts(data) else: raise TypeError( "data must be either a pandas DataFrame or a list of strings." ) async def _create_embeddings_for_dataframe( self, df: pd.DataFrame, target_column: str, output_column: str ) -> pd.DataFrame: """Create embeddings for the target column in the DataFrame.""" logger.info("Generating embeddings for DataFrame...") self.total_requests = len(df) # Set total number of requests self.completed_requests = 0 # Reset completed requests counter batches = [ df[i : i + self.batch_size] for i in range(0, len(df), self.batch_size) ] processed_batches = await asyncio.gather( *[ self._process_batch(batch, target_column, output_column) for batch in batches ] ) return pd.concat(processed_batches) async def _create_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]: """Create embeddings for a list of strings.""" logger.info("Generating embeddings for list of texts...") self.total_requests = len(texts) # Set total number of requests self.completed_requests = 0 # Reset completed requests counter batches = [ texts[i : i + self.batch_size] for i in range(0, len(texts), self.batch_size) ] embeddings = [] for batch in batches: batch_embeddings = await asyncio.gather( *[self.get_embedding(text) for text in batch] ) embeddings.extend(batch_embeddings) return embeddings async def _process_batch( self, df_batch: pd.DataFrame, target_column: str, output_column: str ) -> pd.DataFrame: """Process a batch of rows to generate embeddings.""" embeddings = await asyncio.gather( *[self.get_embedding(row[target_column]) for _, row in df_batch.iterrows()] ) df_batch[output_column] = embeddings return df_batch def _log_progress(self): """Log the progress of embedding generation.""" progress = (self.completed_requests / self.total_requests) * 100 logger.info( f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)" )