Spaces:
Running
Running
import torch | |
from torch.utils.data import Dataset | |
import numpy as np | |
class SchemaStringDataset(Dataset): | |
def __init__(self, data, config): | |
self.data = data | |
self.config = config | |
def __len__(self): | |
# Return the dataset size specified in the configuration | |
return self.config["dataset_size"] | |
def transform_entry(self, entry): | |
# Filter out None and NaN values | |
filtered_entry = {k: v for k, v in entry.items() if v is not np.nan and v is not None} | |
# Check if there are any entries after filtering | |
if not filtered_entry: | |
return '', '' # Return empty strings if no valid entries exist | |
# Use the rest of the entry as input | |
inputs = [f"{k}:{v}" for k, v in filtered_entry.items()] | |
return ' '.join(inputs) | |
def __getitem__(self, idx): | |
transformed_data = { | |
'inputs': [] | |
} | |
item = self.data[idx] | |
input_data = {k: v for k, v in item.items()} | |
inputs = self.transform_entry(input_data) | |
transformed_data['inputs'] = inputs | |
transformed_data['idx'] = idx | |
# Return the transformed item for the current idx | |
return transformed_data | |