Spaces:
Running
Running
import os | |
from fastapi import FastAPI, Depends, HTTPException | |
from fastapi.responses import JSONResponse, RedirectResponse | |
from fastapi.middleware.gzip import GZipMiddleware | |
from pydantic import BaseModel | |
from typing import List, Dict | |
from src.api.models.embedding_models import ( | |
CreateEmbeddingRequest, | |
ReadEmbeddingRequest, | |
UpdateEmbeddingRequest, | |
DeleteEmbeddingRequest, | |
EmbedRequest, | |
) | |
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError | |
from src.api.services.embedding_service import EmbeddingService | |
from src.api.services.huggingface_service import HuggingFaceService | |
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError | |
# from src.api.dependency import get_embedding_service, get_huggingface_service | |
import pandas as pd | |
import logging | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Set up structured logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
description = """A FastAPI application for similarity search with PostgreSQL and OpenAI embeddings. | |
Direct/API URL: | |
https://re-mind-similarity-search.hf.space | |
""" | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Similarity Search API", | |
description=description, | |
version="1.0.0", | |
) | |
app.add_middleware(GZipMiddleware, minimum_size=1000) | |
# Dependency to get EmbeddingService | |
def get_embedding_service() -> EmbeddingService: | |
return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY")) | |
# Dependency to get HuggingFaceService | |
def get_huggingface_service() -> HuggingFaceService: | |
return HuggingFaceService() | |
# Root endpoint redirects to /docs | |
async def root(): | |
return RedirectResponse(url="/docs") | |
# Health check endpoint | |
async def health_check(db: Database = Depends(get_db)): | |
try: | |
is_healthy = await db.health_check() | |
if not is_healthy: | |
raise HTTPException(status_code=500, detail="Database is unhealthy") | |
return {"status": "healthy"} | |
except HealthCheckError as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Endpoint to generate embeddings for a list of strings | |
async def embed( | |
request: EmbedRequest, | |
embedding_service: EmbeddingService = Depends(get_embedding_service), | |
): | |
""" | |
Generate embeddings for a list of strings and return them in the response. | |
""" | |
try: | |
# Step 1: Generate embeddings | |
logger.info("Generating embeddings for list of texts...") | |
embeddings = await embedding_service.create_embeddings(request.texts) | |
return JSONResponse( | |
content={ | |
"message": "Embeddings generated successfully.", | |
"embeddings": embeddings, | |
"num_texts": len(request.texts), | |
} | |
) | |
except OpenAIError as e: | |
logger.error(f"OpenAI API error: {e}") | |
raise HTTPException(status_code=500, detail=f"OpenAI API error: {e}") | |
except Exception as e: | |
logger.error(f"An error occurred: {e}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
# Endpoint to create embeddings from a database query | |
async def create_embedding( | |
request: CreateEmbeddingRequest, | |
db: Database = Depends(get_db), | |
embedding_service: EmbeddingService = Depends(get_embedding_service), | |
huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
): | |
""" | |
Create embeddings for the target column in the dataset. | |
""" | |
try: | |
# Step 1: Query the database | |
logger.info("Fetching data from the database...") | |
result = await db.fetch(request.query) | |
df = pd.DataFrame(result) | |
# Step 2: Generate embeddings | |
df = await embedding_service.create_embeddings( | |
df, request.target_column, request.output_column | |
) | |
# Step 3: Push to Hugging Face Hub | |
await huggingface_service.push_to_hub(df, request.dataset_name) | |
return JSONResponse( | |
content={ | |
"message": "Embeddings created and pushed to Hugging Face Hub.", | |
"dataset_name": request.dataset_name, | |
"num_rows": len(df), | |
} | |
) | |
except QueryExecutionError as e: | |
logger.error(f"Database query failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Database query failed: {e}") | |
except OpenAIError as e: | |
logger.error(f"OpenAI API error: {e}") | |
raise HTTPException(status_code=500, detail=f"OpenAI API error: {e}") | |
except DatasetPushError as e: | |
logger.error(f"Failed to push dataset: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to push dataset: {e}") | |
except Exception as e: | |
logger.error(f"An error occurred: {e}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
# Endpoint to read embeddings | |
async def read_embeddings( | |
request: ReadEmbeddingRequest, | |
huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
): | |
""" | |
Read embeddings from a Hugging Face dataset. | |
""" | |
try: | |
df = await huggingface_service.read_dataset(request.dataset_name) | |
return df | |
except DatasetNotFoundError as e: | |
logger.error(f"Dataset not found: {e}") | |
raise HTTPException(status_code=404, detail=f"Dataset not found: {e}") | |
except Exception as e: | |
logger.error(f"An error occurred: {e}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
# Endpoint to update embeddings | |
# @app.post("/update_embeddings") | |
# async def update_embeddings( | |
# request: UpdateEmbeddingRequest, | |
# huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
# ): | |
# """ | |
# Update embeddings in a Hugging Face dataset. | |
# """ | |
# try: | |
# df = await huggingface_service.update_dataset( | |
# request.dataset_name, request.updates | |
# ) | |
# return { | |
# "message": "Embeddings updated successfully.", | |
# "dataset_name": request.dataset_name, | |
# } | |
# except DatasetPushError as e: | |
# logger.error(f"Failed to update dataset: {e}") | |
# raise HTTPException(status_code=500, detail=f"Failed to update dataset: {e}") | |
# except Exception as e: | |
# logger.error(f"An error occurred: {e}") | |
# raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
async def update_embeddings( | |
request: UpdateEmbeddingRequest, | |
huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
): | |
""" | |
Update embeddings in a Hugging Face dataset by generating embeddings for new data and concatenating it with the existing dataset. | |
""" | |
try: | |
# Call the update_dataset method to generate embeddings, concatenate, and push the updated dataset | |
updated_df = await huggingface_service.update_dataset( | |
request.dataset_name, | |
request.updates, | |
request.target_column, | |
request.output_column, | |
) | |
return { | |
"message": "Embeddings updated successfully.", | |
"dataset_name": request.dataset_name, | |
"num_rows": len(updated_df), | |
} | |
except DatasetPushError as e: | |
logger.error(f"Failed to update dataset: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to update dataset: {e}") | |
except Exception as e: | |
logger.error(f"An error occurred: {e}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
# Endpoint to delete embeddings | |
async def delete_embeddings( | |
request: DeleteEmbeddingRequest, | |
huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
): | |
""" | |
Delete embeddings from a Hugging Face dataset. | |
""" | |
try: | |
await huggingface_service.delete_dataset(request.dataset_name) | |
return { | |
"message": "Embeddings deleted successfully.", | |
"dataset_name": request.dataset_name, | |
} | |
except DatasetPushError as e: | |
logger.error(f"Failed to delete columns: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to delete columns: {e}") | |
except Exception as e: | |
logger.error(f"An error occurred: {e}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |