zmbfeng commited on
Commit
004631e
·
verified ·
1 Parent(s): 6bd9e07

added seed

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -5,7 +5,7 @@ import copy
5
 
6
  from huggingface_hub import login
7
  from transformers import pipeline
8
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
9
  login(os.environ["HF_TOKEN"])
10
  #https://huggingface.co/facebook/opt-1.3b
11
  #generator = pipeline('text-generation', model="microsoft/DialoGPT-medium")
@@ -19,6 +19,7 @@ default_repetition_penalty=1.5
19
  default_top_p=1.9
20
  default_top_k=50
21
  default_do_sample=True
 
22
  def create_response(input_str,
23
  # num_beams,
24
  num_return_sequences,
@@ -27,6 +28,7 @@ def create_response(input_str,
27
  top_p,
28
  top_k,
29
  do_sample,
 
30
  model_name):
31
  print("input_str="+input_str)
32
  print("model_name="+str(model_name))
@@ -43,12 +45,13 @@ def create_response(input_str,
43
  if not do_sample:
44
  num_beams = 1
45
  print("num_beams=" + str(num_beams))
46
-
47
  encoded = tokenizer.encode_plus(input_str + tokenizer.eos_token, return_tensors="pt")
48
  input_ids = encoded["input_ids"]
49
  attention_mask = encoded["attention_mask"]
50
 
51
-
 
52
  if model_name == "original_model":
53
  output_ids = original_model.generate(input_ids,pad_token_id=tokenizer.eos_token_id,do_sample=do_sample, attention_mask=attention_mask, max_length=100, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty,num_return_sequences=num_return_sequences )
54
  elif model_name == "untethered_model":
@@ -110,6 +113,9 @@ interface_original = gr.Interface(fn=create_response,
110
  "If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
111
  " select a word from the probability distribution at each step. This results in a more diverse and creative" +
112
  " output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
 
 
 
113
  gr.Textbox(label="model", lines=3, value="original_model",visible=False)
114
  ],
115
  outputs="html"
@@ -154,6 +160,9 @@ interface_untethered_model = gr.Interface(fn=create_response,
154
  "If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
155
  " select a word from the probability distribution at each step. This results in a more diverse and creative" +
156
  " output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
 
 
 
157
  gr.Textbox(label="model", lines=3, value="untethered_model",visible=False)
158
  ],
159
  outputs="html"
@@ -197,6 +206,9 @@ interface_untethered_paraphrased_model = gr.Interface(fn=create_response,
197
  "If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
198
  " select a word from the probability distribution at each step. This results in a more diverse and creative" +
199
  " output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
 
 
 
200
  gr.Textbox(label="model", lines=3, value="untethered_paraphrased_model",visible=False)
201
  ],
202
  outputs= "html"
 
5
 
6
  from huggingface_hub import login
7
  from transformers import pipeline
8
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel,set_seed
9
  login(os.environ["HF_TOKEN"])
10
  #https://huggingface.co/facebook/opt-1.3b
11
  #generator = pipeline('text-generation', model="microsoft/DialoGPT-medium")
 
19
  default_top_p=1.9
20
  default_top_k=50
21
  default_do_sample=True
22
+ default_seed=45
23
  def create_response(input_str,
24
  # num_beams,
25
  num_return_sequences,
 
28
  top_p,
29
  top_k,
30
  do_sample,
31
+ seed,
32
  model_name):
33
  print("input_str="+input_str)
34
  print("model_name="+str(model_name))
 
45
  if not do_sample:
46
  num_beams = 1
47
  print("num_beams=" + str(num_beams))
48
+ print("seed" + str(seed))
49
  encoded = tokenizer.encode_plus(input_str + tokenizer.eos_token, return_tensors="pt")
50
  input_ids = encoded["input_ids"]
51
  attention_mask = encoded["attention_mask"]
52
 
53
+ if seed != -1:
54
+ set_seed(seed)
55
  if model_name == "original_model":
56
  output_ids = original_model.generate(input_ids,pad_token_id=tokenizer.eos_token_id,do_sample=do_sample, attention_mask=attention_mask, max_length=100, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty,num_return_sequences=num_return_sequences )
57
  elif model_name == "untethered_model":
 
113
  "If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
114
  " select a word from the probability distribution at each step. This results in a more diverse and creative" +
115
  " output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
116
+ gr.Number(
117
+ label="seed (integer) random seed, set to -1 to use a random seed everytime",
118
+ value=default_seed),
119
  gr.Textbox(label="model", lines=3, value="original_model",visible=False)
120
  ],
121
  outputs="html"
 
160
  "If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
161
  " select a word from the probability distribution at each step. This results in a more diverse and creative" +
162
  " output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
163
+ gr.Number(
164
+ label="seed (integer) random seed, set to -1 to use a random seed everytime",
165
+ value=default_seed),
166
  gr.Textbox(label="model", lines=3, value="untethered_model",visible=False)
167
  ],
168
  outputs="html"
 
206
  "If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
207
  " select a word from the probability distribution at each step. This results in a more diverse and creative" +
208
  " output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
209
+ gr.Number(
210
+ label="seed (integer) random seed, set to -1 to use a random seed everytime",
211
+ value=default_seed),
212
  gr.Textbox(label="model", lines=3, value="untethered_paraphrased_model",visible=False)
213
  ],
214
  outputs= "html"