|
import os |
|
import json |
|
import numpy as np |
|
import torch |
|
|
|
from utils.poet_utils import StropheParams, SyllableMaker, TextAnalysis, TextManipulation |
|
from torch.utils.data import Dataset |
|
from transformers import PreTrainedTokenizerBase, PreTrainedModel |
|
|
|
class CorpusDatasetPytorch: |
|
"""Dataset class responsible for data loading. |
|
""" |
|
|
|
class RawDataset: |
|
"""Dataset distributing raw sting data with no preprocessing |
|
""" |
|
def __init__(self, data_file_paths, lower_case:bool = True): |
|
"""Construct the frame around Raw data generation |
|
|
|
Args: |
|
data_file_paths (_type_): list of paths to data files |
|
lower_case (bool, optional): if resulting data should be in lowercase. Defaults to True. |
|
""" |
|
self._data_file_paths = data_file_paths |
|
self.lower_case = lower_case |
|
|
|
def gen_files(self): |
|
"""Get individual opened files |
|
|
|
Yields: |
|
_type_: open file object |
|
""" |
|
for filename in self._data_file_paths: |
|
yield open(filename, 'r') |
|
|
|
def get_text(self): |
|
"""Get lines of text of poetry |
|
|
|
Yields: |
|
str: individual verse line |
|
""" |
|
for step,file in enumerate(self.gen_files()): |
|
if step % 500 == 0: |
|
print(f"Processing file {step}") |
|
datum = json.load(file) |
|
for data_line in datum: |
|
for part_line in data_line['body']: |
|
for text_line in part_line: |
|
yield text_line['text'].lower() if self.lower_case else text_line['text'] |
|
|
|
def get_part(self): |
|
"""Get strophe of poetry |
|
|
|
Yields: |
|
str: 1 strophe of poetry |
|
""" |
|
for step,file in enumerate(self.gen_files()): |
|
if step % 500 == 0: |
|
print(f"Processing file {step}") |
|
datum = json.load(file) |
|
for data_line in datum: |
|
for part_line in data_line['body']: |
|
part = [] |
|
for text_line in part_line: |
|
part.append(text_line['text']) |
|
yield "\n".join(part).lower() if self.lower_case else "\n".join(part) |
|
|
|
def get_body(self): |
|
"""Get whole poem |
|
|
|
Yields: |
|
str: 1 whole poem |
|
""" |
|
for step,file in enumerate(self.gen_files()): |
|
if step % 500 == 0: |
|
print(f"Processing file {step}") |
|
datum = json.load(file) |
|
for data_line in datum: |
|
body = [] |
|
for part_line in data_line['body']: |
|
|
|
for text_line in part_line: |
|
body.append(text_line['text']) |
|
body.append("\n") |
|
yield "\n".join(body).lower() if self.lower_case else "\n".join(body) |
|
|
|
class TextDataset(Dataset): |
|
"""Dataset of preprocessed verse lines |
|
|
|
Args: |
|
Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface |
|
""" |
|
|
|
def __init__(self, data_file_paths, prompt_length=True, prompt_ending=True, lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05): |
|
"""Construct the class our given data files path and store variables |
|
|
|
Args: |
|
data_file_paths (_type_): list of paths to data files |
|
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True. |
|
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True. |
|
lower_case (bool, optional): If the string should be in lowercase. Defaults to True. |
|
val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05. |
|
test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05. |
|
""" |
|
self._data_file_paths = data_file_paths |
|
self.prompt_length = prompt_length |
|
self.prompt_ending = prompt_ending |
|
self.lower_case = lower_case |
|
|
|
self.val_data_rate = val_data_rate |
|
self.test_data_rate = test_data_rate |
|
|
|
self.data = [] |
|
self.validation_data = [] |
|
self.test_data = [] |
|
|
|
|
|
def gen_files(self): |
|
"""Get individual opened files |
|
|
|
Yields: |
|
_type_: open file object |
|
""" |
|
for filename in self._data_file_paths: |
|
yield open(filename, 'r') |
|
|
|
@staticmethod |
|
def _vowels_and_endings(raw_text): |
|
"""Get the verse ending and number of syllables in verse |
|
|
|
Args: |
|
raw_text (str): raw verse to analyze |
|
|
|
Returns: |
|
tuple: number of syllables, ending syllable |
|
""" |
|
syllabs = SyllableMaker.syllabify(raw_text) |
|
vowels = len(syllabs) |
|
ending = syllabs[-1] |
|
return vowels, ending |
|
|
|
@staticmethod |
|
def _ending_vector(end): |
|
"""Construct One-hot encoded vector for ending syllable |
|
|
|
Args: |
|
end (str): Ending syllable |
|
|
|
Returns: |
|
numpy.ndarray: One-hot encoded vector of ending syllable |
|
""" |
|
verse_end_vector = np.zeros(len(StropheParams.ENDS)) |
|
if end in StropheParams.ENDS[:-1]: |
|
verse_end_vector[StropheParams.ENDS.index(end)] = 1 |
|
else: |
|
verse_end_vector[-1] = 1 |
|
return verse_end_vector |
|
|
|
@staticmethod |
|
def _syllable_line(raw_text): |
|
"""Construct verse as sequence of syllables |
|
|
|
Args: |
|
raw_text (str): raw verse line |
|
|
|
Returns: |
|
str: Verse line as sequence of syllables |
|
""" |
|
ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else '' |
|
return " ".join(SyllableMaker.syllabify(raw_text)) + ending |
|
|
|
def _construct_line(self, raw_text, metre): |
|
"""Construct individual content line |
|
|
|
Args: |
|
raw_text (str): raw verse line |
|
|
|
Returns: |
|
str: Processed verse line with line parameters |
|
""" |
|
syllables = SyllableMaker.syllabify(raw_text) |
|
num_str = f"{len(syllables)} # " if self.prompt_length else "" |
|
verse_end = f"{syllables[-1]} # " if self.prompt_ending else "" |
|
metre_txt = f"{metre} # " |
|
return metre_txt + num_str + verse_end + raw_text |
|
|
|
def _introduce_phonetics(self, raw_text:str, phonetics): |
|
phonetic_text = raw_text |
|
for word in phonetics['words']: |
|
phonetic_text = phonetic_text.replace(f'{word["token_lc"]}', f'{word["phoebe"]}') if self.lower_case else phonetic_text.replace(f'{word["token"]}', f'{word["phoebe"]}') |
|
return phonetic_text |
|
|
|
def _construct_syllable_line(self, raw_text, metre): |
|
"""Construct individual content line as sequence of syllables |
|
|
|
Args: |
|
raw_text (str): raw verse line |
|
|
|
Returns: |
|
str: Processed verse line as sequence of syllables with line parameters |
|
""" |
|
ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else '' |
|
syllables = SyllableMaker.syllabify(raw_text) |
|
num_str = f"{len(syllables)} # " if self.prompt_length else "" |
|
verse_end = f"{syllables[-1]} # " if self.prompt_ending else "" |
|
metre_txt = f"{metre} # " |
|
return metre_txt+ num_str + verse_end + " ".join(syllables) + ending |
|
|
|
|
|
def data_text_line_gen(self): |
|
"""Preprocess and process data for usage |
|
""" |
|
for step,file in enumerate(self.gen_files()): |
|
if step % 500 == 0: |
|
print(f"Processing file {step}") |
|
datum = json.load(file) |
|
for data_line in datum: |
|
for part_line in data_line['body']: |
|
for text_line in part_line: |
|
metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "N") |
|
|
|
scanned_text = TextManipulation._remove_most_nonchar(text_line['text'], self.lower_case) |
|
|
|
text_line_scanned = self._construct_line(scanned_text, metre) |
|
syllable_line = self._construct_syllable_line(scanned_text, metre) |
|
|
|
|
|
num_vowels, verse_end = self._vowels_and_endings(scanned_text) |
|
|
|
|
|
rand_split = np.random.rand() |
|
if rand_split > self.val_data_rate + self.test_data_rate: |
|
self.data.append({ |
|
"input_ids" : [text_line_scanned,syllable_line], |
|
"nums": [num_vowels], |
|
"verse_end": verse_end, |
|
"metre": metre |
|
}) |
|
elif rand_split < self.test_data_rate: |
|
self.test_data.append({ |
|
"input_ids" : [text_line_scanned,syllable_line], |
|
"nums": [num_vowels], |
|
"verse_end": verse_end, |
|
"metre": metre |
|
}) |
|
else: |
|
self.validation_data.append({ |
|
"input_ids" : [text_line_scanned,syllable_line], |
|
"nums": [num_vowels], |
|
"verse_end": verse_end, |
|
"metre": metre |
|
}) |
|
|
|
|
|
def __len__(self): |
|
"""Return length of training data |
|
|
|
Returns: |
|
int: length of training data |
|
""" |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
"""return indexed item |
|
|
|
Args: |
|
index (int): index from where to return |
|
|
|
Returns: |
|
dict: dict with indexed data |
|
""" |
|
return self.data[index] |
|
|
|
class BodyDataset(Dataset): |
|
"""Dataset of preprocessed strophe |
|
|
|
Args: |
|
Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface |
|
""" |
|
def __init__(self, data_file_paths, |
|
prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05): |
|
"""Construct the class our given data files path and store variables |
|
|
|
Args: |
|
data_file_paths (_type_): list of paths to data files |
|
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True. |
|
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True. |
|
prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True. |
|
verse_len (list, optional): Considered length of strophe. Defaults to [4,6]. |
|
lower_case (bool, optional): If the string should be in lowercase. Defaults to True. |
|
val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05. |
|
test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05. |
|
""" |
|
self._data_file_paths = data_file_paths |
|
self.prompt_length = prompt_length |
|
self.prompt_ending = prompt_ending |
|
self.prompt_verse = prompt_verse |
|
self.verse_len = verse_len |
|
self.lower_case = lower_case |
|
|
|
self.val_data_rate = val_data_rate |
|
self.test_data_rate = test_data_rate |
|
|
|
self.data = [] |
|
self.validation_data = [] |
|
self.test_data = [] |
|
|
|
def gen_files(self): |
|
"""Get individual opened files |
|
|
|
Yields: |
|
_type_: open file object |
|
""" |
|
for filename in self._data_file_paths: |
|
yield open(filename, 'r') |
|
|
|
|
|
|
|
|
|
def _construct_line(self, raw_text, metre): |
|
"""Construct individual content line |
|
|
|
Args: |
|
raw_text (str): raw verse line |
|
|
|
Returns: |
|
str: Processed verse line with line parameters |
|
""" |
|
syllables = SyllableMaker.syllabify(raw_text) |
|
num_str = f"{len(syllables)} # " if self.prompt_length else "" |
|
verse_end = f"{syllables[-1]} # " if self.prompt_ending else "" |
|
metre_txt = f"{metre} # " |
|
return metre_txt + num_str + verse_end + raw_text |
|
|
|
def _construct_syllable_line(self, raw_text, metre): |
|
"""Construct individual content line as sequence of syllables |
|
|
|
Args: |
|
raw_text (str): raw verse line |
|
|
|
Returns: |
|
str: Processed verse line as sequence of syllables with line parameters |
|
""" |
|
ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else '' |
|
syllables = SyllableMaker.syllabify(raw_text) |
|
num_str = f"{len(syllables)} # " if self.prompt_length else "" |
|
verse_end = f"{syllables[-1]} # " if self.prompt_ending else "" |
|
metre_txt = f"{metre} # " |
|
return metre_txt + num_str + verse_end + " ".join(syllables) + ending |
|
|
|
|
|
|
|
def data_body_gen(self): |
|
"""Preprocess and process data for usage |
|
""" |
|
for step,file in enumerate(self.gen_files()): |
|
if step % 500 == 0: |
|
print(f"Processing file {step}") |
|
datum = json.load(file) |
|
|
|
for data_line in datum: |
|
publish_year_text = TextManipulation._year_bucketor(data_line["biblio"]["year"]) |
|
publish_year_true = data_line["biblio"]["year"] if TextAnalysis._is_year(data_line["biblio"]["year"]) else 'NaN' |
|
context = ["NO CONTEXT"] |
|
|
|
for part_line in data_line['body']: |
|
body = [] |
|
body_syllabs = [] |
|
rhyme= [] |
|
metres = [] |
|
i = 0 |
|
for text_line in part_line: |
|
|
|
|
|
metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "J") |
|
metres += [metre] |
|
|
|
rhyme.append(text_line["rhyme"]) |
|
|
|
scanned_text = TextManipulation._remove_most_nonchar(text_line["text"], self.lower_case) |
|
|
|
body.append(self._construct_line(scanned_text,metre)) |
|
body_syllabs.append(self._construct_syllable_line(scanned_text,metre)) |
|
|
|
i+=1 |
|
|
|
if i in self.verse_len: |
|
|
|
rhyme_str = TextManipulation._rhyme_string(rhyme) |
|
|
|
text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body) + "\n" |
|
syllable_text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body_syllabs) + "\n" |
|
context_text= "\n".join(context) |
|
rand_split = np.random.rand() |
|
if rand_split > self.val_data_rate + self.test_data_rate: |
|
self.data.append({ |
|
"input_ids" : [text,syllable_text], |
|
"context_ids" : context_text, |
|
"year": publish_year_true, |
|
"rhyme": rhyme_str, |
|
"metre_ids" : metres.copy() |
|
}) |
|
elif rand_split < self.test_data_rate: |
|
self.test_data.append({ |
|
"input_ids" : [text,syllable_text], |
|
"context_ids" : context_text, |
|
"year": publish_year_true, |
|
"rhyme": rhyme_str, |
|
"metre_ids" : metres.copy() |
|
}) |
|
else: |
|
self.validation_data.append({ |
|
"input_ids" : [text,syllable_text], |
|
"context_ids" : context_text, |
|
"year": publish_year_true, |
|
"rhyme": rhyme_str, |
|
"metre_ids" : metres.copy() |
|
}) |
|
|
|
if i == max(self.verse_len): |
|
body = [] |
|
body_syllabs = [] |
|
rhyme = [] |
|
metres = [] |
|
i=0 |
|
|
|
|
|
def __len__(self): |
|
"""Return length of training data |
|
|
|
Returns: |
|
int: length of training data |
|
""" |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
"""return indexed item |
|
|
|
Args: |
|
index (int): index from where to return |
|
|
|
Returns: |
|
dict: dict with indexed data |
|
""" |
|
return self.data[index] |
|
|
|
def get_filenames(self): |
|
"""Get paths of data files |
|
|
|
Returns: |
|
list: Paths of data files |
|
""" |
|
data_filenames = os.listdir(self.data_dir) |
|
data_by_files = [] |
|
for filename in data_filenames: |
|
file_path = os.path.join(self.data_dir, filename) |
|
data_by_files.append(file_path) |
|
return data_by_files |
|
|
|
def load_raw_(self): |
|
"""Load Raw dataset with raw string data |
|
""" |
|
filenames = self.get_filenames() |
|
|
|
self.raw_dataset = CorpusDatasetPytorch.RawDataset(filenames, self.lower_case) |
|
|
|
def load_json_filenames(self, prompt_length, prompt_ending, prompt_verse, verse_len=[4,6], val_data_rate=0.05, test_data_rate=0.05): |
|
"""Load Verse and Strophe datasets |
|
|
|
Args: |
|
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True. |
|
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True. |
|
prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True. |
|
verse_len (list, optional): Considered length of strophe. Defaults to [4,6]. |
|
val_data_rate (float, optional): If the string should be in lowercase. Defaults to 0.1. |
|
""" |
|
filenames = self.get_filenames() |
|
|
|
self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset(filenames, prompt_ending=prompt_ending, |
|
prompt_length=prompt_length, prompt_verse=prompt_verse, |
|
verse_len=verse_len, lower_case=self.lower_case, |
|
val_data_rate=val_data_rate, test_data_rate=test_data_rate) |
|
self.pytorch_dataset_body.data_body_gen() |
|
|
|
|
|
self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset(filenames, prompt_ending=prompt_ending, |
|
prompt_length=prompt_length, lower_case=self.lower_case, |
|
val_data_rate=val_data_rate, test_data_rate=test_data_rate) |
|
|
|
self.pytorch_dataset_text.data_text_line_gen() |
|
|
|
self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
|
self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
|
|
|
self.val_pytorch_dataset_body.data = self.pytorch_dataset_body.validation_data |
|
self.val_pytorch_dataset_text.data = self.pytorch_dataset_text.validation_data |
|
|
|
self.pytorch_dataset_text.validation_data = [] |
|
self.pytorch_dataset_body.validation_data = [] |
|
|
|
self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
|
self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
|
|
|
self.test_pytorch_dataset_body.data = self.pytorch_dataset_body.test_data |
|
self.test_pytorch_dataset_text.data = self.pytorch_dataset_text.test_data |
|
|
|
self.pytorch_dataset_text.test_data = [] |
|
self.pytorch_dataset_body.test_data = [] |
|
|
|
def create_empty(self): |
|
"""Create empty holder for possible load of processed data from file |
|
""" |
|
self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
|
self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
|
self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
|
self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
|
self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
|
self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
|
|
|
|
|
@staticmethod |
|
def collate(batch, tokenizer: PreTrainedTokenizerBase ,max_len = 1024, max_context = 1024 ,mask_rate = 0.0, syllables: bool = False, format: str = 'METER_VERSE'): |
|
"""Process data for usage in LM |
|
|
|
Args: |
|
batch (_type_): Batch with selected data points |
|
tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text |
|
max_len (int, optional): Maximum length of tokenization. Defaults to 1024. |
|
max_context (int, optional): Maximum length of tokenization of context. Defaults to 1024. |
|
mask_rate (float, optional): Rate in with to mask data. Defaults to 0.0. |
|
syllables (bool, optional): If to use sequence of syllables as input text. Defaults to False. |
|
|
|
Returns: |
|
dict: tokenized and processed to tensors data |
|
""" |
|
index = 1 if syllables else 0 |
|
|
|
tokenizer.model_max_length = max_len |
|
if batch[0]['input_ids'][0].startswith("#"): |
|
|
|
data = [text['input_ids'][index] for text in batch] |
|
if format == "BASIC": |
|
data = ["\n".join |
|
( |
|
[line + f" # {datum.splitlines()[1].split()[0]}" |
|
if i==0 else line.split('#')[-1] for i, line in enumerate(datum.splitlines())] |
|
) + tokenizer.eos_token for j, datum in enumerate(data) |
|
] |
|
elif format == "VERSE_PAR": |
|
data = ["\n".join |
|
( |
|
[line + f" # {datum.splitlines()[1].split()[0]}" |
|
if i==0 else "#".join(line.split('#')[1:]) for i, line in enumerate(datum.splitlines())] |
|
) + tokenizer.eos_token for j, datum in enumerate(data) |
|
] |
|
else: |
|
data = [text['input_ids'][index] + tokenizer.eos_token for text in batch] |
|
|
|
tokenized = tokenizer(data,return_tensors='pt', truncation=True, padding=True) |
|
input_ids = tokenized['input_ids'] |
|
attention = tokenized["attention_mask"] |
|
|
|
else: |
|
tokenized = tokenizer([text['input_ids'][index] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True) |
|
input_ids = tokenized['input_ids'] |
|
attention = tokenized["attention_mask"] |
|
|
|
|
|
nums = None |
|
if "nums" in batch[0].keys(): |
|
nums = torch.tensor(np.asarray([text['nums'] for text in batch], dtype=np.int32), dtype=torch.float32) |
|
|
|
rhyme=None |
|
if "rhyme" in batch[0].keys(): |
|
rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
|
|
|
verse_end = None |
|
if "verse_end" in batch[0].keys(): |
|
verse_end = torch.tensor(np.asarray([CorpusDatasetPytorch.TextDataset._ending_vector(text["verse_end"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
|
|
|
year = None |
|
if "year" in batch[0].keys(): |
|
year = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
|
|
|
metre = None |
|
if "metre" in batch[0].keys(): |
|
metre = torch.tensor(np.asarray([TextAnalysis._metre_vector(text["metre"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
|
|
|
context_ids = None |
|
context_attention_mask = None |
|
if "context_ids" in batch[0].keys(): |
|
tokenizer.model_max_length = max_context |
|
tokenized_context = tokenizer([text['context_ids'] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True) |
|
context_ids = tokenized_context['input_ids'] |
|
context_attention_mask = tokenized_context['attention_mask'] |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"labels": input_ids.type(torch.LongTensor), |
|
"attention_mask": attention, |
|
"context_ids" : context_ids, |
|
"context_attention_mask" : context_attention_mask, |
|
"nums" : nums, |
|
"rhyme": rhyme, |
|
"verse_end" : verse_end, |
|
"year": year, |
|
"metre" : metre} |
|
|
|
|
|
@staticmethod |
|
def collate_distil(batch, tokenizer: PreTrainedTokenizerBase ,surrogate_model: PreTrainedModel = None,surrogate_model_device=None ,max_len = 1024): |
|
tokenizer.model_max_length = max_len |
|
tokenized = tokenizer([text['input_ids'][0] + tokenizer.eos_token for text in batch], return_tensors='pt', truncation=True, padding=True) |
|
input_ids = tokenized['input_ids'] |
|
attention = tokenized["attention_mask"] |
|
|
|
with torch.no_grad(): |
|
|
|
model_hidden_states = surrogate_model(input_ids=input_ids.to(surrogate_model_device), |
|
attention_mask=attention.to(surrogate_model_device), |
|
labels=input_ids.type(torch.LongTensor).to(surrogate_model_device))['hidden_states'] |
|
model_hidden_states = [hidden.cpu().detach() for hidden in model_hidden_states] |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"labels": input_ids.type(torch.LongTensor), |
|
"attention_mask": attention, |
|
"to_replicate_states": model_hidden_states |
|
} |
|
|
|
@staticmethod |
|
def collate_validator(batch, tokenizer: PreTrainedTokenizerBase,syllables:bool, is_syllable:bool = False,max_len = 512): |
|
"""Process data for use in LM for metre,rhyme and year prediction |
|
|
|
Args: |
|
batch (_type_): Batch with selected data points |
|
tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text |
|
syllables (bool): If to use sequence of syllables as input text |
|
is_syllable (bool, optional): Signal if the preprocessed inputs contain syllable data. Defaults to False. |
|
max_len (int, optional): Maximum length of tokenization. Defaults to 1024. |
|
|
|
Returns: |
|
dict: tokenized and processed to tensors data |
|
""" |
|
index = 1 if syllables and is_syllable else 0 |
|
tokenizer.model_max_length = max_len |
|
data_ids = ["\n".join( |
|
[" ".join( |
|
SyllableMaker.syllabify(line.split('#')[-1]) |
|
) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in text['input_ids'][index].splitlines()[1:]] |
|
) for text in batch ] |
|
|
|
|
|
tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True) |
|
input_ids = tokenized['input_ids'] |
|
attention = tokenized["attention_mask"] |
|
|
|
rhyme=None |
|
if "rhyme" in batch[0].keys(): |
|
rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
|
|
|
year_bucket = None |
|
year = None |
|
if "year" in batch[0].keys(): |
|
year_bucket = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
|
year = torch.tensor(np.asarray([ [int(text['year'])] if text['year'] != 'NaN' else [0] for text in batch], dtype=np.int32), dtype=torch.float32) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention, |
|
"rhyme": rhyme, |
|
"metre_ids": None, |
|
"year_bucket": year_bucket, |
|
'year':year} |
|
|
|
@staticmethod |
|
def collate_meter(batch, tokenizer: PreTrainedTokenizerBase, syllables:bool, is_syllable:bool = False, max_len = 512): |
|
index = 1 if syllables and is_syllable else 0 |
|
tokenizer.model_max_length = max_len |
|
data_ids = [] |
|
metre = [] |
|
for datum in batch: |
|
data_ids += [ |
|
" ".join( |
|
SyllableMaker.syllabify(line.split('#')[-1]) |
|
) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in datum['input_ids'][index].splitlines()[1:] |
|
] |
|
if "metre_ids" in batch[0].keys(): |
|
metre += [TextAnalysis._metre_vector(one_metre) for one_metre in datum['metre_ids']] |
|
|
|
tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True) |
|
input_ids = tokenized['input_ids'] |
|
attention = tokenized["attention_mask"] |
|
|
|
metre_ids = None |
|
if len(metre) > 0: |
|
metre_ids = torch.tensor(np.asarray(metre, dtype=np.int32), dtype=torch.float32) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention, |
|
"rhyme": None, |
|
"metre_ids": metre_ids, |
|
"year_bucket": None, |
|
"year": None} |
|
|
|
|
|
|
|
def __init__(self, data_dir = "PoetGen\corpusCzechVerse-master\ccv", cache_dir='./', |
|
prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate=0.05, test_data_rate=0.05): |
|
"""Construct the Dataloader and create Datasets |
|
|
|
Args: |
|
data_dir (str, optional): Path to data. Defaults to "PoetGen\corpusCzechVerse-master\ccv". |
|
cache_dir (str, optional): Path where to store processed data. Defaults to './'. |
|
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True. |
|
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True. |
|
prompt_verse (bool, optional): If to prompt rhyme schema. Defaults to True. |
|
verse_len (list, optional): Considered length of strophe. Defaults to [4,6]. |
|
lower_case (bool, optional): If the string should be in lowercase. Defaults to True. |
|
val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.1. |
|
""" |
|
self.lower_case = lower_case |
|
self.data_dir = data_dir |
|
if os.path.isfile(os.path.join(cache_dir, "body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "text_poet_data.json")) \ |
|
and os.path.isfile(os.path.join(cache_dir, "val_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "val_text_poet_data.json")) \ |
|
and os.path.isfile(os.path.join(cache_dir, "test_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "test_text_poet_data.json")) : |
|
self.create_empty() |
|
self.pytorch_dataset_body.data =list(json.load( open( os.path.join(cache_dir, "body_poet_data.json"), 'r'))) |
|
self.pytorch_dataset_text.data =list(json.load( open( os.path.join(cache_dir, "text_poet_data.json"), 'r'))) |
|
self.val_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "val_body_poet_data.json"), 'r'))) |
|
self.val_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "val_text_poet_data.json"), 'r'))) |
|
self.test_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "test_body_poet_data.json"), 'r'))) |
|
self.test_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "test_text_poet_data.json"), 'r'))) |
|
else: |
|
self.load_json_filenames(prompt_length, prompt_ending, prompt_verse, verse_len=verse_len, val_data_rate=val_data_rate, test_data_rate=test_data_rate) |
|
json.dump(self.pytorch_dataset_body.data, open( os.path.join(cache_dir, "body_poet_data.json"), 'w+'), indent = 6) |
|
json.dump(self.pytorch_dataset_text.data, open( os.path.join(cache_dir, "text_poet_data.json"), 'w+'), indent = 6) |
|
json.dump(self.val_pytorch_dataset_body.data, open( os.path.join(cache_dir, "val_body_poet_data.json"), 'w+'), indent = 6) |
|
json.dump(self.val_pytorch_dataset_text.data, open( os.path.join(cache_dir, "val_text_poet_data.json"), 'w+'), indent = 6) |
|
json.dump(self.test_pytorch_dataset_body.data, open( os.path.join(cache_dir, "test_body_poet_data.json"), 'w+'), indent = 6) |
|
json.dump(self.test_pytorch_dataset_text.data, open( os.path.join(cache_dir, "test_text_poet_data.json"), 'w+'), indent = 6) |
|
|
|
self.load_raw_() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|