nmed2024 / adrd /nn /img_model_wrapper.py
xf3227's picture
ok
6fc43ab
import torch
from .. import nn
from .. import model
import numpy as np
from icecream import ic
from monai.networks.nets.swin_unetr import SwinUNETR
from typing import Any
class ImagingModelWrapper(torch.nn.Module):
def __init__(
self,
arch: str = 'ViTAutoEnc',
tgt_modalities: dict | None = {},
img_size: int | None = 128,
patch_size: int | None = 16,
ckpt_path: str | None = None,
train_backbone: bool = False,
out_dim: int = 128,
layers: int | None = 1,
device: str = 'cpu',
fusion_stage: str = 'middle',
):
super(ImagingModelWrapper, self).__init__()
self.arch = arch
self.tgt_modalities = tgt_modalities
self.img_size = img_size
self.patch_size = patch_size
self.train_backbone = train_backbone
self.ckpt_path = ckpt_path
self.device = device
self.out_dim = out_dim
self.layers = layers
self.fusion_stage = fusion_stage
if "swinunetr" in self.arch.lower():
if "emb" not in self.arch.lower():
ckpt_path = '/projectnb/ivc-ml/dlteif/pretrained_models/model_swinvit.pt'
ckpt = torch.load(ckpt_path, map_location='cpu')
self.img_model = SwinUNETR(
in_channels=1,
out_channels=1,
img_size=128,
feature_size=48,
use_checkpoint=True,
)
ckpt["state_dict"] = {k.replace("swinViT.", "module."): v for k, v in ckpt["state_dict"].items()}
ic(ckpt["state_dict"].keys())
self.img_model.load_from(ckpt)
self.dim = 768
elif "vit" in self.arch.lower():
if "emb" not in self.arch.lower():
# Initialize image model
self.img_model = nn.__dict__[self.arch](
in_channels = 1,
img_size = self.img_size,
patch_size = self.patch_size,
)
if self.ckpt_path:
self.img_model.load(self.ckpt_path, map_location=self.device)
self.dim = self.img_model.hidden_size
else:
self.dim = 768
if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower():
dim = self.dim
if self.fusion_stage == 'middle':
downsample = torch.nn.ModuleList()
# print('Number of layers: ', self.layers)
for i in range(self.layers):
if i == self.layers - 1:
dim_out = self.out_dim
# print(layers)
ks = 2
stride = 2
else:
dim_out = dim // 2
ks = 2
stride = 2
downsample.append(
torch.nn.Conv1d(in_channels=dim, out_channels=dim_out, kernel_size=ks, stride=stride)
)
dim = dim_out
downsample.append(
torch.nn.BatchNorm1d(dim)
)
downsample.append(
torch.nn.ReLU()
)
self.downsample = torch.nn.Sequential(*downsample)
elif self.fusion_stage == 'late':
self.downsample = torch.nn.Identity()
else:
pass
# print('Downsample layers: ', self.downsample)
elif "densenet" in self.arch.lower():
if "emb" not in self.arch.lower():
self.img_model = model.ImagingModel.from_ckpt(self.ckpt_path, device=self.device, img_backend=self.arch, load_from_ckpt=True).net_
self.downsample = torch.nn.Linear(3900, self.out_dim)
# randomly initialize weights for downsample block
for p in self.downsample.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)
p.requires_grad = True
if "emb" not in self.arch.lower():
# freeze imaging model parameters
if "densenet" in self.arch.lower():
for n, p in self.img_model.features.named_parameters():
if not self.train_backbone:
p.requires_grad = False
else:
p.requires_grad = True
for n, p in self.img_model.tgt.named_parameters():
p.requires_grad = False
else:
for n, p in self.img_model.named_parameters():
# print(n, p.requires_grad)
if not self.train_backbone:
p.requires_grad = False
else:
p.requires_grad = True
def forward(self, x):
# print("--------ImagingModelWrapper forward--------")
if "emb" not in self.arch.lower():
if "swinunetr" in self.arch.lower():
# print(x.size())
out = self.img_model(x)
# print(out.size())
out = self.downsample(out)
# print(out.size())
out = torch.mean(out, dim=-1)
# print(out.size())
elif "vit" in self.arch.lower():
out = self.img_model(x, return_emb=True)
ic(out.size())
out = self.downsample(out)
out = torch.mean(out, dim=-1)
elif "densenet" in self.arch.lower():
out = torch.nn.Sequential(*list(self.img_model.features.children()))(x)
# print(out.size())
out = torch.flatten(out, 1)
out = self.downsample(out)
else:
# print(x.size())
if "swinunetr" in self.arch.lower():
x = torch.squeeze(x, dim=1)
x = x.view(x.size(0),self.dim, -1)
# print('x: ', x.size())
out = self.downsample(x)
# print('out: ', out.size())
if self.fusion_stage == 'middle':
if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower():
out = torch.mean(out, dim=-1)
else:
out = torch.squeeze(out, dim=1)
elif self.fusion_stage == 'late':
pass
return out