Sergidev commited on
Commit
de91fd4
1 Parent(s): 4af388b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -3,7 +3,6 @@ from fastapi.responses import HTMLResponse, StreamingResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from modules.pmbl import PMBL
5
  import torch
6
- from queue import Queue
7
  import asyncio
8
 
9
  print(f"CUDA available: {torch.cuda.is_available()}")
@@ -17,7 +16,8 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
17
  app.mount("/templates", StaticFiles(directory="templates"), name="templates")
18
 
19
  pmbl = PMBL("./PMB-7b.Q6_K.gguf", gpu_layers=50)
20
- request_queue = Queue()
 
21
 
22
  @app.head("/")
23
  @app.get("/")
@@ -26,9 +26,10 @@ def index() -> HTMLResponse:
26
  return HTMLResponse(content=f.read())
27
 
28
  async def process_request(user_input: str, mode: str):
29
- history = pmbl.get_chat_history(mode, user_input)
30
- async for chunk in pmbl.generate_response(user_input, history, mode):
31
- yield chunk
 
32
 
33
  @app.post("/chat")
34
  async def chat(request: Request, background_tasks: BackgroundTasks):
@@ -38,11 +39,8 @@ async def chat(request: Request, background_tasks: BackgroundTasks):
38
  mode = data["mode"]
39
 
40
  async def response_generator():
41
- future = asyncio.Future()
42
- request_queue.put((future, user_input, mode))
43
- await future
44
-
45
- async for chunk in future.result():
46
  yield chunk
47
 
48
  return StreamingResponse(response_generator(), media_type="text/plain")
@@ -52,11 +50,10 @@ async def chat(request: Request, background_tasks: BackgroundTasks):
52
 
53
  async def queue_worker():
54
  while True:
55
- if not request_queue.empty():
56
- future, user_input, mode = request_queue.get()
57
- result = process_request(user_input, mode)
58
- future.set_result(result)
59
- await asyncio.sleep(0.1)
60
 
61
  @app.on_event("startup")
62
  async def startup_event():
 
3
  from fastapi.staticfiles import StaticFiles
4
  from modules.pmbl import PMBL
5
  import torch
 
6
  import asyncio
7
 
8
  print(f"CUDA available: {torch.cuda.is_available()}")
 
16
  app.mount("/templates", StaticFiles(directory="templates"), name="templates")
17
 
18
  pmbl = PMBL("./PMB-7b.Q6_K.gguf", gpu_layers=50)
19
+ request_queue = asyncio.Queue()
20
+ processing_lock = asyncio.Lock()
21
 
22
  @app.head("/")
23
  @app.get("/")
 
26
  return HTMLResponse(content=f.read())
27
 
28
  async def process_request(user_input: str, mode: str):
29
+ async with processing_lock:
30
+ history = pmbl.get_chat_history(mode, user_input)
31
+ async for chunk in pmbl.generate_response(user_input, history, mode):
32
+ yield chunk
33
 
34
  @app.post("/chat")
35
  async def chat(request: Request, background_tasks: BackgroundTasks):
 
39
  mode = data["mode"]
40
 
41
  async def response_generator():
42
+ await request_queue.put((user_input, mode))
43
+ async for chunk in await process_request(user_input, mode):
 
 
 
44
  yield chunk
45
 
46
  return StreamingResponse(response_generator(), media_type="text/plain")
 
50
 
51
  async def queue_worker():
52
  while True:
53
+ user_input, mode = await request_queue.get()
54
+ async for _ in process_request(user_input, mode):
55
+ pass
56
+ request_queue.task_done()
 
57
 
58
  @app.on_event("startup")
59
  async def startup_event():