from typing import Dict import numpy as np import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf def parse_int_or_default(s: str, default: int = 0) -> int: try: return int(s) except: return default def tf2pytorch(checkpoint_path: str) -> Dict: init_vars = tf.train.list_variables(checkpoint_path) tf_vars = {} for name, _ in init_vars: try: # print('Loading TF Weight {} with shape {}'.format(name, shape)) data = tf.train.load_variable(checkpoint_path, name) tf_vars[name] = data except Exception as e: print(f"Load error: {name}") raise layer_idxs = set( [ parse_int_or_default(name.split("/")[0].split("_")[-1], default=0) for name in tf_vars.keys() if "conv2d_transpose" in name ] ) n_layers_per_unet = 6 n_layers_in_chkpt = max(layer_idxs) + 1 assert ( n_layers_in_chkpt % 6 == 0 ), f"expected multiple of {n_layers_per_unet}... ie: {n_layers_per_unet} layers per unet & 1 unet per stem" n_stems = n_layers_in_chkpt // n_layers_per_unet stem_names = { 2: ["vocals", "accompaniment"], 4: ["vocals", "drums", "bass", "other"], 5: ["vocals", "piano", "drums", "bass", "other"], }.get(n_stems, []) assert stem_names, f"Unsupported stem count: {n_stems}" state_dict = {} tf_idx_conv = 0 tf_idx_tconv = 0 tf_idx_bn = 0 for stem_name in stem_names: # Encoder Blocks (Down sampling) for layer_idx in range(n_layers_per_unet): prefix = f"stems.{stem_name}.encoder_layers.{layer_idx}" conv_suffix = "" if tf_idx_conv == 0 else f"_{tf_idx_conv}" bn_suffix = "" if tf_idx_bn == 0 else f"_{tf_idx_bn}" state_dict[f"{prefix}.conv.weight"] = np.transpose( tf_vars[f"conv2d{conv_suffix}/kernel"], (3, 2, 0, 1) ) state_dict[f"{prefix}.conv.bias"] = tf_vars[f"conv2d{conv_suffix}/bias"] tf_idx_conv += 1 state_dict[f"{prefix}.bn.weight"] = tf_vars[ f"batch_normalization{bn_suffix}/gamma" ] state_dict[f"{prefix}.bn.bias"] = tf_vars[ f"batch_normalization{bn_suffix}/beta" ] state_dict[f"{prefix}.bn.running_mean"] = tf_vars[ f"batch_normalization{bn_suffix}/moving_mean" ] state_dict[f"{prefix}.bn.running_var"] = tf_vars[ f"batch_normalization{bn_suffix}/moving_variance" ] tf_idx_bn += 1 # Decoder Blocks (Up sampling) for layer_idx in range(n_layers_per_unet): prefix = f"stems.{stem_name}.decoder_layers.{layer_idx}" tconv_suffix = "" if tf_idx_tconv == 0 else f"_{tf_idx_tconv}" bn_suffix = f"_{tf_idx_bn}" state_dict[f"{prefix}.tconv.weight"] = np.transpose( tf_vars[f"conv2d_transpose{tconv_suffix}/kernel"], (3, 2, 0, 1) ) state_dict[f"{prefix}.tconv.bias"] = tf_vars[ f"conv2d_transpose{tconv_suffix}/bias" ] tf_idx_tconv += 1 state_dict[f"{prefix}.bn.weight"] = tf_vars[ f"batch_normalization{bn_suffix}/gamma" ] state_dict[f"{prefix}.bn.bias"] = tf_vars[ f"batch_normalization{bn_suffix}/beta" ] state_dict[f"{prefix}.bn.running_mean"] = tf_vars[ f"batch_normalization{bn_suffix}/moving_mean" ] state_dict[f"{prefix}.bn.running_var"] = tf_vars[ f"batch_normalization{bn_suffix}/moving_variance" ] tf_idx_bn += 1 # Final conv2d state_dict[f"stems.{stem_name}.up_final.weight"] = np.transpose( tf_vars[f"conv2d_{tf_idx_conv}/kernel"], (3, 2, 0, 1) ) state_dict[f"stems.{stem_name}.up_final.bias"] = tf_vars[ f"conv2d_{tf_idx_conv}/bias" ] tf_idx_conv += 1 return state_dict