File size: 1,185 Bytes
f8bd4d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

class TranslationDataset(Dataset):
    def __init__(self, file_path, src_tokenizer, tgt_tokenizer, max_len):
        self.data = []
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_len = max_len

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                src, tgt = line.strip().split('\t')
                self.data.append((self.encode(src, self.src_tokenizer), self.encode(tgt, self.tgt_tokenizer)))

    def encode(self, sentence, tokenizer):
        tokens = sentence.split()
        ids = [tokenizer.word2int.get(token, tokenizer.word2int["<UNK>"]) for token in tokens]
        ids = [tokenizer.word2int["<BOS>"]] + ids[:self.max_len - 2] + [tokenizer.word2int["<EOS>"]]
        ids += [tokenizer.word2int["<PAD>"]] * (self.max_len - len(ids))
        return ids

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src, tgt = self.data[idx]
        return torch.tensor(src, dtype=torch.long), torch.tensor(tgt, dtype=torch.long)