|
import torch |
|
from torch.utils.data import Dataset |
|
import numpy as np |
|
from functools import cached_property |
|
from typing import Any, Type |
|
from numpy.typing import NDArray |
|
import random |
|
from monai.utils.type_conversion import convert_to_tensor |
|
import time |
|
|
|
Tensor = Type[torch.Tensor] |
|
|
|
from .masker import Masker |
|
from . import DropoutMasker |
|
from . import MissingMasker |
|
from . import LabelMasker |
|
|
|
from .imputer import Imputer |
|
from . import FrequencyImputer |
|
from . import ConstantImputer |
|
from . import Formatter |
|
import random |
|
import os |
|
|
|
class TransformerDataset(torch.utils.data.Dataset): |
|
''' ... ''' |
|
def __init__(self, |
|
src: list[dict[str, Any]], |
|
tgt: list[dict[str, Any]] | None, |
|
src_modalities: dict[str, dict[str, Any]], |
|
tgt_modalities: dict[str, dict[str, Any]] | None, |
|
img_transform: Any | None = None, |
|
is_embedding: dict[str, bool] | None = None |
|
) -> None: |
|
''' ... ''' |
|
|
|
self.is_embedding = is_embedding |
|
|
|
|
|
self.fmt_src = Formatter(src_modalities) |
|
self.src = [self.fmt_src(smp) for smp in src] |
|
self.src_modalities = src_modalities |
|
|
|
|
|
if tgt is None: return |
|
self.fmt_tgt = Formatter(tgt_modalities) |
|
self.tgt = [self.fmt_tgt(smp) for smp in tgt] |
|
self.tgt_modalities = tgt_modalities |
|
|
|
|
|
self.img_transform = img_transform |
|
|
|
def __len__(self) -> int: |
|
''' ... ''' |
|
return len(self.src) |
|
|
|
def img_input_trans(self, k, x): |
|
if self.img_transform is not None: |
|
try: |
|
mri = self.img_transform({"image": x})["image"] |
|
if torch.isnan(mri).any() or mri.size(0) != 1: |
|
return None |
|
|
|
|
|
return mri |
|
except: |
|
return None |
|
else: |
|
return x |
|
|
|
def __getitem__(self, |
|
idx: int |
|
) -> tuple[ |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, bool], |
|
dict[str, int | NDArray[np.float32]], |
|
]: |
|
''' ... ''' |
|
|
|
for k, v in self.src[idx].items(): |
|
if isinstance(v, str): |
|
assert os.path.exists(v) |
|
self.src[idx][k] = self.img_input_trans(k, v) |
|
|
|
|
|
x_imp = self.imputer_src(self.src[idx]) |
|
mask_x = self.masker_src(self.src[idx]) |
|
y_imp = self.imputer_tgt(self.tgt[idx]) if hasattr(self, 'tgt') else None |
|
mask_y = self.masker_tgt(self.tgt[idx]) if hasattr(self, 'tgt') else None |
|
|
|
|
|
for k, v in x_imp.items(): |
|
if isinstance(v, np.memmap): |
|
x_imp[k] = np.load(v.filename) |
|
x_imp[k] = np.reshape(x_imp[k], v.shape) |
|
|
|
|
|
|
|
|
|
return x_imp, y_imp, mask_x, mask_y |
|
|
|
@cached_property |
|
def imputer_src(self) -> Imputer: |
|
''' imputer object ''' |
|
raise NotImplementedError |
|
|
|
@cached_property |
|
def imputer_tgt(self) -> Imputer: |
|
''' imputer object ''' |
|
pass |
|
|
|
@cached_property |
|
def masker_src(self) -> Masker: |
|
''' mask generator object ''' |
|
raise NotImplementedError |
|
|
|
@cached_property |
|
def masker_tgt(self) -> LabelMasker: |
|
''' mask generator object ''' |
|
pass |
|
|
|
@staticmethod |
|
def collate_fn( |
|
batch: list[ |
|
tuple[ |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, bool], |
|
dict[str, int | NDArray[np.float32]], |
|
] |
|
] |
|
) -> tuple[ |
|
dict[str, Tensor], |
|
dict[str, Tensor], |
|
dict[str, Tensor], |
|
dict[str, Tensor], |
|
]: |
|
''' ... ''' |
|
|
|
|
|
_x = [smp[0] for smp in batch] |
|
y = [smp[1] for smp in batch] |
|
m = [smp[2] for smp in batch] |
|
m_y = [smp[3] for smp in batch] |
|
|
|
|
|
y = [{k: v if v is not None else 0 for k, v in y[i].items()} for i in range(len(y))] |
|
|
|
x = {k: torch.stack([convert_to_tensor(_x[i][k]) for i in range(len(_x))]) for k in _x[0]} |
|
y = {k: torch.as_tensor(np.array([y[i][k] for i in range(len(y))])) for k in y[0]} |
|
m = {k: torch.as_tensor(np.array([m[i][k] for i in range(len(m))])) for k in m[0]} |
|
m_y = {k: torch.as_tensor(np.array([m_y[i][k] for i in range(len(m_y))])) for k in m_y[0]} |
|
|
|
return x, y, m, m_y |
|
|
|
|
|
class TransformerTrainingDataset(TransformerDataset): |
|
''' ... ''' |
|
def __init__(self, |
|
src: list[dict[str, Any]], |
|
tgt: list[dict[str, Any]], |
|
src_modalities: dict[str, dict[str, Any]], |
|
tgt_modalities: dict[str, dict[str, Any]], |
|
dropout_rate: float = .5, |
|
dropout_strategy: str = 'permutation', |
|
img_transform: Any | None = None, |
|
) -> None: |
|
''' ... ''' |
|
|
|
super().__init__(src, tgt, src_modalities, tgt_modalities, img_transform=img_transform) |
|
|
|
self.dropout_rate = dropout_rate |
|
self.dropout_strategy = dropout_strategy |
|
|
|
print(img_transform) |
|
|
|
@cached_property |
|
def imputer_src(self) -> FrequencyImputer: |
|
''' imputer object ''' |
|
return FrequencyImputer(self.src_modalities, self.src) |
|
|
|
@cached_property |
|
def imputer_tgt(self) -> ConstantImputer: |
|
''' imputer object ''' |
|
return ConstantImputer(self.tgt_modalities) |
|
|
|
@cached_property |
|
def masker_src(self) -> DropoutMasker: |
|
''' mask generator object ''' |
|
return DropoutMasker( |
|
self.src_modalities, self.src, |
|
dropout_rate = self.dropout_rate, |
|
dropout_strategy = self.dropout_strategy, |
|
) |
|
|
|
@cached_property |
|
def masker_tgt(self) -> LabelMasker: |
|
''' mask generator object ''' |
|
return LabelMasker(self.tgt_modalities) |
|
|
|
class TransformerValidationDataset(TransformerDataset): |
|
def __init__(self, |
|
src: list[dict[str, Any]], |
|
tgt: list[dict[str, Any]], |
|
src_modalities: dict[str, dict[str, Any]], |
|
tgt_modalities: dict[str, dict[str, Any]], |
|
img_transform: Any | None = None, |
|
is_embedding: dict[str, bool] | None = None |
|
) -> None: |
|
''' ... ''' |
|
|
|
super().__init__(src, tgt, src_modalities, tgt_modalities, img_transform=img_transform, is_embedding=is_embedding) |
|
|
|
@cached_property |
|
def imputer_src(self) -> ConstantImputer: |
|
''' imputer object ''' |
|
return ConstantImputer(self.src_modalities, self.is_embedding) |
|
|
|
@cached_property |
|
def imputer_tgt(self) -> ConstantImputer: |
|
''' imputer object ''' |
|
return ConstantImputer(self.tgt_modalities) |
|
|
|
@cached_property |
|
def masker_src(self) -> MissingMasker: |
|
''' mask generator object ''' |
|
return MissingMasker(self.src_modalities) |
|
|
|
@cached_property |
|
def masker_tgt(self) -> LabelMasker: |
|
''' mask generator object ''' |
|
return LabelMasker(self.tgt_modalities) |
|
|
|
|
|
class TransformerTestingDataset(TransformerValidationDataset): |
|
|
|
def __init__(self, |
|
src: list[dict[str, Any]], |
|
src_modalities: dict[str, dict[str, Any]], |
|
img_transform: Any | None = None, |
|
is_embedding: dict[str, bool] | None = None |
|
) -> None: |
|
''' ... ''' |
|
|
|
super().__init__(src, None, src_modalities, None, img_transform=img_transform, is_embedding=is_embedding) |
|
|
|
def __getitem__(self, |
|
idx: int |
|
) -> tuple[ |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, bool], |
|
]: |
|
''' ... ''' |
|
x_imp, _, mask_x, _ = super().__getitem__(idx) |
|
return x_imp, mask_x |
|
|
|
@staticmethod |
|
def collate_fn( |
|
batch: list[ |
|
tuple[ |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, bool], |
|
] |
|
] |
|
) -> tuple[ |
|
dict[str, Tensor], |
|
dict[str, Tensor], |
|
]: |
|
''' ... ''' |
|
|
|
x = [smp[0] for smp in batch] |
|
m = [smp[1] for smp in batch] |
|
|
|
|
|
x = {k: torch.as_tensor(np.array([x[i][k] for i in range(len(x))])) for k in x[0]} |
|
m = {k: torch.as_tensor(np.array([m[i][k] for i in range(len(m))])) for k in m[0]} |
|
|
|
return x, m |
|
|
|
|
|
class TransformerBalancedTrainingDataset(TransformerTrainingDataset): |
|
|
|
def __init__(self, |
|
src: list[dict[str, Any]], |
|
tgt: list[dict[str, Any]], |
|
src_modalities: dict[str, dict[str, Any]], |
|
tgt_modalities: dict[str, dict[str, Any]], |
|
dropout_rate: float = .5, |
|
dropout_strategy: str = 'permutation', |
|
img_transform: Any | None = None, |
|
) -> None: |
|
''' ... ''' |
|
|
|
super().__init__( |
|
src, tgt, src_modalities, tgt_modalities, |
|
dropout_rate, dropout_strategy, img_transform, |
|
) |
|
|
|
|
|
self.tgt_indices: dict[str, dict[int, list[int]]] = dict() |
|
for tgt_k in self.tgt_modalities: |
|
tmp = [self.tgt[i][tgt_k] for i in range(len(self.tgt))] |
|
self.tgt_indices[tgt_k] = dict() |
|
self.tgt_indices[tgt_k][0] = [i for i in range(len(self.tgt)) if tmp[i] == 0] |
|
self.tgt_indices[tgt_k][1] = [i for i in range(len(self.tgt)) if tmp[i] == 1] |
|
|
|
def __getitem__(self, |
|
idx: int |
|
) -> tuple[ |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, bool], |
|
dict[str, bool], |
|
]: |
|
|
|
tgt_k = random.choice(list(self.tgt_modalities.keys())) |
|
cls = random.choice([0, 1]) |
|
idx = random.choice(self.tgt_indices[tgt_k][cls]) |
|
|
|
|
|
x_imp, y_imp, mask_x, mask_y = super().__getitem__(idx) |
|
|
|
|
|
mask_y = {k: mask_y[k] if k == tgt_k else 0 for k in self.tgt_modalities} |
|
|
|
|
|
return x_imp, y_imp, mask_x, mask_y |
|
|
|
|
|
class Transformer2ndOrderBalancedTrainingDataset(TransformerTrainingDataset): |
|
|
|
def __init__(self, |
|
src: list[dict[str, Any]], |
|
tgt: list[dict[str, Any]], |
|
src_modalities: dict[str, dict[str, Any]], |
|
tgt_modalities: dict[str, dict[str, Any]], |
|
dropout_rate: float = .5, |
|
dropout_strategy: str = 'permutation', |
|
img_transform: Any | None = None, |
|
) -> None: |
|
""" ... """ |
|
|
|
super().__init__( |
|
src, tgt, src_modalities, tgt_modalities, |
|
dropout_rate, dropout_strategy, img_transform, |
|
) |
|
|
|
|
|
self.tasks: dict[tuple[str, str], list[int]] = {} |
|
tgt_keys = list(self.tgt_modalities.keys()) |
|
for tgt_k_0 in tgt_keys: |
|
for tgt_k_1 in tgt_keys: |
|
self.tasks[(tgt_k_0, tgt_k_1)] = [] |
|
|
|
for i, smp in enumerate(tgt): |
|
for tgt_k_0 in tgt_keys: |
|
for tgt_k_1 in tgt_keys: |
|
if smp[tgt_k_0] == 0 and smp[tgt_k_1] == 1: |
|
self.tasks[(tgt_k_0, tgt_k_1)].append(i) |
|
|
|
def __getitem__(self, |
|
idx: int |
|
) -> tuple[ |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, int | NDArray[np.float32]], |
|
dict[str, bool], |
|
dict[str, bool], |
|
]: |
|
|
|
while True: |
|
tgt_k_0 = random.choice(list(self.tgt_modalities.keys())) |
|
tgt_k_1 = random.choice(list(self.tgt_modalities.keys())) |
|
if len(self.tasks[(tgt_k_0, tgt_k_1)]) != 0: |
|
idx = random.choice(self.tasks[(tgt_k_0, tgt_k_1)]) |
|
break |
|
|
|
|
|
x_imp, y_imp, mask_x, mask_y = super().__getitem__(idx) |
|
|
|
|
|
mask_y = {k: mask_y[k] if k in [tgt_k_0, tgt_k_1] else 0 for k in self.tgt_modalities} |
|
|
|
return x_imp, y_imp, mask_x, mask_y |
|
|
|
|