tree3po commited on
Commit
692f4a4
1 Parent(s): 70be205

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -0
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ models=[
5
+ "facebook/MobileLLM-125M",
6
+ "facebook/MobileLLM-350M",
7
+ "facebook/MobileLLM-600M",
8
+ "facebook/MobileLLM-1B",
9
+ ]
10
+ client_z=[]
11
+
12
+
13
+ def load_models(inp,new_models):
14
+ if not new_models:
15
+ new_models=models
16
+ out_box=[gr.Chatbot(),gr.Chatbot(),gr.Chatbot(),gr.Chatbot()]
17
+ print(type(inp))
18
+ print(inp)
19
+ #print(new_models[inp[0]])
20
+ client_z.clear()
21
+ for z,ea in enumerate(inp):
22
+ client_z.append(gr.load(new_models[inp[z]],token=token))
23
+ out_box[z]=(gr.update(label=new_models[inp[z]]))
24
+ return out_box[0],out_box[1],out_box[2],out_box[3]
25
+ def format_prompt_default(message, history):
26
+ prompt = ""
27
+ if history:
28
+ #<start_of_turn>userHow does the brain work?<end_of_turn><start_of_turn>model
29
+ for user_prompt, bot_response in history:
30
+ prompt += f"{user_prompt}\n"
31
+ print(prompt)
32
+ prompt += f"{bot_response}\n"
33
+ print(prompt)
34
+ prompt += f"{message}\n"
35
+ return prompt
36
+
37
+ def format_prompt_gemma(message, history):
38
+ prompt = ""
39
+ if history:
40
+ #<start_of_turn>userHow does the brain work?<end_of_turn><start_of_turn>model
41
+ for user_prompt, bot_response in history:
42
+ prompt += f"{user_prompt}\n"
43
+ print(prompt)
44
+ prompt += f"{bot_response}\n"
45
+ print(prompt)
46
+ prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
47
+ return prompt
48
+
49
+
50
+ def format_prompt_mixtral(message, history):
51
+ prompt = "<s>"
52
+ if history:
53
+ for user_prompt, bot_response in history:
54
+ prompt += f"[INST] {user_prompt} [/INST]"
55
+ prompt += f" {bot_response}</s> "
56
+ prompt += f"[INST] {message} [/INST]"
57
+ return prompt
58
+
59
+ def format_prompt_choose(message, history, model_name, new_models=None):
60
+ if not new_models:
61
+ new_models=models
62
+ if "gemma" in new_models[model_name].lower() and "it" in new_models[model_name].lower():
63
+ return format_prompt_gemma(message,history)
64
+ if "mixtral" in new_models[model_name].lower():
65
+ return format_prompt_mixtral(message,history)
66
+ else:
67
+ return format_prompt_mixtral(message,history)
68
+
69
+
70
+
71
+ mega_hist=[[],[],[],[]]
72
+ def chat_inf_tree(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
73
+ if len(client_choice)>=hid_val:
74
+ client=client_z[int(hid_val)-1]
75
+ if history:
76
+ mega_hist[hid_val-1]=history
77
+ #history = []
78
+ hist_len=0
79
+ generate_kwargs = dict(
80
+ temperature=temp,
81
+ max_new_tokens=tokens,
82
+ top_p=top_p,
83
+ repetition_penalty=rep_p,
84
+ do_sample=True,
85
+ seed=seed,
86
+ )
87
+ #formatted_prompt=prompt
88
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", mega_hist[hid_val-1])
89
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
90
+ output = ""
91
+ for response in stream:
92
+ output += response.token.text
93
+ yield [(prompt,output)]
94
+ mega_hist[hid_val-1].append((prompt,output))
95
+ yield mega_hist[hid_val-1]
96
+ else:
97
+ yield None
98
+
99
+
100
+
101
+
102
+ def chat_inf_a(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
103
+ if len(client_choice)>=hid_val:
104
+ if system_prompt:
105
+ system_prompt=f'{system_prompt}, '
106
+ client1=client_z[int(hid_val)-1]
107
+ if not history:
108
+ history = []
109
+ hist_len=0
110
+ generate_kwargs = dict(
111
+ temperature=temp,
112
+ max_new_tokens=tokens,
113
+ top_p=top_p,
114
+ repetition_penalty=rep_p,
115
+ do_sample=True,
116
+ seed=seed,
117
+ )
118
+ #formatted_prompt=prompt
119
+ formatted_prompt = format_prompt_choose(f"{system_prompt}{prompt}", history, client_choice[0])
120
+ stream1 = client1.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
121
+ output = ""
122
+ for response in stream1:
123
+ output += response.token.text
124
+ yield [(prompt,output)]
125
+ history.append((prompt,output))
126
+ yield history
127
+ else:
128
+ yield None
129
+
130
+
131
+ def chat_inf_b(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
132
+ if len(client_choice)>=hid_val:
133
+ if system_prompt:
134
+ system_prompt=f'{system_prompt}, '
135
+ client2=client_z[int(hid_val)-1]
136
+ if not history:
137
+ history = []
138
+ hist_len=0
139
+ generate_kwargs = dict(
140
+ temperature=temp,
141
+ max_new_tokens=tokens,
142
+ top_p=top_p,
143
+ repetition_penalty=rep_p,
144
+ do_sample=True,
145
+ seed=seed,
146
+ )
147
+ #formatted_prompt=prompt
148
+ formatted_prompt = format_prompt_choose(f"{system_prompt}{prompt}", history, client_choice[1])
149
+ stream2 = client2.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
150
+ output = ""
151
+ for response in stream2:
152
+ output += response.token.text
153
+ yield [(prompt,output)]
154
+ history.append((prompt,output))
155
+ yield history
156
+ else:
157
+ yield None
158
+
159
+ def chat_inf_c(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
160
+ if len(client_choice)>=hid_val:
161
+ if system_prompt:
162
+ system_prompt=f'{system_prompt}, '
163
+ client3=client_z[int(hid_val)-1]
164
+ if not history:
165
+ history = []
166
+ hist_len=0
167
+ generate_kwargs = dict(
168
+ temperature=temp,
169
+ max_new_tokens=tokens,
170
+ top_p=top_p,
171
+ repetition_penalty=rep_p,
172
+ do_sample=True,
173
+ seed=seed,
174
+ )
175
+ #formatted_prompt=prompt
176
+ formatted_prompt = format_prompt_choose(f"{system_prompt}{prompt}", history, client_choice[2])
177
+ stream3 = client3.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
178
+ output = ""
179
+ for response in stream3:
180
+ output += response.token.text
181
+ yield [(prompt,output)]
182
+ history.append((prompt,output))
183
+ yield history
184
+ else:
185
+ yield None
186
+
187
+ def chat_inf_d(system_prompt,prompt,history,client_choice,seed,temp,tokens,top_p,rep_p,hid_val):
188
+ if len(client_choice)>=hid_val:
189
+ if system_prompt:
190
+ system_prompt=f'{system_prompt}, '
191
+ client4=client_z[int(hid_val)-1]
192
+ if not history:
193
+ history = []
194
+ hist_len=0
195
+ generate_kwargs = dict(
196
+ temperature=temp,
197
+ max_new_tokens=tokens,
198
+ top_p=top_p,
199
+ repetition_penalty=rep_p,
200
+ do_sample=True,
201
+ seed=seed,
202
+ )
203
+ #formatted_prompt=prompt
204
+ formatted_prompt = format_prompt_choose(f"{system_prompt}{prompt}", history, client_choice[3])
205
+ stream4 = client4.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
206
+ output = ""
207
+ for response in stream4:
208
+ output += response.token.text
209
+ yield [(prompt,output)]
210
+ history.append((prompt,output))
211
+ yield history
212
+ else:
213
+ yield None
214
+ def add_new_model(inp, cur):
215
+ cur.append(inp)
216
+ return cur,gr.update(choices=[z for z in cur])
217
+ def load_new(models=models):
218
+ return models
219
+
220
+ def clear_fn():
221
+ return None,None,None,None,None,None
222
+ rand_val=random.randint(1,1111111111111111)
223
+ def check_rand(inp,val):
224
+ if inp==True:
225
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1,1111111111111111))
226
+ else:
227
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
228
+
229
+ with gr.Blocks() as app:
230
+ new_models=gr.State([])
231
+ gr.HTML("""<center><h1 style='font-size:xx-large;'>Chatbot Model Compare</h1>""")
232
+ with gr.Row():
233
+ chat_a = gr.Chatbot(height=500)
234
+ chat_b = gr.Chatbot(height=500)
235
+ with gr.Row():
236
+ chat_c = gr.Chatbot(height=500)
237
+ chat_d = gr.Chatbot(height=500)
238
+ with gr.Group():
239
+ with gr.Row():
240
+ with gr.Column(scale=3):
241
+ inp = gr.Textbox(label="Prompt")
242
+ sys_inp = gr.Textbox(label="System Prompt (optional)")
243
+ with gr.Row():
244
+ with gr.Column(scale=2):
245
+ btn = gr.Button("Chat")
246
+ with gr.Column(scale=1):
247
+ with gr.Group():
248
+ stop_btn=gr.Button("Stop")
249
+ clear_btn=gr.Button("Clear")
250
+ client_choice=gr.Dropdown(label="Models",type='index', choices=[c for c in models],max_choices=4,multiselect=True,interactive=True)
251
+ add_model=gr.Textbox(label="New Model")
252
+ add_btn=gr.Button("Add Model")
253
+ with gr.Column(scale=1):
254
+ with gr.Group():
255
+ rand = gr.Checkbox(label="Random Seed", value=True)
256
+ seed=gr.Slider(label="Seed", minimum=1, maximum=1111111111111111,step=1, value=rand_val)
257
+ tokens = gr.Slider(label="Max new tokens",value=3840,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
258
+ temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.9)
259
+ top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.9)
260
+ rep_p=gr.Slider(label="Repetition Penalty",step=0.1, minimum=0.1, maximum=2.0, value=1.0)
261
+
262
+ hid1=gr.Number(value=1,visible=False)
263
+ hid2=gr.Number(value=2,visible=False)
264
+ hid3=gr.Number(value=3,visible=False)
265
+ hid4=gr.Number(value=4,visible=False)
266
+
267
+ app.load(load_new,None,new_models)
268
+ add_btn.click(add_new_model,[add_model,new_models],[new_models,client_choice])
269
+ client_choice.change(load_models,[client_choice,new_models],[chat_a,chat_b,chat_c,chat_d])
270
+
271
+ #im_go=im_btn.click(get_screenshot,[chat_b,im_height,im_width,chatblock,theme,wait_time],img)
272
+ #chat_sub=inp.submit(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p],chat_b)
273
+
274
+ go1=btn.click(check_rand,[rand,seed],seed).then(chat_inf_a,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p,hid1],chat_a)
275
+ go2=btn.click(check_rand,[rand,seed],seed).then(chat_inf_b,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p,hid2],chat_b)
276
+ go3=btn.click(check_rand,[rand,seed],seed).then(chat_inf_c,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p,hid3],chat_c)
277
+ go4=btn.click(check_rand,[rand,seed],seed).then(chat_inf_d,[sys_inp,inp,chat_b,client_choice,seed,temp,tokens,top_p,rep_p,hid4],chat_d)
278
+
279
+ stop_btn.click(None,None,None,cancels=[go1,go2,go3,go4])
280
+ clear_btn.click(clear_fn,None,[inp,sys_inp,chat_a,chat_b,chat_c,chat_d])
281
+ app.queue(default_concurrency_limit=10).launch()
282
+
283
+