reedmayhew commited on
Commit
f804d88
·
verified ·
1 Parent(s): b2f2185

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -34
app.py CHANGED
@@ -1,14 +1,12 @@
1
  import gradio as gr
2
  import os
3
  import spaces
4
- from transformers import GemmaTokenizer, AutoModelForCausalLM
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
 
8
  # Set an environment variable
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
 
11
-
12
  DESCRIPTION = '''
13
  <div>
14
  <h1 style="text-align: center;">DeepSeek-R1-Zero</h1>
@@ -28,7 +26,6 @@ PLACEHOLDER = """
28
  </div>
29
  """
30
 
31
-
32
  css = """
33
  h1 {
34
  text-align: center;
@@ -45,7 +42,8 @@ h1 {
45
 
46
  # Load the tokenizer and model
47
  tokenizer = AutoTokenizer.from_pretrained("reedmayhew/DeepSeek-R1-Refined-Llama-3.1-8B-hf")
48
- model = AutoModelForCausalLM.from_pretrained("reedmayhew/DeepSeek-R1-Refined-Llama-3.1-8B-hf", device_map="auto") # to("cuda:0")
 
49
  terminators = [
50
  tokenizer.eos_token_id,
51
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
@@ -53,10 +51,10 @@ terminators = [
53
 
54
  @spaces.GPU(duration=30)
55
  def chat_llama3_8b(message: str,
56
- history: list,
57
- temperature: float,
58
- max_new_tokens: int
59
- ) -> str:
60
  """
61
  Generate a streaming response using the llama3-8b model.
62
  Args:
@@ -67,24 +65,31 @@ def chat_llama3_8b(message: str,
67
  Returns:
68
  str: The generated response.
69
  """
 
70
  conversation = []
71
  for user, assistant in history:
72
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
 
 
 
 
73
  conversation.append({"role": "user", "content": message})
 
74
 
75
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
76
 
77
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
78
 
79
  generate_kwargs = dict(
80
- input_ids= input_ids,
81
  streamer=streamer,
82
  max_new_tokens=max_new_tokens,
83
  do_sample=True,
84
  temperature=temperature,
85
  eos_token_id=terminators,
86
  )
87
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
88
  if temperature == 0:
89
  generate_kwargs['do_sample'] = False
90
 
@@ -92,14 +97,34 @@ def chat_llama3_8b(message: str,
92
  t.start()
93
 
94
  outputs = []
 
 
 
 
 
95
  for text in streamer:
96
- outputs.append(text)
97
- #print(outputs)
98
- yield "".join(outputs)
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # Gradio block
102
- chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
103
 
104
  with gr.Blocks(fill_height=True, css=css) as demo:
105
 
@@ -110,31 +135,20 @@ with gr.Blocks(fill_height=True, css=css) as demo:
110
  fill_height=True,
111
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
112
  additional_inputs=[
113
- gr.Slider(minimum=0.6,
114
- maximum=0.6,
115
- step=0.1,
116
- value=0.6,
117
- label="Temperature",
118
- render=False),
119
- gr.Slider(minimum=128,
120
- maximum=4096,
121
- step=64,
122
- value=1024,
123
- label="Max new tokens",
124
- render=False ),
125
- ],
126
  examples=[
127
  ['How to setup a human base on Mars? Give short answer.'],
128
  ['Explain theory of relativity to me like I’m 8 years old.'],
129
  ['What is 9,000 * 9,000?'],
130
  ['Write a pun-filled happy birthday message to my friend Alex.'],
131
  ['Justify why a penguin might make a good king of the jungle.']
132
- ],
133
  cache_examples=False,
134
- )
135
 
136
  gr.Markdown(LICENSE)
137
-
138
  if __name__ == "__main__":
139
- demo.launch()
140
-
 
1
  import gradio as gr
2
  import os
3
  import spaces
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from threading import Thread
6
 
7
  # Set an environment variable
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
 
 
10
  DESCRIPTION = '''
11
  <div>
12
  <h1 style="text-align: center;">DeepSeek-R1-Zero</h1>
 
26
  </div>
27
  """
28
 
 
29
  css = """
30
  h1 {
31
  text-align: center;
 
42
 
43
  # Load the tokenizer and model
44
  tokenizer = AutoTokenizer.from_pretrained("reedmayhew/DeepSeek-R1-Refined-Llama-3.1-8B-hf")
45
+ model = AutoModelForCausalLM.from_pretrained("reedmayhew/DeepSeek-R1-Refined-Llama-3.1-8B-hf", device_map="auto")
46
+
47
  terminators = [
48
  tokenizer.eos_token_id,
49
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
 
51
 
52
  @spaces.GPU(duration=30)
53
  def chat_llama3_8b(message: str,
54
+ history: list,
55
+ temperature: float,
56
+ max_new_tokens: int
57
+ ) -> str:
58
  """
59
  Generate a streaming response using the llama3-8b model.
60
  Args:
 
65
  Returns:
66
  str: The generated response.
67
  """
68
+
69
  conversation = []
70
  for user, assistant in history:
71
+ conversation.extend([
72
+ {"role": "user", "content": user},
73
+ {"role": "assistant", "content": assistant}
74
+ ])
75
+
76
+ # Ensure the model starts with "<think>"
77
  conversation.append({"role": "user", "content": message})
78
+ conversation.append({"role": "assistant", "content": "<think> "}) # Force <think> at start
79
 
80
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
81
 
82
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
83
 
84
  generate_kwargs = dict(
85
+ input_ids=input_ids,
86
  streamer=streamer,
87
  max_new_tokens=max_new_tokens,
88
  do_sample=True,
89
  temperature=temperature,
90
  eos_token_id=terminators,
91
  )
92
+
93
  if temperature == 0:
94
  generate_kwargs['do_sample'] = False
95
 
 
97
  t.start()
98
 
99
  outputs = []
100
+ buffer = ""
101
+ think_detected = False
102
+ thinking_message_sent = False
103
+ full_response = "" # Store the full assistant response
104
+
105
  for text in streamer:
106
+ buffer += text
107
+ full_response += text # Store raw assistant response (includes <think>)
108
+
109
+ # Send the "thinking" message once text starts generating
110
+ if not thinking_message_sent:
111
+ thinking_message_sent = True
112
+ yield "DeepSeek R1 is Thinking...\n\n"
113
+
114
+ # Wait until </think> is detected before streaming output
115
+ if not think_detected:
116
+ if "</think>" in buffer:
117
+ think_detected = True
118
+ buffer = buffer.split("</think>", 1)[1] # Remove <think> section
119
+ else:
120
+ outputs.append(text)
121
+ yield "".join(outputs)
122
+
123
+ # Store the full response (including <think>) in history, but only show the user the cleaned response
124
+ history.append((message, full_response)) # Full assistant response saved for context
125
 
126
  # Gradio block
127
+ chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
128
 
129
  with gr.Blocks(fill_height=True, css=css) as demo:
130
 
 
135
  fill_height=True,
136
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
137
  additional_inputs=[
138
+ gr.Slider(minimum=0.6, maximum=0.6, step=0.1, value=0.6, label="Temperature", render=False),
139
+ gr.Slider(minimum=128, maximum=4096, step=64, value=1024, label="Max new tokens", render=False),
140
+ ],
 
 
 
 
 
 
 
 
 
 
141
  examples=[
142
  ['How to setup a human base on Mars? Give short answer.'],
143
  ['Explain theory of relativity to me like I’m 8 years old.'],
144
  ['What is 9,000 * 9,000?'],
145
  ['Write a pun-filled happy birthday message to my friend Alex.'],
146
  ['Justify why a penguin might make a good king of the jungle.']
147
+ ],
148
  cache_examples=False,
149
+ )
150
 
151
  gr.Markdown(LICENSE)
152
+
153
  if __name__ == "__main__":
154
+ demo.launch()