Spaces:
Running
Running
from torch.utils.data import Dataset | |
from transformers import AutoTokenizer | |
import torch | |
class TokenizedDataset(Dataset): | |
def __init__(self, custom_dataset, tokenizer, max_seq_len): | |
""" | |
custom_dataset: An instance of CustomDataset | |
tokenizer: An instance of the tokenizer | |
max_seq_len: Maximum sequence length for padding | |
""" | |
self.dataset = custom_dataset | |
self.tokenizer = tokenizer | |
self.max_seq_len = max_seq_len | |
def __len__(self): | |
# The length is inherited from the custom dataset | |
return len(self.dataset) | |
def tokenize_and_pad(self, text_list): | |
""" | |
Tokenize and pad a list of text strings. | |
""" | |
# Tokenize all text strings in the list | |
tokens = self.tokenizer(text_list, padding='max_length', max_length=self.max_seq_len, truncation=True, return_tensors="pt") | |
return tokens | |
def __getitem__(self, idx): | |
# Fetch the transformed data from the CustomDataset instance | |
transformed_data = self.dataset[idx] | |
# Initialize containers for inputs and optionally labels | |
tokenized_inputs = {} | |
tokenized_labels = {} | |
# Dynamically process each item in the dataset | |
for key, value in transformed_data.items(): | |
if type(value) == int: # Check if value is an integer | |
# Convert integer to tensor and directly assign to inputs or labels based on key prefix | |
if key.startswith('label'): | |
tokenized_labels[key] = torch.tensor(value) # Convert int to tensor for labels | |
else: | |
tokenized_inputs[key] = torch.tensor(value) # Convert int to tensor for inputs | |
if type(value) == str: | |
tokenized_data = self.tokenize_and_pad(value) | |
if key.startswith('label'): | |
tokenized_labels[key] = tokenized_data['input_ids'] | |
tokenized_labels['attention_mask_' + key] = tokenized_data['attention_mask'] | |
else: | |
tokenized_inputs[key] = tokenized_data['input_ids'] | |
tokenized_inputs['attention_mask_' + key] = tokenized_data['attention_mask'] | |
# Prepare the return structure, conditionally including 'label' if labels are present | |
output = {"inputs": tokenized_inputs} | |
if tokenized_labels: # Check if there are any labels before adding to the output | |
output["label"] = tokenized_labels | |
return output | |