|
import os |
|
import pkgutil |
|
from copy import deepcopy |
|
|
|
from torch import nn as nn |
|
|
|
from timm.layers import Conv2dSame, BatchNormAct2d, Linear |
|
|
|
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file'] |
|
|
|
|
|
def extract_layer(model, layer): |
|
layer = layer.split('.') |
|
module = model |
|
if hasattr(model, 'module') and layer[0] != 'module': |
|
module = model.module |
|
if not hasattr(model, 'module') and layer[0] == 'module': |
|
layer = layer[1:] |
|
for l in layer: |
|
if hasattr(module, l): |
|
if not l.isdigit(): |
|
module = getattr(module, l) |
|
else: |
|
module = module[int(l)] |
|
else: |
|
return module |
|
return module |
|
|
|
|
|
def set_layer(model, layer, val): |
|
layer = layer.split('.') |
|
module = model |
|
if hasattr(model, 'module') and layer[0] != 'module': |
|
module = model.module |
|
lst_index = 0 |
|
module2 = module |
|
for l in layer: |
|
if hasattr(module2, l): |
|
if not l.isdigit(): |
|
module2 = getattr(module2, l) |
|
else: |
|
module2 = module2[int(l)] |
|
lst_index += 1 |
|
lst_index -= 1 |
|
for l in layer[:lst_index]: |
|
if not l.isdigit(): |
|
module = getattr(module, l) |
|
else: |
|
module = module[int(l)] |
|
l = layer[lst_index] |
|
setattr(module, l, val) |
|
|
|
|
|
def adapt_model_from_string(parent_module, model_string): |
|
separator = '***' |
|
state_dict = {} |
|
lst_shape = model_string.split(separator) |
|
for k in lst_shape: |
|
k = k.split(':') |
|
key = k[0] |
|
shape = k[1][1:-1].split(',') |
|
if shape[0] != '': |
|
state_dict[key] = [int(i) for i in shape] |
|
|
|
new_module = deepcopy(parent_module) |
|
for n, m in parent_module.named_modules(): |
|
old_module = extract_layer(parent_module, n) |
|
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): |
|
if isinstance(old_module, Conv2dSame): |
|
conv = Conv2dSame |
|
else: |
|
conv = nn.Conv2d |
|
s = state_dict[n + '.weight'] |
|
in_channels = s[1] |
|
out_channels = s[0] |
|
g = 1 |
|
if old_module.groups > 1: |
|
in_channels = out_channels |
|
g = in_channels |
|
new_conv = conv( |
|
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, |
|
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, |
|
groups=g, stride=old_module.stride) |
|
set_layer(new_module, n, new_conv) |
|
elif isinstance(old_module, BatchNormAct2d): |
|
new_bn = BatchNormAct2d( |
|
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, |
|
affine=old_module.affine, track_running_stats=True) |
|
new_bn.drop = old_module.drop |
|
new_bn.act = old_module.act |
|
set_layer(new_module, n, new_bn) |
|
elif isinstance(old_module, nn.BatchNorm2d): |
|
new_bn = nn.BatchNorm2d( |
|
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, |
|
affine=old_module.affine, track_running_stats=True) |
|
set_layer(new_module, n, new_bn) |
|
elif isinstance(old_module, nn.Linear): |
|
|
|
num_features = state_dict[n + '.weight'][1] |
|
new_fc = Linear( |
|
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) |
|
set_layer(new_module, n, new_fc) |
|
if hasattr(new_module, 'num_features'): |
|
new_module.num_features = num_features |
|
new_module.eval() |
|
parent_module.eval() |
|
|
|
return new_module |
|
|
|
|
|
def adapt_model_from_file(parent_module, model_variant): |
|
adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt')) |
|
return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip()) |
|
|