Spaces:
Running
Running
File size: 12,879 Bytes
640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff 640b1c8 e87abff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 |
# src/main.py
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator, Dict
import asyncio
import json
import uuid
from datetime import datetime
import aiosqlite
from pathlib import Path
import shutil
import os
# Import custom modules
from .agents.rag_agent import RAGAgent
from .llms.openai_llm import OpenAILanguageModel
from .llms.ollama_llm import OllamaLanguageModel
from .llms.bert_llm import BERTLanguageModel
from .llms.falcon_llm import FalconLanguageModel
from .llms.llama_llm import LlamaLanguageModel
from .embeddings.huggingface_embedding import HuggingFaceEmbedding
from .vectorstores.chroma_vectorstore import ChromaVectorStore
from .utils.document_processor import DocumentProcessor
from .utils.conversation_summarizer import ConversationSummarizer
from .utils.logger import logger
from config.config import settings
app = FastAPI(title="RAG Chatbot API")
# Initialize core components
doc_processor = DocumentProcessor(
chunk_size=1000,
chunk_overlap=200,
max_file_size=10 * 1024 * 1024
)
summarizer = ConversationSummarizer()
# Pydantic models
class ChatRequest(BaseModel):
query: str
llm_provider: str = 'openai'
max_context_docs: int = 3
temperature: float = 0.7
stream: bool = False
conversation_id: Optional[str] = None
class ChatResponse(BaseModel):
response: str
context: Optional[List[str]] = None
sources: Optional[List[Dict[str, str]]] = None
conversation_id: str
timestamp: datetime
relevant_doc_scores: Optional[List[float]] = None
class DocumentResponse(BaseModel):
message: str
document_id: str
status: str
document_info: Optional[dict] = None
class BatchUploadResponse(BaseModel):
message: str
processed_files: List[DocumentResponse]
failed_files: List[dict]
class SummarizeRequest(BaseModel):
conversation_id: str
include_metadata: bool = True
class SummaryResponse(BaseModel):
summary: str
key_insights: Dict
metadata: Optional[Dict] = None
class FeedbackRequest(BaseModel):
rating: int
feedback: Optional[str] = None
# Database initialization
async def init_db():
async with aiosqlite.connect('chat_history.db') as db:
await db.execute('''
CREATE TABLE IF NOT EXISTS chat_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
query TEXT,
response TEXT,
context TEXT,
sources TEXT,
llm_provider TEXT,
feedback TEXT,
rating INTEGER
)
''')
await db.commit()
# Utility functions
def get_llm_instance(provider: str):
"""Get LLM instance based on provider"""
llm_map = {
'openai': lambda: OpenAILanguageModel(api_key=settings.OPENAI_API_KEY),
'ollama': lambda: OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL),
'bert': lambda: BERTLanguageModel(),
'falcon': lambda: FalconLanguageModel(),
'llama': lambda: LlamaLanguageModel(),
}
if provider not in llm_map:
raise ValueError(f"Unsupported LLM provider: {provider}")
return llm_map[provider]()
async def get_vector_store():
"""Initialize and return vector store with embedding model."""
try:
embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL)
vector_store = ChromaVectorStore(
embedding_function=embedding.embed_documents,
persist_directory=settings.CHROMA_PATH
)
return vector_store, embedding
except Exception as e:
logger.error(f"Error initializing vector store: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to initialize vector store")
async def process_and_store_document(
file_path: Path,
vector_store: ChromaVectorStore,
document_id: str
):
"""Process document and store in vector database."""
try:
processed_doc = await doc_processor.process_document(file_path)
vector_store.add_documents(
documents=processed_doc['chunks'],
metadatas=[{
'document_id': document_id,
'chunk_id': i,
'source': str(file_path.name),
'metadata': processed_doc['metadata']
} for i in range(len(processed_doc['chunks']))],
ids=[f"{document_id}_chunk_{i}" for i in range(len(processed_doc['chunks']))]
)
return processed_doc
finally:
if file_path.exists():
file_path.unlink()
async def store_chat_history(
conversation_id: str,
query: str,
response: str,
context: List[str],
sources: List[Dict],
llm_provider: str
):
"""Store chat history in database"""
async with aiosqlite.connect('chat_history.db') as db:
await db.execute(
'''INSERT INTO chat_history
(conversation_id, query, response, context, sources, llm_provider)
VALUES (?, ?, ?, ?, ?, ?)''',
(conversation_id, query, response, json.dumps(context),
json.dumps(sources), llm_provider)
)
await db.commit()
# Endpoints
@app.post("/documents/upload", response_model=BatchUploadResponse)
async def upload_documents(
files: List[UploadFile] = File(...),
background_tasks: BackgroundTasks = BackgroundTasks()
):
"""Upload and process multiple documents"""
try:
vector_store, _ = await get_vector_store()
upload_dir = Path("temp_uploads")
upload_dir.mkdir(exist_ok=True)
processed_files = []
failed_files = []
for file in files:
try:
document_id = str(uuid.uuid4())
if not any(file.filename.lower().endswith(ext)
for ext in doc_processor.supported_formats):
failed_files.append({
"filename": file.filename,
"error": "Unsupported file format"
})
continue
temp_path = upload_dir / f"{document_id}_{file.filename}"
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
background_tasks.add_task(
process_and_store_document,
temp_path,
vector_store,
document_id
)
processed_files.append(
DocumentResponse(
message="Document queued for processing",
document_id=document_id,
status="processing",
document_info={
"original_filename": file.filename,
"size": os.path.getsize(temp_path),
"content_type": file.content_type
}
)
)
except Exception as e:
logger.error(f"Error processing file {file.filename}: {str(e)}")
failed_files.append({
"filename": file.filename,
"error": str(e)
})
return BatchUploadResponse(
message=f"Processed {len(processed_files)} documents with {len(failed_files)} failures",
processed_files=processed_files,
failed_files=failed_files
)
except Exception as e:
logger.error(f"Error in document upload: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
if upload_dir.exists() and not any(upload_dir.iterdir()):
upload_dir.rmdir()
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
request: ChatRequest,
background_tasks: BackgroundTasks
):
"""Chat endpoint with RAG support"""
try:
vector_store, embedding_model = await get_vector_store()
llm = get_llm_instance(request.llm_provider)
rag_agent = RAGAgent(
llm=llm,
embedding=embedding_model,
vector_store=vector_store
)
if request.stream:
return StreamingResponse(
rag_agent.generate_streaming_response(request.query),
media_type="text/event-stream"
)
response = await rag_agent.generate_response(
query=request.query,
temperature=request.temperature
)
conversation_id = request.conversation_id or str(uuid.uuid4())
background_tasks.add_task(
store_chat_history,
conversation_id,
request.query,
response.response,
response.context_docs,
response.sources,
request.llm_provider
)
return ChatResponse(
response=response.response,
context=response.context_docs,
sources=response.sources,
conversation_id=conversation_id,
timestamp=datetime.now(),
relevant_doc_scores=response.scores if hasattr(response, 'scores') else None
)
except Exception as e:
logger.error(f"Error in chat endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/chat/history/{conversation_id}")
async def get_conversation_history(conversation_id: str):
"""Get complete conversation history"""
async with aiosqlite.connect('chat_history.db') as db:
db.row_factory = aiosqlite.Row
async with db.execute(
'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp',
(conversation_id,)
) as cursor:
history = await cursor.fetchall()
if not history:
raise HTTPException(status_code=404, detail="Conversation not found")
return {
"conversation_id": conversation_id,
"messages": [dict(row) for row in history]
}
@app.post("/chat/summarize", response_model=SummaryResponse)
async def summarize_conversation(request: SummarizeRequest):
"""Generate a summary of a conversation"""
try:
async with aiosqlite.connect('chat_history.db') as db:
db.row_factory = aiosqlite.Row
async with db.execute(
'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp',
(request.conversation_id,)
) as cursor:
history = await cursor.fetchall()
if not history:
raise HTTPException(status_code=404, detail="Conversation not found")
messages = [{
'role': 'user' if msg['query'] else 'assistant',
'content': msg['query'] or msg['response'],
'timestamp': msg['timestamp'],
'sources': json.loads(msg['sources']) if msg['sources'] else None
} for msg in history]
summary = await summarizer.summarize_conversation(
messages,
include_metadata=request.include_metadata
)
return SummaryResponse(**summary)
except Exception as e:
logger.error(f"Error generating summary: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat/feedback/{conversation_id}")
async def submit_feedback(
conversation_id: str,
feedback_request: FeedbackRequest
):
"""Submit feedback for a conversation"""
try:
async with aiosqlite.connect('chat_history.db') as db:
await db.execute(
'''UPDATE chat_history
SET feedback = ?, rating = ?
WHERE conversation_id = ?''',
(feedback_request.feedback, feedback_request.rating, conversation_id)
)
await db.commit()
return {"status": "Feedback submitted successfully"}
except Exception as e:
logger.error(f"Error submitting feedback: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
# Startup event
@app.on_event("startup")
async def startup_event():
await init_db()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |