from fastapi import FastAPI, UploadFile, File, HTTPException from pydantic import BaseModel from typing import List import uvicorn from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader from aimakerspace.openai_utils.prompts import ( UserRolePrompt, SystemRolePrompt, ) from aimakerspace.vectordatabase import VectorDatabase from aimakerspace.openai_utils.chatmodel import ChatOpenAI import os import tempfile import shutil from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse import json app = FastAPI(title="RAG API", description="REST API for RAG-based Q&A system") # Move CORS middleware setup to the top, before any routes app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000"], # React app's address allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Keep the same prompt templates system_template = """\ Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer.""" system_role_prompt = SystemRolePrompt(system_template) user_prompt_template = """\ Context: {context} Question: {question} """ user_role_prompt = UserRolePrompt(user_prompt_template) # Pydantic models for request/response class Question(BaseModel): query: str class Answer(BaseModel): response: str context: List[str] class Config: json_schema_extra = { "example": { "response": "This is a sample response", "context": ["Context piece 1", "Context piece 2"] } } # Add this class near the top of the file, after imports class AppState: def __init__(self): self.text_splitter = CharacterTextSplitter() self.vector_db = None self.qa_pipeline = None def has_pipeline(self): return self.qa_pipeline is not None # Create a global app state app_state = AppState() class RetrievalAugmentedQAPipeline: def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None: self.llm = llm self.vector_db_retriever = vector_db_retriever async def arun_pipeline(self, user_query: str): context_list = self.vector_db_retriever.search_by_text(user_query, k=4) context_prompt = "" for context in context_list: context_prompt += context[0] + "\n" formatted_system_prompt = system_role_prompt.create_message() formatted_user_prompt = user_role_prompt.create_message( question=user_query, context=context_prompt ) # Get the full response instead of streaming response = "" async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]): response += chunk return { "response": response, "context": [str(context[0]) for context in context_list] # Convert context to strings } def process_file(file_path: str, file_name: str): if file_name.lower().endswith('.pdf'): loader = PDFLoader(file_path) else: loader = TextFileLoader(file_path) documents = loader.load_documents() texts = app_state.text_splitter.split_texts(documents) return texts @app.post("/upload") async def upload_file(file: UploadFile = File(...)): print("Starting file upload process...") # Debug print if not file: print("No file provided") # Debug print raise HTTPException(400, detail="No file provided") print(f"File received: {file.filename}") # Debug print if not file.filename.lower().endswith(('.txt', '.pdf')): print(f"Invalid file type: {file.filename}") # Debug print raise HTTPException(400, detail="Only .txt and .pdf files are supported") try: suffix = f".{file.filename.split('.')[-1]}" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: print(f"Created temp file: {temp_file.name}") # Debug print content = await file.read() temp_file.write(content) temp_file.flush() try: print("Processing file...") # Debug print texts = process_file(temp_file.name, file.filename) print(f"Got {len(texts)} text chunks") # Debug print # Initialize vector database print("Initializing vector database...") # Debug print app_state.vector_db = VectorDatabase() app_state.vector_db = await app_state.vector_db.abuild_from_list(texts) # Initialize QA pipeline print("Initializing QA pipeline...") # Debug print chat_openai = ChatOpenAI() app_state.qa_pipeline = RetrievalAugmentedQAPipeline( vector_db_retriever=app_state.vector_db, llm=chat_openai ) print("QA pipeline initialized successfully") # Debug print return {"message": f"Successfully processed {file.filename}", "chunks": len(texts)} finally: try: os.unlink(temp_file.name) print("Temporary file cleaned up") # Debug print except Exception as e: print(f"Error cleaning up temporary file: {e}") except Exception as e: print(f"Error during file processing: {str(e)}") # Debug print raise HTTPException( status_code=500, detail=f"Error processing file: {str(e)}" ) @app.post("/query", response_model=Answer) async def query(question: Question): print(f"Received query: {question.query}") # Debug print print(f"QA Pipeline exists: {app_state.has_pipeline()}") # Debug print if not app_state.has_pipeline(): print("No QA pipeline available") # Debug print raise HTTPException( status_code=400, detail="Please upload a document first" ) try: print("Starting query pipeline...") # Debug print result = await app_state.qa_pipeline.arun_pipeline(question.query) print(f"Generated result: {result}") # Debug print return result except Exception as e: print(f"Error in query endpoint: {str(e)}") # Debug print raise HTTPException( status_code=500, detail=f"Error processing query: {str(e)}" ) @app.get("/status") async def get_status(): return { "ready": app_state.has_pipeline(), "message": "Document loaded and ready for queries" if app_state.has_pipeline() else "No document loaded" } if __name__ == "__main__": uvicorn.run( "api:app", host="0.0.0.0", port=8000, reload=True, # Enable auto-reload reload_dirs=["./"] # Watch current directory for changes )