Last commit not found
from fastapi import FastAPI, HTTPException, Header, Depends, BackgroundTasks, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Field | |
from typing import List, Optional, Dict, AsyncGenerator | |
import json | |
import os | |
import logging | |
from txtai.embeddings import Embeddings | |
import pandas as pd | |
import glob | |
import uuid | |
import httpx | |
import asyncio | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="Embeddings API", | |
description="An API for creating and querying text embeddings indexes.", | |
version="1.0.0" | |
) | |
CHAT_AUTH_KEY = os.environ.get("CHAT_AUTH_KEY", "default_secret_key") | |
# Enable CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
embeddings = Embeddings({"path": "avsolatorio/GIST-all-MiniLM-L6-v2"}) | |
class DocumentRequest(BaseModel): | |
index_id: str = Field(..., description="Unique identifier for the index") | |
documents: List[str] = Field(..., description="List of documents to be indexed") | |
class QueryRequest(BaseModel): | |
index_id: str = Field(..., description="Unique identifier for the index to query") | |
query: str = Field(..., description="The search query") | |
num_results: int = Field(..., description="Number of results to return", ge=1) | |
def save_embeddings(index_id: str, document_list: List[str]): | |
try: | |
folder_path = f"/app/indexes/{index_id}" | |
os.makedirs(folder_path, exist_ok=True) | |
# Save embeddings | |
embeddings.save(f"{folder_path}/embeddings") | |
# Save document_list | |
with open(f"{folder_path}/document_list.json", "w") as f: | |
json.dump(document_list, f) | |
logger.info(f"Embeddings and document list saved for index_id: {index_id}") | |
except Exception as e: | |
logger.error(f"Error saving embeddings for index_id {index_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error saving embeddings: {str(e)}") | |
def load_embeddings(index_id: str) -> List[str]: | |
try: | |
folder_path = f"/app/indexes/{index_id}" | |
if not os.path.exists(folder_path): | |
logger.error(f"Index not found for index_id: {index_id}") | |
raise HTTPException(status_code=404, detail="Index not found") | |
# Load embeddings | |
embeddings.load(f"{folder_path}/embeddings") | |
# Load document_list | |
with open(f"{folder_path}/document_list.json", "r") as f: | |
document_list = json.load(f) | |
logger.info(f"Embeddings and document list loaded for index_id: {index_id}") | |
return document_list | |
except Exception as e: | |
logger.error(f"Error loading embeddings for index_id {index_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error loading embeddings: {str(e)}") | |
async def create_index(request: DocumentRequest): | |
""" | |
Create a new index with the given documents. | |
- **index_id**: Unique identifier for the index | |
- **documents**: List of documents to be indexed | |
""" | |
try: | |
document_list = [(i, text, None) for i, text in enumerate(request.documents)] | |
embeddings.index(document_list) | |
save_embeddings(request.index_id, request.documents) # Save the original documents | |
logger.info(f"Index created successfully for index_id: {request.index_id}") | |
return {"message": "Index created successfully"} | |
except Exception as e: | |
logger.error(f"Error creating index: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error creating index: {str(e)}") | |
async def query_index(request: QueryRequest): | |
""" | |
Query an existing index with the given search query. | |
- **index_id**: Unique identifier for the index to query | |
- **query**: The search query | |
- **num_results**: Number of results to return | |
""" | |
try: | |
document_list = load_embeddings(request.index_id) | |
results = embeddings.search(request.query, request.num_results) | |
queried_texts = [document_list[idx[0]] for idx in results] | |
logger.info(f"Query executed successfully for index_id: {request.index_id}") | |
return {"queried_texts": queried_texts} | |
except Exception as e: | |
logger.error(f"Error querying index: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error querying index: {str(e)}") | |
def process_csv_file(file_path): | |
try: | |
df = pd.read_csv(file_path) | |
df_rows = df.apply(lambda row: ' '.join(row.values.astype(str)), axis=1) | |
txtai_data = [(i, row, None) for i, row in enumerate(df_rows)] | |
return txtai_data, df_rows.tolist() | |
except Exception as e: | |
logger.error(f"Error processing CSV file {file_path}: {str(e)}") | |
return None, None | |
def check_and_index_csv_files(): | |
index_data_folder = "/app/index_data" | |
if not os.path.exists(index_data_folder): | |
logger.warning(f"index_data folder not found: {index_data_folder}") | |
return | |
csv_files = glob.glob(os.path.join(index_data_folder, "*.csv")) | |
for csv_file in csv_files: | |
index_id = os.path.splitext(os.path.basename(csv_file))[0] | |
if not os.path.exists(f"/app/indexes/{index_id}"): | |
logger.info(f"Processing CSV file: {csv_file}") | |
txtai_data, documents = process_csv_file(csv_file) | |
if txtai_data and documents: | |
embeddings.index(txtai_data) | |
save_embeddings(index_id, documents) | |
logger.info(f"CSV file indexed successfully: {csv_file}") | |
else: | |
logger.warning(f"Failed to process CSV file: {csv_file}") | |
else: | |
logger.info(f"Index already exists for: {csv_file}") | |
# ... [Previous code for DocumentRequest, QueryRequest, save_embeddings, load_embeddings, create_index, query_index, process_csv_file, check_and_index_csv_files remains the same] | |
class ChatRequest(BaseModel): | |
query: str = Field(..., description="The user's query") | |
index_id: str = Field(..., description="Unique identifier for the index to query") | |
conversation_id: Optional[str] = Field(None, description="Unique identifier for the conversation") | |
model_id: str = Field(..., description="Identifier for the LLM model to use") | |
user_id: str = Field(..., description="Unique identifier for the user") | |
async def get_api_key(x_api_key: str = Header(...)) -> str: | |
if x_api_key != CHAT_AUTH_KEY: | |
raise HTTPException(status_code=403, detail="Invalid API key") | |
return x_api_key | |
async def stream_llm_request(api_key: str, llm_request: Dict[str, str]) -> AsyncGenerator[str, None]: | |
""" | |
Make a streaming request to the LLM service. | |
""" | |
try: | |
async with httpx.AsyncClient() as client: | |
async with client.stream( | |
"POST", | |
"https://pvanand-audio-chat.hf.space/llm-agent", | |
headers={ | |
"accept": "text/event-stream", | |
"X-API-Key": api_key, | |
"Content-Type": "application/json" | |
}, | |
json=llm_request | |
) as response: | |
if response.status_code != 200: | |
raise HTTPException(status_code=response.status_code, detail="Error from LLM service") | |
async for chunk in response.aiter_text(): | |
yield chunk | |
except httpx.HTTPError as e: | |
logger.error(f"HTTP error occurred while making LLM request: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"HTTP error occurred while making LLM request: {str(e)}") | |
except Exception as e: | |
logger.error(f"Unexpected error occurred while making LLM request: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Unexpected error occurred while making LLM request: {str(e)}") | |
async def chat(request: ChatRequest, background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key)): | |
""" | |
Chat endpoint that uses embeddings search and LLM for response generation. | |
""" | |
try: | |
# Load embeddings for the specified index | |
document_list = load_embeddings(request.index_id) | |
# Perform embeddings search | |
search_results = embeddings.search(request.query, 5) # Get top 5 relevant results | |
context = "\n".join([document_list[idx[0]] for idx in search_results]) | |
# Create RAG prompt | |
rag_prompt = f"Based on the following context, please answer the user's question:\n\nContext:\n{context}\n\nUser's question: {request.query}\n\nAnswer:" | |
# Generate conversation_id if not provided | |
conversation_id = request.conversation_id or str(uuid.uuid4()) | |
# Prepare the request for the LLM service | |
llm_request = { | |
"prompt": request.query, | |
"system_message": rag_prompt, | |
"model_id": request.model_id, | |
"conversation_id": conversation_id, | |
"user_id": request.user_id | |
} | |
async def response_generator(): | |
full_response = "" | |
async for chunk in stream_llm_request(api_key, llm_request): | |
full_response += chunk | |
yield chunk | |
# Here you might want to add logic to save the conversation or perform other background tasks | |
# For example: | |
# background_tasks.add_task(save_conversation, request.user_id, conversation_id, request.query, full_response) | |
logger.info(f"Starting chat response generation for user: {request.user_id}") | |
return StreamingResponse(response_generator(), media_type="text/event-stream") | |
except Exception as e: | |
logger.error(f"Error in chat endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error in chat endpoint: {str(e)}") | |
async def startup_event(): | |
check_and_index_csv_files() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |