Nagesh Muralidhar
Commiting fastapi code
273a5e1
raw
history blame
7.05 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from typing import List
import uvicorn
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
)
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
import os
import tempfile
import shutil
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import json
app = FastAPI(title="RAG API", description="REST API for RAG-based Q&A system")
# Move CORS middleware setup to the top, before any routes
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"], # React app's address
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Keep the same prompt templates
system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = """\
Context:
{context}
Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)
# Pydantic models for request/response
class Question(BaseModel):
query: str
class Answer(BaseModel):
response: str
context: List[str]
class Config:
json_schema_extra = {
"example": {
"response": "This is a sample response",
"context": ["Context piece 1", "Context piece 2"]
}
}
# Add this class near the top of the file, after imports
class AppState:
def __init__(self):
self.text_splitter = CharacterTextSplitter()
self.vector_db = None
self.qa_pipeline = None
def has_pipeline(self):
return self.qa_pipeline is not None
# Create a global app state
app_state = AppState()
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
async def arun_pipeline(self, user_query: str):
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
formatted_system_prompt = system_role_prompt.create_message()
formatted_user_prompt = user_role_prompt.create_message(
question=user_query, context=context_prompt
)
# Get the full response instead of streaming
response = ""
async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
response += chunk
return {
"response": response,
"context": [str(context[0]) for context in context_list] # Convert context to strings
}
def process_file(file_path: str, file_name: str):
if file_name.lower().endswith('.pdf'):
loader = PDFLoader(file_path)
else:
loader = TextFileLoader(file_path)
documents = loader.load_documents()
texts = app_state.text_splitter.split_texts(documents)
return texts
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
print("Starting file upload process...") # Debug print
if not file:
print("No file provided") # Debug print
raise HTTPException(400, detail="No file provided")
print(f"File received: {file.filename}") # Debug print
if not file.filename.lower().endswith(('.txt', '.pdf')):
print(f"Invalid file type: {file.filename}") # Debug print
raise HTTPException(400, detail="Only .txt and .pdf files are supported")
try:
suffix = f".{file.filename.split('.')[-1]}"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
print(f"Created temp file: {temp_file.name}") # Debug print
content = await file.read()
temp_file.write(content)
temp_file.flush()
try:
print("Processing file...") # Debug print
texts = process_file(temp_file.name, file.filename)
print(f"Got {len(texts)} text chunks") # Debug print
# Initialize vector database
print("Initializing vector database...") # Debug print
app_state.vector_db = VectorDatabase()
app_state.vector_db = await app_state.vector_db.abuild_from_list(texts)
# Initialize QA pipeline
print("Initializing QA pipeline...") # Debug print
chat_openai = ChatOpenAI()
app_state.qa_pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=app_state.vector_db,
llm=chat_openai
)
print("QA pipeline initialized successfully") # Debug print
return {"message": f"Successfully processed {file.filename}", "chunks": len(texts)}
finally:
try:
os.unlink(temp_file.name)
print("Temporary file cleaned up") # Debug print
except Exception as e:
print(f"Error cleaning up temporary file: {e}")
except Exception as e:
print(f"Error during file processing: {str(e)}") # Debug print
raise HTTPException(
status_code=500,
detail=f"Error processing file: {str(e)}"
)
@app.post("/query", response_model=Answer)
async def query(question: Question):
print(f"Received query: {question.query}") # Debug print
print(f"QA Pipeline exists: {app_state.has_pipeline()}") # Debug print
if not app_state.has_pipeline():
print("No QA pipeline available") # Debug print
raise HTTPException(
status_code=400,
detail="Please upload a document first"
)
try:
print("Starting query pipeline...") # Debug print
result = await app_state.qa_pipeline.arun_pipeline(question.query)
print(f"Generated result: {result}") # Debug print
return result
except Exception as e:
print(f"Error in query endpoint: {str(e)}") # Debug print
raise HTTPException(
status_code=500,
detail=f"Error processing query: {str(e)}"
)
@app.get("/status")
async def get_status():
return {
"ready": app_state.has_pipeline(),
"message": "Document loaded and ready for queries" if app_state.has_pipeline() else "No document loaded"
}
if __name__ == "__main__":
uvicorn.run(
"api:app",
host="0.0.0.0",
port=8000,
reload=True, # Enable auto-reload
reload_dirs=["./"] # Watch current directory for changes
)