from asyncio import sleep from typing import Union from fastapi import FastAPI from fastapi.encoders import jsonable_encoder from fastapi.websockets import WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse, JSONResponse from accelerator import Accelerator from answerer import Answerer from mapper import Mapper try: mapper = Mapper("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") except Exception as e: print(f"ERROR! cannot load Mapper model!\n{e}") answerer = Answerer( model="RWKV-5-World-3B-v2-20231118-ctx16k", vocab="rwkv_vocab_v20230424", strategy="cpu bf16", ctx_limit=16*1024, ) accelerator = Accelerator() app = FastAPI() HTML = """

""" @app.get("/") def index(): return HTMLResponse(HTML) @app.websocket("/accelerate") async def answer(ws: WebSocket): await accelerator.connect(ws) while accelerator.connected(): await sleep(10) @app.post("/map") def map(query: Union[str, None], items: Union[list[str], None]): scores = mapper(query, items) return JSONResponse(jsonable_encoder(scores)) @app.websocket("/answer") async def answer(ws: WebSocket): await ws.accept() try: input = await ws.receive_text() if accelerator.connected(): output = await accelerator.accelerate(input) await ws.send_text(output) else: output = answerer(input, 32) async for el in output: await ws.send_text(el) except WebSocketDisconnect: return await ws.close()