rjiang12 commited on
Commit
65a6fca
·
1 Parent(s): aadd5d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  #from transformers import
3
  import tensorflow as tf
4
  from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
 
5
 
6
  #generator = pipeline('text-generation', model='gpt2')
7
 
@@ -18,15 +19,17 @@ def func(sentence, max_length, temperature):
18
  output_list = model.generate(
19
  input_ids,
20
  do_sample=True,
21
- max_length=max_length,
 
22
  top_p=0.92,
23
  top_k=0
 
24
  )
25
  output_strs = [tokenizer.decode(output, skip_special_tokens=True) for output in output_list]
26
- return output_strs[0]
27
 
28
 
29
- demo = gr.Interface(fn=func, inputs=["text", gr.Slider(5, 25), gr.Slider(0.1, 100)], outputs="text")
30
 
31
  if __name__ == "__main__":
32
  demo.launch()
 
2
  #from transformers import
3
  import tensorflow as tf
4
  from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
5
+ import math
6
 
7
  #generator = pipeline('text-generation', model='gpt2')
8
 
 
19
  output_list = model.generate(
20
  input_ids,
21
  do_sample=True,
22
+ max_length=Math.floor(max_length),
23
+ temperature=temperature,
24
  top_p=0.92,
25
  top_k=0
26
+ num_return_sequences=5
27
  )
28
  output_strs = [tokenizer.decode(output, skip_special_tokens=True) for output in output_list]
29
+ return output_strs
30
 
31
 
32
+ demo = gr.Interface(fn=func, inputs=["text", gr.Slider(5, 25), gr.Slider(0.1, 100)], outputs=["text", "text", "text", "text", "text"])
33
 
34
  if __name__ == "__main__":
35
  demo.launch()