Spaces:
Sleeping
Sleeping
adjustment
Browse files
app.py
CHANGED
@@ -16,10 +16,10 @@ untethered_paraphrased_model = GPT2LMHeadModel.from_pretrained('zmbfeng/untether
|
|
16 |
default_num_return_sequences=5
|
17 |
default_temperature=0.5
|
18 |
default_repetition_penalty=1.5
|
19 |
-
default_top_p=
|
20 |
default_top_k=50
|
21 |
default_do_sample=True
|
22 |
-
default_seed=
|
23 |
def create_response(input_str,
|
24 |
# num_beams,
|
25 |
num_return_sequences,
|
@@ -49,25 +49,25 @@ def create_response(input_str,
|
|
49 |
print("seed" + str(seed))
|
50 |
encoded = tokenizer.encode_plus(input_str + tokenizer.eos_token, return_tensors="pt")
|
51 |
input_ids = encoded["input_ids"]
|
52 |
-
attention_mask = encoded["attention_mask"]
|
53 |
|
54 |
if seed != -1:
|
55 |
set_seed(seed)
|
56 |
if model_name == "original_model":
|
57 |
-
output = original_model.generate(input_ids,
|
58 |
transition_scores = original_model.compute_transition_scores(output.sequences, output.scores,
|
59 |
normalize_logits=False)
|
60 |
|
61 |
elif model_name == "untethered_model":
|
62 |
-
output = untethered_model.generate(input_ids,
|
63 |
transition_scores = untethered_model.compute_transition_scores(output.sequences, output.scores,
|
64 |
normalize_logits=False)
|
65 |
elif model_name == "untethered_paraphrased_model":
|
66 |
-
output = untethered_paraphrased_model.generate(input_ids,
|
67 |
transition_scores = untethered_paraphrased_model.compute_transition_scores(output.sequences, output.scores,
|
68 |
normalize_logits=False)
|
69 |
else:
|
70 |
-
output = original_model.generate(input_ids,
|
71 |
transition_scores = original_model.compute_transition_scores(output.sequences, output.scores,
|
72 |
normalize_logits=False)
|
73 |
score_list = []
|
|
|
16 |
default_num_return_sequences=5
|
17 |
default_temperature=0.5
|
18 |
default_repetition_penalty=1.5
|
19 |
+
default_top_p=2
|
20 |
default_top_k=50
|
21 |
default_do_sample=True
|
22 |
+
default_seed=43
|
23 |
def create_response(input_str,
|
24 |
# num_beams,
|
25 |
num_return_sequences,
|
|
|
49 |
print("seed" + str(seed))
|
50 |
encoded = tokenizer.encode_plus(input_str + tokenizer.eos_token, return_tensors="pt")
|
51 |
input_ids = encoded["input_ids"]
|
52 |
+
#attention_mask = encoded["attention_mask"]
|
53 |
|
54 |
if seed != -1:
|
55 |
set_seed(seed)
|
56 |
if model_name == "original_model":
|
57 |
+
output = original_model.generate(input_ids,do_sample=do_sample, max_length=100, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty,num_return_sequences=num_return_sequences,return_dict_in_generate=True, output_scores=True )
|
58 |
transition_scores = original_model.compute_transition_scores(output.sequences, output.scores,
|
59 |
normalize_logits=False)
|
60 |
|
61 |
elif model_name == "untethered_model":
|
62 |
+
output = untethered_model.generate(input_ids, do_sample=do_sample, max_length=100, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty,num_return_sequences=num_return_sequences,return_dict_in_generate=True, output_scores=True )
|
63 |
transition_scores = untethered_model.compute_transition_scores(output.sequences, output.scores,
|
64 |
normalize_logits=False)
|
65 |
elif model_name == "untethered_paraphrased_model":
|
66 |
+
output = untethered_paraphrased_model.generate(input_ids, do_sample=do_sample, max_length=100, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty,num_return_sequences=num_return_sequences,return_dict_in_generate=True, output_scores=True )
|
67 |
transition_scores = untethered_paraphrased_model.compute_transition_scores(output.sequences, output.scores,
|
68 |
normalize_logits=False)
|
69 |
else:
|
70 |
+
output = original_model.generate(input_ids,do_sample=do_sample, max_length=100, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty,num_return_sequences=num_return_sequences,return_dict_in_generate=True, output_scores=True )
|
71 |
transition_scores = original_model.compute_transition_scores(output.sequences, output.scores,
|
72 |
normalize_logits=False)
|
73 |
score_list = []
|