ffreemt commited on
Commit
c48ba74
·
1 Parent(s): 584239a

Update API ready, TODO: fix info

Browse files
Files changed (1) hide show
  1. app.py +183 -12
app.py CHANGED
@@ -6,6 +6,30 @@ transformers 4.31.0
6
  import torch
7
  torch.cuda.empty_cache()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  """
10
  # pylint: disable=line-too-long, invalid-name, no-member, redefined-outer-name, missing-function-docstring, missing-class-docstring, broad-except,
11
  import gc
@@ -14,7 +38,9 @@ import sys
14
  import time
15
  from collections import deque
16
  from dataclasses import asdict, dataclass
 
17
  from types import SimpleNamespace
 
18
 
19
  import gradio as gr
20
  import torch
@@ -100,6 +126,10 @@ model = None
100
  gc.collect()
101
  torch.cuda.empty_cache()
102
 
 
 
 
 
103
  model = gen_model(model_name)
104
 
105
 
@@ -136,6 +166,7 @@ def bot(chat_history, **kwargs):
136
  chat_history[:-1].append(["message", str(exc)])
137
  return chat_history
138
 
 
139
  def bot_stream(chat_history, **kwargs):
140
  logger.trace(f"{chat_history=}")
141
  logger.trace(f"{kwargs=}")
@@ -149,14 +180,17 @@ def bot_stream(chat_history, **kwargs):
149
 
150
  # for elm in model.chat_stream(tokenizer, message, chat_history):
151
  model.generation_config.update(**kwargs)
 
152
  for elm in model.chat_stream(tokenizer, message, chat_history):
153
  chat_history[-1] = [message, elm]
 
154
  yield chat_history
155
- logger.debug(f"response: {elm}")
 
156
 
157
 
158
  SYSTEM_PROMPT = "You are a helpful assistant."
159
- MAX_MAX_NEW_TOKENS = 1024
160
  MAX_NEW_TOKENS = 256
161
 
162
 
@@ -172,6 +206,72 @@ class Config:
172
  # stats_default = SimpleNamespace(llm=model, system_prompt=SYSTEM_PROMPT, config=Config())
173
  stats_default = SimpleNamespace(llm=None, system_prompt=SYSTEM_PROMPT, config=Config())
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  theme = gr.themes.Soft(text_size="sm")
176
  with gr.Blocks(
177
  theme=theme,
@@ -179,24 +279,69 @@ with gr.Blocks(
179
  css=css,
180
  ) as block:
181
  stats = gr.State(stats_default)
182
- if not torch.cuda.is_available():
183
- raise gr.Error("GPU not available, cant run. Turn on GPU and restart")
184
 
 
 
 
 
 
185
  config = asdict(stats.value.config)
 
186
  def bot_stream_state(chat_history):
187
  logger.trace(f"{chat_history=}")
188
  yield from bot_stream(chat_history, **config)
189
 
190
  with gr.Accordion("🎈 Info", open=False):
191
  gr.Markdown(
192
- f"""<h5><center>{model_name.lower()}</center></h4>
193
- Set `repetition_penalty` to 2.1 or higher for a chatty conversation. Lower it to 1.1 or smaller if more focused anwsers are desired (for example for translations or fact-oriented queries). Smaller `top_k` probably will result in smoothier sentences.
194
- (`top_k=0` is equivalent to `top_k` equal to very very big though.) Consult `transformers` documentation for more details.
195
-
196
-
197
- Most examples are meant for another model.
198
- You probably should try to test
199
- some related prompts.""",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  elem_classes="xsmall",
201
  )
202
 
@@ -367,5 +512,31 @@ with gr.Blocks(
367
  elem_classes=["disclaimer"],
368
  )
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  if __name__ == "__main__":
 
371
  block.queue(max_size=8).launch(debug=True)
 
6
  import torch
7
  torch.cuda.empty_cache()
8
 
9
+ model.chat(
10
+ tokenizer: transformers.tokenization_utils.PreTrainedTokenizer,
11
+ query: str,
12
+ history: Optional[List[Tuple[str, str]]],
13
+ system: str = 'You are a helpful assistant.',
14
+ append_history: bool = True,
15
+ stream: Optional[bool] = <object object at 0x7f905797ec20>,
16
+ stop_words_ids: Optional[List[List[int]]] = None,
17
+ **kwargs) -> Tuple[str, List[Tuple[str, str]]]
18
+ )
19
+
20
+ model.generation_config
21
+ GenerationConfig {
22
+ "chat_format": "chatml",
23
+ "do_sample": true,
24
+ "eos_token_id": 151643,
25
+ "max_new_tokens": 512,
26
+ "max_window_size": 6144,
27
+ "pad_token_id": 151643,
28
+ "top_k": 0,
29
+ "top_p": 0.5,
30
+ "transformers_version": "4.31.0",
31
+ "trust_remote_code": true
32
+ }
33
  """
34
  # pylint: disable=line-too-long, invalid-name, no-member, redefined-outer-name, missing-function-docstring, missing-class-docstring, broad-except,
35
  import gc
 
38
  import time
39
  from collections import deque
40
  from dataclasses import asdict, dataclass
41
+ from textwrap import dedent
42
  from types import SimpleNamespace
43
+ from typing import List, Optional
44
 
45
  import gradio as gr
46
  import torch
 
126
  gc.collect()
127
  torch.cuda.empty_cache()
128
 
129
+ if not torch.cuda.is_available():
130
+ # raise gr.Error("GPU not available, cant run. Turn on GPU and retry")
131
+ raise SystemExit("GPU not available, cant run. Turn on GPU and retry")
132
+
133
  model = gen_model(model_name)
134
 
135
 
 
166
  chat_history[:-1].append(["message", str(exc)])
167
  return chat_history
168
 
169
+
170
  def bot_stream(chat_history, **kwargs):
171
  logger.trace(f"{chat_history=}")
172
  logger.trace(f"{kwargs=}")
 
180
 
181
  # for elm in model.chat_stream(tokenizer, message, chat_history):
182
  model.generation_config.update(**kwargs)
183
+ response = ""
184
  for elm in model.chat_stream(tokenizer, message, chat_history):
185
  chat_history[-1] = [message, elm]
186
+ response = elm
187
  yield chat_history
188
+ logger.debug(f"{model.generation_config=}")
189
+ logger.debug(f"{response=}")
190
 
191
 
192
  SYSTEM_PROMPT = "You are a helpful assistant."
193
+ MAX_MAX_NEW_TOKENS = 2048 # sequence length 2048
194
  MAX_NEW_TOKENS = 256
195
 
196
 
 
206
  # stats_default = SimpleNamespace(llm=model, system_prompt=SYSTEM_PROMPT, config=Config())
207
  stats_default = SimpleNamespace(llm=None, system_prompt=SYSTEM_PROMPT, config=Config())
208
 
209
+
210
+ # input max_new_tokens temperature repetition_penalty top_k top_p system_prompt history
211
+ def api_fn( # pylint: disable=too-many-arguments
212
+ input_text: Optional[str],
213
+ # max_length: int = 256,
214
+ max_new_tokens: int = stats_default.config.max_new_tokens,
215
+ temperature: float = stats_default.config.temperature,
216
+ repetition_penalty: float = stats_default.config.repetition_penalty,
217
+ top_k: int = stats_default.config.top_k,
218
+ top_p: int = stats_default.config.top_p,
219
+ system_prompt: Optional[str] = None,
220
+ history: Optional[List[str]] = None,
221
+ ):
222
+ if input_text is None:
223
+ input_text = ""
224
+ try:
225
+ input_text = str(input_text).strip()
226
+ except Exception as exc:
227
+ logger.error(exc)
228
+ input_text = ""
229
+ if not input_text:
230
+ return ""
231
+ if history is None:
232
+ history = []
233
+ try:
234
+ temperature = float(temperature)
235
+ except Exception:
236
+ temperature = stats_default.config.temperature
237
+
238
+ if system_prompt is None:
239
+ system_prompt = stats_default.system_prompt
240
+ # if max_length < 10: max_length = 4096
241
+ if max_new_tokens < 10:
242
+ max_new_tokens = stats_default.config.max_new_tokens
243
+ if top_p < 0.1 or top_p > 1:
244
+ top_p = stats_default.config.top_p
245
+ if temperature <= 0.5:
246
+ temperature = stats_default.config.temperature
247
+
248
+ _ = {
249
+ "max_new_tokens": max_new_tokens,
250
+ "temperature": temperature,
251
+ "repetition_penalty": repetition_penalty,
252
+ "top_k": top_k,
253
+ "top_p": top_p,
254
+ }
255
+ model.generation_config.update(**_)
256
+ try:
257
+ res, _ = model.chat(
258
+ tokenizer,
259
+ input_text,
260
+ history=history,
261
+ # max_length=max_length,
262
+ append_history=False,
263
+ )
264
+ # logger.debug(f"{res=} \n{_=}")
265
+ except Exception as exc:
266
+ logger.error(f"{exc=}")
267
+ res = str(exc)
268
+
269
+ logger.debug(f"api {model.generation_config=}")
270
+ logger.debug(f"api {res=}")
271
+
272
+ return res
273
+
274
+
275
  theme = gr.themes.Soft(text_size="sm")
276
  with gr.Blocks(
277
  theme=theme,
 
279
  css=css,
280
  ) as block:
281
  stats = gr.State(stats_default)
 
 
282
 
283
+ # would this reset model?
284
+ model.generation_config = GenerationConfig.from_pretrained(
285
+ model_name,
286
+ trust_remote_code=True,
287
+ )
288
  config = asdict(stats.value.config)
289
+
290
  def bot_stream_state(chat_history):
291
  logger.trace(f"{chat_history=}")
292
  yield from bot_stream(chat_history, **config)
293
 
294
  with gr.Accordion("🎈 Info", open=False):
295
  gr.Markdown(
296
+ dedent(
297
+ f"""
298
+ ## {model_name.lower()}
299
+
300
+ * temperature range: .51 and up; higher temperature implies more random outputs. Suggested temperature for chatting and creative writing is around 1.1 while it should be set to 0.51-1.0 for summerizing and translation for example.
301
+
302
+ * Set `repetition_penalty` to 2.1 or higher for a chatty conversation (more unpredictable and undesirable output). Lower it to 1.1 or smaller if more focused anwsers are desired (for example for translations or fact-oriented queries).
303
+
304
+ * Smaller `top_k` probably will result in smoothier sentences.
305
+ (`top_k=0` is equivalent to `top_k` equal to very very big though.) Consult `transformers` documentation for more details.
306
+
307
+ * If you inadvertanyl messed up the parameters or the model, reset it in Advanced Options or reload the browser.
308
+
309
+ <p></p>
310
+ An api is available at, well, https://mikeee-qwen-7b-chat.hf.space/, e.g. in python
311
+
312
+ ```python
313
+ from gradio_client import Client
314
+
315
+ client = Client("https://7cff5e13976c7ba889.gradio.live/")
316
+ result = client.predict(
317
+ "你好!", # user prompt
318
+ 256, # max_new_tokens
319
+ 0.951, # temperature
320
+ 1.1, # repetition_penalty
321
+ 0, # top_k
322
+ 0.9, # top_p
323
+ "You are a help assistant", # system_prompt
324
+ None, # history
325
+ api_name="/api"
326
+ )
327
+ print(result)
328
+ ```
329
+
330
+ or in javascript
331
+ ```js
332
+ import {{ client }} from "@gradio/client";
333
+
334
+ const app = await client("https://mikeee-qwen-7b-chat.hf.space/");
335
+ const result = await app.predict("api", [...]);
336
+ console.log(result.data);
337
+ ```
338
+ Check documentation and examples by clicking `Use via API` at the very bottom of this page.
339
+
340
+ <p></p>
341
+ Most examples are meant for another model.
342
+ You probably should try to test
343
+ some related prompts."""
344
+ ),
345
  elem_classes="xsmall",
346
  )
347
 
 
512
  elem_classes=["disclaimer"],
513
  )
514
 
515
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
516
+ input_text = gr.Text()
517
+ api_history = gr.Chatbot(value=[])
518
+ api_btn = gr.Button("Go", variant="primary")
519
+ out_text = gr.Text()
520
+
521
+ # api_fn args order
522
+ # input_text max_new_tokens temperature repetition_penalty top_k top_p system_prompt history
523
+ api_btn.click(
524
+ api_fn,
525
+ [
526
+ input_text,
527
+ max_new_tokens,
528
+ temperature,
529
+ repetition_penalty,
530
+ top_k,
531
+ top_p,
532
+ system_prompt,
533
+ api_history, # dont know how to pass this in gradio_client.Client calls
534
+ ],
535
+ out_text,
536
+ api_name="api",
537
+ )
538
+
539
+
540
  if __name__ == "__main__":
541
+ logger.info("Just record start time")
542
  block.queue(max_size=8).launch(debug=True)