sobir-hf's picture
First commit
b56c828
raw
history blame
1.62 kB
import torch
from torch.nn.utils.rnn import pad_sequence
def load_config(path):
d = torch.load(path, map_location='cpu')
return d['config']
class Tokenizer:
def __init__(self, config) -> None:
self.src_vocab = config['src_vocab']
self.trg_vocab = config['trg_vocab']
self.src_char_index = {char:i for i,char in enumerate(self.src_vocab)}
self.trg_char_index = {char:i for i,char in enumerate(self.trg_vocab)}
self.trg_null_idx = self.trg_char_index['<NULL>']
self.src_null_idx = self.src_char_index['<NULL>']
self.src_pad_idx = self.src_char_index['<PAD>']
self.trg_pad_idx = self.trg_char_index['<PAD>']
self.trg_unk_idx = self.trg_char_index['<UNK>']
self.src_unk_idx = self.src_char_index['<UNK>']
def encode_src(self, text: str):
src = [self.src_char_index.get(src_char, self.src_unk_idx) for src_char in text]
src = torch.tensor(src, dtype=torch.long)
return src
def decode_src(self, src: torch.Tensor):
return [self.src_vocab[i] for i in src]
def decode_trg(self, trg: torch.Tensor):
trg = trg.flatten().tolist()
trg = [r for r in trg if r != self.trg_null_idx]
return [self.trg_vocab[i] for i in trg]
def collate_fn(self, batch):
src = [x for x, _ in batch]
trg = [y for _, y in batch]
src_padded = pad_sequence(src, batch_first=True, padding_value=self.src_pad_idx)
trg_padded = pad_sequence(trg, batch_first=True, padding_value=self.trg_pad_idx)
return src_padded, trg_padded