vwxyzjn commited on
Commit
49c3eca
·
1 Parent(s): 785ac30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -31
app.py CHANGED
@@ -159,7 +159,7 @@ def generate(
159
  repetition_penalty=repetition_penalty,
160
  do_sample=True,
161
  seed=42,
162
- stop_sequences=["<call>"]
163
  )
164
  generation_still_running = True
165
  request_idx = -1
@@ -179,13 +179,10 @@ def generate(
179
  highlighted_output = [
180
  (prompt, ""),
181
  ]
182
- yield highlighted_output
183
  for response in stream:
184
  i += 1
185
- if response.token.text == "<|endoftext|>":
186
- return output
187
- else:
188
- output += response.token.text
189
  tool, query = parse_tool_call(output[generation_start_idx:])
190
 
191
  if tool is not None and query is not None:
@@ -205,17 +202,24 @@ def generate(
205
  call_idx = output[generation_start_idx:].find("<call>")
206
  if response_idx == -1:
207
  response_idx = output[generation_start_idx:].find("<response>")
208
-
209
- # if `<request>` is in the output, highlight it, if `<call>` is in the output, highlight it
210
- # print("-------", generation_start_idx, request_idx, call_idx, response_idx)
 
211
  highlighted_output = [
212
  (prompt, ""),
 
213
  (output[generation_start_idx:generation_start_idx+request_idx], ""),
 
214
  (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
215
- (output[generation_start_idx+call_idx:-1], "call"),
 
 
 
 
216
  ]
217
- # print(i, highlighted_output, output)
218
- yield highlighted_output
219
 
220
  # breakpoint()
221
  call_output = copy.deepcopy(output)
@@ -224,26 +228,23 @@ def generate(
224
  generate_kwargs["stop_sequences"] = ["<submit>"]
225
  stream = client.generate_stream(output, **generate_kwargs)
226
  for response in stream:
227
- if response.token.text == "<|endoftext|>":
228
- return output
229
- else:
230
- output += response.token.text
231
  if submit_idx == -1:
232
  submit_idx = output[generation_start_idx:].find("<submit>")
233
- # print("-------", generation_start_idx, request_idx, call_idx, response_idx, submit_idx)
234
  highlighted_output = [
235
  (prompt, ""),
236
  (output[generation_start_idx:generation_start_idx+request_idx], ""),
237
  (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
238
  (output[generation_start_idx+call_idx:generation_start_idx+response_idx], "call"),
239
- (output[generation_start_idx+response_idx:-1], "submit"),
240
  ]
241
- # print(highlighted_output, output)
242
- yield highlighted_output
243
- print("-------", generation_start_idx, request_idx, call_idx, response_idx, submit_idx)
244
- print(highlighted_output, output)
245
 
246
- return highlighted_output
247
  except Exception as e:
248
  if "loading" in str(e):
249
  gr.Warning("waiting for model to load... (this could take up to 20 minutes, after which things are much faster)")
@@ -281,14 +282,13 @@ css += share_btn_css + monospace_css + ".gradio-container {color: black}"
281
 
282
  description = """
283
  <div style="text-align: center;">
284
- <h1> ⭐ StarCoderBase TriviaQA <span style='color: #e6b800;'>Models</span> Playground</h1>
285
  </div>
286
  <div style="text-align: left;">
287
- <p>This is a demo to generate text and code with the following StarCoderBase TriviaQA models:</p>
288
  <ul>
289
- <li><a href="https://huggingface.co/bigcode/starcoderplus" style='color: #e6b800;'>StarCoderPlus</a>: A finetuned version of StarCoderBase on English web data, making it strong in both English text and code generation.</li>
290
- <li><a href="https://huggingface.co/bigcode/starcoderbase" style='color: #e6b800;'>StarCoderBase</a>: A code generation model trained on 80+ programming languages, providing broad language coverage for code generation tasks.</li>
291
- <li><a href="https://huggingface.co/bigcode/starcoder" style='color: #e6b800;'>StarCoderBase TriviaQA</a>: A finetuned version of StarCoderBase specifically focused on Python, while also maintaining strong performance on other programming languages.</li>
292
  </ul>
293
  <p><b>Please note:</b> These models are not designed for instruction purposes. If you're looking for instruction or want to chat with a fine-tuned model, you can visit the <a href="https://huggingface.co/spaces/HuggingFaceH4/starchat-playground">StarChat Playground</a>.</p>
294
  </div>
@@ -326,11 +326,12 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
326
  elem_id="q-input",
327
  )
328
  submit = gr.Button("Generate", variant="primary")
329
- # output = gr.Code(elem_id="q-output", lines=30, label="Output")
330
  output = gr.HighlightedText(
331
  label="Output",
332
  color_map={"query": "red", "call": "green", "response": "blue", "submit": "yellow", "model": "pink"},
333
  )
 
334
  with gr.Row():
335
  with gr.Column():
336
  with gr.Accordion("Advanced settings", open=False):
@@ -387,14 +388,14 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
387
  inputs=[instruction],
388
  cache_examples=False,
389
  fn=process_example,
390
- outputs=[output],
391
  )
392
  # gr.Markdown(FORMATS)
393
 
394
  submit.click(
395
  generate,
396
  inputs=[instruction, system_prompt, version, temperature, max_new_tokens, top_p, repetition_penalty],
397
- outputs=[output],
398
  )
399
  share_button.click(None, [], [], _js=share_js)
400
  demo.queue(concurrency_count=16).launch(debug=True)
 
159
  repetition_penalty=repetition_penalty,
160
  do_sample=True,
161
  seed=42,
162
+ stop_sequences=["<call>", "<submit>"]
163
  )
164
  generation_still_running = True
165
  request_idx = -1
 
179
  highlighted_output = [
180
  (prompt, ""),
181
  ]
182
+ yield highlighted_output, output[generation_start_idx:]
183
  for response in stream:
184
  i += 1
185
+ output += response.token.text
 
 
 
186
  tool, query = parse_tool_call(output[generation_start_idx:])
187
 
188
  if tool is not None and query is not None:
 
202
  call_idx = output[generation_start_idx:].find("<call>")
203
  if response_idx == -1:
204
  response_idx = output[generation_start_idx:].find("<response>")
205
+ if submit_idx == -1:
206
+ submit_idx = output[generation_start_idx:].find("<submit>")
207
+ # I am sorry about the code
208
+ print("-------", generation_start_idx, request_idx, call_idx, response_idx)
209
  highlighted_output = [
210
  (prompt, ""),
211
+ (output[generation_start_idx:], "") if request_idx == -1 else ("", ""),
212
  (output[generation_start_idx:generation_start_idx+request_idx], ""),
213
+ (output[generation_start_idx+request_idx:], "") if call_idx == -1 else ("", ""),
214
  (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
215
+ (output[generation_start_idx+call_idx:generation_start_idx+response_idx], "call"),
216
+ (output[generation_start_idx+response_idx:], "submit") if submit_idx != -1 else ("", ""),
217
+ # (output[generation_start_idx:generation_start_idx+request_idx], ""),
218
+ # (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
219
+ # (output[generation_start_idx+call_idx:], "call"),
220
  ]
221
+ print(i, highlighted_output, output[generation_start_idx:])
222
+ yield highlighted_output, output[generation_start_idx:]
223
 
224
  # breakpoint()
225
  call_output = copy.deepcopy(output)
 
228
  generate_kwargs["stop_sequences"] = ["<submit>"]
229
  stream = client.generate_stream(output, **generate_kwargs)
230
  for response in stream:
231
+ output += response.token.text
 
 
 
232
  if submit_idx == -1:
233
  submit_idx = output[generation_start_idx:].find("<submit>")
234
+ # print("-------", generation_start_idx, request_idx, call_idx, response_idx)
235
  highlighted_output = [
236
  (prompt, ""),
237
  (output[generation_start_idx:generation_start_idx+request_idx], ""),
238
  (output[generation_start_idx+request_idx:generation_start_idx+call_idx], "request"),
239
  (output[generation_start_idx+call_idx:generation_start_idx+response_idx], "call"),
240
+ (output[generation_start_idx+response_idx:], "submit") if submit_idx != -1 else ("", ""),
241
  ]
242
+ # print(highlighted_output, output[generation_start_idx:])
243
+ yield highlighted_output, output[generation_start_idx:]
244
+ print("-------", generation_start_idx, request_idx, call_idx, response_idx)
245
+ print(highlighted_output, output[generation_start_idx:])
246
 
247
+ return highlighted_output, output[generation_start_idx:]
248
  except Exception as e:
249
  if "loading" in str(e):
250
  gr.Warning("waiting for model to load... (this could take up to 20 minutes, after which things are much faster)")
 
282
 
283
  description = """
284
  <div style="text-align: center;">
285
+ <h1> ⭐ TRL + TextEnvironment <span style='color: #e6b800;'>Models</span> Playground</h1>
286
  </div>
287
  <div style="text-align: left;">
288
+ <p>This is a demo to generate text and code with the following StarCoderBase models:</p>
289
  <ul>
290
+ <li><a href="https://huggingface.co/bigcode/starcoderplus" style='color: #e6b800;'>StarCoderBase TriviaQA</a>: A finetuned version of StarCoderBase on on the TriviaQA dataset using reinforcement learning via TRL's TextEnvironment (https://github.com/huggingface/trl/pull/424)</li>
291
+ <li><a href="https://huggingface.co/bigcode/starcoderbase" style='color: #e6b800;'>StarCoderBase GSM8K</a>: A finetuned version of StarCoderBase on on the GSM8K dataset using reinforcement learning via TRL's TextEnvironment (https://github.com/huggingface/trl/pull/424).</li>
 
292
  </ul>
293
  <p><b>Please note:</b> These models are not designed for instruction purposes. If you're looking for instruction or want to chat with a fine-tuned model, you can visit the <a href="https://huggingface.co/spaces/HuggingFaceH4/starchat-playground">StarChat Playground</a>.</p>
294
  </div>
 
326
  elem_id="q-input",
327
  )
328
  submit = gr.Button("Generate", variant="primary")
329
+ #
330
  output = gr.HighlightedText(
331
  label="Output",
332
  color_map={"query": "red", "call": "green", "response": "blue", "submit": "yellow", "model": "pink"},
333
  )
334
+ output2 = gr.Code(elem_id="q-output", lines=30, label="Raw output")
335
  with gr.Row():
336
  with gr.Column():
337
  with gr.Accordion("Advanced settings", open=False):
 
388
  inputs=[instruction],
389
  cache_examples=False,
390
  fn=process_example,
391
+ outputs=[output, output2],
392
  )
393
  # gr.Markdown(FORMATS)
394
 
395
  submit.click(
396
  generate,
397
  inputs=[instruction, system_prompt, version, temperature, max_new_tokens, top_p, repetition_penalty],
398
+ outputs=[output, output2],
399
  )
400
  share_button.click(None, [], [], _js=share_js)
401
  demo.queue(concurrency_count=16).launch(debug=True)