File size: 2,539 Bytes
37c2a8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import torch


class TokenizedDataset(Dataset):
    def __init__(self, custom_dataset, tokenizer, max_seq_len):
        """
        custom_dataset: An instance of CustomDataset
        tokenizer: An instance of the tokenizer
        max_seq_len: Maximum sequence length for padding
        """
        self.dataset = custom_dataset
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __len__(self):
        # The length is inherited from the custom dataset
        return len(self.dataset)

    def tokenize_and_pad(self, text_list):
        """
        Tokenize and pad a list of text strings.
        """
        # Tokenize all text strings in the list
        tokens = self.tokenizer(text_list, padding='max_length', max_length=self.max_seq_len, truncation=True, return_tensors="pt")
        return tokens

    def __getitem__(self, idx):
        # Fetch the transformed data from the CustomDataset instance
        transformed_data = self.dataset[idx]

        # Initialize containers for inputs and optionally labels
        tokenized_inputs = {}
        tokenized_labels = {}

        # Dynamically process each item in the dataset
        for key, value in transformed_data.items():
            if type(value) == int:  # Check if value is an integer
                # Convert integer to tensor and directly assign to inputs or labels based on key prefix
                if key.startswith('label'):
                    tokenized_labels[key] = torch.tensor(value)  # Convert int to tensor for labels
                else:
                    tokenized_inputs[key] = torch.tensor(value)  # Convert int to tensor for inputs

            if type(value) == str:
                tokenized_data = self.tokenize_and_pad(value)
                if key.startswith('label'):
                    tokenized_labels[key] = tokenized_data['input_ids']
                    tokenized_labels['attention_mask_' + key] = tokenized_data['attention_mask']

                else:
                    tokenized_inputs[key] = tokenized_data['input_ids']
                    tokenized_inputs['attention_mask_' + key] = tokenized_data['attention_mask']


        # Prepare the return structure, conditionally including 'label' if labels are present
        output = {"inputs": tokenized_inputs}
        if tokenized_labels:  # Check if there are any labels before adding to the output
            output["label"] = tokenized_labels

        return output