""" This file contains helper functions for building the model and for loading model parameters. These helper functions are built to mirror those in the official TensorFlow implementation. """ import re import math import collections from functools import partial import torch from torch import nn from torch.nn import functional as F from torch.utils import model_zoo ######################################################################## ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### ######################################################################## # Parameters for the entire model (stem, all blocks, and head) GlobalParams = collections.namedtuple('GlobalParams', [ 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) # Parameters for an individual model block BlockArgs = collections.namedtuple('BlockArgs', [ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) # Change namedtuple defaults GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) #nnunet pkg softmax_helper = lambda x: F.softmax(x, 1) sigmoid_helper = lambda x: F.sigmoid(x) class InitWeights_He(object): def __init__(self, neg_slope=1e-2): self.neg_slope = neg_slope def __call__(self, module): if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) if module.bias is not None: module.bias = nn.init.constant_(module.bias, 0) def maybe_to_torch(d): if isinstance(d, list): d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d] elif not isinstance(d, torch.Tensor): d = torch.from_numpy(d).float() return d def to_cuda(data, non_blocking=True, gpu_id=0): if isinstance(data, list): data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data] else: data = data.cuda(gpu_id, non_blocking=non_blocking) return data class no_op(object): def __enter__(self): pass def __exit__(self, *args): pass class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): result = i * torch.sigmoid(i) ctx.save_for_backward(i) return result @staticmethod def backward(ctx, grad_output): i = ctx.saved_variables[0] sigmoid_i = torch.sigmoid(i) return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) class MemoryEfficientSwish(nn.Module): def forward(self, x): return SwishImplementation.apply(x) class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) def round_filters(filters, global_params): """ Calculate and round number of filters based on depth multiplier. """ multiplier = global_params.width_coefficient if not multiplier: return filters divisor = global_params.depth_divisor min_depth = global_params.min_depth filters *= multiplier min_depth = min_depth or divisor new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) if new_filters < 0.9 * filters: # prevent rounding by more than 10% new_filters += divisor return int(new_filters) def round_repeats(repeats, global_params): """ Round number of filters based on depth multiplier. """ multiplier = global_params.depth_coefficient if not multiplier: return repeats return int(math.ceil(multiplier * repeats)) def drop_connect(inputs, p, training): """ Drop connect. """ if not training: return inputs batch_size = inputs.shape[0] keep_prob = 1 - p random_tensor = keep_prob random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) binary_tensor = torch.floor(random_tensor) output = inputs / keep_prob * binary_tensor return output def get_same_padding_conv2d(image_size=None): """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. Static padding is necessary for ONNX exporting of models. """ if image_size is None: return Conv2dDynamicSamePadding else: return partial(Conv2dStaticSamePadding, image_size=image_size) def get_same_padding_conv2d_freeze(image_size=None): """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. Static padding is necessary for ONNX exporting of models. """ if image_size is None: return Conv2dStaticSamePadding_freeze else: return partial(Conv2dStaticSamePadding_freeze, image_size=image_size) class Conv2dDynamicSamePadding(nn.Conv2d): """ 2D Convolutions like TensorFlow, for a dynamic image size """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 def forward(self, x): ih, iw = x.size()[-2:] kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) class Conv2dStaticSamePadding(nn.Conv2d): """ 2D Convolutions like TensorFlow, for a fixed image size""" def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): super().__init__(in_channels, out_channels, kernel_size, **kwargs) self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 # Calculate padding based on image size and save it assert image_size is not None ih, iw = image_size if type(image_size) == list else [image_size, image_size] kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) else: self.static_padding = Identity() def forward(self, x): x = self.static_padding(x) x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x def Conv2dStaticSamePadding_freeze(inputs, weight, bias=None, image_size=None, stride=(1,1), padding=0, dilation=(1,1), groups=1): """ 2D Convolutions like TensorFlow, for a fixed image size""" if type(stride) == int: stride = [stride] * 2 else: stride = stride if len(stride) == 2 else [stride[0]] * 2 # Calculate padding based on image size and save it assert image_size is not None ih, iw = image_size if type(image_size) == list else [image_size, image_size] kh, kw = weight.size()[-2:] sh, sw = stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max((oh - 1) * stride[0] + (kh - 1) * dilation[0] + 1 - ih, 0) pad_w = max((ow - 1) * stride[1] + (kw - 1) * dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) else: static_padding = Identity() x = static_padding(inputs) x = F.conv2d(x, weight, bias, stride, padding, dilation, groups) return x class Identity(nn.Module): def __init__(self, ): super(Identity, self).__init__() def forward(self, input): return input ######################################################################## ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## ######################################################################## def efficientnet_params(model_name): """ Map EfficientNet model name to parameter coefficients. """ params_dict = { # Coefficients: width,depth,res,dropout 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 'efficientnet-b7': (2.0, 3.1, 600, 0.5), } return params_dict[model_name] class BlockDecoder(object): """ Block Decoder for readability, straight from the official TensorFlow repository """ @staticmethod def _decode_block_string(block_string): """ Gets a block through a string notation of arguments. """ assert isinstance(block_string, str) ops = block_string.split('_') options = {} for op in ops: splits = re.split(r'(\d.*)', op) if len(splits) >= 2: key, value = splits[:2] options[key] = value # Check stride assert (('s' in options and len(options['s']) == 1) or (len(options['s']) == 2 and options['s'][0] == options['s'][1])) return BlockArgs( kernel_size=int(options['k']), num_repeat=int(options['r']), input_filters=int(options['i']), output_filters=int(options['o']), expand_ratio=int(options['e']), id_skip=('noskip' not in block_string), se_ratio=float(options['se']) if 'se' in options else None, stride=[int(options['s'][0])]) @staticmethod def _encode_block_string(block): """Encodes a block to a string.""" args = [ 'r%d' % block.num_repeat, 'k%d' % block.kernel_size, 's%d%d' % (block.strides[0], block.strides[1]), 'e%s' % block.expand_ratio, 'i%d' % block.input_filters, 'o%d' % block.output_filters ] if 0 < block.se_ratio <= 1: args.append('se%s' % block.se_ratio) if block.id_skip is False: args.append('noskip') return '_'.join(args) @staticmethod def decode(string_list): """ Decodes a list of string notations to specify blocks inside the network. :param string_list: a list of strings, each string is a notation of block :return: a list of BlockArgs namedtuples of block args """ assert isinstance(string_list, list) blocks_args = [] for block_string in string_list: blocks_args.append(BlockDecoder._decode_block_string(block_string)) return blocks_args @staticmethod def encode(blocks_args): """ Encodes a list of BlockArgs to a list of strings. :param blocks_args: a list of BlockArgs namedtuples of block args :return: a list of strings, each string is a notation of block """ block_strings = [] for block in blocks_args: block_strings.append(BlockDecoder._encode_block_string(block)) return block_strings def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, drop_connect_rate=0.2, image_size=None, num_classes=1000): """ Creates a efficientnet model. """ blocks_args = [ 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 'r1_k3_s11_e6_i192_o320_se0.25', ] blocks_args = BlockDecoder.decode(blocks_args) global_params = GlobalParams( batch_norm_momentum=0.99, batch_norm_epsilon=1e-3, dropout_rate=dropout_rate, drop_connect_rate=drop_connect_rate, # data_format='channels_last', # removed, this is always true in PyTorch num_classes=num_classes, width_coefficient=width_coefficient, depth_coefficient=depth_coefficient, depth_divisor=8, min_depth=None, image_size=image_size, ) return blocks_args, global_params def get_model_params(model_name, override_params): """ Get the block args and global params for a given model """ if model_name.startswith('efficientnet'): w, d, s, p = efficientnet_params(model_name) # note: all models have drop connect rate = 0.2 blocks_args, global_params = efficientnet( width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) else: raise NotImplementedError('model name is not pre-defined: %s' % model_name) if override_params: # ValueError will be raised here if override_params has fields not included in global_params. global_params = global_params._replace(**override_params) return blocks_args, global_params url_map = { 'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth', 'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth', 'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth', 'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth', 'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth', 'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth', 'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth', 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth', } def load_pretrained_weights(model, model_name, load_fc=True): """ Loads pretrained weights, and downloads if loading for the first time. """ state_dict = model_zoo.load_url(url_map[model_name]) if load_fc: model.load_state_dict(state_dict) else: state_dict.pop('_fc.weight') state_dict.pop('_fc.bias') res = model.load_state_dict(state_dict, strict=False) assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' print('Loaded pretrained weights for {}'.format(model_name)) def gram_matrix(input): a, b, c, d = input.size() # a=batch size(=1) # b=number of feature maps # (c,d)=dimensions of a f. map (N=c*d) features = input.view(a * b, c * d) # resise F_XL into \hat F_XL G = torch.mm(features, features.t()) # compute the gram product # we 'normalize' the values of the gram matrix # by dividing by the number of element in each feature maps. return G.div(a * b * c * d)