|
from typing import Any |
|
from numpy.typing import NDArray |
|
import numpy as np |
|
|
|
class Formatter: |
|
''' ... ''' |
|
def __init__(self, |
|
modalities: dict[str, dict[str, Any]], |
|
) -> None: |
|
''' ... ''' |
|
self.modalities = modalities |
|
|
|
def __call__(self, |
|
smp: dict[str, Any], |
|
) -> dict[str, int | NDArray[np.float32] | None]: |
|
''' ... ''' |
|
new = dict() |
|
|
|
|
|
for k, info in self.modalities.items(): |
|
|
|
if k not in smp or smp[k] is None: |
|
new[k] = None |
|
continue |
|
|
|
|
|
v = smp[k] |
|
|
|
|
|
|
|
|
|
if info['type'] == 'imaging' and len(info['shape']) == 4: |
|
new[k] = v |
|
continue |
|
|
|
|
|
try: |
|
v_np = np.array(v, dtype=np.float32) |
|
except: |
|
raise ValueError('\"{}\" has unexpected value {}'.format(k, v)) |
|
|
|
|
|
if info['type'] == 'categorical': |
|
|
|
if v_np.shape != (): |
|
raise ValueError('Categorical data \"{}\" has unexpected value {}.'.format(k, v)) |
|
elif int(v) != v: |
|
raise ValueError('Categorical data \"{}\" has unexpected value {}.'.format(k, v)) |
|
elif v < 0: |
|
raise ValueError('Categorical data \"{}\" has unexpected value {}.'.format(k, v)) |
|
|
|
|
|
elif info['type'] == 'numerical': |
|
if info['shape'] == [1] and v_np.shape != () and v_np.shape != (1,): |
|
raise ValueError('Numerical data \"{}\" has unexpected shape {}.'.format(k, v_np.shape)) |
|
elif info['shape'] != [1] and tuple(info['shape']) != v_np.shape: |
|
raise ValueError('Numerical data \"{}\" has unexpected shape {}.'.format(k, v_np.shape)) |
|
|
|
|
|
|
|
|
|
if info['type'] == 'categorical': |
|
new[k] = int(v) |
|
|
|
|
|
elif info['type'] == 'numerical' or info['type'] == 'imaging': |
|
if info['shape'] == [1] and v_np.shape == (): |
|
|
|
new[k] = np.array([v], dtype=np.float32) |
|
else: |
|
new[k] = v_np |
|
|
|
|
|
return new |