matthoffner commited on
Commit
8263466
·
1 Parent(s): 8bf9fb0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -12
main.py CHANGED
@@ -5,13 +5,11 @@ import uvicorn
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from sse_starlette.sse import EventSourceResponse
8
- from ctransformers import AutoModelForCausalLM
9
  from pydantic import BaseModel
10
 
11
- llm = AutoModelForCausalLM.from_pretrained("danforbes/santacoder-ggml-q4_1",
12
- model_file="santacoder-ggml-q4_1.bin",
13
- model_type="starcoder")
14
- app = fastapi.FastAPI()
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -32,14 +30,13 @@ class ChatCompletionRequest(BaseModel):
32
 
33
  @app.post("/v1/chat/completions")
34
  async def chat(request: ChatCompletionRequest, response_mode=None):
35
- tokens = llm.tokenize(request.prompt)
36
- async def server_sent_events(chat_chunks, llm):
37
- for token in llm.generate(chat_chunks):
38
- yield dict(data=llm.detokenize(token))
39
  yield dict(data="[DONE]")
40
 
41
- return EventSourceResponse(server_sent_events(tokens, llm))
42
 
43
  if __name__ == "__main__":
44
- uvicorn.run(app, host="0.0.0.0", port=8000)
45
-
 
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from sse_starlette.sse import EventSourceResponse
8
+ from ctransformers.langchain import CTransformers
9
  from pydantic import BaseModel
10
 
11
+ llm = CTransformers(model='ggml-model-q4_1.bin', model_type='starcoder')
12
+ app = fastapi.FastAPI(title="Santacoder")
 
 
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
 
30
 
31
  @app.post("/v1/chat/completions")
32
  async def chat(request: ChatCompletionRequest, response_mode=None):
33
+ completion = llm(request.prompt)
34
+ async def server_sent_events(chat_chunks):
35
+ for chat_chunk in chat_chunks:
36
+ yield dict(data=json.dumps(chat_chunk))
37
  yield dict(data="[DONE]")
38
 
39
+ return EventSourceResponse(server_sent_events(completion))
40
 
41
  if __name__ == "__main__":
42
+ uvicorn.run(app, host="0.0.0.0", port=8000)