from torch.utils.data import Dataset from transformers import AutoTokenizer import torch class TokenizedDataset(Dataset): def __init__(self, custom_dataset, tokenizer, max_seq_len): """ custom_dataset: An instance of CustomDataset tokenizer: An instance of the tokenizer max_seq_len: Maximum sequence length for padding """ self.dataset = custom_dataset self.tokenizer = tokenizer self.max_seq_len = max_seq_len def __len__(self): # The length is inherited from the custom dataset return len(self.dataset) def tokenize_and_pad(self, text_list): """ Tokenize and pad a list of text strings. """ # Tokenize all text strings in the list tokens = self.tokenizer(text_list, padding='max_length', max_length=self.max_seq_len, truncation=True, return_tensors="pt") return tokens def __getitem__(self, idx): # Fetch the transformed data from the CustomDataset instance transformed_data = self.dataset[idx] # Initialize containers for inputs and optionally labels tokenized_inputs = {} tokenized_labels = {} # Dynamically process each item in the dataset for key, value in transformed_data.items(): if type(value) == int: # Check if value is an integer # Convert integer to tensor and directly assign to inputs or labels based on key prefix if key.startswith('label'): tokenized_labels[key] = torch.tensor(value) # Convert int to tensor for labels else: tokenized_inputs[key] = torch.tensor(value) # Convert int to tensor for inputs if type(value) == str: tokenized_data = self.tokenize_and_pad(value) if key.startswith('label'): tokenized_labels[key] = tokenized_data['input_ids'] tokenized_labels['attention_mask_' + key] = tokenized_data['attention_mask'] else: tokenized_inputs[key] = tokenized_data['input_ids'] tokenized_inputs['attention_mask_' + key] = tokenized_data['attention_mask'] # Prepare the return structure, conditionally including 'label' if labels are present output = {"inputs": tokenized_inputs} if tokenized_labels: # Check if there are any labels before adding to the output output["label"] = tokenized_labels return output