|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
class TranslationDataset(Dataset): |
|
def __init__(self, file_path, src_tokenizer, tgt_tokenizer, max_len): |
|
self.data = [] |
|
self.src_tokenizer = src_tokenizer |
|
self.tgt_tokenizer = tgt_tokenizer |
|
self.max_len = max_len |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
for line in f: |
|
src, tgt = line.strip().split('\t') |
|
self.data.append((self.encode(src, self.src_tokenizer), self.encode(tgt, self.tgt_tokenizer))) |
|
|
|
def encode(self, sentence, tokenizer): |
|
tokens = sentence.split() |
|
ids = [tokenizer.word2int.get(token, tokenizer.word2int["<UNK>"]) for token in tokens] |
|
ids = [tokenizer.word2int["<BOS>"]] + ids[:self.max_len - 2] + [tokenizer.word2int["<EOS>"]] |
|
ids += [tokenizer.word2int["<PAD>"]] * (self.max_len - len(ids)) |
|
return ids |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
src, tgt = self.data[idx] |
|
return torch.tensor(src, dtype=torch.long), torch.tensor(tgt, dtype=torch.long) |