Spaces:
Configuration error
Configuration error
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
__version__ = "0.5.1" | |
from .utils import ( | |
GlobalParams, | |
BlockArgs, | |
BlockDecoder, | |
efficientnet, | |
get_model_params, | |
) | |
from .utils import ( | |
round_filters, | |
round_repeats, | |
drop_connect, | |
get_same_padding_conv2d, | |
get_same_padding_conv2d_freeze, | |
get_model_params, | |
efficientnet_params, | |
load_pretrained_weights, | |
Swish, | |
MemoryEfficientSwish, | |
gram_matrix, | |
) | |
class MBConvBlock(nn.Module): | |
""" | |
Mobile Inverted Residual Bottleneck Block | |
Args: | |
block_args (namedtuple): BlockArgs, see above | |
global_params (namedtuple): GlobalParam, see above | |
Attributes: | |
has_se (bool): Whether the block contains a Squeeze and Excitation layer. | |
""" | |
def __init__(self, block_args, global_params): | |
super().__init__() | |
self._block_args = block_args | |
self._bn_mom = 1 - global_params.batch_norm_momentum | |
self._bn_eps = global_params.batch_norm_epsilon | |
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) | |
self.id_skip = block_args.id_skip # skip connection and drop connect | |
# Get static or dynamic convolution depending on image size | |
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) | |
# Expansion phase | |
inp = self._block_args.input_filters # number of input channels | |
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels | |
if self._block_args.expand_ratio != 1: | |
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) | |
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) | |
# Depthwise convolution phase | |
k = self._block_args.kernel_size | |
s = self._block_args.stride | |
self._depthwise_conv = Conv2d( | |
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise | |
kernel_size=k, stride=s, bias=False) | |
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) | |
# Squeeze and Excitation layer, if desired | |
if self.has_se: | |
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) | |
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) | |
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) | |
# Output phase | |
final_oup = self._block_args.output_filters | |
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) | |
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) | |
self._swish = MemoryEfficientSwish() | |
def forward(self, inputs, drop_connect_rate=None): | |
""" | |
:param inputs: input tensor | |
:param drop_connect_rate: drop connect rate (float, between 0 and 1) | |
:return: output of block | |
""" | |
# Expansion and Depthwise Convolution | |
x = inputs | |
if self._block_args.expand_ratio != 1: | |
x = self._swish(self._bn0(self._expand_conv(inputs))) | |
x = self._swish(self._bn1(self._depthwise_conv(x))) | |
# Squeeze and Excitation | |
if self.has_se: | |
x_squeezed = F.adaptive_avg_pool2d(x, 1) | |
x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) | |
x = torch.sigmoid(x_squeezed) * x | |
x = self._bn2(self._project_conv(x)) | |
# Skip connection and drop connect | |
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters | |
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: | |
if drop_connect_rate: | |
x = drop_connect(x, p=drop_connect_rate, training=self.training) | |
x = x + inputs # skip connection | |
return x | |
def set_swish(self, memory_efficient=True): | |
"""Sets swish function as memory efficient (for training) or standard (for export)""" | |
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() | |
class MBConvBlock_freeze(nn.Module): | |
""" | |
Mobile Inverted Residual Bottleneck Block | |
Args: | |
block_args (namedtuple): BlockArgs, see above | |
global_params (namedtuple): GlobalParam, see above | |
Attributes: | |
has_se (bool): Whether the block contains a Squeeze and Excitation layer. | |
""" | |
def __init__(self, block_args, index, device, global_params): | |
super().__init__() | |
self._block_args = block_args | |
self._bn_mom = 1 - global_params.batch_norm_momentum | |
self._bn_eps = global_params.batch_norm_epsilon | |
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) | |
self.id_skip = block_args.id_skip # skip connection and drop connect | |
self.Conv2d = get_same_padding_conv2d_freeze(image_size=global_params.image_size) | |
s = self._block_args.stride | |
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels | |
# Output phase | |
final_oup = self._block_args.output_filters | |
self._swish = MemoryEfficientSwish() | |
self.oup = oup | |
self.s = s | |
self.block_name = '_blocks.{:d}'.format(index) | |
self.device = device | |
def forward(self, inputs, weights, drop_connect_rate=None): | |
""" | |
:param inputs: input tensor | |
:param drop_connect_rate: drop connect rate (float, between 0 and 1) | |
:return: output of block | |
""" | |
# Expansion and Depthwise Convolution | |
# for (name,para) in weights.items(): | |
# print(name) if name.find('_expand_conv') else None | |
x = inputs | |
if self._block_args.expand_ratio != 1: | |
x = self.Conv2d(x, weights[self.block_name + '._expand_conv.weight']) | |
x = F.batch_norm(x, torch.zeros(x.data.size()[1]).to(self.device), | |
torch.ones(x.data.size()[1]).to(self.device), | |
weights[self.block_name + '._bn0.weight'], | |
weights[self.block_name + '._bn0.bias'], | |
training=True) | |
x = self.Conv2d(x, weights[self.block_name + '._depthwise_conv.weight'], groups = self.oup, stride=self.s) | |
x = F.batch_norm(x, torch.zeros(x.data.size()[1]).to(self.device), | |
torch.ones(x.data.size()[1]).to(self.device), | |
weights[self.block_name + '._bn1.weight'], | |
weights[self.block_name + '._bn1.bias'], | |
training=True) | |
# Squeeze and Excitation | |
if self.has_se: | |
x_squeezed = F.adaptive_avg_pool2d(x, 1) | |
x = self.Conv2d(x, weights[self.block_name + '._se_reduce.weight'],weights[self.block_name + '._se_reduce.bias']) | |
x = self.Conv2d(x, weights[self.block_name + '._se_expand.weight'], | |
weights[self.block_name + '._se_expand.bias']) | |
x = torch.sigmoid(x_squeezed) * x | |
x = self.Conv2d(x, weights[self.block_name + '._project_conv.weight']) | |
x = F.batch_norm(x, torch.zeros(x.data.size()[1]).to(self.device), | |
torch.ones(x.data.size()[1]).to(self.device), | |
weights[self.block_name + '._bn2.weight'], | |
weights[self.block_name + '._bn2.bias'], | |
training=True) | |
# Skip connection and drop connect | |
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters | |
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: | |
if drop_connect_rate: | |
x = drop_connect(x, p=drop_connect_rate, training=self.training) | |
x = x + inputs # skip connection | |
return x | |
def set_swish(self, memory_efficient=True): | |
"""Sets swish function as memory efficient (for training) or standard (for export)""" | |
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() | |
class EfficientNet(nn.Module): | |
""" | |
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods | |
Args: | |
blocks_args (list): A list of BlockArgs to construct blocks | |
global_params (namedtuple): A set of GlobalParams shared between blocks | |
Example: | |
model = EfficientNet.from_pretrained('efficientnet-b0') | |
""" | |
def __init__(self, device , blocks_args=None, global_params=None): | |
super().__init__() | |
assert isinstance(blocks_args, list), 'blocks_args should be a list' | |
assert len(blocks_args) > 0, 'block args must be greater than 0' | |
self._global_params = global_params | |
self._blocks_args = blocks_args | |
self.type = type | |
# Get static or dynamic convolution depending on image size | |
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) | |
# Batch norm parameters | |
bn_mom = 1 - self._global_params.batch_norm_momentum | |
bn_eps = self._global_params.batch_norm_epsilon | |
# Stem | |
in_channels = 4 # rgb | |
out_channels = round_filters(32, self._global_params) # number of output channels | |
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) | |
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) | |
# Build blocks | |
self._blocks = nn.ModuleList([]) | |
for block_args in self._blocks_args: | |
# Update block input and output filters based on depth multiplier. | |
block_args = block_args._replace( | |
input_filters=round_filters(block_args.input_filters, self._global_params), | |
output_filters=round_filters(block_args.output_filters, self._global_params), | |
num_repeat=round_repeats(block_args.num_repeat, self._global_params) | |
) | |
# The first block needs to take care of stride and filter size increase. | |
self._blocks.append(MBConvBlock(block_args, self._global_params)) | |
if block_args.num_repeat > 1: | |
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) | |
for _ in range(block_args.num_repeat - 1): | |
self._blocks.append(MBConvBlock(block_args, self._global_params)) | |
# Head | |
in_channels = block_args.output_filters # output of final block | |
out_channels = round_filters(1280, self._global_params) | |
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) | |
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) | |
# Final linear layer | |
self._avg_pooling = nn.AdaptiveAvgPool2d(1) | |
self._dropout = nn.Dropout(self._global_params.dropout_rate) | |
self.conv_reg = nn.Conv2d(1792, 1, 1) | |
if self.type == 'big_map' or self.type == 'img': | |
self.conv_transe1 = nn.Conv2d(1792, 448, 1) | |
self.bn_transe1 = nn.BatchNorm2d(num_features=448, momentum=bn_mom, eps=bn_eps) | |
self.conv_transe2 = nn.Conv2d(448, 112, 1) | |
self.bn_transe2 = nn.BatchNorm2d(num_features=112, momentum=bn_mom, eps=bn_eps) | |
if self.type == 'big_map': | |
self.conv_transe_mask = nn.Conv2d(112, 1, 1) | |
self.deconv_big = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose | |
if self.type == 'img': | |
self.conv_transe3 = nn.Conv2d(112, 3, 1) | |
self.deconv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose | |
elif self.type == 'deconv_map' or self.type == 'deconv_img': | |
self.conv_big_reg = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose | |
self.conv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose | |
else: | |
self.conv_reg = nn.Conv2d(1792, 1, 1) | |
self.relu = nn.ReLU() | |
self.up_double = nn.Upsample(scale_factor=2, mode='bilinear') | |
self._fc = nn.Linear(out_channels, 1) | |
self._swish = MemoryEfficientSwish() | |
self.sig = nn.Sigmoid() | |
self.device = device | |
def set_swish(self, memory_efficient=True): | |
"""Sets swish function as memory efficient (for training) or standard (for export)""" | |
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() | |
for block in self._blocks: | |
block.set_swish(memory_efficient) | |
def extract_features(self, inputs): | |
""" Returns output of the final convolution layer """ | |
# Stem | |
x = self._swish(self._bn0(self._conv_stem(inputs))) | |
# Blocks | |
for idx, block in enumerate(self._blocks): | |
drop_connect_rate = self._global_params.drop_connect_rate | |
if drop_connect_rate: | |
drop_connect_rate *= float(idx) / len(self._blocks) | |
x = block(x, drop_connect_rate=drop_connect_rate) | |
# Head | |
x = self._swish(self._bn1(self._conv_head(x))) | |
return x | |
def forward(self, inputs, weights=None): | |
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """ | |
bs = inputs.size(0) | |
# Convolution layers | |
x = self.extract_features(inputs) | |
# Pooling and final linear layer | |
x = self._avg_pooling(x) | |
x = x.view(bs, -1) | |
x = self._dropout(x) | |
x = self._fc(x) | |
return x | |
def from_name(cls, model_name, device, override_params=None): | |
cls._check_model_name_is_valid(model_name) | |
blocks_args, global_params = get_model_params(model_name, override_params) | |
return cls(device, blocks_args, global_params) | |
def from_pretrained(cls, model_name, num_classes=1000, in_channels=3): | |
model = cls.from_name(model_name, override_params={'num_classes': num_classes}) | |
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) | |
if in_channels != 3: | |
Conv2d = get_same_padding_conv2d(image_size=model._global_params.image_size) | |
out_channels = round_filters(32, model._global_params) | |
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) | |
return model | |
def from_pretrained(cls, model_name, num_classes=1000): | |
model = cls.from_name(model_name, override_params={'num_classes': num_classes}) | |
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) | |
return model | |
def get_image_size(cls, model_name): | |
cls._check_model_name_is_valid(model_name) | |
_, _, res, _ = efficientnet_params(model_name) | |
return res | |
def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False): | |
""" Validates model name. None that pretrained weights are only available for | |
the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """ | |
num_models = 4 if also_need_pretrained_weights else 8 | |
valid_models = ['efficientnet-b' + str(i) for i in range(num_models)] | |
if model_name not in valid_models: | |
raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) | |