Suburst's picture
Upload 27 files
f8bd4d2 verified
raw
history blame
1.19 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
class TranslationDataset(Dataset):
def __init__(self, file_path, src_tokenizer, tgt_tokenizer, max_len):
self.data = []
self.src_tokenizer = src_tokenizer
self.tgt_tokenizer = tgt_tokenizer
self.max_len = max_len
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
src, tgt = line.strip().split('\t')
self.data.append((self.encode(src, self.src_tokenizer), self.encode(tgt, self.tgt_tokenizer)))
def encode(self, sentence, tokenizer):
tokens = sentence.split()
ids = [tokenizer.word2int.get(token, tokenizer.word2int["<UNK>"]) for token in tokens]
ids = [tokenizer.word2int["<BOS>"]] + ids[:self.max_len - 2] + [tokenizer.word2int["<EOS>"]]
ids += [tokenizer.word2int["<PAD>"]] * (self.max_len - len(ids))
return ids
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
src, tgt = self.data[idx]
return torch.tensor(src, dtype=torch.long), torch.tensor(tgt, dtype=torch.long)