nmed2024 / adrd /nn /transformer.py
xf3227's picture
ok
6fc43ab
raw
history blame
9.97 kB
import torch
import numpy as np
from .. import nn
# from ..nn import ImagingModelWrapper
from .net_resnet3d import r3d_18
from typing import Any, Type
import math
Tensor = Type[torch.Tensor]
from icecream import ic
ic.disable()
class Transformer(torch.nn.Module):
''' ... '''
def __init__(self,
src_modalities: dict[str, dict[str, Any]],
tgt_modalities: dict[str, dict[str, Any]],
d_model: int,
nhead: int,
num_encoder_layers: int = 1,
num_decoder_layers: int = 1,
device: str = 'cpu',
cuda_devices: list = [3],
img_net: str | None = None,
layers: int = 3,
img_size: int | None = 128,
patch_size: int | None = 16,
imgnet_ckpt: str | None = None,
train_imgnet: bool = False,
fusion_stage: str = 'middle',
) -> None:
''' ... '''
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.num_encoder_layers = num_encoder_layers
self.num_decoder_layers = num_decoder_layers
self.img_net = img_net
self.img_size = img_size
self.patch_size = patch_size
self.imgnet_ckpt = imgnet_ckpt
self.train_imgnet = train_imgnet
self.layers = layers
self.src_modalities = src_modalities
self.tgt_modalities = tgt_modalities
self.device = device
self.fusion_stage = fusion_stage
# embedding modules for source
self.modules_emb_src = torch.nn.ModuleDict()
print('Downsample layers: ', self.layers)
self.img_model = nn.ImagingModelWrapper(arch=self.img_net, img_size=self.img_size, patch_size=self.patch_size, ckpt_path=self.imgnet_ckpt, train_backbone=self.train_imgnet, layers=self.layers, out_dim=self.d_model, device=self.device, fusion_stage=self.fusion_stage)
for k, info in src_modalities.items():
# ic(k)
# for key, val in info.items():
# ic(key, val)
if info['type'] == 'categorical':
self.modules_emb_src[k] = torch.nn.Embedding(info['num_categories'], d_model)
elif info['type'] == 'numerical':
self.modules_emb_src[k] = torch.nn.Sequential(
torch.nn.BatchNorm1d(info['shape'][0]),
torch.nn.Linear(info['shape'][0], d_model)
)
elif info['type'] == 'imaging':
# print(info['shape'], info['img_shape'])
if self.img_net:
self.modules_emb_src[k] = self.img_model
else:
# unrecognized
raise ValueError('{} is an unrecognized data modality'.format(k))
# positional encoding
self.pe = PositionalEncoding(d_model)
# auxiliary embedding vectors for targets
self.emb_aux = torch.nn.Parameter(
torch.zeros(len(tgt_modalities), 1, d_model),
requires_grad = True,
)
# transformer
enc = torch.nn.TransformerEncoderLayer(
self.d_model, self.nhead,
dim_feedforward = self.d_model,
activation = 'gelu',
dropout = 0.3,
)
self.transformer = torch.nn.TransformerEncoder(enc, self.num_encoder_layers)
# classifiers (binary only)
self.modules_cls = torch.nn.ModuleDict()
for k, info in tgt_modalities.items():
if info['type'] == 'categorical' and info['num_categories'] == 2:
self.modules_cls[k] = torch.nn.Linear(d_model, 1)
else:
# unrecognized
raise ValueError
# for n,p in self.named_parameters():
# print(n, p.requires_grad)
def forward(self,
x: dict[str, Tensor],
mask: dict[str, Tensor],
# x_img: dict[str, Tensor] | Any = None,
skip_embedding: dict[str, bool] | None = None,
return_out_emb: bool = False,
) -> dict[str, Tensor]:
""" ... """
out_emb = self.forward_emb(x, mask, skip_embedding)
if self.fusion_stage == "late":
out_emb = {k: v for k,v in out_emb.items() if "img_MRI" not in k}
img_out_emb = {k: v for k,v in out_emb.items() if "img_MRI" in k}
# for k,v in out_emb.items():
# print(k, v.size())
mask_nonimg = {k: v for k,v in mask.items() if "img_MRI" not in k}
out_trf = self.forward_trf(out_emb, mask_nonimg) # (8,128) + (8,50,128)
# print("out_trf: ", out_trf.size())
out_trf = torch.concatenate()
else:
out_trf = self.forward_trf(out_emb, mask)
out_cls = self.forward_cls(out_trf)
if return_out_emb:
return out_emb, out_cls
return out_cls
def forward_emb(self,
x: dict[str, Tensor],
mask: dict[str, Tensor],
skip_embedding: dict[str, bool] | None = None,
# x_img: dict[str, Tensor] | Any = None,
) -> dict[str, Tensor]:
""" ... """
# print("-------forward_emb--------")
out_emb = dict()
for k in self.modules_emb_src.keys():
if skip_embedding is None or k not in skip_embedding or not skip_embedding[k]:
if "img_MRI" in k:
# print("img_MRI in ", k)
if torch.all(mask[k]):
if "swinunetr" in self.img_net.lower() and self.fusion_stage == 'late':
out_emb[k] = torch.zeros((1,768,4,4,4))
else:
if 'cuda' in self.device:
device = x[k].device
# print(device)
else:
device = self.device
out_emb[k] = torch.zeros((mask[k].shape[0], self.d_model)).to(device, non_blocking=True)
# print("mask is True, out_emb[k]: ", out_emb[k].size())
else:
# print("calling modules_emb_src...")
out_emb[k] = self.modules_emb_src[k](x[k])
# print("mask is False, out_emb[k]: ", out_emb[k].size())
else:
out_emb[k] = self.modules_emb_src[k](x[k])
# out_emb[k] = self.modules_emb_src[k](x[k])
else:
out_emb[k] = x[k]
return out_emb
def forward_trf(self,
out_emb: dict[str, Tensor],
mask: dict[str, Tensor],
) -> dict[str, Tensor]:
""" ... """
# print('-----------forward_trf----------')
N = len(next(iter(out_emb.values()))) # batch size
S = len(self.modules_emb_src) # number of sources
T = len(self.modules_cls) # number of targets
if self.fusion_stage == 'late':
src_iter = [k for k in self.modules_emb_src.keys() if "img_MRI" not in k]
S = len(src_iter) # number of sources
else:
src_iter = self.modules_emb_src.keys()
tgt_iter = self.modules_cls.keys()
emb_src = torch.stack([o for o in out_emb.values()], dim=0)
# print('emb_src: ', emb_src.size())
self.pe.index = -1
emb_src = self.pe(emb_src)
# print('emb_src + pe: ', emb_src.size())
# target embedding
# print('emb_aux: ', self.emb_aux.size())
emb_tgt = self.emb_aux.repeat(1, N, 1)
# print('emb_tgt: ', emb_tgt.size())
# concatenate source embeddings and target embeddings
emb_all = torch.concatenate((emb_tgt, emb_src), dim=0)
# combine masks
mask_src = [mask[k] for k in src_iter]
mask_src = torch.stack(mask_src, dim=1)
# target masks
mask_tgt = torch.zeros((N, T), dtype=torch.bool, device=self.emb_aux.device)
# concatenate source masks and target masks
mask_all = torch.concatenate((mask_tgt, mask_src), dim=1)
# repeat mask_all to fit transformer
mask_all = mask_all.unsqueeze(1).expand(-1, S + T, -1).repeat(self.nhead, 1, 1)
# run transformer
out_trf = self.transformer(
src = emb_all,
mask = mask_all,
)[0]
# print('out_trf: ', out_trf.size())
# out_trf = {k: out_trf[i] for i, k in enumerate(tgt_iter)}
return out_trf
def forward_cls(self,
out_trf: dict[str, Tensor],
) -> dict[str, Tensor]:
""" ... """
tgt_iter = self.modules_cls.keys()
out_cls = {k: self.modules_cls[k](out_trf).squeeze(1) for k in tgt_iter}
return out_cls
class PositionalEncoding(torch.nn.Module):
def __init__(self,
d_model: int,
max_len: int = 512
):
""" ... """
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
self.index = -1
def forward(self, x: Tensor, pe_type: str = 'non_img') -> Tensor:
"""
Arguments:
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
"""
# print('pe: ', self.pe.size())
# print('x: ', x.size())
if pe_type == 'img':
self.index += 1
return x + self.pe[self.index]
else:
self.index += 1
return x + self.pe[self.index:x.size(0)+self.index]
if __name__ == '__main__':
''' for testing purpose only '''
pass