|
from __future__ import annotations |
|
from fastapi import FastAPI, WebSocket |
|
from fastapi.responses import StreamingResponse |
|
from dataclasses import dataclass |
|
from typing import Callable |
|
import asyncio |
|
import uuid |
|
from pydantic import BaseModel |
|
|
|
app = FastAPI() |
|
|
|
@dataclass |
|
class Event: |
|
session_id: str |
|
data: str |
|
outputs: asyncio.Queue[str] | None |
|
mode: str |
|
websocket: WebSocket | None = None |
|
completed: bool = False |
|
|
|
queue: list[Event] = [] |
|
active_jobs: list[Event | None] = [None] * 1000 |
|
|
|
|
|
def run_coro_in_background(func: Callable, *args, **kwargs): |
|
event_loop = asyncio.get_event_loop() |
|
return event_loop.create_task(func(*args, **kwargs)) |
|
|
|
async def queue_process(): |
|
while True: |
|
if queue and None in active_jobs: |
|
job_index = active_jobs.index(None) |
|
event = queue.pop(0) |
|
active_jobs[job_index] = event |
|
run_coro_in_background(process_event, event) |
|
continue |
|
await asyncio.sleep(0.05) |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
run_coro_in_background(queue_process) |
|
|
|
async def number_generator(_): |
|
for number in range(1, 501): |
|
message = "Lorem "*(number) |
|
yield message |
|
await asyncio.sleep(0.01) |
|
|
|
async def process_event(event: Event): |
|
async for output in number_generator(event.data): |
|
if event.mode == "sse": |
|
event.outputs.put_nowait(output) |
|
elif event.mode == "ws": |
|
await event.websocket.send_text(output) |
|
if event.mode == "sse": |
|
event.outputs.put_nowait(None) |
|
event.completed = True |
|
active_jobs[active_jobs.index(event)] = None |
|
|
|
class EventData(BaseModel): |
|
data: str |
|
|
|
@app.post("/sse/send") |
|
async def sse_send(data: EventData): |
|
session_id = str(uuid.uuid4()) |
|
event = Event(session_id=session_id, data=data.data, outputs=asyncio.Queue(), mode="sse") |
|
queue.append(event) |
|
return {"session_id": session_id} |
|
|
|
@app.get("/sse/listen") |
|
async def sse_listen(session_id: str): |
|
event = None |
|
while event is None: |
|
for evt in active_jobs: |
|
if evt: |
|
if evt.session_id == session_id: |
|
event = evt |
|
break |
|
await asyncio.sleep(0.05) |
|
|
|
async def event_generator(): |
|
while not event.completed: |
|
output = await event.outputs.get() |
|
if output is None: |
|
break |
|
yield f"data: {output}\n\n" |
|
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream") |
|
|
|
@app.websocket("/ws") |
|
async def websocket_endpoint(websocket: WebSocket): |
|
await websocket.accept() |
|
data = await websocket.receive_text() |
|
session_id = str(uuid.uuid4()) |
|
event = Event(session_id=session_id, data=data, outputs=None, mode="ws", websocket=websocket) |
|
queue.append(event) |
|
|
|
while True: |
|
await asyncio.sleep(1) |
|
if event.completed: |
|
return |
|
|
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|