import os import logging import pathlib import time import re from typing import List from pydantic import BaseModel from fastapi import FastAPI from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from .rag import ChatPDF middleware = [ Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=['*'], allow_headers=['*'] ) ] app = FastAPI(middleware=middleware) files_dir = os.path.expanduser("~/wtp_be_files/") os.makedirs(files_dir) session_assistant = ChatPDF() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def astreamer(generator): t0 = time.time() for i in generator: logger.info(f"Chunk being yielded (time {int((time.time()-t0)*1000)}ms)") yield i logger.info(f"Over (time {int((time.time()-t0)*1000)}ms)") @app.get("/query") async def process_input(text: str): generator = None if text and len(text.strip()) > 0: if session_assistant.pdf_count > 0: text = text.strip() streaming_response = session_assistant.ask(text) generator = streaming_response.response_gen else: message = "Please add a PDF document first." generator = re.split(r'(\s)', message) else: message = "The provided query is empty." generator = re.split(r'(\s)', message) return StreamingResponse(astreamer(generator), media_type='text/event-stream') class File(BaseModel): bytes: bytes filename: str @app.post("/upload") def upload(data: File): if data.get('bytes', None) and data.get('filename', None): try: filename = data['filename'] print("Filename: " + filename) path = f"{files_dir}/{filename}" with open(path, "wb") as f: f.write(data['bytes']) session_assistant.ingest(files_dir) pathlib.Path(path).unlink() except Exception as e: logging.error(traceback.format_exc()) message = "Files inserted successfully." generator = re.split(r'(\s)', message) return StreamingResponse(astreamer(generator), media_type='text/event-stream') @app.get("/clear") def ping(): session_assistant.clear() message = "All files have been cleared. The first query may take a little longer." generator = re.split(r'(\s)', message) return StreamingResponse(astreamer(generator), media_type='text/event-stream') @app.get("/") def ping(): return "Pong!"