Similarity_Search / src /api /services /embedding_service.py
amaye15
Feat - Use huggingface dataset instead of pandas
0611c31
raw
history blame
10.2 kB
# from openai import AsyncOpenAI
# import logging
# from typing import List, Dict, Union
# import pandas as pd
# import asyncio
# from src.api.exceptions import OpenAIError
# # Set up structured logging
# logging.basicConfig(
# level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# )
# logger = logging.getLogger(__name__)
# class EmbeddingService:
# def __init__(
# self,
# openai_api_key: str,
# model: str = "text-embedding-3-small",
# batch_size: int = 10,
# max_concurrent_requests: int = 10, # Limit to 10 concurrent requests
# ):
# self.client = AsyncOpenAI(api_key=openai_api_key)
# self.model = model
# self.batch_size = batch_size
# self.semaphore = asyncio.Semaphore(max_concurrent_requests) # Rate limiter
# self.total_requests = 0 # Total number of requests to process
# self.completed_requests = 0 # Number of completed requests
# async def get_embedding(self, text: str) -> List[float]:
# """Generate embeddings for the given text using OpenAI."""
# text = text.replace("\n", " ")
# try:
# async with self.semaphore: # Acquire a semaphore slot
# response = await self.client.embeddings.create(
# input=[text], model=self.model
# )
# self.completed_requests += 1 # Increment completed requests
# self._log_progress() # Log progress
# return response.data[0].embedding
# except Exception as e:
# logger.error(f"Failed to generate embedding: {e}")
# raise OpenAIError(f"OpenAI API error: {e}")
# async def create_embeddings(
# self,
# data: Union[pd.DataFrame, List[str]],
# target_column: str = None,
# output_column: str = "embeddings",
# ) -> Union[pd.DataFrame, List[List[float]]]:
# """
# Create embeddings for either a DataFrame or a list of strings.
# Args:
# data: Either a DataFrame or a list of strings.
# target_column: The column in the DataFrame to generate embeddings for (required if data is a DataFrame).
# output_column: The column to store embeddings in the DataFrame (default: "embeddings").
# Returns:
# If data is a DataFrame, returns the DataFrame with the embeddings column.
# If data is a list of strings, returns a list of embeddings.
# """
# if isinstance(data, pd.DataFrame):
# if not target_column:
# raise ValueError("target_column is required when data is a DataFrame.")
# return await self._create_embeddings_for_dataframe(
# data, target_column, output_column
# )
# elif isinstance(data, list):
# return await self._create_embeddings_for_texts(data)
# else:
# raise TypeError(
# "data must be either a pandas DataFrame or a list of strings."
# )
# async def _create_embeddings_for_dataframe(
# self, df: pd.DataFrame, target_column: str, output_column: str
# ) -> pd.DataFrame:
# """Create embeddings for the target column in the DataFrame."""
# logger.info("Generating embeddings for DataFrame...")
# self.total_requests = len(df) # Set total number of requests
# self.completed_requests = 0 # Reset completed requests counter
# batches = [
# df[i : i + self.batch_size] for i in range(0, len(df), self.batch_size)
# ]
# processed_batches = await asyncio.gather(
# *[
# self._process_batch(batch, target_column, output_column)
# for batch in batches
# ]
# )
# return pd.concat(processed_batches)
# async def _create_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]:
# """Create embeddings for a list of strings."""
# logger.info("Generating embeddings for list of texts...")
# self.total_requests = len(texts) # Set total number of requests
# self.completed_requests = 0 # Reset completed requests counter
# batches = [
# texts[i : i + self.batch_size]
# for i in range(0, len(texts), self.batch_size)
# ]
# embeddings = []
# for batch in batches:
# batch_embeddings = await asyncio.gather(
# *[self.get_embedding(text) for text in batch]
# )
# embeddings.extend(batch_embeddings)
# return embeddings
# async def _process_batch(
# self, df_batch: pd.DataFrame, target_column: str, output_column: str
# ) -> pd.DataFrame:
# """Process a batch of rows to generate embeddings."""
# embeddings = await asyncio.gather(
# *[self.get_embedding(row[target_column]) for _, row in df_batch.iterrows()]
# )
# df_batch[output_column] = embeddings
# return df_batch
# def _log_progress(self):
# """Log the progress of embedding generation."""
# progress = (self.completed_requests / self.total_requests) * 100
# logger.info(
# f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)"
# )
from openai import AsyncOpenAI
import logging
from typing import List, Dict, Union
from datasets import Dataset
import asyncio
from src.api.exceptions import OpenAIError
# Set up structured logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(
self,
openai_api_key: str,
model: str = "text-embedding-3-small",
batch_size: int = 10,
max_concurrent_requests: int = 10, # Limit to 10 concurrent requests
):
self.client = AsyncOpenAI(api_key=openai_api_key)
self.model = model
self.batch_size = batch_size
self.semaphore = asyncio.Semaphore(max_concurrent_requests) # Rate limiter
self.total_requests = 0 # Total number of requests to process
self.completed_requests = 0 # Number of completed requests
async def get_embedding(self, text: str) -> List[float]:
"""Generate embeddings for the given text using OpenAI."""
text = text.replace("\n", " ")
try:
async with self.semaphore: # Acquire a semaphore slot
response = await self.client.embeddings.create(
input=[text], model=self.model
)
self.completed_requests += 1 # Increment completed requests
self._log_progress() # Log progress
return response.data[0].embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
raise OpenAIError(f"OpenAI API error: {e}")
async def create_embeddings(
self,
data: Union[Dataset, List[str]],
target_column: str = None,
output_column: str = "embeddings",
) -> Union[Dataset, List[List[float]]]:
"""
Create embeddings for either a Dataset or a list of strings.
Args:
data: Either a Dataset or a list of strings.
target_column: The column in the Dataset to generate embeddings for (required if data is a Dataset).
output_column: The column to store embeddings in the Dataset (default: "embeddings").
Returns:
If data is a Dataset, returns the Dataset with the embeddings column.
If data is a list of strings, returns a list of embeddings.
"""
if isinstance(data, Dataset):
if not target_column:
raise ValueError("target_column is required when data is a Dataset.")
return await self._create_embeddings_for_dataset(
data, target_column, output_column
)
elif isinstance(data, list):
return await self._create_embeddings_for_texts(data)
else:
raise TypeError(
"data must be either a Hugging Face Dataset or a list of strings."
)
async def _create_embeddings_for_dataset(
self, dataset: Dataset, target_column: str, output_column: str
) -> Dataset:
"""Create embeddings for the target column in the Dataset."""
logger.info("Generating embeddings for Dataset...")
self.total_requests = len(dataset) # Set total number of requests
self.completed_requests = 0 # Reset completed requests counter
embeddings = []
for i in range(0, len(dataset), self.batch_size):
batch = dataset[i : i + self.batch_size]
batch_embeddings = await asyncio.gather(
*[self.get_embedding(text) for text in batch[target_column]]
)
embeddings.extend(batch_embeddings)
dataset = dataset.add_column(output_column, embeddings)
return dataset
async def _create_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]:
"""Create embeddings for a list of strings."""
logger.info("Generating embeddings for list of texts...")
self.total_requests = len(texts) # Set total number of requests
self.completed_requests = 0 # Reset completed requests counter
batches = [
texts[i : i + self.batch_size]
for i in range(0, len(texts), self.batch_size)
]
embeddings = []
for batch in batches:
batch_embeddings = await asyncio.gather(
*[self.get_embedding(text) for text in batch]
)
embeddings.extend(batch_embeddings)
return embeddings
def _log_progress(self):
"""Log the progress of embedding generation."""
progress = (self.completed_requests / self.total_requests) * 100
logger.info(
f"Progress: {self.completed_requests}/{self.total_requests} ({progress:.2f}%)"
)