m-ric HF staff commited on
Commit
4fde691
1 Parent(s): e20ac5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -371,9 +371,10 @@ def get_beam_search_html(input_text, number_steps, number_beams, length_penalty)
371
  output_scores=True,
372
  do_sample=False,
373
  )
374
- print("Sequences:")
375
- print(tokenizer.batch_decode(outputs.sequences))
376
- print("Scores:", outputs.sequences_scores)
 
377
 
378
  original_tree = generate_beams(
379
  input_text,
@@ -382,7 +383,7 @@ def get_beam_search_html(input_text, number_steps, number_beams, length_penalty)
382
  length_penalty,
383
  )
384
  html = generate_html(input_text, original_tree)
385
- return html
386
 
387
 
388
  with gr.Blocks(
@@ -391,13 +392,24 @@ with gr.Blocks(
391
  ),
392
  css=STYLE,
393
  ) as demo:
 
 
 
 
 
 
 
 
 
 
394
  text = gr.Textbox(label="Sentence to decode from", value="Today is")
395
  with gr.Row():
396
  steps = gr.Slider(label="Number of steps", minimum=1, maximum=8, step=1, value=4)
397
  beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
398
  length_penalty = gr.Slider(label="Length penalty", minimum=-5, maximum=5, step=0.5, value=1)
399
  button = gr.Button()
400
- out = gr.Markdown(label="Output")
401
- button.click(get_beam_search_html, inputs=[text, steps, beams, length_penalty], outputs=out)
 
402
 
403
  demo.launch()
 
371
  output_scores=True,
372
  do_sample=False,
373
  )
374
+ markdown = "Sequences:"
375
+ decoded_sequences = tokenizer.batch_decode(outputs.sequences)
376
+ for i, sequence in enumerate(decoded_sequences):
377
+ markdown += f"\n- {sequence} ( score {outputs.sequences_scores[i]:.2f})"
378
 
379
  original_tree = generate_beams(
380
  input_text,
 
383
  length_penalty,
384
  )
385
  html = generate_html(input_text, original_tree)
386
+ return html, markdown
387
 
388
 
389
  with gr.Blocks(
 
392
  ),
393
  css=STYLE,
394
  ) as demo:
395
+ gr.Markdown("""# Beam search visualizer
396
+
397
+ Play with the parameters below to understand how beam search decoding works!
398
+
399
+ #### Parameters:
400
+ - **Sentence to decode from**: the input sequence to your decoder.
401
+ - **Number of steps**: the number of tokens to generate
402
+ - **Number of beams**: the number of beams to use
403
+ - **Length penalty**: the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
404
+ """)
405
  text = gr.Textbox(label="Sentence to decode from", value="Today is")
406
  with gr.Row():
407
  steps = gr.Slider(label="Number of steps", minimum=1, maximum=8, step=1, value=4)
408
  beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
409
  length_penalty = gr.Slider(label="Length penalty", minimum=-5, maximum=5, step=0.5, value=1)
410
  button = gr.Button()
411
+ out_html = gr.Markdown()
412
+ out_markdown = gr.Markdown()
413
+ button.click(get_beam_search_html, inputs=[text, steps, beams, length_penalty], outputs=[out_html, out_markdown])
414
 
415
  demo.launch()