Spaces:
Running
Running
# src/main.py | |
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks | |
from fastapi.responses import StreamingResponse | |
from typing import List | |
import uuid | |
from datetime import datetime | |
# Import custom modules | |
from src.agents.rag_agent import RAGAgent | |
from src.models.document import AllDocumentsResponse, StoredDocument | |
from src.utils.document_processor import DocumentProcessor | |
from src.utils.conversation_summarizer import ConversationSummarizer | |
from src.utils.logger import logger | |
from src.utils.llm_utils import get_llm_instance, get_vector_store | |
from src.db.mongodb_store import MongoDBStore | |
from src.implementations.document_service import DocumentService | |
from src.models import ( | |
ChatRequest, | |
ChatResponse, | |
DocumentResponse, | |
BatchUploadResponse, | |
SummarizeRequest, | |
SummaryResponse, | |
FeedbackRequest | |
) | |
from config.config import settings | |
app = FastAPI(title="RAG Chatbot API") | |
# Initialize MongoDB | |
mongodb = MongoDBStore(settings.MONGODB_URI) | |
# Initialize core components | |
doc_processor = DocumentProcessor( | |
chunk_size=1000, | |
chunk_overlap=200, | |
max_file_size=10 * 1024 * 1024 | |
) | |
summarizer = ConversationSummarizer() | |
document_service = DocumentService(doc_processor, mongodb) | |
async def upload_documents( | |
files: List[UploadFile] = File(...), | |
background_tasks: BackgroundTasks = BackgroundTasks() | |
): | |
"""Upload and process multiple documents""" | |
try: | |
vector_store, _ = await get_vector_store() | |
response = await document_service.process_documents( | |
files, | |
vector_store, | |
background_tasks | |
) | |
return response | |
except Exception as e: | |
logger.error(f"Error in document upload: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
finally: | |
document_service.cleanup() | |
async def get_all_documents(include_embeddings: bool = False): | |
""" | |
Get all documents stored in the system | |
Args: | |
include_embeddings (bool): Whether to include embeddings in the response | |
""" | |
try: | |
vector_store, _ = await get_vector_store() | |
documents = vector_store.get_all_documents(include_embeddings=include_embeddings) | |
return AllDocumentsResponse( | |
total_documents=len(documents), | |
documents=[ | |
StoredDocument( | |
id=doc['id'], | |
text=doc['text'], | |
embedding=doc.get('embedding'), | |
metadata=doc.get('metadata') | |
) for doc in documents | |
] | |
) | |
except Exception as e: | |
logger.error(f"Error retrieving documents: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_document_chunks(document_id: str): | |
"""Get all chunks for a specific document""" | |
try: | |
vector_store, _ = await get_vector_store() | |
chunks = vector_store.get_document_chunks(document_id) | |
if not chunks: | |
raise HTTPException(status_code=404, detail="Document not found") | |
return { | |
"document_id": document_id, | |
"total_chunks": len(chunks), | |
"chunks": chunks | |
} | |
except Exception as e: | |
logger.error(f"Error retrieving document chunks: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def chat_endpoint( | |
request: ChatRequest, | |
background_tasks: BackgroundTasks | |
): | |
"""Chat endpoint with RAG support""" | |
try: | |
vector_store, embedding_model = await get_vector_store() | |
llm = get_llm_instance(request.llm_provider) | |
rag_agent = RAGAgent( | |
llm=llm, | |
embedding=embedding_model, | |
vector_store=vector_store | |
) | |
if request.stream: | |
return StreamingResponse( | |
rag_agent.generate_streaming_response(request.query), | |
media_type="text/event-stream" | |
) | |
response = await rag_agent.generate_response( | |
query=request.query, | |
temperature=request.temperature | |
) | |
conversation_id = request.conversation_id or str(uuid.uuid4()) | |
# Store chat history in MongoDB | |
await mongodb.store_message( | |
conversation_id=conversation_id, | |
query=request.query, | |
response=response.response, | |
context=response.context_docs, | |
sources=response.sources, | |
llm_provider=request.llm_provider | |
) | |
return ChatResponse( | |
response=response.response, | |
context=response.context_docs, | |
sources=response.sources, | |
conversation_id=conversation_id, | |
timestamp=datetime.now(), | |
relevant_doc_scores=response.scores if hasattr(response, 'scores') else None | |
) | |
except Exception as e: | |
logger.error(f"Error in chat endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_conversation_history(conversation_id: str): | |
"""Get complete conversation history""" | |
history = await mongodb.get_conversation_history(conversation_id) | |
if not history: | |
raise HTTPException(status_code=404, detail="Conversation not found") | |
return { | |
"conversation_id": conversation_id, | |
"messages": history | |
} | |
async def summarize_conversation(request: SummarizeRequest): | |
"""Generate a summary of a conversation""" | |
try: | |
messages = await mongodb.get_messages_for_summary(request.conversation_id) | |
if not messages: | |
raise HTTPException(status_code=404, detail="Conversation not found") | |
summary = await summarizer.summarize_conversation( | |
messages, | |
include_metadata=request.include_metadata | |
) | |
return SummaryResponse(**summary) | |
except Exception as e: | |
logger.error(f"Error generating summary: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def submit_feedback( | |
conversation_id: str, | |
feedback_request: FeedbackRequest | |
): | |
"""Submit feedback for a conversation""" | |
try: | |
success = await mongodb.update_feedback( | |
conversation_id=conversation_id, | |
feedback=feedback_request.feedback, | |
rating=feedback_request.rating | |
) | |
if not success: | |
raise HTTPException(status_code=404, detail="Conversation not found") | |
return {"status": "Feedback submitted successfully"} | |
except Exception as e: | |
logger.error(f"Error submitting feedback: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |