import torch from torch.utils.data import Dataset import numpy as np class SchemaStringDataset(Dataset): def __init__(self, data, config): self.data = data self.config = config def __len__(self): # Return the dataset size specified in the configuration return self.config["dataset_size"] def transform_entry(self, entry): # Filter out None and NaN values filtered_entry = {k: v for k, v in entry.items() if v is not np.nan and v is not None} # Check if there are any entries after filtering if not filtered_entry: return '', '' # Return empty strings if no valid entries exist # Use the rest of the entry as input inputs = [f"{k}:{v}" for k, v in filtered_entry.items()] return ' '.join(inputs) def __getitem__(self, idx): transformed_data = { 'inputs': [] } item = self.data[idx] input_data = {k: v for k, v in item.items()} inputs = self.transform_entry(input_data) transformed_data['inputs'] = inputs transformed_data['idx'] = idx # Return the transformed item for the current idx return transformed_data