|
from abc import ABC, abstractmethod |
|
from collections.abc import Callable |
|
from functools import wraps |
|
from typing import Any |
|
from numpy.typing import NDArray |
|
import numpy as np |
|
import torch |
|
|
|
class Imputer(ABC): |
|
''' ... ''' |
|
def __init__(self, |
|
modalities: dict[str, dict[str, Any]], |
|
is_embedding: dict[str, bool] | None = None |
|
) -> None: |
|
''' ... ''' |
|
self.modalities = modalities |
|
self.is_embedding = is_embedding |
|
|
|
@abstractmethod |
|
def __call__(self, |
|
smp: dict[str, int | NDArray[np.float32] | None], |
|
) -> dict[str, int | NDArray[np.float32]]: |
|
''' ... ''' |
|
pass |
|
|
|
@staticmethod |
|
def _keyerror_hint(func): |
|
''' Print hint for resolving KeyError. ''' |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
try: |
|
return func(*args, **kwargs) |
|
except KeyError as err: |
|
raise ValueError('Format the data using Formatter module.') from err |
|
return wrapper |
|
|
|
|
|
class ConstantImputer(Imputer): |
|
''' ... ''' |
|
@Imputer._keyerror_hint |
|
def __call__(self, |
|
smp: dict[str, int | NDArray[np.float32] | None], |
|
) -> dict[str, int | NDArray[np.float32]]: |
|
''' ... ''' |
|
new = dict() |
|
for k, info in self.modalities.items(): |
|
if smp[k] is not None: |
|
new[k] = smp[k] |
|
else: |
|
if self.is_embedding is not None and k in self.is_embedding and self.is_embedding[k]: |
|
new[k] = np.zeros(256, dtype=np.float32) |
|
else: |
|
if info['type'] == 'categorical': |
|
new[k] = 0 |
|
elif info['type'] == 'numerical' or info['type'] == 'imaging': |
|
new[k] = np.zeros(tuple(info['shape']), dtype=np.float32) |
|
else: |
|
raise ValueError |
|
return new |
|
|
|
|
|
class FrequencyImputer(Imputer): |
|
''' ... ''' |
|
@Imputer._keyerror_hint |
|
def __init__(self, |
|
modalities: dict[str, dict[str, Any]], |
|
dat: list[dict[str, int | NDArray[np.float32] | None]], |
|
) -> None: |
|
''' ... ''' |
|
super().__init__(modalities) |
|
|
|
|
|
self.pool = {k: [smp[k] for smp in dat] for k in modalities} |
|
|
|
|
|
self.pool = {k: [v for v in self.pool[k] if v is not None] for k in self.pool} |
|
|
|
|
|
@Imputer._keyerror_hint |
|
def __call__(self, |
|
smp: dict[str, int | NDArray[np.float32] | None], |
|
) -> dict[str, int | NDArray[np.float32]]: |
|
''' ... ''' |
|
new = dict() |
|
for k, info in self.modalities.items(): |
|
if smp[k] is not None: |
|
new[k] = smp[k] |
|
else: |
|
|
|
if info['type'] == 'categorical': |
|
new[k] = 0 |
|
else: |
|
if info['type'] == 'numerical': |
|
rnd_idx = np.random.randint(0, len(self.pool[k])) |
|
new[k] = np.array(self.pool[k][rnd_idx]) |
|
|
|
elif info['type'] == 'imaging': |
|
new[k] = np.zeros(tuple(info['shape']), dtype=np.float32) |
|
|
|
else: |
|
ic(info['shape']) |
|
raise ValueError |
|
return new |