"""Script to download the pre-trained tensorflow weights and convert them to pytorch weights.""" import os import argparse import torch import numpy as np from tensorflow.python.training import py_checkpoint_reader from repnet import utils from repnet.model import RepNet # Relevant paths PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) TF_CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/repnet_ckpt' TF_CHECKPOINT_FILES = ['checkpoint', 'ckpt-88.data-00000-of-00002', 'ckpt-88.data-00001-of-00002', 'ckpt-88.index'] OUT_CHECKPOINTS_DIR = os.path.join(PROJECT_ROOT, 'checkpoints') # Mapping of ndim -> permutation to go from tf to pytorch WEIGHTS_PERMUTATION = { 2: (1, 0), 4: (3, 2, 0, 1), 5: (4, 3, 0, 1, 2) } # Mapping of tf attributes -> pytorch attributes ATTR_MAPPING = { 'kernel':'weight', 'bias': 'bias', 'beta': 'bias', 'gamma': 'weight', 'moving_mean': 'running_mean', 'moving_variance': 'running_var' } # Mapping of tf checkpoint -> tf model -> pytorch model WEIGHTS_MAPPING = [ # Base frame encoder ('base_model.layer-2', 'conv1_conv', 'encoder.stem.conv'), ('base_model.layer-5', 'conv2_block1_preact_bn', 'encoder.stages.0.blocks.0.norm1'), ('base_model.layer-7', 'conv2_block1_1_conv', 'encoder.stages.0.blocks.0.conv1'), ('base_model.layer-8', 'conv2_block1_1_bn', 'encoder.stages.0.blocks.0.norm2'), ('base_model.layer_with_weights-4', 'conv2_block1_2_conv', 'encoder.stages.0.blocks.0.conv2'), ('base_model.layer_with_weights-5', 'conv2_block1_2_bn', 'encoder.stages.0.blocks.0.norm3'), ('base_model.layer_with_weights-6', 'conv2_block1_0_conv', 'encoder.stages.0.blocks.0.downsample.conv'), ('base_model.layer_with_weights-7', 'conv2_block1_3_conv', 'encoder.stages.0.blocks.0.conv3'), ('base_model.layer_with_weights-8', 'conv2_block2_preact_bn', 'encoder.stages.0.blocks.1.norm1'), ('base_model.layer_with_weights-9', 'conv2_block2_1_conv', 'encoder.stages.0.blocks.1.conv1'), ('base_model.layer_with_weights-10', 'conv2_block2_1_bn', 'encoder.stages.0.blocks.1.norm2'), ('base_model.layer_with_weights-11', 'conv2_block2_2_conv', 'encoder.stages.0.blocks.1.conv2'), ('base_model.layer_with_weights-12', 'conv2_block2_2_bn', 'encoder.stages.0.blocks.1.norm3'), ('base_model.layer_with_weights-13', 'conv2_block2_3_conv', 'encoder.stages.0.blocks.1.conv3'), ('base_model.layer_with_weights-14', 'conv2_block3_preact_bn', 'encoder.stages.0.blocks.2.norm1'), ('base_model.layer_with_weights-15', 'conv2_block3_1_conv', 'encoder.stages.0.blocks.2.conv1'), ('base_model.layer_with_weights-16', 'conv2_block3_1_bn', 'encoder.stages.0.blocks.2.norm2'), ('base_model.layer_with_weights-17', 'conv2_block3_2_conv', 'encoder.stages.0.blocks.2.conv2'), ('base_model.layer_with_weights-18', 'conv2_block3_2_bn', 'encoder.stages.0.blocks.2.norm3'), ('base_model.layer_with_weights-19', 'conv2_block3_3_conv', 'encoder.stages.0.blocks.2.conv3'), ('base_model.layer_with_weights-20', 'conv3_block1_preact_bn', 'encoder.stages.1.blocks.0.norm1'), ('base_model.layer_with_weights-21', 'conv3_block1_1_conv', 'encoder.stages.1.blocks.0.conv1'), ('base_model.layer_with_weights-22', 'conv3_block1_1_bn', 'encoder.stages.1.blocks.0.norm2'), ('base_model.layer_with_weights-23', 'conv3_block1_2_conv', 'encoder.stages.1.blocks.0.conv2'), ('base_model.layer-47', 'conv3_block1_2_bn', 'encoder.stages.1.blocks.0.norm3'), ('base_model.layer_with_weights-25', 'conv3_block1_0_conv', 'encoder.stages.1.blocks.0.downsample.conv'), ('base_model.layer_with_weights-26', 'conv3_block1_3_conv', 'encoder.stages.1.blocks.0.conv3'), ('base_model.layer_with_weights-27', 'conv3_block2_preact_bn', 'encoder.stages.1.blocks.1.norm1'), ('base_model.layer_with_weights-28', 'conv3_block2_1_conv', 'encoder.stages.1.blocks.1.conv1'), ('base_model.layer_with_weights-29', 'conv3_block2_1_bn', 'encoder.stages.1.blocks.1.norm2'), ('base_model.layer_with_weights-30', 'conv3_block2_2_conv', 'encoder.stages.1.blocks.1.conv2'), ('base_model.layer_with_weights-31', 'conv3_block2_2_bn', 'encoder.stages.1.blocks.1.norm3'), ('base_model.layer-61', 'conv3_block2_3_conv', 'encoder.stages.1.blocks.1.conv3'), ('base_model.layer-63', 'conv3_block3_preact_bn', 'encoder.stages.1.blocks.2.norm1'), ('base_model.layer-65', 'conv3_block3_1_conv', 'encoder.stages.1.blocks.2.conv1'), ('base_model.layer-66', 'conv3_block3_1_bn', 'encoder.stages.1.blocks.2.norm2'), ('base_model.layer-69', 'conv3_block3_2_conv', 'encoder.stages.1.blocks.2.conv2'), ('base_model.layer-70', 'conv3_block3_2_bn', 'encoder.stages.1.blocks.2.norm3'), ('base_model.layer_with_weights-38', 'conv3_block3_3_conv', 'encoder.stages.1.blocks.2.conv3'), ('base_model.layer-74', 'conv3_block4_preact_bn', 'encoder.stages.1.blocks.3.norm1'), ('base_model.layer_with_weights-40', 'conv3_block4_1_conv', 'encoder.stages.1.blocks.3.conv1'), ('base_model.layer_with_weights-41', 'conv3_block4_1_bn', 'encoder.stages.1.blocks.3.norm2'), ('base_model.layer_with_weights-42', 'conv3_block4_2_conv', 'encoder.stages.1.blocks.3.conv2'), ('base_model.layer_with_weights-43', 'conv3_block4_2_bn', 'encoder.stages.1.blocks.3.norm3'), ('base_model.layer_with_weights-44', 'conv3_block4_3_conv', 'encoder.stages.1.blocks.3.conv3'), ('base_model.layer_with_weights-45', 'conv4_block1_preact_bn', 'encoder.stages.2.blocks.0.norm1'), ('base_model.layer_with_weights-46', 'conv4_block1_1_conv', 'encoder.stages.2.blocks.0.conv1'), ('base_model.layer_with_weights-47', 'conv4_block1_1_bn', 'encoder.stages.2.blocks.0.norm2'), ('base_model.layer-92', 'conv4_block1_2_conv', 'encoder.stages.2.blocks.0.conv2'), ('base_model.layer-93', 'conv4_block1_2_bn', 'encoder.stages.2.blocks.0.norm3'), ('base_model.layer-95', 'conv4_block1_0_conv', 'encoder.stages.2.blocks.0.downsample.conv'), ('base_model.layer-96', 'conv4_block1_3_conv', 'encoder.stages.2.blocks.0.conv3'), ('base_model.layer-98', 'conv4_block2_preact_bn', 'encoder.stages.2.blocks.1.norm1'), ('base_model.layer-100', 'conv4_block2_1_conv', 'encoder.stages.2.blocks.1.conv1'), ('base_model.layer-101', 'conv4_block2_1_bn', 'encoder.stages.2.blocks.1.norm2'), ('base_model.layer-104', 'conv4_block2_2_conv', 'encoder.stages.2.blocks.1.conv2'), ('base_model.layer-105', 'conv4_block2_2_bn', 'encoder.stages.2.blocks.1.norm3'), ('base_model.layer-107', 'conv4_block2_3_conv', 'encoder.stages.2.blocks.1.conv3'), ('base_model.layer-109', 'conv4_block3_preact_bn', 'encoder.stages.2.blocks.2.norm1'), ('base_model.layer-111', 'conv4_block3_1_conv', 'encoder.stages.2.blocks.2.conv1'), ('base_model.layer-112', 'conv4_block3_1_bn', 'encoder.stages.2.blocks.2.norm2'), ('base_model.layer-115', 'conv4_block3_2_conv', 'encoder.stages.2.blocks.2.conv2'), ('base_model.layer-116', 'conv4_block3_2_bn', 'encoder.stages.2.blocks.2.norm3'), ('base_model.layer-118', 'conv4_block3_3_conv', 'encoder.stages.2.blocks.2.conv3'), # Temporal convolution ('temporal_conv_layers.0', 'conv3d', 'temporal_conv.0'), ('temporal_bn_layers.0', 'batch_normalization', 'temporal_conv.1'), ('conv_3x3_layer', 'conv2d', 'tsm_conv.0'), # Period length head ('input_projection', 'dense', 'period_length_head.0.input_projection'), ('pos_encoding', None, 'period_length_head.0.pos_encoding'), ('transformer_layers.0.ffn.layer-0', None, 'period_length_head.0.transformer_layer.linear1'), ('transformer_layers.0.ffn.layer-1', None, 'period_length_head.0.transformer_layer.linear2'), ('transformer_layers.0.layernorm1', None, 'period_length_head.0.transformer_layer.norm1'), ('transformer_layers.0.layernorm2', None, 'period_length_head.0.transformer_layer.norm2'), ('transformer_layers.0.mha.w_weight', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_weight'), ('transformer_layers.0.mha.w_bias', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_bias'), ('transformer_layers.0.mha.dense', None, 'period_length_head.0.transformer_layer.self_attn.out_proj'), ('fc_layers.0', 'dense_14', 'period_length_head.1'), ('fc_layers.1', 'dense_15', 'period_length_head.3'), ('fc_layers.2', 'dense_16', 'period_length_head.5'), # Periodicity head ('input_projection2', 'dense_1', 'periodicity_head.0.input_projection'), ('pos_encoding2', None, 'periodicity_head.0.pos_encoding'), ('transformer_layers2.0.ffn.layer-0', None, 'periodicity_head.0.transformer_layer.linear1'), ('transformer_layers2.0.ffn.layer-1', None, 'periodicity_head.0.transformer_layer.linear2'), ('transformer_layers2.0.layernorm1', None, 'periodicity_head.0.transformer_layer.norm1'), ('transformer_layers2.0.layernorm2', None, 'periodicity_head.0.transformer_layer.norm2'), ('transformer_layers2.0.mha.w_weight',None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_weight'), ('transformer_layers2.0.mha.w_bias', None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_bias'), ('transformer_layers2.0.mha.dense', None, 'periodicity_head.0.transformer_layer.self_attn.out_proj'), ('within_period_fc_layers.0', 'dense_17', 'periodicity_head.1'), ('within_period_fc_layers.1', 'dense_18', 'periodicity_head.3'), ('within_period_fc_layers.2', 'dense_19', 'periodicity_head.5'), ] # Script arguments parser = argparse.ArgumentParser(description='Download and convert the pre-trained weights from tensorflow to pytorch.') if __name__ == '__main__': args = parser.parse_args() # Download tensorflow checkpoints print('Downloading checkpoints...') tf_checkpoint_dir = os.path.join(OUT_CHECKPOINTS_DIR, 'tf_checkpoint') os.makedirs(tf_checkpoint_dir, exist_ok=True) for file in TF_CHECKPOINT_FILES: dst = os.path.join(tf_checkpoint_dir, file) if not os.path.exists(dst): utils.download_file(f'{TF_CHECKPOINT_BASE_URL}/{file}', dst) # Load tensorflow weights into a dictionary print('Loading tensorflow checkpoint...') checkpoint_path = os.path.join(tf_checkpoint_dir, 'ckpt-88') checkpoint_reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) shape_map = checkpoint_reader.get_variable_to_shape_map() tf_state_dict = {} for var_name in sorted(shape_map.keys()): var_tensor = checkpoint_reader.get_tensor(var_name) if not var_name.startswith('model') or '.OPTIMIZER_SLOT' in var_name: continue # Skip variables that are not part of the model, e.g. from the optimizer # Split var_name into path var_path = var_name.split('/')[1:] # Remove `model`` key from the path var_path = [p for p in var_path if p not in ['.ATTRIBUTES', 'VARIABLE_VALUE']] # Map weights into a nested dictionary current_dict = tf_state_dict for path in var_path[:-1]: current_dict = current_dict.setdefault(path, {}) current_dict[var_path[-1]] = var_tensor # Merge transformer self-attention weights into a single tensor for k in ['transformer_layers', 'transformer_layers2']: v = tf_state_dict[k]['0']['mha'] v['w_weight'] = np.concatenate([v['wq']['kernel'].T, v['wk']['kernel'].T, v['wv']['kernel'].T], axis=0) v['w_bias'] = np.concatenate([v['wq']['bias'].T, v['wk']['bias'].T, v['wv']['bias'].T], axis=0) del v['wk'], v['wq'], v['wv'] tf_state_dict = utils.flatten_dict(tf_state_dict, keep_last=True) # Add missing final level for some weights for k, v in tf_state_dict.items(): if not isinstance(v, dict): tf_state_dict[k] = {None: v} # Convert to a format compatible with PyTorch and save print(f'Converting to PyTorch format...') pt_checkpoint_path = os.path.join(OUT_CHECKPOINTS_DIR, 'pytorch_weights.pth') pt_state_dict = {} for k_tf, _, k_pt in WEIGHTS_MAPPING: assert k_pt not in pt_state_dict pt_state_dict[k_pt] = {} for attr in tf_state_dict[k_tf]: new_attr = ATTR_MAPPING.get(attr, attr) pt_state_dict[k_pt][new_attr] = torch.from_numpy(tf_state_dict[k_tf][attr]) if attr == 'kernel': weights_permutation = WEIGHTS_PERMUTATION[pt_state_dict[k_pt][new_attr].ndim] # Permute weights if needed pt_state_dict[k_pt][new_attr] = pt_state_dict[k_pt][new_attr].permute(weights_permutation) pt_state_dict = utils.flatten_dict(pt_state_dict, skip_none=True) torch.save(pt_state_dict, pt_checkpoint_path) # Initialize the model and try to load the weights print('Check that the weights can be loaded into the model...') model = RepNet() pt_state_dict = torch.load(pt_checkpoint_path) model.load_state_dict(pt_state_dict) print(f'Done. PyTorch weights saved to {pt_checkpoint_path}.')