File size: 4,148 Bytes
62e9d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import Dict
import numpy as np

import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf


def parse_int_or_default(s: str, default: int = 0) -> int:
    try:
        return int(s)
    except:
        return default


def tf2pytorch(checkpoint_path: str) -> Dict:
    init_vars = tf.train.list_variables(checkpoint_path)

    tf_vars = {}
    for name, _ in init_vars:
        try:
            # print('Loading TF Weight {} with shape {}'.format(name, shape))
            data = tf.train.load_variable(checkpoint_path, name)
            tf_vars[name] = data
        except Exception as e:
            print(f"Load error: {name}")
            raise

    layer_idxs = set(
        [
            parse_int_or_default(name.split("/")[0].split("_")[-1], default=0)
            for name in tf_vars.keys()
            if "conv2d_transpose" in name
        ]
    )

    n_layers_per_unet = 6
    n_layers_in_chkpt = max(layer_idxs) + 1
    assert (
        n_layers_in_chkpt % 6 == 0
    ), f"expected multiple of {n_layers_per_unet}... ie: {n_layers_per_unet} layers per unet & 1 unet per stem"
    n_stems = n_layers_in_chkpt // n_layers_per_unet

    stem_names = {
        2: ["vocals", "accompaniment"],
        4: ["vocals", "drums", "bass", "other"],
        5: ["vocals", "piano", "drums", "bass", "other"],
    }.get(n_stems, [])

    assert stem_names, f"Unsupported stem count: {n_stems}"

    state_dict = {}
    tf_idx_conv = 0
    tf_idx_tconv = 0
    tf_idx_bn = 0

    for stem_name in stem_names:
        # Encoder Blocks (Down sampling)
        for layer_idx in range(n_layers_per_unet):
            prefix = f"stems.{stem_name}.encoder_layers.{layer_idx}"
            conv_suffix = "" if tf_idx_conv == 0 else f"_{tf_idx_conv}"
            bn_suffix = "" if tf_idx_bn == 0 else f"_{tf_idx_bn}"

            state_dict[f"{prefix}.conv.weight"] = np.transpose(
                tf_vars[f"conv2d{conv_suffix}/kernel"], (3, 2, 0, 1)
            )
            state_dict[f"{prefix}.conv.bias"] = tf_vars[f"conv2d{conv_suffix}/bias"]
            tf_idx_conv += 1

            state_dict[f"{prefix}.bn.weight"] = tf_vars[
                f"batch_normalization{bn_suffix}/gamma"
            ]
            state_dict[f"{prefix}.bn.bias"] = tf_vars[
                f"batch_normalization{bn_suffix}/beta"
            ]
            state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
                f"batch_normalization{bn_suffix}/moving_mean"
            ]
            state_dict[f"{prefix}.bn.running_var"] = tf_vars[
                f"batch_normalization{bn_suffix}/moving_variance"
            ]
            tf_idx_bn += 1

        # Decoder Blocks (Up sampling)
        for layer_idx in range(n_layers_per_unet):
            prefix = f"stems.{stem_name}.decoder_layers.{layer_idx}"
            tconv_suffix = "" if tf_idx_tconv == 0 else f"_{tf_idx_tconv}"
            bn_suffix = f"_{tf_idx_bn}"

            state_dict[f"{prefix}.tconv.weight"] = np.transpose(
                tf_vars[f"conv2d_transpose{tconv_suffix}/kernel"], (3, 2, 0, 1)
            )
            state_dict[f"{prefix}.tconv.bias"] = tf_vars[
                f"conv2d_transpose{tconv_suffix}/bias"
            ]
            tf_idx_tconv += 1

            state_dict[f"{prefix}.bn.weight"] = tf_vars[
                f"batch_normalization{bn_suffix}/gamma"
            ]
            state_dict[f"{prefix}.bn.bias"] = tf_vars[
                f"batch_normalization{bn_suffix}/beta"
            ]
            state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
                f"batch_normalization{bn_suffix}/moving_mean"
            ]
            state_dict[f"{prefix}.bn.running_var"] = tf_vars[
                f"batch_normalization{bn_suffix}/moving_variance"
            ]
            tf_idx_bn += 1

        # Final conv2d
        state_dict[f"stems.{stem_name}.up_final.weight"] = np.transpose(
            tf_vars[f"conv2d_{tf_idx_conv}/kernel"], (3, 2, 0, 1)
        )
        state_dict[f"stems.{stem_name}.up_final.bias"] = tf_vars[
            f"conv2d_{tf_idx_conv}/bias"
        ]
        tf_idx_conv += 1

    return state_dict