poet-validators / corpus_capsulated_datasets.py
jinymusim's picture
New Gen
13cea7c verified
import os
import json
import numpy as np
import torch
from utils.poet_utils import StropheParams, SyllableMaker, TextAnalysis, TextManipulation
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerBase, PreTrainedModel
#TODO: Maybe replace year of book being written for year Author was born
class CorpusDatasetPytorch:
"""Dataset class responsible for data loading.
"""
class RawDataset:
"""Dataset distributing raw sting data with no preprocessing
"""
def __init__(self, data_file_paths, lower_case:bool = True):
"""Construct the frame around Raw data generation
Args:
data_file_paths (_type_): list of paths to data files
lower_case (bool, optional): if resulting data should be in lowercase. Defaults to True.
"""
self._data_file_paths = data_file_paths
self.lower_case = lower_case
def gen_files(self):
"""Get individual opened files
Yields:
_type_: open file object
"""
for filename in self._data_file_paths:
yield open(filename, 'r')
def get_text(self):
"""Get lines of text of poetry
Yields:
str: individual verse line
"""
for step,file in enumerate(self.gen_files()):
if step % 500 == 0:
print(f"Processing file {step}")
datum = json.load(file)
for data_line in datum:
for part_line in data_line['body']:
for text_line in part_line:
yield text_line['text'].lower() if self.lower_case else text_line['text']
def get_part(self):
"""Get strophe of poetry
Yields:
str: 1 strophe of poetry
"""
for step,file in enumerate(self.gen_files()):
if step % 500 == 0:
print(f"Processing file {step}")
datum = json.load(file)
for data_line in datum:
for part_line in data_line['body']:
part = []
for text_line in part_line:
part.append(text_line['text'])
yield "\n".join(part).lower() if self.lower_case else "\n".join(part)
def get_body(self):
"""Get whole poem
Yields:
str: 1 whole poem
"""
for step,file in enumerate(self.gen_files()):
if step % 500 == 0:
print(f"Processing file {step}")
datum = json.load(file)
for data_line in datum:
body = []
for part_line in data_line['body']:
for text_line in part_line:
body.append(text_line['text'])
body.append("\n")
yield "\n".join(body).lower() if self.lower_case else "\n".join(body)
class TextDataset(Dataset):
"""Dataset of preprocessed verse lines
Args:
Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface
"""
def __init__(self, data_file_paths, prompt_length=True, prompt_ending=True, lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05):
"""Construct the class our given data files path and store variables
Args:
data_file_paths (_type_): list of paths to data files
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
"""
self._data_file_paths = data_file_paths
self.prompt_length = prompt_length
self.prompt_ending = prompt_ending
self.lower_case = lower_case
self.val_data_rate = val_data_rate
self.test_data_rate = test_data_rate
self.data = []
self.validation_data = []
self.test_data = []
def gen_files(self):
"""Get individual opened files
Yields:
_type_: open file object
"""
for filename in self._data_file_paths:
yield open(filename, 'r')
@staticmethod
def _vowels_and_endings(raw_text):
"""Get the verse ending and number of syllables in verse
Args:
raw_text (str): raw verse to analyze
Returns:
tuple: number of syllables, ending syllable
"""
syllabs = SyllableMaker.syllabify(raw_text)
vowels = len(syllabs) #INFO: Now counts the number of syllables
ending = syllabs[-1]
return vowels, ending
@staticmethod
def _ending_vector(end):
"""Construct One-hot encoded vector for ending syllable
Args:
end (str): Ending syllable
Returns:
numpy.ndarray: One-hot encoded vector of ending syllable
"""
verse_end_vector = np.zeros(len(StropheParams.ENDS))
if end in StropheParams.ENDS[:-1]:
verse_end_vector[StropheParams.ENDS.index(end)] = 1
else:
verse_end_vector[-1] = 1
return verse_end_vector
@staticmethod
def _syllable_line(raw_text):
"""Construct verse as sequence of syllables
Args:
raw_text (str): raw verse line
Returns:
str: Verse line as sequence of syllables
"""
ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
return " ".join(SyllableMaker.syllabify(raw_text)) + ending
def _construct_line(self, raw_text, metre):
"""Construct individual content line
Args:
raw_text (str): raw verse line
Returns:
str: Processed verse line with line parameters
"""
syllables = SyllableMaker.syllabify(raw_text)
num_str = f"{len(syllables)} # " if self.prompt_length else ""
verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
metre_txt = f"{metre} # "
return metre_txt + num_str + verse_end + raw_text
def _introduce_phonetics(self, raw_text:str, phonetics):
phonetic_text = raw_text
for word in phonetics['words']:
phonetic_text = phonetic_text.replace(f'{word["token_lc"]}', f'{word["phoebe"]}') if self.lower_case else phonetic_text.replace(f'{word["token"]}', f'{word["phoebe"]}')
return phonetic_text
def _construct_syllable_line(self, raw_text, metre):
"""Construct individual content line as sequence of syllables
Args:
raw_text (str): raw verse line
Returns:
str: Processed verse line as sequence of syllables with line parameters
"""
ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
syllables = SyllableMaker.syllabify(raw_text)
num_str = f"{len(syllables)} # " if self.prompt_length else ""
verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
metre_txt = f"{metre} # "
return metre_txt+ num_str + verse_end + " ".join(syllables) + ending
def data_text_line_gen(self):
"""Preprocess and process data for usage
"""
for step,file in enumerate(self.gen_files()):
if step % 500 == 0:
print(f"Processing file {step}")
datum = json.load(file)
for data_line in datum:
for part_line in data_line['body']:
for text_line in part_line:
metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "N")
scanned_text = TextManipulation._remove_most_nonchar(text_line['text'], self.lower_case)
text_line_scanned = self._construct_line(scanned_text, metre)
syllable_line = self._construct_syllable_line(scanned_text, metre)
#phonetic_text = self._introduce_phonetics(scanned_text, text_line)
num_vowels, verse_end = self._vowels_and_endings(scanned_text)
# Based on result of random chose proper set. Because data are large enough, will result in wanted split.
rand_split = np.random.rand()
if rand_split > self.val_data_rate + self.test_data_rate:
self.data.append({
"input_ids" : [text_line_scanned,syllable_line],
"nums": [num_vowels],
"verse_end": verse_end,
"metre": metre
})
elif rand_split < self.test_data_rate:
self.test_data.append({
"input_ids" : [text_line_scanned,syllable_line],
"nums": [num_vowels],
"verse_end": verse_end,
"metre": metre
})
else:
self.validation_data.append({
"input_ids" : [text_line_scanned,syllable_line],
"nums": [num_vowels],
"verse_end": verse_end,
"metre": metre
})
def __len__(self):
"""Return length of training data
Returns:
int: length of training data
"""
return len(self.data)
def __getitem__(self, index):
"""return indexed item
Args:
index (int): index from where to return
Returns:
dict: dict with indexed data
"""
return self.data[index]
class BodyDataset(Dataset):
"""Dataset of preprocessed strophe
Args:
Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface
"""
def __init__(self, data_file_paths,
prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05):
"""Construct the class our given data files path and store variables
Args:
data_file_paths (_type_): list of paths to data files
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True.
verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05.
"""
self._data_file_paths = data_file_paths
self.prompt_length = prompt_length
self.prompt_ending = prompt_ending
self.prompt_verse = prompt_verse
self.verse_len = verse_len
self.lower_case = lower_case
self.val_data_rate = val_data_rate
self.test_data_rate = test_data_rate
self.data = []
self.validation_data = []
self.test_data = []
def gen_files(self):
"""Get individual opened files
Yields:
_type_: open file object
"""
for filename in self._data_file_paths:
yield open(filename, 'r')
def _construct_line(self, raw_text, metre):
"""Construct individual content line
Args:
raw_text (str): raw verse line
Returns:
str: Processed verse line with line parameters
"""
syllables = SyllableMaker.syllabify(raw_text)
num_str = f"{len(syllables)} # " if self.prompt_length else ""
verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
metre_txt = f"{metre} # "
return metre_txt + num_str + verse_end + raw_text
def _construct_syllable_line(self, raw_text, metre):
"""Construct individual content line as sequence of syllables
Args:
raw_text (str): raw verse line
Returns:
str: Processed verse line as sequence of syllables with line parameters
"""
ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else ''
syllables = SyllableMaker.syllabify(raw_text)
num_str = f"{len(syllables)} # " if self.prompt_length else ""
verse_end = f"{syllables[-1]} # " if self.prompt_ending else ""
metre_txt = f"{metre} # "
return metre_txt + num_str + verse_end + " ".join(syllables) + ending
def data_body_gen(self):
"""Preprocess and process data for usage
"""
for step,file in enumerate(self.gen_files()):
if step % 500 == 0:
print(f"Processing file {step}")
datum = json.load(file)
for data_line in datum:
publish_year_text = TextManipulation._year_bucketor(data_line["biblio"]["year"])
publish_year_true = data_line["biblio"]["year"] if TextAnalysis._is_year(data_line["biblio"]["year"]) else 'NaN'
context = ["NO CONTEXT"]
for part_line in data_line['body']:
body = []
body_syllabs = []
rhyme= []
metres = []
i = 0
for text_line in part_line:
# In rare cases multiple, but from searching only 1 metre per line
metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "J")
metres += [metre]
rhyme.append(text_line["rhyme"])
scanned_text = TextManipulation._remove_most_nonchar(text_line["text"], self.lower_case)
body.append(self._construct_line(scanned_text,metre))
body_syllabs.append(self._construct_syllable_line(scanned_text,metre))
i+=1
if i in self.verse_len:
rhyme_str = TextManipulation._rhyme_string(rhyme)
text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body) + "\n"
syllable_text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body_syllabs) + "\n"
context_text= "\n".join(context)
rand_split = np.random.rand()
if rand_split > self.val_data_rate + self.test_data_rate:
self.data.append({
"input_ids" : [text,syllable_text],
"context_ids" : context_text,
"year": publish_year_true,
"rhyme": rhyme_str,
"metre_ids" : metres.copy()
})
elif rand_split < self.test_data_rate:
self.test_data.append({
"input_ids" : [text,syllable_text],
"context_ids" : context_text,
"year": publish_year_true,
"rhyme": rhyme_str,
"metre_ids" : metres.copy()
})
else:
self.validation_data.append({
"input_ids" : [text,syllable_text],
"context_ids" : context_text,
"year": publish_year_true,
"rhyme": rhyme_str,
"metre_ids" : metres.copy()
})
if i == max(self.verse_len):
body = []
body_syllabs = []
rhyme = []
metres = []
i=0
def __len__(self):
"""Return length of training data
Returns:
int: length of training data
"""
return len(self.data)
def __getitem__(self, index):
"""return indexed item
Args:
index (int): index from where to return
Returns:
dict: dict with indexed data
"""
return self.data[index]
def get_filenames(self):
"""Get paths of data files
Returns:
list: Paths of data files
"""
data_filenames = os.listdir(self.data_dir)
data_by_files = []
for filename in data_filenames:
file_path = os.path.join(self.data_dir, filename)
data_by_files.append(file_path)
return data_by_files
def load_raw_(self):
"""Load Raw dataset with raw string data
"""
filenames = self.get_filenames()
self.raw_dataset = CorpusDatasetPytorch.RawDataset(filenames, self.lower_case)
def load_json_filenames(self, prompt_length, prompt_ending, prompt_verse, verse_len=[4,6], val_data_rate=0.05, test_data_rate=0.05):
"""Load Verse and Strophe datasets
Args:
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True.
verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
val_data_rate (float, optional): If the string should be in lowercase. Defaults to 0.1.
"""
filenames = self.get_filenames()
self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset(filenames, prompt_ending=prompt_ending,
prompt_length=prompt_length, prompt_verse=prompt_verse,
verse_len=verse_len, lower_case=self.lower_case,
val_data_rate=val_data_rate, test_data_rate=test_data_rate)
self.pytorch_dataset_body.data_body_gen()
self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset(filenames, prompt_ending=prompt_ending,
prompt_length=prompt_length, lower_case=self.lower_case,
val_data_rate=val_data_rate, test_data_rate=test_data_rate)
self.pytorch_dataset_text.data_text_line_gen()
self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
self.val_pytorch_dataset_body.data = self.pytorch_dataset_body.validation_data
self.val_pytorch_dataset_text.data = self.pytorch_dataset_text.validation_data
self.pytorch_dataset_text.validation_data = []
self.pytorch_dataset_body.validation_data = []
self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
self.test_pytorch_dataset_body.data = self.pytorch_dataset_body.test_data
self.test_pytorch_dataset_text.data = self.pytorch_dataset_text.test_data
self.pytorch_dataset_text.test_data = []
self.pytorch_dataset_body.test_data = []
def create_empty(self):
"""Create empty holder for possible load of processed data from file
"""
self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([])
self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([])
@staticmethod
def collate(batch, tokenizer: PreTrainedTokenizerBase ,max_len = 1024, max_context = 1024 ,mask_rate = 0.0, syllables: bool = False, format: str = 'METER_VERSE'):
"""Process data for usage in LM
Args:
batch (_type_): Batch with selected data points
tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text
max_len (int, optional): Maximum length of tokenization. Defaults to 1024.
max_context (int, optional): Maximum length of tokenization of context. Defaults to 1024.
mask_rate (float, optional): Rate in with to mask data. Defaults to 0.0.
syllables (bool, optional): If to use sequence of syllables as input text. Defaults to False.
Returns:
dict: tokenized and processed to tensors data
"""
index = 1 if syllables else 0
tokenizer.model_max_length = max_len
if batch[0]['input_ids'][0].startswith("#"):
data = [text['input_ids'][index] for text in batch]
if format == "BASIC":
data = ["\n".join
(
[line + f" # {datum.splitlines()[1].split()[0]}"
if i==0 else line.split('#')[-1] for i, line in enumerate(datum.splitlines())]
) + tokenizer.eos_token for j, datum in enumerate(data)
]
elif format == "VERSE_PAR":
data = ["\n".join
(
[line + f" # {datum.splitlines()[1].split()[0]}"
if i==0 else "#".join(line.split('#')[1:]) for i, line in enumerate(datum.splitlines())]
) + tokenizer.eos_token for j, datum in enumerate(data)
]
else:
data = [text['input_ids'][index] + tokenizer.eos_token for text in batch]
tokenized = tokenizer(data,return_tensors='pt', truncation=True, padding=True)
input_ids = tokenized['input_ids']
attention = tokenized["attention_mask"]
else:
tokenized = tokenizer([text['input_ids'][index] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True)
input_ids = tokenized['input_ids']
attention = tokenized["attention_mask"]
nums = None
if "nums" in batch[0].keys():
nums = torch.tensor(np.asarray([text['nums'] for text in batch], dtype=np.int32), dtype=torch.float32)
rhyme=None
if "rhyme" in batch[0].keys():
rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32)
verse_end = None
if "verse_end" in batch[0].keys():
verse_end = torch.tensor(np.asarray([CorpusDatasetPytorch.TextDataset._ending_vector(text["verse_end"]) for text in batch], dtype=np.int32), dtype=torch.float32)
year = None
if "year" in batch[0].keys():
year = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32)
metre = None
if "metre" in batch[0].keys():
metre = torch.tensor(np.asarray([TextAnalysis._metre_vector(text["metre"]) for text in batch], dtype=np.int32), dtype=torch.float32)
context_ids = None
context_attention_mask = None
if "context_ids" in batch[0].keys():
tokenizer.model_max_length = max_context
tokenized_context = tokenizer([text['context_ids'] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True)
context_ids = tokenized_context['input_ids']
context_attention_mask = tokenized_context['attention_mask']
return {
"input_ids": input_ids,
"labels": input_ids.type(torch.LongTensor),
"attention_mask": attention,
"context_ids" : context_ids,
"context_attention_mask" : context_attention_mask,
"nums" : nums,
"rhyme": rhyme,
"verse_end" : verse_end,
"year": year,
"metre" : metre}
@staticmethod
def collate_distil(batch, tokenizer: PreTrainedTokenizerBase ,surrogate_model: PreTrainedModel = None,surrogate_model_device=None ,max_len = 1024):
tokenizer.model_max_length = max_len
tokenized = tokenizer([text['input_ids'][0] + tokenizer.eos_token for text in batch], return_tensors='pt', truncation=True, padding=True)
input_ids = tokenized['input_ids']
attention = tokenized["attention_mask"]
with torch.no_grad():
# This is Tuple
model_hidden_states = surrogate_model(input_ids=input_ids.to(surrogate_model_device),
attention_mask=attention.to(surrogate_model_device),
labels=input_ids.type(torch.LongTensor).to(surrogate_model_device))['hidden_states']
model_hidden_states = [hidden.cpu().detach() for hidden in model_hidden_states]
return {
"input_ids": input_ids,
"labels": input_ids.type(torch.LongTensor),
"attention_mask": attention,
"to_replicate_states": model_hidden_states
}
@staticmethod
def collate_validator(batch, tokenizer: PreTrainedTokenizerBase,syllables:bool, is_syllable:bool = False,max_len = 512):
"""Process data for use in LM for metre,rhyme and year prediction
Args:
batch (_type_): Batch with selected data points
tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text
syllables (bool): If to use sequence of syllables as input text
is_syllable (bool, optional): Signal if the preprocessed inputs contain syllable data. Defaults to False.
max_len (int, optional): Maximum length of tokenization. Defaults to 1024.
Returns:
dict: tokenized and processed to tensors data
"""
index = 1 if syllables and is_syllable else 0
tokenizer.model_max_length = max_len
data_ids = ["\n".join(
[" ".join(
SyllableMaker.syllabify(line.split('#')[-1])
) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in text['input_ids'][index].splitlines()[1:]]
) for text in batch ]
tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True)
input_ids = tokenized['input_ids']
attention = tokenized["attention_mask"]
rhyme=None
if "rhyme" in batch[0].keys():
rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32)
year_bucket = None
year = None
if "year" in batch[0].keys():
year_bucket = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32)
year = torch.tensor(np.asarray([ [int(text['year'])] if text['year'] != 'NaN' else [0] for text in batch], dtype=np.int32), dtype=torch.float32)
return {
"input_ids": input_ids,
"attention_mask": attention,
"rhyme": rhyme,
"metre_ids": None,
"year_bucket": year_bucket,
'year':year}
@staticmethod
def collate_meter(batch, tokenizer: PreTrainedTokenizerBase, syllables:bool, is_syllable:bool = False, max_len = 512):
index = 1 if syllables and is_syllable else 0
tokenizer.model_max_length = max_len
data_ids = []
metre = []
for datum in batch:
data_ids += [
" ".join(
SyllableMaker.syllabify(line.split('#')[-1])
) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in datum['input_ids'][index].splitlines()[1:]
]
if "metre_ids" in batch[0].keys():
metre += [TextAnalysis._metre_vector(one_metre) for one_metre in datum['metre_ids']]
tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True)
input_ids = tokenized['input_ids']
attention = tokenized["attention_mask"]
metre_ids = None
if len(metre) > 0:
metre_ids = torch.tensor(np.asarray(metre, dtype=np.int32), dtype=torch.float32)
return {
"input_ids": input_ids,
"attention_mask": attention,
"rhyme": None,
"metre_ids": metre_ids,
"year_bucket": None,
"year": None}
def __init__(self, data_dir = "PoetGen\corpusCzechVerse-master\ccv", cache_dir='./',
prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate=0.05, test_data_rate=0.05):
"""Construct the Dataloader and create Datasets
Args:
data_dir (str, optional): Path to data. Defaults to "PoetGen\corpusCzechVerse-master\ccv".
cache_dir (str, optional): Path where to store processed data. Defaults to './'.
prompt_length (bool, optional): If to prompt the syllable count. Defaults to True.
prompt_ending (bool, optional): If to prompt verse ending. Defaults to True.
prompt_verse (bool, optional): If to prompt rhyme schema. Defaults to True.
verse_len (list, optional): Considered length of strophe. Defaults to [4,6].
lower_case (bool, optional): If the string should be in lowercase. Defaults to True.
val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.1.
"""
self.lower_case = lower_case
self.data_dir = data_dir
if os.path.isfile(os.path.join(cache_dir, "body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "text_poet_data.json")) \
and os.path.isfile(os.path.join(cache_dir, "val_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "val_text_poet_data.json")) \
and os.path.isfile(os.path.join(cache_dir, "test_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "test_text_poet_data.json")) :
self.create_empty()
self.pytorch_dataset_body.data =list(json.load( open( os.path.join(cache_dir, "body_poet_data.json"), 'r')))
self.pytorch_dataset_text.data =list(json.load( open( os.path.join(cache_dir, "text_poet_data.json"), 'r')))
self.val_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "val_body_poet_data.json"), 'r')))
self.val_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "val_text_poet_data.json"), 'r')))
self.test_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "test_body_poet_data.json"), 'r')))
self.test_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "test_text_poet_data.json"), 'r')))
else:
self.load_json_filenames(prompt_length, prompt_ending, prompt_verse, verse_len=verse_len, val_data_rate=val_data_rate, test_data_rate=test_data_rate)
json.dump(self.pytorch_dataset_body.data, open( os.path.join(cache_dir, "body_poet_data.json"), 'w+'), indent = 6)
json.dump(self.pytorch_dataset_text.data, open( os.path.join(cache_dir, "text_poet_data.json"), 'w+'), indent = 6)
json.dump(self.val_pytorch_dataset_body.data, open( os.path.join(cache_dir, "val_body_poet_data.json"), 'w+'), indent = 6)
json.dump(self.val_pytorch_dataset_text.data, open( os.path.join(cache_dir, "val_text_poet_data.json"), 'w+'), indent = 6)
json.dump(self.test_pytorch_dataset_body.data, open( os.path.join(cache_dir, "test_body_poet_data.json"), 'w+'), indent = 6)
json.dump(self.test_pytorch_dataset_text.data, open( os.path.join(cache_dir, "test_text_poet_data.json"), 'w+'), indent = 6)
self.load_raw_()
#if __name__ == "__main__":
# Line Count
# print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_text())))
# Strophe Count
# print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_part())))
# Poem Count
# print(len(list(CorpusDatasetPytorch(os.path.abspath(os.path.join(os.path.dirname(__file__), "corpusCzechVerse", "ccv")) ).raw_dataset.get_body())))