Upload base_poet_models.py
Browse files- utils/base_poet_models.py +11 -9
utils/base_poet_models.py
CHANGED
@@ -147,15 +147,6 @@ class PoetModelFunctionalInterface(PoetModelInterface):
|
|
147 |
# Finish possible not completed lines
|
148 |
base_prompt_len = len(prompt_list)
|
149 |
for i in range(2,base_prompt_len + 1):
|
150 |
-
rhyme_char = 0
|
151 |
-
if features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "B":
|
152 |
-
rhyme_char = 1
|
153 |
-
elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "C":
|
154 |
-
rhyme_char = 2
|
155 |
-
elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "D":
|
156 |
-
rhyme_char = 3
|
157 |
-
elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "X":
|
158 |
-
rhyme_char = -1
|
159 |
|
160 |
token_gen_finish = tokenizer.encode("\n".join(prompt_list[:i]), return_tensors='pt')
|
161 |
if sample:
|
@@ -178,6 +169,17 @@ class PoetModelFunctionalInterface(PoetModelInterface):
|
|
178 |
to_dec = min(i, len(decoded))
|
179 |
prompt_list[:to_dec] = decoded[:to_dec]
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
if to_dec - 1 < len(prompt_list):
|
182 |
dec_line = prompt_list[to_dec-1]
|
183 |
#OLD
|
|
|
147 |
# Finish possible not completed lines
|
148 |
base_prompt_len = len(prompt_list)
|
149 |
for i in range(2,base_prompt_len + 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
token_gen_finish = tokenizer.encode("\n".join(prompt_list[:i]), return_tensors='pt')
|
152 |
if sample:
|
|
|
169 |
to_dec = min(i, len(decoded))
|
170 |
prompt_list[:to_dec] = decoded[:to_dec]
|
171 |
|
172 |
+
|
173 |
+
rhyme_char = 0
|
174 |
+
if features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "B":
|
175 |
+
rhyme_char = 1
|
176 |
+
elif features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "C":
|
177 |
+
rhyme_char = 2
|
178 |
+
elif features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "D":
|
179 |
+
rhyme_char = 3
|
180 |
+
elif features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "X":
|
181 |
+
rhyme_char = -1
|
182 |
+
|
183 |
if to_dec - 1 < len(prompt_list):
|
184 |
dec_line = prompt_list[to_dec-1]
|
185 |
#OLD
|