File size: 7,468 Bytes
6fc43ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
from pyexpat import features
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import numpy as np
import re
from icecream import ic
import math
import torch.nn.utils.weight_norm as weightNorm
# from . import UNet3DBase
from .unet_3d import UNet3DBase
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
class feat_classifier(nn.Module):
def __init__(self, class_num, bottleneck_dim=256, type="linear"):
super(feat_classifier, self).__init__()
self.type = type
# if type in ['conv', 'gap'] and len(bottleneck_dim) > 3:
# bottleneck_dim = bottleneck_dim[-3:]
ic(bottleneck_dim)
if type == 'wn':
self.layer = weightNorm(
nn.Linear(bottleneck_dim[1:], class_num), name="weight")
# self.fc.apply(init_weights)
elif type == 'gap':
if len(bottleneck_dim) > 3:
bottleneck_dim = bottleneck_dim[-3:]
self.layer = nn.AvgPool3d(bottleneck_dim, stride=(1,1,1))
elif type == 'conv':
if len(bottleneck_dim) > 3:
bottleneck_dim = bottleneck_dim[-4:]
ic(bottleneck_dim)
self.layer = nn.Conv3d(bottleneck_dim[0], class_num, kernel_size=bottleneck_dim[1:])
ic(self.layer)
else:
print('bottleneck dim: ', bottleneck_dim)
self.layer = nn.Sequential(
torch.nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(math.prod(bottleneck_dim), class_num)
)
self.layer.apply(init_weights)
def forward(self, x):
# print('=> feat_classifier forward')
# ic(x.size())
x = self.layer(x)
# ic(x.size())
if self.type in ['gap','conv']:
x = torch.squeeze(x)
if len(x.shape) < 2:
x = torch.unsqueeze(x,0)
# print('returning x: ', x.size())
return x
class ImageModel(nn.Module):
"""
Empirical Risk Minimization (ERM)
"""
def __init__(
self,
counts=None,
classifier='gap',
accum_iter=8,
save_emb=False,
# ssl,
num_classes=1,
load_img_ckpt=False,
):
super(ImageModel, self).__init__()
if counts is not None:
if isinstance(counts[0], list):
counts = np.stack(counts, axis=0).sum(axis=0)
print('counts: ', counts)
total = np.sum(counts)
print(total/counts)
self.weight = total/torch.FloatTensor(counts)
else:
total = sum(counts)
self.weight = torch.FloatTensor([total/c for c in counts])
else:
self.weight = None
print('weight: ', self.weight)
# device = torch.device(f'cuda:{args.gpu_id}' if args.gpu_id is not None else 'cpu')
self.criterion = nn.CrossEntropyLoss(weight=self.weight)
# if ssl:
# # add contrastive loss
# # self.ssl_criterion =
# pass
self.featurizer = UNet3DBase(n_class=num_classes, attention=True, pretrained=load_img_ckpt)
self.classifier = feat_classifier(
num_classes, self.featurizer.in_features, classifier)
self.network = nn.Sequential(
self.featurizer, self.classifier)
self.accum_iter = accum_iter
self.acc_steps = 0
self.save_embedding = save_emb
def update(self, minibatches, opt, sch, scaler):
print('--------------def update----------------')
device = list(self.parameters())[0].device
all_x = torch.cat([data[1].to(device).float() for data in minibatches])
all_y = torch.cat([data[2].to(device).long() for data in minibatches])
print('all_x: ', all_x.size())
# all_p = self.predict(all_x)
# all_probs =
label_list = all_y.tolist()
count = float(len(label_list))
ic(count)
uniques = sorted(list(set(label_list)))
ic(uniques)
counts = [float(label_list.count(i)) for i in uniques]
ic(counts)
weights = [count / c for c in counts]
ic(weights)
with autocast():
loss = self.criterion(self.predict(all_x), all_y)
self.acc_steps += 1
print('class: ', loss.item())
scaler.scale(loss / self.accum_iter).backward()
if self.acc_steps == self.accum_iter:
scaler.step(opt)
if sch:
sch.step()
scaler.update()
self.zero_grad()
self.acc_steps = 0
torch.cuda.empty_cache()
del all_x
del all_y
return {'class': loss.item()}, sch
def forward(self, *args, **kwargs):
return self.network(*args, **kwargs)
def predict(self, x, stage='normal', attention=False):
# print('network device: ', list(self.network.parameters())[0].device)
# print('x device: ', x.device)
if stage == 'get_features' or self.save_embedding:
feats = self.network[0](x, attention=attention)
output = self.network[1](feats[-1] if attention else feats)
return feats, output
else:
return self.network(x)
def extract_features(self, x, attention=False):
feats = self.network[0](x, attention=attention)
return feats
def load_checkpoint(self, state_dict):
try:
self.load_checkpoint_helper(state_dict)
except:
featurizer_dict = {}
net_dict = {}
for key,val in state_dict.items():
if 'featurizer' in key:
featurizer_dict[key] = val
elif 'network' in key:
net_dict[key] = val
self.featurizer.load_state_dict(featurizer_dict)
self.classifier.load_state_dict(net_dict)
def load_checkpoint_helper(self, state_dict):
try:
self.load_state_dict(state_dict)
print('try: loaded')
except RuntimeError as e:
print('--> except')
if 'Missing key(s) in state_dict:' in str(e):
state_dict = {
key.replace('module.', '', 1): value
for key, value in state_dict.items()
}
state_dict = {
key.replace('featurizer.', '', 1).replace('classifier.','',1): value
for key, value in state_dict.items()
}
state_dict = {
re.sub('network.[0-9].', '', key): value
for key, value in state_dict.items()
}
try:
del state_dict['criterion.weight']
except:
pass
self.load_state_dict(state_dict)
print('except: loaded') |