pengdaqian commited on
Commit
3064720
·
1 Parent(s): 62e9d65
Files changed (2) hide show
  1. app.py +0 -4
  2. torchspleeter/tf2pytorch.py +0 -122
app.py CHANGED
@@ -1,9 +1,5 @@
1
  import os
2
  import sys
3
- import threading
4
-
5
- os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
6
-
7
  from music.search import get_youtube, download_random
8
  from vits.models import SynthesizerInfer
9
  import whisper.inference
 
1
  import os
2
  import sys
 
 
 
 
3
  from music.search import get_youtube, download_random
4
  from vits.models import SynthesizerInfer
5
  import whisper.inference
torchspleeter/tf2pytorch.py DELETED
@@ -1,122 +0,0 @@
1
- from typing import Dict
2
- import numpy as np
3
-
4
- import os
5
-
6
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
- import tensorflow as tf
8
-
9
-
10
- def parse_int_or_default(s: str, default: int = 0) -> int:
11
- try:
12
- return int(s)
13
- except:
14
- return default
15
-
16
-
17
- def tf2pytorch(checkpoint_path: str) -> Dict:
18
- init_vars = tf.train.list_variables(checkpoint_path)
19
-
20
- tf_vars = {}
21
- for name, _ in init_vars:
22
- try:
23
- # print('Loading TF Weight {} with shape {}'.format(name, shape))
24
- data = tf.train.load_variable(checkpoint_path, name)
25
- tf_vars[name] = data
26
- except Exception as e:
27
- print(f"Load error: {name}")
28
- raise
29
-
30
- layer_idxs = set(
31
- [
32
- parse_int_or_default(name.split("/")[0].split("_")[-1], default=0)
33
- for name in tf_vars.keys()
34
- if "conv2d_transpose" in name
35
- ]
36
- )
37
-
38
- n_layers_per_unet = 6
39
- n_layers_in_chkpt = max(layer_idxs) + 1
40
- assert (
41
- n_layers_in_chkpt % 6 == 0
42
- ), f"expected multiple of {n_layers_per_unet}... ie: {n_layers_per_unet} layers per unet & 1 unet per stem"
43
- n_stems = n_layers_in_chkpt // n_layers_per_unet
44
-
45
- stem_names = {
46
- 2: ["vocals", "accompaniment"],
47
- 4: ["vocals", "drums", "bass", "other"],
48
- 5: ["vocals", "piano", "drums", "bass", "other"],
49
- }.get(n_stems, [])
50
-
51
- assert stem_names, f"Unsupported stem count: {n_stems}"
52
-
53
- state_dict = {}
54
- tf_idx_conv = 0
55
- tf_idx_tconv = 0
56
- tf_idx_bn = 0
57
-
58
- for stem_name in stem_names:
59
- # Encoder Blocks (Down sampling)
60
- for layer_idx in range(n_layers_per_unet):
61
- prefix = f"stems.{stem_name}.encoder_layers.{layer_idx}"
62
- conv_suffix = "" if tf_idx_conv == 0 else f"_{tf_idx_conv}"
63
- bn_suffix = "" if tf_idx_bn == 0 else f"_{tf_idx_bn}"
64
-
65
- state_dict[f"{prefix}.conv.weight"] = np.transpose(
66
- tf_vars[f"conv2d{conv_suffix}/kernel"], (3, 2, 0, 1)
67
- )
68
- state_dict[f"{prefix}.conv.bias"] = tf_vars[f"conv2d{conv_suffix}/bias"]
69
- tf_idx_conv += 1
70
-
71
- state_dict[f"{prefix}.bn.weight"] = tf_vars[
72
- f"batch_normalization{bn_suffix}/gamma"
73
- ]
74
- state_dict[f"{prefix}.bn.bias"] = tf_vars[
75
- f"batch_normalization{bn_suffix}/beta"
76
- ]
77
- state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
78
- f"batch_normalization{bn_suffix}/moving_mean"
79
- ]
80
- state_dict[f"{prefix}.bn.running_var"] = tf_vars[
81
- f"batch_normalization{bn_suffix}/moving_variance"
82
- ]
83
- tf_idx_bn += 1
84
-
85
- # Decoder Blocks (Up sampling)
86
- for layer_idx in range(n_layers_per_unet):
87
- prefix = f"stems.{stem_name}.decoder_layers.{layer_idx}"
88
- tconv_suffix = "" if tf_idx_tconv == 0 else f"_{tf_idx_tconv}"
89
- bn_suffix = f"_{tf_idx_bn}"
90
-
91
- state_dict[f"{prefix}.tconv.weight"] = np.transpose(
92
- tf_vars[f"conv2d_transpose{tconv_suffix}/kernel"], (3, 2, 0, 1)
93
- )
94
- state_dict[f"{prefix}.tconv.bias"] = tf_vars[
95
- f"conv2d_transpose{tconv_suffix}/bias"
96
- ]
97
- tf_idx_tconv += 1
98
-
99
- state_dict[f"{prefix}.bn.weight"] = tf_vars[
100
- f"batch_normalization{bn_suffix}/gamma"
101
- ]
102
- state_dict[f"{prefix}.bn.bias"] = tf_vars[
103
- f"batch_normalization{bn_suffix}/beta"
104
- ]
105
- state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
106
- f"batch_normalization{bn_suffix}/moving_mean"
107
- ]
108
- state_dict[f"{prefix}.bn.running_var"] = tf_vars[
109
- f"batch_normalization{bn_suffix}/moving_variance"
110
- ]
111
- tf_idx_bn += 1
112
-
113
- # Final conv2d
114
- state_dict[f"stems.{stem_name}.up_final.weight"] = np.transpose(
115
- tf_vars[f"conv2d_{tf_idx_conv}/kernel"], (3, 2, 0, 1)
116
- )
117
- state_dict[f"stems.{stem_name}.up_final.bias"] = tf_vars[
118
- f"conv2d_{tf_idx_conv}/bias"
119
- ]
120
- tf_idx_conv += 1
121
-
122
- return state_dict