ruggsea commited on
Commit
f97c568
·
1 Parent(s): d6f5bef

Fixing the chat history

Browse files
Files changed (1) hide show
  1. app.py +39 -31
app.py CHANGED
@@ -24,9 +24,6 @@ LICENSE = """
24
  As a derivative work of Llama 3.1, this demo is governed by the original Meta license and acceptable use policy.
25
  """
26
 
27
- if not torch.cuda.is_available():
28
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
29
-
30
  # Initialize model and tokenizer
31
  if torch.cuda.is_available():
32
  model_id = "ruggsea/Llama3.1-Instruct-SEP-Chat"
@@ -35,30 +32,30 @@ if torch.cuda.is_available():
35
  tokenizer.use_default_system_prompt = False
36
 
37
  @spaces.GPU
38
- def generate(
39
- message: str,
40
- chat_history: list[tuple[str, str]],
 
 
 
 
 
 
 
41
  system_prompt: str,
42
  max_new_tokens: int = 1024,
43
  temperature: float = 0.7,
44
  top_p: float = 0.9,
45
  top_k: int = 50,
46
  repetition_penalty: float = 1.1,
47
- ) -> Iterator[list[tuple[str, str]]]:
48
- if chat_history is None:
49
- chat_history = []
50
-
51
  conversation = []
52
  if system_prompt:
53
  conversation.append({"role": "system", "content": system_prompt})
54
 
55
- for user, assistant in chat_history:
56
- conversation.extend([
57
- {"role": "user", "content": str(user).strip()},
58
- {"role": "assistant", "content": str(assistant).strip()}
59
- ])
60
-
61
- conversation.append({"role": "user", "content": str(message).strip()})
62
 
63
  try:
64
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
@@ -83,17 +80,15 @@ def generate(
83
  t = Thread(target=model.generate, kwargs=generate_kwargs)
84
  t.start()
85
 
86
- outputs = []
87
  for text in streamer:
88
- outputs.append(text)
89
- partial_output = "".join(outputs)
90
- chat_history = chat_history + [(message, partial_output)]
91
- yield chat_history
92
 
93
  except Exception as e:
94
  gr.Warning(f"Error during generation: {str(e)}")
95
- chat_history = chat_history + [(message, "I apologize, but I encountered an error. Please try again.")]
96
- yield chat_history
97
 
98
  def create_demo() -> gr.Blocks:
99
  with gr.Blocks(css="style.css") as demo:
@@ -109,6 +104,7 @@ def create_demo() -> gr.Blocks:
109
  chatbot = gr.Chatbot(
110
  show_label=False,
111
  avatar_images=(None, None),
 
112
  )
113
 
114
  with gr.Row():
@@ -119,7 +115,7 @@ def create_demo() -> gr.Blocks:
119
  container=False,
120
  )
121
  submit = gr.Button("Submit", scale=1, variant="primary")
122
-
123
  system_prompt = gr.Textbox(
124
  label="System prompt",
125
  lines=6,
@@ -177,15 +173,27 @@ def create_demo() -> gr.Blocks:
177
  cache_examples=False,
178
  )
179
 
 
180
  msg.submit(
181
- generate,
182
- [msg, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
183
- [chatbot],
 
 
 
 
 
184
  )
 
185
  submit.click(
186
- generate,
187
- [msg, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
188
- [chatbot],
 
 
 
 
 
189
  )
190
 
191
  gr.Markdown(LICENSE)
 
24
  As a derivative work of Llama 3.1, this demo is governed by the original Meta license and acceptable use policy.
25
  """
26
 
 
 
 
27
  # Initialize model and tokenizer
28
  if torch.cuda.is_available():
29
  model_id = "ruggsea/Llama3.1-Instruct-SEP-Chat"
 
32
  tokenizer.use_default_system_prompt = False
33
 
34
  @spaces.GPU
35
+ def user(user_message: str, history: list, system_prompt: str) -> tuple[str, list]:
36
+ """Add user message to history"""
37
+ if history is None:
38
+ history = []
39
+ history.append({"role": "user", "content": user_message.strip()})
40
+ return "", history
41
+
42
+ @spaces.GPU
43
+ def bot(
44
+ history: list,
45
  system_prompt: str,
46
  max_new_tokens: int = 1024,
47
  temperature: float = 0.7,
48
  top_p: float = 0.9,
49
  top_k: int = 50,
50
  repetition_penalty: float = 1.1,
51
+ ) -> Iterator[list]:
52
+ """Generate bot response"""
 
 
53
  conversation = []
54
  if system_prompt:
55
  conversation.append({"role": "system", "content": system_prompt})
56
 
57
+ for message in history:
58
+ conversation.append(message)
 
 
 
 
 
59
 
60
  try:
61
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
80
  t = Thread(target=model.generate, kwargs=generate_kwargs)
81
  t.start()
82
 
83
+ history.append({"role": "assistant", "content": ""})
84
  for text in streamer:
85
+ history[-1]["content"] += text
86
+ yield history
 
 
87
 
88
  except Exception as e:
89
  gr.Warning(f"Error during generation: {str(e)}")
90
+ history.append({"role": "assistant", "content": "I apologize, but I encountered an error. Please try again."})
91
+ yield history
92
 
93
  def create_demo() -> gr.Blocks:
94
  with gr.Blocks(css="style.css") as demo:
 
104
  chatbot = gr.Chatbot(
105
  show_label=False,
106
  avatar_images=(None, None),
107
+ bubble_full_width=False,
108
  )
109
 
110
  with gr.Row():
 
115
  container=False,
116
  )
117
  submit = gr.Button("Submit", scale=1, variant="primary")
118
+
119
  system_prompt = gr.Textbox(
120
  label="System prompt",
121
  lines=6,
 
173
  cache_examples=False,
174
  )
175
 
176
+ # Chain the user and bot responses
177
  msg.submit(
178
+ user,
179
+ [msg, chatbot, system_prompt],
180
+ [msg, chatbot],
181
+ queue=False
182
+ ).then(
183
+ bot,
184
+ [chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
185
+ chatbot
186
  )
187
+
188
  submit.click(
189
+ user,
190
+ [msg, chatbot, system_prompt],
191
+ [msg, chatbot],
192
+ queue=False
193
+ ).then(
194
+ bot,
195
+ [chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
196
+ chatbot
197
  )
198
 
199
  gr.Markdown(LICENSE)