File size: 1,448 Bytes
a5c42f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from torch.utils.data import Dataset
import torch
import pandas as pd

def load_data(args, split):
    df = pd.read_csv(f"{args.data_root}/{split}.csv")
    texts = df['text'].values.tolist()
    labels = df['target'].values.tolist()
    return texts, labels

class MyDataset(Dataset):
    def __init__(self, data, tokenizer, max_length, is_test):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.texts = data[0]
        self.labels = data[1]
        self.is_test = is_test
            
    def __len__(self):
        """returns the length of dataframe"""
        return len(self.texts)

    def __getitem__(self, index):
        """return the input ids, attention masks and target ids"""
        text = str(self.texts[index])
        source = self.tokenizer.batch_encode_plus(
            [text],
            max_length=self.max_length,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        source_ids = source["input_ids"].squeeze()
        source_mask = source["attention_mask"].squeeze()
        data_sample = {
            "input_ids": source_ids,
            "attention_mask": source_mask,
        }
        if not self.is_test:
            label = self.labels[index]
            target_ids = torch.tensor(label).squeeze()
            data_sample["labels"] = target_ids 
        return data_sample