xnetba commited on
Commit
f82a3d5
·
1 Parent(s): 13057a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +290 -0
app.py CHANGED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from text_generation import Client, InferenceAPIClient
4
+
5
+ openchat_preprompt = (
6
+ "\n<human>: Zdravo!\n<bot>: \n"
7
+ )
8
+
9
+ def get_client(model: str):
10
+ InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
11
+
12
+ def get_usernames(model: str):
13
+ """
14
+ Returns:
15
+ (str, str, str, str): pre-prompt, username, bot name, separator
16
+ """
17
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
18
+ return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
19
+
20
+ def predict(
21
+ model: str,
22
+ inputs: str,
23
+ typical_p: float,
24
+ top_p: float,
25
+ temperature: float,
26
+ top_k: int,
27
+ repetition_penalty: float,
28
+ watermark: bool,
29
+ chatbot,
30
+ history,
31
+ ):
32
+ client = get_client(model)
33
+ preprompt, user_name, assistant_name, sep = get_usernames(model)
34
+
35
+ history.append(inputs)
36
+
37
+ past = []
38
+ for data in chatbot:
39
+ user_data, model_data = data
40
+
41
+ if not user_data.startswith(user_name):
42
+ user_data = user_name + user_data
43
+ if not model_data.startswith(sep + assistant_name):
44
+ model_data = sep + assistant_name + model_data
45
+
46
+ past.append(user_data + model_data.rstrip() + sep)
47
+
48
+ if not inputs.startswith(user_name):
49
+ inputs = user_name + inputs
50
+
51
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
52
+
53
+ partial_words = ""
54
+
55
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
56
+ iterator = client.generate_stream(
57
+ total_inputs,
58
+ typical_p=typical_p,
59
+ truncate=1000,
60
+ watermark=watermark,
61
+ max_new_tokens=500,
62
+ )
63
+ else:
64
+ iterator = client.generate_stream(
65
+ total_inputs,
66
+ top_p=top_p if top_p < 1.0 else None,
67
+ top_k=top_k,
68
+ truncate=1000,
69
+ repetition_penalty=repetition_penalty,
70
+ watermark=watermark,
71
+ temperature=temperature,
72
+ max_new_tokens=500,
73
+ stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
74
+ )
75
+
76
+ for i, response in enumerate(iterator):
77
+ if response.token.special:
78
+ continue
79
+
80
+ partial_words = partial_words + response.token.text
81
+ if partial_words.endswith(user_name.rstrip()):
82
+ partial_words = partial_words.rstrip(user_name.rstrip())
83
+ if partial_words.endswith(assistant_name.rstrip()):
84
+ partial_words = partial_words.rstrip(assistant_name.rstrip())
85
+
86
+ if i == 0:
87
+ history.append(" " + partial_words)
88
+ elif response.token.text not in user_name:
89
+ history[-1] = partial_words
90
+
91
+ chat = [
92
+ (history[i].strip(), history[i + 1].strip())
93
+ for i in range(0, len(history) - 1, 2)
94
+ ]
95
+ yield chat, history
96
+
97
+
98
+ def reset_textbox():
99
+ return gr.update(value="")
100
+
101
+
102
+ def radio_on_change(
103
+ value: str,
104
+ disclaimer,
105
+ typical_p,
106
+ top_p,
107
+ top_k,
108
+ temperature,
109
+ repetition_penalty,
110
+ watermark,
111
+ ):
112
+ if value in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
113
+ typical_p = typical_p.update(value=0.2, visible=True)
114
+ top_p = top_p.update(visible=False)
115
+ top_k = top_k.update(visible=False)
116
+ temperature = temperature.update(visible=False)
117
+ disclaimer = disclaimer.update(visible=False)
118
+ repetition_penalty = repetition_penalty.update(visible=False)
119
+ watermark = watermark.update(False)
120
+ else:
121
+ typical_p = typical_p.update(visible=False)
122
+ top_p = top_p.update(value=0.95, visible=True)
123
+ top_k = top_k.update(value=4, visible=True)
124
+ temperature = temperature.update(value=0.5, visible=True)
125
+ repetition_penalty = repetition_penalty.update(value=1.03, visible=True)
126
+ watermark = watermark.update(True)
127
+ disclaimer = disclaimer.update(visible=False)
128
+ return (
129
+ disclaimer,
130
+ typical_p,
131
+ top_p,
132
+ top_k,
133
+ temperature,
134
+ repetition_penalty,
135
+ watermark,
136
+ )
137
+
138
+
139
+ title = """<h1 align="center">LLM Chat</h1>"""
140
+ description = """LLM Chat predložak:
141
+ ```
142
+ User: <utterance>
143
+ Assistant: <utterance>
144
+ User: <utterance>
145
+ Assistant: <utterance>
146
+ ...
147
+ ```
148
+ Mjenjajući predložak.
149
+ """
150
+
151
+ text_generation_inference = """
152
+ <div align="center">Pokrenuto od: <a href=https://xnet.ba/>xnet.ba</a></div>
153
+ """
154
+
155
+ openchat_disclaimer = """
156
+ <div align="center">Checkout the official <a href=https://xnet.ba>xChat app</a> for the full experience.</div>
157
+ """
158
+
159
+ with gr.Blocks(
160
+ css="""#col_container {margin-left: auto; margin-right: auto;}
161
+ #chatbot {height: 520px; overflow: auto;}"""
162
+ ) as demo:
163
+ gr.HTML(title)
164
+ gr.Markdown(text_generation_inference, visible=True)
165
+ with gr.Column(elem_id="col_container"):
166
+ model = gr.Radio(
167
+ value="OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
168
+ choices=[
169
+ "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
170
+ "OpenAssistant/oasst-sft-1-pythia-12b",
171
+ ],
172
+ label="Model",
173
+ interactive=True,
174
+ )
175
+
176
+ chatbot = gr.Chatbot(elem_id="chatbot")
177
+ inputs = gr.Textbox(
178
+ placeholder="Vozdra raja!", label="Unesi pitanje i pritisni Enter"
179
+ )
180
+ disclaimer = gr.Markdown(openchat_disclaimer, visible=False)
181
+ state = gr.State([])
182
+ b1 = gr.Button()
183
+
184
+ with gr.Accordion("Parametri", open=False):
185
+ typical_p = gr.Slider(
186
+ minimum=-0,
187
+ maximum=1.0,
188
+ value=0.2,
189
+ step=0.05,
190
+ interactive=True,
191
+ label="Tipična P masa",
192
+ )
193
+ top_p = gr.Slider(
194
+ minimum=-0,
195
+ maximum=1.0,
196
+ value=0.25,
197
+ step=0.05,
198
+ interactive=True,
199
+ label="Top-p (uzorkovanje jezgra)",
200
+ visible=False,
201
+ )
202
+ temperature = gr.Slider(
203
+ minimum=-0,
204
+ maximum=5.0,
205
+ value=0.6,
206
+ step=0.1,
207
+ interactive=True,
208
+ label="Temperatura",
209
+ visible=False,
210
+ )
211
+ top_k = gr.Slider(
212
+ minimum=1,
213
+ maximum=50,
214
+ value=50,
215
+ step=1,
216
+ interactive=True,
217
+ label="Top-k",
218
+ visible=False,
219
+ )
220
+ repetition_penalty = gr.Slider(
221
+ minimum=0.1,
222
+ maximum=3.0,
223
+ value=1.03,
224
+ step=0.01,
225
+ interactive=True,
226
+ label="Kazna za ponavljanje",
227
+ visible=False,
228
+ )
229
+ watermark = gr.Checkbox(value=False, label="Vodeni žig teksta")
230
+
231
+ model.change(
232
+ lambda value: radio_on_change(
233
+ value,
234
+ disclaimer,
235
+ typical_p,
236
+ top_p,
237
+ top_k,
238
+ temperature,
239
+ repetition_penalty,
240
+ watermark,
241
+ ),
242
+ inputs=model,
243
+ outputs=[
244
+ disclaimer,
245
+ typical_p,
246
+ top_p,
247
+ top_k,
248
+ temperature,
249
+ repetition_penalty,
250
+ watermark,
251
+ ],
252
+ )
253
+
254
+ inputs.submit(
255
+ predict,
256
+ [
257
+ model,
258
+ inputs,
259
+ typical_p,
260
+ top_p,
261
+ temperature,
262
+ top_k,
263
+ repetition_penalty,
264
+ watermark,
265
+ chatbot,
266
+ state,
267
+ ],
268
+ [chatbot, state],
269
+ )
270
+ b1.click(
271
+ predict,
272
+ [
273
+ model,
274
+ inputs,
275
+ typical_p,
276
+ top_p,
277
+ temperature,
278
+ top_k,
279
+ repetition_penalty,
280
+ watermark,
281
+ chatbot,
282
+ state,
283
+ ],
284
+ [chatbot, state],
285
+ )
286
+ b1.click(reset_textbox, [], [inputs])
287
+ inputs.submit(reset_textbox, [], [inputs])
288
+
289
+ gr.Markdown(description)
290
+ demo.queue(concurrency_count=16).launch(debug=True)