kth-qa / main.py
erseux's picture
new huggingface structure
cc0bd0f
raw
history blame
2.57 kB
import json
import logging
from magic.conversational import question_handler
from schema import Answer
logger = logging.getLogger()
logging.basicConfig(encoding='utf-8', level=logging.INFO)
from pathlib import Path
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from starlette.routing import WebSocketRoute
import uvicorn
from schema import Question
from config import State
import arel
# --- Setup ---
# hot reload
async def reload_data():
print("Reloading server data...")
BASE_PATH = Path(__file__).resolve().parent
static_path = str(BASE_PATH / "static")
template_path = str(BASE_PATH / "templates")
hotreload = arel.HotReload(
paths=[
arel.Path(static_path),
arel.Path(template_path),
],
)
state = State()
app = FastAPI(
routes=[WebSocketRoute("/hot-reload", hotreload, name="hot-reload")],
on_startup=[hotreload.startup],
on_shutdown=[hotreload.shutdown],
)
# templates
app.mount("/static", StaticFiles(directory="static"), name="static")
BASE_PATH = Path(__file__).resolve().parent
templates = Jinja2Templates(directory=template_path)
templates.env.globals["DEBUG"] = True
templates.env.globals["hotreload"] = hotreload
# CORS
origins = [
"http://localhost",
"http://localhost:5001",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# test questions
with open("test_response.json", "r") as f:
test_questions = json.load(f)
# --- Routes ---
@app.get("/", response_class=HTMLResponse)
def index(request: Request):
return templates.TemplateResponse(
"index.html",
{"request": request}
)
@app.post("/api/ask", response_class=JSONResponse)
async def ask(question: Question):
question_str = question.question
if question_str in test_questions:
return test_questions[question_str]
answer = None
try:
answer: Answer = await question_handler(question, state)
except Exception as e:
logger.exception(e)
if not answer:
return JSONResponse(status_code=404, content={"answer": "Something went wrong."})
return answer.dict(include={"answer", "urls"})
if __name__ == "__main__":
uvicorn.run("kth_qa:app", host="localhost", port=5001, reload=True, reload_excludes=['files/', 'logs/'], reload_dirs=['/templates', '/static'])