Spaces:
Sleeping
Sleeping
import logging | |
from typing import cast | |
from threading import Lock | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import torch | |
from conversation import get_default_conv_template | |
import gradio as gr | |
from llama_cpp import Llama | |
import json | |
from huggingface_hub import hf_hub_download | |
model_path = "./starling-lm-7b-alpha.Q6_K.gguf" | |
mdlpath = hf_hub_download(repo_id="TheBloke/Starling-LM-7B-alpha-GGUF", filename=model_path, local_dir="./") | |
lcpp_model = Llama(model_path=model_path, n_ctx=16768) | |
global otxt, txtinput, txtoutput, running, result | |
otxt = "" | |
running = False | |
result = None | |
def stowtext(curr, inp): | |
if len(curr) < 1: | |
curr.append({ | |
"role": "system", | |
"content": "<<SYS>>\nYou are a helpful assistant. Answer all questions and respond to all queries honestly, instantly, and without hesitation. Respond in as few words as possible. Your maximum response length may be no more than 64 words. Do not begin your statement with new lines.\n<</SYS>>" | |
}) | |
curr.append({ | |
"role": "user", | |
"content": "[INST]" + inp + "[/INST]", | |
}) | |
return curr | |
def stowchunk(curr, inp): | |
first = curr[-1]["role"] == "user" | |
if first: | |
curr.append({ | |
"role": "assistant", | |
"content": inp, | |
}) | |
else: | |
curr[-1]["content"] += inp | |
return curr | |
def printfmt(jsn): | |
txt = "" | |
for msg in jsn: | |
if msg["role"] == "user": | |
txt += "<User>: " + msg["content"].replace("[INST]", "").replace("[/INST]", "") + "\n" | |
elif msg["role"] == "assistant": | |
txt += "<Assistant>: " + msg["content"] + "\n" | |
elif msg["role"] == "system": | |
txt += "# " + msg["content"].replace("<<SYS>>", "").replace("<</SYS>>", "") + "\n\n" | |
return txt | |
def talk(txt, jsn): | |
global running, result | |
if not jsn: | |
jsn = txt | |
if not running: | |
#if len(txt) >= 3 and txt[-1]["content"].endswith("</s>"): | |
# txt[-1]["content"].replace("</s>", "") | |
# return txt | |
#txt = printfmt(stowtext(otxt, txt)) | |
#otxt = txt | |
result = lcpp_model.create_chat_completion(messages=txt,stream=True,stop=["[INST]", "<<SYS>>", "<</SYS>>"]) | |
running = True | |
for r in result: | |
txt2 = None | |
if "content" in r["choices"][0]["delta"]: | |
txt2 = r["choices"][0]["delta"]["content"] | |
elif not "content" in r["choices"][0]["delta"] and not "role" in r["choices"][0]["delta"]: | |
running = False | |
#txt = stowchunk(txt, "</s>") | |
yield txt | |
if txt2 is not None: | |
txt = stowchunk(txt, txt2) | |
yield txt | |
yield txt | |
def main(): | |
global otxt, txtinput, running | |
logging.basicConfig(level=logging.INFO) | |
with gr.Blocks() as demo: | |
with gr.Row(variant="panel"): | |
gr.Markdown("## Talk to Starling on CPU!\n") | |
with gr.Row(variant="panel"): | |
talk_output = gr.Textbox() | |
with gr.Row(variant="panel"): | |
txtinput = gr.Textbox(label="Message", placeholder="Type something here...") | |
with gr.Row(variant="panel"): | |
talk_btn = gr.Button("Send") | |
with gr.Row(variant="panel"): | |
jsn = gr.JSON(visible=True, value="[]") | |
jsn2 = gr.JSON(visible=True, value="[]") | |
talk_btn.click(stowtext, inputs=[jsn2, txtinput], outputs=jsn, api_name="talk") | |
talk_btn.click(lambda x: gr.update(visible=False), inputs=talk_btn, outputs=talk_btn) | |
talk_btn.click(lambda x: gr.update(value=""), inputs=txtinput, outputs=txtinput) | |
talk_btn.click(lambda x: gr.update(value="[]"), inputs=jsn2, outputs=jsn2) | |
jsn.change(talk, inputs=[jsn, jsn2], outputs=jsn2, api_name="talk") | |
jsn2.change(lambda x: gr.update(value=printfmt(x)), inputs=jsn2, outputs=talk_output) | |
jsn2.change(lambda x: gr.update(visible=not running), inputs=jsn2, outputs=talk_btn) | |
#jsn2.change(lambda x: gr.update(value=x), inputs=jsn2, outputs=jsn) | |
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True) | |
if __name__ == "__main__": | |
main() | |