import torch from torch import nn from torch.nn import functional as F __version__ = "0.5.1" from .utils import (BlockArgs, BlockDecoder, GlobalParams, MemoryEfficientSwish, Swish, drop_connect, efficientnet, efficientnet_params, get_model_params, get_same_padding_conv2d, gram_matrix, load_pretrained_weights, round_filters, round_repeats) class MBConvBlock(nn.Module): """ Mobile Inverted Residual Bottleneck Block Args: block_args (namedtuple): BlockArgs, see above global_params (namedtuple): GlobalParam, see above Attributes: has_se (bool): Whether the block contains a Squeeze and Excitation layer. """ def __init__(self, block_args, global_params): super().__init__() self._block_args = block_args self._bn_mom = 1 - global_params.batch_norm_momentum self._bn_eps = global_params.batch_norm_epsilon self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) self.id_skip = block_args.id_skip # skip connection and drop connect # Get static or dynamic convolution depending on image size Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) # Expansion phase inp = self._block_args.input_filters # number of input channels oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels if self._block_args.expand_ratio != 1: self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # Depthwise convolution phase k = self._block_args.kernel_size s = self._block_args.stride self._depthwise_conv = Conv2d( in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise kernel_size=k, stride=s, bias=False) self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # Squeeze and Excitation layer, if desired if self.has_se: num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) # Output phase final_oup = self._block_args.output_filters self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) self._swish = MemoryEfficientSwish() def forward(self, inputs, drop_connect_rate=None): """ :param inputs: input tensor :param drop_connect_rate: drop connect rate (float, between 0 and 1) :return: output of block """ # Expansion and Depthwise Convolution x = inputs if self._block_args.expand_ratio != 1: x = self._swish(self._bn0(self._expand_conv(inputs))) x = self._swish(self._bn1(self._depthwise_conv(x))) # Squeeze and Excitation if self.has_se: x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) x = torch.sigmoid(x_squeezed) * x x = self._bn2(self._project_conv(x)) # Skip connection and drop connect input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: if drop_connect_rate: x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection return x def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export)""" self._swish = MemoryEfficientSwish() if memory_efficient else Swish() class EfficientNet(nn.Module): """ An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods Args: blocks_args (list): A list of BlockArgs to construct blocks global_params (namedtuple): A set of GlobalParams shared between blocks Example: model = EfficientNet.from_pretrained('efficientnet-b0') """ def __init__(self, type, blocks_args=None, global_params=None): super().__init__() assert isinstance(blocks_args, list), 'blocks_args should be a list' assert len(blocks_args) > 0, 'block args must be greater than 0' self._global_params = global_params self._blocks_args = blocks_args self.type = type # Get static or dynamic convolution depending on image size Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) # Batch norm parameters bn_mom = 1 - self._global_params.batch_norm_momentum bn_eps = self._global_params.batch_norm_epsilon # Stem in_channels = 5 # rgb out_channels = round_filters(32, self._global_params) # number of output channels self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) # Build blocks self._blocks = nn.ModuleList([]) for block_args in self._blocks_args: # Update block input and output filters based on depth multiplier. block_args = block_args._replace( input_filters=round_filters(block_args.input_filters, self._global_params), output_filters=round_filters(block_args.output_filters, self._global_params), num_repeat=round_repeats(block_args.num_repeat, self._global_params) ) # The first block needs to take care of stride and filter size increase. self._blocks.append(MBConvBlock(block_args, self._global_params)) if block_args.num_repeat > 1: block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) for _ in range(block_args.num_repeat - 1): self._blocks.append(MBConvBlock(block_args, self._global_params)) # Head in_channels = block_args.output_filters # output of final block out_channels = round_filters(1280, self._global_params) self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) # Final linear layer self._avg_pooling = nn.AdaptiveAvgPool2d(1) self._dropout = nn.Dropout(self._global_params.dropout_rate) self._fc = nn.Linear(out_channels, 1) self._swish = MemoryEfficientSwish() self.conv_reg = nn.Conv2d(1792, 1, 1) if self.type == 'big_map' or self.type == 'img': self.conv_transe1 = nn.Conv2d(1792, 448, 1) self.bn_transe1 = nn.BatchNorm2d(num_features=448, momentum=bn_mom, eps=bn_eps) self.conv_transe2 = nn.Conv2d(448, 112, 1) self.bn_transe2 = nn.BatchNorm2d(num_features=112, momentum=bn_mom, eps=bn_eps) if self.type == 'big_map': self.conv_transe_mask = nn.Conv2d(112, 1, 1) self.deconv_big = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose if self.type == 'img': self.conv_transe3 = nn.Conv2d(112, 3, 1) self.deconv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose elif self.type == 'deconv_map' or self.type == 'deconv_img': self.conv_big_reg = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose self.conv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose else: self.conv_reg = nn.Conv2d(1792, 1, 1) self.relu = nn.ReLU() self.up_double = nn.Upsample(scale_factor=2, mode='bilinear') self.sig = nn.Sigmoid() def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export)""" self._swish = MemoryEfficientSwish() if memory_efficient else Swish() for block in self._blocks: block.set_swish(memory_efficient) def extract_features(self, inputs): """ Returns output of the final convolution layer """ # Stem x = self._swish(self._bn0(self._conv_stem(inputs))) # Blocks for idx, block in enumerate(self._blocks): drop_connect_rate = self._global_params.drop_connect_rate if drop_connect_rate: drop_connect_rate *= float(idx) / len(self._blocks) x = block(x, drop_connect_rate=drop_connect_rate) # Head x = self._swish(self._bn1(self._conv_head(x))) return x def forward(self, seg, label, natural): label = label.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(seg.size()) x = torch.cat((label, natural, seg), 1) # concated input bs = seg.size(0) # Convolution layers x = self.extract_features(x) if self.type == 'map': reg = self.conv_reg(x) reg = self.sig(reg) elif self.type == 'big_map': reg = self.up_double(x) # 12*14 reg = self.relu(reg) reg = self.conv_transe1(reg) # 448 reg = self.bn_transe1(reg) reg = self.up_double(reg) # 24*28 reg = self.relu(reg) reg = self.conv_transe2(reg) # 112 reg = self.bn_transe2(reg) reg = self.conv_transe_mask(reg) # 1 reg = self.sig(reg) elif self.type == 'img': reg = self.up_double(x) # 12*14 reg = self.relu(reg) reg = self.conv_transe1(reg) # 448 reg = self.bn_transe1(reg) reg = self.up_double(reg) # 24*28 reg = self.relu(reg) reg = self.conv_transe2(reg) # 112 reg = self.bn_transe2(reg) reg = self.conv_transe3(reg) # 3 reg = self.sig(reg) elif self.type == 'deconv_map': reg = self.conv_big_reg(x) reg = self.sig(reg) elif self.type == 'deconv_img': reg = self.conv_img(x) reg = self.sig(reg) elif self.type == 'feature': reg = gram_matrix(x - x.mean(0, True)) return reg @classmethod def from_name(cls, model_name, type, override_params=None): cls._check_model_name_is_valid(model_name) blocks_args, global_params = get_model_params(model_name, override_params) return cls(type, blocks_args, global_params) @classmethod def from_pretrained(cls, model_name, num_classes=1000, in_channels=3): model = cls.from_name(model_name, override_params={'num_classes': num_classes}) load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) if in_channels != 3: Conv2d = get_same_padding_conv2d(image_size=model._global_params.image_size) out_channels = round_filters(32, model._global_params) model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) return model @classmethod def from_pretrained(cls, model_name, num_classes=1000): model = cls.from_name(model_name, override_params={'num_classes': num_classes}) load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) return model @classmethod def get_image_size(cls, model_name): cls._check_model_name_is_valid(model_name) _, _, res, _ = efficientnet_params(model_name) return res @classmethod def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False): """ Validates model name. None that pretrained weights are only available for the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """ num_models = 4 if also_need_pretrained_weights else 8 valid_models = ['efficientnet-b' + str(i) for i in range(num_models)] if model_name not in valid_models: raise ValueError('model_name should be one of: ' + ', '.join(valid_models))