# from openai import AsyncOpenAI # import logging # from typing import List, Dict, Union # import pandas as pd # import asyncio # from src.api.exceptions import OpenAIError # # Set up structured logging # logging.basicConfig( # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" # ) # logger = logging.getLogger(__name__) # class EmbeddingService: # def __init__( # self, # openai_api_key: str, # model: str = "text-embedding-3-small", # batch_size: int = 10, # max_concurrent_requests: int = 10, # Limit to 10 concurrent requests # ): # self.client = AsyncOpenAI(api_key=openai_api_key) # self.model = model # self.batch_size = batch_size # self.semaphore = asyncio.Semaphore(max_concurrent_requests) # Rate limiter # self.total_requests = 0 # Total number of requests to process # self.completed_requests = 0 # Number of completed requests # async def get_embedding(self, text: str) -> List[float]: # """Generate embeddings for the given text using OpenAI.""" # text = text.replace("\n", " ") # try: # async with self.semaphore: # Acquire a semaphore slot # response = await self.client.embeddings.create( # input=[text], model=self.model # ) # self.completed_requests += 1 # Increment completed requests # self._log_progress() # Log progress # return response.data[0].embedding # except Exception as e: # logger.error(f"Failed to generate embedding: {e}") # raise OpenAIError(f"OpenAI API error: {e}") # async def create_embeddings( # self, # data: Union[pd.DataFrame, List[str]], # target_column: str = None, # output_column: str = "embeddings", # ) -> Union[pd.DataFrame, List[List[float]]]: # """ # Create embeddings for either a DataFrame or a list of strings. # Args: # data: Either a DataFrame or a list of strings. # target_column: The column in the DataFrame to generate embeddings for (required if data is a DataFrame). # output_column: The column to store embeddings in the DataFrame (default: "embeddings"). # Returns: # If data is a DataFrame, returns the DataFrame with the embeddings column. # If data is a list of strings, returns a list of embeddings. # """ # if isinstance(data, pd.DataFrame): # if not target_column: # raise ValueError("target_column is required when data is a DataFrame.") # return await self._create_embeddings_for_dataframe( # data, target_column, output_column # ) # elif isinstance(data, list): # return await self._create_embeddings_for_texts(data) # else: # raise TypeError( # "data must be either a pandas DataFrame or a list of strings." # ) # async def _create_embeddings_for_dataframe( # self, df: pd.DataFrame, target_column: str, output_column: str # ) -> pd.DataFrame: # """Create embeddings for the target column in the DataFrame.""" # logger.info("Generating embeddings for DataFrame...") # self.total_requests = len(df) # Set total number of requests # self.completed_requests = 0 # Reset completed requests counter # batches = [ # df[i : i + self.batch_size] for i in range(0, len(df), self.batch_size) # ] # processed_batches = await asyncio.gather( # *[ # self._process_batch(batch, target_column, output_column) # for batch in batches # ] # ) # return pd.concat(processed_batches) # async def _create_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]: # """Create embeddings for a list of strings.""" # logger.info("Generating embeddings for list of texts...") # self.total_requests = len(texts) # Set total number of requests # self.completed_requests = 0 # Reset completed requests counter # batches = [ # texts[i : i + self.batch_size] # for i in range(0, len(texts), self.batch_size) # ] # embeddings = [] # for batch in batches: # batch_embeddings = await asyncio.gather( # *[self.get_embedding(text) for text in batch] # ) # embeddings.extend(batch_embeddings) # return embeddings # async def _process_batch( # self, df_batch: pd.DataFrame, target_column: str, output_column: str # ) -> pd.DataFrame: # """Process a batch of rows to generate embeddings.""" # embeddings = await asyncio.gather( # *[self.get_embedding(row[target_column]) for _, row in df_batch.iterrows()] # ) # df_batch[output_column] = embeddings # return df_batch # def _log_progress(self): # """Log the progress of embedding generation.""" # progress = (self.completed_requests / self.total_requests) * 100 # logger.info( # f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)" # ) from openai import AsyncOpenAI import logging from typing import List, Dict, Union from datasets import Dataset import asyncio from src.api.exceptions import OpenAIError # Set up structured logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class EmbeddingService: def __init__( self, openai_api_key: str, model: str = "text-embedding-3-small", batch_size: int = 10, max_concurrent_requests: int = 10, # Limit to 10 concurrent requests ): self.client = AsyncOpenAI(api_key=openai_api_key) self.model = model self.batch_size = batch_size self.semaphore = asyncio.Semaphore(max_concurrent_requests) # Rate limiter self.total_requests = 0 # Total number of requests to process self.completed_requests = 0 # Number of completed requests async def get_embedding(self, text: str) -> List[float]: """Generate embeddings for the given text using OpenAI.""" text = text.replace("\n", " ") try: async with self.semaphore: # Acquire a semaphore slot response = await self.client.embeddings.create( input=[text], model=self.model ) self.completed_requests += 1 # Increment completed requests self._log_progress() # Log progress return response.data[0].embedding except Exception as e: logger.error(f"Failed to generate embedding: {e}") raise OpenAIError(f"OpenAI API error: {e}") async def create_embeddings( self, data: Union[Dataset, List[str]], target_column: str = None, output_column: str = "embeddings", ) -> Union[Dataset, List[List[float]]]: """ Create embeddings for either a Dataset or a list of strings. Args: data: Either a Dataset or a list of strings. target_column: The column in the Dataset to generate embeddings for (required if data is a Dataset). output_column: The column to store embeddings in the Dataset (default: "embeddings"). Returns: If data is a Dataset, returns the Dataset with the embeddings column. If data is a list of strings, returns a list of embeddings. """ if isinstance(data, Dataset): if not target_column: raise ValueError("target_column is required when data is a Dataset.") return await self._create_embeddings_for_dataset( data, target_column, output_column ) elif isinstance(data, list): return await self._create_embeddings_for_texts(data) else: raise TypeError( "data must be either a Hugging Face Dataset or a list of strings." ) async def _create_embeddings_for_dataset( self, dataset: Dataset, target_column: str, output_column: str ) -> Dataset: """Create embeddings for the target column in the Dataset.""" logger.info("Generating embeddings for Dataset...") self.total_requests = len(dataset) # Set total number of requests self.completed_requests = 0 # Reset completed requests counter embeddings = [] for i in range(0, len(dataset), self.batch_size): batch = dataset[i : i + self.batch_size] batch_embeddings = await asyncio.gather( *[self.get_embedding(text) for text in batch[target_column]] ) embeddings.extend(batch_embeddings) dataset = dataset.add_column(output_column, embeddings) return dataset async def _create_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]: """Create embeddings for a list of strings.""" logger.info("Generating embeddings for list of texts...") self.total_requests = len(texts) # Set total number of requests self.completed_requests = 0 # Reset completed requests counter batches = [ texts[i : i + self.batch_size] for i in range(0, len(texts), self.batch_size) ] embeddings = [] for batch in batches: batch_embeddings = await asyncio.gather( *[self.get_embedding(text) for text in batch] ) embeddings.extend(batch_embeddings) return embeddings def _log_progress(self): """Log the progress of embedding generation.""" progress = (self.completed_requests / self.total_requests) * 100 logger.info( f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)" )