File size: 2,703 Bytes
0d67dc2
 
 
 
cd842ff
0d67dc2
5ad3bc3
4228071
0d67dc2
cd842ff
0d67dc2
fdb3b96
8bd4741
 
 
0d67dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07d7cbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716d802
 
 
 
31bf9c0
716d802
a2f46f0
07d7cbc
716d802
a2f46f0
716d802
0d67dc2
 
 
cd842ff
 
 
 
0d67dc2
cd842ff
 
 
7229992
cd842ff
 
 
 
0d67dc2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import fastapi
import json
import markdown
import uvicorn
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from ctransformers import AutoModelForCausalLM
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse

config = {"max_seq_len": 4096}
llm = AutoModelForCausalLM.from_pretrained('TheBloke/MPT-7B-Storywriter-GGML',
                                           model_file='mpt-7b-storywriter.ggmlv3.q4_0.bin',
                                           model_type='mpt')
app = fastapi.FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
async def index():
    with open("README.md", "r", encoding="utf-8") as readme_file:
        md_template_string = readme_file.read()
    html_content = markdown.markdown(md_template_string)
    return HTMLResponse(content=html_content, status_code=200)

class ChatCompletionRequest(BaseModel):
    prompt: str


@app.get("/demo")
async def demo():
    html_content = """
    <!DOCTYPE html>
    <html>
    <head>
      <style>
      #logs {
        background-color: black;
        color:white;
        height:600px;
        overflow-x: hidden;
        overflow-y: auto;
        text-align: left;
        padding-left:10px;
      }
      </style>
    </head>
    
    <body>
    
    <h1>StoryWriter Demo</h1>
    <div id="logs">
    </div>
    
    <script>
      var source = new EventSource("http://localhost:8000/stream");
      source.onmessage = function(event) {
        document.getElementById("logs").innerHTML += event.data + "<br>";
      };
    </script>
    
    </body>
    </html>
    """
    return HTMLResponse(content=html_content, status_code=200)


@app.get("/stream")
async def chat(prompt = "Once upon a time there was a "):
    completion = llm(prompt)
    async def server_sent_events(chat_chunks):
        yield prompt
        for chat_chunk in chat_chunks:
            yield chat_chunk
        yield ""

    return StreamingResponse(server_sent_events(completion))

@app.post("/v1/chat/completions")
async def chat(request: ChatCompletionRequest, response_mode=None):
    completion = llm(request.prompt)

    async def server_sent_events(
        chat_chunks,
    ):
        for chat_chunk in chat_chunks:
            yield dict(data=json.dumps(chat_chunk))
        yield dict(data="[DONE]")

    chunks = completion

    return EventSourceResponse(
        server_sent_events(chunks),
    )


if __name__ == "__main__":
  uvicorn.run(app, host="0.0.0.0", port=8000)