jinymusim commited on
Commit
edd506f
·
verified ·
1 Parent(s): 65896b7

Delete corpus_capsulated_datasets.py

Browse files
Files changed (1) hide show
  1. corpus_capsulated_datasets.py +0 -754
corpus_capsulated_datasets.py DELETED
@@ -1,754 +0,0 @@
1
- import os
2
- import json
3
- import numpy as np
4
- import torch
5
-
6
- from utils.poet_utils import StropheParams, SyllableMaker, TextAnalysis, TextManipulation
7
- from torch.utils.data import Dataset
8
- from transformers import PreTrainedTokenizerBase, PreTrainedModel
9
- #TODO: Maybe replace year of book being written for year Author was born
10
- class CorpusDatasetPytorch:
11
- """Dataset class responsible for data loading.
12
- """
13
-
14
- class RawDataset:
15
- """Dataset distributing raw sting data with no preprocessing
16
- """
17
- def __init__(self, data_file_paths, lower_case:bool = True):
18
- """Construct the frame around Raw data generation
19
-
20
- Args:
21
- data_file_paths (_type_): list of paths to data files
22
- lower_case (bool, optional): if resulting data should be in lowercase. Defaults to True.
23
- """
24
- self._data_file_paths = data_file_paths
25
- self.lower_case = lower_case
26
-
27
- def gen_files(self):
28
- """Get individual opened files
29
-
30
- Yields:
31
- _type_: open file object
32
- """
33
- for filename in self._data_file_paths:
34
- yield open(filename, 'r')
35
-
36
- def get_text(self):
37
- """Get lines of text of poetry
38
-
39
- Yields:
40
- str: individual verse line
41
- """
42
- for step,file in enumerate(self.gen_files()):
43
- if step % 500 == 0:
44
- print(f"Processing file {step}")
45
- datum = json.load(file)
46
- for data_line in datum:
47
- for part_line in data_line['body']:
48
- for text_line in part_line:
49
- yield text_line['text'].lower() if self.lower_case else text_line['text']
50
-
51
- def get_part(self):
52
- """Get strophe of poetry
53
-
54
- Yields:
55
- str: 1 strophe of poetry
56
- """
57
- for step,file in enumerate(self.gen_files()):
58
- if step % 500 == 0:
59
- print(f"Processing file {step}")
60
- datum = json.load(file)
61
- for data_line in datum:
62
- for part_line in data_line['body']:
63
- part = []
64
- for text_line in part_line:
65
- part.append(text_line['text'])
66
- yield "\n".join(part).lower() if self.lower_case else "\n".join(part)
67
-
68
- def get_body(self):
69
- """Get whole poem
70
-
71
- Yields:
72
- str: 1 whole poem
73
- """
74
- for step,file in enumerate(self.gen_files()):
75
- if step % 500 == 0:
76
- print(f"Processing file {step}")
77
- datum = json.load(file)
78
- for data_line in datum:
79
- body = []
80
- for part_line in data_line['body']:
81
-
82
- for text_line in part_line:
83
- body.append(text_line['text'])
84
- body.append("\n")
85
- yield "\n".join(body).lower() if self.lower_case else "\n".join(body)
86
-
87
- class TextDataset(Dataset):
88
- """Dataset of preprocessed verse lines
89
-
90
- Args:
91
- Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface
92
- """
93
-
94
- def __init__(self, data_file_paths, prompt_length=True, prompt_ending=True, lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05):
95
- """Construct the class our given data files path and store variables
96
-
97
- Args:
98
- data_file_paths (_type_): list of paths to data files
99
- prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
100
- prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
101
- lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
102
- val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
103
- test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
104
- """
105
- self._data_file_paths = data_file_paths
106
- self.prompt_length = prompt_length
107
- self.prompt_ending = prompt_ending
108
- self.lower_case = lower_case
109
-
110
- self.val_data_rate = val_data_rate
111
- self.test_data_rate = test_data_rate
112
-
113
- self.data = []
114
- self.validation_data = []
115
- self.test_data = []
116
-
117
-
118
- def gen_files(self):
119
- """Get individual opened files
120
-
121
- Yields:
122
- _type_: open file object
123
- """
124
- for filename in self._data_file_paths:
125
- yield open(filename, 'r')
126
-
127
- @staticmethod
128
- def _vowels_and_endings(raw_text):
129
- """Get the verse ending and number of syllables in verse
130
-
131
- Args:
132
- raw_text (str): raw verse to analyze
133
-
134
- Returns:
135
- tuple: number of syllables, ending syllable
136
- """
137
- syllabs = SyllableMaker.syllabify(raw_text)
138
- vowels = len(syllabs) #INFO: Now counts the number of syllables
139
- ending = syllabs[-1]
140
- return vowels, ending
141
-
142
- @staticmethod
143
- def _ending_vector(end):
144
- """Construct One-hot encoded vector for ending syllable
145
-
146
- Args:
147
- end (str): Ending syllable
148
-
149
- Returns:
150
- numpy.ndarray: One-hot encoded vector of ending syllable
151
- """
152
- verse_end_vector = np.zeros(len(StropheParams.ENDS))
153
- if end in StropheParams.ENDS[:-1]:
154
- verse_end_vector[StropheParams.ENDS.index(end)] = 1
155
- else:
156
- verse_end_vector[-1] = 1
157
- return verse_end_vector
158
-
159
- @staticmethod
160
- def _syllable_line(raw_text):
161
- """Construct verse as sequence of syllables
162
-
163
- Args:
164
- raw_text (str): raw verse line
165
-
166
- Returns:
167
- str: Verse line as sequence of syllables
168
- """
169
- ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
170
- return " ".join(SyllableMaker.syllabify(raw_text)) + ending
171
-
172
- def _construct_line(self, raw_text, metre):
173
- """Construct individual content line
174
-
175
- Args:
176
- raw_text (str): raw verse line
177
-
178
- Returns:
179
- str: Processed verse line with line parameters
180
- """
181
- syllables = SyllableMaker.syllabify(raw_text)
182
- num_str = f"{len(syllables)} # " if self.prompt_length else ""
183
- verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
184
- metre_txt = f"{metre} # "
185
- return metre_txt + num_str + verse_end + raw_text
186
-
187
- def _introduce_phonetics(self, raw_text:str, phonetics):
188
- phonetic_text = raw_text
189
- for word in phonetics['words']:
190
- phonetic_text = phonetic_text.replace(f'{word["token_lc"]}', f'{word["phoebe"]}') if self.lower_case else phonetic_text.replace(f'{word["token"]}', f'{word["phoebe"]}')
191
- return phonetic_text
192
-
193
- def _construct_syllable_line(self, raw_text, metre):
194
- """Construct individual content line as sequence of syllables
195
-
196
- Args:
197
- raw_text (str): raw verse line
198
-
199
- Returns:
200
- str: Processed verse line as sequence of syllables with line parameters
201
- """
202
- ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
203
- syllables = SyllableMaker.syllabify(raw_text)
204
- num_str = f"{len(syllables)} # " if self.prompt_length else ""
205
- verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
206
- metre_txt = f"{metre} # "
207
- return metre_txt+ num_str + verse_end + " ".join(syllables) + ending
208
-
209
-
210
- def data_text_line_gen(self):
211
- """Preprocess and process data for usage
212
- """
213
- for step,file in enumerate(self.gen_files()):
214
- if step % 500 == 0:
215
- print(f"Processing file {step}")
216
- datum = json.load(file)
217
- for data_line in datum:
218
- for part_line in data_line['body']:
219
- for text_line in part_line:
220
- metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "N")
221
-
222
- scanned_text = TextManipulation._remove_most_nonchar(text_line['text'], self.lower_case)
223
-
224
- text_line_scanned = self._construct_line(scanned_text, metre)
225
- syllable_line = self._construct_syllable_line(scanned_text, metre)
226
- #phonetic_text = self._introduce_phonetics(scanned_text, text_line)
227
-
228
- num_vowels, verse_end = self._vowels_and_endings(scanned_text)
229
-
230
- # Based on result of random chose proper set. Because data are large enough, will result in wanted split.
231
- rand_split = np.random.rand()
232
- if rand_split > self.val_data_rate + self.test_data_rate:
233
- self.data.append({
234
- "input_ids" : [text_line_scanned,syllable_line],
235
- "nums": [num_vowels],
236
- "verse_end": verse_end,
237
- "metre": metre
238
- })
239
- elif rand_split < self.test_data_rate:
240
- self.test_data.append({
241
- "input_ids" : [text_line_scanned,syllable_line],
242
- "nums": [num_vowels],
243
- "verse_end": verse_end,
244
- "metre": metre
245
- })
246
- else:
247
- self.validation_data.append({
248
- "input_ids" : [text_line_scanned,syllable_line],
249
- "nums": [num_vowels],
250
- "verse_end": verse_end,
251
- "metre": metre
252
- })
253
-
254
-
255
- def __len__(self):
256
- """Return length of training data
257
-
258
- Returns:
259
- int: length of training data
260
- """
261
- return len(self.data)
262
-
263
- def __getitem__(self, index):
264
- """return indexed item
265
-
266
- Args:
267
- index (int): index from where to return
268
-
269
- Returns:
270
- dict: dict with indexed data
271
- """
272
- return self.data[index]
273
-
274
- class BodyDataset(Dataset):
275
- """Dataset of preprocessed strophe
276
-
277
- Args:
278
- Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface
279
- """
280
- def __init__(self, data_file_paths,
281
- prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05):
282
- """Construct the class our given data files path and store variables
283
-
284
- Args:
285
- data_file_paths (_type_): list of paths to data files
286
- prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
287
- prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
288
- prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True.
289
- verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
290
- lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
291
- val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
292
- test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
293
- """
294
- self._data_file_paths = data_file_paths
295
- self.prompt_length = prompt_length
296
- self.prompt_ending = prompt_ending
297
- self.prompt_verse = prompt_verse
298
- self.verse_len = verse_len
299
- self.lower_case = lower_case
300
-
301
- self.val_data_rate = val_data_rate
302
- self.test_data_rate = test_data_rate
303
-
304
- self.data = []
305
- self.validation_data = []
306
- self.test_data = []
307
-
308
- def gen_files(self):
309
- """Get individual opened files
310
-
311
- Yields:
312
- _type_: open file object
313
- """
314
- for filename in self._data_file_paths:
315
- yield open(filename, 'r')
316
-
317
-
318
-
319
-
320
- def _construct_line(self, raw_text, metre):
321
- """Construct individual content line
322
-
323
- Args:
324
- raw_text (str): raw verse line
325
-
326
- Returns:
327
- str: Processed verse line with line parameters
328
- """
329
- syllables = SyllableMaker.syllabify(raw_text)
330
- num_str = f"{len(syllables)} # " if self.prompt_length else ""
331
- verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
332
- metre_txt = f"{metre} # "
333
- return metre_txt + num_str + verse_end + raw_text
334
-
335
- def _construct_syllable_line(self, raw_text, metre):
336
- """Construct individual content line as sequence of syllables
337
-
338
- Args:
339
- raw_text (str): raw verse line
340
-
341
- Returns:
342
- str: Processed verse line as sequence of syllables with line parameters
343
- """
344
- ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
345
- syllables = SyllableMaker.syllabify(raw_text)
346
- num_str = f"{len(syllables)} # " if self.prompt_length else ""
347
- verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
348
- metre_txt = f"{metre} # "
349
- return metre_txt + num_str + verse_end + " ".join(syllables) + ending
350
-
351
-
352
-
353
- def data_body_gen(self):
354
- """Preprocess and process data for usage
355
- """
356
- for step,file in enumerate(self.gen_files()):
357
- if step % 500 == 0:
358
- print(f"Processing file {step}")
359
- datum = json.load(file)
360
-
361
- for data_line in datum:
362
- publish_year_text = TextManipulation._year_bucketor(data_line["biblio"]["year"])
363
- publish_year_true = data_line["biblio"]["year"] if TextAnalysis._is_year(data_line["biblio"]["year"]) else 'NaN'
364
- context = ["NO CONTEXT"]
365
-
366
- for part_line in data_line['body']:
367
- body = []
368
- body_syllabs = []
369
- rhyme= []
370
- metres = []
371
- i = 0
372
- for text_line in part_line:
373
-
374
- # In rare cases multiple, but from searching only 1 metre per line
375
- metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "J")
376
- metres += [metre]
377
-
378
- rhyme.append(text_line["rhyme"])
379
-
380
- scanned_text = TextManipulation._remove_most_nonchar(text_line["text"], self.lower_case)
381
-
382
- body.append(self._construct_line(scanned_text,metre))
383
- body_syllabs.append(self._construct_syllable_line(scanned_text,metre))
384
-
385
- i+=1
386
-
387
- if i in self.verse_len:
388
-
389
- rhyme_str = TextManipulation._rhyme_string(rhyme)
390
-
391
- text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body) + "\n"
392
- syllable_text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body_syllabs) + "\n"
393
- context_text= "\n".join(context)
394
- rand_split = np.random.rand()
395
- if rand_split > self.val_data_rate + self.test_data_rate:
396
- self.data.append({
397
- "input_ids" : [text,syllable_text],
398
- "context_ids" : context_text,
399
- "year": publish_year_true,
400
- "rhyme": rhyme_str,
401
- "metre_ids" : metres.copy()
402
- })
403
- elif rand_split < self.test_data_rate:
404
- self.test_data.append({
405
- "input_ids" : [text,syllable_text],
406
- "context_ids" : context_text,
407
- "year": publish_year_true,
408
- "rhyme": rhyme_str,
409
- "metre_ids" : metres.copy()
410
- })
411
- else:
412
- self.validation_data.append({
413
- "input_ids" : [text,syllable_text],
414
- "context_ids" : context_text,
415
- "year": publish_year_true,
416
- "rhyme": rhyme_str,
417
- "metre_ids" : metres.copy()
418
- })
419
-
420
- if i == max(self.verse_len):
421
- body = []
422
- body_syllabs = []
423
- rhyme = []
424
- metres = []
425
- i=0
426
-
427
-
428
- def __len__(self):
429
- """Return length of training data
430
-
431
- Returns:
432
- int: length of training data
433
- """
434
- return len(self.data)
435
-
436
- def __getitem__(self, index):
437
- """return indexed item
438
-
439
- Args:
440
- index (int): index from where to return
441
-
442
- Returns:
443
- dict: dict with indexed data
444
- """
445
- return self.data[index]
446
-
447
- def get_filenames(self):
448
- """Get paths of data files
449
-
450
- Returns:
451
- list: Paths of data files
452
- """
453
- data_filenames = os.listdir(self.data_dir)
454
- data_by_files = []
455
- for filename in data_filenames:
456
- file_path = os.path.join(self.data_dir, filename)
457
- data_by_files.append(file_path)
458
- return data_by_files
459
-
460
- def load_raw_(self):
461
- """Load Raw dataset with raw string data
462
- """
463
- filenames = self.get_filenames()
464
-
465
- self.raw_dataset = CorpusDatasetPytorch.RawDataset(filenames, self.lower_case)
466
-
467
- def load_json_filenames(self, prompt_length, prompt_ending, prompt_verse, verse_len=[4,6], val_data_rate=0.05, test_data_rate=0.05):
468
- """Load Verse and Strophe datasets
469
-
470
- Args:
471
- prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
472
- prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
473
- prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True.
474
- verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
475
- val_data_rate (float, optional): If the string should be in lowercase. Defaults to 0.1.
476
- """
477
- filenames = self.get_filenames()
478
-
479
- self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset(filenames, prompt_ending=prompt_ending,
480
- prompt_length=prompt_length, prompt_verse=prompt_verse,
481
- verse_len=verse_len, lower_case=self.lower_case,
482
- val_data_rate=val_data_rate, test_data_rate=test_data_rate)
483
- self.pytorch_dataset_body.data_body_gen()
484
-
485
-
486
- self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset(filenames, prompt_ending=prompt_ending,
487
- prompt_length=prompt_length, lower_case=self.lower_case,
488
- val_data_rate=val_data_rate, test_data_rate=test_data_rate)
489
-
490
- self.pytorch_dataset_text.data_text_line_gen()
491
-
492
- self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
493
- self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
494
-
495
- self.val_pytorch_dataset_body.data = self.pytorch_dataset_body.validation_data
496
- self.val_pytorch_dataset_text.data = self.pytorch_dataset_text.validation_data
497
-
498
- self.pytorch_dataset_text.validation_data = []
499
- self.pytorch_dataset_body.validation_data = []
500
-
501
- self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
502
- self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
503
-
504
- self.test_pytorch_dataset_body.data = self.pytorch_dataset_body.test_data
505
- self.test_pytorch_dataset_text.data = self.pytorch_dataset_text.test_data
506
-
507
- self.pytorch_dataset_text.test_data = []
508
- self.pytorch_dataset_body.test_data = []
509
-
510
- def create_empty(self):
511
- """Create empty holder for possible load of processed data from file
512
- """
513
- self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
514
- self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
515
- self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
516
- self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
517
- self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
518
- self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
519
-
520
-
521
- @staticmethod
522
- def collate(batch, tokenizer: PreTrainedTokenizerBase ,max_len = 1024, max_context = 1024 ,mask_rate = 0.0, syllables: bool = False, format: str = 'METER_VERSE'):
523
- """Process data for usage in LM
524
-
525
- Args:
526
- batch (_type_): Batch with selected data points
527
- tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text
528
- max_len (int, optional): Maximum length of tokenization. Defaults to 1024.
529
- max_context (int, optional): Maximum length of tokenization of context. Defaults to 1024.
530
- mask_rate (float, optional): Rate in with to mask data. Defaults to 0.0.
531
- syllables (bool, optional): If to use sequence of syllables as input text. Defaults to False.
532
-
533
- Returns:
534
- dict: tokenized and processed to tensors data
535
- """
536
- index = 1 if syllables else 0
537
-
538
- tokenizer.model_max_length = max_len
539
- if batch[0]['input_ids'][0].startswith("#"):
540
-
541
- data = [text['input_ids'][index] for text in batch]
542
- if format == "BASIC":
543
- data = ["\n".join
544
- (
545
- [line + f" # {datum.splitlines()[1].split()[0]}"
546
- if i==0 else line.split('#')[-1] for i, line in enumerate(datum.splitlines())]
547
- ) + tokenizer.eos_token for j, datum in enumerate(data)
548
- ]
549
- elif format == "VERSE_PAR":
550
- data = ["\n".join
551
- (
552
- [line + f" # {datum.splitlines()[1].split()[0]}"
553
- if i==0 else "#".join(line.split('#')[1:]) for i, line in enumerate(datum.splitlines())]
554
- ) + tokenizer.eos_token for j, datum in enumerate(data)
555
- ]
556
- else:
557
- data = [text['input_ids'][index] + tokenizer.eos_token for text in batch]
558
-
559
- tokenized = tokenizer(data,return_tensors='pt', truncation=True, padding=True)
560
- input_ids = tokenized['input_ids']
561
- attention = tokenized["attention_mask"]
562
-
563
- else:
564
- tokenized = tokenizer([text['input_ids'][index] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True)
565
- input_ids = tokenized['input_ids']
566
- attention = tokenized["attention_mask"]
567
-
568
-
569
- nums = None
570
- if "nums" in batch[0].keys():
571
- nums = torch.tensor(np.asarray([text['nums'] for text in batch], dtype=np.int32), dtype=torch.float32)
572
-
573
- rhyme=None
574
- if "rhyme" in batch[0].keys():
575
- rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32)
576
-
577
- verse_end = None
578
- if "verse_end" in batch[0].keys():
579
- verse_end = torch.tensor(np.asarray([CorpusDatasetPytorch.TextDataset._ending_vector(text["verse_end"]) for text in batch], dtype=np.int32), dtype=torch.float32)
580
-
581
- year = None
582
- if "year" in batch[0].keys():
583
- year = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32)
584
-
585
- metre = None
586
- if "metre" in batch[0].keys():
587
- metre = torch.tensor(np.asarray([TextAnalysis._metre_vector(text["metre"]) for text in batch], dtype=np.int32), dtype=torch.float32)
588
-
589
- context_ids = None
590
- context_attention_mask = None
591
- if "context_ids" in batch[0].keys():
592
- tokenizer.model_max_length = max_context
593
- tokenized_context = tokenizer([text['context_ids'] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True)
594
- context_ids = tokenized_context['input_ids']
595
- context_attention_mask = tokenized_context['attention_mask']
596
-
597
- return {
598
- "input_ids": input_ids,
599
- "labels": input_ids.type(torch.LongTensor),
600
- "attention_mask": attention,
601
- "context_ids" : context_ids,
602
- "context_attention_mask" : context_attention_mask,
603
- "nums" : nums,
604
- "rhyme": rhyme,
605
- "verse_end" : verse_end,
606
- "year": year,
607
- "metre" : metre}
608
-
609
-
610
- @staticmethod
611
- def collate_distil(batch, tokenizer: PreTrainedTokenizerBase ,surrogate_model: PreTrainedModel = None,surrogate_model_device=None ,max_len = 1024):
612
- tokenizer.model_max_length = max_len
613
- tokenized = tokenizer([text['input_ids'][0] + tokenizer.eos_token for text in batch], return_tensors='pt', truncation=True, padding=True)
614
- input_ids = tokenized['input_ids']
615
- attention = tokenized["attention_mask"]
616
-
617
- with torch.no_grad():
618
- # This is Tuple
619
- model_hidden_states = surrogate_model(input_ids=input_ids.to(surrogate_model_device),
620
- attention_mask=attention.to(surrogate_model_device),
621
- labels=input_ids.type(torch.LongTensor).to(surrogate_model_device))['hidden_states']
622
- model_hidden_states = [hidden.cpu().detach() for hidden in model_hidden_states]
623
-
624
- return {
625
- "input_ids": input_ids,
626
- "labels": input_ids.type(torch.LongTensor),
627
- "attention_mask": attention,
628
- "to_replicate_states": model_hidden_states
629
- }
630
-
631
- @staticmethod
632
- def collate_validator(batch, tokenizer: PreTrainedTokenizerBase,syllables:bool, is_syllable:bool = False,max_len = 512):
633
- """Process data for use in LM for metre,rhyme and year prediction
634
-
635
- Args:
636
- batch (_type_): Batch with selected data points
637
- tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text
638
- syllables (bool): If to use sequence of syllables as input text
639
- is_syllable (bool, optional): Signal if the preprocessed inputs contain syllable data. Defaults to False.
640
- max_len (int, optional): Maximum length of tokenization. Defaults to 1024.
641
-
642
- Returns:
643
- dict: tokenized and processed to tensors data
644
- """
645
- index = 1 if syllables and is_syllable else 0
646
- tokenizer.model_max_length = max_len
647
- data_ids = ["\n".join(
648
- [" ".join(
649
- SyllableMaker.syllabify(line.split('#')[-1])
650
- ) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in text['input_ids'][index].splitlines()[1:]]
651
- ) for text in batch ]
652
-
653
-
654
- tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True)
655
- input_ids = tokenized['input_ids']
656
- attention = tokenized["attention_mask"]
657
-
658
- rhyme=None
659
- if "rhyme" in batch[0].keys():
660
- rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32)
661
-
662
- year_bucket = None
663
- year = None
664
- if "year" in batch[0].keys():
665
- year_bucket = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32)
666
- year = torch.tensor(np.asarray([ [int(text['year'])] if text['year'] != 'NaN' else [0] for text in batch], dtype=np.int32), dtype=torch.float32)
667
-
668
- return {
669
- "input_ids": input_ids,
670
- "attention_mask": attention,
671
- "rhyme": rhyme,
672
- "metre_ids": None,
673
- "year_bucket": year_bucket,
674
- 'year':year}
675
-
676
- @staticmethod
677
- def collate_meter(batch, tokenizer: PreTrainedTokenizerBase, syllables:bool, is_syllable:bool = False, max_len = 512):
678
- index = 1 if syllables and is_syllable else 0
679
- tokenizer.model_max_length = max_len
680
- data_ids = []
681
- metre = []
682
- for datum in batch:
683
- data_ids += [
684
- " ".join(
685
- SyllableMaker.syllabify(line.split('#')[-1])
686
- ) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in datum['input_ids'][index].splitlines()[1:]
687
- ]
688
- if "metre_ids" in batch[0].keys():
689
- metre += [TextAnalysis._metre_vector(one_metre) for one_metre in datum['metre_ids']]
690
-
691
- tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True)
692
- input_ids = tokenized['input_ids']
693
- attention = tokenized["attention_mask"]
694
-
695
- metre_ids = None
696
- if len(metre) > 0:
697
- metre_ids = torch.tensor(np.asarray(metre, dtype=np.int32), dtype=torch.float32)
698
-
699
- return {
700
- "input_ids": input_ids,
701
- "attention_mask": attention,
702
- "rhyme": None,
703
- "metre_ids": metre_ids,
704
- "year_bucket": None,
705
- "year": None}
706
-
707
-
708
-
709
- def __init__(self, data_dir = "PoetGen\corpusCzechVerse-master\ccv", cache_dir='./',
710
- prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate=0.05, test_data_rate=0.05):
711
- """Construct the Dataloader and create Datasets
712
-
713
- Args:
714
- data_dir (str, optional): Path to data. Defaults to "PoetGen\corpusCzechVerse-master\ccv".
715
- cache_dir (str, optional): Path where to store processed data. Defaults to './'.
716
- prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
717
- prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
718
- prompt_verse (bool, optional): If to prompt rhyme schema. Defaults to True.
719
- verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
720
- lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
721
- val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.1.
722
- """
723
- self.lower_case = lower_case
724
- self.data_dir = data_dir
725
- if os.path.isfile(os.path.join(cache_dir, "body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "text_poet_data.json")) \
726
- and os.path.isfile(os.path.join(cache_dir, "val_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "val_text_poet_data.json")) \
727
- and os.path.isfile(os.path.join(cache_dir, "test_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "test_text_poet_data.json")) :
728
- self.create_empty()
729
- self.pytorch_dataset_body.data =list(json.load( open( os.path.join(cache_dir, "body_poet_data.json"), 'r')))
730
- self.pytorch_dataset_text.data =list(json.load( open( os.path.join(cache_dir, "text_poet_data.json"), 'r')))
731
- self.val_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "val_body_poet_data.json"), 'r')))
732
- self.val_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "val_text_poet_data.json"), 'r')))
733
- self.test_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "test_body_poet_data.json"), 'r')))
734
- self.test_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "test_text_poet_data.json"), 'r')))
735
- else:
736
- self.load_json_filenames(prompt_length, prompt_ending, prompt_verse, verse_len=verse_len, val_data_rate=val_data_rate, test_data_rate=test_data_rate)
737
- json.dump(self.pytorch_dataset_body.data, open( os.path.join(cache_dir, "body_poet_data.json"), 'w+'), indent = 6)
738
- json.dump(self.pytorch_dataset_text.data, open( os.path.join(cache_dir, "text_poet_data.json"), 'w+'), indent = 6)
739
- json.dump(self.val_pytorch_dataset_body.data, open( os.path.join(cache_dir, "val_body_poet_data.json"), 'w+'), indent = 6)
740
- json.dump(self.val_pytorch_dataset_text.data, open( os.path.join(cache_dir, "val_text_poet_data.json"), 'w+'), indent = 6)
741
- json.dump(self.test_pytorch_dataset_body.data, open( os.path.join(cache_dir, "test_body_poet_data.json"), 'w+'), indent = 6)
742
- json.dump(self.test_pytorch_dataset_text.data, open( os.path.join(cache_dir, "test_text_poet_data.json"), 'w+'), indent = 6)
743
-
744
- self.load_raw_()
745
-
746
-
747
-
748
- #if __name__ == "__main__":
749
- # Line Count
750
- # print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_text())))
751
- # Strophe Count
752
- # print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_part())))
753
- # Poem Count
754
- # print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_body())))