import re from typing import Optional, Union import requests from torch.utils.data import Dataset import torch as t class WordsDataset(Dataset): def __init__(self, texts, labels): self.texts = texts self.labels = labels def __len__(self): return len(self.labels) def __getitem__(self, idx): label = self.labels[idx] text = self.texts[idx] sample = (text, label) return sample #%% def tokenize(text): return re.split(r"\b", text) def _remove_duplicates(text, string=" "): if string + string in text: text = text.replace(string + string, string) return _remove_duplicates(text, string) return text def remove_duplicates(text): text = _remove_duplicates(text, ' ') text = _remove_duplicates(text, '\n') return text # %% class WordData(): def __init__(self, text, start, end, device): self.complete_text = remove_duplicates(text) if start is not None and end is not None: self.complete_text = self.get_excerpt(start, end) self.complete_tokens = tokenize(self.complete_text) self.vocab = sorted(set(self.complete_tokens)) self.token_to_id = dict(zip(self.vocab, list(range(len(self.vocab))))) self.id_to_token = dict(zip(list(range(len(self.vocab))), self.vocab)) self.model_max_length = None self.device = device @staticmethod def from_link(link, device, start=None, end=None): return WordData( requests.get(link).content.decode('utf-8'), start, end, device=device ) @staticmethod def from_file(filename, device, start=None, end=None): with open(filename, encoding='utf-8') as f: text = f.read() return WordData(text, start, end, device=device) def get_excerpt(self, start="THE SONNETS", end="THE END", text=None): if text is None: text = self.complete_text assert start in text, f'get_excerpt: cannot find {start} in text' l_stripped = text.split(start, maxsplit=1)[1] assert end in l_stripped, f'get_excerpt: cannot find {end} in text' r_stripped = l_stripped.split(end, maxsplit=1)[0] return r_stripped def generate_autoregressive_dataset(self, sequence_length, text=None): self.model_max_length = sequence_length if text is None: text = self.complete_text token_ids = self.encode(text, return_tensors="pt") inputs = [token_ids[i:i + sequence_length] for i in range(len(token_ids) - sequence_length)] labels = [token_ids[i + 1:i + 1 + sequence_length] for i in range(len(token_ids) - sequence_length)] return WordsDataset(inputs, labels) def encode(self, initial_text: str, return_tensors: Optional[str] = None) -> Union[list, t.Tensor]: ''' Tokenizes initial_text, then returns the token ids. Return type is list by default, but if return_tensors="pt" then it is returned as a tensor. ''' tokens = tokenize(initial_text) token_ids = [self.token_to_id[t] for t in tokens] if return_tensors == "pt": return t.tensor(token_ids, device=self.device) return token_ids def decode(self, list_of_ids: Union[t.Tensor, list]) -> str: ''' Converts ids to a list of tokens, then joins them into a single string. ''' tokens = [self.id_to_token[int(i)] for i in list_of_ids] return "".join(tokens)