File size: 1,185 Bytes
f8bd4d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
class TranslationDataset(Dataset):
def __init__(self, file_path, src_tokenizer, tgt_tokenizer, max_len):
self.data = []
self.src_tokenizer = src_tokenizer
self.tgt_tokenizer = tgt_tokenizer
self.max_len = max_len
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
src, tgt = line.strip().split('\t')
self.data.append((self.encode(src, self.src_tokenizer), self.encode(tgt, self.tgt_tokenizer)))
def encode(self, sentence, tokenizer):
tokens = sentence.split()
ids = [tokenizer.word2int.get(token, tokenizer.word2int["<UNK>"]) for token in tokens]
ids = [tokenizer.word2int["<BOS>"]] + ids[:self.max_len - 2] + [tokenizer.word2int["<EOS>"]]
ids += [tokenizer.word2int["<PAD>"]] * (self.max_len - len(ids))
return ids
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
src, tgt = self.data[idx]
return torch.tensor(src, dtype=torch.long), torch.tensor(tgt, dtype=torch.long) |