basit123796 commited on
Commit
47511ce
·
verified ·
1 Parent(s): 3d13b76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -54
app.py CHANGED
@@ -4,27 +4,25 @@ from huggingface_hub import InferenceClient
4
 
5
  ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
6
 
7
- models=[
8
  "google/gemma-7b",
9
  "google/gemma-7b-it",
10
  "google/gemma-2b",
11
  "google/gemma-2b-it"
12
  ]
13
- clients=[
14
- InferenceClient(models[0]),
15
- InferenceClient(models[1]),
16
- InferenceClient(models[2]),
17
- InferenceClient(models[3]),
18
  ]
19
 
20
- VERBOSE=False
 
 
 
 
21
 
22
- def load_models(inp):
23
- if VERBOSE==True:
24
- print(type(inp))
25
- print(inp)
26
- print(models[inp])
27
- return gr.update(label=models[inp])
28
 
29
  def format_prompt(message, history):
30
  prompt = ""
@@ -32,28 +30,34 @@ def format_prompt(message, history):
32
  for user_prompt, bot_response in history:
33
  prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
34
  prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
35
- if VERBOSE==True:
36
  print(prompt)
37
  prompt += message
38
  return prompt
39
 
40
- def chat_inf(prompt,history,memory,client_choice,temp,tokens,top_p,rep_p,chat_mem):
41
- hist_len=0
42
- client=clients[int(client_choice)-1]
 
43
  if not history:
44
  history = []
45
- hist_len=0
46
  if not memory:
47
  memory = []
48
- mem_len=0
49
  if memory:
50
- for ea in memory[0-chat_mem:]:
51
- hist_len+=len(str(ea))
52
- in_len=len(prompt)+hist_len
53
 
54
- if (in_len+tokens) > 8000:
55
- history.append((prompt,"Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value"))
56
- yield history,memory
 
 
 
 
 
57
  else:
58
  generate_kwargs = dict(
59
  temperature=temp,
@@ -62,33 +66,59 @@ def chat_inf(prompt,history,memory,client_choice,temp,tokens,top_p,rep_p,chat_me
62
  repetition_penalty=rep_p,
63
  do_sample=True,
64
  )
65
- formatted_prompt = format_prompt(prompt, memory[0-chat_mem:])
66
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
 
 
 
 
 
 
67
  output = ""
68
  for response in stream:
69
  output += response.token.text
70
- yield [(prompt,output)],memory
71
- history.append((prompt,output))
72
- memory.append((prompt,output))
73
- yield history,memory
74
-
75
- if VERBOSE==True:
76
- print("\n######### HIST "+str(in_len))
77
- print("\n######### TOKENS "+str(tokens))
78
-
79
- def get_screenshot(chat: list,height=5000,width=600,chatblock=[],theme="light",wait=3000,header=True):
 
 
 
 
 
 
 
 
 
80
  tog = 0
81
  if chatblock:
82
  tog = 3
83
- result = ss_client.predict(str(chat),height,width,chatblock,header,theme,wait,api_name="/run_script")
 
 
 
 
 
 
 
 
 
84
  out = f'https://omnibus-html-image-current-tab.hf.space/file={result[tog]}'
85
  return out
86
 
 
87
  def clear_fn():
88
- return None,None,None,None
 
89
 
90
  with gr.Blocks() as app:
91
- memory=gr.State()
92
  chat_b = gr.Chatbot(height=500)
93
  with gr.Group():
94
  with gr.Row():
@@ -97,17 +127,46 @@ with gr.Blocks() as app:
97
  btn = gr.Button("Chat")
98
  with gr.Column(scale=1):
99
  with gr.Group():
100
- temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
101
- tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
102
- top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
103
- rep_p=gr.Slider(label="Repetition Penalty",step=0.01, minimum=0.1, maximum=2.0, value=0.99)
104
- chat_mem=gr.Number(label="Chat Memory", info="Number of previous chats to retain",value=4)
105
-
106
- client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
107
- client_choice.change(load_models,client_choice,[chat_b])
108
- app.load(load_models,client_choice,[chat_b])
109
-
110
- chat_sub=inp.submit().then(chat_inf,[inp,chat_b,memory,client_choice,temp,tokens,top_p,rep_p,chat_mem],[chat_b,memory])
111
- go=btn.click().then(chat_inf,[inp,chat_b,memory,client_choice,temp,tokens,top_p,rep_p,chat_mem],[chat_b,memory])
112
-
113
- app.queue(default_concurrency_limit=10).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
6
 
7
+ models = [
8
  "google/gemma-7b",
9
  "google/gemma-7b-it",
10
  "google/gemma-2b",
11
  "google/gemma-2b-it"
12
  ]
13
+ clients = [
14
+ InferenceClient(models[0]),
15
+ InferenceClient(models[1]),
16
+ InferenceClient(models[2]),
17
+ InferenceClient(models[3]),
18
  ]
19
 
20
+ VERBOSE = False
21
+
22
+
23
+ def load_models():
24
+ return gr.update(label=models[0])
25
 
 
 
 
 
 
 
26
 
27
  def format_prompt(message, history):
28
  prompt = ""
 
30
  for user_prompt, bot_response in history:
31
  prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
32
  prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
33
+ if VERBOSE:
34
  print(prompt)
35
  prompt += message
36
  return prompt
37
 
38
+
39
+ def chat_inf(prompt, history, memory, temp, tokens, top_p, rep_p, chat_mem):
40
+ hist_len = 0
41
+ client = clients[0]
42
  if not history:
43
  history = []
44
+ hist_len = 0
45
  if not memory:
46
  memory = []
47
+ mem_len = 0
48
  if memory:
49
+ for ea in memory[0 - chat_mem :]:
50
+ hist_len += len(str(ea))
51
+ in_len = len(prompt) + hist_len
52
 
53
+ if (in_len + tokens) > 8000:
54
+ history.append(
55
+ (
56
+ prompt,
57
+ "Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value",
58
+ )
59
+ )
60
+ yield history, memory
61
  else:
62
  generate_kwargs = dict(
63
  temperature=temp,
 
66
  repetition_penalty=rep_p,
67
  do_sample=True,
68
  )
69
+ formatted_prompt = format_prompt(prompt, memory[0 - chat_mem :])
70
+ stream = client.text_generation(
71
+ formatted_prompt,
72
+ **generate_kwargs,
73
+ stream=True,
74
+ details=True,
75
+ return_full_text=True,
76
+ )
77
  output = ""
78
  for response in stream:
79
  output += response.token.text
80
+ yield [(prompt, output)], memory
81
+ history.append((prompt, output))
82
+ memory.append((prompt, output))
83
+ yield history, memory
84
+
85
+ if VERBOSE:
86
+ print("\n######### HIST " + str(in_len))
87
+ print("\n######### TOKENS " + str(tokens))
88
+
89
+
90
+ def get_screenshot(
91
+ chat: list,
92
+ height=5000,
93
+ width=600,
94
+ chatblock=[],
95
+ theme="light",
96
+ wait=3000,
97
+ header=True,
98
+ ):
99
  tog = 0
100
  if chatblock:
101
  tog = 3
102
+ result = ss_client.predict(
103
+ str(chat),
104
+ height,
105
+ width,
106
+ chatblock,
107
+ header,
108
+ theme,
109
+ wait,
110
+ api_name="/run_script",
111
+ )
112
  out = f'https://omnibus-html-image-current-tab.hf.space/file={result[tog]}'
113
  return out
114
 
115
+
116
  def clear_fn():
117
+ return None, None, None, None
118
+
119
 
120
  with gr.Blocks() as app:
121
+ memory = gr.State()
122
  chat_b = gr.Chatbot(height=500)
123
  with gr.Group():
124
  with gr.Row():
 
127
  btn = gr.Button("Chat")
128
  with gr.Column(scale=1):
129
  with gr.Group():
130
+ temp = gr.Slider(
131
+ label="Temperature",
132
+ step=0.01,
133
+ minimum=0.01,
134
+ maximum=1.0,
135
+ value=0.49,
136
+ )
137
+ tokens = gr.Slider(
138
+ label="Max new tokens",
139
+ value=1600,
140
+ minimum=0,
141
+ maximum=8000,
142
+ step=64,
143
+ interactive=True,
144
+ visible=True,
145
+ info="The maximum number of tokens",
146
+ )
147
+ top_p = gr.Slider(
148
+ label="Top-P",
149
+ step=0.01,
150
+ minimum=0.01,
151
+ maximum=1.0,
152
+ value=0.49,
153
+ )
154
+ rep_p = gr.Slider(
155
+ label="Repetition Penalty",
156
+ step=0.01,
157
+ minimum=0.1,
158
+ maximum=2.0,
159
+ value=0.99,
160
+ )
161
+ chat_mem = gr.Number(
162
+ label="Chat Memory",
163
+ info="Number of previous chats to retain",
164
+ value=4,
165
+ )
166
+
167
+ app.load(load_models)
168
+ chat_sub = inp.submit().then(
169
+ chat_inf, [inp, chat_b, memory, temp, tokens, top_p, rep_p, chat_mem], [chat_b, memory]
170
+ )
171
+ go = btn.click().then(
172
+ chat_inf,