TalatMasood's picture
Added support for multiple LLMs
e87abff
raw
history blame
12.9 kB
# src/main.py
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator, Dict
import asyncio
import json
import uuid
from datetime import datetime
import aiosqlite
from pathlib import Path
import shutil
import os
# Import custom modules
from .agents.rag_agent import RAGAgent
from .llms.openai_llm import OpenAILanguageModel
from .llms.ollama_llm import OllamaLanguageModel
from .llms.bert_llm import BERTLanguageModel
from .llms.falcon_llm import FalconLanguageModel
from .llms.llama_llm import LlamaLanguageModel
from .embeddings.huggingface_embedding import HuggingFaceEmbedding
from .vectorstores.chroma_vectorstore import ChromaVectorStore
from .utils.document_processor import DocumentProcessor
from .utils.conversation_summarizer import ConversationSummarizer
from .utils.logger import logger
from config.config import settings
app = FastAPI(title="RAG Chatbot API")
# Initialize core components
doc_processor = DocumentProcessor(
chunk_size=1000,
chunk_overlap=200,
max_file_size=10 * 1024 * 1024
)
summarizer = ConversationSummarizer()
# Pydantic models
class ChatRequest(BaseModel):
query: str
llm_provider: str = 'openai'
max_context_docs: int = 3
temperature: float = 0.7
stream: bool = False
conversation_id: Optional[str] = None
class ChatResponse(BaseModel):
response: str
context: Optional[List[str]] = None
sources: Optional[List[Dict[str, str]]] = None
conversation_id: str
timestamp: datetime
relevant_doc_scores: Optional[List[float]] = None
class DocumentResponse(BaseModel):
message: str
document_id: str
status: str
document_info: Optional[dict] = None
class BatchUploadResponse(BaseModel):
message: str
processed_files: List[DocumentResponse]
failed_files: List[dict]
class SummarizeRequest(BaseModel):
conversation_id: str
include_metadata: bool = True
class SummaryResponse(BaseModel):
summary: str
key_insights: Dict
metadata: Optional[Dict] = None
class FeedbackRequest(BaseModel):
rating: int
feedback: Optional[str] = None
# Database initialization
async def init_db():
async with aiosqlite.connect('chat_history.db') as db:
await db.execute('''
CREATE TABLE IF NOT EXISTS chat_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
query TEXT,
response TEXT,
context TEXT,
sources TEXT,
llm_provider TEXT,
feedback TEXT,
rating INTEGER
)
''')
await db.commit()
# Utility functions
def get_llm_instance(provider: str):
"""Get LLM instance based on provider"""
llm_map = {
'openai': lambda: OpenAILanguageModel(api_key=settings.OPENAI_API_KEY),
'ollama': lambda: OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL),
'bert': lambda: BERTLanguageModel(),
'falcon': lambda: FalconLanguageModel(),
'llama': lambda: LlamaLanguageModel(),
}
if provider not in llm_map:
raise ValueError(f"Unsupported LLM provider: {provider}")
return llm_map[provider]()
async def get_vector_store():
"""Initialize and return vector store with embedding model."""
try:
embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL)
vector_store = ChromaVectorStore(
embedding_function=embedding.embed_documents,
persist_directory=settings.CHROMA_PATH
)
return vector_store, embedding
except Exception as e:
logger.error(f"Error initializing vector store: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to initialize vector store")
async def process_and_store_document(
file_path: Path,
vector_store: ChromaVectorStore,
document_id: str
):
"""Process document and store in vector database."""
try:
processed_doc = await doc_processor.process_document(file_path)
vector_store.add_documents(
documents=processed_doc['chunks'],
metadatas=[{
'document_id': document_id,
'chunk_id': i,
'source': str(file_path.name),
'metadata': processed_doc['metadata']
} for i in range(len(processed_doc['chunks']))],
ids=[f"{document_id}_chunk_{i}" for i in range(len(processed_doc['chunks']))]
)
return processed_doc
finally:
if file_path.exists():
file_path.unlink()
async def store_chat_history(
conversation_id: str,
query: str,
response: str,
context: List[str],
sources: List[Dict],
llm_provider: str
):
"""Store chat history in database"""
async with aiosqlite.connect('chat_history.db') as db:
await db.execute(
'''INSERT INTO chat_history
(conversation_id, query, response, context, sources, llm_provider)
VALUES (?, ?, ?, ?, ?, ?)''',
(conversation_id, query, response, json.dumps(context),
json.dumps(sources), llm_provider)
)
await db.commit()
# Endpoints
@app.post("/documents/upload", response_model=BatchUploadResponse)
async def upload_documents(
files: List[UploadFile] = File(...),
background_tasks: BackgroundTasks = BackgroundTasks()
):
"""Upload and process multiple documents"""
try:
vector_store, _ = await get_vector_store()
upload_dir = Path("temp_uploads")
upload_dir.mkdir(exist_ok=True)
processed_files = []
failed_files = []
for file in files:
try:
document_id = str(uuid.uuid4())
if not any(file.filename.lower().endswith(ext)
for ext in doc_processor.supported_formats):
failed_files.append({
"filename": file.filename,
"error": "Unsupported file format"
})
continue
temp_path = upload_dir / f"{document_id}_{file.filename}"
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
background_tasks.add_task(
process_and_store_document,
temp_path,
vector_store,
document_id
)
processed_files.append(
DocumentResponse(
message="Document queued for processing",
document_id=document_id,
status="processing",
document_info={
"original_filename": file.filename,
"size": os.path.getsize(temp_path),
"content_type": file.content_type
}
)
)
except Exception as e:
logger.error(f"Error processing file {file.filename}: {str(e)}")
failed_files.append({
"filename": file.filename,
"error": str(e)
})
return BatchUploadResponse(
message=f"Processed {len(processed_files)} documents with {len(failed_files)} failures",
processed_files=processed_files,
failed_files=failed_files
)
except Exception as e:
logger.error(f"Error in document upload: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
if upload_dir.exists() and not any(upload_dir.iterdir()):
upload_dir.rmdir()
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
request: ChatRequest,
background_tasks: BackgroundTasks
):
"""Chat endpoint with RAG support"""
try:
vector_store, embedding_model = await get_vector_store()
llm = get_llm_instance(request.llm_provider)
rag_agent = RAGAgent(
llm=llm,
embedding=embedding_model,
vector_store=vector_store
)
if request.stream:
return StreamingResponse(
rag_agent.generate_streaming_response(request.query),
media_type="text/event-stream"
)
response = await rag_agent.generate_response(
query=request.query,
temperature=request.temperature
)
conversation_id = request.conversation_id or str(uuid.uuid4())
background_tasks.add_task(
store_chat_history,
conversation_id,
request.query,
response.response,
response.context_docs,
response.sources,
request.llm_provider
)
return ChatResponse(
response=response.response,
context=response.context_docs,
sources=response.sources,
conversation_id=conversation_id,
timestamp=datetime.now(),
relevant_doc_scores=response.scores if hasattr(response, 'scores') else None
)
except Exception as e:
logger.error(f"Error in chat endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/chat/history/{conversation_id}")
async def get_conversation_history(conversation_id: str):
"""Get complete conversation history"""
async with aiosqlite.connect('chat_history.db') as db:
db.row_factory = aiosqlite.Row
async with db.execute(
'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp',
(conversation_id,)
) as cursor:
history = await cursor.fetchall()
if not history:
raise HTTPException(status_code=404, detail="Conversation not found")
return {
"conversation_id": conversation_id,
"messages": [dict(row) for row in history]
}
@app.post("/chat/summarize", response_model=SummaryResponse)
async def summarize_conversation(request: SummarizeRequest):
"""Generate a summary of a conversation"""
try:
async with aiosqlite.connect('chat_history.db') as db:
db.row_factory = aiosqlite.Row
async with db.execute(
'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp',
(request.conversation_id,)
) as cursor:
history = await cursor.fetchall()
if not history:
raise HTTPException(status_code=404, detail="Conversation not found")
messages = [{
'role': 'user' if msg['query'] else 'assistant',
'content': msg['query'] or msg['response'],
'timestamp': msg['timestamp'],
'sources': json.loads(msg['sources']) if msg['sources'] else None
} for msg in history]
summary = await summarizer.summarize_conversation(
messages,
include_metadata=request.include_metadata
)
return SummaryResponse(**summary)
except Exception as e:
logger.error(f"Error generating summary: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat/feedback/{conversation_id}")
async def submit_feedback(
conversation_id: str,
feedback_request: FeedbackRequest
):
"""Submit feedback for a conversation"""
try:
async with aiosqlite.connect('chat_history.db') as db:
await db.execute(
'''UPDATE chat_history
SET feedback = ?, rating = ?
WHERE conversation_id = ?''',
(feedback_request.feedback, feedback_request.rating, conversation_id)
)
await db.commit()
return {"status": "Feedback submitted successfully"}
except Exception as e:
logger.error(f"Error submitting feedback: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
# Startup event
@app.on_event("startup")
async def startup_event():
await init_db()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)