File size: 3,464 Bytes
6fc43ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import wraps
from typing import Any
from numpy.typing import NDArray
import numpy as np
import torch
class Imputer(ABC):
''' ... '''
def __init__(self,
modalities: dict[str, dict[str, Any]],
is_embedding: dict[str, bool] | None = None
) -> None:
''' ... '''
self.modalities = modalities
self.is_embedding = is_embedding
@abstractmethod
def __call__(self,
smp: dict[str, int | NDArray[np.float32] | None],
) -> dict[str, int | NDArray[np.float32]]:
''' ... '''
pass
@staticmethod
def _keyerror_hint(func):
''' Print hint for resolving KeyError. '''
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except KeyError as err:
raise ValueError('Format the data using Formatter module.') from err
return wrapper
class ConstantImputer(Imputer):
''' ... '''
@Imputer._keyerror_hint
def __call__(self,
smp: dict[str, int | NDArray[np.float32] | None],
) -> dict[str, int | NDArray[np.float32]]:
''' ... '''
new = dict()
for k, info in self.modalities.items():
if smp[k] is not None:
new[k] = smp[k]
else:
if self.is_embedding is not None and k in self.is_embedding and self.is_embedding[k]:
new[k] = np.zeros(256, dtype=np.float32)
else:
if info['type'] == 'categorical':
new[k] = 0
elif info['type'] == 'numerical' or info['type'] == 'imaging':
new[k] = np.zeros(tuple(info['shape']), dtype=np.float32)
else:
raise ValueError
return new
class FrequencyImputer(Imputer):
''' ... '''
@Imputer._keyerror_hint
def __init__(self,
modalities: dict[str, dict[str, Any]],
dat: list[dict[str, int | NDArray[np.float32] | None]],
) -> None:
''' ... '''
super().__init__(modalities)
# List[Dict] to Dict[List]
self.pool = {k: [smp[k] for smp in dat] for k in modalities}
# remove None
self.pool = {k: [v for v in self.pool[k] if v is not None] for k in self.pool}
@Imputer._keyerror_hint
def __call__(self,
smp: dict[str, int | NDArray[np.float32] | None],
) -> dict[str, int | NDArray[np.float32]]:
''' ... '''
new = dict()
for k, info in self.modalities.items():
if smp[k] is not None:
new[k] = smp[k]
else:
# print(k)
if info['type'] == 'categorical':
new[k] = 0
else:
if info['type'] == 'numerical':
rnd_idx = np.random.randint(0, len(self.pool[k]))
new[k] = np.array(self.pool[k][rnd_idx])
# print(type(new[k]))
elif info['type'] == 'imaging':
new[k] = np.zeros(tuple(info['shape']), dtype=np.float32)
# print(new[k].shape)
else:
ic(info['shape'])
raise ValueError
return new |