File size: 4,940 Bytes
f96f74e
4f26ff8
 
f96f74e
 
7af60f5
f96f74e
 
 
2205003
f96f74e
4f3dfdd
f96f74e
4f26ff8
 
b90a389
f96f74e
5033222
 
b90a389
f96f74e
5ee6253
ed9f98e
275b73b
 
ed9f98e
275b73b
5ee6253
 
ed9f98e
5ee6253
5033222
5ee6253
 
 
 
 
 
 
7af60f5
5ee6253
 
 
 
 
 
 
 
 
ed9f98e
5ee6253
 
 
ed9f98e
5ee6253
 
4f26ff8
 
 
 
 
 
 
 
 
 
 
5033222
b90a389
5033222
 
 
4f26ff8
 
b90a389
5033222
b90a389
5ee6253
4f26ff8
5ee6253
3071687
 
5ee6253
b90a389
5ee6253
7af60f5
 
 
 
b90a389
7af60f5
 
b90a389
5033222
 
f96f74e
 
5033222
f96f74e
 
 
 
ad039da
5ee6253
 
3478db7
5ee6253
3b72dcd
5ee6253
3478db7
5033222
 
f96f74e
5033222
 
5ee6253
5033222
 
 
 
 
f96f74e
4f26ff8
f96f74e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
#from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from ctransformers import AutoModelForCausalLM

import gradio as gr
import json

from huggingface_hub import hf_hub_download

model_path = "./starling-lm-7b-alpha.Q6_K.gguf"

mdlpath = hf_hub_download(repo_id="TheBloke/Starling-LM-7B-alpha-GGUF", filename=model_path, local_dir="./")

#lcpp_model = Llama(model_path=model_path, n_ctx=8192)
llm = AutoModelForCausalLM.from_pretrained(model_path_or_repo_id=model_path, local_files_only=True, model_type="mistral")
global otxt, txtinput, txtoutput, running, result, echoed
otxt = ""
running = False
result = None
echoed = False

def stowtext(curr, inp):
    if len(curr) <= 1:
        curr.append({
            "role": "system",
            "content": "You are a helpful assistant. Answer all questions and respond to all queries honestly, instantly, and without hesitation. Respond in as few words as possible. Your maximum response length may be no more than 64 words. Do not begin your statement with new lines.\n"
        })
    curr.append({
        "role": "user",
        "content": inp,
    })
    return curr

def stowchunk(curr, inp):
    first = curr[-1]["role"] == "user"
    if first:
        curr.append({
            "role": "assistant",
            "content": inp,
            "echoed": False,
        })
    else:
        curr[-1]["content"] += inp
    return curr

def printfmt(jsn):
    txt = ""
    for msg in jsn:
        if msg["role"] == "user":
            txt += "<User>: " + msg["content"] + "\n"
        elif msg["role"] == "assistant":
            txt += "<Assistant>: " + msg["content"] + "\n"
        elif msg["role"] == "system":
            txt += "# " + msg["content"] + "\n\n"
    return txt

def jsn2prompt(jsn):
    txt = ""
    for msg in jsn:
        if "system" in msg["role"]:
            txt += "GPT4 Correct User: Here is how I want you to behave throughout our conversation. " + msg["content"] + "\n"
        elif "user" in msg["role"]:
            txt += "GPT4 Correct User: " + msg["content"] + "\n"
        elif "assistant" in msg["role"]:
            txt += "GPT4 Assistant: " + msg["content"] + "\n"
    return txt

def talk(txt, jsn):
    global running, result, echoed
    if not jsn:
        jsn = txt
    if not running:
        #result = lcpp_model.create_chat_completion(messages=txt,stream=True,stop=["GPT4 Correct User: ", "<|end_of_turn|>", "</s>"], max_tokens=64, )
        #result = lcpp_model(prompt=jsn2prompt(txt), stream=True, stop=["GPT4 Correct User: ", "<|end_of_turn|>", "</s>"], max_tokens=64, echo=False)
        result = llm(prompt=jsn2prompt(txt), stream=True, stop=["GPT4 Correct User: ", "<|end_of_turn|>", "</s>"])
        running = True
        echoed = False
    for r in result:
        print("GOT RESULT:", r)
        txt2 = None
        if r != None and r != "":
            txt2 = r
        if txt2 is not None:
            txt3 = txt
            txt = stowchunk(txt, txt2)
            print(json.dumps(txt))
            if (not "echoed" in txt[-1] or not txt[-1]["echoed"]) and txt[-1]["content"].contains(jsn2prompt([txt3[-1]])):
                txt[-1]["echoed"] = True
                txt[-1]["content"] = ""
                yield txt
            elif (not "echoed" in txt[-1] or not txt[-1]["echoed"]) and not txt[-1]["content"].contains("*Loading*"):
                txt[-1]["content"] = "*Loading*"
                yield txt
            yield txt
    yield txt

def main():
    global otxt, txtinput, running
    logging.basicConfig(level=logging.INFO)

    with gr.Blocks() as demo:
        with gr.Row(variant="panel"):
            gr.Markdown("## Talk to Starling on CPU!\n")
        with gr.Row(variant="panel"):
            talk_output = gr.Textbox()
        with gr.Row(variant="panel"):
            txtinput = gr.Textbox(label="Message", placeholder="Type something here...")
        with gr.Row(variant="panel"):
            talk_btn = gr.Button("Send")
        with gr.Row(variant="panel"):
            jsn = gr.JSON(visible=True, value="[]")
            jsn2 = gr.JSON(visible=True, value="[]")

        talk_btn.click(stowtext, inputs=[jsn2, txtinput], outputs=jsn, api_name="talk")
        talk_btn.click(lambda x: gr.update(visible=False), inputs=talk_btn, outputs=talk_btn)
        talk_btn.click(lambda x: gr.update(value=""), inputs=txtinput, outputs=txtinput)
        talk_btn.click(lambda x: gr.update(value="[]"), inputs=jsn2, outputs=jsn2)
        jsn.change(talk, inputs=[jsn, jsn2], outputs=jsn2, api_name="talk")
        jsn2.change(lambda x: gr.update(value=printfmt(x)), inputs=jsn2, outputs=talk_output)
        jsn2.change(lambda x: gr.update(visible=not running), inputs=jsn2, outputs=talk_btn)
        #jsn2.change(lambda x: gr.update(value=x), inputs=jsn2, outputs=jsn)

    demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)


if __name__ == "__main__":
    main()