matthoffner commited on
Commit
6349229
Β·
1 Parent(s): dd98bd4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +11 -7
main.py CHANGED
@@ -5,10 +5,13 @@ import uvicorn
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from sse_starlette.sse import EventSourceResponse
 
8
  from ctransformers.langchain import CTransformers
9
  from pydantic import BaseModel
10
 
11
- llm = CTransformers(model='ggml-model-q4_1.bin', model_type='starcoder')
 
 
12
  app = fastapi.FastAPI()
13
  app.add_middleware(
14
  CORSMiddleware,
@@ -30,13 +33,14 @@ class ChatCompletionRequest(BaseModel):
30
 
31
  @app.post("/v1/chat/completions")
32
  async def chat(request: ChatCompletionRequest, response_mode=None):
33
- completion = llm(request.prompt)
34
- async def server_sent_events(chat_chunks):
35
- for chat_chunk in chat_chunks:
36
- yield dict(data=json.dumps(chat_chunk))
37
- yield dict(data="[DONE]")
 
38
 
39
- return EventSourceResponse(server_sent_events(completion))
40
 
41
  if __name__ == "__main__":
42
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from sse_starlette.sse import EventSourceResponse
8
+ from ctransformers import AutoModelForCausalLM
9
  from ctransformers.langchain import CTransformers
10
  from pydantic import BaseModel
11
 
12
+ llm = AutoModelForCausalLM.from_pretrained("TheBloke/gorilla-7B-GGML",
13
+ model_file="Gorilla-7B.ggmlv3.q4_0.bin",
14
+ model_type="llama")
15
  app = fastapi.FastAPI()
16
  app.add_middleware(
17
  CORSMiddleware,
 
33
 
34
  @app.post("/v1/chat/completions")
35
  async def chat(request: ChatCompletionRequest, response_mode=None):
36
+ tokens = llm.tokenize(prompt)
37
+ async def server_sent_events(chat_chunks, llm):
38
+ yield prompt
39
+ for chat_chunk in llm.generate(chat_chunks):
40
+ yield llm.detokenize(chat_chunk)
41
+ yield ""
42
 
43
+ return EventSourceResponse(server_sent_events(tokens, llm))
44
 
45
  if __name__ == "__main__":
46
  uvicorn.run(app, host="0.0.0.0", port=8000)