New Gen
Browse files- .gitattributes +3 -0
- corpus_capsulated_datasets.py +754 -0
- simple_generation_player.py +195 -0
- utils/__init__.py +0 -0
- utils/base_poet_models.py +689 -0
- utils/poet_model_utils.py +272 -0
- utils/poet_utils.py +591 -0
- utils/validators.py +359 -0
- utils/validators/meter/ufal-robeczech-base_BPE_validator_1704126400265 +3 -0
- utils/validators/rhyme/distilroberta-base_BPE_validator_1704126399565 +3 -0
- utils/validators/year/ufal-robeczech-base_BPE_validator_1702393305267 +3 -0
.gitattributes
CHANGED
@@ -36,3 +36,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
validators/meter/ufal-robeczech-base_syllable_BPE_validator_1702489033354 filter=lfs diff=lfs merge=lfs -text
|
37 |
validators/rhyme/distilroberta-base_syllable_BPE_validator_1702665903087 filter=lfs diff=lfs merge=lfs -text
|
38 |
validators/year/ufal-robeczech-base_BPE_validator_1702393305267 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
36 |
validators/meter/ufal-robeczech-base_syllable_BPE_validator_1702489033354 filter=lfs diff=lfs merge=lfs -text
|
37 |
validators/rhyme/distilroberta-base_syllable_BPE_validator_1702665903087 filter=lfs diff=lfs merge=lfs -text
|
38 |
validators/year/ufal-robeczech-base_BPE_validator_1702393305267 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
utils/validators/meter/ufal-robeczech-base_BPE_validator_1704126400265 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
utils/validators/rhyme/distilroberta-base_BPE_validator_1704126399565 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
utils/validators/year/ufal-robeczech-base_BPE_validator_1702393305267 filter=lfs diff=lfs merge=lfs -text
|
corpus_capsulated_datasets.py
ADDED
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from utils.poet_utils import StropheParams, SyllableMaker, TextAnalysis, TextManipulation
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from transformers import PreTrainedTokenizerBase, PreTrainedModel
|
9 |
+
#TODO: Maybe replace year of book being written for year Author was born
|
10 |
+
class CorpusDatasetPytorch:
|
11 |
+
"""Dataset class responsible for data loading.
|
12 |
+
"""
|
13 |
+
|
14 |
+
class RawDataset:
|
15 |
+
"""Dataset distributing raw sting data with no preprocessing
|
16 |
+
"""
|
17 |
+
def __init__(self, data_file_paths, lower_case:bool = True):
|
18 |
+
"""Construct the frame around Raw data generation
|
19 |
+
|
20 |
+
Args:
|
21 |
+
data_file_paths (_type_): list of paths to data files
|
22 |
+
lower_case (bool, optional): if resulting data should be in lowercase. Defaults to True.
|
23 |
+
"""
|
24 |
+
self._data_file_paths = data_file_paths
|
25 |
+
self.lower_case = lower_case
|
26 |
+
|
27 |
+
def gen_files(self):
|
28 |
+
"""Get individual opened files
|
29 |
+
|
30 |
+
Yields:
|
31 |
+
_type_: open file object
|
32 |
+
"""
|
33 |
+
for filename in self._data_file_paths:
|
34 |
+
yield open(filename, 'r')
|
35 |
+
|
36 |
+
def get_text(self):
|
37 |
+
"""Get lines of text of poetry
|
38 |
+
|
39 |
+
Yields:
|
40 |
+
str: individual verse line
|
41 |
+
"""
|
42 |
+
for step,file in enumerate(self.gen_files()):
|
43 |
+
if step % 500 == 0:
|
44 |
+
print(f"Processing file {step}")
|
45 |
+
datum = json.load(file)
|
46 |
+
for data_line in datum:
|
47 |
+
for part_line in data_line['body']:
|
48 |
+
for text_line in part_line:
|
49 |
+
yield text_line['text'].lower() if self.lower_case else text_line['text']
|
50 |
+
|
51 |
+
def get_part(self):
|
52 |
+
"""Get strophe of poetry
|
53 |
+
|
54 |
+
Yields:
|
55 |
+
str: 1 strophe of poetry
|
56 |
+
"""
|
57 |
+
for step,file in enumerate(self.gen_files()):
|
58 |
+
if step % 500 == 0:
|
59 |
+
print(f"Processing file {step}")
|
60 |
+
datum = json.load(file)
|
61 |
+
for data_line in datum:
|
62 |
+
for part_line in data_line['body']:
|
63 |
+
part = []
|
64 |
+
for text_line in part_line:
|
65 |
+
part.append(text_line['text'])
|
66 |
+
yield "\n".join(part).lower() if self.lower_case else "\n".join(part)
|
67 |
+
|
68 |
+
def get_body(self):
|
69 |
+
"""Get whole poem
|
70 |
+
|
71 |
+
Yields:
|
72 |
+
str: 1 whole poem
|
73 |
+
"""
|
74 |
+
for step,file in enumerate(self.gen_files()):
|
75 |
+
if step % 500 == 0:
|
76 |
+
print(f"Processing file {step}")
|
77 |
+
datum = json.load(file)
|
78 |
+
for data_line in datum:
|
79 |
+
body = []
|
80 |
+
for part_line in data_line['body']:
|
81 |
+
|
82 |
+
for text_line in part_line:
|
83 |
+
body.append(text_line['text'])
|
84 |
+
body.append("\n")
|
85 |
+
yield "\n".join(body).lower() if self.lower_case else "\n".join(body)
|
86 |
+
|
87 |
+
class TextDataset(Dataset):
|
88 |
+
"""Dataset of preprocessed verse lines
|
89 |
+
|
90 |
+
Args:
|
91 |
+
Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, data_file_paths, prompt_length=True, prompt_ending=True, lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05):
|
95 |
+
"""Construct the class our given data files path and store variables
|
96 |
+
|
97 |
+
Args:
|
98 |
+
data_file_paths (_type_): list of paths to data files
|
99 |
+
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
|
100 |
+
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
|
101 |
+
lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
|
102 |
+
val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
|
103 |
+
test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
|
104 |
+
"""
|
105 |
+
self._data_file_paths = data_file_paths
|
106 |
+
self.prompt_length = prompt_length
|
107 |
+
self.prompt_ending = prompt_ending
|
108 |
+
self.lower_case = lower_case
|
109 |
+
|
110 |
+
self.val_data_rate = val_data_rate
|
111 |
+
self.test_data_rate = test_data_rate
|
112 |
+
|
113 |
+
self.data = []
|
114 |
+
self.validation_data = []
|
115 |
+
self.test_data = []
|
116 |
+
|
117 |
+
|
118 |
+
def gen_files(self):
|
119 |
+
"""Get individual opened files
|
120 |
+
|
121 |
+
Yields:
|
122 |
+
_type_: open file object
|
123 |
+
"""
|
124 |
+
for filename in self._data_file_paths:
|
125 |
+
yield open(filename, 'r')
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def _vowels_and_endings(raw_text):
|
129 |
+
"""Get the verse ending and number of syllables in verse
|
130 |
+
|
131 |
+
Args:
|
132 |
+
raw_text (str): raw verse to analyze
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
tuple: number of syllables, ending syllable
|
136 |
+
"""
|
137 |
+
syllabs = SyllableMaker.syllabify(raw_text)
|
138 |
+
vowels = len(syllabs) #INFO: Now counts the number of syllables
|
139 |
+
ending = syllabs[-1]
|
140 |
+
return vowels, ending
|
141 |
+
|
142 |
+
@staticmethod
|
143 |
+
def _ending_vector(end):
|
144 |
+
"""Construct One-hot encoded vector for ending syllable
|
145 |
+
|
146 |
+
Args:
|
147 |
+
end (str): Ending syllable
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
numpy.ndarray: One-hot encoded vector of ending syllable
|
151 |
+
"""
|
152 |
+
verse_end_vector = np.zeros(len(StropheParams.ENDS))
|
153 |
+
if end in StropheParams.ENDS[:-1]:
|
154 |
+
verse_end_vector[StropheParams.ENDS.index(end)] = 1
|
155 |
+
else:
|
156 |
+
verse_end_vector[-1] = 1
|
157 |
+
return verse_end_vector
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def _syllable_line(raw_text):
|
161 |
+
"""Construct verse as sequence of syllables
|
162 |
+
|
163 |
+
Args:
|
164 |
+
raw_text (str): raw verse line
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
str: Verse line as sequence of syllables
|
168 |
+
"""
|
169 |
+
ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
|
170 |
+
return " ".join(SyllableMaker.syllabify(raw_text)) + ending
|
171 |
+
|
172 |
+
def _construct_line(self, raw_text, metre):
|
173 |
+
"""Construct individual content line
|
174 |
+
|
175 |
+
Args:
|
176 |
+
raw_text (str): raw verse line
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
str: Processed verse line with line parameters
|
180 |
+
"""
|
181 |
+
syllables = SyllableMaker.syllabify(raw_text)
|
182 |
+
num_str = f"{len(syllables)} # " if self.prompt_length else ""
|
183 |
+
verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
|
184 |
+
metre_txt = f"{metre} # "
|
185 |
+
return metre_txt + num_str + verse_end + raw_text
|
186 |
+
|
187 |
+
def _introduce_phonetics(self, raw_text:str, phonetics):
|
188 |
+
phonetic_text = raw_text
|
189 |
+
for word in phonetics['words']:
|
190 |
+
phonetic_text = phonetic_text.replace(f'{word["token_lc"]}', f'{word["phoebe"]}') if self.lower_case else phonetic_text.replace(f'{word["token"]}', f'{word["phoebe"]}')
|
191 |
+
return phonetic_text
|
192 |
+
|
193 |
+
def _construct_syllable_line(self, raw_text, metre):
|
194 |
+
"""Construct individual content line as sequence of syllables
|
195 |
+
|
196 |
+
Args:
|
197 |
+
raw_text (str): raw verse line
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
str: Processed verse line as sequence of syllables with line parameters
|
201 |
+
"""
|
202 |
+
ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
|
203 |
+
syllables = SyllableMaker.syllabify(raw_text)
|
204 |
+
num_str = f"{len(syllables)} # " if self.prompt_length else ""
|
205 |
+
verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
|
206 |
+
metre_txt = f"{metre} # "
|
207 |
+
return metre_txt+ num_str + verse_end + " ".join(syllables) + ending
|
208 |
+
|
209 |
+
|
210 |
+
def data_text_line_gen(self):
|
211 |
+
"""Preprocess and process data for usage
|
212 |
+
"""
|
213 |
+
for step,file in enumerate(self.gen_files()):
|
214 |
+
if step % 500 == 0:
|
215 |
+
print(f"Processing file {step}")
|
216 |
+
datum = json.load(file)
|
217 |
+
for data_line in datum:
|
218 |
+
for part_line in data_line['body']:
|
219 |
+
for text_line in part_line:
|
220 |
+
metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "N")
|
221 |
+
|
222 |
+
scanned_text = TextManipulation._remove_most_nonchar(text_line['text'], self.lower_case)
|
223 |
+
|
224 |
+
text_line_scanned = self._construct_line(scanned_text, metre)
|
225 |
+
syllable_line = self._construct_syllable_line(scanned_text, metre)
|
226 |
+
#phonetic_text = self._introduce_phonetics(scanned_text, text_line)
|
227 |
+
|
228 |
+
num_vowels, verse_end = self._vowels_and_endings(scanned_text)
|
229 |
+
|
230 |
+
# Based on result of random chose proper set. Because data are large enough, will result in wanted split.
|
231 |
+
rand_split = np.random.rand()
|
232 |
+
if rand_split > self.val_data_rate + self.test_data_rate:
|
233 |
+
self.data.append({
|
234 |
+
"input_ids" : [text_line_scanned,syllable_line],
|
235 |
+
"nums": [num_vowels],
|
236 |
+
"verse_end": verse_end,
|
237 |
+
"metre": metre
|
238 |
+
})
|
239 |
+
elif rand_split < self.test_data_rate:
|
240 |
+
self.test_data.append({
|
241 |
+
"input_ids" : [text_line_scanned,syllable_line],
|
242 |
+
"nums": [num_vowels],
|
243 |
+
"verse_end": verse_end,
|
244 |
+
"metre": metre
|
245 |
+
})
|
246 |
+
else:
|
247 |
+
self.validation_data.append({
|
248 |
+
"input_ids" : [text_line_scanned,syllable_line],
|
249 |
+
"nums": [num_vowels],
|
250 |
+
"verse_end": verse_end,
|
251 |
+
"metre": metre
|
252 |
+
})
|
253 |
+
|
254 |
+
|
255 |
+
def __len__(self):
|
256 |
+
"""Return length of training data
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
int: length of training data
|
260 |
+
"""
|
261 |
+
return len(self.data)
|
262 |
+
|
263 |
+
def __getitem__(self, index):
|
264 |
+
"""return indexed item
|
265 |
+
|
266 |
+
Args:
|
267 |
+
index (int): index from where to return
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
dict: dict with indexed data
|
271 |
+
"""
|
272 |
+
return self.data[index]
|
273 |
+
|
274 |
+
class BodyDataset(Dataset):
|
275 |
+
"""Dataset of preprocessed strophe
|
276 |
+
|
277 |
+
Args:
|
278 |
+
Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface
|
279 |
+
"""
|
280 |
+
def __init__(self, data_file_paths,
|
281 |
+
prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05):
|
282 |
+
"""Construct the class our given data files path and store variables
|
283 |
+
|
284 |
+
Args:
|
285 |
+
data_file_paths (_type_): list of paths to data files
|
286 |
+
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
|
287 |
+
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
|
288 |
+
prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True.
|
289 |
+
verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
|
290 |
+
lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
|
291 |
+
val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
|
292 |
+
test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
|
293 |
+
"""
|
294 |
+
self._data_file_paths = data_file_paths
|
295 |
+
self.prompt_length = prompt_length
|
296 |
+
self.prompt_ending = prompt_ending
|
297 |
+
self.prompt_verse = prompt_verse
|
298 |
+
self.verse_len = verse_len
|
299 |
+
self.lower_case = lower_case
|
300 |
+
|
301 |
+
self.val_data_rate = val_data_rate
|
302 |
+
self.test_data_rate = test_data_rate
|
303 |
+
|
304 |
+
self.data = []
|
305 |
+
self.validation_data = []
|
306 |
+
self.test_data = []
|
307 |
+
|
308 |
+
def gen_files(self):
|
309 |
+
"""Get individual opened files
|
310 |
+
|
311 |
+
Yields:
|
312 |
+
_type_: open file object
|
313 |
+
"""
|
314 |
+
for filename in self._data_file_paths:
|
315 |
+
yield open(filename, 'r')
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
|
320 |
+
def _construct_line(self, raw_text, metre):
|
321 |
+
"""Construct individual content line
|
322 |
+
|
323 |
+
Args:
|
324 |
+
raw_text (str): raw verse line
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
str: Processed verse line with line parameters
|
328 |
+
"""
|
329 |
+
syllables = SyllableMaker.syllabify(raw_text)
|
330 |
+
num_str = f"{len(syllables)} # " if self.prompt_length else ""
|
331 |
+
verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
|
332 |
+
metre_txt = f"{metre} # "
|
333 |
+
return metre_txt + num_str + verse_end + raw_text
|
334 |
+
|
335 |
+
def _construct_syllable_line(self, raw_text, metre):
|
336 |
+
"""Construct individual content line as sequence of syllables
|
337 |
+
|
338 |
+
Args:
|
339 |
+
raw_text (str): raw verse line
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
str: Processed verse line as sequence of syllables with line parameters
|
343 |
+
"""
|
344 |
+
ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
|
345 |
+
syllables = SyllableMaker.syllabify(raw_text)
|
346 |
+
num_str = f"{len(syllables)} # " if self.prompt_length else ""
|
347 |
+
verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
|
348 |
+
metre_txt = f"{metre} # "
|
349 |
+
return metre_txt + num_str + verse_end + " ".join(syllables) + ending
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
def data_body_gen(self):
|
354 |
+
"""Preprocess and process data for usage
|
355 |
+
"""
|
356 |
+
for step,file in enumerate(self.gen_files()):
|
357 |
+
if step % 500 == 0:
|
358 |
+
print(f"Processing file {step}")
|
359 |
+
datum = json.load(file)
|
360 |
+
|
361 |
+
for data_line in datum:
|
362 |
+
publish_year_text = TextManipulation._year_bucketor(data_line["biblio"]["year"])
|
363 |
+
publish_year_true = data_line["biblio"]["year"] if TextAnalysis._is_year(data_line["biblio"]["year"]) else 'NaN'
|
364 |
+
context = ["NO CONTEXT"]
|
365 |
+
|
366 |
+
for part_line in data_line['body']:
|
367 |
+
body = []
|
368 |
+
body_syllabs = []
|
369 |
+
rhyme= []
|
370 |
+
metres = []
|
371 |
+
i = 0
|
372 |
+
for text_line in part_line:
|
373 |
+
|
374 |
+
# In rare cases multiple, but from searching only 1 metre per line
|
375 |
+
metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "J")
|
376 |
+
metres += [metre]
|
377 |
+
|
378 |
+
rhyme.append(text_line["rhyme"])
|
379 |
+
|
380 |
+
scanned_text = TextManipulation._remove_most_nonchar(text_line["text"], self.lower_case)
|
381 |
+
|
382 |
+
body.append(self._construct_line(scanned_text,metre))
|
383 |
+
body_syllabs.append(self._construct_syllable_line(scanned_text,metre))
|
384 |
+
|
385 |
+
i+=1
|
386 |
+
|
387 |
+
if i in self.verse_len:
|
388 |
+
|
389 |
+
rhyme_str = TextManipulation._rhyme_string(rhyme)
|
390 |
+
|
391 |
+
text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body) + "\n"
|
392 |
+
syllable_text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body_syllabs) + "\n"
|
393 |
+
context_text= "\n".join(context)
|
394 |
+
rand_split = np.random.rand()
|
395 |
+
if rand_split > self.val_data_rate + self.test_data_rate:
|
396 |
+
self.data.append({
|
397 |
+
"input_ids" : [text,syllable_text],
|
398 |
+
"context_ids" : context_text,
|
399 |
+
"year": publish_year_true,
|
400 |
+
"rhyme": rhyme_str,
|
401 |
+
"metre_ids" : metres.copy()
|
402 |
+
})
|
403 |
+
elif rand_split < self.test_data_rate:
|
404 |
+
self.test_data.append({
|
405 |
+
"input_ids" : [text,syllable_text],
|
406 |
+
"context_ids" : context_text,
|
407 |
+
"year": publish_year_true,
|
408 |
+
"rhyme": rhyme_str,
|
409 |
+
"metre_ids" : metres.copy()
|
410 |
+
})
|
411 |
+
else:
|
412 |
+
self.validation_data.append({
|
413 |
+
"input_ids" : [text,syllable_text],
|
414 |
+
"context_ids" : context_text,
|
415 |
+
"year": publish_year_true,
|
416 |
+
"rhyme": rhyme_str,
|
417 |
+
"metre_ids" : metres.copy()
|
418 |
+
})
|
419 |
+
|
420 |
+
if i == max(self.verse_len):
|
421 |
+
body = []
|
422 |
+
body_syllabs = []
|
423 |
+
rhyme = []
|
424 |
+
metres = []
|
425 |
+
i=0
|
426 |
+
|
427 |
+
|
428 |
+
def __len__(self):
|
429 |
+
"""Return length of training data
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
int: length of training data
|
433 |
+
"""
|
434 |
+
return len(self.data)
|
435 |
+
|
436 |
+
def __getitem__(self, index):
|
437 |
+
"""return indexed item
|
438 |
+
|
439 |
+
Args:
|
440 |
+
index (int): index from where to return
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
dict: dict with indexed data
|
444 |
+
"""
|
445 |
+
return self.data[index]
|
446 |
+
|
447 |
+
def get_filenames(self):
|
448 |
+
"""Get paths of data files
|
449 |
+
|
450 |
+
Returns:
|
451 |
+
list: Paths of data files
|
452 |
+
"""
|
453 |
+
data_filenames = os.listdir(self.data_dir)
|
454 |
+
data_by_files = []
|
455 |
+
for filename in data_filenames:
|
456 |
+
file_path = os.path.join(self.data_dir, filename)
|
457 |
+
data_by_files.append(file_path)
|
458 |
+
return data_by_files
|
459 |
+
|
460 |
+
def load_raw_(self):
|
461 |
+
"""Load Raw dataset with raw string data
|
462 |
+
"""
|
463 |
+
filenames = self.get_filenames()
|
464 |
+
|
465 |
+
self.raw_dataset = CorpusDatasetPytorch.RawDataset(filenames, self.lower_case)
|
466 |
+
|
467 |
+
def load_json_filenames(self, prompt_length, prompt_ending, prompt_verse, verse_len=[4,6], val_data_rate=0.05, test_data_rate=0.05):
|
468 |
+
"""Load Verse and Strophe datasets
|
469 |
+
|
470 |
+
Args:
|
471 |
+
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
|
472 |
+
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
|
473 |
+
prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True.
|
474 |
+
verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
|
475 |
+
val_data_rate (float, optional): If the string should be in lowercase. Defaults to 0.1.
|
476 |
+
"""
|
477 |
+
filenames = self.get_filenames()
|
478 |
+
|
479 |
+
self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset(filenames, prompt_ending=prompt_ending,
|
480 |
+
prompt_length=prompt_length, prompt_verse=prompt_verse,
|
481 |
+
verse_len=verse_len, lower_case=self.lower_case,
|
482 |
+
val_data_rate=val_data_rate, test_data_rate=test_data_rate)
|
483 |
+
self.pytorch_dataset_body.data_body_gen()
|
484 |
+
|
485 |
+
|
486 |
+
self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset(filenames, prompt_ending=prompt_ending,
|
487 |
+
prompt_length=prompt_length, lower_case=self.lower_case,
|
488 |
+
val_data_rate=val_data_rate, test_data_rate=test_data_rate)
|
489 |
+
|
490 |
+
self.pytorch_dataset_text.data_text_line_gen()
|
491 |
+
|
492 |
+
self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
|
493 |
+
self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
|
494 |
+
|
495 |
+
self.val_pytorch_dataset_body.data = self.pytorch_dataset_body.validation_data
|
496 |
+
self.val_pytorch_dataset_text.data = self.pytorch_dataset_text.validation_data
|
497 |
+
|
498 |
+
self.pytorch_dataset_text.validation_data = []
|
499 |
+
self.pytorch_dataset_body.validation_data = []
|
500 |
+
|
501 |
+
self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
|
502 |
+
self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
|
503 |
+
|
504 |
+
self.test_pytorch_dataset_body.data = self.pytorch_dataset_body.test_data
|
505 |
+
self.test_pytorch_dataset_text.data = self.pytorch_dataset_text.test_data
|
506 |
+
|
507 |
+
self.pytorch_dataset_text.test_data = []
|
508 |
+
self.pytorch_dataset_body.test_data = []
|
509 |
+
|
510 |
+
def create_empty(self):
|
511 |
+
"""Create empty holder for possible load of processed data from file
|
512 |
+
"""
|
513 |
+
self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
|
514 |
+
self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
|
515 |
+
self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
|
516 |
+
self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
|
517 |
+
self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
|
518 |
+
self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
|
519 |
+
|
520 |
+
|
521 |
+
@staticmethod
|
522 |
+
def collate(batch, tokenizer: PreTrainedTokenizerBase ,max_len = 1024, max_context = 1024 ,mask_rate = 0.0, syllables: bool = False, format: str = 'METER_VERSE'):
|
523 |
+
"""Process data for usage in LM
|
524 |
+
|
525 |
+
Args:
|
526 |
+
batch (_type_): Batch with selected data points
|
527 |
+
tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text
|
528 |
+
max_len (int, optional): Maximum length of tokenization. Defaults to 1024.
|
529 |
+
max_context (int, optional): Maximum length of tokenization of context. Defaults to 1024.
|
530 |
+
mask_rate (float, optional): Rate in with to mask data. Defaults to 0.0.
|
531 |
+
syllables (bool, optional): If to use sequence of syllables as input text. Defaults to False.
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
dict: tokenized and processed to tensors data
|
535 |
+
"""
|
536 |
+
index = 1 if syllables else 0
|
537 |
+
|
538 |
+
tokenizer.model_max_length = max_len
|
539 |
+
if batch[0]['input_ids'][0].startswith("#"):
|
540 |
+
|
541 |
+
data = [text['input_ids'][index] for text in batch]
|
542 |
+
if format == "BASIC":
|
543 |
+
data = ["\n".join
|
544 |
+
(
|
545 |
+
[line + f" # {datum.splitlines()[1].split()[0]}"
|
546 |
+
if i==0 else line.split('#')[-1] for i, line in enumerate(datum.splitlines())]
|
547 |
+
) + tokenizer.eos_token for j, datum in enumerate(data)
|
548 |
+
]
|
549 |
+
elif format == "VERSE_PAR":
|
550 |
+
data = ["\n".join
|
551 |
+
(
|
552 |
+
[line + f" # {datum.splitlines()[1].split()[0]}"
|
553 |
+
if i==0 else "#".join(line.split('#')[1:]) for i, line in enumerate(datum.splitlines())]
|
554 |
+
) + tokenizer.eos_token for j, datum in enumerate(data)
|
555 |
+
]
|
556 |
+
else:
|
557 |
+
data = [text['input_ids'][index] + tokenizer.eos_token for text in batch]
|
558 |
+
|
559 |
+
tokenized = tokenizer(data,return_tensors='pt', truncation=True, padding=True)
|
560 |
+
input_ids = tokenized['input_ids']
|
561 |
+
attention = tokenized["attention_mask"]
|
562 |
+
|
563 |
+
else:
|
564 |
+
tokenized = tokenizer([text['input_ids'][index] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True)
|
565 |
+
input_ids = tokenized['input_ids']
|
566 |
+
attention = tokenized["attention_mask"]
|
567 |
+
|
568 |
+
|
569 |
+
nums = None
|
570 |
+
if "nums" in batch[0].keys():
|
571 |
+
nums = torch.tensor(np.asarray([text['nums'] for text in batch], dtype=np.int32), dtype=torch.float32)
|
572 |
+
|
573 |
+
rhyme=None
|
574 |
+
if "rhyme" in batch[0].keys():
|
575 |
+
rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32)
|
576 |
+
|
577 |
+
verse_end = None
|
578 |
+
if "verse_end" in batch[0].keys():
|
579 |
+
verse_end = torch.tensor(np.asarray([CorpusDatasetPytorch.TextDataset._ending_vector(text["verse_end"]) for text in batch], dtype=np.int32), dtype=torch.float32)
|
580 |
+
|
581 |
+
year = None
|
582 |
+
if "year" in batch[0].keys():
|
583 |
+
year = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32)
|
584 |
+
|
585 |
+
metre = None
|
586 |
+
if "metre" in batch[0].keys():
|
587 |
+
metre = torch.tensor(np.asarray([TextAnalysis._metre_vector(text["metre"]) for text in batch], dtype=np.int32), dtype=torch.float32)
|
588 |
+
|
589 |
+
context_ids = None
|
590 |
+
context_attention_mask = None
|
591 |
+
if "context_ids" in batch[0].keys():
|
592 |
+
tokenizer.model_max_length = max_context
|
593 |
+
tokenized_context = tokenizer([text['context_ids'] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True)
|
594 |
+
context_ids = tokenized_context['input_ids']
|
595 |
+
context_attention_mask = tokenized_context['attention_mask']
|
596 |
+
|
597 |
+
return {
|
598 |
+
"input_ids": input_ids,
|
599 |
+
"labels": input_ids.type(torch.LongTensor),
|
600 |
+
"attention_mask": attention,
|
601 |
+
"context_ids" : context_ids,
|
602 |
+
"context_attention_mask" : context_attention_mask,
|
603 |
+
"nums" : nums,
|
604 |
+
"rhyme": rhyme,
|
605 |
+
"verse_end" : verse_end,
|
606 |
+
"year": year,
|
607 |
+
"metre" : metre}
|
608 |
+
|
609 |
+
|
610 |
+
@staticmethod
|
611 |
+
def collate_distil(batch, tokenizer: PreTrainedTokenizerBase ,surrogate_model: PreTrainedModel = None,surrogate_model_device=None ,max_len = 1024):
|
612 |
+
tokenizer.model_max_length = max_len
|
613 |
+
tokenized = tokenizer([text['input_ids'][0] + tokenizer.eos_token for text in batch], return_tensors='pt', truncation=True, padding=True)
|
614 |
+
input_ids = tokenized['input_ids']
|
615 |
+
attention = tokenized["attention_mask"]
|
616 |
+
|
617 |
+
with torch.no_grad():
|
618 |
+
# This is Tuple
|
619 |
+
model_hidden_states = surrogate_model(input_ids=input_ids.to(surrogate_model_device),
|
620 |
+
attention_mask=attention.to(surrogate_model_device),
|
621 |
+
labels=input_ids.type(torch.LongTensor).to(surrogate_model_device))['hidden_states']
|
622 |
+
model_hidden_states = [hidden.cpu().detach() for hidden in model_hidden_states]
|
623 |
+
|
624 |
+
return {
|
625 |
+
"input_ids": input_ids,
|
626 |
+
"labels": input_ids.type(torch.LongTensor),
|
627 |
+
"attention_mask": attention,
|
628 |
+
"to_replicate_states": model_hidden_states
|
629 |
+
}
|
630 |
+
|
631 |
+
@staticmethod
|
632 |
+
def collate_validator(batch, tokenizer: PreTrainedTokenizerBase,syllables:bool, is_syllable:bool = False,max_len = 512):
|
633 |
+
"""Process data for use in LM for metre,rhyme and year prediction
|
634 |
+
|
635 |
+
Args:
|
636 |
+
batch (_type_): Batch with selected data points
|
637 |
+
tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text
|
638 |
+
syllables (bool): If to use sequence of syllables as input text
|
639 |
+
is_syllable (bool, optional): Signal if the preprocessed inputs contain syllable data. Defaults to False.
|
640 |
+
max_len (int, optional): Maximum length of tokenization. Defaults to 1024.
|
641 |
+
|
642 |
+
Returns:
|
643 |
+
dict: tokenized and processed to tensors data
|
644 |
+
"""
|
645 |
+
index = 1 if syllables and is_syllable else 0
|
646 |
+
tokenizer.model_max_length = max_len
|
647 |
+
data_ids = ["\n".join(
|
648 |
+
[" ".join(
|
649 |
+
SyllableMaker.syllabify(line.split('#')[-1])
|
650 |
+
) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in text['input_ids'][index].splitlines()[1:]]
|
651 |
+
) for text in batch ]
|
652 |
+
|
653 |
+
|
654 |
+
tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True)
|
655 |
+
input_ids = tokenized['input_ids']
|
656 |
+
attention = tokenized["attention_mask"]
|
657 |
+
|
658 |
+
rhyme=None
|
659 |
+
if "rhyme" in batch[0].keys():
|
660 |
+
rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32)
|
661 |
+
|
662 |
+
year_bucket = None
|
663 |
+
year = None
|
664 |
+
if "year" in batch[0].keys():
|
665 |
+
year_bucket = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32)
|
666 |
+
year = torch.tensor(np.asarray([ [int(text['year'])] if text['year'] != 'NaN' else [0] for text in batch], dtype=np.int32), dtype=torch.float32)
|
667 |
+
|
668 |
+
return {
|
669 |
+
"input_ids": input_ids,
|
670 |
+
"attention_mask": attention,
|
671 |
+
"rhyme": rhyme,
|
672 |
+
"metre_ids": None,
|
673 |
+
"year_bucket": year_bucket,
|
674 |
+
'year':year}
|
675 |
+
|
676 |
+
@staticmethod
|
677 |
+
def collate_meter(batch, tokenizer: PreTrainedTokenizerBase, syllables:bool, is_syllable:bool = False, max_len = 512):
|
678 |
+
index = 1 if syllables and is_syllable else 0
|
679 |
+
tokenizer.model_max_length = max_len
|
680 |
+
data_ids = []
|
681 |
+
metre = []
|
682 |
+
for datum in batch:
|
683 |
+
data_ids += [
|
684 |
+
" ".join(
|
685 |
+
SyllableMaker.syllabify(line.split('#')[-1])
|
686 |
+
) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in datum['input_ids'][index].splitlines()[1:]
|
687 |
+
]
|
688 |
+
if "metre_ids" in batch[0].keys():
|
689 |
+
metre += [TextAnalysis._metre_vector(one_metre) for one_metre in datum['metre_ids']]
|
690 |
+
|
691 |
+
tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True)
|
692 |
+
input_ids = tokenized['input_ids']
|
693 |
+
attention = tokenized["attention_mask"]
|
694 |
+
|
695 |
+
metre_ids = None
|
696 |
+
if len(metre) > 0:
|
697 |
+
metre_ids = torch.tensor(np.asarray(metre, dtype=np.int32), dtype=torch.float32)
|
698 |
+
|
699 |
+
return {
|
700 |
+
"input_ids": input_ids,
|
701 |
+
"attention_mask": attention,
|
702 |
+
"rhyme": None,
|
703 |
+
"metre_ids": metre_ids,
|
704 |
+
"year_bucket": None,
|
705 |
+
"year": None}
|
706 |
+
|
707 |
+
|
708 |
+
|
709 |
+
def __init__(self, data_dir = "PoetGen\corpusCzechVerse-master\ccv", cache_dir='./',
|
710 |
+
prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate=0.05, test_data_rate=0.05):
|
711 |
+
"""Construct the Dataloader and create Datasets
|
712 |
+
|
713 |
+
Args:
|
714 |
+
data_dir (str, optional): Path to data. Defaults to "PoetGen\corpusCzechVerse-master\ccv".
|
715 |
+
cache_dir (str, optional): Path where to store processed data. Defaults to './'.
|
716 |
+
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
|
717 |
+
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
|
718 |
+
prompt_verse (bool, optional): If to prompt rhyme schema. Defaults to True.
|
719 |
+
verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
|
720 |
+
lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
|
721 |
+
val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.1.
|
722 |
+
"""
|
723 |
+
self.lower_case = lower_case
|
724 |
+
self.data_dir = data_dir
|
725 |
+
if os.path.isfile(os.path.join(cache_dir, "body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "text_poet_data.json")) \
|
726 |
+
and os.path.isfile(os.path.join(cache_dir, "val_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "val_text_poet_data.json")) \
|
727 |
+
and os.path.isfile(os.path.join(cache_dir, "test_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "test_text_poet_data.json")) :
|
728 |
+
self.create_empty()
|
729 |
+
self.pytorch_dataset_body.data =list(json.load( open( os.path.join(cache_dir, "body_poet_data.json"), 'r')))
|
730 |
+
self.pytorch_dataset_text.data =list(json.load( open( os.path.join(cache_dir, "text_poet_data.json"), 'r')))
|
731 |
+
self.val_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "val_body_poet_data.json"), 'r')))
|
732 |
+
self.val_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "val_text_poet_data.json"), 'r')))
|
733 |
+
self.test_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "test_body_poet_data.json"), 'r')))
|
734 |
+
self.test_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "test_text_poet_data.json"), 'r')))
|
735 |
+
else:
|
736 |
+
self.load_json_filenames(prompt_length, prompt_ending, prompt_verse, verse_len=verse_len, val_data_rate=val_data_rate, test_data_rate=test_data_rate)
|
737 |
+
json.dump(self.pytorch_dataset_body.data, open( os.path.join(cache_dir, "body_poet_data.json"), 'w+'), indent = 6)
|
738 |
+
json.dump(self.pytorch_dataset_text.data, open( os.path.join(cache_dir, "text_poet_data.json"), 'w+'), indent = 6)
|
739 |
+
json.dump(self.val_pytorch_dataset_body.data, open( os.path.join(cache_dir, "val_body_poet_data.json"), 'w+'), indent = 6)
|
740 |
+
json.dump(self.val_pytorch_dataset_text.data, open( os.path.join(cache_dir, "val_text_poet_data.json"), 'w+'), indent = 6)
|
741 |
+
json.dump(self.test_pytorch_dataset_body.data, open( os.path.join(cache_dir, "test_body_poet_data.json"), 'w+'), indent = 6)
|
742 |
+
json.dump(self.test_pytorch_dataset_text.data, open( os.path.join(cache_dir, "test_text_poet_data.json"), 'w+'), indent = 6)
|
743 |
+
|
744 |
+
self.load_raw_()
|
745 |
+
|
746 |
+
|
747 |
+
|
748 |
+
#if __name__ == "__main__":
|
749 |
+
# Line Count
|
750 |
+
# print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_text())))
|
751 |
+
# Strophe Count
|
752 |
+
# print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_part())))
|
753 |
+
# Poem Count
|
754 |
+
# print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_body())))
|
simple_generation_player.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import sys
|
7 |
+
|
8 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
9 |
+
from utils.poet_utils import StropheParams, Tokens, TextManipulation, TextAnalysis
|
10 |
+
from utils.base_poet_models import PoetModelBase
|
11 |
+
from utils.validators import ValidatorInterface
|
12 |
+
|
13 |
+
from corpus_capsulated_datasets import CorpusDatasetPytorch
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
|
17 |
+
parser.add_argument("--model_path_full", default='jinymusim/gpt-czech-poet', type=str, help="Path to Model")
|
18 |
+
|
19 |
+
parser.add_argument("--rhyme_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils', 'validators', 'rhyme', 'distilroberta-base_BPE_validator_1704126399565')), type=str, help="Path to Model")
|
20 |
+
parser.add_argument("--metre_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'meter', 'ufal-robeczech-base_BPE_validator_1704126400265')), type=str, help="Path to Model")
|
21 |
+
parser.add_argument("--year_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'year', 'ufal-robeczech-base_BPE_validator_1702393305267')), type=str, help="Path to Model")
|
22 |
+
|
23 |
+
parser.add_argument("--validator_tokenizer_model_rhyme", default='distilroberta-base', type=str, help="Validator tokenizer")
|
24 |
+
parser.add_argument("--validator_tokenizer_model_meter", default='ufal/robeczech-base', type=str, help="Validator tokenizer")
|
25 |
+
parser.add_argument("--validator_tokenizer_model_year", default='ufal/robeczech-base', type=str, help="Validator tokenizer")
|
26 |
+
parser.add_argument("--val_syllables_rhyme", default=False, type=bool, help="Does validator use syllables")
|
27 |
+
parser.add_argument("--val_syllables_meter", default=False, type=bool, help="Does validator use syllables")
|
28 |
+
parser.add_argument("--val_syllables_year", default=False, type=bool, help="Does validator use syllables")
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
args = parser.parse_args([] if "__file__" not in globals() else None)
|
33 |
+
|
34 |
+
_ ,model_rel_name = os.path.split(args.model_path_full)
|
35 |
+
|
36 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
37 |
+
|
38 |
+
model = PoetModelBase(args.model_path_full).to(device)
|
39 |
+
model.eval()
|
40 |
+
|
41 |
+
rhyme_model, meter_model, year_model = None, None, None
|
42 |
+
rhyme_model_name, meter_model_name, year_model_name = "", "", ""
|
43 |
+
if args.rhyme_model_path_full:
|
44 |
+
rhyme_model: ValidatorInterface = (torch.load(args.rhyme_model_path_full, map_location=torch.device('cpu'))).to(device)
|
45 |
+
rhyme_model.eval()
|
46 |
+
_, rhyme_model_name = os.path.split(args.rhyme_model_path_full)
|
47 |
+
|
48 |
+
if args.metre_model_path_full:
|
49 |
+
meter_model: ValidatorInterface = (torch.load(args.metre_model_path_full, map_location=torch.device('cpu'))).to(device)
|
50 |
+
meter_model.eval()
|
51 |
+
_, meter_model_name = os.path.split(args.metre_model_path_full)
|
52 |
+
|
53 |
+
if args.year_model_path_full:
|
54 |
+
year_model: ValidatorInterface = (torch.load(args.year_model_path_full, map_location=torch.device('cpu'))).to(device)
|
55 |
+
year_model.eval()
|
56 |
+
_, year_model_name = os.path.split(args.year_model_path_full)
|
57 |
+
# Load Rhyme tokenizer
|
58 |
+
validator_tokenizer_rhyme: PreTrainedTokenizerBase = None
|
59 |
+
if args.validator_tokenizer_model_rhyme:
|
60 |
+
try:
|
61 |
+
validator_tokenizer_rhyme = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_rhyme)
|
62 |
+
except:
|
63 |
+
validator_tokenizer_rhyme: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_rhyme)
|
64 |
+
validator_tokenizer_rhyme.eos_token = Tokens.EOS
|
65 |
+
validator_tokenizer_rhyme.eos_token_id = Tokens.EOS_ID
|
66 |
+
validator_tokenizer_rhyme.pad_token = Tokens.PAD
|
67 |
+
validator_tokenizer_rhyme.pad_token_id = Tokens.PAD_ID
|
68 |
+
validator_tokenizer_rhyme.unk_token = Tokens.UNK
|
69 |
+
validator_tokenizer_rhyme.unk_token_id = Tokens.UNK_ID
|
70 |
+
validator_tokenizer_rhyme.cls_token = Tokens.CLS
|
71 |
+
validator_tokenizer_rhyme.cls_token_id = Tokens.CLS_ID
|
72 |
+
validator_tokenizer_rhyme.sep_token = Tokens.SEP
|
73 |
+
validator_tokenizer_rhyme.sep_token_id = Tokens.SEP_ID
|
74 |
+
|
75 |
+
# Load Meter tokenizer
|
76 |
+
validator_tokenizer_meter: PreTrainedTokenizerBase = None
|
77 |
+
if args.validator_tokenizer_model_meter:
|
78 |
+
try:
|
79 |
+
validator_tokenizer_meter = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_meter, revision='v1.0')
|
80 |
+
except:
|
81 |
+
validator_tokenizer_meter: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_meter)
|
82 |
+
validator_tokenizer_meter.eos_token = Tokens.EOS
|
83 |
+
validator_tokenizer_meter.eos_token_id = Tokens.EOS_ID
|
84 |
+
validator_tokenizer_meter.pad_token = Tokens.PAD
|
85 |
+
validator_tokenizer_meter.pad_token_id = Tokens.PAD_ID
|
86 |
+
validator_tokenizer_meter.unk_token = Tokens.UNK
|
87 |
+
validator_tokenizer_meter.unk_token_id = Tokens.UNK_ID
|
88 |
+
validator_tokenizer_meter.cls_token = Tokens.CLS
|
89 |
+
validator_tokenizer_meter.cls_token_id = Tokens.CLS_ID
|
90 |
+
validator_tokenizer_meter.sep_token = Tokens.SEP
|
91 |
+
validator_tokenizer_meter.sep_token_id = Tokens.SEP_ID
|
92 |
+
|
93 |
+
# Load Year tokenizer
|
94 |
+
validator_tokenizer_year: PreTrainedTokenizerBase = None
|
95 |
+
if args.validator_tokenizer_model_year:
|
96 |
+
try:
|
97 |
+
validator_tokenizer_year = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_year, revision='v1.0')
|
98 |
+
except:
|
99 |
+
validator_tokenizer_year: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_year)
|
100 |
+
validator_tokenizer_year.eos_token = Tokens.EOS
|
101 |
+
validator_tokenizer_year.eos_token_id = Tokens.EOS_ID
|
102 |
+
validator_tokenizer_year.pad_token = Tokens.PAD
|
103 |
+
validator_tokenizer_year.pad_token_id = Tokens.PAD_ID
|
104 |
+
validator_tokenizer_year.unk_token = Tokens.UNK
|
105 |
+
validator_tokenizer_year.unk_token_id = Tokens.UNK_ID
|
106 |
+
validator_tokenizer_year.cls_token = Tokens.CLS
|
107 |
+
validator_tokenizer_year.cls_token_id = Tokens.CLS_ID
|
108 |
+
validator_tokenizer_year.sep_token = Tokens.SEP
|
109 |
+
validator_tokenizer_year.sep_token_id = Tokens.SEP_ID
|
110 |
+
|
111 |
+
# Load LM tokenizers
|
112 |
+
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(args.model_path_full)
|
113 |
+
|
114 |
+
generation = "BASIC"
|
115 |
+
|
116 |
+
def decoder_helper(type, user_input):
|
117 |
+
if type == "BASIC":
|
118 |
+
tokenized = tokenizer.encode(user_input, return_tensors='pt', truncation=True)
|
119 |
+
out = model.model.generate(tokenized.to(device),
|
120 |
+
max_length=512,
|
121 |
+
do_sample=True,
|
122 |
+
top_k=50,
|
123 |
+
eos_token_id = tokenizer.eos_token_id,
|
124 |
+
early_stopping=True,
|
125 |
+
pad_token_id= tokenizer.pad_token_id)
|
126 |
+
return tokenizer.decode(out.cpu()[0], skip_special_tokens=True)
|
127 |
+
if type=="FORCED":
|
128 |
+
return model.generate_forced(user_input, tokenizer, sample=True, device=device)
|
129 |
+
|
130 |
+
help = f"Current setting is {generation} generating.\nChange it by writing FORCED/BASIC to input. type HELP for HELP.\nType EXIT to exit."
|
131 |
+
|
132 |
+
print("Welcome to simple czech strophe generation.", help)
|
133 |
+
|
134 |
+
while True:
|
135 |
+
|
136 |
+
user_input = ""
|
137 |
+
while True:
|
138 |
+
curr_line = input(">").strip()
|
139 |
+
if curr_line == 'EXIT':
|
140 |
+
sys.exit()
|
141 |
+
elif curr_line == "HELP":
|
142 |
+
print(help)
|
143 |
+
continue
|
144 |
+
elif curr_line == "BASIC":
|
145 |
+
print("Changed to BASIC")
|
146 |
+
generation = 'BASIC'
|
147 |
+
continue
|
148 |
+
elif curr_line == "FORCED":
|
149 |
+
print("Changed to FORCED")
|
150 |
+
generation = "FORCED"
|
151 |
+
continue
|
152 |
+
if not curr_line:
|
153 |
+
break
|
154 |
+
user_input += curr_line + "\n"
|
155 |
+
|
156 |
+
user_input = user_input.strip()
|
157 |
+
user_reqs = model.analyze_prompt(user_input)
|
158 |
+
|
159 |
+
if "RHYME" not in user_reqs.keys() and generation == "BASIC":
|
160 |
+
print("BASIC generation can't work with imputed format.", help)
|
161 |
+
print("User input is substituted for #")
|
162 |
+
user_input = '#'
|
163 |
+
|
164 |
+
generated_poem:str = decoder_helper(generation, user_input)
|
165 |
+
|
166 |
+
# Predictions
|
167 |
+
meters = []
|
168 |
+
rhyme_pred = ''
|
169 |
+
year_pred = 0
|
170 |
+
for line in generated_poem.splitlines():
|
171 |
+
# Skip Empty lines
|
172 |
+
if not line.strip():
|
173 |
+
break
|
174 |
+
if not (TextManipulation._remove_most_nonchar(line)).strip():
|
175 |
+
break
|
176 |
+
# Validate for Strophe Parameters
|
177 |
+
if TextAnalysis._is_param_line(line):
|
178 |
+
data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_rhyme,
|
179 |
+
is_syllable=False, syllables=args.val_syllables_rhyme,
|
180 |
+
max_len=rhyme_model.model.config.max_position_embeddings - 2)
|
181 |
+
rhyme_pred =StropheParams.RHYME[np.argmax(rhyme_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())]
|
182 |
+
data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_year,
|
183 |
+
is_syllable=False, syllables=args.val_syllables_year,
|
184 |
+
max_len=year_model.model.config.max_position_embeddings - 2)
|
185 |
+
year_pred = round(year_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy()[0])
|
186 |
+
continue
|
187 |
+
data = CorpusDatasetPytorch.collate_meter([{"input_ids" :["FIRST LINE SKIP!\n" + line]}],tokenizer=validator_tokenizer_meter,
|
188 |
+
is_syllable=False, syllables=args.val_syllables_meter,
|
189 |
+
max_len=meter_model.model.config.max_position_embeddings - 2)
|
190 |
+
meters.append(
|
191 |
+
StropheParams.METER[np.argmax(meter_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())]
|
192 |
+
)
|
193 |
+
print(f"REQUESTED: {user_reqs}, GENERATED USING: {generation}\n")
|
194 |
+
print(generated_poem.strip())
|
195 |
+
print(f"PREDICTED: {rhyme_pred}, {year_pred}, {meters}\n\n")
|
utils/__init__.py
ADDED
File without changes
|
utils/base_poet_models.py
ADDED
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .poet_model_utils import PoetModelInterface
|
2 |
+
from .poet_utils import TextAnalysis, StropheParams
|
3 |
+
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
from transformers.utils import ModelOutput
|
6 |
+
import random
|
7 |
+
import torch
|
8 |
+
|
9 |
+
class PoetModelFunctionalInterface(PoetModelInterface):
|
10 |
+
"""Poet Model Functional Interface. Abstract class with implementation of
|
11 |
+
|
12 |
+
Args:
|
13 |
+
PoetModelInterface (_type_): Is child of PoetModelInterface for carrying core methods
|
14 |
+
"""
|
15 |
+
def __init__(self, *args, **kwargs) -> None:
|
16 |
+
""" Constructor. As child Class needs to construct Parent
|
17 |
+
"""
|
18 |
+
super().__init__(*args, **kwargs)
|
19 |
+
|
20 |
+
def analyze_prompt(self, prompt) -> dict:
|
21 |
+
"""Analysis of users prompt
|
22 |
+
|
23 |
+
Args:
|
24 |
+
prompt (_type_): dict or string, carrying users intent
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
dict: Analysis with users intended input
|
28 |
+
"""
|
29 |
+
if isinstance(prompt, dict):
|
30 |
+
return prompt
|
31 |
+
features_dict = {}
|
32 |
+
lines = prompt.splitlines()
|
33 |
+
lines = list(map(str.strip, lines))
|
34 |
+
i = 0
|
35 |
+
while i < len(lines):
|
36 |
+
if not lines[i]:
|
37 |
+
lines.pop(i)
|
38 |
+
i-=1
|
39 |
+
i+=1
|
40 |
+
cont_line = 0
|
41 |
+
for line in lines:
|
42 |
+
if TextAnalysis._is_param_line(line):
|
43 |
+
for key, value in TextAnalysis._first_line_analysis(line).items():
|
44 |
+
features_dict[key] = value
|
45 |
+
else:
|
46 |
+
val = cont_line
|
47 |
+
if "RHYME" in features_dict.keys() and cont_line < len(features_dict['RHYME']):
|
48 |
+
if features_dict["RHYME"][cont_line] == "A":
|
49 |
+
val = 0
|
50 |
+
elif features_dict["RHYME"][cont_line] == "B":
|
51 |
+
val = 1
|
52 |
+
elif features_dict["RHYME"][cont_line] == "C":
|
53 |
+
val = 2
|
54 |
+
elif features_dict["RHYME"][cont_line] == "D":
|
55 |
+
val = 3
|
56 |
+
for key, value in TextAnalysis._continuos_line_analysis(line).items():
|
57 |
+
features_dict[f"{key}_{val}"] = value
|
58 |
+
cont_line += 1
|
59 |
+
|
60 |
+
return features_dict
|
61 |
+
|
62 |
+
def generate_forced(self, prompt, tokenizer: AutoTokenizer, sample: bool = True, format: str = 'METER_VERSE', device= torch.device('cpu'), *args, **kwargs) -> str:
|
63 |
+
"""Generate Strophe using the FORCED generation
|
64 |
+
|
65 |
+
Args:
|
66 |
+
prompt (_type_): dict or string of users intended parameters of strophe start
|
67 |
+
tokenizer (AutoTokenizer): tokenizer to be used during generation. Should be model specific.
|
68 |
+
sample (bool, optional): If to sample. Defaults to False.
|
69 |
+
format (str, optional): Format of generation to be used. Should be same as trained on. possible formats: BASIC, VERSE_PAR, METER_VERSE, OLD (DEPRECATED! For old models compatibility only). Defaults to 'METER_VERSE'.
|
70 |
+
device (_type_, optional): Device to generate on. CPU as default. Defaults to torch.device('cpu').
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
str: Generated Strophe
|
74 |
+
"""
|
75 |
+
features_dict_init = self.analyze_prompt(prompt)
|
76 |
+
# If user parameters as dict, list is initialized to carry future verses.
|
77 |
+
if isinstance(prompt, dict):
|
78 |
+
prompt_list = []
|
79 |
+
else:
|
80 |
+
prompt_list = prompt.splitlines()
|
81 |
+
# GENERATE FOR POSSIBLE MISSING POET PARAM
|
82 |
+
token_gen_rhyme = tokenizer.encode("#", return_tensors='pt')
|
83 |
+
if sample:
|
84 |
+
rhyme_line = self.model.generate(token_gen_rhyme.to(device),
|
85 |
+
max_new_tokens= 100,
|
86 |
+
do_sample=True,
|
87 |
+
top_k=50,
|
88 |
+
early_stopping=True,
|
89 |
+
pad_token_id=tokenizer.pad_token_id,
|
90 |
+
eos_token_id=tokenizer.eos_token_id)
|
91 |
+
else:
|
92 |
+
rhyme_line = self.model.generate(token_gen_rhyme.to(device),
|
93 |
+
max_new_tokens= 100,
|
94 |
+
num_beams=8,
|
95 |
+
no_repeat_ngram_size=2,
|
96 |
+
early_stopping=True,
|
97 |
+
pad_token_id=tokenizer.pad_token_id,
|
98 |
+
eos_token_id=tokenizer.eos_token_id)
|
99 |
+
rhyme_dec = tokenizer.decode(rhyme_line.cpu()[0], skip_special_tokens=True).splitlines()[0]
|
100 |
+
features_dict= TextAnalysis._first_line_analysis(rhyme_dec)
|
101 |
+
for key, value in features_dict_init.items():
|
102 |
+
features_dict[key] = value
|
103 |
+
# CONSTRUCT BEST INPUT LINE
|
104 |
+
# BACKUP RHYME
|
105 |
+
if "RHYME" not in features_dict.keys():
|
106 |
+
features_dict["RHYME"] = random.choice(StropheParams.RHYME[:-1])
|
107 |
+
#OLD
|
108 |
+
if format == 'OLD':
|
109 |
+
poet_param_str = ""
|
110 |
+
if "RHYME" in features_dict.keys():
|
111 |
+
poet_param_str += features_dict["RHYME"]
|
112 |
+
if "YEAR" in features_dict.keys():
|
113 |
+
poet_param_str += f" # {features_dict['YEAR']}"
|
114 |
+
if 'STROPHE_METER' in features_dict.keys():
|
115 |
+
poet_param_str += f" # {features_dict['STROPHE_METER']}"
|
116 |
+
|
117 |
+
elif format != 'METER_VERSE':
|
118 |
+
poet_param_str = "# "
|
119 |
+
if "RHYME" in features_dict.keys():
|
120 |
+
poet_param_str += features_dict["RHYME"]
|
121 |
+
if "YEAR" in features_dict.keys():
|
122 |
+
poet_param_str += f" # {features_dict['YEAR']}"
|
123 |
+
if 'STROPHE_METER' in features_dict.keys():
|
124 |
+
poet_param_str += f" # {features_dict['STROPHE_METER']}"
|
125 |
+
# NEW
|
126 |
+
else:
|
127 |
+
poet_param_str = "# "
|
128 |
+
if "RHYME" in features_dict.keys():
|
129 |
+
poet_param_str += features_dict["RHYME"]
|
130 |
+
if "YEAR" in features_dict.keys():
|
131 |
+
poet_param_str += f" # {features_dict['YEAR']}"
|
132 |
+
# REPLACE OR INSERT BASED ON PRESENCE
|
133 |
+
if len(features_dict_init.keys()) == 0: # Wierd Input
|
134 |
+
prompt_list = [poet_param_str]
|
135 |
+
elif len(prompt_list) == 0: # Inputed as Dict
|
136 |
+
prompt_list.append(poet_param_str)
|
137 |
+
elif "RHYME" not in features_dict_init.keys():
|
138 |
+
if "YEAR" in features_dict_init.keys() or 'STROPHE_METER' in features_dict_init.keys(): # Replace the Uncomplete first line
|
139 |
+
prompt_list[0] = poet_param_str
|
140 |
+
else:
|
141 |
+
prompt_list.insert(0, poet_param_str)
|
142 |
+
else:
|
143 |
+
prompt_list[0] = poet_param_str
|
144 |
+
|
145 |
+
verse_len = len(features_dict["RHYME"])
|
146 |
+
|
147 |
+
# Finish possible not completed lines
|
148 |
+
base_prompt_len = len(prompt_list)
|
149 |
+
for i in range(2,base_prompt_len + 1):
|
150 |
+
rhyme_char = 0
|
151 |
+
if features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "B":
|
152 |
+
rhyme_char = 1
|
153 |
+
elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "C":
|
154 |
+
rhyme_char = 2
|
155 |
+
elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "D":
|
156 |
+
rhyme_char = 3
|
157 |
+
elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "X":
|
158 |
+
rhyme_char = -1
|
159 |
+
|
160 |
+
token_gen_finish = tokenizer.encode("\n".join(prompt_list[:i]), return_tensors='pt')
|
161 |
+
if sample:
|
162 |
+
finish_line = self.model.generate(token_gen_finish.to(device),
|
163 |
+
max_new_tokens= 100,
|
164 |
+
do_sample=True,
|
165 |
+
top_k=50,
|
166 |
+
early_stopping=True,
|
167 |
+
pad_token_id=tokenizer.pad_token_id,
|
168 |
+
eos_token_id=tokenizer.eos_token_id)
|
169 |
+
else:
|
170 |
+
finish_line = self.model.generate(token_gen_finish.to(device),
|
171 |
+
max_new_tokens= 100,
|
172 |
+
num_beams=8,
|
173 |
+
no_repeat_ngram_size=2,
|
174 |
+
early_stopping=True,
|
175 |
+
pad_token_id=tokenizer.pad_token_id,
|
176 |
+
eos_token_id=tokenizer.eos_token_id)
|
177 |
+
decoded = tokenizer.decode(finish_line.cpu()[0], skip_special_tokens=True).splitlines()
|
178 |
+
to_dec = min(i, len(decoded))
|
179 |
+
prompt_list[:to_dec] = decoded[:to_dec]
|
180 |
+
|
181 |
+
if to_dec - 1 < len(prompt_list):
|
182 |
+
dec_line = prompt_list[to_dec-1]
|
183 |
+
#OLD
|
184 |
+
if format == 'VERSE_PAR' or format == 'OLD':
|
185 |
+
if f"END_{rhyme_char}" not in features_dict.keys() and len(dec_line.split()) > 1 and rhyme_char>=0 and dec_line.count("#") <=1:
|
186 |
+
features_dict[f'LENGTH_{rhyme_char}'] = dec_line.split()[0]
|
187 |
+
features_dict[f'END_{rhyme_char}'] = dec_line.split()[1]
|
188 |
+
elif f"END_{rhyme_char}" not in features_dict.keys() and len(dec_line.split()) > 2 and rhyme_char>=0:
|
189 |
+
features_dict[f'LENGTH_{rhyme_char}'] = dec_line.split()[0]
|
190 |
+
features_dict[f'END_{rhyme_char}'] = dec_line.split()[2]
|
191 |
+
# NEW
|
192 |
+
elif format == 'METER_VERSE':
|
193 |
+
if f"END_{rhyme_char}" not in features_dict.keys() and len(dec_line.split()) > 4 and rhyme_char>=0:
|
194 |
+
features_dict[f'METER_{rhyme_char}'] = dec_line.split()[0]
|
195 |
+
features_dict[f'LENGTH_{rhyme_char}'] = dec_line.split()[2]
|
196 |
+
features_dict[f'END_{rhyme_char}'] = dec_line.split()[4]
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
# Generating 4 verse rhymes
|
201 |
+
has_rep= False
|
202 |
+
has_rep_again = False
|
203 |
+
while len(prompt_list) <= verse_len:
|
204 |
+
j = 0
|
205 |
+
if features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "B":
|
206 |
+
j = 1
|
207 |
+
elif features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "C":
|
208 |
+
j = 2
|
209 |
+
elif features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "D":
|
210 |
+
j = 3
|
211 |
+
elif features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "X":
|
212 |
+
j=-1
|
213 |
+
#OLD
|
214 |
+
if format == 'BASIC':
|
215 |
+
line_start = ""
|
216 |
+
elif format == 'OLD':
|
217 |
+
line_start = (f"{features_dict[f'LENGTH_{j}']} " if f"LENGTH_{j}" in features_dict.keys() else "" ) + \
|
218 |
+
(f" {features_dict[f'END_{j}'] } #" if f"END_{j}" in features_dict.keys() else "")
|
219 |
+
elif format == 'VERSE_PAR':
|
220 |
+
line_start = (f"{features_dict[f'LENGTH_{j}']} #" if f"LENGTH_{j}" in features_dict.keys() else "" ) + \
|
221 |
+
(f" {features_dict[f'END_{j}'] } #" if f"END_{j}" in features_dict.keys() else "")
|
222 |
+
else:
|
223 |
+
line_start = (f"{features_dict[f'METER_{j}'] } #" if f"METER_{j}" in features_dict.keys() else "") + \
|
224 |
+
(f" {features_dict[f'LENGTH_{j}']} #" if f"LENGTH_{j}" in features_dict.keys() else "" ) + \
|
225 |
+
(f" {features_dict[f'END_{j}'] } #" if f"END_{j}" in features_dict.keys() else "")
|
226 |
+
tokenized_poet_start = tokenizer.encode("\n".join(prompt_list) + "\n" + line_start, return_tensors='pt')
|
227 |
+
if sample:
|
228 |
+
out_line = self.model.generate(tokenized_poet_start.to(device),
|
229 |
+
max_new_tokens= 100,
|
230 |
+
do_sample=True,
|
231 |
+
top_k=50,
|
232 |
+
early_stopping=True,
|
233 |
+
pad_token_id=tokenizer.pad_token_id,
|
234 |
+
eos_token_id=tokenizer.eos_token_id)
|
235 |
+
else:
|
236 |
+
out_line = self.model.generate(tokenized_poet_start.to(device),
|
237 |
+
max_new_tokens= 100,
|
238 |
+
num_beams=2,
|
239 |
+
no_repeat_ngram_size=2,
|
240 |
+
early_stopping=True,
|
241 |
+
pad_token_id=tokenizer.pad_token_id,
|
242 |
+
eos_token_id=tokenizer.eos_token_id)
|
243 |
+
decoded_lines = tokenizer.decode(out_line.cpu()[0], skip_special_tokens=True).splitlines()
|
244 |
+
# Repetition catcher
|
245 |
+
|
246 |
+
# Possible
|
247 |
+
if len(decoded_lines) <= len(prompt_list) and not(has_rep_again and has_rep):
|
248 |
+
if has_rep:
|
249 |
+
prompt_list.pop()
|
250 |
+
has_rep= False
|
251 |
+
has_rep_again = True
|
252 |
+
else:
|
253 |
+
has_rep = True
|
254 |
+
continue
|
255 |
+
if has_rep_again and has_rep:
|
256 |
+
decoded_line: str = decoded_lines[-1]
|
257 |
+
else:
|
258 |
+
decoded_line: str = decoded_lines[len(prompt_list)]
|
259 |
+
#OLD
|
260 |
+
if format == 'VERSE_PAR' or format == 'OLD':
|
261 |
+
if f"END_{j}" not in features_dict.keys() and len(decoded_line.split()) > 1 and j>=0 and decoded_line.count("#") <=1:
|
262 |
+
features_dict[f'LENGTH_{j}'] = decoded_line.split()[0]
|
263 |
+
features_dict[f'END_{j}'] = decoded_line.split()[1]
|
264 |
+
elif f"END_{j}" not in features_dict.keys() and len(decoded_line.split()) > 2 and j>=0:
|
265 |
+
features_dict[f'LENGTH_{j}'] = decoded_line.split()[0]
|
266 |
+
features_dict[f'END_{j}'] = decoded_line.split()[2]
|
267 |
+
# NEW
|
268 |
+
elif format == 'METER_VERSE':
|
269 |
+
if f"END_{j}" not in features_dict.keys() and len(decoded_line.split()) > 4 and j>=0:
|
270 |
+
features_dict[f'METER_{j}'] = decoded_line.split()[0]
|
271 |
+
features_dict[f'LENGTH_{j}'] = decoded_line.split()[2]
|
272 |
+
features_dict[f'END_{j}'] = decoded_line.split()[4]
|
273 |
+
|
274 |
+
prompt_list.append(decoded_line)
|
275 |
+
|
276 |
+
return "\n".join(prompt_list)
|
277 |
+
|
278 |
+
|
279 |
+
class PoetModelBase(PoetModelFunctionalInterface):
|
280 |
+
def __init__(self, pretrainedModel, *args, **kwargs) -> None:
|
281 |
+
super().__init__(*args, **kwargs)
|
282 |
+
|
283 |
+
self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
|
284 |
+
|
285 |
+
model_config = self.model.config
|
286 |
+
self.model_size = 1
|
287 |
+
# Check for Hidden layer size by Attribute Name
|
288 |
+
if hasattr(model_config, "n_embd"):
|
289 |
+
self.model_size = model_config.n_embd
|
290 |
+
elif hasattr(model_config, "hidden_size"):
|
291 |
+
self.model_size = model_config.hidden_size
|
292 |
+
|
293 |
+
|
294 |
+
def forward(self, input_ids=None, labels=None, attention_mask=None, *args, **kwargs):
|
295 |
+
outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
|
296 |
+
|
297 |
+
return ModelOutput(loss= outputs.loss, model_output=outputs) # {"model_output" : outputs,"loss" : outputs.loss}
|
298 |
+
|
299 |
+
def save_LM(self, LM_path):
|
300 |
+
self.model.save_pretrained(LM_path, safe_serialization=False)
|
301 |
+
|
302 |
+
|
303 |
+
class PoetModelAllTasks(PoetModelFunctionalInterface):
|
304 |
+
def __init__(self, pretrainedModel, *args, **kwargs) -> None:
|
305 |
+
super().__init__(*args, **kwargs)
|
306 |
+
|
307 |
+
self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
|
308 |
+
|
309 |
+
model_config = self.model.config
|
310 |
+
self.model_size = 1
|
311 |
+
# Check for Hidden layer size by Attribute Name
|
312 |
+
if hasattr(model_config, "n_embd"):
|
313 |
+
self.model_size = model_config.n_embd
|
314 |
+
elif hasattr(model_config, "hidden_size"):
|
315 |
+
self.model_size = model_config.hidden_size
|
316 |
+
|
317 |
+
self.vowels_regressor = torch.nn.Linear(self.model_size,1) # Vowel Count
|
318 |
+
self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
|
319 |
+
self.verse_endings = torch.nn.Linear(self.model_size, len(StropheParams.ENDS)) # Verse End Syllable
|
320 |
+
self.metre_regressor = torch.nn.Linear(self.model_size,len(StropheParams.METER)) # Meter Type
|
321 |
+
self.year_regressor = torch.nn.Linear(self.model_size,len(StropheParams.YEAR)) # Year Bucket
|
322 |
+
|
323 |
+
|
324 |
+
def forward(self, input_ids=None, labels=None, attention_mask=None, nums=None, rhyme=None, verse_end=None, year=None, metre=None, *args, **kwargs):
|
325 |
+
outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
|
326 |
+
last_hidden = outputs['hidden_states'][-1]
|
327 |
+
vowel_regression = self.vowels_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
328 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
329 |
+
verse_end_reg = self.verse_endings((last_hidden[:,0,:].view(-1, self.model_size)))
|
330 |
+
metre_regression = self.metre_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
331 |
+
year_regression = self.year_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
332 |
+
full_loss = outputs.loss
|
333 |
+
|
334 |
+
vowel_loss = None
|
335 |
+
if nums is not None:
|
336 |
+
loss_fct = torch.nn.MSELoss()
|
337 |
+
vowel_loss = loss_fct(vowel_regression.view(-1, 1), nums.view(-1, 1))
|
338 |
+
full_loss = full_loss + 0.1*vowel_loss
|
339 |
+
|
340 |
+
rhyme_loss = None
|
341 |
+
if rhyme is not None:
|
342 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
343 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
344 |
+
rhyme_loss = loss_fct(softmaxed, rhyme)
|
345 |
+
full_loss = full_loss + 0.1*rhyme_loss
|
346 |
+
|
347 |
+
verse_loss = None
|
348 |
+
if verse_end is not None:
|
349 |
+
softmaxed = torch.softmax(verse_end_reg, dim=1)
|
350 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
351 |
+
verse_loss = loss_fct(softmaxed, verse_end)
|
352 |
+
full_loss = full_loss + 0.1*verse_loss
|
353 |
+
|
354 |
+
metre_loss = None
|
355 |
+
if metre is not None:
|
356 |
+
softmaxed = torch.softmax(metre_regression, dim=1)
|
357 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
358 |
+
metre_loss = loss_fct(softmaxed, metre)
|
359 |
+
full_loss = full_loss + 0.1*metre_loss
|
360 |
+
|
361 |
+
year_loss = None
|
362 |
+
if year is not None:
|
363 |
+
softmaxed = torch.softmax(year_regression, dim=1)
|
364 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
365 |
+
year_loss = loss_fct(softmaxed, year)
|
366 |
+
full_loss = full_loss + 0.1*year_loss
|
367 |
+
|
368 |
+
|
369 |
+
return {"model_output" : outputs,
|
370 |
+
"vowel_regression_output": vowel_regression,
|
371 |
+
"vowel_regression_loss": vowel_loss,
|
372 |
+
"rhyme_regression_output": rhyme_regression,
|
373 |
+
"rhyme_regression_loss": rhyme_loss,
|
374 |
+
"verse_end_regression_output" : verse_end_reg,
|
375 |
+
"verse_end_regression_loss" : verse_loss,
|
376 |
+
"metre_regression_output" : metre_regression,
|
377 |
+
"metre_regression_loss" : metre_loss,
|
378 |
+
"year_regression_output" : year_regression,
|
379 |
+
"year_regression_loss" : year_loss,
|
380 |
+
"loss": full_loss}
|
381 |
+
|
382 |
+
def save_LM(self, LM_path):
|
383 |
+
self.model.save_pretrained(LM_path, safe_serialization=False)
|
384 |
+
|
385 |
+
from .poet_model_utils import ContextModule
|
386 |
+
|
387 |
+
class PoetModelContextInput(PoetModelFunctionalInterface):
|
388 |
+
def __init__(self, pretrainedModel, context_input_size:int = 2048, block_count:int=3, *args, **kwargs) -> None:
|
389 |
+
super().__init__(*args, **kwargs)
|
390 |
+
|
391 |
+
self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel,output_hidden_states=True)
|
392 |
+
|
393 |
+
model_config = self.model.config
|
394 |
+
self.model_size = -1
|
395 |
+
# Check for Hidden layer size by Attribute Name
|
396 |
+
if hasattr(model_config, "n_embd"):
|
397 |
+
self.model_size = model_config.n_embd
|
398 |
+
elif hasattr(model_config, "hidden_size"):
|
399 |
+
self.model_size = model_config.hidden_size # Number of Emmbedings taken from config
|
400 |
+
self.context_size = context_input_size
|
401 |
+
|
402 |
+
|
403 |
+
self.model.base_model.h.insert(3, ContextModule(block_count, context_input_size, self.model_size, self.model_size))
|
404 |
+
# Because of Inserted Layer, Head Masks don't match => Add 1 more
|
405 |
+
self.model.base_model.config.n_layer += 1
|
406 |
+
|
407 |
+
self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
|
408 |
+
|
409 |
+
|
410 |
+
def forward(self, input_ids=None, labels=None, attention_mask=None, rhyme=None, context_ids=None, context_attention_mask=None,*args, **kwargs):
|
411 |
+
# Inject Context to bypass GPT2Blocks (Can't Forward it)
|
412 |
+
self.model.base_model.h[3].context_ids = context_ids
|
413 |
+
self.model.base_model.h[3].context_attention_mask = context_attention_mask
|
414 |
+
|
415 |
+
outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
|
416 |
+
last_hidden = outputs['hidden_states'][-1]
|
417 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
418 |
+
full_loss = outputs.loss
|
419 |
+
|
420 |
+
rhyme_loss = None
|
421 |
+
if rhyme is not None:
|
422 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
423 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
424 |
+
rhyme_loss = loss_fct(softmaxed, rhyme)
|
425 |
+
full_loss = full_loss + rhyme_loss
|
426 |
+
# Delete the Injection to prevent Dataloss
|
427 |
+
self.model.base_model.h[3].context_ids = None
|
428 |
+
self.model.base_model.h[3].context_attention_mask = None
|
429 |
+
|
430 |
+
return {"model_output" : outputs,
|
431 |
+
"rhyme_regression_output": rhyme_regression,
|
432 |
+
"rhyme_regression_loss": rhyme_loss,
|
433 |
+
"loss": full_loss}
|
434 |
+
|
435 |
+
def save_LM(self, LM_path):
|
436 |
+
self.model.save_pretrained(LM_path)
|
437 |
+
|
438 |
+
from .poet_model_utils import PoetTypeModule
|
439 |
+
|
440 |
+
class PoetModelContextYear(PoetModelFunctionalInterface):
|
441 |
+
def __init__(self, pretrainedModel, context_input_size:int = 2048, block_count:int=3, *args, **kwargs) -> None:
|
442 |
+
super().__init__(*args, **kwargs)
|
443 |
+
|
444 |
+
self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
|
445 |
+
|
446 |
+
model_config = self.model.config
|
447 |
+
self.model_size = -1
|
448 |
+
# Check for Hidden layer size by Attribute Name
|
449 |
+
if hasattr(model_config, "n_embd"):
|
450 |
+
self.model_size = model_config.n_embd
|
451 |
+
elif hasattr(model_config, "hidden_size"):
|
452 |
+
self.model_size = model_config.hidden_size # Number of Emmbedings taken from config
|
453 |
+
self.context_size = context_input_size
|
454 |
+
|
455 |
+
|
456 |
+
self.model.base_model.h.insert(3, ContextModule(block_count, context_input_size, self.model_size, self.model_size))
|
457 |
+
self.model.base_model.h.insert(3, PoetTypeModule(block_count, context_input_size, self.model_size, self.model_size))
|
458 |
+
# Because of Inserted Layer, Head Masks don't match => Add 1 more
|
459 |
+
self.model.base_model.config.n_layer += 2
|
460 |
+
|
461 |
+
self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
|
462 |
+
self.year_regressor = torch.nn.Linear(self.model_size, len(StropheParams.YEAR)) # Year Bucket
|
463 |
+
|
464 |
+
|
465 |
+
def forward(self, input_ids=None, labels=None, attention_mask=None, rhyme=None, context_ids=None, context_attention_mask=None, year=None,*args, **kwargs):
|
466 |
+
# Inject Context to bypass GPT2Blocks (Can't Forward it)
|
467 |
+
self.model.base_model.h[3].context_ids = context_ids
|
468 |
+
self.model.base_model.h[3].context_attention_mask = context_attention_mask
|
469 |
+
self.model.base_model.h[3].type_labels = year
|
470 |
+
|
471 |
+
self.model.base_model.h[4].context_ids = context_ids
|
472 |
+
self.model.base_model.h[4].context_attention_mask = context_attention_mask
|
473 |
+
|
474 |
+
outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
|
475 |
+
last_hidden = outputs['hidden_states'][-1]
|
476 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
477 |
+
full_loss = outputs.loss
|
478 |
+
|
479 |
+
rhyme_loss = None
|
480 |
+
if rhyme is not None:
|
481 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
482 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
483 |
+
rhyme_loss = loss_fct(softmaxed, rhyme)
|
484 |
+
full_loss = full_loss + rhyme_loss
|
485 |
+
|
486 |
+
|
487 |
+
year_regression = self.year_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
488 |
+
|
489 |
+
year_loss = None
|
490 |
+
if year is not None:
|
491 |
+
softmaxed = torch.softmax(year_regression, dim=1)
|
492 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
493 |
+
year_loss = loss_fct(softmaxed, year)
|
494 |
+
full_loss = full_loss + year_loss + self.model.base_model.h[3].indiv_loss
|
495 |
+
|
496 |
+
# Delete the Injection to prevent Dataloss
|
497 |
+
self.model.base_model.h[3].context_ids = None
|
498 |
+
self.model.base_model.h[3].context_attention_mask = None
|
499 |
+
self.model.base_model.h[3].type_labels = None
|
500 |
+
# Delete Loss
|
501 |
+
self.model.base_model.h[3].indiv_loss = None
|
502 |
+
|
503 |
+
self.model.base_model.h[4].context_ids = None
|
504 |
+
self.model.base_model.h[4].context_attention_mask = None
|
505 |
+
|
506 |
+
return {"model_output" : outputs,
|
507 |
+
"rhyme_regression_output": rhyme_regression,
|
508 |
+
"rhyme_regression_loss": rhyme_loss,
|
509 |
+
"year_regression_output" : year_regression,
|
510 |
+
"year_loss" : year_loss,
|
511 |
+
"loss": full_loss}
|
512 |
+
|
513 |
+
def save_LM(self, LM_path):
|
514 |
+
self.model.save_pretrained(LM_path)
|
515 |
+
|
516 |
+
|
517 |
+
class DistilModel(PoetModelFunctionalInterface):
|
518 |
+
|
519 |
+
def __init__(self, pretrainedModel, *args, **kwargs) -> None:
|
520 |
+
super().__init__(*args, **kwargs)
|
521 |
+
|
522 |
+
self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
|
523 |
+
|
524 |
+
model_config = self.model.config
|
525 |
+
self.model_size = 1
|
526 |
+
# Check for Hidden layer size by Attribute Name
|
527 |
+
if hasattr(model_config, "n_embd"):
|
528 |
+
self.model_size = model_config.n_embd
|
529 |
+
elif hasattr(model_config, "hidden_size"):
|
530 |
+
self.model_size = model_config.hidden_size
|
531 |
+
|
532 |
+
self.kept_states = [1, 3, 5, 7, 9, 11]
|
533 |
+
|
534 |
+
for pop_index in sorted(list(set(range(len(self.model.base_model.h))) - set(self.kept_states)), reverse=True):
|
535 |
+
|
536 |
+
self.model.base_model.h.pop(pop_index)
|
537 |
+
# Because of Inserted Layer, Head Masks don't match => Add 1 more
|
538 |
+
self.model.base_model.config.n_layer = len(self.kept_states)
|
539 |
+
|
540 |
+
self.loss_fnc = torch.nn.MSELoss()
|
541 |
+
|
542 |
+
def forward(self, input_ids=None, labels=None, attention_mask=None, to_replicate_states= None, *args, **kwargs):
|
543 |
+
outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
|
544 |
+
loss = outputs.loss
|
545 |
+
# The 6 layers + embeddings (add + 1 to shift the original_index)
|
546 |
+
for distil_index, original_index in enumerate([-1] + self.kept_states):
|
547 |
+
loss += self.loss_fnc(outputs['hidden_states'][distil_index], to_replicate_states[original_index + 1])
|
548 |
+
|
549 |
+
return {"model_output" : outputs,
|
550 |
+
"loss": loss}
|
551 |
+
|
552 |
+
def save_LM(self, LM_path):
|
553 |
+
self.model.save_pretrained(LM_path, safe_serialization=False)
|
554 |
+
|
555 |
+
def generate_forced(self, *args, **kwargs):
|
556 |
+
raise NotImplementedError("Currently without")
|
557 |
+
|
558 |
+
class PoetModelHalfBase(PoetModelFunctionalInterface):
|
559 |
+
def __init__(self, pretrainedModel, *args, **kwargs) -> None:
|
560 |
+
super().__init__(*args, **kwargs)
|
561 |
+
|
562 |
+
self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True, torch_dtype=torch.float16)
|
563 |
+
|
564 |
+
model_config = self.model.config
|
565 |
+
self.model_size = -1
|
566 |
+
# Check for Hidden layer size by Attribute Name
|
567 |
+
if hasattr(model_config, "n_embd"):
|
568 |
+
self.model_size = model_config.n_embd
|
569 |
+
elif hasattr(model_config, "hidden_size"):
|
570 |
+
self.model_size = model_config.hidden_size
|
571 |
+
|
572 |
+
|
573 |
+
def forward(self, input_ids=None, labels=None, attention_mask=None, *args, **kwargs):
|
574 |
+
outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
|
575 |
+
|
576 |
+
return {"model_output" : outputs,
|
577 |
+
"loss" : outputs.loss}
|
578 |
+
|
579 |
+
def save_LM(self, LM_path):
|
580 |
+
self.model.save_pretrained(LM_path)
|
581 |
+
|
582 |
+
|
583 |
+
class PoetModelSecondaryTasks(PoetModelFunctionalInterface):
|
584 |
+
def __init__(self, pretrainedModel, *args, **kwargs) -> None:
|
585 |
+
super().__init__(*args, **kwargs)
|
586 |
+
|
587 |
+
self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
|
588 |
+
|
589 |
+
model_config = self.model.config
|
590 |
+
self.model_size = -1
|
591 |
+
# Check for Hidden layer size by Attribute Name
|
592 |
+
if hasattr(model_config, "n_embd"):
|
593 |
+
self.model_size = model_config.n_embd
|
594 |
+
elif hasattr(model_config, "hidden_size"):
|
595 |
+
self.model_size = model_config.hidden_size # Number of Emmbedings taken from config
|
596 |
+
self.vowels_regressor = torch.nn.Linear(self.model_size,1) # Vowel count
|
597 |
+
self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
|
598 |
+
|
599 |
+
|
600 |
+
def forward(self, input_ids=None, labels=None, attention_mask=None, nums=None, rhyme=None, *args, **kwargs):
|
601 |
+
outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
|
602 |
+
last_hidden = outputs['hidden_states'][-1]
|
603 |
+
vowel_regression = self.vowels_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
604 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
605 |
+
full_loss = outputs.loss
|
606 |
+
|
607 |
+
vowel_loss = None
|
608 |
+
if nums is not None:
|
609 |
+
loss_fct = torch.nn.MSELoss()
|
610 |
+
vowel_loss = loss_fct(vowel_regression.view(-1, 1), nums.view(-1, 1))
|
611 |
+
full_loss = full_loss + vowel_loss
|
612 |
+
|
613 |
+
rhyme_loss = None
|
614 |
+
if rhyme is not None:
|
615 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
616 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
617 |
+
rhyme_loss = loss_fct(softmaxed, rhyme)
|
618 |
+
full_loss = full_loss + rhyme_loss
|
619 |
+
|
620 |
+
|
621 |
+
return {"model_output" : outputs,
|
622 |
+
"vowel_regression_output": vowel_regression,
|
623 |
+
"vowel_regression_loss": vowel_loss,
|
624 |
+
"rhyme_regression_output": rhyme_regression,
|
625 |
+
"rhyme_regression_loss": rhyme_loss,
|
626 |
+
"loss": full_loss}
|
627 |
+
|
628 |
+
def save_LM(self, LM_path):
|
629 |
+
self.model.save_pretrained(LM_path)
|
630 |
+
|
631 |
+
|
632 |
+
class PoetModelVerseEnd(PoetModelFunctionalInterface):
|
633 |
+
def __init__(self, pretrainedModel, *args, **kwargs) -> None:
|
634 |
+
super().__init__(*args, **kwargs)
|
635 |
+
|
636 |
+
self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
|
637 |
+
|
638 |
+
model_config = self.model.config
|
639 |
+
self.model_size = -1
|
640 |
+
# Check for Hidden layer size by Attribute Name
|
641 |
+
if hasattr(model_config, "n_embd"):
|
642 |
+
self.model_size = model_config.n_embd
|
643 |
+
elif hasattr(model_config, "hidden_size"):
|
644 |
+
self.model_size = model_config.hidden_size # Number of Emmbedings taken from config
|
645 |
+
self.vowels_regressor = torch.nn.Linear(self.model_size,1) # Vowel count
|
646 |
+
self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
|
647 |
+
self.verse_endings = torch.nn.Linear(self.model_size, len(StropheParams.ENDS)) # Verse End Syllable
|
648 |
+
|
649 |
+
|
650 |
+
def forward(self, input_ids=None, labels=None, attention_mask=None, nums=None, rhyme=None, verse_end = None, *args, **kwargs):
|
651 |
+
outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
|
652 |
+
last_hidden = outputs['hidden_states'][-1]
|
653 |
+
vowel_regression = self.vowels_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
654 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
655 |
+
verse_end_reg = self.verse_endings((last_hidden[:,0,:].view(-1, self.model_size)))
|
656 |
+
full_loss = outputs.loss
|
657 |
+
|
658 |
+
vowel_loss = None
|
659 |
+
if nums is not None:
|
660 |
+
loss_fct = torch.nn.MSELoss()
|
661 |
+
vowel_loss = loss_fct(vowel_regression.view(-1, 1), nums.view(-1, 1))
|
662 |
+
full_loss = full_loss + vowel_loss
|
663 |
+
|
664 |
+
rhyme_loss = None
|
665 |
+
if rhyme is not None:
|
666 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
667 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
668 |
+
rhyme_loss = loss_fct(softmaxed, rhyme)
|
669 |
+
full_loss = full_loss + rhyme_loss
|
670 |
+
|
671 |
+
verse_loss = None
|
672 |
+
if verse_end is not None:
|
673 |
+
softmaxed = torch.softmax(verse_end_reg, dim=1)
|
674 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
675 |
+
verse_loss = loss_fct(softmaxed, verse_end)
|
676 |
+
full_loss = full_loss + verse_loss
|
677 |
+
|
678 |
+
|
679 |
+
return {"model_output" : outputs,
|
680 |
+
"vowel_regression_output": vowel_regression,
|
681 |
+
"vowel_regression_loss": vowel_loss,
|
682 |
+
"rhyme_regression_output": rhyme_regression,
|
683 |
+
"rhyme_regression_loss": rhyme_loss,
|
684 |
+
"verse_end_regression_output" : verse_end_reg,
|
685 |
+
"verse_end_regression_loss" : verse_loss,
|
686 |
+
"loss": full_loss}
|
687 |
+
|
688 |
+
def save_LM(self, LM_path):
|
689 |
+
self.model.save_pretrained(LM_path)
|
utils/poet_model_utils.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class PoetModelInterface(torch.nn.Module):
|
4 |
+
"""Pytorch Model Interface. Abstract class for all Poet model types
|
5 |
+
|
6 |
+
Args:
|
7 |
+
torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
|
8 |
+
"""
|
9 |
+
def __init__(self, *args, **kwargs) -> None:
|
10 |
+
""" Constructor. As child Class needs to construct Parent
|
11 |
+
"""
|
12 |
+
super().__init__(*args, **kwargs)
|
13 |
+
|
14 |
+
|
15 |
+
def forward(self, input_ids=None, labels=None, attention_mask=None, *args, **kwargs):
|
16 |
+
"""Compute model output and model loss
|
17 |
+
|
18 |
+
Args:
|
19 |
+
input_ids (_type_, optional): Model inputs. Defaults to None.
|
20 |
+
labels (_type_, optional): Language Model labels. Defaults to None.
|
21 |
+
attention_mask (_type_, optional): Attention mask where padding starts. Defaults to None.
|
22 |
+
|
23 |
+
Raises:
|
24 |
+
NotImplementedError: Abstract class
|
25 |
+
"""
|
26 |
+
raise NotImplementedError()
|
27 |
+
|
28 |
+
def generate_forced(self, *args, **kwargs):
|
29 |
+
"""Generates model output with restriction on inputs and past generation
|
30 |
+
|
31 |
+
Raises:
|
32 |
+
NotImplementedError: Abstract class
|
33 |
+
"""
|
34 |
+
raise NotImplementedError()
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def rhyme_like(rhyme:str):
|
38 |
+
"""DEPRECATED: Check string in rhyme format
|
39 |
+
|
40 |
+
Args:
|
41 |
+
rhyme (str): String with possible rhyme
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
bool: Boolean if string like rhyme
|
45 |
+
"""
|
46 |
+
return rhyme.isupper() and len(rhyme) in [4,6]
|
47 |
+
|
48 |
+
def save_LM(self, LM_path):
|
49 |
+
"""Save raw LM
|
50 |
+
|
51 |
+
Args:
|
52 |
+
LM_path (str): Where to store the LM
|
53 |
+
|
54 |
+
Raises:
|
55 |
+
NotImplementedError: Abstract class
|
56 |
+
"""
|
57 |
+
raise NotImplementedError()
|
58 |
+
|
59 |
+
|
60 |
+
from transformers import GPT2Config, GPT2Model
|
61 |
+
from .poet_utils import StropheParams
|
62 |
+
|
63 |
+
class ContextModule(torch.nn.Module):
|
64 |
+
"""Module for understanding poet context
|
65 |
+
|
66 |
+
Args:
|
67 |
+
torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
|
68 |
+
"""
|
69 |
+
def __init__(self, block_count, input_size, n_embd ,output_size,*args, **kwargs) -> None:
|
70 |
+
"""Construct the underlying small LM for context
|
71 |
+
|
72 |
+
Args:
|
73 |
+
block_count (_type_): LM number of blocks of GPT2Block
|
74 |
+
input_size (_type_): LM size of input
|
75 |
+
n_embd (_type_): LM size of hidden layers
|
76 |
+
output_size (_type_): LM size of output
|
77 |
+
"""
|
78 |
+
super().__init__(*args, **kwargs)
|
79 |
+
self.config = GPT2Config(n_positions=input_size, n_head=(n_embd//(768//12)),n_embd=n_embd,
|
80 |
+
n_layer=block_count, output_hidden_states=True, output_attentions =True)
|
81 |
+
self.context_model = GPT2Model(self.config)
|
82 |
+
self.linear_downscale = torch.nn.Linear(n_embd, output_size)
|
83 |
+
self.input_size = input_size
|
84 |
+
self.n_embd = n_embd
|
85 |
+
self.output_size = output_size
|
86 |
+
# Context is getting injected from Outside
|
87 |
+
self.context_ids = None
|
88 |
+
self.context_attention_mask = None
|
89 |
+
|
90 |
+
|
91 |
+
def forward(self, hidden_states,layer_past=None,*args, **kwargs):
|
92 |
+
"""Compute Context LM output, Data are injected from outside
|
93 |
+
|
94 |
+
Args:
|
95 |
+
hidden_states (_type_): Current hidden states
|
96 |
+
layer_past (_type_, optional): Past layer outputs. Defaults to None.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
_type_: GPT2Block structured output (hidden states, layer past, attention, keys)
|
100 |
+
"""
|
101 |
+
down = torch.zeros_like(hidden_states)
|
102 |
+
model_output = None
|
103 |
+
# Sometimes there might be no context
|
104 |
+
if self.context_ids != None:
|
105 |
+
model_output = self.context_model.forward(input_ids=self.context_ids, attention_mask=self.context_attention_mask)
|
106 |
+
# Take only the Class token as
|
107 |
+
down = self.linear_downscale.forward(model_output["hidden_states"][-1][:,0,:].view(-1, self.n_embd))[:, None, :]
|
108 |
+
return (hidden_states + down,
|
109 |
+
down[None, :, :, :],
|
110 |
+
(None if model_output == None else model_output["attentions"],
|
111 |
+
None))
|
112 |
+
|
113 |
+
class PoetTypeModule(torch.nn.Module):
|
114 |
+
"""Module to classify poet type
|
115 |
+
|
116 |
+
Args:
|
117 |
+
torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(self, block_count, input_size, n_embd,output_size,*args, **kwargs) -> None:
|
121 |
+
"""Construct LM for poet classification from inputs
|
122 |
+
|
123 |
+
Args:
|
124 |
+
block_count (_type_): LM number of blocks of GPT2Block
|
125 |
+
input_size (_type_): LM size of input
|
126 |
+
n_embd (_type_): LM size of hidden layers
|
127 |
+
output_size (_type_): LM size of output
|
128 |
+
"""
|
129 |
+
super().__init__(*args, **kwargs)
|
130 |
+
self.config = GPT2Config(n_positions=input_size, n_head=(n_embd//(768//12)),n_embd=n_embd,
|
131 |
+
n_layer=block_count, output_hidden_states=True, output_attentions =True)
|
132 |
+
self.type_model = GPT2Model(self.config)
|
133 |
+
self.type_predict = torch.nn.Linear(n_embd, len(StropheParams.YEAR))
|
134 |
+
self.softmax = torch.nn.Softmax()
|
135 |
+
self.linear_scale = torch.nn.Linear(len(StropheParams.YEAR), output_size)
|
136 |
+
self.input_size = input_size
|
137 |
+
self.n_embd = n_embd
|
138 |
+
self.output_size = output_size
|
139 |
+
# Context and labels are getting injected from Outside
|
140 |
+
self.context_ids = None
|
141 |
+
self.context_attention_mask = None
|
142 |
+
self.type_labels=None
|
143 |
+
# Store for loss for model itself
|
144 |
+
self.indiv_loss=None
|
145 |
+
|
146 |
+
def forward(self, hidden_states,layer_past=None,*args, **kwargs):
|
147 |
+
"""Compute Classification LM output and loss
|
148 |
+
|
149 |
+
Args:
|
150 |
+
hidden_states (_type_): Current hidden states
|
151 |
+
layer_past (_type_, optional): Past layer outputs. Defaults to None.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
_type_: GPT2Block structured output (hidden states, layer past, attention, keys)
|
155 |
+
"""
|
156 |
+
type_prob = torch.zeros((hidden_states.shape[0], len(StropheParams.YEAR))).to("cuda" if torch.cuda.is_available() else "cpu")
|
157 |
+
model_output = None
|
158 |
+
# Sometimes there might be no context
|
159 |
+
if self.context_ids != None:
|
160 |
+
model_output = self.type_model.forward(input_ids=self.context_ids, attention_mask=self.context_attention_mask)
|
161 |
+
# Only Class token is taken
|
162 |
+
poet_type = self.type_predict.forward(model_output["hidden_states"][-1][:,0,:].view(-1, self.n_embd))
|
163 |
+
type_prob = self.softmax.forward(poet_type)
|
164 |
+
# If type labels are present, inject the true labels to future blocks
|
165 |
+
if self.type_labels != None:
|
166 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
167 |
+
self.indiv_loss = loss_fct(type_prob, self.type_labels)
|
168 |
+
type_prob = (self.type_labels.type(torch.FloatTensor)).to("cuda" if torch.cuda.is_available() else "cpu")
|
169 |
+
linear_up = self.linear_scale.forward(type_prob)
|
170 |
+
return (hidden_states + linear_up[:, None, :],
|
171 |
+
linear_up[None, :, None, :],
|
172 |
+
(None if model_output == None else model_output["attentions"],
|
173 |
+
None))
|
174 |
+
|
175 |
+
from transformers import PreTrainedTokenizerBase
|
176 |
+
|
177 |
+
class ModelManipulation:
|
178 |
+
"""Static Class incorporating methods for Manipulation with LMs
|
179 |
+
Code Inspired by article: Fine-tuning the English GPT-2 in any language with Hugging Face
|
180 |
+
Link: https://github.com/piegu/fastai-projects/blob/master/finetuning-English-GPT2-any-language-Portuguese-HuggingFace-fastaiv2.ipynb
|
181 |
+
"""
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def exchange_embedding(poet_model: PoetModelInterface, new_tokenizer: PreTrainedTokenizerBase, old_tokenizer: PreTrainedTokenizerBase, mirror_imbed:bool=False):
|
185 |
+
"""Exchange embedding matrixes for GPT2 Models
|
186 |
+
|
187 |
+
Args:
|
188 |
+
poet_model (PoetModelInterface): Model to manipulate with
|
189 |
+
new_tokenizer (PreTrainedTokenizerBase): New tokenization
|
190 |
+
old_tokenizer (PreTrainedTokenizerBase): Old tokenization
|
191 |
+
"""
|
192 |
+
# Get old Embeddings
|
193 |
+
if hasattr(poet_model.model, "transformer"):
|
194 |
+
old_embed_in = poet_model.model.transformer.get_input_embeddings().weight.clone().detach()
|
195 |
+
else:
|
196 |
+
old_embed_in = poet_model.model.get_input_embeddings().weight.clone().detach()
|
197 |
+
old_mean_in = old_embed_in.mean(0)
|
198 |
+
# Generate new Embedding based on new tokenization
|
199 |
+
new_embd_in = old_embed_in.new_zeros(new_tokenizer.vocab_size, old_embed_in.size(1))
|
200 |
+
old_vocab = old_tokenizer.get_vocab()
|
201 |
+
|
202 |
+
vocab_hit = 0
|
203 |
+
# Keep as much from old Embeddings as possible
|
204 |
+
for w, idx_new in new_tokenizer.get_vocab().items():
|
205 |
+
idx_old = old_vocab.get(w, -1)
|
206 |
+
if idx_old >= 0:
|
207 |
+
new_embd_in[idx_new] = old_embed_in[idx_old]
|
208 |
+
vocab_hit +=1
|
209 |
+
else:
|
210 |
+
new_embd_in[idx_new] = old_mean_in
|
211 |
+
|
212 |
+
print(f"Vocab hit rate: {vocab_hit}/{old_tokenizer.vocab_size}")
|
213 |
+
#Exchange Embeddings and Decoding
|
214 |
+
new_embd_layer_in = torch.nn.Embedding(new_tokenizer.vocab_size, old_embed_in.size(1))
|
215 |
+
new_embd_layer_in.weight.data = new_embd_in
|
216 |
+
if hasattr(poet_model.model, "transformer"):
|
217 |
+
poet_model.model.transformer.set_input_embeddings(new_embd_layer_in)
|
218 |
+
else:
|
219 |
+
poet_model.model.set_input_embeddings(new_embd_layer_in)
|
220 |
+
|
221 |
+
new_decoder = torch.nn.Linear( old_embed_in.size(1), new_tokenizer.vocab_size, bias=False)
|
222 |
+
if hasattr(poet_model.model, "transformer"):
|
223 |
+
new_decoder.weight = poet_model.model.transformer.wte.weight
|
224 |
+
else:
|
225 |
+
new_decoder.weight = poet_model.model.base_model.embeddings.weight
|
226 |
+
if hasattr(poet_model.model, "lm_head"):
|
227 |
+
poet_model.model.lm_head = new_decoder
|
228 |
+
else:
|
229 |
+
poet_model.model.head = new_decoder
|
230 |
+
|
231 |
+
|
232 |
+
# Update LM config to reflect possible change in vocab size
|
233 |
+
poet_model.model.config.vocab_size = new_tokenizer.vocab_size
|
234 |
+
|
235 |
+
|
236 |
+
@staticmethod
|
237 |
+
def exchange_embedding_roberta(metre_model, new_tokenizer: PreTrainedTokenizerBase, old_tokenizer: PreTrainedTokenizerBase):
|
238 |
+
"""Exchange embedding matrixes for Roberta Models
|
239 |
+
|
240 |
+
Args:
|
241 |
+
poet_model (PoetModelInterface): Model to manipulate with
|
242 |
+
new_tokenizer (PreTrainedTokenizerBase): New tokenization
|
243 |
+
old_tokenizer (PreTrainedTokenizerBase): Old tokenization
|
244 |
+
"""
|
245 |
+
# Get old Embeddings
|
246 |
+
old_embed = metre_model.model.get_input_embeddings().weight.clone().detach()
|
247 |
+
old_mean = old_embed.mean(0)
|
248 |
+
# Generate new Embedding based on new tokenization
|
249 |
+
new_embd = old_embed.new_zeros(new_tokenizer.vocab_size, old_embed.size(1))
|
250 |
+
old_vocab = old_tokenizer.get_vocab()
|
251 |
+
|
252 |
+
vocab_hit = 0
|
253 |
+
# Keep as much from old Embeddings as possible
|
254 |
+
for w, idx_new in new_tokenizer.get_vocab().items():
|
255 |
+
idx_old = old_vocab.get(w, -1)
|
256 |
+
if idx_old >= 0:
|
257 |
+
new_embd[idx_new] = old_embed[idx_old]
|
258 |
+
vocab_hit +=1
|
259 |
+
else:
|
260 |
+
new_embd[idx_new] = old_mean
|
261 |
+
|
262 |
+
print(f"Vocab hit rate: {vocab_hit}/{old_tokenizer.vocab_size}")
|
263 |
+
#Exchange Embeddings and Decoding
|
264 |
+
new_embd_layer = torch.nn.Embedding(new_tokenizer.vocab_size, old_embed.size(1))
|
265 |
+
new_embd_layer.weight.data = new_embd
|
266 |
+
metre_model.model.set_input_embeddings(new_embd_layer)
|
267 |
+
new_decoder = torch.nn.Linear( old_embed.size(1), new_tokenizer.vocab_size)
|
268 |
+
new_decoder.weight = metre_model.model.roberta.embeddings.word_embeddings.weight
|
269 |
+
metre_model.model.lm_head.decoder = new_decoder
|
270 |
+
# Update LM config to reflect possible change in vocab size
|
271 |
+
metre_model.model.config.vocab_size = new_tokenizer.vocab_size
|
272 |
+
|
utils/poet_utils.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class StropheParams:
|
2 |
+
|
3 |
+
|
4 |
+
# Most Common Rhyme Schemas (Every Rhyme schema with presence over 0.36 %)
|
5 |
+
RHYME_SCHEMES = ['ABAB', 'XXXX',
|
6 |
+
'XAXA','AABB',
|
7 |
+
'XXXXXX','ABBA',
|
8 |
+
'AAXX', 'AABBCC',
|
9 |
+
'ABABCC','ABABXX',
|
10 |
+
'AABCCB','XXAA',
|
11 |
+
'XAAX', 'AXAX',
|
12 |
+
'XAXAXX','XXABAB',
|
13 |
+
'ABBACC','AXAA',
|
14 |
+
'XAABBX','AABCBC',
|
15 |
+
'AABBXX','ABBAXX',
|
16 |
+
'ABABAB','AAXA',
|
17 |
+
'AXXA','XAXABB',
|
18 |
+
'XXAABB','XXAAXX',
|
19 |
+
'ABABAX','XXABBA',
|
20 |
+
'AAXBBX','XXXAXA',
|
21 |
+
'AAAX','XABABX',
|
22 |
+
'XABBAX','AAXXBB',
|
23 |
+
'AXABBX','ABABBX',
|
24 |
+
'XAAXBB','AAAA',
|
25 |
+
'XAAA','XAABXB',
|
26 |
+
'AXABXB','AXAXBB',
|
27 |
+
None]
|
28 |
+
|
29 |
+
RHYME = RHYME_SCHEMES
|
30 |
+
|
31 |
+
|
32 |
+
NORMAL_SCHEMES = ["ABAB", "ABBA", "AABB", "AABBCC", "ABABCC", "ABBACC", "ABBAAB"]
|
33 |
+
|
34 |
+
# First 200 Most common endings
|
35 |
+
VERSE_ENDS = ['ní', 'la', 'je', 'tí', 'ce', 'ti', 'ky', 'ku', 'li', 'jí', 'ně', 'né', 'vá', 'se', 'ny', 'ly', 'na', 'ne', 'nou',
|
36 |
+
'lo', 'ci', 'mi', 'ný', 'sti', 'ka', 'le', 'cí', 'ná', 'ží', 'čí', 'ho', 'dí', 'ší', 'du', 'lí', 'dy', 'nu', 'ří',
|
37 |
+
'ji', 'ru', 'tě', 'ře', 'stí', 'vy', 'ká', 'še', 'dá', 'ni', 'te', 'ví', 'mu', 'tu', 'ta', 'vé', 'val', 'va', 'lý',
|
38 |
+
'tá', 'že', 'ty', 'no', 'vu', 'lá', 'kem', 'chu', 'ků', 'bě', 'vý', 'sy', 'me', 'zí', 'hu', 'vě', 'lu', 'da', 'ry',
|
39 |
+
'rá', 'lé', 'ko', 'ři', 'de', 'hy', 'lem', 'tem', 'kou', 'vou', 'ši', 'há', 'sí', 'ze', 'be', 'ra', 'má', 'to', 'by',
|
40 |
+
'mě', 'su', 'té', 'si', 'ných', 'den', 'či', 'ký', 'ním', 'če', 'tý', 'ma', 'my', 'sem', 'nem', 'dě', 'ha', 'vat', 'ným',
|
41 |
+
'dem', 'dou', 'sta', 'dla', 'svět', 'zem', 'jen', 'dal', 'mí', 'hou', 'zas', 'sen', 'rem', 'nů', 'bu', 'e', 'ba', 'ké',
|
42 |
+
'til', 'jest', 'ství', 'děl', 'květ', 'tů', 'chem', 'lou', 'sám', 'bí', 'tou', 'dé', 'šel', 'nul', 'chá', 'vem', 'sa',
|
43 |
+
'hlas', 'pí', 'čas', 'dil', 'let', 'cích', 'lů', 'žil', 'mů', 'dál', 'cha', 'byl', 'nost', 'ček', 'zy', 'hý', 'nám', 'di',
|
44 |
+
'bou', 'tím', 'ži', 'tek', 'vil', 'jsem', 'sů', 'dech', 'men', 'tla', 'sá', 'zrak', 'chy', 'vám', 'vi', 'dý', 'rád', 'svou',
|
45 |
+
'ném', 've', 'py', 'vo', 'vým', 'nek', 'již', 'víc', 'kal', 'mé', 'dů', 'stá', 'dnes', 'sty', 'ven', None]
|
46 |
+
ENDS = VERSE_ENDS
|
47 |
+
# Years to bucket to
|
48 |
+
POET_YEARS_BUCKETS = [1800, 1820, 1840, 1860, 1880, 1900, 1920, 1940, 1960, None]
|
49 |
+
POET_YEARS = POET_YEARS_BUCKETS
|
50 |
+
YEAR = POET_YEARS_BUCKETS
|
51 |
+
# Possible Meter Types
|
52 |
+
METER_TYPES = ["J","T","D","A","X","Y","N","H","P", None]
|
53 |
+
METER = METER_TYPES
|
54 |
+
# Translation of Meter to one char types
|
55 |
+
METER_TRANSLATE = {
|
56 |
+
"J":"J",
|
57 |
+
"T":"T",
|
58 |
+
"D":"D",
|
59 |
+
"A":"A",
|
60 |
+
"X":"X",
|
61 |
+
"Y":"Y",
|
62 |
+
"hexameter": "H",
|
63 |
+
"pentameter": "P",
|
64 |
+
"N":"N"
|
65 |
+
}
|
66 |
+
|
67 |
+
# Basic Characters to consider in rhyme and syllables (43)
|
68 |
+
VALID_CHARS = [""," ",'a','á','b','c','č','d','ď','e','é','ě',
|
69 |
+
'f','g','h','i','í','j','k','l','m','n','ň',
|
70 |
+
'o','ó','p','q','r','ř','s','š','t','ť','u',
|
71 |
+
'ú','ů','v','w','x','y','ý','z','ž']
|
72 |
+
CHARS = VALID_CHARS
|
73 |
+
class Tokens:
|
74 |
+
# Tokenizers Special Tokens
|
75 |
+
EOS = "<|EOS|>"
|
76 |
+
EOS_ID = 0
|
77 |
+
PAD = "<|PAD|>"
|
78 |
+
PAD_ID = 1
|
79 |
+
UNK = "<|UNK|>"
|
80 |
+
UNK_ID = 2
|
81 |
+
CLS = "<|CLS|>"
|
82 |
+
CLS_ID = 3
|
83 |
+
# SEP Token is EOS Token
|
84 |
+
SEP = EOS
|
85 |
+
SEP_ID = 0
|
86 |
+
|
87 |
+
ALL_TOKENS = {
|
88 |
+
EOS : 0,
|
89 |
+
PAD : 1,
|
90 |
+
UNK : 2,
|
91 |
+
CLS : 3,
|
92 |
+
}
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
import re
|
97 |
+
import numpy as np
|
98 |
+
|
99 |
+
def parse_boolean(value):
|
100 |
+
value = value.lower()
|
101 |
+
|
102 |
+
if value in ["true", "yes", "y", "1", "t"]:
|
103 |
+
return True
|
104 |
+
elif value in ["false", "no", "n", "0", "f"]:
|
105 |
+
return False
|
106 |
+
|
107 |
+
return False
|
108 |
+
|
109 |
+
class TextManipulation:
|
110 |
+
"""Static class for string manipulation methods
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
_type_: str returned by all methods
|
114 |
+
"""
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def _remove_most_nonchar(raw_text, lower_case=True):
|
118 |
+
"""Remove most non-alpha non-whitespace characters
|
119 |
+
|
120 |
+
Args:
|
121 |
+
raw_text (str): Text to manipulate
|
122 |
+
lower_case (bool, optional): If resulting text should be lowercase. Defaults to True.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
str: Cleaned up text
|
126 |
+
"""
|
127 |
+
text = re.sub(r'[–\„\“\’\;\:()\]\[\_\*\‘\”\'\-\—\"]+', "", raw_text)
|
128 |
+
return text.lower() if lower_case else text
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def _remove_all_nonchar(raw_text):
|
132 |
+
"""Remove all possible non-alpha characters
|
133 |
+
|
134 |
+
Args:
|
135 |
+
raw_text (str): Text to manipulate
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
str: Cleaned up text
|
139 |
+
"""
|
140 |
+
sub = re.sub(r'([^\w\s]+|[0-9]+)', '', raw_text)
|
141 |
+
return sub
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def _year_bucketor(raw_year):
|
145 |
+
"""Bucketizes year string to boundaries, Bad inputs returns NaN string
|
146 |
+
|
147 |
+
Args:
|
148 |
+
raw_year (str): Year string to bucketize
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
_type_: Bucketized year string
|
152 |
+
"""
|
153 |
+
if TextAnalysis._is_year(raw_year) and raw_year != "NaN":
|
154 |
+
year_index = np.argmin(np.abs(np.asarray(StropheParams.YEAR[:-1]) - int(raw_year)))
|
155 |
+
return str(StropheParams.YEAR[year_index])
|
156 |
+
else:
|
157 |
+
return "NaN"
|
158 |
+
|
159 |
+
_RHYME_POS = ["A", "B", "C", "D", "E", "F", "G", "H"]
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def rhyme_sec(rhyme_ref, current_rhyme):
|
163 |
+
"""Return proper rhyme indicator to given reference
|
164 |
+
|
165 |
+
Args:
|
166 |
+
rhyme_ref (_type_): reference number of 'A'
|
167 |
+
current_rhyme (_type_): current rhyme number that needs inidcation
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
str: rhyme indicator character
|
171 |
+
"""
|
172 |
+
|
173 |
+
return "X" if current_rhyme == None or current_rhyme== -1 or rhyme_ref == None or current_rhyme < rhyme_ref or current_rhyme >= rhyme_ref + len(TextManipulation._RHYME_POS) else TextManipulation._RHYME_POS[current_rhyme - rhyme_ref]
|
174 |
+
|
175 |
+
@staticmethod
|
176 |
+
def __post_process_rhyme(rhyme_str: str):
|
177 |
+
# First Pass
|
178 |
+
marker_count = {marker: rhyme_str.count(marker) for marker in TextManipulation._RHYME_POS}
|
179 |
+
for key, val in marker_count.items():
|
180 |
+
# Replace all, that ocurr only once with X
|
181 |
+
if val == 1:
|
182 |
+
rhyme_str = re.sub(key, 'X', rhyme_str)
|
183 |
+
# Downscale higher to lower if lower not present
|
184 |
+
marker_count = {marker: rhyme_str.count(marker) for marker in TextManipulation._RHYME_POS}
|
185 |
+
for key, val in marker_count.items():
|
186 |
+
if val > 1 and key != 'X':
|
187 |
+
key_index = TextManipulation._RHYME_POS.index(key)
|
188 |
+
replacements = {marker: rhyme_str.count(marker) for marker in TextManipulation._RHYME_POS[:key_index]}
|
189 |
+
for rep_key, rep_val in replacements.items():
|
190 |
+
if rep_val ==0:
|
191 |
+
rhyme_str = re.sub(key, rep_key, rhyme_str)
|
192 |
+
break
|
193 |
+
|
194 |
+
# Pass to swap letters
|
195 |
+
marker_index = {marker: rhyme_str.find(marker) for marker in TextManipulation._RHYME_POS if rhyme_str.find(marker) != -1}
|
196 |
+
keys_values = marker_index.items()
|
197 |
+
keys = [v[0] for v in keys_values]
|
198 |
+
values = [v[1] for v in keys_values]
|
199 |
+
|
200 |
+
i = 0
|
201 |
+
while i < len(keys):
|
202 |
+
j= 0
|
203 |
+
while j< len(keys):
|
204 |
+
if TextManipulation._RHYME_POS.index(keys[j]) > TextManipulation._RHYME_POS.index(keys[i]) and values[j] < values[i]:
|
205 |
+
# Swap the positions
|
206 |
+
rhyme_str = re.sub(keys[j], 'Z', rhyme_str)
|
207 |
+
rhyme_str = re.sub(keys[i], keys[j], rhyme_str)
|
208 |
+
rhyme_str = re.sub('Z', keys[i], rhyme_str)
|
209 |
+
# Need to update the value
|
210 |
+
temp = values[i]
|
211 |
+
values[i]= values[j]
|
212 |
+
values[j] = temp
|
213 |
+
j+=1
|
214 |
+
i+=1
|
215 |
+
|
216 |
+
|
217 |
+
return rhyme_str
|
218 |
+
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def _rhyme_string(curr_rhyme_list):
|
222 |
+
"""Translate rhyme as list of rhyming number to rhyme schema
|
223 |
+
|
224 |
+
Args:
|
225 |
+
curr_rhyme_list (list): Current rhyme as list of ints indicating rhyming verses
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
str: Rhyme schema
|
229 |
+
"""
|
230 |
+
rhyme_list = curr_rhyme_list.copy()
|
231 |
+
reference = None
|
232 |
+
# Give None a blank -1 rhyme id
|
233 |
+
for i in range(len(rhyme_list)):
|
234 |
+
if rhyme_list[i] != None and reference == None:
|
235 |
+
reference = rhyme_list[i]
|
236 |
+
elif rhyme_list[i] != None and rhyme_list[i] < reference:
|
237 |
+
reference = rhyme_list[i]
|
238 |
+
elif rhyme_list[i] == None:
|
239 |
+
rhyme_list[i] = -1
|
240 |
+
|
241 |
+
# With more robust post processing, this is may not needed
|
242 |
+
|
243 |
+
# if there is valid rhyme, normalize
|
244 |
+
if reference != None:
|
245 |
+
# sort the rhyme and get index of reference number
|
246 |
+
cheat_sheet = sorted(list(set(rhyme_list[:])))
|
247 |
+
ref_index = cheat_sheet.index(reference)
|
248 |
+
# normalize the rest around this reference
|
249 |
+
for i in range(len(rhyme_list)):
|
250 |
+
idx = cheat_sheet.index(rhyme_list[i])
|
251 |
+
rhyme_list[i] = reference + (idx - ref_index)
|
252 |
+
|
253 |
+
|
254 |
+
rhyme_str = ""
|
255 |
+
for num in rhyme_list:
|
256 |
+
rhyme_str += TextManipulation.rhyme_sec(reference, num)
|
257 |
+
|
258 |
+
return TextManipulation.__post_process_rhyme(rhyme_str)
|
259 |
+
|
260 |
+
class TextAnalysis:
|
261 |
+
"""Static class with methods of analysis of strings
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
Union[str, bool, dict, numpy.ndarray]: Analyzed input
|
265 |
+
"""
|
266 |
+
|
267 |
+
# Possible Keys if returned type is dict
|
268 |
+
POET_PARAM_LIST = ["RHYME", "YEAR", "METER", "LENGTH", "END", "TRUE_LENGTH", "TRUE_END"]
|
269 |
+
|
270 |
+
@staticmethod
|
271 |
+
def _is_meter(meter:str):
|
272 |
+
"""Return if string is meter type
|
273 |
+
|
274 |
+
Args:
|
275 |
+
meter (str): string to analyze
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
bool: If string is meter type
|
279 |
+
"""
|
280 |
+
return meter in StropheParams.METER[:-1]
|
281 |
+
|
282 |
+
@staticmethod
|
283 |
+
def _is_year(year:str):
|
284 |
+
"""Return if string is year or special NaN
|
285 |
+
|
286 |
+
Args:
|
287 |
+
year (str): string to analyze
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
bool: If string is year or special NaN
|
291 |
+
"""
|
292 |
+
return (year.isdecimal() and int(year) > 1_000 and int(year) < 10_000) or year == "NaN"
|
293 |
+
|
294 |
+
@staticmethod
|
295 |
+
def _rhyme_like(rhyme:str):
|
296 |
+
"""Return if string is structured like rhyme schema
|
297 |
+
|
298 |
+
Args:
|
299 |
+
rhyme (str): string to analyze
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
bool: If string is structured like rhyme schema
|
303 |
+
"""
|
304 |
+
return (rhyme.isupper() and len(rhyme) >= 3 and len(rhyme) <= 6)
|
305 |
+
|
306 |
+
@staticmethod
|
307 |
+
def _rhyme_vector(rhyme:str) -> np.ndarray:
|
308 |
+
"""Create One-hot encoded rhyme schema vector from given string
|
309 |
+
|
310 |
+
Args:
|
311 |
+
rhyme (str): string to construct vector from
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
numpy.ndarray: One-hot encoded rhyme schema vector
|
315 |
+
"""
|
316 |
+
|
317 |
+
rhyme_vec = np.zeros(len(StropheParams.RHYME))
|
318 |
+
if rhyme in StropheParams.RHYME:
|
319 |
+
rhyme_vec[StropheParams.RHYME.index(rhyme)] = 1
|
320 |
+
else:
|
321 |
+
rhyme_vec[-1] = 1
|
322 |
+
|
323 |
+
return rhyme_vec
|
324 |
+
|
325 |
+
|
326 |
+
@staticmethod
|
327 |
+
def _publish_year_vector(year_string):
|
328 |
+
"""Construct vector of year of publishing, weighting by distance
|
329 |
+
|
330 |
+
Args:
|
331 |
+
year_string (str): String with publish year
|
332 |
+
|
333 |
+
Returns:
|
334 |
+
numpy.ndarray: Vector of bucketized One-hot encoded publish year
|
335 |
+
"""
|
336 |
+
publish_year = None if not year_string.isdigit() else int(year_string)
|
337 |
+
publish_vector = np.zeros(len(StropheParams.YEAR))
|
338 |
+
if publish_year == None:
|
339 |
+
publish_vector[-1] = 1
|
340 |
+
else:
|
341 |
+
# Distance Part
|
342 |
+
#distance_weighting = [1/(1 + abs(year - publish_year)) for year in POET_YEARS_BUCKETS[:-1]] + [0]
|
343 |
+
#publish_vector = np.asarray(distance_weighting)
|
344 |
+
# Correct class correction
|
345 |
+
publish_vector[np.argmin( abs(np.asarray(StropheParams.YEAR[:-1]) - publish_year))] += 1
|
346 |
+
# Normalize
|
347 |
+
#publish_vector = publish_vector/np.sum(publish_vector)
|
348 |
+
return publish_vector
|
349 |
+
|
350 |
+
@staticmethod
|
351 |
+
def _rhyme_or_not(rhyme_str:str) -> np.ndarray:
|
352 |
+
"""Create vector if given rhyme string is in our list of rhyme schemas
|
353 |
+
|
354 |
+
Args:
|
355 |
+
rhyme_str (str): string to construct vector from
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
numpy.ndarray: Boolean flag vector
|
359 |
+
"""
|
360 |
+
rhyme_vector = np.zeros(2)
|
361 |
+
if rhyme_str in StropheParams.RHYME:
|
362 |
+
rhyme_vector[0] = 1
|
363 |
+
else:
|
364 |
+
rhyme_vector[1] = 1
|
365 |
+
return rhyme_vector
|
366 |
+
|
367 |
+
@staticmethod
|
368 |
+
def _metre_vector(metre: str) -> np.ndarray:
|
369 |
+
"""Create One-hot encoded metre vector from given string
|
370 |
+
|
371 |
+
Args:
|
372 |
+
metre (str): string to construct vector from
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
numpy.ndarray: One-hot encoded metre vector
|
376 |
+
"""
|
377 |
+
metre_vec = np.zeros(len(StropheParams.METER))
|
378 |
+
if metre in StropheParams.METER:
|
379 |
+
metre_vec[StropheParams.METER.index(metre)] = 1
|
380 |
+
else:
|
381 |
+
metre_vec[-1] = 1
|
382 |
+
return metre_vec
|
383 |
+
|
384 |
+
@staticmethod
|
385 |
+
def _first_line_analysis(text:str):
|
386 |
+
"""Analysis of parameter line for RHYME, METER, YEAR
|
387 |
+
|
388 |
+
Args:
|
389 |
+
text (str): parameter line string
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
dict: Dictionary with analysis result
|
393 |
+
"""
|
394 |
+
line_striped = text.strip()
|
395 |
+
if not line_striped:
|
396 |
+
return {}
|
397 |
+
poet_params = {}
|
398 |
+
# Look for each possible parameter
|
399 |
+
for param in line_striped.split():
|
400 |
+
if TextAnalysis._is_year(param):
|
401 |
+
# Year is Bucketized so to fit
|
402 |
+
poet_params["YEAR"] = TextManipulation._year_bucketor(param)
|
403 |
+
elif TextAnalysis._rhyme_like(param):
|
404 |
+
poet_params["RHYME"] = param
|
405 |
+
elif TextAnalysis._is_meter(param):
|
406 |
+
poet_params["STROPHE_METER"] = param
|
407 |
+
return poet_params
|
408 |
+
|
409 |
+
@staticmethod
|
410 |
+
def _is_line_length(length:str):
|
411 |
+
"""Return if string is number of syllables parameter
|
412 |
+
|
413 |
+
Args:
|
414 |
+
length (str): string to analyze
|
415 |
+
|
416 |
+
Returns:
|
417 |
+
bool: If string is number of syllables parameter
|
418 |
+
"""
|
419 |
+
return length.isdigit() and int(length) > 1 and int(length) < 100
|
420 |
+
|
421 |
+
@staticmethod
|
422 |
+
def _is_line_end(end:str):
|
423 |
+
"""Return if string is valid ending syllable/sequence parameter
|
424 |
+
|
425 |
+
Args:
|
426 |
+
end (str): string to analyze
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
bool: If string is valid ending syllable/sequence parameter
|
430 |
+
"""
|
431 |
+
return end.isalpha() and end.islower() and len(end) <= 5
|
432 |
+
|
433 |
+
@staticmethod
|
434 |
+
def _continuos_line_analysis(text:str):
|
435 |
+
"""Analysis of Content lines for LENGTH, TRUE_LENGTH, END, TRUE_END
|
436 |
+
|
437 |
+
Args:
|
438 |
+
text (str): content line to analyze
|
439 |
+
|
440 |
+
Returns:
|
441 |
+
dict: Dictionary with analysis result
|
442 |
+
"""
|
443 |
+
# Strip line of most separators and look if its empty
|
444 |
+
line_striped = TextManipulation._remove_most_nonchar(text, lower_case=False).strip()
|
445 |
+
if not line_striped:
|
446 |
+
return {}
|
447 |
+
line_params = {}
|
448 |
+
# OLD MODEL
|
449 |
+
if text.count('#') == 0: # BASIC
|
450 |
+
pass
|
451 |
+
else:
|
452 |
+
for param_group in text.split('#')[:-1]:
|
453 |
+
for param in param_group.split():
|
454 |
+
if TextAnalysis._is_meter(param.strip()):
|
455 |
+
line_params["METER"] = param.strip()
|
456 |
+
elif TextAnalysis._is_line_length(param.strip()):
|
457 |
+
line_params["LENGTH"] = int(param.strip())
|
458 |
+
elif TextAnalysis._is_line_end(param.strip()):
|
459 |
+
line_params["END"] = param.strip()
|
460 |
+
|
461 |
+
|
462 |
+
line_params["TRUE_LENGTH"] = len(SyllableMaker.syllabify(line_striped.split('#')[-1]))
|
463 |
+
line_only_char = TextManipulation._remove_all_nonchar(line_striped).strip()
|
464 |
+
if len(line_only_char) > 2:
|
465 |
+
line_params["TRUE_END"] = SyllableMaker.syllabify(" ".join(line_only_char.split()[-2:]))[-1]
|
466 |
+
|
467 |
+
return line_params
|
468 |
+
|
469 |
+
@staticmethod
|
470 |
+
def _is_param_line(text:str):
|
471 |
+
"""Return if line is a Parameter line (Parameters RHYME, METER, YEAR)
|
472 |
+
|
473 |
+
Args:
|
474 |
+
text (str): line to analyze
|
475 |
+
|
476 |
+
Returns:
|
477 |
+
bool: If line is a Parameter line
|
478 |
+
"""
|
479 |
+
line_striped = text.strip()
|
480 |
+
if not line_striped:
|
481 |
+
return False
|
482 |
+
small_analysis = TextAnalysis._first_line_analysis(line_striped)
|
483 |
+
return "RHYME" in small_analysis.keys() or "YEAR" in small_analysis.keys()
|
484 |
+
|
485 |
+
class SyllableMaker:
|
486 |
+
"""Static class with methods for separating string to list of Syllables
|
487 |
+
|
488 |
+
Returns:
|
489 |
+
list: List of syllables
|
490 |
+
"""
|
491 |
+
|
492 |
+
|
493 |
+
# NON-Original code!
|
494 |
+
# Taken from Barbora Štěpánková
|
495 |
+
|
496 |
+
@staticmethod
|
497 |
+
def syllabify(text : str) -> list[str]:
|
498 |
+
words = re.findall(r"[aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzžAÁBCČDĎEÉĚFGHIÍJKLMNŇOÓPQRŘSŠTŤUÚŮVWXYÝZŽäöüÄÜÖ]+", text)
|
499 |
+
syllables : list[str] = []
|
500 |
+
|
501 |
+
i = 0
|
502 |
+
while i < len(words):
|
503 |
+
word = words[i]
|
504 |
+
|
505 |
+
if (word.lower() == "k" or word.lower() == "v" or word.lower() == "s" or word.lower() == "z") and i < len(words) - 1 and len(words[i + 1]) > 1:
|
506 |
+
i += 1
|
507 |
+
word = word + words[i]
|
508 |
+
|
509 |
+
letter_counter = 0
|
510 |
+
|
511 |
+
# Get syllables: mask the word and split the mask
|
512 |
+
for syllable_mask in SyllableMaker.__split_mask(SyllableMaker.__create_word_mask(word)):
|
513 |
+
word_syllable = ""
|
514 |
+
for character in syllable_mask:
|
515 |
+
word_syllable += word[letter_counter]
|
516 |
+
letter_counter += 1
|
517 |
+
|
518 |
+
syllables.append(word_syllable)
|
519 |
+
|
520 |
+
i += 1
|
521 |
+
|
522 |
+
return syllables
|
523 |
+
|
524 |
+
|
525 |
+
@staticmethod
|
526 |
+
def __create_word_mask(word : str) -> str:
|
527 |
+
word = word.lower()
|
528 |
+
|
529 |
+
vocals = r"[aeiyouáéěíýóůúäöü]"
|
530 |
+
consonants = r"[bcčdďfghjklmnňpqrřsštťvwxzž]"
|
531 |
+
|
532 |
+
replacements = [
|
533 |
+
#double letters
|
534 |
+
('ch', 'c0'),
|
535 |
+
('rr', 'r0'),
|
536 |
+
('ll', 'l0'),
|
537 |
+
('nn', 'n0'),
|
538 |
+
('th', 't0'),
|
539 |
+
|
540 |
+
# au, ou, ai, oi
|
541 |
+
(r'[ao]u', '0V'),
|
542 |
+
(r'[ao]i','0V'),
|
543 |
+
|
544 |
+
# eu at the beginning of the word
|
545 |
+
(r'^eu', '0V'),
|
546 |
+
|
547 |
+
# now all vocals
|
548 |
+
(vocals, 'V'),
|
549 |
+
|
550 |
+
# r,l that act like vocals in syllables
|
551 |
+
(r'([^V])([rl])(0*[^0Vrl]|$)', r'\1V\3'),
|
552 |
+
|
553 |
+
# sp, st, sk, št, Cř, Cl, Cr, Cv
|
554 |
+
(r's[pt]', 's0'),
|
555 |
+
(r'([^V0lr]0*)[řlrv]', r'\g<1>0'),
|
556 |
+
(r'([^V0]0*)sk', r'\1s0'),
|
557 |
+
(r'([^V0]0*)št', r'\1š0'),
|
558 |
+
|
559 |
+
(consonants, 'K')
|
560 |
+
]
|
561 |
+
|
562 |
+
for (original, replacement) in replacements:
|
563 |
+
word = re.sub(original, replacement, word)
|
564 |
+
|
565 |
+
return word
|
566 |
+
|
567 |
+
|
568 |
+
@staticmethod
|
569 |
+
def __split_mask(mask : str) -> list[str]:
|
570 |
+
replacements = [
|
571 |
+
# vocal at the beginning
|
572 |
+
(r'(^0*V)(K0*V)', r'\1/\2'),
|
573 |
+
(r'(^0*V0*K0*)K', r'\1/K'),
|
574 |
+
|
575 |
+
# dividing the middle of the word
|
576 |
+
(r'(K0*V(K0*$)?)', r'\1/'),
|
577 |
+
(r'/(K0*)K', r'\1/K'),
|
578 |
+
(r'/(0*V)(0*K0*V)', r'/\1/\2'),
|
579 |
+
(r'/(0*V0*K0*)K', r'/\1/K'),
|
580 |
+
|
581 |
+
# add the last consonant to the previous syllable
|
582 |
+
(r'/(K0*)$', r'\1/')
|
583 |
+
]
|
584 |
+
|
585 |
+
for (original, replacement) in replacements:
|
586 |
+
mask = re.sub(original, replacement, mask)
|
587 |
+
|
588 |
+
if len(mask) > 0 and mask[-1] == "/":
|
589 |
+
mask = mask[0:-1]
|
590 |
+
|
591 |
+
return mask.split("/")
|
utils/validators.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
import jellyfish
|
4 |
+
from tqdm import tqdm
|
5 |
+
from transformers import AutoModelForMaskedLM
|
6 |
+
from transformers.utils import ModelOutput
|
7 |
+
import numpy as np
|
8 |
+
from .poet_utils import StropheParams
|
9 |
+
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
from pytorch_optimizer import SAM
|
12 |
+
|
13 |
+
class ValidatorInterface(torch.nn.Module):
|
14 |
+
"""Pytorch Model Interface. Abstract class for all validators
|
15 |
+
|
16 |
+
Args:
|
17 |
+
torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
|
18 |
+
"""
|
19 |
+
def __init__(self, *args, **kwargs) -> None:
|
20 |
+
""" Constructor. As child Class needs to construct Parent
|
21 |
+
"""
|
22 |
+
super().__init__(*args, **kwargs)
|
23 |
+
|
24 |
+
def forward(self, input_ids=None, attention_mask=None, *args, **kwargs):
|
25 |
+
"""Compute model output and model loss
|
26 |
+
|
27 |
+
Args:
|
28 |
+
input_ids (_type_, optional): Model inputs. Defaults to None.
|
29 |
+
attention_mask (_type_, optional): Attention mask where padding starts. Defaults to None.
|
30 |
+
|
31 |
+
Raises:
|
32 |
+
NotImplementedError: Abstract class
|
33 |
+
"""
|
34 |
+
raise NotImplementedError()
|
35 |
+
|
36 |
+
def predict_state(self, input_ids=None, *args, **kwargs):
|
37 |
+
"""Compute model outputs
|
38 |
+
|
39 |
+
Args:
|
40 |
+
input_ids (_type_, optional): Model inputs. Defaults to None.
|
41 |
+
|
42 |
+
Raises:
|
43 |
+
NotImplementedError: Abstract class
|
44 |
+
"""
|
45 |
+
raise NotImplementedError()
|
46 |
+
|
47 |
+
def validate_model(self, input_ids=None, *args, **kwargs):
|
48 |
+
"""Validate model given some labels, Doesn't use loss
|
49 |
+
|
50 |
+
Args:
|
51 |
+
input_ids (_type_, optional): Model inputs. Defaults to None.
|
52 |
+
|
53 |
+
Raises:
|
54 |
+
NotImplementedError: Abstract class
|
55 |
+
"""
|
56 |
+
raise NotImplementedError()
|
57 |
+
|
58 |
+
|
59 |
+
class RhymeValidator(ValidatorInterface):
|
60 |
+
def __init__(self, pretrained_model, *args, **kwargs) -> None:
|
61 |
+
super().__init__(*args, **kwargs)
|
62 |
+
|
63 |
+
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
|
64 |
+
|
65 |
+
self.config = self.model.config
|
66 |
+
|
67 |
+
self.model_size = self.config.hidden_size
|
68 |
+
|
69 |
+
self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Common Rhyme Type
|
70 |
+
|
71 |
+
self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.0, weight=torch.tensor([1, 1, 1.5, 1.5, 1.5, 1.5,
|
72 |
+
2, 2, 2, 3, 3, 3,
|
73 |
+
3, 3, 3, 3, 4, 4,
|
74 |
+
5, 5, 5, 5, 7, 7,
|
75 |
+
7, 7, 7, 8, 8, 8,
|
76 |
+
9, 9, 9, 10, 10, 10,
|
77 |
+
12,12, 12, 12, 12, 12,
|
78 |
+
15,15,1.5]) )
|
79 |
+
|
80 |
+
def forward(self, input_ids=None, attention_mask=None, rhyme=None, *args, **kwargs):
|
81 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
|
82 |
+
|
83 |
+
last_hidden = outputs['hidden_states'][-1]
|
84 |
+
|
85 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
86 |
+
|
87 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
88 |
+
rhyme_loss = self.loss_fnc(softmaxed, rhyme)
|
89 |
+
|
90 |
+
return ModelOutput(loss=rhyme_loss + outputs.loss, model_output=softmaxed)
|
91 |
+
|
92 |
+
def predict_state(self, input_ids=None, *args, **kwargs):
|
93 |
+
|
94 |
+
outputs = self.model(input_ids=input_ids)
|
95 |
+
|
96 |
+
last_hidden = outputs['hidden_states'][-1]
|
97 |
+
|
98 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
99 |
+
|
100 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
101 |
+
|
102 |
+
return softmaxed
|
103 |
+
|
104 |
+
def validate_model(self, input_ids=None, rhyme=None, k:int = 2,*args, **kwargs):
|
105 |
+
outputs = self.model(input_ids=input_ids)
|
106 |
+
|
107 |
+
last_hidden = outputs['hidden_states'][-1]
|
108 |
+
|
109 |
+
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
110 |
+
|
111 |
+
softmaxed = torch.softmax(rhyme_regression, dim=1)
|
112 |
+
|
113 |
+
softmaxed = softmaxed.flatten().cpu()
|
114 |
+
|
115 |
+
predicted_val = torch.argmax(softmaxed)
|
116 |
+
|
117 |
+
predicted_top_k = torch.topk(softmaxed, k).indices
|
118 |
+
|
119 |
+
label_val = torch.argmax(rhyme.flatten())
|
120 |
+
|
121 |
+
validation_true_val = (label_val == predicted_val).float().sum().numpy()
|
122 |
+
top_k_presence = 0
|
123 |
+
if label_val in predicted_top_k:
|
124 |
+
top_k_presence = 1
|
125 |
+
|
126 |
+
levenshtein = jellyfish.levenshtein_distance(StropheParams.RHYME[predicted_val] if StropheParams.RHYME[predicted_val] != None else "", StropheParams.RHYME[label_val] if StropheParams.RHYME[label_val] != None else "")
|
127 |
+
|
128 |
+
hit_pred = softmaxed[label_val].detach().numpy()
|
129 |
+
|
130 |
+
return {"acc" : validation_true_val,
|
131 |
+
"top_k" : top_k_presence,
|
132 |
+
"lev_distance": levenshtein,
|
133 |
+
"predicted_label" : hit_pred
|
134 |
+
}
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
class MeterValidator(ValidatorInterface):
|
139 |
+
def __init__(self, pretrained_model, *args, **kwargs) -> None:
|
140 |
+
super().__init__(*args, **kwargs)
|
141 |
+
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
|
142 |
+
|
143 |
+
self.config = self.model.config
|
144 |
+
|
145 |
+
self.model_size = self.config.hidden_size
|
146 |
+
|
147 |
+
self.meter_regressor = torch.nn.Linear(self.model_size, len(StropheParams.METER)) # Meter Type
|
148 |
+
|
149 |
+
self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.0, weight=torch.tensor([1, 1.5, 5, 10, 10, 20, 5, 20, 20, 0]))
|
150 |
+
|
151 |
+
def forward(self, input_ids=None, attention_mask=None, metre_ids=None, *args, **kwargs):
|
152 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
|
153 |
+
|
154 |
+
last_hidden = outputs['hidden_states'][-1]
|
155 |
+
|
156 |
+
meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
157 |
+
|
158 |
+
softmaxed = torch.softmax(meter_regression, dim=1)
|
159 |
+
meter_loss = self.loss_fnc(softmaxed, metre_ids)
|
160 |
+
|
161 |
+
return ModelOutput(loss=meter_loss + outputs.loss, model_output=softmaxed)
|
162 |
+
|
163 |
+
def predict_state(self, input_ids=None, *args, **kwargs):
|
164 |
+
outputs = self.model(input_ids=input_ids)
|
165 |
+
|
166 |
+
last_hidden = outputs['hidden_states'][-1]
|
167 |
+
|
168 |
+
meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
169 |
+
|
170 |
+
softmaxed = torch.softmax(meter_regression, dim=1)
|
171 |
+
|
172 |
+
return softmaxed
|
173 |
+
|
174 |
+
def validate_model(self, input_ids=None, metre_ids=None, attention_mask=None, k: int=2,*args, **kwargs):
|
175 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask )
|
176 |
+
|
177 |
+
last_hidden = outputs['hidden_states'][-1]
|
178 |
+
|
179 |
+
meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
|
180 |
+
|
181 |
+
softmaxed = torch.softmax(meter_regression, dim=1)
|
182 |
+
|
183 |
+
softmaxed = softmaxed.flatten().cpu()
|
184 |
+
|
185 |
+
predicted_val = torch.argmax(softmaxed)
|
186 |
+
|
187 |
+
predicted_top_k = torch.topk(softmaxed, k).indices
|
188 |
+
|
189 |
+
label_val = torch.argmax(metre_ids.flatten())
|
190 |
+
|
191 |
+
validation_true_val = (label_val == predicted_val).float().sum().numpy()
|
192 |
+
top_k_presence = 0
|
193 |
+
if label_val in predicted_top_k:
|
194 |
+
top_k_presence = 1
|
195 |
+
|
196 |
+
hit_pred = softmaxed[label_val].detach().numpy()
|
197 |
+
|
198 |
+
return {"acc" : validation_true_val,
|
199 |
+
"top_k" : top_k_presence,
|
200 |
+
"predicted_label" : hit_pred
|
201 |
+
}
|
202 |
+
|
203 |
+
class YearValidator(ValidatorInterface):
|
204 |
+
def __init__(self, pretrained_model, *args, **kwargs) -> None:
|
205 |
+
super().__init__(*args, **kwargs)
|
206 |
+
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
|
207 |
+
|
208 |
+
self.config = self.model.config
|
209 |
+
|
210 |
+
self.model_size = self.config.hidden_size
|
211 |
+
|
212 |
+
self.year_era = torch.nn.Linear(self.model_size, len(StropheParams.YEAR))
|
213 |
+
self.softmax = torch.nn.Softmax(dim=-1)
|
214 |
+
|
215 |
+
self.year_val = torch.nn.Linear(self.model_size, 1) # Year Value
|
216 |
+
|
217 |
+
|
218 |
+
self.loss_fnc_era = torch.nn.CrossEntropyLoss(label_smoothing=0.0,weight=torch.tensor([10, 5, 3, 3, 1, 1, 1.5, 2, 5, 0]))
|
219 |
+
|
220 |
+
self.loss_fnc_val = torch.nn.L1Loss()
|
221 |
+
|
222 |
+
def forward(self, input_ids=None, attention_mask=None, year_bucket=None, year=None, *args, **kwargs):
|
223 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
|
224 |
+
|
225 |
+
last_hidden = outputs['hidden_states'][-1]
|
226 |
+
|
227 |
+
|
228 |
+
year_val = self.year_val((last_hidden[:,0,:].view(-1, self.model_size)))
|
229 |
+
year_val_loss = self.loss_fnc_val(year_val, year)
|
230 |
+
|
231 |
+
year_era = self.year_era((last_hidden[:,0,:].view(-1, self.model_size)))
|
232 |
+
year_era = self.softmax(year_era)
|
233 |
+
year_era_loss = self.loss_fnc_era(year_era, year_bucket)
|
234 |
+
|
235 |
+
return ModelOutput(loss=year_val_loss + year_era_loss + outputs.loss, model_output=(year_val, year_era))
|
236 |
+
|
237 |
+
def predict_state(self, input_ids=None, *args, **kwargs):
|
238 |
+
outputs = self.model(input_ids=input_ids)
|
239 |
+
|
240 |
+
last_hidden = outputs['hidden_states'][-1]
|
241 |
+
|
242 |
+
year_val = self.year_val((last_hidden[:,0,:].view(-1, self.model_size)))
|
243 |
+
|
244 |
+
return year_val
|
245 |
+
|
246 |
+
def validate_model(self, input_ids=None, year_bucket=None, k: int=2,*args, **kwargs):
|
247 |
+
|
248 |
+
outputs = self.model(input_ids=input_ids)
|
249 |
+
|
250 |
+
last_hidden = outputs['hidden_states'][-1]
|
251 |
+
|
252 |
+
year_val = self.year_val((last_hidden[:,0,:].view(-1, self.model_size)))
|
253 |
+
if hasattr(self, 'year_era'):
|
254 |
+
year_era = self.year_era((last_hidden[:,0,:].view(-1, self.model_size)))
|
255 |
+
year_era = self.softmax(year_era)
|
256 |
+
|
257 |
+
year_val = year_val.detach().flatten().cpu().numpy()
|
258 |
+
if hasattr(self, 'year_era'):
|
259 |
+
year_era = year_era.detach().flatten().cpu().numpy()
|
260 |
+
|
261 |
+
publish_vector = [1/(1 + abs(year - year_val[0])) for year in StropheParams.YEAR[:-1]] + [0]
|
262 |
+
publish_vector = np.asarray(publish_vector)/np.sum(publish_vector)
|
263 |
+
# Adding era prediction
|
264 |
+
if hasattr(self, 'year_era'):
|
265 |
+
publish_vector+= year_era
|
266 |
+
publish_vector = torch.tensor( np.asarray(publish_vector)/np.sum(publish_vector))
|
267 |
+
|
268 |
+
|
269 |
+
predicted_val = torch.argmax(publish_vector)
|
270 |
+
|
271 |
+
predicted_top_k = torch.topk(publish_vector, k).indices
|
272 |
+
|
273 |
+
label_val = torch.argmax(year_bucket.flatten())
|
274 |
+
|
275 |
+
validation_true_val = (label_val == predicted_val).float().sum().numpy()
|
276 |
+
top_k_presence = 0
|
277 |
+
if label_val in predicted_top_k:
|
278 |
+
top_k_presence = 1
|
279 |
+
|
280 |
+
hit_pred = publish_vector[label_val].detach().numpy()
|
281 |
+
|
282 |
+
distance = abs(label_val.numpy() - predicted_val.numpy())
|
283 |
+
|
284 |
+
return {"acc" : validation_true_val,
|
285 |
+
"top_k" : top_k_presence,
|
286 |
+
"predicted_label" : hit_pred,
|
287 |
+
"distance" : distance
|
288 |
+
}
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
class ValidatorTrainer:
|
293 |
+
def __init__(self, model: ValidatorInterface, args: dict, train_dataset: Dataset, data_collator, device):
|
294 |
+
self.model = model
|
295 |
+
self.args = args
|
296 |
+
self.epochs = 1 if "epochs" not in args.keys() else args["epochs"]
|
297 |
+
self.batch_size = 1 if "batch_size" not in args.keys() else args["batch_size"]
|
298 |
+
self.lr = 5e-5 if "lr" not in args.keys() else args["lr"]
|
299 |
+
self.weight_decay = 0.0 if "weight_decay" not in args.keys() else args['weight_decay']
|
300 |
+
|
301 |
+
self.train_loader = DataLoader(train_dataset, self.batch_size, True, collate_fn=data_collator)
|
302 |
+
|
303 |
+
# SAM Values
|
304 |
+
self.device = device
|
305 |
+
self.optimizer = SAM(self.model.parameters(), torch.optim.AdamW, lr=self.lr, weight_decay=self.weight_decay)
|
306 |
+
self.scheduler = transformers.get_constant_schedule_with_warmup(self.optimizer, 4 * len(train_dataset)//self.batch_size)
|
307 |
+
|
308 |
+
# GSAM Value
|
309 |
+
#self.device = device
|
310 |
+
#self.base_optim = AdamP(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
311 |
+
#self.scheduler = transformers.get_constant_schedule_with_warmup(self.base_optim, len(train_dataset)//self.batch_size)
|
312 |
+
#self.rho_scheduler= ProportionScheduler( self.scheduler, max_lr=self.lr)
|
313 |
+
#self.optimizer = GSAM(self.model.parameters(),self.base_optim, self.model, self.rho_scheduler, alpha=0.05)
|
314 |
+
|
315 |
+
def train(self):
|
316 |
+
for epoch in tqdm(range(self.epochs)):
|
317 |
+
self.model.train()
|
318 |
+
|
319 |
+
# SAM Attempt
|
320 |
+
|
321 |
+
for step, batch in enumerate(self.train_loader):
|
322 |
+
# First Pass
|
323 |
+
loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
|
324 |
+
rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
|
325 |
+
metre_ids = None if batch["metre_ids"] == None else batch["metre_ids"].to(self.device),
|
326 |
+
year_bucket = None if batch["year_bucket"] == None else batch["year_bucket"].to(self.device),
|
327 |
+
year = None if batch["year"] == None else batch["year"].to(self.device))['loss']
|
328 |
+
loss.backward()
|
329 |
+
self.optimizer.first_step(zero_grad=True)
|
330 |
+
# Second Pass
|
331 |
+
loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
|
332 |
+
rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
|
333 |
+
metre_ids = None if batch["metre_ids"] == None else batch["metre_ids"].to(self.device),
|
334 |
+
year_bucket = None if batch["year_bucket"] == None else batch["year_bucket"].to(self.device),
|
335 |
+
year = None if batch["year"] == None else batch["year"].to(self.device))['loss']
|
336 |
+
|
337 |
+
loss.backward()
|
338 |
+
self.optimizer.second_step(zero_grad=True)
|
339 |
+
self.scheduler.step()
|
340 |
+
|
341 |
+
# GSAM Attempt
|
342 |
+
|
343 |
+
#for step, batch in enumerate(self.train_loader):
|
344 |
+
# def closure():
|
345 |
+
# self.optimizer.base_optimizer.zero_grad()
|
346 |
+
# with torch.enable_grad():
|
347 |
+
# outputs = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
|
348 |
+
# rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
|
349 |
+
# metre = None if batch["metre"] == None else batch["metre"].to(self.device))
|
350 |
+
# loss = torch.nn.functional.cross_entropy(outputs['model_output'].to(self.device),batch['rhyme'].to(self.device) if isinstance(self.model, RhymeValidator) else batch['metre'].to(self.device))
|
351 |
+
# loss.backward()
|
352 |
+
# return outputs['model_output'], loss.detach()
|
353 |
+
# predictions, loss = self.optimizer.step(closure)
|
354 |
+
# self.scheduler.step()
|
355 |
+
# self.optimizer.update_rho_t()
|
356 |
+
#
|
357 |
+
if step % 100 == 0:
|
358 |
+
print(f'Step {len(self.train_loader) * epoch + step}, loss : {loss.item()}', flush=True)
|
359 |
+
|
utils/validators/meter/ufal-robeczech-base_BPE_validator_1704126400265
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d83f2b8f9b00db0945584e3bcbce96f971cfc572cb8665ff713c6d3cc67854d4
|
3 |
+
size 504173324
|
utils/validators/rhyme/distilroberta-base_BPE_validator_1704126399565
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ceb77ef356a5e5ce3d59a6b2d31b96c925af09e29b4731c143ebabdaf3401c65
|
3 |
+
size 328898329
|
utils/validators/year/ufal-robeczech-base_BPE_validator_1702393305267
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4695ae160b8236b89c467fb50318c6cb429ae6152f9332f74ddcaff5cbe23da1
|
3 |
+
size 504177816
|