StarlingCPU / code /app.py
Samuel L Meyers
Fix text
ad039da
raw
history blame
2.79 kB
import logging
from typing import cast
from threading import Lock
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from conversation import get_default_conv_template
import gradio as gr
from llama_cpp import Llama
import json
from huggingface_hub import hf_hub_download
model_path = "./starling-lm-7b-alpha.Q6_K.gguf"
mdlpath = hf_hub_download(repo_id="TheBloke/Starling-LM-7B-alpha-GGUF", filename=model_path, local_dir="./")
lcpp_model = Llama(model_path=model_path)
global otxt, txtinput, txtoutput
otxt = ""
def stowtext(curr, inp):
curr.append({
"role": "user",
"content": inp,
})
return [curr, curr]
def stowchunk(curr, inp):
first = curr[-1]["role"] == "user"
if first:
curr.append({
"role": "assistant",
"content": inp,
})
else:
curr[-1]["content"] += inp
return curr
def printfmt(jsn):
txt = ""
for msg in jsn:
if msg["role"] == "user":
txt += "<User>: " + msg["content"] + "\n"
elif msg["role"] == "assistant":
txt += "<Assistant>: " + msg["content"] + "\n"
elif msg["role"] == "system":
txt += "# " + msg["content"] + "\n\n"
return txt
def talk(txt):
result = lcpp_model.create_chat_completion(messages=txt, stop=["</s>", "<|end_of_text|>", "GPT4 User: ", "<|im_sep|>", "\n\n"], stream=True)
for r in result:
txt2 = None
if "content" in r["choices"][0]["delta"]:
txt2 = r["choices"][0]["delta"]["content"]
if txt2.startswith("\n"):
txt2 = txt2[1:]
if txt2 is not None:
txt = stowchunk(txt, txt2)
yield [printfmt(txt), txt]
yield [printfmt(txt), txt]
def main():
global otxt, txtinput
logging.basicConfig(level=logging.INFO)
with gr.Blocks() as demo:
with gr.Row(variant="panel"):
gr.Markdown("## Talk to Starling on CPU!\n")
with gr.Row(variant="panel"):
talk_output = gr.Textbox()
with gr.Row(variant="panel"):
txtinput = gr.Textbox(label="Message", placeholder="Type something here...")
with gr.Row(variant="panel"):
talk_btn = gr.Button("Send")
with gr.Row(variant="panel"):
jsn = gr.JSON(visible=False, value="[]")
jsn2 = gr.JSON(visible=False, value="[]")
talk_btn.click(stowtext, inputs=[jsn2, txtinput], outputs=[jsn, jsn2], api_name="talk")
talk_btn.click(lambda x: gr.update(value=""), inputs=txtinput, outputs=txtinput)
jsn.change(talk, inputs=jsn, outputs=[talk_output, jsn2], api_name="talk")
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
if __name__ == "__main__":
main()