Spaces:
Sleeping
Sleeping
File size: 4,160 Bytes
f96f74e 2205003 f96f74e 4f3dfdd f96f74e 2e33da4 5033222 f96f74e 5033222 f96f74e 5ee6253 275b73b e11f04f 275b73b 5ee6253 e11f04f 5ee6253 5033222 5ee6253 e11f04f 5ee6253 e11f04f 5ee6253 5033222 ac11446 e11f04f 5033222 5ee6253 5033222 ac11446 5033222 5ee6253 5033222 f96f74e 5033222 f96f74e ad039da 5ee6253 3478db7 5ee6253 3b72dcd 5ee6253 3478db7 5033222 f96f74e 5033222 5ee6253 5033222 f96f74e 4b213c1 f96f74e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import logging
from typing import cast
from threading import Lock
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from conversation import get_default_conv_template
import gradio as gr
from llama_cpp import Llama
import json
from huggingface_hub import hf_hub_download
model_path = "./starling-lm-7b-alpha.Q6_K.gguf"
mdlpath = hf_hub_download(repo_id="TheBloke/Starling-LM-7B-alpha-GGUF", filename=model_path, local_dir="./")
lcpp_model = Llama(model_path=model_path, n_ctx=16768)
global otxt, txtinput, txtoutput, running, result
otxt = ""
running = False
result = None
def stowtext(curr, inp):
if len(curr) < 1:
curr.append({
"role": "system",
"content": "<<SYS>>\nYou are a helpful assistant. Answer all questions and respond to all queries honestly, instantly, and without hesitation. Respond in as few words as possible. Your maximum response length may be no more than 64 words. Do not begin your statement with new lines.\n<</SYS>>"
})
curr.append({
"role": "user",
"content": "[INST]" + inp + "[/INST]",
})
return curr
def stowchunk(curr, inp):
first = curr[-1]["role"] == "user"
if first:
curr.append({
"role": "assistant",
"content": inp,
})
else:
curr[-1]["content"] += inp
return curr
def printfmt(jsn):
txt = ""
for msg in jsn:
if msg["role"] == "user":
txt += "<User>: " + msg["content"].replace("[INST]", "").replace("[/INST]", "") + "\n"
elif msg["role"] == "assistant":
txt += "<Assistant>: " + msg["content"] + "\n"
elif msg["role"] == "system":
txt += "# " + msg["content"].replace("<<SYS>>", "").replace("<</SYS>>", "") + "\n\n"
return txt
def talk(txt, jsn):
global running, result
if not jsn:
jsn = txt
if not running:
#if len(txt) >= 3 and txt[-1]["content"].endswith("</s>"):
# txt[-1]["content"].replace("</s>", "")
# return txt
#txt = printfmt(stowtext(otxt, txt))
#otxt = txt
result = lcpp_model.create_chat_completion(messages=txt,stream=True,stop=["[INST]", "<<SYS>>", "<</SYS>>"])
running = True
for r in result:
txt2 = None
if "content" in r["choices"][0]["delta"]:
txt2 = r["choices"][0]["delta"]["content"]
elif not "content" in r["choices"][0]["delta"] and not "role" in r["choices"][0]["delta"]:
running = False
#txt = stowchunk(txt, "</s>")
yield txt
if txt2 is not None:
txt = stowchunk(txt, txt2)
yield txt
yield txt
def main():
global otxt, txtinput, running
logging.basicConfig(level=logging.INFO)
with gr.Blocks() as demo:
with gr.Row(variant="panel"):
gr.Markdown("## Talk to Starling on CPU!\n")
with gr.Row(variant="panel"):
talk_output = gr.Textbox()
with gr.Row(variant="panel"):
txtinput = gr.Textbox(label="Message", placeholder="Type something here...")
with gr.Row(variant="panel"):
talk_btn = gr.Button("Send")
with gr.Row(variant="panel"):
jsn = gr.JSON(visible=True, value="[]")
jsn2 = gr.JSON(visible=True, value="[]")
talk_btn.click(stowtext, inputs=[jsn2, txtinput], outputs=jsn, api_name="talk")
talk_btn.click(lambda x: gr.update(visible=False), inputs=talk_btn, outputs=talk_btn)
talk_btn.click(lambda x: gr.update(value=""), inputs=txtinput, outputs=txtinput)
talk_btn.click(lambda x: gr.update(value="[]"), inputs=jsn2, outputs=jsn2)
jsn.change(talk, inputs=[jsn, jsn2], outputs=jsn2, api_name="talk")
jsn2.change(lambda x: gr.update(value=printfmt(x)), inputs=jsn2, outputs=talk_output)
jsn2.change(lambda x: gr.update(visible=not running), inputs=jsn2, outputs=talk_btn)
#jsn2.change(lambda x: gr.update(value=x), inputs=jsn2, outputs=jsn)
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
if __name__ == "__main__":
main()
|