File size: 3,226 Bytes
46fcc2f
 
 
41ed540
46fcc2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import glob, json, os
import torch
import warnings
import numpy as np

from torch.utils.data import Dataset

class BaseDataset2(Dataset):
    """Template class for all datasets in the project."""

    def __init__(self, x, y):
        """Initialize dataset.

        Args:
            x(ndarray): Input features.
            y(ndarray): Targets.
        """
        self.data = torch.from_numpy(x).float()
        self.targets = torch.from_numpy(y).float()
        self.latents = None

        self.labels = None
        self.is_radial = []
        self.partition = True

    def __getitem__(self, index):
        return self.data[index], self.targets[index], index

    def __len__(self):
        return len(self.data)

    def numpy(self, idx=None):
        """Get dataset as ndarray.

        Specify indices to return a subset of the dataset, otherwise return whole dataset.

        Args:
            idx(int, optional): Specify index or indices to return.

        Returns:
            ndarray: Return flattened dataset as a ndarray.

        """
        n = len(self)

        data = self.data.numpy().reshape((n, -1))

        if idx is None:
            return data, self.targets.numpy()
        else:
            return data[idx], self.targets[idx].numpy()

    def get_latents(self):
        """Get latent variables.

        Returns:
            latents(ndarray): Latent variables for each sample.
        """
        return self.latents


def load_json(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

def read_json_files(file):
    data_x = []
    data_y = []

    samples = load_json(file)
    valid_samples = 0

    for sample in samples:
        data = []
        skip_sample = False
        for key in ['AX1', 'AX2', 'AX3', 'AX4', 'AY1', 'AY2', 'AY3', 'AY4', 'AZ1', 'AZ2', 'AZ3', 'AZ4', 'GX1', 'GX2', 'GX3', 'GX4', 'GY1', 'GY2', 'GY3', 'GY4', 'GZ1', 'GZ2', 'GZ3', 'GZ4', 'GZ1_precise_time_diff', 'GZ2_precise_time_diff', 'GZ3_precise_time_diff', 'GZ4_precise_time_diff', 'precise_time_diff']:
            if key in sample:
                if key.endswith('_precise_time_diff') or key == 'precise_time_diff':
                    if sample[key] is None:
                        skip_sample = True
                        break
                    data.append(round(sample[key])*20)
                else:
                    data.extend(sample[key])
            else:
                warnings.warn(f"KeyError: {key} not found in JSON file: {file}")

        if skip_sample:
            #warnings.warn(f"Skipped sample with null values in JSON file: {json_file}")
            continue

        if len(data) != 768*2 + 5:  # 24 keys * 64 values each + 5 additional values
            warnings.warn(f"Incomplete sample in JSON file: {file}")
            continue

        valid_samples += 1
        tensor = torch.tensor(data, dtype=torch.float32)
        data_x.append(tensor)
        data_y.append(1)

    if valid_samples == 0:
        warnings.warn(f"No valid samples found in JSON file: {file}")

    if not data_x:
        raise ValueError("No valid samples found in all the JSON files.")
    
    return torch.stack(data_x), torch.tensor(data_y, dtype=torch.long)