nmed2024 / adrd /nn /unet_img_model.py
xf3227's picture
ok
6fc43ab
raw
history blame
7.47 kB
from pyexpat import features
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import numpy as np
import re
from icecream import ic
import math
import torch.nn.utils.weight_norm as weightNorm
# from . import UNet3DBase
from .unet_3d import UNet3DBase
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
class feat_classifier(nn.Module):
def __init__(self, class_num, bottleneck_dim=256, type="linear"):
super(feat_classifier, self).__init__()
self.type = type
# if type in ['conv', 'gap'] and len(bottleneck_dim) > 3:
# bottleneck_dim = bottleneck_dim[-3:]
ic(bottleneck_dim)
if type == 'wn':
self.layer = weightNorm(
nn.Linear(bottleneck_dim[1:], class_num), name="weight")
# self.fc.apply(init_weights)
elif type == 'gap':
if len(bottleneck_dim) > 3:
bottleneck_dim = bottleneck_dim[-3:]
self.layer = nn.AvgPool3d(bottleneck_dim, stride=(1,1,1))
elif type == 'conv':
if len(bottleneck_dim) > 3:
bottleneck_dim = bottleneck_dim[-4:]
ic(bottleneck_dim)
self.layer = nn.Conv3d(bottleneck_dim[0], class_num, kernel_size=bottleneck_dim[1:])
ic(self.layer)
else:
print('bottleneck dim: ', bottleneck_dim)
self.layer = nn.Sequential(
torch.nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(math.prod(bottleneck_dim), class_num)
)
self.layer.apply(init_weights)
def forward(self, x):
# print('=> feat_classifier forward')
# ic(x.size())
x = self.layer(x)
# ic(x.size())
if self.type in ['gap','conv']:
x = torch.squeeze(x)
if len(x.shape) < 2:
x = torch.unsqueeze(x,0)
# print('returning x: ', x.size())
return x
class ImageModel(nn.Module):
"""
Empirical Risk Minimization (ERM)
"""
def __init__(
self,
counts=None,
classifier='gap',
accum_iter=8,
save_emb=False,
# ssl,
num_classes=1,
load_img_ckpt=False,
):
super(ImageModel, self).__init__()
if counts is not None:
if isinstance(counts[0], list):
counts = np.stack(counts, axis=0).sum(axis=0)
print('counts: ', counts)
total = np.sum(counts)
print(total/counts)
self.weight = total/torch.FloatTensor(counts)
else:
total = sum(counts)
self.weight = torch.FloatTensor([total/c for c in counts])
else:
self.weight = None
print('weight: ', self.weight)
# device = torch.device(f'cuda:{args.gpu_id}' if args.gpu_id is not None else 'cpu')
self.criterion = nn.CrossEntropyLoss(weight=self.weight)
# if ssl:
# # add contrastive loss
# # self.ssl_criterion =
# pass
self.featurizer = UNet3DBase(n_class=num_classes, attention=True, pretrained=load_img_ckpt)
self.classifier = feat_classifier(
num_classes, self.featurizer.in_features, classifier)
self.network = nn.Sequential(
self.featurizer, self.classifier)
self.accum_iter = accum_iter
self.acc_steps = 0
self.save_embedding = save_emb
def update(self, minibatches, opt, sch, scaler):
print('--------------def update----------------')
device = list(self.parameters())[0].device
all_x = torch.cat([data[1].to(device).float() for data in minibatches])
all_y = torch.cat([data[2].to(device).long() for data in minibatches])
print('all_x: ', all_x.size())
# all_p = self.predict(all_x)
# all_probs =
label_list = all_y.tolist()
count = float(len(label_list))
ic(count)
uniques = sorted(list(set(label_list)))
ic(uniques)
counts = [float(label_list.count(i)) for i in uniques]
ic(counts)
weights = [count / c for c in counts]
ic(weights)
with autocast():
loss = self.criterion(self.predict(all_x), all_y)
self.acc_steps += 1
print('class: ', loss.item())
scaler.scale(loss / self.accum_iter).backward()
if self.acc_steps == self.accum_iter:
scaler.step(opt)
if sch:
sch.step()
scaler.update()
self.zero_grad()
self.acc_steps = 0
torch.cuda.empty_cache()
del all_x
del all_y
return {'class': loss.item()}, sch
def forward(self, *args, **kwargs):
return self.network(*args, **kwargs)
def predict(self, x, stage='normal', attention=False):
# print('network device: ', list(self.network.parameters())[0].device)
# print('x device: ', x.device)
if stage == 'get_features' or self.save_embedding:
feats = self.network[0](x, attention=attention)
output = self.network[1](feats[-1] if attention else feats)
return feats, output
else:
return self.network(x)
def extract_features(self, x, attention=False):
feats = self.network[0](x, attention=attention)
return feats
def load_checkpoint(self, state_dict):
try:
self.load_checkpoint_helper(state_dict)
except:
featurizer_dict = {}
net_dict = {}
for key,val in state_dict.items():
if 'featurizer' in key:
featurizer_dict[key] = val
elif 'network' in key:
net_dict[key] = val
self.featurizer.load_state_dict(featurizer_dict)
self.classifier.load_state_dict(net_dict)
def load_checkpoint_helper(self, state_dict):
try:
self.load_state_dict(state_dict)
print('try: loaded')
except RuntimeError as e:
print('--> except')
if 'Missing key(s) in state_dict:' in str(e):
state_dict = {
key.replace('module.', '', 1): value
for key, value in state_dict.items()
}
state_dict = {
key.replace('featurizer.', '', 1).replace('classifier.','',1): value
for key, value in state_dict.items()
}
state_dict = {
re.sub('network.[0-9].', '', key): value
for key, value in state_dict.items()
}
try:
del state_dict['criterion.weight']
except:
pass
self.load_state_dict(state_dict)
print('except: loaded')