File size: 8,695 Bytes
71b93be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import tensorflow as tf
import torch
import numpy as np
def main(model_name: str = "efficientnetv2-s",
tf_weights_path: str = "./efficientnetv2-s/model",
stage0_num: int = 2,
fused_conv_num: int = 10):
except_var = ["global_step"]
new_weights = {}
var_list = [i for i in tf.train.list_variables(tf_weights_path) if "Exponential" not in i[0]]
reader = tf.train.load_checkpoint(tf_weights_path)
for v in var_list:
if v[0] in except_var:
continue
new_name = v[0].replace(model_name + "/", "").replace("/", ".")
if "stem" in v[0]:
new_name = new_name.replace("conv2d.kernel",
"conv.weight")
new_name = new_name.replace("tpu_batch_normalization.beta",
"bn.bias")
new_name = new_name.replace("tpu_batch_normalization.gamma",
"bn.weight")
new_name = new_name.replace("tpu_batch_normalization.moving_mean",
"bn.running_mean")
new_name = new_name.replace("tpu_batch_normalization.moving_variance",
"bn.running_var")
elif "head" in v[0]:
new_name = new_name.replace("conv2d.kernel",
"project_conv.conv.weight")
new_name = new_name.replace("dense.kernel",
"classifier.weight")
new_name = new_name.replace("dense.bias",
"classifier.bias")
new_name = new_name.replace("tpu_batch_normalization.beta",
"project_conv.bn.bias")
new_name = new_name.replace("tpu_batch_normalization.gamma",
"project_conv.bn.weight")
new_name = new_name.replace("tpu_batch_normalization.moving_mean",
"project_conv.bn.running_mean")
new_name = new_name.replace("tpu_batch_normalization.moving_variance",
"project_conv.bn.running_var")
elif "blocks" in v[0]:
# e.g. blocks_0.conv2d.kernel -> 0
blocks_id = new_name.split(".", maxsplit=1)[0].replace("blocks_", "")
new_name = new_name.replace("blocks_{}".format(blocks_id),
"blocks.{}".format(blocks_id))
if int(blocks_id) <= stage0_num - 1: # expansion=1 fused_mbconv
new_name = new_name.replace("conv2d.kernel",
"project_conv.conv.weight")
new_name = new_name.replace("tpu_batch_normalization.beta",
"project_conv.bn.bias")
new_name = new_name.replace("tpu_batch_normalization.gamma",
"project_conv.bn.weight")
new_name = new_name.replace("tpu_batch_normalization.moving_mean",
"project_conv.bn.running_mean")
new_name = new_name.replace("tpu_batch_normalization.moving_variance",
"project_conv.bn.running_var")
else:
new_name = new_name.replace("blocks.{}.conv2d.kernel".format(blocks_id),
"blocks.{}.expand_conv.conv.weight".format(blocks_id))
new_name = new_name.replace("tpu_batch_normalization.beta",
"expand_conv.bn.bias")
new_name = new_name.replace("tpu_batch_normalization.gamma",
"expand_conv.bn.weight")
new_name = new_name.replace("tpu_batch_normalization.moving_mean",
"expand_conv.bn.running_mean")
new_name = new_name.replace("tpu_batch_normalization.moving_variance",
"expand_conv.bn.running_var")
if int(blocks_id) <= fused_conv_num - 1: # fused_mbconv
new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
"blocks.{}.project_conv.conv.weight".format(blocks_id))
new_name = new_name.replace("tpu_batch_normalization_1.beta",
"project_conv.bn.bias")
new_name = new_name.replace("tpu_batch_normalization_1.gamma",
"project_conv.bn.weight")
new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
"project_conv.bn.running_mean")
new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
"project_conv.bn.running_var")
else: # mbconv
new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
"blocks.{}.project_conv.conv.weight".format(blocks_id))
new_name = new_name.replace("depthwise_conv2d.depthwise_kernel",
"dwconv.conv.weight")
new_name = new_name.replace("tpu_batch_normalization_1.beta",
"dwconv.bn.bias")
new_name = new_name.replace("tpu_batch_normalization_1.gamma",
"dwconv.bn.weight")
new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
"dwconv.bn.running_mean")
new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
"dwconv.bn.running_var")
new_name = new_name.replace("tpu_batch_normalization_2.beta",
"project_conv.bn.bias")
new_name = new_name.replace("tpu_batch_normalization_2.gamma",
"project_conv.bn.weight")
new_name = new_name.replace("tpu_batch_normalization_2.moving_mean",
"project_conv.bn.running_mean")
new_name = new_name.replace("tpu_batch_normalization_2.moving_variance",
"project_conv.bn.running_var")
new_name = new_name.replace("se.conv2d.bias",
"se.conv_reduce.bias")
new_name = new_name.replace("se.conv2d.kernel",
"se.conv_reduce.weight")
new_name = new_name.replace("se.conv2d_1.bias",
"se.conv_expand.bias")
new_name = new_name.replace("se.conv2d_1.kernel",
"se.conv_expand.weight")
else:
print("not recognized name: " + v[0])
var = reader.get_tensor(v[0])
new_var = var
if "conv" in new_name and "weight" in new_name and "bn" not in new_name and "dw" not in new_name:
assert len(var.shape) == 4
# conv kernel [h, w, c, n] -> [n, c, h, w]
new_var = np.transpose(var, (3, 2, 0, 1))
elif "bn" in new_name:
pass
elif "dwconv" in new_name and "weight" in new_name:
# dw_kernel [h, w, n, c] -> [n, c, h, w]
assert len(var.shape) == 4
new_var = np.transpose(var, (2, 3, 0, 1))
elif "classifier" in new_name and "weight" in new_name:
assert len(var.shape) == 2
new_var = np.transpose(var, (1, 0))
new_weights[new_name] = torch.as_tensor(new_var)
torch.save(new_weights, "pre_" + model_name + ".pth")
if __name__ == '__main__':
main(model_name="efficientnetv2-s",
tf_weights_path="./efficientnetv2-s/model",
stage0_num=2,
fused_conv_num=10)
# main(model_name="efficientnetv2-m",
# tf_weights_path="./efficientnetv2-m/model",
# stage0_num=3,
# fused_conv_num=13)
# main(model_name="efficientnetv2-l",
# tf_weights_path="./efficientnetv2-l/model",
# stage0_num=4,
# fused_conv_num=18)
|