Spaces:
Sleeping
Sleeping
added seed
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import copy
|
|
5 |
|
6 |
from huggingface_hub import login
|
7 |
from transformers import pipeline
|
8 |
-
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
9 |
login(os.environ["HF_TOKEN"])
|
10 |
#https://huggingface.co/facebook/opt-1.3b
|
11 |
#generator = pipeline('text-generation', model="microsoft/DialoGPT-medium")
|
@@ -19,6 +19,7 @@ default_repetition_penalty=1.5
|
|
19 |
default_top_p=1.9
|
20 |
default_top_k=50
|
21 |
default_do_sample=True
|
|
|
22 |
def create_response(input_str,
|
23 |
# num_beams,
|
24 |
num_return_sequences,
|
@@ -27,6 +28,7 @@ def create_response(input_str,
|
|
27 |
top_p,
|
28 |
top_k,
|
29 |
do_sample,
|
|
|
30 |
model_name):
|
31 |
print("input_str="+input_str)
|
32 |
print("model_name="+str(model_name))
|
@@ -43,12 +45,13 @@ def create_response(input_str,
|
|
43 |
if not do_sample:
|
44 |
num_beams = 1
|
45 |
print("num_beams=" + str(num_beams))
|
46 |
-
|
47 |
encoded = tokenizer.encode_plus(input_str + tokenizer.eos_token, return_tensors="pt")
|
48 |
input_ids = encoded["input_ids"]
|
49 |
attention_mask = encoded["attention_mask"]
|
50 |
|
51 |
-
|
|
|
52 |
if model_name == "original_model":
|
53 |
output_ids = original_model.generate(input_ids,pad_token_id=tokenizer.eos_token_id,do_sample=do_sample, attention_mask=attention_mask, max_length=100, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty,num_return_sequences=num_return_sequences )
|
54 |
elif model_name == "untethered_model":
|
@@ -110,6 +113,9 @@ interface_original = gr.Interface(fn=create_response,
|
|
110 |
"If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
|
111 |
" select a word from the probability distribution at each step. This results in a more diverse and creative" +
|
112 |
" output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
|
|
|
|
|
|
|
113 |
gr.Textbox(label="model", lines=3, value="original_model",visible=False)
|
114 |
],
|
115 |
outputs="html"
|
@@ -154,6 +160,9 @@ interface_untethered_model = gr.Interface(fn=create_response,
|
|
154 |
"If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
|
155 |
" select a word from the probability distribution at each step. This results in a more diverse and creative" +
|
156 |
" output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
|
|
|
|
|
|
|
157 |
gr.Textbox(label="model", lines=3, value="untethered_model",visible=False)
|
158 |
],
|
159 |
outputs="html"
|
@@ -197,6 +206,9 @@ interface_untethered_paraphrased_model = gr.Interface(fn=create_response,
|
|
197 |
"If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
|
198 |
" select a word from the probability distribution at each step. This results in a more diverse and creative" +
|
199 |
" output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
|
|
|
|
|
|
|
200 |
gr.Textbox(label="model", lines=3, value="untethered_paraphrased_model",visible=False)
|
201 |
],
|
202 |
outputs= "html"
|
|
|
5 |
|
6 |
from huggingface_hub import login
|
7 |
from transformers import pipeline
|
8 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel,set_seed
|
9 |
login(os.environ["HF_TOKEN"])
|
10 |
#https://huggingface.co/facebook/opt-1.3b
|
11 |
#generator = pipeline('text-generation', model="microsoft/DialoGPT-medium")
|
|
|
19 |
default_top_p=1.9
|
20 |
default_top_k=50
|
21 |
default_do_sample=True
|
22 |
+
default_seed=45
|
23 |
def create_response(input_str,
|
24 |
# num_beams,
|
25 |
num_return_sequences,
|
|
|
28 |
top_p,
|
29 |
top_k,
|
30 |
do_sample,
|
31 |
+
seed,
|
32 |
model_name):
|
33 |
print("input_str="+input_str)
|
34 |
print("model_name="+str(model_name))
|
|
|
45 |
if not do_sample:
|
46 |
num_beams = 1
|
47 |
print("num_beams=" + str(num_beams))
|
48 |
+
print("seed" + str(seed))
|
49 |
encoded = tokenizer.encode_plus(input_str + tokenizer.eos_token, return_tensors="pt")
|
50 |
input_ids = encoded["input_ids"]
|
51 |
attention_mask = encoded["attention_mask"]
|
52 |
|
53 |
+
if seed != -1:
|
54 |
+
set_seed(seed)
|
55 |
if model_name == "original_model":
|
56 |
output_ids = original_model.generate(input_ids,pad_token_id=tokenizer.eos_token_id,do_sample=do_sample, attention_mask=attention_mask, max_length=100, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty,num_return_sequences=num_return_sequences )
|
57 |
elif model_name == "untethered_model":
|
|
|
113 |
"If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
|
114 |
" select a word from the probability distribution at each step. This results in a more diverse and creative" +
|
115 |
" output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
|
116 |
+
gr.Number(
|
117 |
+
label="seed (integer) random seed, set to -1 to use a random seed everytime",
|
118 |
+
value=default_seed),
|
119 |
gr.Textbox(label="model", lines=3, value="original_model",visible=False)
|
120 |
],
|
121 |
outputs="html"
|
|
|
160 |
"If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
|
161 |
" select a word from the probability distribution at each step. This results in a more diverse and creative" +
|
162 |
" output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
|
163 |
+
gr.Number(
|
164 |
+
label="seed (integer) random seed, set to -1 to use a random seed everytime",
|
165 |
+
value=default_seed),
|
166 |
gr.Textbox(label="model", lines=3, value="untethered_model",visible=False)
|
167 |
],
|
168 |
outputs="html"
|
|
|
206 |
"If is set to True, the generate function will use stochastic sampling, which means that it will randomly" +
|
207 |
" select a word from the probability distribution at each step. This results in a more diverse and creative" +
|
208 |
" output, but it might also introduce errors and inconsistencies ", value=default_do_sample),
|
209 |
+
gr.Number(
|
210 |
+
label="seed (integer) random seed, set to -1 to use a random seed everytime",
|
211 |
+
value=default_seed),
|
212 |
gr.Textbox(label="model", lines=3, value="untethered_paraphrased_model",visible=False)
|
213 |
],
|
214 |
outputs= "html"
|