import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader class TranslationDataset(Dataset): def __init__(self, file_path, src_tokenizer, tgt_tokenizer, max_len): self.data = [] self.src_tokenizer = src_tokenizer self.tgt_tokenizer = tgt_tokenizer self.max_len = max_len with open(file_path, 'r', encoding='utf-8') as f: for line in f: src, tgt = line.strip().split('\t') self.data.append((self.encode(src, self.src_tokenizer), self.encode(tgt, self.tgt_tokenizer))) def encode(self, sentence, tokenizer): tokens = sentence.split() ids = [tokenizer.word2int.get(token, tokenizer.word2int[""]) for token in tokens] ids = [tokenizer.word2int[""]] + ids[:self.max_len - 2] + [tokenizer.word2int[""]] ids += [tokenizer.word2int[""]] * (self.max_len - len(ids)) return ids def __len__(self): return len(self.data) def __getitem__(self, idx): src, tgt = self.data[idx] return torch.tensor(src, dtype=torch.long), torch.tensor(tgt, dtype=torch.long)