File size: 4,171 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
import pkgutil
from copy import deepcopy
from torch import nn as nn
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
def extract_layer(model, layer):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
if not hasattr(model, 'module') and layer[0] == 'module':
layer = layer[1:]
for l in layer:
if hasattr(module, l):
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
else:
return module
return module
def set_layer(model, layer, val):
layer = layer.split('.')
module = model
if hasattr(model, 'module') and layer[0] != 'module':
module = model.module
lst_index = 0
module2 = module
for l in layer:
if hasattr(module2, l):
if not l.isdigit():
module2 = getattr(module2, l)
else:
module2 = module2[int(l)]
lst_index += 1
lst_index -= 1
for l in layer[:lst_index]:
if not l.isdigit():
module = getattr(module, l)
else:
module = module[int(l)]
l = layer[lst_index]
setattr(module, l, val)
def adapt_model_from_string(parent_module, model_string):
separator = '***'
state_dict = {}
lst_shape = model_string.split(separator)
for k in lst_shape:
k = k.split(':')
key = k[0]
shape = k[1][1:-1].split(',')
if shape[0] != '':
state_dict[key] = [int(i) for i in shape]
new_module = deepcopy(parent_module)
for n, m in parent_module.named_modules():
old_module = extract_layer(parent_module, n)
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
if isinstance(old_module, Conv2dSame):
conv = Conv2dSame
else:
conv = nn.Conv2d
s = state_dict[n + '.weight']
in_channels = s[1]
out_channels = s[0]
g = 1
if old_module.groups > 1:
in_channels = out_channels
g = in_channels
new_conv = conv(
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
groups=g, stride=old_module.stride)
set_layer(new_module, n, new_conv)
elif isinstance(old_module, BatchNormAct2d):
new_bn = BatchNormAct2d(
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
new_bn.drop = old_module.drop
new_bn.act = old_module.act
set_layer(new_module, n, new_bn)
elif isinstance(old_module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
set_layer(new_module, n, new_bn)
elif isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1]
new_fc = Linear(
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'):
new_module.num_features = num_features
new_module.eval()
parent_module.eval()
return new_module
def adapt_model_from_file(parent_module, model_variant):
adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())
|