nmed2024 / adrd /utils /transformer_dataset.py
xf3227's picture
ok
6fc43ab
import torch
from torch.utils.data import Dataset
import numpy as np
from functools import cached_property
from typing import Any, Type
from numpy.typing import NDArray
import random
from monai.utils.type_conversion import convert_to_tensor
import time
Tensor = Type[torch.Tensor]
from .masker import Masker
from . import DropoutMasker
from . import MissingMasker
from . import LabelMasker
from .imputer import Imputer
from . import FrequencyImputer
from . import ConstantImputer
from . import Formatter
import random
import os
class TransformerDataset(torch.utils.data.Dataset):
''' ... '''
def __init__(self,
src: list[dict[str, Any]],
tgt: list[dict[str, Any]] | None,
src_modalities: dict[str, dict[str, Any]],
tgt_modalities: dict[str, dict[str, Any]] | None,
img_transform: Any | None = None,
is_embedding: dict[str, bool] | None = None
) -> None:
''' ... '''
# boolean dict to indicate which features are embeddings
self.is_embedding = is_embedding
# format source
self.fmt_src = Formatter(src_modalities)
self.src = [self.fmt_src(smp) for smp in src]
self.src_modalities = src_modalities
# self.src = src
# format target
if tgt is None: return
self.fmt_tgt = Formatter(tgt_modalities)
self.tgt = [self.fmt_tgt(smp) for smp in tgt]
self.tgt_modalities = tgt_modalities
# self.tgt = tgt
self.img_transform = img_transform
def __len__(self) -> int:
''' ... '''
return len(self.src)
def img_input_trans(self, k, x):
if self.img_transform is not None:
try:
mri = self.img_transform({"image": x})["image"]
if torch.isnan(mri).any() or mri.size(0) != 1:
return None
# print(mri)
# print(torch.all(mri == 0))
return mri
except:
return None
else:
return x
def __getitem__(self,
idx: int
) -> tuple[
dict[str, int | NDArray[np.float32]],
dict[str, int | NDArray[np.float32]],
dict[str, bool],
dict[str, int | NDArray[np.float32]],
]:
''' ... '''
for k, v in self.src[idx].items():
if isinstance(v, str):
assert os.path.exists(v)
self.src[idx][k] = self.img_input_trans(k, v)
# impute x and y
x_imp = self.imputer_src(self.src[idx])
mask_x = self.masker_src(self.src[idx])
y_imp = self.imputer_tgt(self.tgt[idx]) if hasattr(self, 'tgt') else None
mask_y = self.masker_tgt(self.tgt[idx]) if hasattr(self, 'tgt') else None
# replace mmap object by the loaded one
for k, v in x_imp.items():
if isinstance(v, np.memmap):
x_imp[k] = np.load(v.filename)
x_imp[k] = np.reshape(x_imp[k], v.shape)
# elif isinstance(v, str):
# assert os.path.exists(v)
# x_imp[k] = self.img_input_trans(k, v)
return x_imp, y_imp, mask_x, mask_y
@cached_property
def imputer_src(self) -> Imputer:
''' imputer object '''
raise NotImplementedError
@cached_property
def imputer_tgt(self) -> Imputer:
''' imputer object '''
pass
@cached_property
def masker_src(self) -> Masker:
''' mask generator object '''
raise NotImplementedError
@cached_property
def masker_tgt(self) -> LabelMasker:
''' mask generator object '''
pass
@staticmethod
def collate_fn(
batch: list[
tuple[
dict[str, int | NDArray[np.float32]],
dict[str, int | NDArray[np.float32]],
dict[str, bool],
dict[str, int | NDArray[np.float32]],
]
]
) -> tuple[
dict[str, Tensor],
dict[str, Tensor],
dict[str, Tensor],
dict[str, Tensor],
]:
''' ... '''
# start_time = time.time()
# seperate entries
_x = [smp[0] for smp in batch]
y = [smp[1] for smp in batch]
m = [smp[2] for smp in batch]
m_y = [smp[3] for smp in batch]
y = [{k: v if v is not None else 0 for k, v in y[i].items()} for i in range(len(y))]
x = {k: torch.stack([convert_to_tensor(_x[i][k]) for i in range(len(_x))]) for k in _x[0]}
y = {k: torch.as_tensor(np.array([y[i][k] for i in range(len(y))])) for k in y[0]}
m = {k: torch.as_tensor(np.array([m[i][k] for i in range(len(m))])) for k in m[0]}
m_y = {k: torch.as_tensor(np.array([m_y[i][k] for i in range(len(m_y))])) for k in m_y[0]}
return x, y, m, m_y
class TransformerTrainingDataset(TransformerDataset):
''' ... '''
def __init__(self,
src: list[dict[str, Any]],
tgt: list[dict[str, Any]],
src_modalities: dict[str, dict[str, Any]],
tgt_modalities: dict[str, dict[str, Any]],
dropout_rate: float = .5,
dropout_strategy: str = 'permutation',
img_transform: Any | None = None,
) -> None:
''' ... '''
# call the constructor of parent class
super().__init__(src, tgt, src_modalities, tgt_modalities, img_transform=img_transform)
self.dropout_rate = dropout_rate
self.dropout_strategy = dropout_strategy
print(img_transform)
@cached_property
def imputer_src(self) -> FrequencyImputer:
''' imputer object '''
return FrequencyImputer(self.src_modalities, self.src)
@cached_property
def imputer_tgt(self) -> ConstantImputer:
''' imputer object '''
return ConstantImputer(self.tgt_modalities)
@cached_property
def masker_src(self) -> DropoutMasker:
''' mask generator object '''
return DropoutMasker(
self.src_modalities, self.src,
dropout_rate = self.dropout_rate,
dropout_strategy = self.dropout_strategy,
)
@cached_property
def masker_tgt(self) -> LabelMasker:
''' mask generator object '''
return LabelMasker(self.tgt_modalities)
class TransformerValidationDataset(TransformerDataset):
def __init__(self,
src: list[dict[str, Any]],
tgt: list[dict[str, Any]],
src_modalities: dict[str, dict[str, Any]],
tgt_modalities: dict[str, dict[str, Any]],
img_transform: Any | None = None,
is_embedding: dict[str, bool] | None = None
) -> None:
''' ... '''
# call the constructor of parent class
super().__init__(src, tgt, src_modalities, tgt_modalities, img_transform=img_transform, is_embedding=is_embedding)
@cached_property
def imputer_src(self) -> ConstantImputer:
''' imputer object '''
return ConstantImputer(self.src_modalities, self.is_embedding)
@cached_property
def imputer_tgt(self) -> ConstantImputer:
''' imputer object '''
return ConstantImputer(self.tgt_modalities)
@cached_property
def masker_src(self) -> MissingMasker:
''' mask generator object '''
return MissingMasker(self.src_modalities)
@cached_property
def masker_tgt(self) -> LabelMasker:
''' mask generator object '''
return LabelMasker(self.tgt_modalities)
class TransformerTestingDataset(TransformerValidationDataset):
def __init__(self,
src: list[dict[str, Any]],
src_modalities: dict[str, dict[str, Any]],
img_transform: Any | None = None,
is_embedding: dict[str, bool] | None = None
) -> None:
''' ... '''
# call the constructor of parent class
super().__init__(src, None, src_modalities, None, img_transform=img_transform, is_embedding=is_embedding)
def __getitem__(self,
idx: int
) -> tuple[
dict[str, int | NDArray[np.float32]],
dict[str, bool],
]:
''' ... '''
x_imp, _, mask_x, _ = super().__getitem__(idx)
return x_imp, mask_x
@staticmethod
def collate_fn(
batch: list[
tuple[
dict[str, int | NDArray[np.float32]],
dict[str, bool],
]
]
) -> tuple[
dict[str, Tensor],
dict[str, Tensor],
]:
''' ... '''
# seperate entries
x = [smp[0] for smp in batch]
m = [smp[1] for smp in batch]
# stack and convert to tensor
x = {k: torch.as_tensor(np.array([x[i][k] for i in range(len(x))])) for k in x[0]}
m = {k: torch.as_tensor(np.array([m[i][k] for i in range(len(m))])) for k in m[0]}
return x, m
class TransformerBalancedTrainingDataset(TransformerTrainingDataset):
def __init__(self,
src: list[dict[str, Any]],
tgt: list[dict[str, Any]],
src_modalities: dict[str, dict[str, Any]],
tgt_modalities: dict[str, dict[str, Any]],
dropout_rate: float = .5,
dropout_strategy: str = 'permutation',
img_transform: Any | None = None,
) -> None:
''' ... '''
# call the constructor of parent class
super().__init__(
src, tgt, src_modalities, tgt_modalities,
dropout_rate, dropout_strategy, img_transform,
)
# for each target/label, collect the indices of available cases
self.tgt_indices: dict[str, dict[int, list[int]]] = dict()
for tgt_k in self.tgt_modalities:
tmp = [self.tgt[i][tgt_k] for i in range(len(self.tgt))]
self.tgt_indices[tgt_k] = dict()
self.tgt_indices[tgt_k][0] = [i for i in range(len(self.tgt)) if tmp[i] == 0]
self.tgt_indices[tgt_k][1] = [i for i in range(len(self.tgt)) if tmp[i] == 1]
def __getitem__(self,
idx: int
) -> tuple[
dict[str, int | NDArray[np.float32]],
dict[str, int | NDArray[np.float32]],
dict[str, bool],
dict[str, bool],
]:
# select random target, class and index
tgt_k = random.choice(list(self.tgt_modalities.keys()))
cls = random.choice([0, 1])
idx = random.choice(self.tgt_indices[tgt_k][cls])
# call __getitem__ of super class
x_imp, y_imp, mask_x, mask_y = super().__getitem__(idx)
# modify mask_y, all targets are masked except tgt_k
mask_y = {k: mask_y[k] if k == tgt_k else 0 for k in self.tgt_modalities}
# mask_y[tgt_k] = mask_y[k]
return x_imp, y_imp, mask_x, mask_y
class Transformer2ndOrderBalancedTrainingDataset(TransformerTrainingDataset):
def __init__(self,
src: list[dict[str, Any]],
tgt: list[dict[str, Any]],
src_modalities: dict[str, dict[str, Any]],
tgt_modalities: dict[str, dict[str, Any]],
dropout_rate: float = .5,
dropout_strategy: str = 'permutation',
img_transform: Any | None = None,
) -> None:
""" ... """
# call the constructor of parent class
super().__init__(
src, tgt, src_modalities, tgt_modalities,
dropout_rate, dropout_strategy, img_transform,
)
# construct dictionary of paired tasks
self.tasks: dict[tuple[str, str], list[int]] = {}
tgt_keys = list(self.tgt_modalities.keys())
for tgt_k_0 in tgt_keys:
for tgt_k_1 in tgt_keys:
self.tasks[(tgt_k_0, tgt_k_1)] = []
for i, smp in enumerate(tgt):
for tgt_k_0 in tgt_keys:
for tgt_k_1 in tgt_keys:
if smp[tgt_k_0] == 0 and smp[tgt_k_1] == 1:
self.tasks[(tgt_k_0, tgt_k_1)].append(i)
def __getitem__(self,
idx: int
) -> tuple[
dict[str, int | NDArray[np.float32]],
dict[str, int | NDArray[np.float32]],
dict[str, bool],
dict[str, bool],
]:
# select random task
while True:
tgt_k_0 = random.choice(list(self.tgt_modalities.keys()))
tgt_k_1 = random.choice(list(self.tgt_modalities.keys()))
if len(self.tasks[(tgt_k_0, tgt_k_1)]) != 0:
idx = random.choice(self.tasks[(tgt_k_0, tgt_k_1)])
break
# call __getitem__ of super class
x_imp, y_imp, mask_x, mask_y = super().__getitem__(idx)
# modify mask_y, all targets are masked except tgt_k
mask_y = {k: mask_y[k] if k in [tgt_k_0, tgt_k_1] else 0 for k in self.tgt_modalities}
return x_imp, y_imp, mask_x, mask_y