sca255 commited on
Commit
b80e447
·
verified ·
1 Parent(s): e4716cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -19
app.py CHANGED
@@ -1,10 +1,18 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
3
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
9
 
10
  def respond(
@@ -24,29 +32,53 @@ def respond(
24
  messages.append({"role": "assistant", "content": val[1]})
25
 
26
  messages.append({"role": "user", "content": message})
27
-
 
28
  response = ""
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
41
 
42
  """
43
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
  """
45
- demo = gr.ChatInterface(
 
 
 
 
 
 
 
 
 
 
 
46
  respond,
47
  additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
  gr.Slider(
52
  minimum=0.1,
@@ -56,9 +88,8 @@ demo = gr.ChatInterface(
56
  label="Top-p (nucleus sampling)",
57
  ),
58
  ],
59
- theme=gr.themes.dark
60
  )
61
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ import hf_transfer
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer,StoppingCriteriaList,TextIteratorStreamer
4
+ from threading import Thread
5
+
6
+ model = AutoModelForCausalLM.from_pretrained(
7
+ "mistralai/Mistral-7B-v0.1",
8
+ )
9
+ tknz=AutoTokenizer.from_pretrained("kubernetes-bad/chargen-v2")
10
+
11
+
12
 
13
  """
14
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
15
  """
 
16
 
17
 
18
  def respond(
 
32
  messages.append({"role": "assistant", "content": val[1]})
33
 
34
  messages.append({"role": "user", "content": message})
35
+
36
+
37
  response = ""
38
+ model_inputs = tokenizer.build_chat_input(history=messages, role='user').input_ids.to(
39
+ next(model.parameters()).device)
40
 
41
+ streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
42
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
43
+ tokenizer.get_command("<|observation|>")]
44
+ generate_kwargs = {
45
+ "input_ids": model_inputs,
46
+ "streamer": streamer,
47
+ "max_new_tokens": max_tokens,
48
+ "do_sample": True,
49
+ "top_p": top_p,
50
+ "temperature": temperature,
51
+ "stopping_criteria": StoppingCriteriaList([stop]),
52
+ "repetition_penalty": 1,
53
+ "eos_token_id": eos_token_id,
54
+ }
55
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
56
+ for new_token in streamer:
57
+ if new_token and '<|user|>' in new_token:
58
+ new_token = new_token.split('<|user|>')[0]
59
+ if new_token:
60
+ history[-1][1] += new_token
61
+ yield history
62
 
63
  """
64
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
65
  """
66
+ js_func = """
67
+ function refresh() {
68
+ const url = new URL(window.location);
69
+
70
+ if (url.searchParams.get('__theme') !== 'dark') {
71
+ url.searchParams.set('__theme', 'dark');
72
+ window.location.href = url.href;
73
+ }
74
+ }
75
+ """
76
+ app = gr.ChatInterface(
77
+ js=js_func
78
  respond,
79
  additional_inputs=[
80
+ gr.Textbox(value="You are a bot who generates perfect roleplaying charecters.", label="System message"),
81
+ gr.Slider(minimum=1, maximum=, value=512, step=1, label="Max new tokens"),
82
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
83
  gr.Slider(
84
  minimum=0.1,
 
88
  label="Top-p (nucleus sampling)",
89
  ),
90
  ],
 
91
  )
92
 
93
 
94
  if __name__ == "__main__":
95
+ app.launch()