Omnibus commited on
Commit
92fbdae
·
verified ·
1 Parent(s): 2fd87d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_client import Client
3
+ from huggingface_hub import InferenceClient
4
+ import random
5
+ from datetime import datetime
6
+ #from models import models
7
+ ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
8
+
9
+ models=[
10
+ "bigcode/starcoder2-15b",
11
+ "bigcode/starcoder2-7b",
12
+ "bigcode/starcoder2-3b",
13
+ ]
14
+
15
+
16
+
17
+ def test_models():
18
+ log_box=[]
19
+ for model in models:
20
+ start_time = datetime.now()
21
+ try:
22
+
23
+ generate_kwargs = dict(
24
+ temperature=0.9,
25
+ max_new_tokens=128,
26
+ top_p=0.9,
27
+ repetition_penalty=1.0,
28
+ do_sample=True,
29
+ seed=111111111,
30
+ )
31
+
32
+ print(f'trying: {model}\n')
33
+ client= InferenceClient(model)
34
+ outp=""
35
+ stream=client.text_generation("What is a cat", **generate_kwargs, stream=True, details=True, return_full_text=True)
36
+ for response in stream:
37
+ outp += response.token.text
38
+ print (outp)
39
+ time_delta = datetime.now() - start_time
40
+ count=time_delta.total_seconds()
41
+ #if time_delta.total_seconds() >= 180:
42
+ log = {"Model":model,"Status":"Success","Output":outp, "Time":count}
43
+ print(f'{log}\n')
44
+ log_box.append(log)
45
+ except Exception as e:
46
+ time_delta = datetime.now() - start_time
47
+ count=time_delta.total_seconds()
48
+
49
+ log = {"Model":model,"Status":"Error","Output":e,"Time":count}
50
+ print(f'{log}\n')
51
+ log_box.append(log)
52
+ yield log_box
53
+
54
+ def format_prompt_default(message, history,cust_p):
55
+ prompt = ""
56
+ if history:
57
+ #<start_of_turn>userHow does the brain work?<end_of_turn><start_of_turn>model
58
+ for user_prompt, bot_response in history:
59
+ prompt += f"{user_prompt}\n"
60
+ print(prompt)
61
+ prompt += f"{bot_response}\n"
62
+ print(prompt)
63
+ #prompt += f"{message}\n"
64
+ prompt+=cust_p.replace("USER_INPUT",message)
65
+ return prompt
66
+
67
+
68
+
69
+ def load_models(inp):
70
+ print(type(inp))
71
+ print(inp)
72
+ print(models[inp])
73
+ model_state= InferenceClient(models[inp])
74
+ out_box=gr.update(label=models[inp])
75
+ prompt_out="USER_INPUT\n"
76
+ return out_box,prompt_out, model_state
77
+
78
+
79
+ VERBOSE=False
80
+
81
+
82
+ def chat_inf(system_prompt,prompt,history,memory,model_state,model_name,seed,temp,tokens,top_p,rep_p,chat_mem,cust_p):
83
+ #token max=8192
84
+ model_n=models[model_name]
85
+ print(model_state)
86
+ hist_len=0
87
+ client=model_state
88
+ if not history:
89
+ history = []
90
+ hist_len=0
91
+ if not memory:
92
+ memory = []
93
+ mem_len=0
94
+ if memory:
95
+ for ea in memory[0-chat_mem:]:
96
+ hist_len+=len(str(ea))
97
+ in_len=len(system_prompt+prompt)+hist_len
98
+
99
+ if (in_len+tokens) > 8000:
100
+ history.append((prompt,"Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value"))
101
+ yield history,memory
102
+ else:
103
+ generate_kwargs = dict(
104
+ temperature=temp,
105
+ max_new_tokens=tokens,
106
+ top_p=top_p,
107
+ repetition_penalty=rep_p,
108
+ do_sample=True,
109
+ seed=seed,
110
+ )
111
+
112
+ stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
113
+ output = ""
114
+ for response in stream:
115
+ output += response.token.text
116
+ yield [(prompt,output)],memory
117
+ history.append((prompt,output))
118
+ memory.append((prompt,output))
119
+ yield history,memory
120
+
121
+
122
+
123
+ def get_screenshot(chat: list,height=5000,width=600,chatblock=[],theme="light",wait=3000,header=True):
124
+ print(chatblock)
125
+ tog = 0
126
+ if chatblock:
127
+ tog = 3
128
+ result = ss_client.predict(str(chat),height,width,chatblock,header,theme,wait,api_name="/run_script")
129
+ out = f'https://omnibus-html-image-current-tab.hf.space/file={result[tog]}'
130
+ print(out)
131
+ return out
132
+
133
+ def clear_fn():
134
+ return None,None,None,None
135
+ rand_val=random.randint(1,1111111111111111)
136
+
137
+ def check_rand(inp,val):
138
+ if inp==True:
139
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1,1111111111111111))
140
+ else:
141
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
142
+
143
+ with gr.Blocks() as app:
144
+ model_state=gr.State()
145
+ memory=gr.State()
146
+ gr.HTML("""<center><h1 style='font-size:xx-large;'>Huggingface Hub InferenceClient</h1><br><h3>Chatbot's</h3></center>""")
147
+ chat_b = gr.Chatbot(height=500)
148
+ with gr.Group():
149
+ with gr.Row():
150
+ with gr.Column(scale=3):
151
+
152
+ inp = gr.Textbox(label="Prompt")
153
+ sys_inp = gr.Textbox(label="System Prompt (optional)")
154
+ with gr.Accordion("Prompt Format",open=False):
155
+ custom_prompt=gr.Textbox(label="Modify Prompt Format", info="For testing purposes. 'USER_INPUT' is where 'SYSTEM_PROMPT, PROMPT' will be placed", lines=3,value="<start_of_turn>userUSER_INPUT<end_of_turn><start_of_turn>model")
156
+ with gr.Row():
157
+ with gr.Column(scale=2):
158
+ btn = gr.Button("Chat")
159
+ with gr.Column(scale=1):
160
+ with gr.Group():
161
+ stop_btn=gr.Button("Stop")
162
+ clear_btn=gr.Button("Clear")
163
+ test_btn=gr.Button("Test")
164
+ client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
165
+ with gr.Column(scale=1):
166
+ with gr.Group():
167
+ rand = gr.Checkbox(label="Random Seed", value=True)
168
+ seed=gr.Slider(label="Seed", minimum=1, maximum=1111111111111111,step=1, value=rand_val)
169
+ tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
170
+ temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
171
+ top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
172
+ rep_p=gr.Slider(label="Repetition Penalty",step=0.01, minimum=0.1, maximum=2.0, value=0.99)
173
+ chat_mem=gr.Number(label="Chat Memory", info="Number of previous chats to retain",value=4)
174
+ with gr.Accordion(label="Screenshot",open=False):
175
+ with gr.Row():
176
+ with gr.Column(scale=3):
177
+ im_btn=gr.Button("Screenshot")
178
+ img=gr.Image(type='filepath')
179
+ with gr.Column(scale=1):
180
+ with gr.Row():
181
+ im_height=gr.Number(label="Height",value=5000)
182
+ im_width=gr.Number(label="Width",value=500)
183
+ wait_time=gr.Number(label="Wait Time",value=3000)
184
+ theme=gr.Radio(label="Theme", choices=["light","dark"],value="light")
185
+ chatblock=gr.Dropdown(label="Chatblocks",info="Choose specific blocks of chat",choices=[c for c in range(1,40)],multiselect=True)
186
+ test_json=gr.JSON(label="Test Output")
187
+ test_btn.click(test_models,None,test_json)
188
+
189
+ client_choice.change(load_models,client_choice,[chat_b,custom_prompt,model_state])
190
+ app.load(load_models,client_choice,[chat_b,custom_prompt,model_state])
191
+
192
+ im_go=im_btn.click(get_screenshot,[chat_b,im_height,im_width,chatblock,theme,wait_time],img)
193
+
194
+ chat_sub=inp.submit(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,model_state,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,custom_prompt],[chat_b,memory])
195
+ go=btn.click(check_rand,[rand,seed],seed).then(chat_inf,[sys_inp,inp,chat_b,memory,model_state,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,custom_prompt],[chat_b,memory])
196
+
197
+ stop_btn.click(None,None,None,cancels=[go,im_go,chat_sub])
198
+ clear_btn.click(clear_fn,None,[inp,sys_inp,chat_b,memory])
199
+ app.queue(default_concurrency_limit=10).launch()