jinymusim commited on
Commit
13cea7c
·
verified ·
1 Parent(s): edd506f
.gitattributes CHANGED
@@ -36,3 +36,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  validators/meter/ufal-robeczech-base_syllable_BPE_validator_1702489033354 filter=lfs diff=lfs merge=lfs -text
37
  validators/rhyme/distilroberta-base_syllable_BPE_validator_1702665903087 filter=lfs diff=lfs merge=lfs -text
38
  validators/year/ufal-robeczech-base_BPE_validator_1702393305267 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
36
  validators/meter/ufal-robeczech-base_syllable_BPE_validator_1702489033354 filter=lfs diff=lfs merge=lfs -text
37
  validators/rhyme/distilroberta-base_syllable_BPE_validator_1702665903087 filter=lfs diff=lfs merge=lfs -text
38
  validators/year/ufal-robeczech-base_BPE_validator_1702393305267 filter=lfs diff=lfs merge=lfs -text
39
+ utils/validators/meter/ufal-robeczech-base_BPE_validator_1704126400265 filter=lfs diff=lfs merge=lfs -text
40
+ utils/validators/rhyme/distilroberta-base_BPE_validator_1704126399565 filter=lfs diff=lfs merge=lfs -text
41
+ utils/validators/year/ufal-robeczech-base_BPE_validator_1702393305267 filter=lfs diff=lfs merge=lfs -text
corpus_capsulated_datasets.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import torch
5
+
6
+ from utils.poet_utils import StropheParams, SyllableMaker, TextAnalysis, TextManipulation
7
+ from torch.utils.data import Dataset
8
+ from transformers import PreTrainedTokenizerBase, PreTrainedModel
9
+ #TODO: Maybe replace year of book being written for year Author was born
10
+ class CorpusDatasetPytorch:
11
+ """Dataset class responsible for data loading.
12
+ """
13
+
14
+ class RawDataset:
15
+ """Dataset distributing raw sting data with no preprocessing
16
+ """
17
+ def __init__(self, data_file_paths, lower_case:bool = True):
18
+ """Construct the frame around Raw data generation
19
+
20
+ Args:
21
+ data_file_paths (_type_): list of paths to data files
22
+ lower_case (bool, optional): if resulting data should be in lowercase. Defaults to True.
23
+ """
24
+ self._data_file_paths = data_file_paths
25
+ self.lower_case = lower_case
26
+
27
+ def gen_files(self):
28
+ """Get individual opened files
29
+
30
+ Yields:
31
+ _type_: open file object
32
+ """
33
+ for filename in self._data_file_paths:
34
+ yield open(filename, 'r')
35
+
36
+ def get_text(self):
37
+ """Get lines of text of poetry
38
+
39
+ Yields:
40
+ str: individual verse line
41
+ """
42
+ for step,file in enumerate(self.gen_files()):
43
+ if step % 500 == 0:
44
+ print(f"Processing file {step}")
45
+ datum = json.load(file)
46
+ for data_line in datum:
47
+ for part_line in data_line['body']:
48
+ for text_line in part_line:
49
+ yield text_line['text'].lower() if self.lower_case else text_line['text']
50
+
51
+ def get_part(self):
52
+ """Get strophe of poetry
53
+
54
+ Yields:
55
+ str: 1 strophe of poetry
56
+ """
57
+ for step,file in enumerate(self.gen_files()):
58
+ if step % 500 == 0:
59
+ print(f"Processing file {step}")
60
+ datum = json.load(file)
61
+ for data_line in datum:
62
+ for part_line in data_line['body']:
63
+ part = []
64
+ for text_line in part_line:
65
+ part.append(text_line['text'])
66
+ yield "\n".join(part).lower() if self.lower_case else "\n".join(part)
67
+
68
+ def get_body(self):
69
+ """Get whole poem
70
+
71
+ Yields:
72
+ str: 1 whole poem
73
+ """
74
+ for step,file in enumerate(self.gen_files()):
75
+ if step % 500 == 0:
76
+ print(f"Processing file {step}")
77
+ datum = json.load(file)
78
+ for data_line in datum:
79
+ body = []
80
+ for part_line in data_line['body']:
81
+
82
+ for text_line in part_line:
83
+ body.append(text_line['text'])
84
+ body.append("\n")
85
+ yield "\n".join(body).lower() if self.lower_case else "\n".join(body)
86
+
87
+ class TextDataset(Dataset):
88
+ """Dataset of preprocessed verse lines
89
+
90
+ Args:
91
+ Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface
92
+ """
93
+
94
+ def __init__(self, data_file_paths, prompt_length=True, prompt_ending=True, lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05):
95
+ """Construct the class our given data files path and store variables
96
+
97
+ Args:
98
+ data_file_paths (_type_): list of paths to data files
99
+ prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
100
+ prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
101
+ lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
102
+ val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
103
+ test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
104
+ """
105
+ self._data_file_paths = data_file_paths
106
+ self.prompt_length = prompt_length
107
+ self.prompt_ending = prompt_ending
108
+ self.lower_case = lower_case
109
+
110
+ self.val_data_rate = val_data_rate
111
+ self.test_data_rate = test_data_rate
112
+
113
+ self.data = []
114
+ self.validation_data = []
115
+ self.test_data = []
116
+
117
+
118
+ def gen_files(self):
119
+ """Get individual opened files
120
+
121
+ Yields:
122
+ _type_: open file object
123
+ """
124
+ for filename in self._data_file_paths:
125
+ yield open(filename, 'r')
126
+
127
+ @staticmethod
128
+ def _vowels_and_endings(raw_text):
129
+ """Get the verse ending and number of syllables in verse
130
+
131
+ Args:
132
+ raw_text (str): raw verse to analyze
133
+
134
+ Returns:
135
+ tuple: number of syllables, ending syllable
136
+ """
137
+ syllabs = SyllableMaker.syllabify(raw_text)
138
+ vowels = len(syllabs) #INFO: Now counts the number of syllables
139
+ ending = syllabs[-1]
140
+ return vowels, ending
141
+
142
+ @staticmethod
143
+ def _ending_vector(end):
144
+ """Construct One-hot encoded vector for ending syllable
145
+
146
+ Args:
147
+ end (str): Ending syllable
148
+
149
+ Returns:
150
+ numpy.ndarray: One-hot encoded vector of ending syllable
151
+ """
152
+ verse_end_vector = np.zeros(len(StropheParams.ENDS))
153
+ if end in StropheParams.ENDS[:-1]:
154
+ verse_end_vector[StropheParams.ENDS.index(end)] = 1
155
+ else:
156
+ verse_end_vector[-1] = 1
157
+ return verse_end_vector
158
+
159
+ @staticmethod
160
+ def _syllable_line(raw_text):
161
+ """Construct verse as sequence of syllables
162
+
163
+ Args:
164
+ raw_text (str): raw verse line
165
+
166
+ Returns:
167
+ str: Verse line as sequence of syllables
168
+ """
169
+ ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
170
+ return " ".join(SyllableMaker.syllabify(raw_text)) + ending
171
+
172
+ def _construct_line(self, raw_text, metre):
173
+ """Construct individual content line
174
+
175
+ Args:
176
+ raw_text (str): raw verse line
177
+
178
+ Returns:
179
+ str: Processed verse line with line parameters
180
+ """
181
+ syllables = SyllableMaker.syllabify(raw_text)
182
+ num_str = f"{len(syllables)} # " if self.prompt_length else ""
183
+ verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
184
+ metre_txt = f"{metre} # "
185
+ return metre_txt + num_str + verse_end + raw_text
186
+
187
+ def _introduce_phonetics(self, raw_text:str, phonetics):
188
+ phonetic_text = raw_text
189
+ for word in phonetics['words']:
190
+ phonetic_text = phonetic_text.replace(f'{word["token_lc"]}', f'{word["phoebe"]}') if self.lower_case else phonetic_text.replace(f'{word["token"]}', f'{word["phoebe"]}')
191
+ return phonetic_text
192
+
193
+ def _construct_syllable_line(self, raw_text, metre):
194
+ """Construct individual content line as sequence of syllables
195
+
196
+ Args:
197
+ raw_text (str): raw verse line
198
+
199
+ Returns:
200
+ str: Processed verse line as sequence of syllables with line parameters
201
+ """
202
+ ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
203
+ syllables = SyllableMaker.syllabify(raw_text)
204
+ num_str = f"{len(syllables)} # " if self.prompt_length else ""
205
+ verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
206
+ metre_txt = f"{metre} # "
207
+ return metre_txt+ num_str + verse_end + " ".join(syllables) + ending
208
+
209
+
210
+ def data_text_line_gen(self):
211
+ """Preprocess and process data for usage
212
+ """
213
+ for step,file in enumerate(self.gen_files()):
214
+ if step % 500 == 0:
215
+ print(f"Processing file {step}")
216
+ datum = json.load(file)
217
+ for data_line in datum:
218
+ for part_line in data_line['body']:
219
+ for text_line in part_line:
220
+ metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "N")
221
+
222
+ scanned_text = TextManipulation._remove_most_nonchar(text_line['text'], self.lower_case)
223
+
224
+ text_line_scanned = self._construct_line(scanned_text, metre)
225
+ syllable_line = self._construct_syllable_line(scanned_text, metre)
226
+ #phonetic_text = self._introduce_phonetics(scanned_text, text_line)
227
+
228
+ num_vowels, verse_end = self._vowels_and_endings(scanned_text)
229
+
230
+ # Based on result of random chose proper set. Because data are large enough, will result in wanted split.
231
+ rand_split = np.random.rand()
232
+ if rand_split > self.val_data_rate + self.test_data_rate:
233
+ self.data.append({
234
+ "input_ids" : [text_line_scanned,syllable_line],
235
+ "nums": [num_vowels],
236
+ "verse_end": verse_end,
237
+ "metre": metre
238
+ })
239
+ elif rand_split < self.test_data_rate:
240
+ self.test_data.append({
241
+ "input_ids" : [text_line_scanned,syllable_line],
242
+ "nums": [num_vowels],
243
+ "verse_end": verse_end,
244
+ "metre": metre
245
+ })
246
+ else:
247
+ self.validation_data.append({
248
+ "input_ids" : [text_line_scanned,syllable_line],
249
+ "nums": [num_vowels],
250
+ "verse_end": verse_end,
251
+ "metre": metre
252
+ })
253
+
254
+
255
+ def __len__(self):
256
+ """Return length of training data
257
+
258
+ Returns:
259
+ int: length of training data
260
+ """
261
+ return len(self.data)
262
+
263
+ def __getitem__(self, index):
264
+ """return indexed item
265
+
266
+ Args:
267
+ index (int): index from where to return
268
+
269
+ Returns:
270
+ dict: dict with indexed data
271
+ """
272
+ return self.data[index]
273
+
274
+ class BodyDataset(Dataset):
275
+ """Dataset of preprocessed strophe
276
+
277
+ Args:
278
+ Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface
279
+ """
280
+ def __init__(self, data_file_paths,
281
+ prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05):
282
+ """Construct the class our given data files path and store variables
283
+
284
+ Args:
285
+ data_file_paths (_type_): list of paths to data files
286
+ prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
287
+ prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
288
+ prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True.
289
+ verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
290
+ lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
291
+ val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
292
+ test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
293
+ """
294
+ self._data_file_paths = data_file_paths
295
+ self.prompt_length = prompt_length
296
+ self.prompt_ending = prompt_ending
297
+ self.prompt_verse = prompt_verse
298
+ self.verse_len = verse_len
299
+ self.lower_case = lower_case
300
+
301
+ self.val_data_rate = val_data_rate
302
+ self.test_data_rate = test_data_rate
303
+
304
+ self.data = []
305
+ self.validation_data = []
306
+ self.test_data = []
307
+
308
+ def gen_files(self):
309
+ """Get individual opened files
310
+
311
+ Yields:
312
+ _type_: open file object
313
+ """
314
+ for filename in self._data_file_paths:
315
+ yield open(filename, 'r')
316
+
317
+
318
+
319
+
320
+ def _construct_line(self, raw_text, metre):
321
+ """Construct individual content line
322
+
323
+ Args:
324
+ raw_text (str): raw verse line
325
+
326
+ Returns:
327
+ str: Processed verse line with line parameters
328
+ """
329
+ syllables = SyllableMaker.syllabify(raw_text)
330
+ num_str = f"{len(syllables)} # " if self.prompt_length else ""
331
+ verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
332
+ metre_txt = f"{metre} # "
333
+ return metre_txt + num_str + verse_end + raw_text
334
+
335
+ def _construct_syllable_line(self, raw_text, metre):
336
+ """Construct individual content line as sequence of syllables
337
+
338
+ Args:
339
+ raw_text (str): raw verse line
340
+
341
+ Returns:
342
+ str: Processed verse line as sequence of syllables with line parameters
343
+ """
344
+ ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
345
+ syllables = SyllableMaker.syllabify(raw_text)
346
+ num_str = f"{len(syllables)} # " if self.prompt_length else ""
347
+ verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
348
+ metre_txt = f"{metre} # "
349
+ return metre_txt + num_str + verse_end + " ".join(syllables) + ending
350
+
351
+
352
+
353
+ def data_body_gen(self):
354
+ """Preprocess and process data for usage
355
+ """
356
+ for step,file in enumerate(self.gen_files()):
357
+ if step % 500 == 0:
358
+ print(f"Processing file {step}")
359
+ datum = json.load(file)
360
+
361
+ for data_line in datum:
362
+ publish_year_text = TextManipulation._year_bucketor(data_line["biblio"]["year"])
363
+ publish_year_true = data_line["biblio"]["year"] if TextAnalysis._is_year(data_line["biblio"]["year"]) else 'NaN'
364
+ context = ["NO CONTEXT"]
365
+
366
+ for part_line in data_line['body']:
367
+ body = []
368
+ body_syllabs = []
369
+ rhyme= []
370
+ metres = []
371
+ i = 0
372
+ for text_line in part_line:
373
+
374
+ # In rare cases multiple, but from searching only 1 metre per line
375
+ metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "J")
376
+ metres += [metre]
377
+
378
+ rhyme.append(text_line["rhyme"])
379
+
380
+ scanned_text = TextManipulation._remove_most_nonchar(text_line["text"], self.lower_case)
381
+
382
+ body.append(self._construct_line(scanned_text,metre))
383
+ body_syllabs.append(self._construct_syllable_line(scanned_text,metre))
384
+
385
+ i+=1
386
+
387
+ if i in self.verse_len:
388
+
389
+ rhyme_str = TextManipulation._rhyme_string(rhyme)
390
+
391
+ text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body) + "\n"
392
+ syllable_text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body_syllabs) + "\n"
393
+ context_text= "\n".join(context)
394
+ rand_split = np.random.rand()
395
+ if rand_split > self.val_data_rate + self.test_data_rate:
396
+ self.data.append({
397
+ "input_ids" : [text,syllable_text],
398
+ "context_ids" : context_text,
399
+ "year": publish_year_true,
400
+ "rhyme": rhyme_str,
401
+ "metre_ids" : metres.copy()
402
+ })
403
+ elif rand_split < self.test_data_rate:
404
+ self.test_data.append({
405
+ "input_ids" : [text,syllable_text],
406
+ "context_ids" : context_text,
407
+ "year": publish_year_true,
408
+ "rhyme": rhyme_str,
409
+ "metre_ids" : metres.copy()
410
+ })
411
+ else:
412
+ self.validation_data.append({
413
+ "input_ids" : [text,syllable_text],
414
+ "context_ids" : context_text,
415
+ "year": publish_year_true,
416
+ "rhyme": rhyme_str,
417
+ "metre_ids" : metres.copy()
418
+ })
419
+
420
+ if i == max(self.verse_len):
421
+ body = []
422
+ body_syllabs = []
423
+ rhyme = []
424
+ metres = []
425
+ i=0
426
+
427
+
428
+ def __len__(self):
429
+ """Return length of training data
430
+
431
+ Returns:
432
+ int: length of training data
433
+ """
434
+ return len(self.data)
435
+
436
+ def __getitem__(self, index):
437
+ """return indexed item
438
+
439
+ Args:
440
+ index (int): index from where to return
441
+
442
+ Returns:
443
+ dict: dict with indexed data
444
+ """
445
+ return self.data[index]
446
+
447
+ def get_filenames(self):
448
+ """Get paths of data files
449
+
450
+ Returns:
451
+ list: Paths of data files
452
+ """
453
+ data_filenames = os.listdir(self.data_dir)
454
+ data_by_files = []
455
+ for filename in data_filenames:
456
+ file_path = os.path.join(self.data_dir, filename)
457
+ data_by_files.append(file_path)
458
+ return data_by_files
459
+
460
+ def load_raw_(self):
461
+ """Load Raw dataset with raw string data
462
+ """
463
+ filenames = self.get_filenames()
464
+
465
+ self.raw_dataset = CorpusDatasetPytorch.RawDataset(filenames, self.lower_case)
466
+
467
+ def load_json_filenames(self, prompt_length, prompt_ending, prompt_verse, verse_len=[4,6], val_data_rate=0.05, test_data_rate=0.05):
468
+ """Load Verse and Strophe datasets
469
+
470
+ Args:
471
+ prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
472
+ prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
473
+ prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True.
474
+ verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
475
+ val_data_rate (float, optional): If the string should be in lowercase. Defaults to 0.1.
476
+ """
477
+ filenames = self.get_filenames()
478
+
479
+ self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset(filenames, prompt_ending=prompt_ending,
480
+ prompt_length=prompt_length, prompt_verse=prompt_verse,
481
+ verse_len=verse_len, lower_case=self.lower_case,
482
+ val_data_rate=val_data_rate, test_data_rate=test_data_rate)
483
+ self.pytorch_dataset_body.data_body_gen()
484
+
485
+
486
+ self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset(filenames, prompt_ending=prompt_ending,
487
+ prompt_length=prompt_length, lower_case=self.lower_case,
488
+ val_data_rate=val_data_rate, test_data_rate=test_data_rate)
489
+
490
+ self.pytorch_dataset_text.data_text_line_gen()
491
+
492
+ self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
493
+ self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
494
+
495
+ self.val_pytorch_dataset_body.data = self.pytorch_dataset_body.validation_data
496
+ self.val_pytorch_dataset_text.data = self.pytorch_dataset_text.validation_data
497
+
498
+ self.pytorch_dataset_text.validation_data = []
499
+ self.pytorch_dataset_body.validation_data = []
500
+
501
+ self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
502
+ self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
503
+
504
+ self.test_pytorch_dataset_body.data = self.pytorch_dataset_body.test_data
505
+ self.test_pytorch_dataset_text.data = self.pytorch_dataset_text.test_data
506
+
507
+ self.pytorch_dataset_text.test_data = []
508
+ self.pytorch_dataset_body.test_data = []
509
+
510
+ def create_empty(self):
511
+ """Create empty holder for possible load of processed data from file
512
+ """
513
+ self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
514
+ self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
515
+ self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
516
+ self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
517
+ self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
518
+ self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
519
+
520
+
521
+ @staticmethod
522
+ def collate(batch, tokenizer: PreTrainedTokenizerBase ,max_len = 1024, max_context = 1024 ,mask_rate = 0.0, syllables: bool = False, format: str = 'METER_VERSE'):
523
+ """Process data for usage in LM
524
+
525
+ Args:
526
+ batch (_type_): Batch with selected data points
527
+ tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text
528
+ max_len (int, optional): Maximum length of tokenization. Defaults to 1024.
529
+ max_context (int, optional): Maximum length of tokenization of context. Defaults to 1024.
530
+ mask_rate (float, optional): Rate in with to mask data. Defaults to 0.0.
531
+ syllables (bool, optional): If to use sequence of syllables as input text. Defaults to False.
532
+
533
+ Returns:
534
+ dict: tokenized and processed to tensors data
535
+ """
536
+ index = 1 if syllables else 0
537
+
538
+ tokenizer.model_max_length = max_len
539
+ if batch[0]['input_ids'][0].startswith("#"):
540
+
541
+ data = [text['input_ids'][index] for text in batch]
542
+ if format == "BASIC":
543
+ data = ["\n".join
544
+ (
545
+ [line + f" # {datum.splitlines()[1].split()[0]}"
546
+ if i==0 else line.split('#')[-1] for i, line in enumerate(datum.splitlines())]
547
+ ) + tokenizer.eos_token for j, datum in enumerate(data)
548
+ ]
549
+ elif format == "VERSE_PAR":
550
+ data = ["\n".join
551
+ (
552
+ [line + f" # {datum.splitlines()[1].split()[0]}"
553
+ if i==0 else "#".join(line.split('#')[1:]) for i, line in enumerate(datum.splitlines())]
554
+ ) + tokenizer.eos_token for j, datum in enumerate(data)
555
+ ]
556
+ else:
557
+ data = [text['input_ids'][index] + tokenizer.eos_token for text in batch]
558
+
559
+ tokenized = tokenizer(data,return_tensors='pt', truncation=True, padding=True)
560
+ input_ids = tokenized['input_ids']
561
+ attention = tokenized["attention_mask"]
562
+
563
+ else:
564
+ tokenized = tokenizer([text['input_ids'][index] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True)
565
+ input_ids = tokenized['input_ids']
566
+ attention = tokenized["attention_mask"]
567
+
568
+
569
+ nums = None
570
+ if "nums" in batch[0].keys():
571
+ nums = torch.tensor(np.asarray([text['nums'] for text in batch], dtype=np.int32), dtype=torch.float32)
572
+
573
+ rhyme=None
574
+ if "rhyme" in batch[0].keys():
575
+ rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32)
576
+
577
+ verse_end = None
578
+ if "verse_end" in batch[0].keys():
579
+ verse_end = torch.tensor(np.asarray([CorpusDatasetPytorch.TextDataset._ending_vector(text["verse_end"]) for text in batch], dtype=np.int32), dtype=torch.float32)
580
+
581
+ year = None
582
+ if "year" in batch[0].keys():
583
+ year = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32)
584
+
585
+ metre = None
586
+ if "metre" in batch[0].keys():
587
+ metre = torch.tensor(np.asarray([TextAnalysis._metre_vector(text["metre"]) for text in batch], dtype=np.int32), dtype=torch.float32)
588
+
589
+ context_ids = None
590
+ context_attention_mask = None
591
+ if "context_ids" in batch[0].keys():
592
+ tokenizer.model_max_length = max_context
593
+ tokenized_context = tokenizer([text['context_ids'] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True)
594
+ context_ids = tokenized_context['input_ids']
595
+ context_attention_mask = tokenized_context['attention_mask']
596
+
597
+ return {
598
+ "input_ids": input_ids,
599
+ "labels": input_ids.type(torch.LongTensor),
600
+ "attention_mask": attention,
601
+ "context_ids" : context_ids,
602
+ "context_attention_mask" : context_attention_mask,
603
+ "nums" : nums,
604
+ "rhyme": rhyme,
605
+ "verse_end" : verse_end,
606
+ "year": year,
607
+ "metre" : metre}
608
+
609
+
610
+ @staticmethod
611
+ def collate_distil(batch, tokenizer: PreTrainedTokenizerBase ,surrogate_model: PreTrainedModel = None,surrogate_model_device=None ,max_len = 1024):
612
+ tokenizer.model_max_length = max_len
613
+ tokenized = tokenizer([text['input_ids'][0] + tokenizer.eos_token for text in batch], return_tensors='pt', truncation=True, padding=True)
614
+ input_ids = tokenized['input_ids']
615
+ attention = tokenized["attention_mask"]
616
+
617
+ with torch.no_grad():
618
+ # This is Tuple
619
+ model_hidden_states = surrogate_model(input_ids=input_ids.to(surrogate_model_device),
620
+ attention_mask=attention.to(surrogate_model_device),
621
+ labels=input_ids.type(torch.LongTensor).to(surrogate_model_device))['hidden_states']
622
+ model_hidden_states = [hidden.cpu().detach() for hidden in model_hidden_states]
623
+
624
+ return {
625
+ "input_ids": input_ids,
626
+ "labels": input_ids.type(torch.LongTensor),
627
+ "attention_mask": attention,
628
+ "to_replicate_states": model_hidden_states
629
+ }
630
+
631
+ @staticmethod
632
+ def collate_validator(batch, tokenizer: PreTrainedTokenizerBase,syllables:bool, is_syllable:bool = False,max_len = 512):
633
+ """Process data for use in LM for metre,rhyme and year prediction
634
+
635
+ Args:
636
+ batch (_type_): Batch with selected data points
637
+ tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text
638
+ syllables (bool): If to use sequence of syllables as input text
639
+ is_syllable (bool, optional): Signal if the preprocessed inputs contain syllable data. Defaults to False.
640
+ max_len (int, optional): Maximum length of tokenization. Defaults to 1024.
641
+
642
+ Returns:
643
+ dict: tokenized and processed to tensors data
644
+ """
645
+ index = 1 if syllables and is_syllable else 0
646
+ tokenizer.model_max_length = max_len
647
+ data_ids = ["\n".join(
648
+ [" ".join(
649
+ SyllableMaker.syllabify(line.split('#')[-1])
650
+ ) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in text['input_ids'][index].splitlines()[1:]]
651
+ ) for text in batch ]
652
+
653
+
654
+ tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True)
655
+ input_ids = tokenized['input_ids']
656
+ attention = tokenized["attention_mask"]
657
+
658
+ rhyme=None
659
+ if "rhyme" in batch[0].keys():
660
+ rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32)
661
+
662
+ year_bucket = None
663
+ year = None
664
+ if "year" in batch[0].keys():
665
+ year_bucket = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32)
666
+ year = torch.tensor(np.asarray([ [int(text['year'])] if text['year'] != 'NaN' else [0] for text in batch], dtype=np.int32), dtype=torch.float32)
667
+
668
+ return {
669
+ "input_ids": input_ids,
670
+ "attention_mask": attention,
671
+ "rhyme": rhyme,
672
+ "metre_ids": None,
673
+ "year_bucket": year_bucket,
674
+ 'year':year}
675
+
676
+ @staticmethod
677
+ def collate_meter(batch, tokenizer: PreTrainedTokenizerBase, syllables:bool, is_syllable:bool = False, max_len = 512):
678
+ index = 1 if syllables and is_syllable else 0
679
+ tokenizer.model_max_length = max_len
680
+ data_ids = []
681
+ metre = []
682
+ for datum in batch:
683
+ data_ids += [
684
+ " ".join(
685
+ SyllableMaker.syllabify(line.split('#')[-1])
686
+ ) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in datum['input_ids'][index].splitlines()[1:]
687
+ ]
688
+ if "metre_ids" in batch[0].keys():
689
+ metre += [TextAnalysis._metre_vector(one_metre) for one_metre in datum['metre_ids']]
690
+
691
+ tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True)
692
+ input_ids = tokenized['input_ids']
693
+ attention = tokenized["attention_mask"]
694
+
695
+ metre_ids = None
696
+ if len(metre) > 0:
697
+ metre_ids = torch.tensor(np.asarray(metre, dtype=np.int32), dtype=torch.float32)
698
+
699
+ return {
700
+ "input_ids": input_ids,
701
+ "attention_mask": attention,
702
+ "rhyme": None,
703
+ "metre_ids": metre_ids,
704
+ "year_bucket": None,
705
+ "year": None}
706
+
707
+
708
+
709
+ def __init__(self, data_dir = "PoetGen\corpusCzechVerse-master\ccv", cache_dir='./',
710
+ prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate=0.05, test_data_rate=0.05):
711
+ """Construct the Dataloader and create Datasets
712
+
713
+ Args:
714
+ data_dir (str, optional): Path to data. Defaults to "PoetGen\corpusCzechVerse-master\ccv".
715
+ cache_dir (str, optional): Path where to store processed data. Defaults to './'.
716
+ prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
717
+ prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
718
+ prompt_verse (bool, optional): If to prompt rhyme schema. Defaults to True.
719
+ verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
720
+ lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
721
+ val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.1.
722
+ """
723
+ self.lower_case = lower_case
724
+ self.data_dir = data_dir
725
+ if os.path.isfile(os.path.join(cache_dir, "body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "text_poet_data.json")) \
726
+ and os.path.isfile(os.path.join(cache_dir, "val_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "val_text_poet_data.json")) \
727
+ and os.path.isfile(os.path.join(cache_dir, "test_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "test_text_poet_data.json")) :
728
+ self.create_empty()
729
+ self.pytorch_dataset_body.data =list(json.load( open( os.path.join(cache_dir, "body_poet_data.json"), 'r')))
730
+ self.pytorch_dataset_text.data =list(json.load( open( os.path.join(cache_dir, "text_poet_data.json"), 'r')))
731
+ self.val_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "val_body_poet_data.json"), 'r')))
732
+ self.val_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "val_text_poet_data.json"), 'r')))
733
+ self.test_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "test_body_poet_data.json"), 'r')))
734
+ self.test_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "test_text_poet_data.json"), 'r')))
735
+ else:
736
+ self.load_json_filenames(prompt_length, prompt_ending, prompt_verse, verse_len=verse_len, val_data_rate=val_data_rate, test_data_rate=test_data_rate)
737
+ json.dump(self.pytorch_dataset_body.data, open( os.path.join(cache_dir, "body_poet_data.json"), 'w+'), indent = 6)
738
+ json.dump(self.pytorch_dataset_text.data, open( os.path.join(cache_dir, "text_poet_data.json"), 'w+'), indent = 6)
739
+ json.dump(self.val_pytorch_dataset_body.data, open( os.path.join(cache_dir, "val_body_poet_data.json"), 'w+'), indent = 6)
740
+ json.dump(self.val_pytorch_dataset_text.data, open( os.path.join(cache_dir, "val_text_poet_data.json"), 'w+'), indent = 6)
741
+ json.dump(self.test_pytorch_dataset_body.data, open( os.path.join(cache_dir, "test_body_poet_data.json"), 'w+'), indent = 6)
742
+ json.dump(self.test_pytorch_dataset_text.data, open( os.path.join(cache_dir, "test_text_poet_data.json"), 'w+'), indent = 6)
743
+
744
+ self.load_raw_()
745
+
746
+
747
+
748
+ #if __name__ == "__main__":
749
+ # Line Count
750
+ # print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_text())))
751
+ # Strophe Count
752
+ # print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_part())))
753
+ # Poem Count
754
+ # print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_body())))
simple_generation_player.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+
6
+ import sys
7
+
8
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
9
+ from utils.poet_utils import StropheParams, Tokens, TextManipulation, TextAnalysis
10
+ from utils.base_poet_models import PoetModelBase
11
+ from utils.validators import ValidatorInterface
12
+
13
+ from corpus_capsulated_datasets import CorpusDatasetPytorch
14
+
15
+ parser = argparse.ArgumentParser()
16
+
17
+ parser.add_argument("--model_path_full", default='jinymusim/gpt-czech-poet', type=str, help="Path to Model")
18
+
19
+ parser.add_argument("--rhyme_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils', 'validators', 'rhyme', 'distilroberta-base_BPE_validator_1704126399565')), type=str, help="Path to Model")
20
+ parser.add_argument("--metre_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'meter', 'ufal-robeczech-base_BPE_validator_1704126400265')), type=str, help="Path to Model")
21
+ parser.add_argument("--year_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'year', 'ufal-robeczech-base_BPE_validator_1702393305267')), type=str, help="Path to Model")
22
+
23
+ parser.add_argument("--validator_tokenizer_model_rhyme", default='distilroberta-base', type=str, help="Validator tokenizer")
24
+ parser.add_argument("--validator_tokenizer_model_meter", default='ufal/robeczech-base', type=str, help="Validator tokenizer")
25
+ parser.add_argument("--validator_tokenizer_model_year", default='ufal/robeczech-base', type=str, help="Validator tokenizer")
26
+ parser.add_argument("--val_syllables_rhyme", default=False, type=bool, help="Does validator use syllables")
27
+ parser.add_argument("--val_syllables_meter", default=False, type=bool, help="Does validator use syllables")
28
+ parser.add_argument("--val_syllables_year", default=False, type=bool, help="Does validator use syllables")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ args = parser.parse_args([] if "__file__" not in globals() else None)
33
+
34
+ _ ,model_rel_name = os.path.split(args.model_path_full)
35
+
36
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
+
38
+ model = PoetModelBase(args.model_path_full).to(device)
39
+ model.eval()
40
+
41
+ rhyme_model, meter_model, year_model = None, None, None
42
+ rhyme_model_name, meter_model_name, year_model_name = "", "", ""
43
+ if args.rhyme_model_path_full:
44
+ rhyme_model: ValidatorInterface = (torch.load(args.rhyme_model_path_full, map_location=torch.device('cpu'))).to(device)
45
+ rhyme_model.eval()
46
+ _, rhyme_model_name = os.path.split(args.rhyme_model_path_full)
47
+
48
+ if args.metre_model_path_full:
49
+ meter_model: ValidatorInterface = (torch.load(args.metre_model_path_full, map_location=torch.device('cpu'))).to(device)
50
+ meter_model.eval()
51
+ _, meter_model_name = os.path.split(args.metre_model_path_full)
52
+
53
+ if args.year_model_path_full:
54
+ year_model: ValidatorInterface = (torch.load(args.year_model_path_full, map_location=torch.device('cpu'))).to(device)
55
+ year_model.eval()
56
+ _, year_model_name = os.path.split(args.year_model_path_full)
57
+ # Load Rhyme tokenizer
58
+ validator_tokenizer_rhyme: PreTrainedTokenizerBase = None
59
+ if args.validator_tokenizer_model_rhyme:
60
+ try:
61
+ validator_tokenizer_rhyme = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_rhyme)
62
+ except:
63
+ validator_tokenizer_rhyme: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_rhyme)
64
+ validator_tokenizer_rhyme.eos_token = Tokens.EOS
65
+ validator_tokenizer_rhyme.eos_token_id = Tokens.EOS_ID
66
+ validator_tokenizer_rhyme.pad_token = Tokens.PAD
67
+ validator_tokenizer_rhyme.pad_token_id = Tokens.PAD_ID
68
+ validator_tokenizer_rhyme.unk_token = Tokens.UNK
69
+ validator_tokenizer_rhyme.unk_token_id = Tokens.UNK_ID
70
+ validator_tokenizer_rhyme.cls_token = Tokens.CLS
71
+ validator_tokenizer_rhyme.cls_token_id = Tokens.CLS_ID
72
+ validator_tokenizer_rhyme.sep_token = Tokens.SEP
73
+ validator_tokenizer_rhyme.sep_token_id = Tokens.SEP_ID
74
+
75
+ # Load Meter tokenizer
76
+ validator_tokenizer_meter: PreTrainedTokenizerBase = None
77
+ if args.validator_tokenizer_model_meter:
78
+ try:
79
+ validator_tokenizer_meter = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_meter, revision='v1.0')
80
+ except:
81
+ validator_tokenizer_meter: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_meter)
82
+ validator_tokenizer_meter.eos_token = Tokens.EOS
83
+ validator_tokenizer_meter.eos_token_id = Tokens.EOS_ID
84
+ validator_tokenizer_meter.pad_token = Tokens.PAD
85
+ validator_tokenizer_meter.pad_token_id = Tokens.PAD_ID
86
+ validator_tokenizer_meter.unk_token = Tokens.UNK
87
+ validator_tokenizer_meter.unk_token_id = Tokens.UNK_ID
88
+ validator_tokenizer_meter.cls_token = Tokens.CLS
89
+ validator_tokenizer_meter.cls_token_id = Tokens.CLS_ID
90
+ validator_tokenizer_meter.sep_token = Tokens.SEP
91
+ validator_tokenizer_meter.sep_token_id = Tokens.SEP_ID
92
+
93
+ # Load Year tokenizer
94
+ validator_tokenizer_year: PreTrainedTokenizerBase = None
95
+ if args.validator_tokenizer_model_year:
96
+ try:
97
+ validator_tokenizer_year = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_year, revision='v1.0')
98
+ except:
99
+ validator_tokenizer_year: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_year)
100
+ validator_tokenizer_year.eos_token = Tokens.EOS
101
+ validator_tokenizer_year.eos_token_id = Tokens.EOS_ID
102
+ validator_tokenizer_year.pad_token = Tokens.PAD
103
+ validator_tokenizer_year.pad_token_id = Tokens.PAD_ID
104
+ validator_tokenizer_year.unk_token = Tokens.UNK
105
+ validator_tokenizer_year.unk_token_id = Tokens.UNK_ID
106
+ validator_tokenizer_year.cls_token = Tokens.CLS
107
+ validator_tokenizer_year.cls_token_id = Tokens.CLS_ID
108
+ validator_tokenizer_year.sep_token = Tokens.SEP
109
+ validator_tokenizer_year.sep_token_id = Tokens.SEP_ID
110
+
111
+ # Load LM tokenizers
112
+ tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(args.model_path_full)
113
+
114
+ generation = "BASIC"
115
+
116
+ def decoder_helper(type, user_input):
117
+ if type == "BASIC":
118
+ tokenized = tokenizer.encode(user_input, return_tensors='pt', truncation=True)
119
+ out = model.model.generate(tokenized.to(device),
120
+ max_length=512,
121
+ do_sample=True,
122
+ top_k=50,
123
+ eos_token_id = tokenizer.eos_token_id,
124
+ early_stopping=True,
125
+ pad_token_id= tokenizer.pad_token_id)
126
+ return tokenizer.decode(out.cpu()[0], skip_special_tokens=True)
127
+ if type=="FORCED":
128
+ return model.generate_forced(user_input, tokenizer, sample=True, device=device)
129
+
130
+ help = f"Current setting is {generation} generating.\nChange it by writing FORCED/BASIC to input. type HELP for HELP.\nType EXIT to exit."
131
+
132
+ print("Welcome to simple czech strophe generation.", help)
133
+
134
+ while True:
135
+
136
+ user_input = ""
137
+ while True:
138
+ curr_line = input(">").strip()
139
+ if curr_line == 'EXIT':
140
+ sys.exit()
141
+ elif curr_line == "HELP":
142
+ print(help)
143
+ continue
144
+ elif curr_line == "BASIC":
145
+ print("Changed to BASIC")
146
+ generation = 'BASIC'
147
+ continue
148
+ elif curr_line == "FORCED":
149
+ print("Changed to FORCED")
150
+ generation = "FORCED"
151
+ continue
152
+ if not curr_line:
153
+ break
154
+ user_input += curr_line + "\n"
155
+
156
+ user_input = user_input.strip()
157
+ user_reqs = model.analyze_prompt(user_input)
158
+
159
+ if "RHYME" not in user_reqs.keys() and generation == "BASIC":
160
+ print("BASIC generation can't work with imputed format.", help)
161
+ print("User input is substituted for #")
162
+ user_input = '#'
163
+
164
+ generated_poem:str = decoder_helper(generation, user_input)
165
+
166
+ # Predictions
167
+ meters = []
168
+ rhyme_pred = ''
169
+ year_pred = 0
170
+ for line in generated_poem.splitlines():
171
+ # Skip Empty lines
172
+ if not line.strip():
173
+ break
174
+ if not (TextManipulation._remove_most_nonchar(line)).strip():
175
+ break
176
+ # Validate for Strophe Parameters
177
+ if TextAnalysis._is_param_line(line):
178
+ data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_rhyme,
179
+ is_syllable=False, syllables=args.val_syllables_rhyme,
180
+ max_len=rhyme_model.model.config.max_position_embeddings - 2)
181
+ rhyme_pred =StropheParams.RHYME[np.argmax(rhyme_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())]
182
+ data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_year,
183
+ is_syllable=False, syllables=args.val_syllables_year,
184
+ max_len=year_model.model.config.max_position_embeddings - 2)
185
+ year_pred = round(year_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy()[0])
186
+ continue
187
+ data = CorpusDatasetPytorch.collate_meter([{"input_ids" :["FIRST LINE SKIP!\n" + line]}],tokenizer=validator_tokenizer_meter,
188
+ is_syllable=False, syllables=args.val_syllables_meter,
189
+ max_len=meter_model.model.config.max_position_embeddings - 2)
190
+ meters.append(
191
+ StropheParams.METER[np.argmax(meter_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())]
192
+ )
193
+ print(f"REQUESTED: {user_reqs}, GENERATED USING: {generation}\n")
194
+ print(generated_poem.strip())
195
+ print(f"PREDICTED: {rhyme_pred}, {year_pred}, {meters}\n\n")
utils/__init__.py ADDED
File without changes
utils/base_poet_models.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .poet_model_utils import PoetModelInterface
2
+ from .poet_utils import TextAnalysis, StropheParams
3
+
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from transformers.utils import ModelOutput
6
+ import random
7
+ import torch
8
+
9
+ class PoetModelFunctionalInterface(PoetModelInterface):
10
+ """Poet Model Functional Interface. Abstract class with implementation of
11
+
12
+ Args:
13
+ PoetModelInterface (_type_): Is child of PoetModelInterface for carrying core methods
14
+ """
15
+ def __init__(self, *args, **kwargs) -> None:
16
+ """ Constructor. As child Class needs to construct Parent
17
+ """
18
+ super().__init__(*args, **kwargs)
19
+
20
+ def analyze_prompt(self, prompt) -> dict:
21
+ """Analysis of users prompt
22
+
23
+ Args:
24
+ prompt (_type_): dict or string, carrying users intent
25
+
26
+ Returns:
27
+ dict: Analysis with users intended input
28
+ """
29
+ if isinstance(prompt, dict):
30
+ return prompt
31
+ features_dict = {}
32
+ lines = prompt.splitlines()
33
+ lines = list(map(str.strip, lines))
34
+ i = 0
35
+ while i < len(lines):
36
+ if not lines[i]:
37
+ lines.pop(i)
38
+ i-=1
39
+ i+=1
40
+ cont_line = 0
41
+ for line in lines:
42
+ if TextAnalysis._is_param_line(line):
43
+ for key, value in TextAnalysis._first_line_analysis(line).items():
44
+ features_dict[key] = value
45
+ else:
46
+ val = cont_line
47
+ if "RHYME" in features_dict.keys() and cont_line < len(features_dict['RHYME']):
48
+ if features_dict["RHYME"][cont_line] == "A":
49
+ val = 0
50
+ elif features_dict["RHYME"][cont_line] == "B":
51
+ val = 1
52
+ elif features_dict["RHYME"][cont_line] == "C":
53
+ val = 2
54
+ elif features_dict["RHYME"][cont_line] == "D":
55
+ val = 3
56
+ for key, value in TextAnalysis._continuos_line_analysis(line).items():
57
+ features_dict[f"{key}_{val}"] = value
58
+ cont_line += 1
59
+
60
+ return features_dict
61
+
62
+ def generate_forced(self, prompt, tokenizer: AutoTokenizer, sample: bool = True, format: str = 'METER_VERSE', device= torch.device('cpu'), *args, **kwargs) -> str:
63
+ """Generate Strophe using the FORCED generation
64
+
65
+ Args:
66
+ prompt (_type_): dict or string of users intended parameters of strophe start
67
+ tokenizer (AutoTokenizer): tokenizer to be used during generation. Should be model specific.
68
+ sample (bool, optional): If to sample. Defaults to False.
69
+ format (str, optional): Format of generation to be used. Should be same as trained on. possible formats: BASIC, VERSE_PAR, METER_VERSE, OLD (DEPRECATED! For old models compatibility only). Defaults to 'METER_VERSE'.
70
+ device (_type_, optional): Device to generate on. CPU as default. Defaults to torch.device('cpu').
71
+
72
+ Returns:
73
+ str: Generated Strophe
74
+ """
75
+ features_dict_init = self.analyze_prompt(prompt)
76
+ # If user parameters as dict, list is initialized to carry future verses.
77
+ if isinstance(prompt, dict):
78
+ prompt_list = []
79
+ else:
80
+ prompt_list = prompt.splitlines()
81
+ # GENERATE FOR POSSIBLE MISSING POET PARAM
82
+ token_gen_rhyme = tokenizer.encode("#", return_tensors='pt')
83
+ if sample:
84
+ rhyme_line = self.model.generate(token_gen_rhyme.to(device),
85
+ max_new_tokens= 100,
86
+ do_sample=True,
87
+ top_k=50,
88
+ early_stopping=True,
89
+ pad_token_id=tokenizer.pad_token_id,
90
+ eos_token_id=tokenizer.eos_token_id)
91
+ else:
92
+ rhyme_line = self.model.generate(token_gen_rhyme.to(device),
93
+ max_new_tokens= 100,
94
+ num_beams=8,
95
+ no_repeat_ngram_size=2,
96
+ early_stopping=True,
97
+ pad_token_id=tokenizer.pad_token_id,
98
+ eos_token_id=tokenizer.eos_token_id)
99
+ rhyme_dec = tokenizer.decode(rhyme_line.cpu()[0], skip_special_tokens=True).splitlines()[0]
100
+ features_dict= TextAnalysis._first_line_analysis(rhyme_dec)
101
+ for key, value in features_dict_init.items():
102
+ features_dict[key] = value
103
+ # CONSTRUCT BEST INPUT LINE
104
+ # BACKUP RHYME
105
+ if "RHYME" not in features_dict.keys():
106
+ features_dict["RHYME"] = random.choice(StropheParams.RHYME[:-1])
107
+ #OLD
108
+ if format == 'OLD':
109
+ poet_param_str = ""
110
+ if "RHYME" in features_dict.keys():
111
+ poet_param_str += features_dict["RHYME"]
112
+ if "YEAR" in features_dict.keys():
113
+ poet_param_str += f" # {features_dict['YEAR']}"
114
+ if 'STROPHE_METER' in features_dict.keys():
115
+ poet_param_str += f" # {features_dict['STROPHE_METER']}"
116
+
117
+ elif format != 'METER_VERSE':
118
+ poet_param_str = "# "
119
+ if "RHYME" in features_dict.keys():
120
+ poet_param_str += features_dict["RHYME"]
121
+ if "YEAR" in features_dict.keys():
122
+ poet_param_str += f" # {features_dict['YEAR']}"
123
+ if 'STROPHE_METER' in features_dict.keys():
124
+ poet_param_str += f" # {features_dict['STROPHE_METER']}"
125
+ # NEW
126
+ else:
127
+ poet_param_str = "# "
128
+ if "RHYME" in features_dict.keys():
129
+ poet_param_str += features_dict["RHYME"]
130
+ if "YEAR" in features_dict.keys():
131
+ poet_param_str += f" # {features_dict['YEAR']}"
132
+ # REPLACE OR INSERT BASED ON PRESENCE
133
+ if len(features_dict_init.keys()) == 0: # Wierd Input
134
+ prompt_list = [poet_param_str]
135
+ elif len(prompt_list) == 0: # Inputed as Dict
136
+ prompt_list.append(poet_param_str)
137
+ elif "RHYME" not in features_dict_init.keys():
138
+ if "YEAR" in features_dict_init.keys() or 'STROPHE_METER' in features_dict_init.keys(): # Replace the Uncomplete first line
139
+ prompt_list[0] = poet_param_str
140
+ else:
141
+ prompt_list.insert(0, poet_param_str)
142
+ else:
143
+ prompt_list[0] = poet_param_str
144
+
145
+ verse_len = len(features_dict["RHYME"])
146
+
147
+ # Finish possible not completed lines
148
+ base_prompt_len = len(prompt_list)
149
+ for i in range(2,base_prompt_len + 1):
150
+ rhyme_char = 0
151
+ if features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "B":
152
+ rhyme_char = 1
153
+ elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "C":
154
+ rhyme_char = 2
155
+ elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "D":
156
+ rhyme_char = 3
157
+ elif features_dict["RHYME"][(i - 2) % len(features_dict["RHYME"])] == "X":
158
+ rhyme_char = -1
159
+
160
+ token_gen_finish = tokenizer.encode("\n".join(prompt_list[:i]), return_tensors='pt')
161
+ if sample:
162
+ finish_line = self.model.generate(token_gen_finish.to(device),
163
+ max_new_tokens= 100,
164
+ do_sample=True,
165
+ top_k=50,
166
+ early_stopping=True,
167
+ pad_token_id=tokenizer.pad_token_id,
168
+ eos_token_id=tokenizer.eos_token_id)
169
+ else:
170
+ finish_line = self.model.generate(token_gen_finish.to(device),
171
+ max_new_tokens= 100,
172
+ num_beams=8,
173
+ no_repeat_ngram_size=2,
174
+ early_stopping=True,
175
+ pad_token_id=tokenizer.pad_token_id,
176
+ eos_token_id=tokenizer.eos_token_id)
177
+ decoded = tokenizer.decode(finish_line.cpu()[0], skip_special_tokens=True).splitlines()
178
+ to_dec = min(i, len(decoded))
179
+ prompt_list[:to_dec] = decoded[:to_dec]
180
+
181
+ if to_dec - 1 < len(prompt_list):
182
+ dec_line = prompt_list[to_dec-1]
183
+ #OLD
184
+ if format == 'VERSE_PAR' or format == 'OLD':
185
+ if f"END_{rhyme_char}" not in features_dict.keys() and len(dec_line.split()) > 1 and rhyme_char>=0 and dec_line.count("#") <=1:
186
+ features_dict[f'LENGTH_{rhyme_char}'] = dec_line.split()[0]
187
+ features_dict[f'END_{rhyme_char}'] = dec_line.split()[1]
188
+ elif f"END_{rhyme_char}" not in features_dict.keys() and len(dec_line.split()) > 2 and rhyme_char>=0:
189
+ features_dict[f'LENGTH_{rhyme_char}'] = dec_line.split()[0]
190
+ features_dict[f'END_{rhyme_char}'] = dec_line.split()[2]
191
+ # NEW
192
+ elif format == 'METER_VERSE':
193
+ if f"END_{rhyme_char}" not in features_dict.keys() and len(dec_line.split()) > 4 and rhyme_char>=0:
194
+ features_dict[f'METER_{rhyme_char}'] = dec_line.split()[0]
195
+ features_dict[f'LENGTH_{rhyme_char}'] = dec_line.split()[2]
196
+ features_dict[f'END_{rhyme_char}'] = dec_line.split()[4]
197
+
198
+
199
+
200
+ # Generating 4 verse rhymes
201
+ has_rep= False
202
+ has_rep_again = False
203
+ while len(prompt_list) <= verse_len:
204
+ j = 0
205
+ if features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "B":
206
+ j = 1
207
+ elif features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "C":
208
+ j = 2
209
+ elif features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "D":
210
+ j = 3
211
+ elif features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "X":
212
+ j=-1
213
+ #OLD
214
+ if format == 'BASIC':
215
+ line_start = ""
216
+ elif format == 'OLD':
217
+ line_start = (f"{features_dict[f'LENGTH_{j}']} " if f"LENGTH_{j}" in features_dict.keys() else "" ) + \
218
+ (f" {features_dict[f'END_{j}'] } #" if f"END_{j}" in features_dict.keys() else "")
219
+ elif format == 'VERSE_PAR':
220
+ line_start = (f"{features_dict[f'LENGTH_{j}']} #" if f"LENGTH_{j}" in features_dict.keys() else "" ) + \
221
+ (f" {features_dict[f'END_{j}'] } #" if f"END_{j}" in features_dict.keys() else "")
222
+ else:
223
+ line_start = (f"{features_dict[f'METER_{j}'] } #" if f"METER_{j}" in features_dict.keys() else "") + \
224
+ (f" {features_dict[f'LENGTH_{j}']} #" if f"LENGTH_{j}" in features_dict.keys() else "" ) + \
225
+ (f" {features_dict[f'END_{j}'] } #" if f"END_{j}" in features_dict.keys() else "")
226
+ tokenized_poet_start = tokenizer.encode("\n".join(prompt_list) + "\n" + line_start, return_tensors='pt')
227
+ if sample:
228
+ out_line = self.model.generate(tokenized_poet_start.to(device),
229
+ max_new_tokens= 100,
230
+ do_sample=True,
231
+ top_k=50,
232
+ early_stopping=True,
233
+ pad_token_id=tokenizer.pad_token_id,
234
+ eos_token_id=tokenizer.eos_token_id)
235
+ else:
236
+ out_line = self.model.generate(tokenized_poet_start.to(device),
237
+ max_new_tokens= 100,
238
+ num_beams=2,
239
+ no_repeat_ngram_size=2,
240
+ early_stopping=True,
241
+ pad_token_id=tokenizer.pad_token_id,
242
+ eos_token_id=tokenizer.eos_token_id)
243
+ decoded_lines = tokenizer.decode(out_line.cpu()[0], skip_special_tokens=True).splitlines()
244
+ # Repetition catcher
245
+
246
+ # Possible
247
+ if len(decoded_lines) <= len(prompt_list) and not(has_rep_again and has_rep):
248
+ if has_rep:
249
+ prompt_list.pop()
250
+ has_rep= False
251
+ has_rep_again = True
252
+ else:
253
+ has_rep = True
254
+ continue
255
+ if has_rep_again and has_rep:
256
+ decoded_line: str = decoded_lines[-1]
257
+ else:
258
+ decoded_line: str = decoded_lines[len(prompt_list)]
259
+ #OLD
260
+ if format == 'VERSE_PAR' or format == 'OLD':
261
+ if f"END_{j}" not in features_dict.keys() and len(decoded_line.split()) > 1 and j>=0 and decoded_line.count("#") <=1:
262
+ features_dict[f'LENGTH_{j}'] = decoded_line.split()[0]
263
+ features_dict[f'END_{j}'] = decoded_line.split()[1]
264
+ elif f"END_{j}" not in features_dict.keys() and len(decoded_line.split()) > 2 and j>=0:
265
+ features_dict[f'LENGTH_{j}'] = decoded_line.split()[0]
266
+ features_dict[f'END_{j}'] = decoded_line.split()[2]
267
+ # NEW
268
+ elif format == 'METER_VERSE':
269
+ if f"END_{j}" not in features_dict.keys() and len(decoded_line.split()) > 4 and j>=0:
270
+ features_dict[f'METER_{j}'] = decoded_line.split()[0]
271
+ features_dict[f'LENGTH_{j}'] = decoded_line.split()[2]
272
+ features_dict[f'END_{j}'] = decoded_line.split()[4]
273
+
274
+ prompt_list.append(decoded_line)
275
+
276
+ return "\n".join(prompt_list)
277
+
278
+
279
+ class PoetModelBase(PoetModelFunctionalInterface):
280
+ def __init__(self, pretrainedModel, *args, **kwargs) -> None:
281
+ super().__init__(*args, **kwargs)
282
+
283
+ self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
284
+
285
+ model_config = self.model.config
286
+ self.model_size = 1
287
+ # Check for Hidden layer size by Attribute Name
288
+ if hasattr(model_config, "n_embd"):
289
+ self.model_size = model_config.n_embd
290
+ elif hasattr(model_config, "hidden_size"):
291
+ self.model_size = model_config.hidden_size
292
+
293
+
294
+ def forward(self, input_ids=None, labels=None, attention_mask=None, *args, **kwargs):
295
+ outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
296
+
297
+ return ModelOutput(loss= outputs.loss, model_output=outputs) # {"model_output" : outputs,"loss" : outputs.loss}
298
+
299
+ def save_LM(self, LM_path):
300
+ self.model.save_pretrained(LM_path, safe_serialization=False)
301
+
302
+
303
+ class PoetModelAllTasks(PoetModelFunctionalInterface):
304
+ def __init__(self, pretrainedModel, *args, **kwargs) -> None:
305
+ super().__init__(*args, **kwargs)
306
+
307
+ self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
308
+
309
+ model_config = self.model.config
310
+ self.model_size = 1
311
+ # Check for Hidden layer size by Attribute Name
312
+ if hasattr(model_config, "n_embd"):
313
+ self.model_size = model_config.n_embd
314
+ elif hasattr(model_config, "hidden_size"):
315
+ self.model_size = model_config.hidden_size
316
+
317
+ self.vowels_regressor = torch.nn.Linear(self.model_size,1) # Vowel Count
318
+ self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
319
+ self.verse_endings = torch.nn.Linear(self.model_size, len(StropheParams.ENDS)) # Verse End Syllable
320
+ self.metre_regressor = torch.nn.Linear(self.model_size,len(StropheParams.METER)) # Meter Type
321
+ self.year_regressor = torch.nn.Linear(self.model_size,len(StropheParams.YEAR)) # Year Bucket
322
+
323
+
324
+ def forward(self, input_ids=None, labels=None, attention_mask=None, nums=None, rhyme=None, verse_end=None, year=None, metre=None, *args, **kwargs):
325
+ outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
326
+ last_hidden = outputs['hidden_states'][-1]
327
+ vowel_regression = self.vowels_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
328
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
329
+ verse_end_reg = self.verse_endings((last_hidden[:,0,:].view(-1, self.model_size)))
330
+ metre_regression = self.metre_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
331
+ year_regression = self.year_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
332
+ full_loss = outputs.loss
333
+
334
+ vowel_loss = None
335
+ if nums is not None:
336
+ loss_fct = torch.nn.MSELoss()
337
+ vowel_loss = loss_fct(vowel_regression.view(-1, 1), nums.view(-1, 1))
338
+ full_loss = full_loss + 0.1*vowel_loss
339
+
340
+ rhyme_loss = None
341
+ if rhyme is not None:
342
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
343
+ loss_fct = torch.nn.CrossEntropyLoss()
344
+ rhyme_loss = loss_fct(softmaxed, rhyme)
345
+ full_loss = full_loss + 0.1*rhyme_loss
346
+
347
+ verse_loss = None
348
+ if verse_end is not None:
349
+ softmaxed = torch.softmax(verse_end_reg, dim=1)
350
+ loss_fct = torch.nn.CrossEntropyLoss()
351
+ verse_loss = loss_fct(softmaxed, verse_end)
352
+ full_loss = full_loss + 0.1*verse_loss
353
+
354
+ metre_loss = None
355
+ if metre is not None:
356
+ softmaxed = torch.softmax(metre_regression, dim=1)
357
+ loss_fct = torch.nn.CrossEntropyLoss()
358
+ metre_loss = loss_fct(softmaxed, metre)
359
+ full_loss = full_loss + 0.1*metre_loss
360
+
361
+ year_loss = None
362
+ if year is not None:
363
+ softmaxed = torch.softmax(year_regression, dim=1)
364
+ loss_fct = torch.nn.CrossEntropyLoss()
365
+ year_loss = loss_fct(softmaxed, year)
366
+ full_loss = full_loss + 0.1*year_loss
367
+
368
+
369
+ return {"model_output" : outputs,
370
+ "vowel_regression_output": vowel_regression,
371
+ "vowel_regression_loss": vowel_loss,
372
+ "rhyme_regression_output": rhyme_regression,
373
+ "rhyme_regression_loss": rhyme_loss,
374
+ "verse_end_regression_output" : verse_end_reg,
375
+ "verse_end_regression_loss" : verse_loss,
376
+ "metre_regression_output" : metre_regression,
377
+ "metre_regression_loss" : metre_loss,
378
+ "year_regression_output" : year_regression,
379
+ "year_regression_loss" : year_loss,
380
+ "loss": full_loss}
381
+
382
+ def save_LM(self, LM_path):
383
+ self.model.save_pretrained(LM_path, safe_serialization=False)
384
+
385
+ from .poet_model_utils import ContextModule
386
+
387
+ class PoetModelContextInput(PoetModelFunctionalInterface):
388
+ def __init__(self, pretrainedModel, context_input_size:int = 2048, block_count:int=3, *args, **kwargs) -> None:
389
+ super().__init__(*args, **kwargs)
390
+
391
+ self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel,output_hidden_states=True)
392
+
393
+ model_config = self.model.config
394
+ self.model_size = -1
395
+ # Check for Hidden layer size by Attribute Name
396
+ if hasattr(model_config, "n_embd"):
397
+ self.model_size = model_config.n_embd
398
+ elif hasattr(model_config, "hidden_size"):
399
+ self.model_size = model_config.hidden_size # Number of Emmbedings taken from config
400
+ self.context_size = context_input_size
401
+
402
+
403
+ self.model.base_model.h.insert(3, ContextModule(block_count, context_input_size, self.model_size, self.model_size))
404
+ # Because of Inserted Layer, Head Masks don't match => Add 1 more
405
+ self.model.base_model.config.n_layer += 1
406
+
407
+ self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
408
+
409
+
410
+ def forward(self, input_ids=None, labels=None, attention_mask=None, rhyme=None, context_ids=None, context_attention_mask=None,*args, **kwargs):
411
+ # Inject Context to bypass GPT2Blocks (Can't Forward it)
412
+ self.model.base_model.h[3].context_ids = context_ids
413
+ self.model.base_model.h[3].context_attention_mask = context_attention_mask
414
+
415
+ outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
416
+ last_hidden = outputs['hidden_states'][-1]
417
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
418
+ full_loss = outputs.loss
419
+
420
+ rhyme_loss = None
421
+ if rhyme is not None:
422
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
423
+ loss_fct = torch.nn.CrossEntropyLoss()
424
+ rhyme_loss = loss_fct(softmaxed, rhyme)
425
+ full_loss = full_loss + rhyme_loss
426
+ # Delete the Injection to prevent Dataloss
427
+ self.model.base_model.h[3].context_ids = None
428
+ self.model.base_model.h[3].context_attention_mask = None
429
+
430
+ return {"model_output" : outputs,
431
+ "rhyme_regression_output": rhyme_regression,
432
+ "rhyme_regression_loss": rhyme_loss,
433
+ "loss": full_loss}
434
+
435
+ def save_LM(self, LM_path):
436
+ self.model.save_pretrained(LM_path)
437
+
438
+ from .poet_model_utils import PoetTypeModule
439
+
440
+ class PoetModelContextYear(PoetModelFunctionalInterface):
441
+ def __init__(self, pretrainedModel, context_input_size:int = 2048, block_count:int=3, *args, **kwargs) -> None:
442
+ super().__init__(*args, **kwargs)
443
+
444
+ self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
445
+
446
+ model_config = self.model.config
447
+ self.model_size = -1
448
+ # Check for Hidden layer size by Attribute Name
449
+ if hasattr(model_config, "n_embd"):
450
+ self.model_size = model_config.n_embd
451
+ elif hasattr(model_config, "hidden_size"):
452
+ self.model_size = model_config.hidden_size # Number of Emmbedings taken from config
453
+ self.context_size = context_input_size
454
+
455
+
456
+ self.model.base_model.h.insert(3, ContextModule(block_count, context_input_size, self.model_size, self.model_size))
457
+ self.model.base_model.h.insert(3, PoetTypeModule(block_count, context_input_size, self.model_size, self.model_size))
458
+ # Because of Inserted Layer, Head Masks don't match => Add 1 more
459
+ self.model.base_model.config.n_layer += 2
460
+
461
+ self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
462
+ self.year_regressor = torch.nn.Linear(self.model_size, len(StropheParams.YEAR)) # Year Bucket
463
+
464
+
465
+ def forward(self, input_ids=None, labels=None, attention_mask=None, rhyme=None, context_ids=None, context_attention_mask=None, year=None,*args, **kwargs):
466
+ # Inject Context to bypass GPT2Blocks (Can't Forward it)
467
+ self.model.base_model.h[3].context_ids = context_ids
468
+ self.model.base_model.h[3].context_attention_mask = context_attention_mask
469
+ self.model.base_model.h[3].type_labels = year
470
+
471
+ self.model.base_model.h[4].context_ids = context_ids
472
+ self.model.base_model.h[4].context_attention_mask = context_attention_mask
473
+
474
+ outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
475
+ last_hidden = outputs['hidden_states'][-1]
476
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
477
+ full_loss = outputs.loss
478
+
479
+ rhyme_loss = None
480
+ if rhyme is not None:
481
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
482
+ loss_fct = torch.nn.CrossEntropyLoss()
483
+ rhyme_loss = loss_fct(softmaxed, rhyme)
484
+ full_loss = full_loss + rhyme_loss
485
+
486
+
487
+ year_regression = self.year_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
488
+
489
+ year_loss = None
490
+ if year is not None:
491
+ softmaxed = torch.softmax(year_regression, dim=1)
492
+ loss_fct = torch.nn.CrossEntropyLoss()
493
+ year_loss = loss_fct(softmaxed, year)
494
+ full_loss = full_loss + year_loss + self.model.base_model.h[3].indiv_loss
495
+
496
+ # Delete the Injection to prevent Dataloss
497
+ self.model.base_model.h[3].context_ids = None
498
+ self.model.base_model.h[3].context_attention_mask = None
499
+ self.model.base_model.h[3].type_labels = None
500
+ # Delete Loss
501
+ self.model.base_model.h[3].indiv_loss = None
502
+
503
+ self.model.base_model.h[4].context_ids = None
504
+ self.model.base_model.h[4].context_attention_mask = None
505
+
506
+ return {"model_output" : outputs,
507
+ "rhyme_regression_output": rhyme_regression,
508
+ "rhyme_regression_loss": rhyme_loss,
509
+ "year_regression_output" : year_regression,
510
+ "year_loss" : year_loss,
511
+ "loss": full_loss}
512
+
513
+ def save_LM(self, LM_path):
514
+ self.model.save_pretrained(LM_path)
515
+
516
+
517
+ class DistilModel(PoetModelFunctionalInterface):
518
+
519
+ def __init__(self, pretrainedModel, *args, **kwargs) -> None:
520
+ super().__init__(*args, **kwargs)
521
+
522
+ self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
523
+
524
+ model_config = self.model.config
525
+ self.model_size = 1
526
+ # Check for Hidden layer size by Attribute Name
527
+ if hasattr(model_config, "n_embd"):
528
+ self.model_size = model_config.n_embd
529
+ elif hasattr(model_config, "hidden_size"):
530
+ self.model_size = model_config.hidden_size
531
+
532
+ self.kept_states = [1, 3, 5, 7, 9, 11]
533
+
534
+ for pop_index in sorted(list(set(range(len(self.model.base_model.h))) - set(self.kept_states)), reverse=True):
535
+
536
+ self.model.base_model.h.pop(pop_index)
537
+ # Because of Inserted Layer, Head Masks don't match => Add 1 more
538
+ self.model.base_model.config.n_layer = len(self.kept_states)
539
+
540
+ self.loss_fnc = torch.nn.MSELoss()
541
+
542
+ def forward(self, input_ids=None, labels=None, attention_mask=None, to_replicate_states= None, *args, **kwargs):
543
+ outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
544
+ loss = outputs.loss
545
+ # The 6 layers + embeddings (add + 1 to shift the original_index)
546
+ for distil_index, original_index in enumerate([-1] + self.kept_states):
547
+ loss += self.loss_fnc(outputs['hidden_states'][distil_index], to_replicate_states[original_index + 1])
548
+
549
+ return {"model_output" : outputs,
550
+ "loss": loss}
551
+
552
+ def save_LM(self, LM_path):
553
+ self.model.save_pretrained(LM_path, safe_serialization=False)
554
+
555
+ def generate_forced(self, *args, **kwargs):
556
+ raise NotImplementedError("Currently without")
557
+
558
+ class PoetModelHalfBase(PoetModelFunctionalInterface):
559
+ def __init__(self, pretrainedModel, *args, **kwargs) -> None:
560
+ super().__init__(*args, **kwargs)
561
+
562
+ self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True, torch_dtype=torch.float16)
563
+
564
+ model_config = self.model.config
565
+ self.model_size = -1
566
+ # Check for Hidden layer size by Attribute Name
567
+ if hasattr(model_config, "n_embd"):
568
+ self.model_size = model_config.n_embd
569
+ elif hasattr(model_config, "hidden_size"):
570
+ self.model_size = model_config.hidden_size
571
+
572
+
573
+ def forward(self, input_ids=None, labels=None, attention_mask=None, *args, **kwargs):
574
+ outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
575
+
576
+ return {"model_output" : outputs,
577
+ "loss" : outputs.loss}
578
+
579
+ def save_LM(self, LM_path):
580
+ self.model.save_pretrained(LM_path)
581
+
582
+
583
+ class PoetModelSecondaryTasks(PoetModelFunctionalInterface):
584
+ def __init__(self, pretrainedModel, *args, **kwargs) -> None:
585
+ super().__init__(*args, **kwargs)
586
+
587
+ self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
588
+
589
+ model_config = self.model.config
590
+ self.model_size = -1
591
+ # Check for Hidden layer size by Attribute Name
592
+ if hasattr(model_config, "n_embd"):
593
+ self.model_size = model_config.n_embd
594
+ elif hasattr(model_config, "hidden_size"):
595
+ self.model_size = model_config.hidden_size # Number of Emmbedings taken from config
596
+ self.vowels_regressor = torch.nn.Linear(self.model_size,1) # Vowel count
597
+ self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
598
+
599
+
600
+ def forward(self, input_ids=None, labels=None, attention_mask=None, nums=None, rhyme=None, *args, **kwargs):
601
+ outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
602
+ last_hidden = outputs['hidden_states'][-1]
603
+ vowel_regression = self.vowels_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
604
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
605
+ full_loss = outputs.loss
606
+
607
+ vowel_loss = None
608
+ if nums is not None:
609
+ loss_fct = torch.nn.MSELoss()
610
+ vowel_loss = loss_fct(vowel_regression.view(-1, 1), nums.view(-1, 1))
611
+ full_loss = full_loss + vowel_loss
612
+
613
+ rhyme_loss = None
614
+ if rhyme is not None:
615
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
616
+ loss_fct = torch.nn.CrossEntropyLoss()
617
+ rhyme_loss = loss_fct(softmaxed, rhyme)
618
+ full_loss = full_loss + rhyme_loss
619
+
620
+
621
+ return {"model_output" : outputs,
622
+ "vowel_regression_output": vowel_regression,
623
+ "vowel_regression_loss": vowel_loss,
624
+ "rhyme_regression_output": rhyme_regression,
625
+ "rhyme_regression_loss": rhyme_loss,
626
+ "loss": full_loss}
627
+
628
+ def save_LM(self, LM_path):
629
+ self.model.save_pretrained(LM_path)
630
+
631
+
632
+ class PoetModelVerseEnd(PoetModelFunctionalInterface):
633
+ def __init__(self, pretrainedModel, *args, **kwargs) -> None:
634
+ super().__init__(*args, **kwargs)
635
+
636
+ self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True)
637
+
638
+ model_config = self.model.config
639
+ self.model_size = -1
640
+ # Check for Hidden layer size by Attribute Name
641
+ if hasattr(model_config, "n_embd"):
642
+ self.model_size = model_config.n_embd
643
+ elif hasattr(model_config, "hidden_size"):
644
+ self.model_size = model_config.hidden_size # Number of Emmbedings taken from config
645
+ self.vowels_regressor = torch.nn.Linear(self.model_size,1) # Vowel count
646
+ self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type
647
+ self.verse_endings = torch.nn.Linear(self.model_size, len(StropheParams.ENDS)) # Verse End Syllable
648
+
649
+
650
+ def forward(self, input_ids=None, labels=None, attention_mask=None, nums=None, rhyme=None, verse_end = None, *args, **kwargs):
651
+ outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
652
+ last_hidden = outputs['hidden_states'][-1]
653
+ vowel_regression = self.vowels_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
654
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
655
+ verse_end_reg = self.verse_endings((last_hidden[:,0,:].view(-1, self.model_size)))
656
+ full_loss = outputs.loss
657
+
658
+ vowel_loss = None
659
+ if nums is not None:
660
+ loss_fct = torch.nn.MSELoss()
661
+ vowel_loss = loss_fct(vowel_regression.view(-1, 1), nums.view(-1, 1))
662
+ full_loss = full_loss + vowel_loss
663
+
664
+ rhyme_loss = None
665
+ if rhyme is not None:
666
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
667
+ loss_fct = torch.nn.CrossEntropyLoss()
668
+ rhyme_loss = loss_fct(softmaxed, rhyme)
669
+ full_loss = full_loss + rhyme_loss
670
+
671
+ verse_loss = None
672
+ if verse_end is not None:
673
+ softmaxed = torch.softmax(verse_end_reg, dim=1)
674
+ loss_fct = torch.nn.CrossEntropyLoss()
675
+ verse_loss = loss_fct(softmaxed, verse_end)
676
+ full_loss = full_loss + verse_loss
677
+
678
+
679
+ return {"model_output" : outputs,
680
+ "vowel_regression_output": vowel_regression,
681
+ "vowel_regression_loss": vowel_loss,
682
+ "rhyme_regression_output": rhyme_regression,
683
+ "rhyme_regression_loss": rhyme_loss,
684
+ "verse_end_regression_output" : verse_end_reg,
685
+ "verse_end_regression_loss" : verse_loss,
686
+ "loss": full_loss}
687
+
688
+ def save_LM(self, LM_path):
689
+ self.model.save_pretrained(LM_path)
utils/poet_model_utils.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class PoetModelInterface(torch.nn.Module):
4
+ """Pytorch Model Interface. Abstract class for all Poet model types
5
+
6
+ Args:
7
+ torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
8
+ """
9
+ def __init__(self, *args, **kwargs) -> None:
10
+ """ Constructor. As child Class needs to construct Parent
11
+ """
12
+ super().__init__(*args, **kwargs)
13
+
14
+
15
+ def forward(self, input_ids=None, labels=None, attention_mask=None, *args, **kwargs):
16
+ """Compute model output and model loss
17
+
18
+ Args:
19
+ input_ids (_type_, optional): Model inputs. Defaults to None.
20
+ labels (_type_, optional): Language Model labels. Defaults to None.
21
+ attention_mask (_type_, optional): Attention mask where padding starts. Defaults to None.
22
+
23
+ Raises:
24
+ NotImplementedError: Abstract class
25
+ """
26
+ raise NotImplementedError()
27
+
28
+ def generate_forced(self, *args, **kwargs):
29
+ """Generates model output with restriction on inputs and past generation
30
+
31
+ Raises:
32
+ NotImplementedError: Abstract class
33
+ """
34
+ raise NotImplementedError()
35
+
36
+ @staticmethod
37
+ def rhyme_like(rhyme:str):
38
+ """DEPRECATED: Check string in rhyme format
39
+
40
+ Args:
41
+ rhyme (str): String with possible rhyme
42
+
43
+ Returns:
44
+ bool: Boolean if string like rhyme
45
+ """
46
+ return rhyme.isupper() and len(rhyme) in [4,6]
47
+
48
+ def save_LM(self, LM_path):
49
+ """Save raw LM
50
+
51
+ Args:
52
+ LM_path (str): Where to store the LM
53
+
54
+ Raises:
55
+ NotImplementedError: Abstract class
56
+ """
57
+ raise NotImplementedError()
58
+
59
+
60
+ from transformers import GPT2Config, GPT2Model
61
+ from .poet_utils import StropheParams
62
+
63
+ class ContextModule(torch.nn.Module):
64
+ """Module for understanding poet context
65
+
66
+ Args:
67
+ torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
68
+ """
69
+ def __init__(self, block_count, input_size, n_embd ,output_size,*args, **kwargs) -> None:
70
+ """Construct the underlying small LM for context
71
+
72
+ Args:
73
+ block_count (_type_): LM number of blocks of GPT2Block
74
+ input_size (_type_): LM size of input
75
+ n_embd (_type_): LM size of hidden layers
76
+ output_size (_type_): LM size of output
77
+ """
78
+ super().__init__(*args, **kwargs)
79
+ self.config = GPT2Config(n_positions=input_size, n_head=(n_embd//(768//12)),n_embd=n_embd,
80
+ n_layer=block_count, output_hidden_states=True, output_attentions =True)
81
+ self.context_model = GPT2Model(self.config)
82
+ self.linear_downscale = torch.nn.Linear(n_embd, output_size)
83
+ self.input_size = input_size
84
+ self.n_embd = n_embd
85
+ self.output_size = output_size
86
+ # Context is getting injected from Outside
87
+ self.context_ids = None
88
+ self.context_attention_mask = None
89
+
90
+
91
+ def forward(self, hidden_states,layer_past=None,*args, **kwargs):
92
+ """Compute Context LM output, Data are injected from outside
93
+
94
+ Args:
95
+ hidden_states (_type_): Current hidden states
96
+ layer_past (_type_, optional): Past layer outputs. Defaults to None.
97
+
98
+ Returns:
99
+ _type_: GPT2Block structured output (hidden states, layer past, attention, keys)
100
+ """
101
+ down = torch.zeros_like(hidden_states)
102
+ model_output = None
103
+ # Sometimes there might be no context
104
+ if self.context_ids != None:
105
+ model_output = self.context_model.forward(input_ids=self.context_ids, attention_mask=self.context_attention_mask)
106
+ # Take only the Class token as
107
+ down = self.linear_downscale.forward(model_output["hidden_states"][-1][:,0,:].view(-1, self.n_embd))[:, None, :]
108
+ return (hidden_states + down,
109
+ down[None, :, :, :],
110
+ (None if model_output == None else model_output["attentions"],
111
+ None))
112
+
113
+ class PoetTypeModule(torch.nn.Module):
114
+ """Module to classify poet type
115
+
116
+ Args:
117
+ torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
118
+ """
119
+
120
+ def __init__(self, block_count, input_size, n_embd,output_size,*args, **kwargs) -> None:
121
+ """Construct LM for poet classification from inputs
122
+
123
+ Args:
124
+ block_count (_type_): LM number of blocks of GPT2Block
125
+ input_size (_type_): LM size of input
126
+ n_embd (_type_): LM size of hidden layers
127
+ output_size (_type_): LM size of output
128
+ """
129
+ super().__init__(*args, **kwargs)
130
+ self.config = GPT2Config(n_positions=input_size, n_head=(n_embd//(768//12)),n_embd=n_embd,
131
+ n_layer=block_count, output_hidden_states=True, output_attentions =True)
132
+ self.type_model = GPT2Model(self.config)
133
+ self.type_predict = torch.nn.Linear(n_embd, len(StropheParams.YEAR))
134
+ self.softmax = torch.nn.Softmax()
135
+ self.linear_scale = torch.nn.Linear(len(StropheParams.YEAR), output_size)
136
+ self.input_size = input_size
137
+ self.n_embd = n_embd
138
+ self.output_size = output_size
139
+ # Context and labels are getting injected from Outside
140
+ self.context_ids = None
141
+ self.context_attention_mask = None
142
+ self.type_labels=None
143
+ # Store for loss for model itself
144
+ self.indiv_loss=None
145
+
146
+ def forward(self, hidden_states,layer_past=None,*args, **kwargs):
147
+ """Compute Classification LM output and loss
148
+
149
+ Args:
150
+ hidden_states (_type_): Current hidden states
151
+ layer_past (_type_, optional): Past layer outputs. Defaults to None.
152
+
153
+ Returns:
154
+ _type_: GPT2Block structured output (hidden states, layer past, attention, keys)
155
+ """
156
+ type_prob = torch.zeros((hidden_states.shape[0], len(StropheParams.YEAR))).to("cuda" if torch.cuda.is_available() else "cpu")
157
+ model_output = None
158
+ # Sometimes there might be no context
159
+ if self.context_ids != None:
160
+ model_output = self.type_model.forward(input_ids=self.context_ids, attention_mask=self.context_attention_mask)
161
+ # Only Class token is taken
162
+ poet_type = self.type_predict.forward(model_output["hidden_states"][-1][:,0,:].view(-1, self.n_embd))
163
+ type_prob = self.softmax.forward(poet_type)
164
+ # If type labels are present, inject the true labels to future blocks
165
+ if self.type_labels != None:
166
+ loss_fct = torch.nn.CrossEntropyLoss()
167
+ self.indiv_loss = loss_fct(type_prob, self.type_labels)
168
+ type_prob = (self.type_labels.type(torch.FloatTensor)).to("cuda" if torch.cuda.is_available() else "cpu")
169
+ linear_up = self.linear_scale.forward(type_prob)
170
+ return (hidden_states + linear_up[:, None, :],
171
+ linear_up[None, :, None, :],
172
+ (None if model_output == None else model_output["attentions"],
173
+ None))
174
+
175
+ from transformers import PreTrainedTokenizerBase
176
+
177
+ class ModelManipulation:
178
+ """Static Class incorporating methods for Manipulation with LMs
179
+ Code Inspired by article: Fine-tuning the English GPT-2 in any language with Hugging Face
180
+ Link: https://github.com/piegu/fastai-projects/blob/master/finetuning-English-GPT2-any-language-Portuguese-HuggingFace-fastaiv2.ipynb
181
+ """
182
+
183
+ @staticmethod
184
+ def exchange_embedding(poet_model: PoetModelInterface, new_tokenizer: PreTrainedTokenizerBase, old_tokenizer: PreTrainedTokenizerBase, mirror_imbed:bool=False):
185
+ """Exchange embedding matrixes for GPT2 Models
186
+
187
+ Args:
188
+ poet_model (PoetModelInterface): Model to manipulate with
189
+ new_tokenizer (PreTrainedTokenizerBase): New tokenization
190
+ old_tokenizer (PreTrainedTokenizerBase): Old tokenization
191
+ """
192
+ # Get old Embeddings
193
+ if hasattr(poet_model.model, "transformer"):
194
+ old_embed_in = poet_model.model.transformer.get_input_embeddings().weight.clone().detach()
195
+ else:
196
+ old_embed_in = poet_model.model.get_input_embeddings().weight.clone().detach()
197
+ old_mean_in = old_embed_in.mean(0)
198
+ # Generate new Embedding based on new tokenization
199
+ new_embd_in = old_embed_in.new_zeros(new_tokenizer.vocab_size, old_embed_in.size(1))
200
+ old_vocab = old_tokenizer.get_vocab()
201
+
202
+ vocab_hit = 0
203
+ # Keep as much from old Embeddings as possible
204
+ for w, idx_new in new_tokenizer.get_vocab().items():
205
+ idx_old = old_vocab.get(w, -1)
206
+ if idx_old >= 0:
207
+ new_embd_in[idx_new] = old_embed_in[idx_old]
208
+ vocab_hit +=1
209
+ else:
210
+ new_embd_in[idx_new] = old_mean_in
211
+
212
+ print(f"Vocab hit rate: {vocab_hit}/{old_tokenizer.vocab_size}")
213
+ #Exchange Embeddings and Decoding
214
+ new_embd_layer_in = torch.nn.Embedding(new_tokenizer.vocab_size, old_embed_in.size(1))
215
+ new_embd_layer_in.weight.data = new_embd_in
216
+ if hasattr(poet_model.model, "transformer"):
217
+ poet_model.model.transformer.set_input_embeddings(new_embd_layer_in)
218
+ else:
219
+ poet_model.model.set_input_embeddings(new_embd_layer_in)
220
+
221
+ new_decoder = torch.nn.Linear( old_embed_in.size(1), new_tokenizer.vocab_size, bias=False)
222
+ if hasattr(poet_model.model, "transformer"):
223
+ new_decoder.weight = poet_model.model.transformer.wte.weight
224
+ else:
225
+ new_decoder.weight = poet_model.model.base_model.embeddings.weight
226
+ if hasattr(poet_model.model, "lm_head"):
227
+ poet_model.model.lm_head = new_decoder
228
+ else:
229
+ poet_model.model.head = new_decoder
230
+
231
+
232
+ # Update LM config to reflect possible change in vocab size
233
+ poet_model.model.config.vocab_size = new_tokenizer.vocab_size
234
+
235
+
236
+ @staticmethod
237
+ def exchange_embedding_roberta(metre_model, new_tokenizer: PreTrainedTokenizerBase, old_tokenizer: PreTrainedTokenizerBase):
238
+ """Exchange embedding matrixes for Roberta Models
239
+
240
+ Args:
241
+ poet_model (PoetModelInterface): Model to manipulate with
242
+ new_tokenizer (PreTrainedTokenizerBase): New tokenization
243
+ old_tokenizer (PreTrainedTokenizerBase): Old tokenization
244
+ """
245
+ # Get old Embeddings
246
+ old_embed = metre_model.model.get_input_embeddings().weight.clone().detach()
247
+ old_mean = old_embed.mean(0)
248
+ # Generate new Embedding based on new tokenization
249
+ new_embd = old_embed.new_zeros(new_tokenizer.vocab_size, old_embed.size(1))
250
+ old_vocab = old_tokenizer.get_vocab()
251
+
252
+ vocab_hit = 0
253
+ # Keep as much from old Embeddings as possible
254
+ for w, idx_new in new_tokenizer.get_vocab().items():
255
+ idx_old = old_vocab.get(w, -1)
256
+ if idx_old >= 0:
257
+ new_embd[idx_new] = old_embed[idx_old]
258
+ vocab_hit +=1
259
+ else:
260
+ new_embd[idx_new] = old_mean
261
+
262
+ print(f"Vocab hit rate: {vocab_hit}/{old_tokenizer.vocab_size}")
263
+ #Exchange Embeddings and Decoding
264
+ new_embd_layer = torch.nn.Embedding(new_tokenizer.vocab_size, old_embed.size(1))
265
+ new_embd_layer.weight.data = new_embd
266
+ metre_model.model.set_input_embeddings(new_embd_layer)
267
+ new_decoder = torch.nn.Linear( old_embed.size(1), new_tokenizer.vocab_size)
268
+ new_decoder.weight = metre_model.model.roberta.embeddings.word_embeddings.weight
269
+ metre_model.model.lm_head.decoder = new_decoder
270
+ # Update LM config to reflect possible change in vocab size
271
+ metre_model.model.config.vocab_size = new_tokenizer.vocab_size
272
+
utils/poet_utils.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class StropheParams:
2
+
3
+
4
+ # Most Common Rhyme Schemas (Every Rhyme schema with presence over 0.36 %)
5
+ RHYME_SCHEMES = ['ABAB', 'XXXX',
6
+ 'XAXA','AABB',
7
+ 'XXXXXX','ABBA',
8
+ 'AAXX', 'AABBCC',
9
+ 'ABABCC','ABABXX',
10
+ 'AABCCB','XXAA',
11
+ 'XAAX', 'AXAX',
12
+ 'XAXAXX','XXABAB',
13
+ 'ABBACC','AXAA',
14
+ 'XAABBX','AABCBC',
15
+ 'AABBXX','ABBAXX',
16
+ 'ABABAB','AAXA',
17
+ 'AXXA','XAXABB',
18
+ 'XXAABB','XXAAXX',
19
+ 'ABABAX','XXABBA',
20
+ 'AAXBBX','XXXAXA',
21
+ 'AAAX','XABABX',
22
+ 'XABBAX','AAXXBB',
23
+ 'AXABBX','ABABBX',
24
+ 'XAAXBB','AAAA',
25
+ 'XAAA','XAABXB',
26
+ 'AXABXB','AXAXBB',
27
+ None]
28
+
29
+ RHYME = RHYME_SCHEMES
30
+
31
+
32
+ NORMAL_SCHEMES = ["ABAB", "ABBA", "AABB", "AABBCC", "ABABCC", "ABBACC", "ABBAAB"]
33
+
34
+ # First 200 Most common endings
35
+ VERSE_ENDS = ['ní', 'la', 'je', 'tí', 'ce', 'ti', 'ky', 'ku', 'li', 'jí', 'ně', 'né', 'vá', 'se', 'ny', 'ly', 'na', 'ne', 'nou',
36
+ 'lo', 'ci', 'mi', 'ný', 'sti', 'ka', 'le', 'cí', 'ná', 'ží', 'čí', 'ho', 'dí', 'ší', 'du', 'lí', 'dy', 'nu', 'ří',
37
+ 'ji', 'ru', 'tě', 'ře', 'stí', 'vy', 'ká', 'še', 'dá', 'ni', 'te', 'ví', 'mu', 'tu', 'ta', 'vé', 'val', 'va', 'lý',
38
+ 'tá', 'že', 'ty', 'no', 'vu', 'lá', 'kem', 'chu', 'ků', 'bě', 'vý', 'sy', 'me', 'zí', 'hu', 'vě', 'lu', 'da', 'ry',
39
+ 'rá', 'lé', 'ko', 'ři', 'de', 'hy', 'lem', 'tem', 'kou', 'vou', 'ši', 'há', 'sí', 'ze', 'be', 'ra', 'má', 'to', 'by',
40
+ 'mě', 'su', 'té', 'si', 'ných', 'den', 'či', 'ký', 'ním', 'če', 'tý', 'ma', 'my', 'sem', 'nem', 'dě', 'ha', 'vat', 'ným',
41
+ 'dem', 'dou', 'sta', 'dla', 'svět', 'zem', 'jen', 'dal', 'mí', 'hou', 'zas', 'sen', 'rem', 'nů', 'bu', 'e', 'ba', 'ké',
42
+ 'til', 'jest', 'ství', 'děl', 'květ', 'tů', 'chem', 'lou', 'sám', 'bí', 'tou', 'dé', 'šel', 'nul', 'chá', 'vem', 'sa',
43
+ 'hlas', 'pí', 'čas', 'dil', 'let', 'cích', 'lů', 'žil', 'mů', 'dál', 'cha', 'byl', 'nost', 'ček', 'zy', 'hý', 'nám', 'di',
44
+ 'bou', 'tím', 'ži', 'tek', 'vil', 'jsem', 'sů', 'dech', 'men', 'tla', 'sá', 'zrak', 'chy', 'vám', 'vi', 'dý', 'rád', 'svou',
45
+ 'ném', 've', 'py', 'vo', 'vým', 'nek', 'již', 'víc', 'kal', 'mé', 'dů', 'stá', 'dnes', 'sty', 'ven', None]
46
+ ENDS = VERSE_ENDS
47
+ # Years to bucket to
48
+ POET_YEARS_BUCKETS = [1800, 1820, 1840, 1860, 1880, 1900, 1920, 1940, 1960, None]
49
+ POET_YEARS = POET_YEARS_BUCKETS
50
+ YEAR = POET_YEARS_BUCKETS
51
+ # Possible Meter Types
52
+ METER_TYPES = ["J","T","D","A","X","Y","N","H","P", None]
53
+ METER = METER_TYPES
54
+ # Translation of Meter to one char types
55
+ METER_TRANSLATE = {
56
+ "J":"J",
57
+ "T":"T",
58
+ "D":"D",
59
+ "A":"A",
60
+ "X":"X",
61
+ "Y":"Y",
62
+ "hexameter": "H",
63
+ "pentameter": "P",
64
+ "N":"N"
65
+ }
66
+
67
+ # Basic Characters to consider in rhyme and syllables (43)
68
+ VALID_CHARS = [""," ",'a','á','b','c','č','d','ď','e','é','ě',
69
+ 'f','g','h','i','í','j','k','l','m','n','ň',
70
+ 'o','ó','p','q','r','ř','s','š','t','ť','u',
71
+ 'ú','ů','v','w','x','y','ý','z','ž']
72
+ CHARS = VALID_CHARS
73
+ class Tokens:
74
+ # Tokenizers Special Tokens
75
+ EOS = "<|EOS|>"
76
+ EOS_ID = 0
77
+ PAD = "<|PAD|>"
78
+ PAD_ID = 1
79
+ UNK = "<|UNK|>"
80
+ UNK_ID = 2
81
+ CLS = "<|CLS|>"
82
+ CLS_ID = 3
83
+ # SEP Token is EOS Token
84
+ SEP = EOS
85
+ SEP_ID = 0
86
+
87
+ ALL_TOKENS = {
88
+ EOS : 0,
89
+ PAD : 1,
90
+ UNK : 2,
91
+ CLS : 3,
92
+ }
93
+
94
+
95
+
96
+ import re
97
+ import numpy as np
98
+
99
+ def parse_boolean(value):
100
+ value = value.lower()
101
+
102
+ if value in ["true", "yes", "y", "1", "t"]:
103
+ return True
104
+ elif value in ["false", "no", "n", "0", "f"]:
105
+ return False
106
+
107
+ return False
108
+
109
+ class TextManipulation:
110
+ """Static class for string manipulation methods
111
+
112
+ Returns:
113
+ _type_: str returned by all methods
114
+ """
115
+
116
+ @staticmethod
117
+ def _remove_most_nonchar(raw_text, lower_case=True):
118
+ """Remove most non-alpha non-whitespace characters
119
+
120
+ Args:
121
+ raw_text (str): Text to manipulate
122
+ lower_case (bool, optional): If resulting text should be lowercase. Defaults to True.
123
+
124
+ Returns:
125
+ str: Cleaned up text
126
+ """
127
+ text = re.sub(r'[–\„\“\’\;\:()\]\[\_\*\‘\”\'\-\—\"]+', "", raw_text)
128
+ return text.lower() if lower_case else text
129
+
130
+ @staticmethod
131
+ def _remove_all_nonchar(raw_text):
132
+ """Remove all possible non-alpha characters
133
+
134
+ Args:
135
+ raw_text (str): Text to manipulate
136
+
137
+ Returns:
138
+ str: Cleaned up text
139
+ """
140
+ sub = re.sub(r'([^\w\s]+|[0-9]+)', '', raw_text)
141
+ return sub
142
+
143
+ @staticmethod
144
+ def _year_bucketor(raw_year):
145
+ """Bucketizes year string to boundaries, Bad inputs returns NaN string
146
+
147
+ Args:
148
+ raw_year (str): Year string to bucketize
149
+
150
+ Returns:
151
+ _type_: Bucketized year string
152
+ """
153
+ if TextAnalysis._is_year(raw_year) and raw_year != "NaN":
154
+ year_index = np.argmin(np.abs(np.asarray(StropheParams.YEAR[:-1]) - int(raw_year)))
155
+ return str(StropheParams.YEAR[year_index])
156
+ else:
157
+ return "NaN"
158
+
159
+ _RHYME_POS = ["A", "B", "C", "D", "E", "F", "G", "H"]
160
+
161
+ @staticmethod
162
+ def rhyme_sec(rhyme_ref, current_rhyme):
163
+ """Return proper rhyme indicator to given reference
164
+
165
+ Args:
166
+ rhyme_ref (_type_): reference number of 'A'
167
+ current_rhyme (_type_): current rhyme number that needs inidcation
168
+
169
+ Returns:
170
+ str: rhyme indicator character
171
+ """
172
+
173
+ return "X" if current_rhyme == None or current_rhyme== -1 or rhyme_ref == None or current_rhyme < rhyme_ref or current_rhyme >= rhyme_ref + len(TextManipulation._RHYME_POS) else TextManipulation._RHYME_POS[current_rhyme - rhyme_ref]
174
+
175
+ @staticmethod
176
+ def __post_process_rhyme(rhyme_str: str):
177
+ # First Pass
178
+ marker_count = {marker: rhyme_str.count(marker) for marker in TextManipulation._RHYME_POS}
179
+ for key, val in marker_count.items():
180
+ # Replace all, that ocurr only once with X
181
+ if val == 1:
182
+ rhyme_str = re.sub(key, 'X', rhyme_str)
183
+ # Downscale higher to lower if lower not present
184
+ marker_count = {marker: rhyme_str.count(marker) for marker in TextManipulation._RHYME_POS}
185
+ for key, val in marker_count.items():
186
+ if val > 1 and key != 'X':
187
+ key_index = TextManipulation._RHYME_POS.index(key)
188
+ replacements = {marker: rhyme_str.count(marker) for marker in TextManipulation._RHYME_POS[:key_index]}
189
+ for rep_key, rep_val in replacements.items():
190
+ if rep_val ==0:
191
+ rhyme_str = re.sub(key, rep_key, rhyme_str)
192
+ break
193
+
194
+ # Pass to swap letters
195
+ marker_index = {marker: rhyme_str.find(marker) for marker in TextManipulation._RHYME_POS if rhyme_str.find(marker) != -1}
196
+ keys_values = marker_index.items()
197
+ keys = [v[0] for v in keys_values]
198
+ values = [v[1] for v in keys_values]
199
+
200
+ i = 0
201
+ while i < len(keys):
202
+ j= 0
203
+ while j< len(keys):
204
+ if TextManipulation._RHYME_POS.index(keys[j]) > TextManipulation._RHYME_POS.index(keys[i]) and values[j] < values[i]:
205
+ # Swap the positions
206
+ rhyme_str = re.sub(keys[j], 'Z', rhyme_str)
207
+ rhyme_str = re.sub(keys[i], keys[j], rhyme_str)
208
+ rhyme_str = re.sub('Z', keys[i], rhyme_str)
209
+ # Need to update the value
210
+ temp = values[i]
211
+ values[i]= values[j]
212
+ values[j] = temp
213
+ j+=1
214
+ i+=1
215
+
216
+
217
+ return rhyme_str
218
+
219
+
220
+ @staticmethod
221
+ def _rhyme_string(curr_rhyme_list):
222
+ """Translate rhyme as list of rhyming number to rhyme schema
223
+
224
+ Args:
225
+ curr_rhyme_list (list): Current rhyme as list of ints indicating rhyming verses
226
+
227
+ Returns:
228
+ str: Rhyme schema
229
+ """
230
+ rhyme_list = curr_rhyme_list.copy()
231
+ reference = None
232
+ # Give None a blank -1 rhyme id
233
+ for i in range(len(rhyme_list)):
234
+ if rhyme_list[i] != None and reference == None:
235
+ reference = rhyme_list[i]
236
+ elif rhyme_list[i] != None and rhyme_list[i] < reference:
237
+ reference = rhyme_list[i]
238
+ elif rhyme_list[i] == None:
239
+ rhyme_list[i] = -1
240
+
241
+ # With more robust post processing, this is may not needed
242
+
243
+ # if there is valid rhyme, normalize
244
+ if reference != None:
245
+ # sort the rhyme and get index of reference number
246
+ cheat_sheet = sorted(list(set(rhyme_list[:])))
247
+ ref_index = cheat_sheet.index(reference)
248
+ # normalize the rest around this reference
249
+ for i in range(len(rhyme_list)):
250
+ idx = cheat_sheet.index(rhyme_list[i])
251
+ rhyme_list[i] = reference + (idx - ref_index)
252
+
253
+
254
+ rhyme_str = ""
255
+ for num in rhyme_list:
256
+ rhyme_str += TextManipulation.rhyme_sec(reference, num)
257
+
258
+ return TextManipulation.__post_process_rhyme(rhyme_str)
259
+
260
+ class TextAnalysis:
261
+ """Static class with methods of analysis of strings
262
+
263
+ Returns:
264
+ Union[str, bool, dict, numpy.ndarray]: Analyzed input
265
+ """
266
+
267
+ # Possible Keys if returned type is dict
268
+ POET_PARAM_LIST = ["RHYME", "YEAR", "METER", "LENGTH", "END", "TRUE_LENGTH", "TRUE_END"]
269
+
270
+ @staticmethod
271
+ def _is_meter(meter:str):
272
+ """Return if string is meter type
273
+
274
+ Args:
275
+ meter (str): string to analyze
276
+
277
+ Returns:
278
+ bool: If string is meter type
279
+ """
280
+ return meter in StropheParams.METER[:-1]
281
+
282
+ @staticmethod
283
+ def _is_year(year:str):
284
+ """Return if string is year or special NaN
285
+
286
+ Args:
287
+ year (str): string to analyze
288
+
289
+ Returns:
290
+ bool: If string is year or special NaN
291
+ """
292
+ return (year.isdecimal() and int(year) > 1_000 and int(year) < 10_000) or year == "NaN"
293
+
294
+ @staticmethod
295
+ def _rhyme_like(rhyme:str):
296
+ """Return if string is structured like rhyme schema
297
+
298
+ Args:
299
+ rhyme (str): string to analyze
300
+
301
+ Returns:
302
+ bool: If string is structured like rhyme schema
303
+ """
304
+ return (rhyme.isupper() and len(rhyme) >= 3 and len(rhyme) <= 6)
305
+
306
+ @staticmethod
307
+ def _rhyme_vector(rhyme:str) -> np.ndarray:
308
+ """Create One-hot encoded rhyme schema vector from given string
309
+
310
+ Args:
311
+ rhyme (str): string to construct vector from
312
+
313
+ Returns:
314
+ numpy.ndarray: One-hot encoded rhyme schema vector
315
+ """
316
+
317
+ rhyme_vec = np.zeros(len(StropheParams.RHYME))
318
+ if rhyme in StropheParams.RHYME:
319
+ rhyme_vec[StropheParams.RHYME.index(rhyme)] = 1
320
+ else:
321
+ rhyme_vec[-1] = 1
322
+
323
+ return rhyme_vec
324
+
325
+
326
+ @staticmethod
327
+ def _publish_year_vector(year_string):
328
+ """Construct vector of year of publishing, weighting by distance
329
+
330
+ Args:
331
+ year_string (str): String with publish year
332
+
333
+ Returns:
334
+ numpy.ndarray: Vector of bucketized One-hot encoded publish year
335
+ """
336
+ publish_year = None if not year_string.isdigit() else int(year_string)
337
+ publish_vector = np.zeros(len(StropheParams.YEAR))
338
+ if publish_year == None:
339
+ publish_vector[-1] = 1
340
+ else:
341
+ # Distance Part
342
+ #distance_weighting = [1/(1 + abs(year - publish_year)) for year in POET_YEARS_BUCKETS[:-1]] + [0]
343
+ #publish_vector = np.asarray(distance_weighting)
344
+ # Correct class correction
345
+ publish_vector[np.argmin( abs(np.asarray(StropheParams.YEAR[:-1]) - publish_year))] += 1
346
+ # Normalize
347
+ #publish_vector = publish_vector/np.sum(publish_vector)
348
+ return publish_vector
349
+
350
+ @staticmethod
351
+ def _rhyme_or_not(rhyme_str:str) -> np.ndarray:
352
+ """Create vector if given rhyme string is in our list of rhyme schemas
353
+
354
+ Args:
355
+ rhyme_str (str): string to construct vector from
356
+
357
+ Returns:
358
+ numpy.ndarray: Boolean flag vector
359
+ """
360
+ rhyme_vector = np.zeros(2)
361
+ if rhyme_str in StropheParams.RHYME:
362
+ rhyme_vector[0] = 1
363
+ else:
364
+ rhyme_vector[1] = 1
365
+ return rhyme_vector
366
+
367
+ @staticmethod
368
+ def _metre_vector(metre: str) -> np.ndarray:
369
+ """Create One-hot encoded metre vector from given string
370
+
371
+ Args:
372
+ metre (str): string to construct vector from
373
+
374
+ Returns:
375
+ numpy.ndarray: One-hot encoded metre vector
376
+ """
377
+ metre_vec = np.zeros(len(StropheParams.METER))
378
+ if metre in StropheParams.METER:
379
+ metre_vec[StropheParams.METER.index(metre)] = 1
380
+ else:
381
+ metre_vec[-1] = 1
382
+ return metre_vec
383
+
384
+ @staticmethod
385
+ def _first_line_analysis(text:str):
386
+ """Analysis of parameter line for RHYME, METER, YEAR
387
+
388
+ Args:
389
+ text (str): parameter line string
390
+
391
+ Returns:
392
+ dict: Dictionary with analysis result
393
+ """
394
+ line_striped = text.strip()
395
+ if not line_striped:
396
+ return {}
397
+ poet_params = {}
398
+ # Look for each possible parameter
399
+ for param in line_striped.split():
400
+ if TextAnalysis._is_year(param):
401
+ # Year is Bucketized so to fit
402
+ poet_params["YEAR"] = TextManipulation._year_bucketor(param)
403
+ elif TextAnalysis._rhyme_like(param):
404
+ poet_params["RHYME"] = param
405
+ elif TextAnalysis._is_meter(param):
406
+ poet_params["STROPHE_METER"] = param
407
+ return poet_params
408
+
409
+ @staticmethod
410
+ def _is_line_length(length:str):
411
+ """Return if string is number of syllables parameter
412
+
413
+ Args:
414
+ length (str): string to analyze
415
+
416
+ Returns:
417
+ bool: If string is number of syllables parameter
418
+ """
419
+ return length.isdigit() and int(length) > 1 and int(length) < 100
420
+
421
+ @staticmethod
422
+ def _is_line_end(end:str):
423
+ """Return if string is valid ending syllable/sequence parameter
424
+
425
+ Args:
426
+ end (str): string to analyze
427
+
428
+ Returns:
429
+ bool: If string is valid ending syllable/sequence parameter
430
+ """
431
+ return end.isalpha() and end.islower() and len(end) <= 5
432
+
433
+ @staticmethod
434
+ def _continuos_line_analysis(text:str):
435
+ """Analysis of Content lines for LENGTH, TRUE_LENGTH, END, TRUE_END
436
+
437
+ Args:
438
+ text (str): content line to analyze
439
+
440
+ Returns:
441
+ dict: Dictionary with analysis result
442
+ """
443
+ # Strip line of most separators and look if its empty
444
+ line_striped = TextManipulation._remove_most_nonchar(text, lower_case=False).strip()
445
+ if not line_striped:
446
+ return {}
447
+ line_params = {}
448
+ # OLD MODEL
449
+ if text.count('#') == 0: # BASIC
450
+ pass
451
+ else:
452
+ for param_group in text.split('#')[:-1]:
453
+ for param in param_group.split():
454
+ if TextAnalysis._is_meter(param.strip()):
455
+ line_params["METER"] = param.strip()
456
+ elif TextAnalysis._is_line_length(param.strip()):
457
+ line_params["LENGTH"] = int(param.strip())
458
+ elif TextAnalysis._is_line_end(param.strip()):
459
+ line_params["END"] = param.strip()
460
+
461
+
462
+ line_params["TRUE_LENGTH"] = len(SyllableMaker.syllabify(line_striped.split('#')[-1]))
463
+ line_only_char = TextManipulation._remove_all_nonchar(line_striped).strip()
464
+ if len(line_only_char) > 2:
465
+ line_params["TRUE_END"] = SyllableMaker.syllabify(" ".join(line_only_char.split()[-2:]))[-1]
466
+
467
+ return line_params
468
+
469
+ @staticmethod
470
+ def _is_param_line(text:str):
471
+ """Return if line is a Parameter line (Parameters RHYME, METER, YEAR)
472
+
473
+ Args:
474
+ text (str): line to analyze
475
+
476
+ Returns:
477
+ bool: If line is a Parameter line
478
+ """
479
+ line_striped = text.strip()
480
+ if not line_striped:
481
+ return False
482
+ small_analysis = TextAnalysis._first_line_analysis(line_striped)
483
+ return "RHYME" in small_analysis.keys() or "YEAR" in small_analysis.keys()
484
+
485
+ class SyllableMaker:
486
+ """Static class with methods for separating string to list of Syllables
487
+
488
+ Returns:
489
+ list: List of syllables
490
+ """
491
+
492
+
493
+ # NON-Original code!
494
+ # Taken from Barbora Štěpánková
495
+
496
+ @staticmethod
497
+ def syllabify(text : str) -> list[str]:
498
+ words = re.findall(r"[aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzžAÁBCČDĎEÉĚFGHIÍJKLMNŇOÓPQRŘSŠTŤUÚŮVWXYÝZŽäöüÄÜÖ]+", text)
499
+ syllables : list[str] = []
500
+
501
+ i = 0
502
+ while i < len(words):
503
+ word = words[i]
504
+
505
+ if (word.lower() == "k" or word.lower() == "v" or word.lower() == "s" or word.lower() == "z") and i < len(words) - 1 and len(words[i + 1]) > 1:
506
+ i += 1
507
+ word = word + words[i]
508
+
509
+ letter_counter = 0
510
+
511
+ # Get syllables: mask the word and split the mask
512
+ for syllable_mask in SyllableMaker.__split_mask(SyllableMaker.__create_word_mask(word)):
513
+ word_syllable = ""
514
+ for character in syllable_mask:
515
+ word_syllable += word[letter_counter]
516
+ letter_counter += 1
517
+
518
+ syllables.append(word_syllable)
519
+
520
+ i += 1
521
+
522
+ return syllables
523
+
524
+
525
+ @staticmethod
526
+ def __create_word_mask(word : str) -> str:
527
+ word = word.lower()
528
+
529
+ vocals = r"[aeiyouáéěíýóůúäöü]"
530
+ consonants = r"[bcčdďfghjklmnňpqrřsštťvwxzž]"
531
+
532
+ replacements = [
533
+ #double letters
534
+ ('ch', 'c0'),
535
+ ('rr', 'r0'),
536
+ ('ll', 'l0'),
537
+ ('nn', 'n0'),
538
+ ('th', 't0'),
539
+
540
+ # au, ou, ai, oi
541
+ (r'[ao]u', '0V'),
542
+ (r'[ao]i','0V'),
543
+
544
+ # eu at the beginning of the word
545
+ (r'^eu', '0V'),
546
+
547
+ # now all vocals
548
+ (vocals, 'V'),
549
+
550
+ # r,l that act like vocals in syllables
551
+ (r'([^V])([rl])(0*[^0Vrl]|$)', r'\1V\3'),
552
+
553
+ # sp, st, sk, št, Cř, Cl, Cr, Cv
554
+ (r's[pt]', 's0'),
555
+ (r'([^V0lr]0*)[řlrv]', r'\g<1>0'),
556
+ (r'([^V0]0*)sk', r'\1s0'),
557
+ (r'([^V0]0*)št', r'\1š0'),
558
+
559
+ (consonants, 'K')
560
+ ]
561
+
562
+ for (original, replacement) in replacements:
563
+ word = re.sub(original, replacement, word)
564
+
565
+ return word
566
+
567
+
568
+ @staticmethod
569
+ def __split_mask(mask : str) -> list[str]:
570
+ replacements = [
571
+ # vocal at the beginning
572
+ (r'(^0*V)(K0*V)', r'\1/\2'),
573
+ (r'(^0*V0*K0*)K', r'\1/K'),
574
+
575
+ # dividing the middle of the word
576
+ (r'(K0*V(K0*$)?)', r'\1/'),
577
+ (r'/(K0*)K', r'\1/K'),
578
+ (r'/(0*V)(0*K0*V)', r'/\1/\2'),
579
+ (r'/(0*V0*K0*)K', r'/\1/K'),
580
+
581
+ # add the last consonant to the previous syllable
582
+ (r'/(K0*)$', r'\1/')
583
+ ]
584
+
585
+ for (original, replacement) in replacements:
586
+ mask = re.sub(original, replacement, mask)
587
+
588
+ if len(mask) > 0 and mask[-1] == "/":
589
+ mask = mask[0:-1]
590
+
591
+ return mask.split("/")
utils/validators.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ import jellyfish
4
+ from tqdm import tqdm
5
+ from transformers import AutoModelForMaskedLM
6
+ from transformers.utils import ModelOutput
7
+ import numpy as np
8
+ from .poet_utils import StropheParams
9
+
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from pytorch_optimizer import SAM
12
+
13
+ class ValidatorInterface(torch.nn.Module):
14
+ """Pytorch Model Interface. Abstract class for all validators
15
+
16
+ Args:
17
+ torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface
18
+ """
19
+ def __init__(self, *args, **kwargs) -> None:
20
+ """ Constructor. As child Class needs to construct Parent
21
+ """
22
+ super().__init__(*args, **kwargs)
23
+
24
+ def forward(self, input_ids=None, attention_mask=None, *args, **kwargs):
25
+ """Compute model output and model loss
26
+
27
+ Args:
28
+ input_ids (_type_, optional): Model inputs. Defaults to None.
29
+ attention_mask (_type_, optional): Attention mask where padding starts. Defaults to None.
30
+
31
+ Raises:
32
+ NotImplementedError: Abstract class
33
+ """
34
+ raise NotImplementedError()
35
+
36
+ def predict_state(self, input_ids=None, *args, **kwargs):
37
+ """Compute model outputs
38
+
39
+ Args:
40
+ input_ids (_type_, optional): Model inputs. Defaults to None.
41
+
42
+ Raises:
43
+ NotImplementedError: Abstract class
44
+ """
45
+ raise NotImplementedError()
46
+
47
+ def validate_model(self, input_ids=None, *args, **kwargs):
48
+ """Validate model given some labels, Doesn't use loss
49
+
50
+ Args:
51
+ input_ids (_type_, optional): Model inputs. Defaults to None.
52
+
53
+ Raises:
54
+ NotImplementedError: Abstract class
55
+ """
56
+ raise NotImplementedError()
57
+
58
+
59
+ class RhymeValidator(ValidatorInterface):
60
+ def __init__(self, pretrained_model, *args, **kwargs) -> None:
61
+ super().__init__(*args, **kwargs)
62
+
63
+ self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
64
+
65
+ self.config = self.model.config
66
+
67
+ self.model_size = self.config.hidden_size
68
+
69
+ self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Common Rhyme Type
70
+
71
+ self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.0, weight=torch.tensor([1, 1, 1.5, 1.5, 1.5, 1.5,
72
+ 2, 2, 2, 3, 3, 3,
73
+ 3, 3, 3, 3, 4, 4,
74
+ 5, 5, 5, 5, 7, 7,
75
+ 7, 7, 7, 8, 8, 8,
76
+ 9, 9, 9, 10, 10, 10,
77
+ 12,12, 12, 12, 12, 12,
78
+ 15,15,1.5]) )
79
+
80
+ def forward(self, input_ids=None, attention_mask=None, rhyme=None, *args, **kwargs):
81
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
82
+
83
+ last_hidden = outputs['hidden_states'][-1]
84
+
85
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
86
+
87
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
88
+ rhyme_loss = self.loss_fnc(softmaxed, rhyme)
89
+
90
+ return ModelOutput(loss=rhyme_loss + outputs.loss, model_output=softmaxed)
91
+
92
+ def predict_state(self, input_ids=None, *args, **kwargs):
93
+
94
+ outputs = self.model(input_ids=input_ids)
95
+
96
+ last_hidden = outputs['hidden_states'][-1]
97
+
98
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
99
+
100
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
101
+
102
+ return softmaxed
103
+
104
+ def validate_model(self, input_ids=None, rhyme=None, k:int = 2,*args, **kwargs):
105
+ outputs = self.model(input_ids=input_ids)
106
+
107
+ last_hidden = outputs['hidden_states'][-1]
108
+
109
+ rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
110
+
111
+ softmaxed = torch.softmax(rhyme_regression, dim=1)
112
+
113
+ softmaxed = softmaxed.flatten().cpu()
114
+
115
+ predicted_val = torch.argmax(softmaxed)
116
+
117
+ predicted_top_k = torch.topk(softmaxed, k).indices
118
+
119
+ label_val = torch.argmax(rhyme.flatten())
120
+
121
+ validation_true_val = (label_val == predicted_val).float().sum().numpy()
122
+ top_k_presence = 0
123
+ if label_val in predicted_top_k:
124
+ top_k_presence = 1
125
+
126
+ levenshtein = jellyfish.levenshtein_distance(StropheParams.RHYME[predicted_val] if StropheParams.RHYME[predicted_val] != None else "", StropheParams.RHYME[label_val] if StropheParams.RHYME[label_val] != None else "")
127
+
128
+ hit_pred = softmaxed[label_val].detach().numpy()
129
+
130
+ return {"acc" : validation_true_val,
131
+ "top_k" : top_k_presence,
132
+ "lev_distance": levenshtein,
133
+ "predicted_label" : hit_pred
134
+ }
135
+
136
+
137
+
138
+ class MeterValidator(ValidatorInterface):
139
+ def __init__(self, pretrained_model, *args, **kwargs) -> None:
140
+ super().__init__(*args, **kwargs)
141
+ self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
142
+
143
+ self.config = self.model.config
144
+
145
+ self.model_size = self.config.hidden_size
146
+
147
+ self.meter_regressor = torch.nn.Linear(self.model_size, len(StropheParams.METER)) # Meter Type
148
+
149
+ self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.0, weight=torch.tensor([1, 1.5, 5, 10, 10, 20, 5, 20, 20, 0]))
150
+
151
+ def forward(self, input_ids=None, attention_mask=None, metre_ids=None, *args, **kwargs):
152
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
153
+
154
+ last_hidden = outputs['hidden_states'][-1]
155
+
156
+ meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
157
+
158
+ softmaxed = torch.softmax(meter_regression, dim=1)
159
+ meter_loss = self.loss_fnc(softmaxed, metre_ids)
160
+
161
+ return ModelOutput(loss=meter_loss + outputs.loss, model_output=softmaxed)
162
+
163
+ def predict_state(self, input_ids=None, *args, **kwargs):
164
+ outputs = self.model(input_ids=input_ids)
165
+
166
+ last_hidden = outputs['hidden_states'][-1]
167
+
168
+ meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
169
+
170
+ softmaxed = torch.softmax(meter_regression, dim=1)
171
+
172
+ return softmaxed
173
+
174
+ def validate_model(self, input_ids=None, metre_ids=None, attention_mask=None, k: int=2,*args, **kwargs):
175
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask )
176
+
177
+ last_hidden = outputs['hidden_states'][-1]
178
+
179
+ meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size)))
180
+
181
+ softmaxed = torch.softmax(meter_regression, dim=1)
182
+
183
+ softmaxed = softmaxed.flatten().cpu()
184
+
185
+ predicted_val = torch.argmax(softmaxed)
186
+
187
+ predicted_top_k = torch.topk(softmaxed, k).indices
188
+
189
+ label_val = torch.argmax(metre_ids.flatten())
190
+
191
+ validation_true_val = (label_val == predicted_val).float().sum().numpy()
192
+ top_k_presence = 0
193
+ if label_val in predicted_top_k:
194
+ top_k_presence = 1
195
+
196
+ hit_pred = softmaxed[label_val].detach().numpy()
197
+
198
+ return {"acc" : validation_true_val,
199
+ "top_k" : top_k_presence,
200
+ "predicted_label" : hit_pred
201
+ }
202
+
203
+ class YearValidator(ValidatorInterface):
204
+ def __init__(self, pretrained_model, *args, **kwargs) -> None:
205
+ super().__init__(*args, **kwargs)
206
+ self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True)
207
+
208
+ self.config = self.model.config
209
+
210
+ self.model_size = self.config.hidden_size
211
+
212
+ self.year_era = torch.nn.Linear(self.model_size, len(StropheParams.YEAR))
213
+ self.softmax = torch.nn.Softmax(dim=-1)
214
+
215
+ self.year_val = torch.nn.Linear(self.model_size, 1) # Year Value
216
+
217
+
218
+ self.loss_fnc_era = torch.nn.CrossEntropyLoss(label_smoothing=0.0,weight=torch.tensor([10, 5, 3, 3, 1, 1, 1.5, 2, 5, 0]))
219
+
220
+ self.loss_fnc_val = torch.nn.L1Loss()
221
+
222
+ def forward(self, input_ids=None, attention_mask=None, year_bucket=None, year=None, *args, **kwargs):
223
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor))
224
+
225
+ last_hidden = outputs['hidden_states'][-1]
226
+
227
+
228
+ year_val = self.year_val((last_hidden[:,0,:].view(-1, self.model_size)))
229
+ year_val_loss = self.loss_fnc_val(year_val, year)
230
+
231
+ year_era = self.year_era((last_hidden[:,0,:].view(-1, self.model_size)))
232
+ year_era = self.softmax(year_era)
233
+ year_era_loss = self.loss_fnc_era(year_era, year_bucket)
234
+
235
+ return ModelOutput(loss=year_val_loss + year_era_loss + outputs.loss, model_output=(year_val, year_era))
236
+
237
+ def predict_state(self, input_ids=None, *args, **kwargs):
238
+ outputs = self.model(input_ids=input_ids)
239
+
240
+ last_hidden = outputs['hidden_states'][-1]
241
+
242
+ year_val = self.year_val((last_hidden[:,0,:].view(-1, self.model_size)))
243
+
244
+ return year_val
245
+
246
+ def validate_model(self, input_ids=None, year_bucket=None, k: int=2,*args, **kwargs):
247
+
248
+ outputs = self.model(input_ids=input_ids)
249
+
250
+ last_hidden = outputs['hidden_states'][-1]
251
+
252
+ year_val = self.year_val((last_hidden[:,0,:].view(-1, self.model_size)))
253
+ if hasattr(self, 'year_era'):
254
+ year_era = self.year_era((last_hidden[:,0,:].view(-1, self.model_size)))
255
+ year_era = self.softmax(year_era)
256
+
257
+ year_val = year_val.detach().flatten().cpu().numpy()
258
+ if hasattr(self, 'year_era'):
259
+ year_era = year_era.detach().flatten().cpu().numpy()
260
+
261
+ publish_vector = [1/(1 + abs(year - year_val[0])) for year in StropheParams.YEAR[:-1]] + [0]
262
+ publish_vector = np.asarray(publish_vector)/np.sum(publish_vector)
263
+ # Adding era prediction
264
+ if hasattr(self, 'year_era'):
265
+ publish_vector+= year_era
266
+ publish_vector = torch.tensor( np.asarray(publish_vector)/np.sum(publish_vector))
267
+
268
+
269
+ predicted_val = torch.argmax(publish_vector)
270
+
271
+ predicted_top_k = torch.topk(publish_vector, k).indices
272
+
273
+ label_val = torch.argmax(year_bucket.flatten())
274
+
275
+ validation_true_val = (label_val == predicted_val).float().sum().numpy()
276
+ top_k_presence = 0
277
+ if label_val in predicted_top_k:
278
+ top_k_presence = 1
279
+
280
+ hit_pred = publish_vector[label_val].detach().numpy()
281
+
282
+ distance = abs(label_val.numpy() - predicted_val.numpy())
283
+
284
+ return {"acc" : validation_true_val,
285
+ "top_k" : top_k_presence,
286
+ "predicted_label" : hit_pred,
287
+ "distance" : distance
288
+ }
289
+
290
+
291
+
292
+ class ValidatorTrainer:
293
+ def __init__(self, model: ValidatorInterface, args: dict, train_dataset: Dataset, data_collator, device):
294
+ self.model = model
295
+ self.args = args
296
+ self.epochs = 1 if "epochs" not in args.keys() else args["epochs"]
297
+ self.batch_size = 1 if "batch_size" not in args.keys() else args["batch_size"]
298
+ self.lr = 5e-5 if "lr" not in args.keys() else args["lr"]
299
+ self.weight_decay = 0.0 if "weight_decay" not in args.keys() else args['weight_decay']
300
+
301
+ self.train_loader = DataLoader(train_dataset, self.batch_size, True, collate_fn=data_collator)
302
+
303
+ # SAM Values
304
+ self.device = device
305
+ self.optimizer = SAM(self.model.parameters(), torch.optim.AdamW, lr=self.lr, weight_decay=self.weight_decay)
306
+ self.scheduler = transformers.get_constant_schedule_with_warmup(self.optimizer, 4 * len(train_dataset)//self.batch_size)
307
+
308
+ # GSAM Value
309
+ #self.device = device
310
+ #self.base_optim = AdamP(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
311
+ #self.scheduler = transformers.get_constant_schedule_with_warmup(self.base_optim, len(train_dataset)//self.batch_size)
312
+ #self.rho_scheduler= ProportionScheduler( self.scheduler, max_lr=self.lr)
313
+ #self.optimizer = GSAM(self.model.parameters(),self.base_optim, self.model, self.rho_scheduler, alpha=0.05)
314
+
315
+ def train(self):
316
+ for epoch in tqdm(range(self.epochs)):
317
+ self.model.train()
318
+
319
+ # SAM Attempt
320
+
321
+ for step, batch in enumerate(self.train_loader):
322
+ # First Pass
323
+ loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
324
+ rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
325
+ metre_ids = None if batch["metre_ids"] == None else batch["metre_ids"].to(self.device),
326
+ year_bucket = None if batch["year_bucket"] == None else batch["year_bucket"].to(self.device),
327
+ year = None if batch["year"] == None else batch["year"].to(self.device))['loss']
328
+ loss.backward()
329
+ self.optimizer.first_step(zero_grad=True)
330
+ # Second Pass
331
+ loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
332
+ rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
333
+ metre_ids = None if batch["metre_ids"] == None else batch["metre_ids"].to(self.device),
334
+ year_bucket = None if batch["year_bucket"] == None else batch["year_bucket"].to(self.device),
335
+ year = None if batch["year"] == None else batch["year"].to(self.device))['loss']
336
+
337
+ loss.backward()
338
+ self.optimizer.second_step(zero_grad=True)
339
+ self.scheduler.step()
340
+
341
+ # GSAM Attempt
342
+
343
+ #for step, batch in enumerate(self.train_loader):
344
+ # def closure():
345
+ # self.optimizer.base_optimizer.zero_grad()
346
+ # with torch.enable_grad():
347
+ # outputs = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device),
348
+ # rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device),
349
+ # metre = None if batch["metre"] == None else batch["metre"].to(self.device))
350
+ # loss = torch.nn.functional.cross_entropy(outputs['model_output'].to(self.device),batch['rhyme'].to(self.device) if isinstance(self.model, RhymeValidator) else batch['metre'].to(self.device))
351
+ # loss.backward()
352
+ # return outputs['model_output'], loss.detach()
353
+ # predictions, loss = self.optimizer.step(closure)
354
+ # self.scheduler.step()
355
+ # self.optimizer.update_rho_t()
356
+ #
357
+ if step % 100 == 0:
358
+ print(f'Step {len(self.train_loader) * epoch + step}, loss : {loss.item()}', flush=True)
359
+
utils/validators/meter/ufal-robeczech-base_BPE_validator_1704126400265 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d83f2b8f9b00db0945584e3bcbce96f971cfc572cb8665ff713c6d3cc67854d4
3
+ size 504173324
utils/validators/rhyme/distilroberta-base_BPE_validator_1704126399565 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ceb77ef356a5e5ce3d59a6b2d31b96c925af09e29b4731c143ebabdaf3401c65
3
+ size 328898329
utils/validators/year/ufal-robeczech-base_BPE_validator_1702393305267 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4695ae160b8236b89c467fb50318c6cb429ae6152f9332f74ddcaff5cbe23da1
3
+ size 504177816