File size: 1,219 Bytes
37c2a8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from torch.utils.data import Dataset
import numpy as np


class SchemaStringDataset(Dataset):
    def __init__(self, data, config):
        self.data = data
        self.config = config

    def __len__(self):
        # Return the dataset size specified in the configuration
        return self.config["dataset_size"]

    def transform_entry(self, entry):
        # Filter out None and NaN values
        filtered_entry = {k: v for k, v in entry.items() if v is not np.nan and v is not None}

        # Check if there are any entries after filtering
        if not filtered_entry:
            return '', ''  # Return empty strings if no valid entries exist

        # Use the rest of the entry as input
        inputs = [f"{k}:{v}" for k, v in filtered_entry.items()]

        return ' '.join(inputs)
    def __getitem__(self, idx):
        transformed_data = {
            'inputs': []
        }

        item = self.data[idx]
        input_data = {k: v for k, v in item.items()}
        inputs = self.transform_entry(input_data)
        transformed_data['inputs'] = inputs

        transformed_data['idx'] = idx

        # Return the transformed item for the current idx
        return transformed_data