Spaces:
Sleeping
Sleeping
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
def load_config(path): | |
d = torch.load(path, map_location='cpu') | |
return d['config'] | |
class Tokenizer: | |
def __init__(self, config) -> None: | |
self.src_vocab = config['src_vocab'] | |
self.trg_vocab = config['trg_vocab'] | |
self.src_char_index = {char:i for i,char in enumerate(self.src_vocab)} | |
self.trg_char_index = {char:i for i,char in enumerate(self.trg_vocab)} | |
self.trg_null_idx = self.trg_char_index['<NULL>'] | |
self.src_null_idx = self.src_char_index['<NULL>'] | |
self.src_pad_idx = self.src_char_index['<PAD>'] | |
self.trg_pad_idx = self.trg_char_index['<PAD>'] | |
self.trg_unk_idx = self.trg_char_index['<UNK>'] | |
self.src_unk_idx = self.src_char_index['<UNK>'] | |
def encode_src(self, text: str): | |
src = [self.src_char_index.get(src_char, self.src_unk_idx) for src_char in text] | |
src = torch.tensor(src, dtype=torch.long) | |
return src | |
def decode_src(self, src: torch.Tensor): | |
return [self.src_vocab[i] for i in src] | |
def decode_trg(self, trg: torch.Tensor): | |
trg = trg.flatten().tolist() | |
trg = [r for r in trg if r != self.trg_null_idx] | |
return [self.trg_vocab[i] for i in trg] | |
def collate_fn(self, batch): | |
src = [x for x, _ in batch] | |
trg = [y for _, y in batch] | |
src_padded = pad_sequence(src, batch_first=True, padding_value=self.src_pad_idx) | |
trg_padded = pad_sequence(trg, batch_first=True, padding_value=self.trg_pad_idx) | |
return src_padded, trg_padded | |