Samuel L Meyers commited on
Commit
5033222
·
1 Parent(s): ad039da
Files changed (1) hide show
  1. code/app.py +26 -13
code/app.py CHANGED
@@ -17,15 +17,17 @@ model_path = "./starling-lm-7b-alpha.Q6_K.gguf"
17
  mdlpath = hf_hub_download(repo_id="TheBloke/Starling-LM-7B-alpha-GGUF", filename=model_path, local_dir="./")
18
 
19
  lcpp_model = Llama(model_path=model_path)
20
- global otxt, txtinput, txtoutput
21
  otxt = ""
 
 
22
 
23
  def stowtext(curr, inp):
24
  curr.append({
25
  "role": "user",
26
  "content": inp,
27
  })
28
- return [curr, curr]
29
 
30
  def stowchunk(curr, inp):
31
  first = curr[-1]["role"] == "user"
@@ -49,21 +51,27 @@ def printfmt(jsn):
49
  txt += "# " + msg["content"] + "\n\n"
50
  return txt
51
 
52
- def talk(txt):
53
- result = lcpp_model.create_chat_completion(messages=txt, stop=["</s>", "<|end_of_text|>", "GPT4 User: ", "<|im_sep|>", "\n\n"], stream=True)
 
 
 
 
 
54
  for r in result:
55
  txt2 = None
56
  if "content" in r["choices"][0]["delta"]:
57
  txt2 = r["choices"][0]["delta"]["content"]
58
- if txt2.startswith("\n"):
59
- txt2 = txt2[1:]
 
60
  if txt2 is not None:
61
  txt = stowchunk(txt, txt2)
62
- yield [printfmt(txt), txt]
63
- yield [printfmt(txt), txt]
64
 
65
  def main():
66
- global otxt, txtinput
67
  logging.basicConfig(level=logging.INFO)
68
 
69
  with gr.Blocks() as demo:
@@ -76,12 +84,17 @@ def main():
76
  with gr.Row(variant="panel"):
77
  talk_btn = gr.Button("Send")
78
  with gr.Row(variant="panel"):
79
- jsn = gr.JSON(visible=False, value="[]")
80
- jsn2 = gr.JSON(visible=False, value="[]")
81
 
82
- talk_btn.click(stowtext, inputs=[jsn2, txtinput], outputs=[jsn, jsn2], api_name="talk")
 
83
  talk_btn.click(lambda x: gr.update(value=""), inputs=txtinput, outputs=txtinput)
84
- jsn.change(talk, inputs=jsn, outputs=[talk_output, jsn2], api_name="talk")
 
 
 
 
85
 
86
  demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
87
 
 
17
  mdlpath = hf_hub_download(repo_id="TheBloke/Starling-LM-7B-alpha-GGUF", filename=model_path, local_dir="./")
18
 
19
  lcpp_model = Llama(model_path=model_path)
20
+ global otxt, txtinput, txtoutput, running, result
21
  otxt = ""
22
+ running = False
23
+ result = None
24
 
25
  def stowtext(curr, inp):
26
  curr.append({
27
  "role": "user",
28
  "content": inp,
29
  })
30
+ return curr
31
 
32
  def stowchunk(curr, inp):
33
  first = curr[-1]["role"] == "user"
 
51
  txt += "# " + msg["content"] + "\n\n"
52
  return txt
53
 
54
+ def talk(txt, jsn):
55
+ global running, result
56
+ if not jsn:
57
+ jsn = txt
58
+ if not running:
59
+ result = lcpp_model.create_chat_completion(messages=txt,stream=True)
60
+ running = True
61
  for r in result:
62
  txt2 = None
63
  if "content" in r["choices"][0]["delta"]:
64
  txt2 = r["choices"][0]["delta"]["content"]
65
+ elif not "content" in r["choices"][0]["delta"] and not "role" in r["choices"][0]["delta"]:
66
+ running = False
67
+ yield txt
68
  if txt2 is not None:
69
  txt = stowchunk(txt, txt2)
70
+ yield txt
71
+ yield txt
72
 
73
  def main():
74
+ global otxt, txtinput, running
75
  logging.basicConfig(level=logging.INFO)
76
 
77
  with gr.Blocks() as demo:
 
84
  with gr.Row(variant="panel"):
85
  talk_btn = gr.Button("Send")
86
  with gr.Row(variant="panel"):
87
+ jsn = gr.JSON(visible=True, value="[]")
88
+ jsn2 = gr.JSON(visible=True, value="[]")
89
 
90
+ talk_btn.click(stowtext, inputs=[jsn2, txtinput], outputs=jsn, api_name="talk")
91
+ talk_btn.click(lambda x: gr.update(visible=False), inputs=talk_btn, outputs=talk_btn)
92
  talk_btn.click(lambda x: gr.update(value=""), inputs=txtinput, outputs=txtinput)
93
+ talk_btn.click(lambda x: gr.update(value="[]"), inputs=jsn2, outputs=jsn2)
94
+ jsn.change(talk, inputs=[jsn, jsn2], outputs=jsn2, api_name="talk")
95
+ jsn2.change(lambda x: gr.update(value=printfmt(x)), inputs=jsn2, outputs=talk_output)
96
+ jsn2.change(lambda x: gr.update(visible=not running), inputs=jsn2, outputs=talk_btn)
97
+ #jsn2.change(lambda x: gr.update(value=x), inputs=jsn2, outputs=jsn)
98
 
99
  demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
100