Spaces:
Paused
Paused
import os | |
import uvicorn | |
import tempfile | |
from openai import AsyncOpenAI | |
from fastapi import FastAPI, Body, UploadFile, File, Depends, HTTPException | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from aimakerspace.openai_utils.prompts import ( | |
UserRolePrompt, | |
SystemRolePrompt, | |
) | |
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader | |
from qdrant_client import QdrantClient | |
from fastapi.security import APIKeyHeader | |
import uuid | |
from typing import Dict, Optional | |
system_template = """\ | |
Use the following context to answer a users question. | |
If you cannot find the answer in the context, say you don't know the answer. | |
""" | |
system_role_prompt = SystemRolePrompt(system_template) | |
user_prompt_template = """\ | |
Context: | |
{context} | |
Question: | |
{question} | |
""" | |
user_role_prompt = UserRolePrompt(user_prompt_template) | |
app = FastAPI() | |
openai = AsyncOpenAI() | |
vector_db = QdrantClient(":memory:") | |
text_splitter = CharacterTextSplitter() | |
sessions: Dict[str, dict] = {} | |
api_key_header = APIKeyHeader(name="X-Session-ID", auto_error=False) | |
async def get_session(session_id: Optional[str] = Depends(api_key_header)): | |
if not session_id: | |
# Create new session | |
session_id = str(uuid.uuid4()) | |
sessions[session_id] = { | |
"vector_db": None, | |
"vector_db_retriever": None, | |
} | |
elif session_id not in sessions: | |
raise HTTPException(status_code=404, detail="Session not found") | |
return session_id, sessions[session_id] | |
def process_file(file: UploadFile): | |
print(f"Processing file: {file.filename}") | |
# Create a temporary file with the correct extension | |
suffix = f".{file.filename.split('.')[-1]}" | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
# Write the uploaded file content to the temporary file | |
content = file.file.read() | |
temp_file.write(content) | |
temp_file.flush() | |
print(f"Created temporary file at: {temp_file.name}") | |
# Create appropriate loader | |
if file.filename.lower().endswith('.pdf'): | |
loader = PDFLoader(temp_file.name) | |
else: | |
loader = TextFileLoader(temp_file.name) | |
try: | |
# Load and process the documents | |
documents = loader.load_documents() | |
texts = text_splitter.split_texts(documents) | |
return texts | |
finally: | |
# Clean up the temporary file | |
try: | |
os.unlink(temp_file.name) | |
except Exception as e: | |
print(f"Error cleaning up temporary file: {e}") | |
async def get_response(msg: str, session_id: str, vector_db: QdrantClient): | |
context_list = vector_db.query( | |
collection_name=session_id, | |
query_text=msg, | |
limit=4, | |
) | |
context_prompt = "" | |
for context in context_list: | |
context_prompt += context.document + "\n" | |
formatted_system_prompt = system_role_prompt.create_message() | |
formatted_user_prompt = user_role_prompt.create_message(question=msg, context=context_prompt) | |
openai_stream = await openai.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
formatted_system_prompt, | |
formatted_user_prompt, | |
], | |
temperature=0.0, | |
stream=True, | |
) | |
async def generate_response(): | |
async for chunk in openai_stream: | |
if chunk.choices[0].delta.content is not None: | |
yield chunk.choices[0].delta.content | |
yield "" | |
return StreamingResponse(generate_response(), media_type="text/event-stream") | |
async def get_bot_response( | |
msg: str = Body(...), | |
session_data: tuple = Depends(get_session) | |
): | |
session_id, _ = session_data | |
print(f"Session ID: {session_id}") | |
response = await get_response(msg, session_id, vector_db) | |
return response | |
async def get_file_response( | |
file: UploadFile = File(..., description="A text file to process"), | |
session_data: tuple = Depends(get_session) | |
): | |
session_id, _ = session_data | |
print(f"Session ID: {session_id}") | |
if not file.filename: | |
return {"error": "No file uploaded"} | |
try: | |
chunks = process_file(file) | |
vector_db.add( | |
collection_name=session_id, | |
documents=chunks, | |
) | |
return { | |
"message": "File processed successfully", | |
"session_id": session_id | |
} | |
except Exception as e: | |
return JSONResponse( | |
status_code=422, | |
content={"detail": str(e)} | |
) | |
app.mount("/", StaticFiles(directory="dist", html=True), name="static") | |
app.get("/")(StaticFiles(directory="dist", html=True)) | |
if __name__ == "__main__": | |
uvicorn.run("server:app") | |