|
import tensorflow as tf |
|
import torch |
|
import numpy as np |
|
|
|
|
|
def main(model_name: str = "efficientnetv2-s", |
|
tf_weights_path: str = "./efficientnetv2-s/model", |
|
stage0_num: int = 2, |
|
fused_conv_num: int = 10): |
|
|
|
except_var = ["global_step"] |
|
|
|
new_weights = {} |
|
var_list = [i for i in tf.train.list_variables(tf_weights_path) if "Exponential" not in i[0]] |
|
reader = tf.train.load_checkpoint(tf_weights_path) |
|
for v in var_list: |
|
if v[0] in except_var: |
|
continue |
|
new_name = v[0].replace(model_name + "/", "").replace("/", ".") |
|
|
|
if "stem" in v[0]: |
|
new_name = new_name.replace("conv2d.kernel", |
|
"conv.weight") |
|
|
|
new_name = new_name.replace("tpu_batch_normalization.beta", |
|
"bn.bias") |
|
new_name = new_name.replace("tpu_batch_normalization.gamma", |
|
"bn.weight") |
|
new_name = new_name.replace("tpu_batch_normalization.moving_mean", |
|
"bn.running_mean") |
|
new_name = new_name.replace("tpu_batch_normalization.moving_variance", |
|
"bn.running_var") |
|
elif "head" in v[0]: |
|
new_name = new_name.replace("conv2d.kernel", |
|
"project_conv.conv.weight") |
|
new_name = new_name.replace("dense.kernel", |
|
"classifier.weight") |
|
new_name = new_name.replace("dense.bias", |
|
"classifier.bias") |
|
|
|
new_name = new_name.replace("tpu_batch_normalization.beta", |
|
"project_conv.bn.bias") |
|
new_name = new_name.replace("tpu_batch_normalization.gamma", |
|
"project_conv.bn.weight") |
|
new_name = new_name.replace("tpu_batch_normalization.moving_mean", |
|
"project_conv.bn.running_mean") |
|
new_name = new_name.replace("tpu_batch_normalization.moving_variance", |
|
"project_conv.bn.running_var") |
|
elif "blocks" in v[0]: |
|
|
|
blocks_id = new_name.split(".", maxsplit=1)[0].replace("blocks_", "") |
|
new_name = new_name.replace("blocks_{}".format(blocks_id), |
|
"blocks.{}".format(blocks_id)) |
|
|
|
if int(blocks_id) <= stage0_num - 1: |
|
new_name = new_name.replace("conv2d.kernel", |
|
"project_conv.conv.weight") |
|
new_name = new_name.replace("tpu_batch_normalization.beta", |
|
"project_conv.bn.bias") |
|
new_name = new_name.replace("tpu_batch_normalization.gamma", |
|
"project_conv.bn.weight") |
|
new_name = new_name.replace("tpu_batch_normalization.moving_mean", |
|
"project_conv.bn.running_mean") |
|
new_name = new_name.replace("tpu_batch_normalization.moving_variance", |
|
"project_conv.bn.running_var") |
|
else: |
|
new_name = new_name.replace("blocks.{}.conv2d.kernel".format(blocks_id), |
|
"blocks.{}.expand_conv.conv.weight".format(blocks_id)) |
|
new_name = new_name.replace("tpu_batch_normalization.beta", |
|
"expand_conv.bn.bias") |
|
new_name = new_name.replace("tpu_batch_normalization.gamma", |
|
"expand_conv.bn.weight") |
|
new_name = new_name.replace("tpu_batch_normalization.moving_mean", |
|
"expand_conv.bn.running_mean") |
|
new_name = new_name.replace("tpu_batch_normalization.moving_variance", |
|
"expand_conv.bn.running_var") |
|
|
|
if int(blocks_id) <= fused_conv_num - 1: |
|
new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id), |
|
"blocks.{}.project_conv.conv.weight".format(blocks_id)) |
|
new_name = new_name.replace("tpu_batch_normalization_1.beta", |
|
"project_conv.bn.bias") |
|
new_name = new_name.replace("tpu_batch_normalization_1.gamma", |
|
"project_conv.bn.weight") |
|
new_name = new_name.replace("tpu_batch_normalization_1.moving_mean", |
|
"project_conv.bn.running_mean") |
|
new_name = new_name.replace("tpu_batch_normalization_1.moving_variance", |
|
"project_conv.bn.running_var") |
|
else: |
|
new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id), |
|
"blocks.{}.project_conv.conv.weight".format(blocks_id)) |
|
|
|
new_name = new_name.replace("depthwise_conv2d.depthwise_kernel", |
|
"dwconv.conv.weight") |
|
|
|
new_name = new_name.replace("tpu_batch_normalization_1.beta", |
|
"dwconv.bn.bias") |
|
new_name = new_name.replace("tpu_batch_normalization_1.gamma", |
|
"dwconv.bn.weight") |
|
new_name = new_name.replace("tpu_batch_normalization_1.moving_mean", |
|
"dwconv.bn.running_mean") |
|
new_name = new_name.replace("tpu_batch_normalization_1.moving_variance", |
|
"dwconv.bn.running_var") |
|
|
|
new_name = new_name.replace("tpu_batch_normalization_2.beta", |
|
"project_conv.bn.bias") |
|
new_name = new_name.replace("tpu_batch_normalization_2.gamma", |
|
"project_conv.bn.weight") |
|
new_name = new_name.replace("tpu_batch_normalization_2.moving_mean", |
|
"project_conv.bn.running_mean") |
|
new_name = new_name.replace("tpu_batch_normalization_2.moving_variance", |
|
"project_conv.bn.running_var") |
|
|
|
new_name = new_name.replace("se.conv2d.bias", |
|
"se.conv_reduce.bias") |
|
new_name = new_name.replace("se.conv2d.kernel", |
|
"se.conv_reduce.weight") |
|
new_name = new_name.replace("se.conv2d_1.bias", |
|
"se.conv_expand.bias") |
|
new_name = new_name.replace("se.conv2d_1.kernel", |
|
"se.conv_expand.weight") |
|
else: |
|
print("not recognized name: " + v[0]) |
|
|
|
var = reader.get_tensor(v[0]) |
|
new_var = var |
|
if "conv" in new_name and "weight" in new_name and "bn" not in new_name and "dw" not in new_name: |
|
assert len(var.shape) == 4 |
|
|
|
new_var = np.transpose(var, (3, 2, 0, 1)) |
|
elif "bn" in new_name: |
|
pass |
|
elif "dwconv" in new_name and "weight" in new_name: |
|
|
|
assert len(var.shape) == 4 |
|
new_var = np.transpose(var, (2, 3, 0, 1)) |
|
elif "classifier" in new_name and "weight" in new_name: |
|
assert len(var.shape) == 2 |
|
new_var = np.transpose(var, (1, 0)) |
|
|
|
new_weights[new_name] = torch.as_tensor(new_var) |
|
|
|
torch.save(new_weights, "pre_" + model_name + ".pth") |
|
|
|
|
|
if __name__ == '__main__': |
|
main(model_name="efficientnetv2-s", |
|
tf_weights_path="./efficientnetv2-s/model", |
|
stage0_num=2, |
|
fused_conv_num=10) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|