File size: 2,765 Bytes
6fc43ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Any
from numpy.typing import NDArray
import numpy as np

class Formatter:
    ''' ... '''
    def __init__(self, 
        modalities: dict[str, dict[str, Any]],
    ) -> None:
        ''' ... '''
        self.modalities = modalities

    def __call__(self,
        smp: dict[str, Any],
    ) -> dict[str, int | NDArray[np.float32] | None]:
        ''' ... '''
        new = dict()

        # loop through all data modalities
        for k, info in self.modalities.items():
            # the value is missing or equals None
            if k not in smp or smp[k] is None:
                new[k] = None
                continue

            # get value
            v = smp[k]

            # if info['type'] == 'imaging':
            #     print(k)
            # print(v.shape)
            if info['type'] == 'imaging' and len(info['shape']) == 4:
                new[k] = v
                continue

            # validate the value by using numpy's intrinsic machanism 
            try:
                v_np = np.array(v, dtype=np.float32)
            except:
                raise ValueError('\"{}\" has unexpected value {}'.format(k, v))
            
            # additional validation for categorical value
            if info['type'] == 'categorical':
                # print(k, v_np.shape)
                if v_np.shape != ():
                    raise ValueError('Categorical data \"{}\" has unexpected value {}.'.format(k, v))
                elif int(v) != v:
                    raise ValueError('Categorical data \"{}\" has unexpected value {}.'.format(k, v))
                elif v < 0: # or v >= info['num_categories']:
                    raise ValueError('Categorical data \"{}\" has unexpected value {}.'.format(k, v))
            
            # additional validation for numerical value
            elif info['type'] == 'numerical':
                if info['shape'] == [1] and v_np.shape != () and v_np.shape != (1,):
                    raise ValueError('Numerical data \"{}\" has unexpected shape {}.'.format(k, v_np.shape))
                elif info['shape'] != [1] and tuple(info['shape']) != v_np.shape:
                    raise ValueError('Numerical data \"{}\" has unexpected shape {}.'.format(k, v_np.shape))

            
                
            # format categorical value
            if info['type'] == 'categorical':
                new[k] = int(v)

            # format numerical value
            elif info['type'] == 'numerical' or info['type'] == 'imaging':
                if info['shape'] == [1] and v_np.shape == ():
                    # unsqueeze the data
                    new[k] = np.array([v], dtype=np.float32)
                else:
                    new[k] = v_np
            

        return new