Spaces:
Runtime error
Runtime error
File size: 4,148 Bytes
62e9d65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from typing import Dict
import numpy as np
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
def parse_int_or_default(s: str, default: int = 0) -> int:
try:
return int(s)
except:
return default
def tf2pytorch(checkpoint_path: str) -> Dict:
init_vars = tf.train.list_variables(checkpoint_path)
tf_vars = {}
for name, _ in init_vars:
try:
# print('Loading TF Weight {} with shape {}'.format(name, shape))
data = tf.train.load_variable(checkpoint_path, name)
tf_vars[name] = data
except Exception as e:
print(f"Load error: {name}")
raise
layer_idxs = set(
[
parse_int_or_default(name.split("/")[0].split("_")[-1], default=0)
for name in tf_vars.keys()
if "conv2d_transpose" in name
]
)
n_layers_per_unet = 6
n_layers_in_chkpt = max(layer_idxs) + 1
assert (
n_layers_in_chkpt % 6 == 0
), f"expected multiple of {n_layers_per_unet}... ie: {n_layers_per_unet} layers per unet & 1 unet per stem"
n_stems = n_layers_in_chkpt // n_layers_per_unet
stem_names = {
2: ["vocals", "accompaniment"],
4: ["vocals", "drums", "bass", "other"],
5: ["vocals", "piano", "drums", "bass", "other"],
}.get(n_stems, [])
assert stem_names, f"Unsupported stem count: {n_stems}"
state_dict = {}
tf_idx_conv = 0
tf_idx_tconv = 0
tf_idx_bn = 0
for stem_name in stem_names:
# Encoder Blocks (Down sampling)
for layer_idx in range(n_layers_per_unet):
prefix = f"stems.{stem_name}.encoder_layers.{layer_idx}"
conv_suffix = "" if tf_idx_conv == 0 else f"_{tf_idx_conv}"
bn_suffix = "" if tf_idx_bn == 0 else f"_{tf_idx_bn}"
state_dict[f"{prefix}.conv.weight"] = np.transpose(
tf_vars[f"conv2d{conv_suffix}/kernel"], (3, 2, 0, 1)
)
state_dict[f"{prefix}.conv.bias"] = tf_vars[f"conv2d{conv_suffix}/bias"]
tf_idx_conv += 1
state_dict[f"{prefix}.bn.weight"] = tf_vars[
f"batch_normalization{bn_suffix}/gamma"
]
state_dict[f"{prefix}.bn.bias"] = tf_vars[
f"batch_normalization{bn_suffix}/beta"
]
state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
f"batch_normalization{bn_suffix}/moving_mean"
]
state_dict[f"{prefix}.bn.running_var"] = tf_vars[
f"batch_normalization{bn_suffix}/moving_variance"
]
tf_idx_bn += 1
# Decoder Blocks (Up sampling)
for layer_idx in range(n_layers_per_unet):
prefix = f"stems.{stem_name}.decoder_layers.{layer_idx}"
tconv_suffix = "" if tf_idx_tconv == 0 else f"_{tf_idx_tconv}"
bn_suffix = f"_{tf_idx_bn}"
state_dict[f"{prefix}.tconv.weight"] = np.transpose(
tf_vars[f"conv2d_transpose{tconv_suffix}/kernel"], (3, 2, 0, 1)
)
state_dict[f"{prefix}.tconv.bias"] = tf_vars[
f"conv2d_transpose{tconv_suffix}/bias"
]
tf_idx_tconv += 1
state_dict[f"{prefix}.bn.weight"] = tf_vars[
f"batch_normalization{bn_suffix}/gamma"
]
state_dict[f"{prefix}.bn.bias"] = tf_vars[
f"batch_normalization{bn_suffix}/beta"
]
state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
f"batch_normalization{bn_suffix}/moving_mean"
]
state_dict[f"{prefix}.bn.running_var"] = tf_vars[
f"batch_normalization{bn_suffix}/moving_variance"
]
tf_idx_bn += 1
# Final conv2d
state_dict[f"stems.{stem_name}.up_final.weight"] = np.transpose(
tf_vars[f"conv2d_{tf_idx_conv}/kernel"], (3, 2, 0, 1)
)
state_dict[f"stems.{stem_name}.up_final.bias"] = tf_vars[
f"conv2d_{tf_idx_conv}/bias"
]
tf_idx_conv += 1
return state_dict
|