|
"""Script to download the pre-trained tensorflow weights and convert them to pytorch weights.""" |
|
import os |
|
import argparse |
|
import torch |
|
import numpy as np |
|
from tensorflow.python.training import py_checkpoint_reader |
|
|
|
from repnet import utils |
|
from repnet.model import RepNet |
|
|
|
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
TF_CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/repnet_ckpt' |
|
TF_CHECKPOINT_FILES = ['checkpoint', 'ckpt-88.data-00000-of-00002', 'ckpt-88.data-00001-of-00002', 'ckpt-88.index'] |
|
OUT_CHECKPOINTS_DIR = os.path.join(PROJECT_ROOT, 'checkpoints') |
|
|
|
|
|
WEIGHTS_PERMUTATION = { |
|
2: (1, 0), |
|
4: (3, 2, 0, 1), |
|
5: (4, 3, 0, 1, 2) |
|
} |
|
|
|
|
|
ATTR_MAPPING = { |
|
'kernel':'weight', |
|
'bias': 'bias', |
|
'beta': 'bias', |
|
'gamma': 'weight', |
|
'moving_mean': 'running_mean', |
|
'moving_variance': 'running_var' |
|
} |
|
|
|
|
|
WEIGHTS_MAPPING = [ |
|
|
|
('base_model.layer-2', 'conv1_conv', 'encoder.stem.conv'), |
|
('base_model.layer-5', 'conv2_block1_preact_bn', 'encoder.stages.0.blocks.0.norm1'), |
|
('base_model.layer-7', 'conv2_block1_1_conv', 'encoder.stages.0.blocks.0.conv1'), |
|
('base_model.layer-8', 'conv2_block1_1_bn', 'encoder.stages.0.blocks.0.norm2'), |
|
('base_model.layer_with_weights-4', 'conv2_block1_2_conv', 'encoder.stages.0.blocks.0.conv2'), |
|
('base_model.layer_with_weights-5', 'conv2_block1_2_bn', 'encoder.stages.0.blocks.0.norm3'), |
|
('base_model.layer_with_weights-6', 'conv2_block1_0_conv', 'encoder.stages.0.blocks.0.downsample.conv'), |
|
('base_model.layer_with_weights-7', 'conv2_block1_3_conv', 'encoder.stages.0.blocks.0.conv3'), |
|
('base_model.layer_with_weights-8', 'conv2_block2_preact_bn', 'encoder.stages.0.blocks.1.norm1'), |
|
('base_model.layer_with_weights-9', 'conv2_block2_1_conv', 'encoder.stages.0.blocks.1.conv1'), |
|
('base_model.layer_with_weights-10', 'conv2_block2_1_bn', 'encoder.stages.0.blocks.1.norm2'), |
|
('base_model.layer_with_weights-11', 'conv2_block2_2_conv', 'encoder.stages.0.blocks.1.conv2'), |
|
('base_model.layer_with_weights-12', 'conv2_block2_2_bn', 'encoder.stages.0.blocks.1.norm3'), |
|
('base_model.layer_with_weights-13', 'conv2_block2_3_conv', 'encoder.stages.0.blocks.1.conv3'), |
|
('base_model.layer_with_weights-14', 'conv2_block3_preact_bn', 'encoder.stages.0.blocks.2.norm1'), |
|
('base_model.layer_with_weights-15', 'conv2_block3_1_conv', 'encoder.stages.0.blocks.2.conv1'), |
|
('base_model.layer_with_weights-16', 'conv2_block3_1_bn', 'encoder.stages.0.blocks.2.norm2'), |
|
('base_model.layer_with_weights-17', 'conv2_block3_2_conv', 'encoder.stages.0.blocks.2.conv2'), |
|
('base_model.layer_with_weights-18', 'conv2_block3_2_bn', 'encoder.stages.0.blocks.2.norm3'), |
|
('base_model.layer_with_weights-19', 'conv2_block3_3_conv', 'encoder.stages.0.blocks.2.conv3'), |
|
('base_model.layer_with_weights-20', 'conv3_block1_preact_bn', 'encoder.stages.1.blocks.0.norm1'), |
|
('base_model.layer_with_weights-21', 'conv3_block1_1_conv', 'encoder.stages.1.blocks.0.conv1'), |
|
('base_model.layer_with_weights-22', 'conv3_block1_1_bn', 'encoder.stages.1.blocks.0.norm2'), |
|
('base_model.layer_with_weights-23', 'conv3_block1_2_conv', 'encoder.stages.1.blocks.0.conv2'), |
|
('base_model.layer-47', 'conv3_block1_2_bn', 'encoder.stages.1.blocks.0.norm3'), |
|
('base_model.layer_with_weights-25', 'conv3_block1_0_conv', 'encoder.stages.1.blocks.0.downsample.conv'), |
|
('base_model.layer_with_weights-26', 'conv3_block1_3_conv', 'encoder.stages.1.blocks.0.conv3'), |
|
('base_model.layer_with_weights-27', 'conv3_block2_preact_bn', 'encoder.stages.1.blocks.1.norm1'), |
|
('base_model.layer_with_weights-28', 'conv3_block2_1_conv', 'encoder.stages.1.blocks.1.conv1'), |
|
('base_model.layer_with_weights-29', 'conv3_block2_1_bn', 'encoder.stages.1.blocks.1.norm2'), |
|
('base_model.layer_with_weights-30', 'conv3_block2_2_conv', 'encoder.stages.1.blocks.1.conv2'), |
|
('base_model.layer_with_weights-31', 'conv3_block2_2_bn', 'encoder.stages.1.blocks.1.norm3'), |
|
('base_model.layer-61', 'conv3_block2_3_conv', 'encoder.stages.1.blocks.1.conv3'), |
|
('base_model.layer-63', 'conv3_block3_preact_bn', 'encoder.stages.1.blocks.2.norm1'), |
|
('base_model.layer-65', 'conv3_block3_1_conv', 'encoder.stages.1.blocks.2.conv1'), |
|
('base_model.layer-66', 'conv3_block3_1_bn', 'encoder.stages.1.blocks.2.norm2'), |
|
('base_model.layer-69', 'conv3_block3_2_conv', 'encoder.stages.1.blocks.2.conv2'), |
|
('base_model.layer-70', 'conv3_block3_2_bn', 'encoder.stages.1.blocks.2.norm3'), |
|
('base_model.layer_with_weights-38', 'conv3_block3_3_conv', 'encoder.stages.1.blocks.2.conv3'), |
|
('base_model.layer-74', 'conv3_block4_preact_bn', 'encoder.stages.1.blocks.3.norm1'), |
|
('base_model.layer_with_weights-40', 'conv3_block4_1_conv', 'encoder.stages.1.blocks.3.conv1'), |
|
('base_model.layer_with_weights-41', 'conv3_block4_1_bn', 'encoder.stages.1.blocks.3.norm2'), |
|
('base_model.layer_with_weights-42', 'conv3_block4_2_conv', 'encoder.stages.1.blocks.3.conv2'), |
|
('base_model.layer_with_weights-43', 'conv3_block4_2_bn', 'encoder.stages.1.blocks.3.norm3'), |
|
('base_model.layer_with_weights-44', 'conv3_block4_3_conv', 'encoder.stages.1.blocks.3.conv3'), |
|
('base_model.layer_with_weights-45', 'conv4_block1_preact_bn', 'encoder.stages.2.blocks.0.norm1'), |
|
('base_model.layer_with_weights-46', 'conv4_block1_1_conv', 'encoder.stages.2.blocks.0.conv1'), |
|
('base_model.layer_with_weights-47', 'conv4_block1_1_bn', 'encoder.stages.2.blocks.0.norm2'), |
|
('base_model.layer-92', 'conv4_block1_2_conv', 'encoder.stages.2.blocks.0.conv2'), |
|
('base_model.layer-93', 'conv4_block1_2_bn', 'encoder.stages.2.blocks.0.norm3'), |
|
('base_model.layer-95', 'conv4_block1_0_conv', 'encoder.stages.2.blocks.0.downsample.conv'), |
|
('base_model.layer-96', 'conv4_block1_3_conv', 'encoder.stages.2.blocks.0.conv3'), |
|
('base_model.layer-98', 'conv4_block2_preact_bn', 'encoder.stages.2.blocks.1.norm1'), |
|
('base_model.layer-100', 'conv4_block2_1_conv', 'encoder.stages.2.blocks.1.conv1'), |
|
('base_model.layer-101', 'conv4_block2_1_bn', 'encoder.stages.2.blocks.1.norm2'), |
|
('base_model.layer-104', 'conv4_block2_2_conv', 'encoder.stages.2.blocks.1.conv2'), |
|
('base_model.layer-105', 'conv4_block2_2_bn', 'encoder.stages.2.blocks.1.norm3'), |
|
('base_model.layer-107', 'conv4_block2_3_conv', 'encoder.stages.2.blocks.1.conv3'), |
|
('base_model.layer-109', 'conv4_block3_preact_bn', 'encoder.stages.2.blocks.2.norm1'), |
|
('base_model.layer-111', 'conv4_block3_1_conv', 'encoder.stages.2.blocks.2.conv1'), |
|
('base_model.layer-112', 'conv4_block3_1_bn', 'encoder.stages.2.blocks.2.norm2'), |
|
('base_model.layer-115', 'conv4_block3_2_conv', 'encoder.stages.2.blocks.2.conv2'), |
|
('base_model.layer-116', 'conv4_block3_2_bn', 'encoder.stages.2.blocks.2.norm3'), |
|
('base_model.layer-118', 'conv4_block3_3_conv', 'encoder.stages.2.blocks.2.conv3'), |
|
|
|
('temporal_conv_layers.0', 'conv3d', 'temporal_conv.0'), |
|
('temporal_bn_layers.0', 'batch_normalization', 'temporal_conv.1'), |
|
('conv_3x3_layer', 'conv2d', 'tsm_conv.0'), |
|
|
|
('input_projection', 'dense', 'period_length_head.0.input_projection'), |
|
('pos_encoding', None, 'period_length_head.0.pos_encoding'), |
|
('transformer_layers.0.ffn.layer-0', None, 'period_length_head.0.transformer_layer.linear1'), |
|
('transformer_layers.0.ffn.layer-1', None, 'period_length_head.0.transformer_layer.linear2'), |
|
('transformer_layers.0.layernorm1', None, 'period_length_head.0.transformer_layer.norm1'), |
|
('transformer_layers.0.layernorm2', None, 'period_length_head.0.transformer_layer.norm2'), |
|
('transformer_layers.0.mha.w_weight', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_weight'), |
|
('transformer_layers.0.mha.w_bias', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_bias'), |
|
('transformer_layers.0.mha.dense', None, 'period_length_head.0.transformer_layer.self_attn.out_proj'), |
|
('fc_layers.0', 'dense_14', 'period_length_head.1'), |
|
('fc_layers.1', 'dense_15', 'period_length_head.3'), |
|
('fc_layers.2', 'dense_16', 'period_length_head.5'), |
|
|
|
('input_projection2', 'dense_1', 'periodicity_head.0.input_projection'), |
|
('pos_encoding2', None, 'periodicity_head.0.pos_encoding'), |
|
('transformer_layers2.0.ffn.layer-0', None, 'periodicity_head.0.transformer_layer.linear1'), |
|
('transformer_layers2.0.ffn.layer-1', None, 'periodicity_head.0.transformer_layer.linear2'), |
|
('transformer_layers2.0.layernorm1', None, 'periodicity_head.0.transformer_layer.norm1'), |
|
('transformer_layers2.0.layernorm2', None, 'periodicity_head.0.transformer_layer.norm2'), |
|
('transformer_layers2.0.mha.w_weight',None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_weight'), |
|
('transformer_layers2.0.mha.w_bias', None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_bias'), |
|
('transformer_layers2.0.mha.dense', None, 'periodicity_head.0.transformer_layer.self_attn.out_proj'), |
|
('within_period_fc_layers.0', 'dense_17', 'periodicity_head.1'), |
|
('within_period_fc_layers.1', 'dense_18', 'periodicity_head.3'), |
|
('within_period_fc_layers.2', 'dense_19', 'periodicity_head.5'), |
|
] |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Download and convert the pre-trained weights from tensorflow to pytorch.') |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parser.parse_args() |
|
|
|
|
|
print('Downloading checkpoints...') |
|
tf_checkpoint_dir = os.path.join(OUT_CHECKPOINTS_DIR, 'tf_checkpoint') |
|
os.makedirs(tf_checkpoint_dir, exist_ok=True) |
|
for file in TF_CHECKPOINT_FILES: |
|
dst = os.path.join(tf_checkpoint_dir, file) |
|
if not os.path.exists(dst): |
|
utils.download_file(f'{TF_CHECKPOINT_BASE_URL}/{file}', dst) |
|
|
|
|
|
print('Loading tensorflow checkpoint...') |
|
checkpoint_path = os.path.join(tf_checkpoint_dir, 'ckpt-88') |
|
checkpoint_reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) |
|
shape_map = checkpoint_reader.get_variable_to_shape_map() |
|
tf_state_dict = {} |
|
for var_name in sorted(shape_map.keys()): |
|
var_tensor = checkpoint_reader.get_tensor(var_name) |
|
if not var_name.startswith('model') or '.OPTIMIZER_SLOT' in var_name: |
|
continue |
|
|
|
var_path = var_name.split('/')[1:] |
|
var_path = [p for p in var_path if p not in ['.ATTRIBUTES', 'VARIABLE_VALUE']] |
|
|
|
current_dict = tf_state_dict |
|
for path in var_path[:-1]: |
|
current_dict = current_dict.setdefault(path, {}) |
|
current_dict[var_path[-1]] = var_tensor |
|
|
|
|
|
for k in ['transformer_layers', 'transformer_layers2']: |
|
v = tf_state_dict[k]['0']['mha'] |
|
v['w_weight'] = np.concatenate([v['wq']['kernel'].T, v['wk']['kernel'].T, v['wv']['kernel'].T], axis=0) |
|
v['w_bias'] = np.concatenate([v['wq']['bias'].T, v['wk']['bias'].T, v['wv']['bias'].T], axis=0) |
|
del v['wk'], v['wq'], v['wv'] |
|
tf_state_dict = utils.flatten_dict(tf_state_dict, keep_last=True) |
|
|
|
for k, v in tf_state_dict.items(): |
|
if not isinstance(v, dict): |
|
tf_state_dict[k] = {None: v} |
|
|
|
|
|
print(f'Converting to PyTorch format...') |
|
pt_checkpoint_path = os.path.join(OUT_CHECKPOINTS_DIR, 'pytorch_weights.pth') |
|
pt_state_dict = {} |
|
for k_tf, _, k_pt in WEIGHTS_MAPPING: |
|
assert k_pt not in pt_state_dict |
|
pt_state_dict[k_pt] = {} |
|
for attr in tf_state_dict[k_tf]: |
|
new_attr = ATTR_MAPPING.get(attr, attr) |
|
pt_state_dict[k_pt][new_attr] = torch.from_numpy(tf_state_dict[k_tf][attr]) |
|
if attr == 'kernel': |
|
weights_permutation = WEIGHTS_PERMUTATION[pt_state_dict[k_pt][new_attr].ndim] |
|
pt_state_dict[k_pt][new_attr] = pt_state_dict[k_pt][new_attr].permute(weights_permutation) |
|
pt_state_dict = utils.flatten_dict(pt_state_dict, skip_none=True) |
|
torch.save(pt_state_dict, pt_checkpoint_path) |
|
|
|
|
|
print('Check that the weights can be loaded into the model...') |
|
model = RepNet() |
|
pt_state_dict = torch.load(pt_checkpoint_path) |
|
model.load_state_dict(pt_state_dict) |
|
|
|
print(f'Done. PyTorch weights saved to {pt_checkpoint_path}.') |
|
|