# src/main.py from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel from typing import List, Optional, AsyncGenerator, Dict import asyncio import json import uuid from datetime import datetime import aiosqlite from pathlib import Path import shutil import os # Import custom modules from .agents.rag_agent import RAGAgent from .llms.openai_llm import OpenAILanguageModel from .llms.ollama_llm import OllamaLanguageModel from .llms.bert_llm import BERTLanguageModel from .llms.falcon_llm import FalconLanguageModel from .llms.llama_llm import LlamaLanguageModel from .embeddings.huggingface_embedding import HuggingFaceEmbedding from .vectorstores.chroma_vectorstore import ChromaVectorStore from .utils.document_processor import DocumentProcessor from .utils.conversation_summarizer import ConversationSummarizer from .utils.logger import logger from config.config import settings app = FastAPI(title="RAG Chatbot API") # Initialize core components doc_processor = DocumentProcessor( chunk_size=1000, chunk_overlap=200, max_file_size=10 * 1024 * 1024 ) summarizer = ConversationSummarizer() # Pydantic models class ChatRequest(BaseModel): query: str llm_provider: str = 'openai' max_context_docs: int = 3 temperature: float = 0.7 stream: bool = False conversation_id: Optional[str] = None class ChatResponse(BaseModel): response: str context: Optional[List[str]] = None sources: Optional[List[Dict[str, str]]] = None conversation_id: str timestamp: datetime relevant_doc_scores: Optional[List[float]] = None class DocumentResponse(BaseModel): message: str document_id: str status: str document_info: Optional[dict] = None class BatchUploadResponse(BaseModel): message: str processed_files: List[DocumentResponse] failed_files: List[dict] class SummarizeRequest(BaseModel): conversation_id: str include_metadata: bool = True class SummaryResponse(BaseModel): summary: str key_insights: Dict metadata: Optional[Dict] = None class FeedbackRequest(BaseModel): rating: int feedback: Optional[str] = None # Database initialization async def init_db(): async with aiosqlite.connect('chat_history.db') as db: await db.execute(''' CREATE TABLE IF NOT EXISTS chat_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, conversation_id TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, query TEXT, response TEXT, context TEXT, sources TEXT, llm_provider TEXT, feedback TEXT, rating INTEGER ) ''') await db.commit() # Utility functions def get_llm_instance(provider: str): """Get LLM instance based on provider""" llm_map = { 'openai': lambda: OpenAILanguageModel(api_key=settings.OPENAI_API_KEY), 'ollama': lambda: OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL), 'bert': lambda: BERTLanguageModel(), 'falcon': lambda: FalconLanguageModel(), 'llama': lambda: LlamaLanguageModel(), } if provider not in llm_map: raise ValueError(f"Unsupported LLM provider: {provider}") return llm_map[provider]() async def get_vector_store(): """Initialize and return vector store with embedding model.""" try: embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL) vector_store = ChromaVectorStore( embedding_function=embedding.embed_documents, persist_directory=settings.CHROMA_PATH ) return vector_store, embedding except Exception as e: logger.error(f"Error initializing vector store: {str(e)}") raise HTTPException(status_code=500, detail="Failed to initialize vector store") async def process_and_store_document( file_path: Path, vector_store: ChromaVectorStore, document_id: str ): """Process document and store in vector database.""" try: processed_doc = await doc_processor.process_document(file_path) vector_store.add_documents( documents=processed_doc['chunks'], metadatas=[{ 'document_id': document_id, 'chunk_id': i, 'source': str(file_path.name), 'metadata': processed_doc['metadata'] } for i in range(len(processed_doc['chunks']))], ids=[f"{document_id}_chunk_{i}" for i in range(len(processed_doc['chunks']))] ) return processed_doc finally: if file_path.exists(): file_path.unlink() async def store_chat_history( conversation_id: str, query: str, response: str, context: List[str], sources: List[Dict], llm_provider: str ): """Store chat history in database""" async with aiosqlite.connect('chat_history.db') as db: await db.execute( '''INSERT INTO chat_history (conversation_id, query, response, context, sources, llm_provider) VALUES (?, ?, ?, ?, ?, ?)''', (conversation_id, query, response, json.dumps(context), json.dumps(sources), llm_provider) ) await db.commit() # Endpoints @app.post("/documents/upload", response_model=BatchUploadResponse) async def upload_documents( files: List[UploadFile] = File(...), background_tasks: BackgroundTasks = BackgroundTasks() ): """Upload and process multiple documents""" try: vector_store, _ = await get_vector_store() upload_dir = Path("temp_uploads") upload_dir.mkdir(exist_ok=True) processed_files = [] failed_files = [] for file in files: try: document_id = str(uuid.uuid4()) if not any(file.filename.lower().endswith(ext) for ext in doc_processor.supported_formats): failed_files.append({ "filename": file.filename, "error": "Unsupported file format" }) continue temp_path = upload_dir / f"{document_id}_{file.filename}" with open(temp_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) background_tasks.add_task( process_and_store_document, temp_path, vector_store, document_id ) processed_files.append( DocumentResponse( message="Document queued for processing", document_id=document_id, status="processing", document_info={ "original_filename": file.filename, "size": os.path.getsize(temp_path), "content_type": file.content_type } ) ) except Exception as e: logger.error(f"Error processing file {file.filename}: {str(e)}") failed_files.append({ "filename": file.filename, "error": str(e) }) return BatchUploadResponse( message=f"Processed {len(processed_files)} documents with {len(failed_files)} failures", processed_files=processed_files, failed_files=failed_files ) except Exception as e: logger.error(f"Error in document upload: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) finally: if upload_dir.exists() and not any(upload_dir.iterdir()): upload_dir.rmdir() @app.post("/chat", response_model=ChatResponse) async def chat_endpoint( request: ChatRequest, background_tasks: BackgroundTasks ): """Chat endpoint with RAG support""" try: vector_store, embedding_model = await get_vector_store() llm = get_llm_instance(request.llm_provider) rag_agent = RAGAgent( llm=llm, embedding=embedding_model, vector_store=vector_store ) if request.stream: return StreamingResponse( rag_agent.generate_streaming_response(request.query), media_type="text/event-stream" ) response = await rag_agent.generate_response( query=request.query, temperature=request.temperature ) conversation_id = request.conversation_id or str(uuid.uuid4()) background_tasks.add_task( store_chat_history, conversation_id, request.query, response.response, response.context_docs, response.sources, request.llm_provider ) return ChatResponse( response=response.response, context=response.context_docs, sources=response.sources, conversation_id=conversation_id, timestamp=datetime.now(), relevant_doc_scores=response.scores if hasattr(response, 'scores') else None ) except Exception as e: logger.error(f"Error in chat endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/chat/history/{conversation_id}") async def get_conversation_history(conversation_id: str): """Get complete conversation history""" async with aiosqlite.connect('chat_history.db') as db: db.row_factory = aiosqlite.Row async with db.execute( 'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp', (conversation_id,) ) as cursor: history = await cursor.fetchall() if not history: raise HTTPException(status_code=404, detail="Conversation not found") return { "conversation_id": conversation_id, "messages": [dict(row) for row in history] } @app.post("/chat/summarize", response_model=SummaryResponse) async def summarize_conversation(request: SummarizeRequest): """Generate a summary of a conversation""" try: async with aiosqlite.connect('chat_history.db') as db: db.row_factory = aiosqlite.Row async with db.execute( 'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp', (request.conversation_id,) ) as cursor: history = await cursor.fetchall() if not history: raise HTTPException(status_code=404, detail="Conversation not found") messages = [{ 'role': 'user' if msg['query'] else 'assistant', 'content': msg['query'] or msg['response'], 'timestamp': msg['timestamp'], 'sources': json.loads(msg['sources']) if msg['sources'] else None } for msg in history] summary = await summarizer.summarize_conversation( messages, include_metadata=request.include_metadata ) return SummaryResponse(**summary) except Exception as e: logger.error(f"Error generating summary: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/chat/feedback/{conversation_id}") async def submit_feedback( conversation_id: str, feedback_request: FeedbackRequest ): """Submit feedback for a conversation""" try: async with aiosqlite.connect('chat_history.db') as db: await db.execute( '''UPDATE chat_history SET feedback = ?, rating = ? WHERE conversation_id = ?''', (feedback_request.feedback, feedback_request.rating, conversation_id) ) await db.commit() return {"status": "Feedback submitted successfully"} except Exception as e: logger.error(f"Error submitting feedback: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy"} # Startup event @app.on_event("startup") async def startup_event(): await init_db() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)