Similarity_Search / src /api /services /embedding_service.py
amaye15
Intial Deployment
2cb9dec
raw
history blame
2.13 kB
from openai import AsyncOpenAI
import logging
from typing import List, Dict
import pandas as pd
import asyncio
from src.api.exceptions import OpenAIError
# Set up structured logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(
self,
openai_api_key: str,
model: str = "text-embedding-3-small",
batch_size: int = 100,
):
self.client = AsyncOpenAI(api_key=openai_api_key)
self.model = model
self.batch_size = batch_size
async def get_embedding(self, text: str) -> List[float]:
"""Generate embeddings for the given text using OpenAI."""
text = text.replace("\n", " ")
try:
response = await self.client.embeddings.create(
input=[text], model=self.model
)
return response.data[0].embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
raise OpenAIError(f"OpenAI API error: {e}")
async def create_embeddings(
self, df: pd.DataFrame, target_column: str, output_column: str
) -> pd.DataFrame:
"""Create embeddings for the target column in the dataset."""
logger.info("Generating embeddings...")
batches = [
df[i : i + self.batch_size] for i in range(0, len(df), self.batch_size)
]
processed_batches = await asyncio.gather(
*[
self._process_batch(batch, target_column, output_column)
for batch in batches
]
)
return pd.concat(processed_batches)
async def _process_batch(
self, df_batch: pd.DataFrame, target_column: str, output_column: str
) -> pd.DataFrame:
"""Process a batch of rows to generate embeddings."""
embeddings = await asyncio.gather(
*[self.get_embedding(row[target_column]) for _, row in df_batch.iterrows()]
)
df_batch[output_column] = embeddings
return df_batch