jinymusim commited on
Commit
33687a3
·
verified ·
1 Parent(s): 25ac77d

Upload base_poet_models.py

Browse files
Files changed (1) hide show
  1. utils/base_poet_models.py +11 -9
utils/base_poet_models.py CHANGED
@@ -147,15 +147,6 @@ class PoetModelFunctionalInterface(PoetModelInterface):
147
  # Finish possible not completed lines
148
  base_prompt_len = len(prompt_list)
149
  for i in range(2,base_prompt_len + 1):
150
- rhyme_char = 0
151
- if features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "B":
152
- rhyme_char = 1
153
- elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "C":
154
- rhyme_char = 2
155
- elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "D":
156
- rhyme_char = 3
157
- elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "X":
158
- rhyme_char = -1
159
 
160
  token_gen_finish = tokenizer.encode("\n".join(prompt_list[:i]), return_tensors='pt')
161
  if sample:
@@ -178,6 +169,17 @@ class PoetModelFunctionalInterface(PoetModelInterface):
178
  to_dec = min(i, len(decoded))
179
  prompt_list[:to_dec] = decoded[:to_dec]
180
 
 
 
 
 
 
 
 
 
 
 
 
181
  if to_dec - 1 < len(prompt_list):
182
  dec_line = prompt_list[to_dec-1]
183
  #OLD
 
147
  # Finish possible not completed lines
148
  base_prompt_len = len(prompt_list)
149
  for i in range(2,base_prompt_len + 1):
 
 
 
 
 
 
 
 
 
150
 
151
  token_gen_finish = tokenizer.encode("\n".join(prompt_list[:i]), return_tensors='pt')
152
  if sample:
 
169
  to_dec = min(i, len(decoded))
170
  prompt_list[:to_dec] = decoded[:to_dec]
171
 
172
+
173
+ rhyme_char = 0
174
+ if features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "B":
175
+ rhyme_char = 1
176
+ elif features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "C":
177
+ rhyme_char = 2
178
+ elif features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "D":
179
+ rhyme_char = 3
180
+ elif features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "X":
181
+ rhyme_char = -1
182
+
183
  if to_dec - 1 < len(prompt_list):
184
  dec_line = prompt_list[to_dec-1]
185
  #OLD