Spaces:
Running
Running
# src/main.py | |
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, BackgroundTasks | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from pydantic import BaseModel | |
from typing import List, Optional, AsyncGenerator, Dict | |
import asyncio | |
import json | |
import uuid | |
from datetime import datetime | |
import aiosqlite | |
from pathlib import Path | |
import shutil | |
import os | |
# Import custom modules | |
from .agents.rag_agent import RAGAgent | |
from .llms.openai_llm import OpenAILanguageModel | |
from .llms.ollama_llm import OllamaLanguageModel | |
from .llms.bert_llm import BERTLanguageModel | |
from .llms.falcon_llm import FalconLanguageModel | |
from .llms.llama_llm import LlamaLanguageModel | |
from .embeddings.huggingface_embedding import HuggingFaceEmbedding | |
from .vectorstores.chroma_vectorstore import ChromaVectorStore | |
from .utils.document_processor import DocumentProcessor | |
from .utils.conversation_summarizer import ConversationSummarizer | |
from .utils.logger import logger | |
from config.config import settings | |
app = FastAPI(title="RAG Chatbot API") | |
# Initialize core components | |
doc_processor = DocumentProcessor( | |
chunk_size=1000, | |
chunk_overlap=200, | |
max_file_size=10 * 1024 * 1024 | |
) | |
summarizer = ConversationSummarizer() | |
# Pydantic models | |
class ChatRequest(BaseModel): | |
query: str | |
llm_provider: str = 'openai' | |
max_context_docs: int = 3 | |
temperature: float = 0.7 | |
stream: bool = False | |
conversation_id: Optional[str] = None | |
class ChatResponse(BaseModel): | |
response: str | |
context: Optional[List[str]] = None | |
sources: Optional[List[Dict[str, str]]] = None | |
conversation_id: str | |
timestamp: datetime | |
relevant_doc_scores: Optional[List[float]] = None | |
class DocumentResponse(BaseModel): | |
message: str | |
document_id: str | |
status: str | |
document_info: Optional[dict] = None | |
class BatchUploadResponse(BaseModel): | |
message: str | |
processed_files: List[DocumentResponse] | |
failed_files: List[dict] | |
class SummarizeRequest(BaseModel): | |
conversation_id: str | |
include_metadata: bool = True | |
class SummaryResponse(BaseModel): | |
summary: str | |
key_insights: Dict | |
metadata: Optional[Dict] = None | |
class FeedbackRequest(BaseModel): | |
rating: int | |
feedback: Optional[str] = None | |
# Database initialization | |
async def init_db(): | |
async with aiosqlite.connect('chat_history.db') as db: | |
await db.execute(''' | |
CREATE TABLE IF NOT EXISTS chat_history ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
conversation_id TEXT, | |
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, | |
query TEXT, | |
response TEXT, | |
context TEXT, | |
sources TEXT, | |
llm_provider TEXT, | |
feedback TEXT, | |
rating INTEGER | |
) | |
''') | |
await db.commit() | |
# Utility functions | |
def get_llm_instance(provider: str): | |
"""Get LLM instance based on provider""" | |
llm_map = { | |
'openai': lambda: OpenAILanguageModel(api_key=settings.OPENAI_API_KEY), | |
'ollama': lambda: OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL), | |
'bert': lambda: BERTLanguageModel(), | |
'falcon': lambda: FalconLanguageModel(), | |
'llama': lambda: LlamaLanguageModel(), | |
} | |
if provider not in llm_map: | |
raise ValueError(f"Unsupported LLM provider: {provider}") | |
return llm_map[provider]() | |
async def get_vector_store(): | |
"""Initialize and return vector store with embedding model.""" | |
try: | |
embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL) | |
vector_store = ChromaVectorStore( | |
embedding_function=embedding.embed_documents, | |
persist_directory=settings.CHROMA_PATH | |
) | |
return vector_store, embedding | |
except Exception as e: | |
logger.error(f"Error initializing vector store: {str(e)}") | |
raise HTTPException(status_code=500, detail="Failed to initialize vector store") | |
async def process_and_store_document( | |
file_path: Path, | |
vector_store: ChromaVectorStore, | |
document_id: str | |
): | |
"""Process document and store in vector database.""" | |
try: | |
processed_doc = await doc_processor.process_document(file_path) | |
vector_store.add_documents( | |
documents=processed_doc['chunks'], | |
metadatas=[{ | |
'document_id': document_id, | |
'chunk_id': i, | |
'source': str(file_path.name), | |
'metadata': processed_doc['metadata'] | |
} for i in range(len(processed_doc['chunks']))], | |
ids=[f"{document_id}_chunk_{i}" for i in range(len(processed_doc['chunks']))] | |
) | |
return processed_doc | |
finally: | |
if file_path.exists(): | |
file_path.unlink() | |
async def store_chat_history( | |
conversation_id: str, | |
query: str, | |
response: str, | |
context: List[str], | |
sources: List[Dict], | |
llm_provider: str | |
): | |
"""Store chat history in database""" | |
async with aiosqlite.connect('chat_history.db') as db: | |
await db.execute( | |
'''INSERT INTO chat_history | |
(conversation_id, query, response, context, sources, llm_provider) | |
VALUES (?, ?, ?, ?, ?, ?)''', | |
(conversation_id, query, response, json.dumps(context), | |
json.dumps(sources), llm_provider) | |
) | |
await db.commit() | |
# Endpoints | |
async def upload_documents( | |
files: List[UploadFile] = File(...), | |
background_tasks: BackgroundTasks = BackgroundTasks() | |
): | |
"""Upload and process multiple documents""" | |
try: | |
vector_store, _ = await get_vector_store() | |
upload_dir = Path("temp_uploads") | |
upload_dir.mkdir(exist_ok=True) | |
processed_files = [] | |
failed_files = [] | |
for file in files: | |
try: | |
document_id = str(uuid.uuid4()) | |
if not any(file.filename.lower().endswith(ext) | |
for ext in doc_processor.supported_formats): | |
failed_files.append({ | |
"filename": file.filename, | |
"error": "Unsupported file format" | |
}) | |
continue | |
temp_path = upload_dir / f"{document_id}_{file.filename}" | |
with open(temp_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
background_tasks.add_task( | |
process_and_store_document, | |
temp_path, | |
vector_store, | |
document_id | |
) | |
processed_files.append( | |
DocumentResponse( | |
message="Document queued for processing", | |
document_id=document_id, | |
status="processing", | |
document_info={ | |
"original_filename": file.filename, | |
"size": os.path.getsize(temp_path), | |
"content_type": file.content_type | |
} | |
) | |
) | |
except Exception as e: | |
logger.error(f"Error processing file {file.filename}: {str(e)}") | |
failed_files.append({ | |
"filename": file.filename, | |
"error": str(e) | |
}) | |
return BatchUploadResponse( | |
message=f"Processed {len(processed_files)} documents with {len(failed_files)} failures", | |
processed_files=processed_files, | |
failed_files=failed_files | |
) | |
except Exception as e: | |
logger.error(f"Error in document upload: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
finally: | |
if upload_dir.exists() and not any(upload_dir.iterdir()): | |
upload_dir.rmdir() | |
async def chat_endpoint( | |
request: ChatRequest, | |
background_tasks: BackgroundTasks | |
): | |
"""Chat endpoint with RAG support""" | |
try: | |
vector_store, embedding_model = await get_vector_store() | |
llm = get_llm_instance(request.llm_provider) | |
rag_agent = RAGAgent( | |
llm=llm, | |
embedding=embedding_model, | |
vector_store=vector_store | |
) | |
if request.stream: | |
return StreamingResponse( | |
rag_agent.generate_streaming_response(request.query), | |
media_type="text/event-stream" | |
) | |
response = await rag_agent.generate_response( | |
query=request.query, | |
temperature=request.temperature | |
) | |
conversation_id = request.conversation_id or str(uuid.uuid4()) | |
background_tasks.add_task( | |
store_chat_history, | |
conversation_id, | |
request.query, | |
response.response, | |
response.context_docs, | |
response.sources, | |
request.llm_provider | |
) | |
return ChatResponse( | |
response=response.response, | |
context=response.context_docs, | |
sources=response.sources, | |
conversation_id=conversation_id, | |
timestamp=datetime.now(), | |
relevant_doc_scores=response.scores if hasattr(response, 'scores') else None | |
) | |
except Exception as e: | |
logger.error(f"Error in chat endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_conversation_history(conversation_id: str): | |
"""Get complete conversation history""" | |
async with aiosqlite.connect('chat_history.db') as db: | |
db.row_factory = aiosqlite.Row | |
async with db.execute( | |
'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp', | |
(conversation_id,) | |
) as cursor: | |
history = await cursor.fetchall() | |
if not history: | |
raise HTTPException(status_code=404, detail="Conversation not found") | |
return { | |
"conversation_id": conversation_id, | |
"messages": [dict(row) for row in history] | |
} | |
async def summarize_conversation(request: SummarizeRequest): | |
"""Generate a summary of a conversation""" | |
try: | |
async with aiosqlite.connect('chat_history.db') as db: | |
db.row_factory = aiosqlite.Row | |
async with db.execute( | |
'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp', | |
(request.conversation_id,) | |
) as cursor: | |
history = await cursor.fetchall() | |
if not history: | |
raise HTTPException(status_code=404, detail="Conversation not found") | |
messages = [{ | |
'role': 'user' if msg['query'] else 'assistant', | |
'content': msg['query'] or msg['response'], | |
'timestamp': msg['timestamp'], | |
'sources': json.loads(msg['sources']) if msg['sources'] else None | |
} for msg in history] | |
summary = await summarizer.summarize_conversation( | |
messages, | |
include_metadata=request.include_metadata | |
) | |
return SummaryResponse(**summary) | |
except Exception as e: | |
logger.error(f"Error generating summary: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def submit_feedback( | |
conversation_id: str, | |
feedback_request: FeedbackRequest | |
): | |
"""Submit feedback for a conversation""" | |
try: | |
async with aiosqlite.connect('chat_history.db') as db: | |
await db.execute( | |
'''UPDATE chat_history | |
SET feedback = ?, rating = ? | |
WHERE conversation_id = ?''', | |
(feedback_request.feedback, feedback_request.rating, conversation_id) | |
) | |
await db.commit() | |
return {"status": "Feedback submitted successfully"} | |
except Exception as e: | |
logger.error(f"Error submitting feedback: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
return {"status": "healthy"} | |
# Startup event | |
async def startup_event(): | |
await init_db() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |