|
|
|
|
|
|
|
|
|
import contextlib |
|
from copy import deepcopy |
|
from typing import Sequence |
|
|
|
import torch |
|
import torch.nn as nn |
|
from thop import profile |
|
|
|
__all__ = [ |
|
"fuse_conv_and_bn", |
|
"fuse_model", |
|
"get_model_info", |
|
"replace_module", |
|
"freeze_module", |
|
"adjust_status", |
|
] |
|
|
|
|
|
def get_model_info(model: nn.Module, tsize: Sequence[int]) -> str: |
|
stride = 64 |
|
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) |
|
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False) |
|
params /= 1e6 |
|
flops /= 1e9 |
|
flops *= tsize[0] * tsize[1] / stride / stride * 2 |
|
info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops) |
|
return info |
|
|
|
|
|
def fuse_conv_and_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d: |
|
""" |
|
Fuse convolution and batchnorm layers. |
|
check more info on https://tehnokv.com/posts/fusing-batchnorm-and-conv/ |
|
|
|
Args: |
|
conv (nn.Conv2d): convolution to fuse. |
|
bn (nn.BatchNorm2d): batchnorm to fuse. |
|
|
|
Returns: |
|
nn.Conv2d: fused convolution behaves the same as the input conv and bn. |
|
""" |
|
fusedconv = ( |
|
nn.Conv2d( |
|
conv.in_channels, |
|
conv.out_channels, |
|
kernel_size=conv.kernel_size, |
|
stride=conv.stride, |
|
padding=conv.padding, |
|
groups=conv.groups, |
|
bias=True, |
|
) |
|
.requires_grad_(False) |
|
.to(conv.weight.device) |
|
) |
|
|
|
|
|
w_conv = conv.weight.clone().view(conv.out_channels, -1) |
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) |
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) |
|
|
|
|
|
b_conv = ( |
|
torch.zeros(conv.weight.size(0), device=conv.weight.device) |
|
if conv.bias is None |
|
else conv.bias |
|
) |
|
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div( |
|
torch.sqrt(bn.running_var + bn.eps) |
|
) |
|
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) |
|
|
|
return fusedconv |
|
|
|
|
|
def fuse_model(model: nn.Module) -> nn.Module: |
|
"""fuse conv and bn in model |
|
|
|
Args: |
|
model (nn.Module): model to fuse |
|
|
|
Returns: |
|
nn.Module: fused model |
|
""" |
|
from yolox.models.network_blocks import BaseConv |
|
|
|
for m in model.modules(): |
|
if type(m) is BaseConv and hasattr(m, "bn"): |
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) |
|
delattr(m, "bn") |
|
m.forward = m.fuseforward |
|
return model |
|
|
|
|
|
def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module: |
|
""" |
|
Replace given type in module to a new type. mostly used in deploy. |
|
|
|
Args: |
|
module (nn.Module): model to apply replace operation. |
|
replaced_module_type (Type): module type to be replaced. |
|
new_module_type (Type) |
|
replace_func (function): python function to describe replace logic. Defalut value None. |
|
|
|
Returns: |
|
model (nn.Module): module that already been replaced. |
|
""" |
|
|
|
def default_replace_func(replaced_module_type, new_module_type): |
|
return new_module_type() |
|
|
|
if replace_func is None: |
|
replace_func = default_replace_func |
|
|
|
model = module |
|
if isinstance(module, replaced_module_type): |
|
model = replace_func(replaced_module_type, new_module_type) |
|
else: |
|
for name, child in module.named_children(): |
|
new_child = replace_module(child, replaced_module_type, new_module_type) |
|
if new_child is not child: |
|
model.add_module(name, new_child) |
|
|
|
return model |
|
|
|
|
|
def freeze_module(module: nn.Module, name=None) -> nn.Module: |
|
"""freeze module inplace |
|
|
|
Args: |
|
module (nn.Module): module to freeze. |
|
name (str, optional): name to freeze. If not given, freeze the whole module. |
|
Note that fuzzy match is not supported. Defaults to None. |
|
|
|
Examples: |
|
freeze the backbone of model |
|
>>> freeze_moudle(model.backbone) |
|
|
|
or freeze the backbone of model by name |
|
>>> freeze_moudle(model, name="backbone") |
|
""" |
|
for param_name, parameter in module.named_parameters(): |
|
if name is None or name in param_name: |
|
parameter.requires_grad = False |
|
|
|
|
|
for module_name, sub_module in module.named_modules(): |
|
|
|
if name is None or name in module_name: |
|
sub_module.eval() |
|
|
|
return module |
|
|
|
|
|
@contextlib.contextmanager |
|
def adjust_status(module: nn.Module, training: bool = False) -> nn.Module: |
|
"""Adjust module to training/eval mode temporarily. |
|
|
|
Args: |
|
module (nn.Module): module to adjust status. |
|
training (bool): training mode to set. True for train mode, False fro eval mode. |
|
|
|
Examples: |
|
>>> with adjust_status(model, training=False): |
|
... model(data) |
|
""" |
|
status = {} |
|
|
|
def backup_status(module): |
|
for m in module.modules(): |
|
|
|
status[m] = m.training |
|
m.training = training |
|
|
|
def recover_status(module): |
|
for m in module.modules(): |
|
|
|
m.training = status.pop(m) |
|
|
|
backup_status(module) |
|
yield module |
|
recover_status(module) |
|
|