|
|
|
import argparse |
|
import os.path as osp |
|
|
|
import mmengine |
|
import numpy as np |
|
import torch |
|
|
|
|
|
def vit_jax_to_torch(jax_weights, num_layer=12): |
|
torch_weights = dict() |
|
|
|
|
|
conv_filters = jax_weights['embedding/kernel'] |
|
conv_filters = conv_filters.permute(3, 2, 0, 1) |
|
torch_weights['patch_embed.projection.weight'] = conv_filters |
|
torch_weights['patch_embed.projection.bias'] = jax_weights[ |
|
'embedding/bias'] |
|
|
|
|
|
torch_weights['pos_embed'] = jax_weights[ |
|
'Transformer/posembed_input/pos_embedding'] |
|
|
|
|
|
torch_weights['cls_token'] = jax_weights['cls'] |
|
|
|
|
|
torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale'] |
|
torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias'] |
|
|
|
|
|
for i in range(num_layer): |
|
jax_block = f'Transformer/encoderblock_{i}' |
|
torch_block = f'layers.{i}' |
|
|
|
|
|
torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[ |
|
f'{jax_block}/LayerNorm_0/scale'] |
|
torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[ |
|
f'{jax_block}/LayerNorm_0/bias'] |
|
|
|
|
|
query_weight = jax_weights[ |
|
f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel'] |
|
query_bias = jax_weights[ |
|
f'{jax_block}/MultiHeadDotProductAttention_1/query/bias'] |
|
key_weight = jax_weights[ |
|
f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel'] |
|
key_bias = jax_weights[ |
|
f'{jax_block}/MultiHeadDotProductAttention_1/key/bias'] |
|
value_weight = jax_weights[ |
|
f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel'] |
|
value_bias = jax_weights[ |
|
f'{jax_block}/MultiHeadDotProductAttention_1/value/bias'] |
|
|
|
qkv_weight = torch.from_numpy( |
|
np.stack((query_weight, key_weight, value_weight), 1)) |
|
qkv_weight = torch.flatten(qkv_weight, start_dim=1) |
|
qkv_bias = torch.from_numpy( |
|
np.stack((query_bias, key_bias, value_bias), 0)) |
|
qkv_bias = torch.flatten(qkv_bias, start_dim=0) |
|
|
|
torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight |
|
torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias |
|
to_out_weight = jax_weights[ |
|
f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel'] |
|
to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1) |
|
torch_weights[ |
|
f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight |
|
torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[ |
|
f'{jax_block}/MultiHeadDotProductAttention_1/out/bias'] |
|
|
|
|
|
torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[ |
|
f'{jax_block}/LayerNorm_2/scale'] |
|
torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[ |
|
f'{jax_block}/LayerNorm_2/bias'] |
|
|
|
|
|
torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[ |
|
f'{jax_block}/MlpBlock_3/Dense_0/kernel'] |
|
torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[ |
|
f'{jax_block}/MlpBlock_3/Dense_0/bias'] |
|
torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[ |
|
f'{jax_block}/MlpBlock_3/Dense_1/kernel'] |
|
torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[ |
|
f'{jax_block}/MlpBlock_3/Dense_1/bias'] |
|
|
|
|
|
for k, v in torch_weights.items(): |
|
if 'weight' in k and 'patch_embed' not in k and 'ln' not in k: |
|
v = v.permute(1, 0) |
|
torch_weights[k] = v |
|
|
|
return torch_weights |
|
|
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser( |
|
description='Convert keys from jax official pretrained vit models to ' |
|
'MMSegmentation style.') |
|
parser.add_argument('src', help='src model path or url') |
|
|
|
parser.add_argument('dst', help='save path') |
|
args = parser.parse_args() |
|
|
|
jax_weights = np.load(args.src) |
|
jax_weights_tensor = {} |
|
for key in jax_weights.files: |
|
value = torch.from_numpy(jax_weights[key]) |
|
jax_weights_tensor[key] = value |
|
if 'L_16-i21k' in args.src: |
|
num_layer = 24 |
|
else: |
|
num_layer = 12 |
|
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer) |
|
mmengine.mkdir_or_exist(osp.dirname(args.dst)) |
|
torch.save(torch_weights, args.dst) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|