yuchenlin commited on
Commit
f60e921
·
verified ·
1 Parent(s): e7e3b25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -40
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import spaces
4
  from threading import Thread
 
5
 
6
  # Load model and tokenizer
7
  model_name = "Magpie-Align/MagpieLM-4B-Chat-v0.1"
@@ -16,27 +17,25 @@ model.to(device)
16
 
17
  MAX_INPUT_TOKEN_LENGTH = 4096 # You may need to adjust this value
18
 
19
- @spaces.GPU(enable_queue=True)
20
- def respond(
21
- message,
22
- history: list[tuple[str, str]],
23
- system_message,
24
- max_tokens=2048,
25
- temperature=0.6,
26
- top_p=0.9,
27
- repetition_penalty=1.0,
28
- ):
29
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
30
 
31
- for val in history:
32
- if val[0]:
33
- messages.append({"role": "user", "content": val[0]})
34
- if val[1]:
35
- messages.append({"role": "assistant", "content": val[1]})
36
-
37
- messages.append({"role": "user", "content": message})
38
-
39
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
40
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
41
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
42
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
@@ -46,36 +45,31 @@ def respond(
46
  generate_kwargs = dict(
47
  input_ids=input_ids,
48
  streamer=streamer,
49
- max_new_tokens=max_tokens,
50
  do_sample=True,
51
  top_p=top_p,
 
52
  temperature=temperature,
 
53
  repetition_penalty=repetition_penalty,
54
  )
55
-
56
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
57
- thread.start()
58
-
59
- def stream():
60
- for text in streamer:
61
- yield text
62
 
63
- return stream()
 
 
 
64
 
65
  demo = gr.ChatInterface(
66
- respond,
67
  additional_inputs=[
68
- gr.Textbox(value="You are Magpie, a friendly Chatbot.", label="System message"),
69
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
70
  gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
71
- gr.Slider(
72
- minimum=0.1,
73
- maximum=1.0,
74
- value=0.9,
75
- step=0.05,
76
- label="Top-p (nucleus sampling)",
77
- ),
78
- gr.Slider(minimum=0.5, maximum=1.5, value=1.0, step=0.1, label="Repetition Penalty"),
79
  ],
80
  )
81
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import spaces
4
  from threading import Thread
5
+ from typing import Iterator
6
 
7
  # Load model and tokenizer
8
  model_name = "Magpie-Align/MagpieLM-4B-Chat-v0.1"
 
17
 
18
  MAX_INPUT_TOKEN_LENGTH = 4096 # You may need to adjust this value
19
 
20
+ @spaces.GPU
21
+ def generate(
22
+ message: str,
23
+ chat_history: list[tuple[str, str]],
24
+ system_prompt: str,
25
+ max_new_tokens: int = 1024,
26
+ temperature: float = 0.6,
27
+ top_p: float = 0.9,
28
+ top_k: int = 50,
29
+ repetition_penalty: float = 1.2,
30
+ ) -> Iterator[str]:
31
+ conversation = []
32
+ if system_prompt:
33
+ conversation.append({"role": "system", "content": system_prompt})
34
+ for user, assistant in chat_history:
35
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
36
+ conversation.append({"role": "user", "content": message})
37
 
38
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
 
 
 
 
 
 
39
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
40
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
41
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
45
  generate_kwargs = dict(
46
  input_ids=input_ids,
47
  streamer=streamer,
48
+ max_new_tokens=max_new_tokens,
49
  do_sample=True,
50
  top_p=top_p,
51
+ top_k=top_k,
52
  temperature=temperature,
53
+ num_beams=1,
54
  repetition_penalty=repetition_penalty,
55
  )
56
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
57
+ t.start()
 
 
 
 
 
58
 
59
+ outputs = []
60
+ for text in streamer:
61
+ outputs.append(text)
62
+ yield "".join(outputs)
63
 
64
  demo = gr.ChatInterface(
65
+ generate,
66
  additional_inputs=[
67
+ gr.Textbox(value="You are Magpie, a friendly Chatbot.", label="System prompt"),
68
+ gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
69
  gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
70
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
71
+ gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
72
+ gr.Slider(minimum=0.5, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty"),
 
 
 
 
 
73
  ],
74
  )
75