Yongyi Zang commited on
Commit
7872d8f
·
1 Parent(s): 09cf0e5

Init Commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/*.pt filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import gradio as gr
7
+ from model import UFormer, UFormerConfig
8
+
9
+ # ——————————————————————
10
+ # 1) Setup & model loading from local checkpoints
11
+ # ——————————————————————
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ CHECKPOINT_DIR = "checkpoints"
14
+ config = UFormerConfig()
15
+ _model_cache = {}
16
+
17
+ VALID_CKPTS = [
18
+ "acoustic_guitar","bass","electric_guitar","guitars","keyboards",
19
+ "orchestra","rhythm_section","synth","vocals"
20
+ ]
21
+
22
+ def _get_model(ckpt_name: str):
23
+ if ckpt_name not in VALID_CKPTS:
24
+ raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
25
+ if ckpt_name in _model_cache:
26
+ return _model_cache[ckpt_name]
27
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pt")
28
+ model = UFormer(config).to(DEVICE).eval()
29
+ state = torch.load(ckpt_path, map_location=DEVICE)
30
+ model.load_state_dict(state)
31
+ _model_cache[ckpt_name] = model
32
+ return model
33
+
34
+ # ——————————————————————
35
+ # 2) Overlap-add for long audio
36
+ # ——————————————————————
37
+ def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=2.5):
38
+ C, T = x.shape
39
+ chunk, hop = int(sr*chunk_s), int(sr*hop_s)
40
+ pad = (-(T - chunk) % hop) if T > chunk else 0
41
+ x_pad = np.pad(x, ((0,0),(0,pad)), mode="reflect")
42
+ win = np.hanning(chunk)[None, :]
43
+ out = np.zeros_like(x_pad)
44
+ norm = np.zeros((1, x_pad.shape[1]))
45
+ n_chunks = 1 + (x_pad.shape[1] - chunk) // hop
46
+
47
+ for i in range(n_chunks):
48
+ s = i * hop
49
+ seg = x_pad[:, s:s+chunk].astype(np.float32)
50
+ with torch.no_grad():
51
+ y = model(torch.from_numpy(seg[None]).to(DEVICE)).squeeze(0).cpu().numpy()
52
+ out[:, s:s+chunk] += y * win
53
+ norm[:, s:s+chunk] += win
54
+
55
+ return (out / norm)[:, :T]
56
+
57
+ # ——————————————————————
58
+ # 3) Restore function for Gradio
59
+ # ——————————————————————
60
+ def restore_fn(audio_path, checkpoint):
61
+ audio, sr = sf.read(audio_path)
62
+ if audio.ndim == 1:
63
+ audio = np.stack([audio, audio], axis=1)
64
+ x = audio.T # (C, T)
65
+
66
+ model = _get_model(checkpoint)
67
+ if x.shape[1] <= sr * 5:
68
+ seg = x.astype(np.float32)[None]
69
+ with torch.no_grad():
70
+ y = model(torch.from_numpy(seg).to(DEVICE)).squeeze(0).cpu().numpy()
71
+ else:
72
+ y = _overlap_add(model, x, sr)
73
+
74
+ tmp = "restored.wav"
75
+ sf.write(tmp, y.T, sr, format="WAV")
76
+ return tmp
77
+
78
+ # ——————————————————————
79
+ # 4) Gradio App
80
+ # ——————————————————————
81
+ demo = gr.Interface(
82
+ fn=restore_fn,
83
+ inputs=[
84
+ gr.Audio(source="upload", type="filepath", label="Your Input"),
85
+ gr.Dropdown(VALID_CKPTS, label="Checkpoint")
86
+ ],
87
+ outputs=gr.Audio(type="filepath", label="Restored Output"),
88
+ title="🎵 Music Source Restoration",
89
+ description="Upload a WAV file and choose an instrument/group checkpoint to restore.",
90
+ allow_flagging="never"
91
+ )
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()
95
+ else:
96
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
checkpoints/acoustic_guitar.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43c8060f061fea8d9dd42e7244004cbbbdb5672e353dfb1a8de5dcc2837ff848
3
+ size 57419739
checkpoints/bass.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5abfb0f75f1f10d07f483acca4612494767d676a62035877b81c48d67db7d73f
3
+ size 57419739
checkpoints/electric_guitar.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:615143058e675760c757ac2eab996c8c426d4163f0b64949f112e5cc0c4072e4
3
+ size 57419739
checkpoints/guitars.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fcf5090367aee602bda4ccc6cb345da127cb7139be669bdd4ad9aad5b025a0d
3
+ size 57419739
checkpoints/keyboards.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:132a79c2d6476a00e818031e097a0555662350510aa9d0733a679d34e3acf2c5
3
+ size 57419739
checkpoints/orchestra.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2c68f28302f6256008185c98db7b4610a606dce973fbf7f605627b19ef7cbab
3
+ size 57419739
checkpoints/rhythm_section.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5aa8c854981359d3564b720c338d19c614029570586e856b0576124515bf01e2
3
+ size 57419739
checkpoints/synth.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b54daf5f65b9eeaef7d98efcdfd9b17616f732d1127a4a32c9a0bdd11689c4a
3
+ size 57419739
checkpoints/vocals.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1029b7c5f3fb06969f740a7583ca27a6944a8cd078b2cd5c6169dc512dd7a097
3
+ size 57419739
model.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ import numpy as np
8
+ from dataclasses import dataclass
9
+
10
+ class Fourier(nn.Module):
11
+
12
+ def __init__(self,
13
+ n_fft=2048,
14
+ hop_length=441,
15
+ return_complex=True,
16
+ normalized=True
17
+ ):
18
+ super(Fourier, self).__init__()
19
+
20
+ self.n_fft = n_fft
21
+ self.hop_length = hop_length
22
+ self.return_complex = return_complex
23
+ self.normalized = normalized
24
+
25
+ def stft(self, waveform):
26
+ """
27
+ Args:
28
+ waveform: (b, c, samples_num)
29
+
30
+ Returns:
31
+ complex_sp: (b, c, t, f)
32
+ """
33
+
34
+ B, C, T = waveform.shape
35
+
36
+ x = rearrange(waveform, 'b c t -> (b c) t')
37
+
38
+ x = torch.stft(
39
+ input=x,
40
+ n_fft=self.n_fft,
41
+ hop_length=self.hop_length,
42
+ window=torch.hann_window(self.n_fft).to(x.device),
43
+ normalized=self.normalized,
44
+ return_complex=self.return_complex
45
+ )
46
+ # shape: (batch_size * channels_num, freq_bins, frames_num)
47
+
48
+ complex_sp = rearrange(x, '(b c) f t -> b c t f', b=B, c=C)
49
+ # shape: (batch_size, channels_num, frames_num, freq_bins)
50
+
51
+ return complex_sp
52
+
53
+ def istft(self, complex_sp):
54
+ """
55
+ Args:
56
+ complex_sp: (batch_size, channels_num, frames_num, freq_bins)
57
+
58
+ Returns:
59
+ waveform: (batch_size, channels_num, samples_num)
60
+ """
61
+
62
+ B, C, T, F = complex_sp.shape
63
+
64
+ x = rearrange(complex_sp, 'b c t f -> (b c) f t')
65
+
66
+ x = torch.istft(
67
+ input=x,
68
+ n_fft=self.n_fft,
69
+ hop_length=self.hop_length,
70
+ window=torch.hann_window(self.n_fft).to(x.device),
71
+ normalized=self.normalized,
72
+ )
73
+ # shape: (batch_size * channels_num, samples_num)
74
+
75
+ x = rearrange(x, '(b c) t -> b c t', b=B, c=C)
76
+ # shape: (batch_size, channels_num, samples_num)
77
+
78
+ return x
79
+
80
+ class Block(nn.Module):
81
+ def __init__(self, config) -> None:
82
+ super().__init__()
83
+ self.att_norm = RMSNorm(config.n_embd)
84
+ self.att = SelfAttention(config)
85
+ self.ffn_norm = RMSNorm(config.n_embd)
86
+ self.mlp = MLP(config)
87
+
88
+ def forward(
89
+ self,
90
+ x: torch.Tensor,
91
+ rope: torch.Tensor,
92
+ mask: torch.Tensor,
93
+ ) -> torch.Tensor:
94
+ r"""
95
+
96
+ Args:
97
+ x: (b, t, d)
98
+ rope: (t, head_dim/2)
99
+ mask: (1, 1, t, t)
100
+
101
+ Outputs:
102
+ x: (b, t, d)
103
+ """
104
+ x = x + self.att(self.att_norm(x), rope, mask)
105
+ x = x + self.mlp(self.ffn_norm(x))
106
+ return x
107
+
108
+
109
+ class RMSNorm(nn.Module):
110
+ r"""Root Mean Square Layer Normalization.
111
+
112
+ Ref: https://github.com/meta-llama/llama/blob/main/llama/model.py
113
+ """
114
+ def __init__(self, dim: int, eps: float = 1e-6):
115
+
116
+ super().__init__()
117
+ self.eps = eps
118
+ self.scale = nn.Parameter(torch.ones(dim))
119
+
120
+ def forward(self, x):
121
+ r"""RMSNorm.
122
+
123
+ Args:
124
+ x: (b, t, d)
125
+
126
+ Outputs:
127
+ x: (b, t, d)
128
+ """
129
+ norm_x = torch.mean(x ** 2, dim=-1, keepdim=True)
130
+ output = x * torch.rsqrt(norm_x + self.eps) * self.scale
131
+ return output
132
+
133
+
134
+ class SelfAttention(nn.Module):
135
+ def __init__(self, config) -> None:
136
+ super().__init__()
137
+ assert config.n_embd % config.n_head == 0
138
+
139
+ # key, query, value projections for all heads, but in a batch
140
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
141
+
142
+ # output projection
143
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
144
+
145
+ self.n_head = config.n_head
146
+ self.n_embd = config.n_embd
147
+
148
+ def forward(
149
+ self,
150
+ x: torch.Tensor,
151
+ rope: torch.Tensor,
152
+ mask: torch.Tensor,
153
+ ) -> torch.Tensor:
154
+ r"""Causal self attention.
155
+
156
+ b: batch size
157
+ t: time steps
158
+ d: latent dim
159
+ h: heads num
160
+
161
+ Args:
162
+ x: (b, t, d)
163
+ rope: (t, head_dim/2, 2)
164
+ mask: (1, 1, )
165
+
166
+ Outputs:
167
+ x: (b, t, d)
168
+ """
169
+ B, T, D = x.shape
170
+
171
+ # Calculate query, key, values
172
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
173
+ # q, k, v shapes: (b, t, d)
174
+
175
+ k = k.view(B, T, self.n_head, D // self.n_head)
176
+ q = q.view(B, T, self.n_head, D // self.n_head)
177
+ v = v.view(B, T, self.n_head, D // self.n_head)
178
+ # q, k, v shapes: (b, t, h, head_dim)
179
+
180
+ q = apply_rope(q, rope)
181
+ k = apply_rope(k, rope)
182
+ # q, k shapes: (b, t, h, head_dim)
183
+
184
+ k = k.transpose(1, 2)
185
+ q = q.transpose(1, 2)
186
+ v = v.transpose(1, 2)
187
+ # q, k, v shapes: (b, h, t, head_dim)
188
+
189
+ # Efficient attention using Flash Attention CUDA kernels
190
+ x = F.scaled_dot_product_attention(
191
+ query=q,
192
+ key=k,
193
+ value=v,
194
+ attn_mask=mask,
195
+ dropout_p=0.0
196
+ )
197
+ # shape: (b, h, t, head_dim)
198
+
199
+ x = x.transpose(1, 2).contiguous().view(B, T, D) # shape: (b, t, d)
200
+
201
+ # output projection
202
+ x = self.c_proj(x) # shape: (b, t, d)
203
+
204
+ return x
205
+
206
+
207
+ class MLP(nn.Module):
208
+ def __init__(self, config) -> None:
209
+ super().__init__()
210
+
211
+ # The hyper-parameters follow https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
212
+ hidden_dim = 4 * config.n_embd
213
+ n_hidden = int(2 * hidden_dim / 3)
214
+
215
+ self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
216
+ self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
217
+ self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
218
+
219
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
220
+ r"""Causal self attention.
221
+
222
+ Args:
223
+ x: (b, t, d)
224
+
225
+ Outputs:
226
+ x: (b, t, d)
227
+ """
228
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
229
+ x = self.c_proj(x)
230
+ return x
231
+
232
+ def build_rope(
233
+ seq_len: int, head_dim: int, base: int = 10000
234
+ ) -> torch.Tensor:
235
+ r"""Rotary Position Embedding.
236
+ Modified from: https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
237
+
238
+ Args:
239
+ seq_len: int, e.g., 1024
240
+ head_dim: head dim, e.g., 768/24
241
+ base: int
242
+
243
+ Outputs:
244
+ cache: (t, head_dim/2, 2)
245
+ """
246
+
247
+ theta = 1.0 / (base ** (torch.arange(0, head_dim, 2) / head_dim))
248
+
249
+ seq_idx = torch.arange(seq_len)
250
+
251
+ # Calculate the product of position index and $\theta_i$
252
+ idx_theta = torch.outer(seq_idx, theta).float()
253
+
254
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
255
+
256
+ return cache
257
+
258
+
259
+ def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
260
+ # truncate to support variable sizes
261
+ T = x.size(1)
262
+ rope_cache = rope_cache[:T]
263
+
264
+ # cast because the reference does
265
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
266
+ rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
267
+ x_out2 = torch.stack(
268
+ [
269
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
270
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
271
+ ],
272
+ -1,
273
+ )
274
+
275
+ x_out2 = x_out2.flatten(3)
276
+ return x_out2.type_as(x)
277
+
278
+
279
+ @dataclass
280
+ class UFormerConfig:
281
+ sr: float = 44100
282
+ n_fft: int = 2048
283
+ hop_length: int = 441
284
+
285
+ n_layer: int = 6
286
+ n_head: int = 8
287
+ n_embd: int = 256
288
+
289
+ class UFormer(Fourier):
290
+ def __init__(self, config: UFormerConfig) -> None:
291
+
292
+ super(UFormer, self).__init__(
293
+ n_fft=config.n_fft,
294
+ hop_length=config.hop_length,
295
+ return_complex=True,
296
+ normalized=True
297
+ )
298
+
299
+ self.ds_factor = 16 # Downsample factor
300
+ self.fps = config.sr // config.hop_length
301
+
302
+ self.audio_channels = 2
303
+ self.cmplx_num = 2
304
+ in_channels = self.audio_channels * self.cmplx_num
305
+
306
+ self.encoder_block1 = EncoderBlock(in_channels, 16)
307
+ self.encoder_block2 = EncoderBlock(16, 64)
308
+ self.encoder_block3 = EncoderBlock(64, 256)
309
+ self.encoder_block4 = EncoderBlock(256, config.n_embd)
310
+ self.decoder_block1 = DecoderBlock(config.n_embd, 256)
311
+ self.decoder_block2 = DecoderBlock(256, 64)
312
+ self.decoder_block3 = DecoderBlock(64, 16)
313
+ self.decoder_block4 = DecoderBlock(16, 16)
314
+
315
+ self.t_blocks = nn.ModuleList(Block(config) for _ in range(config.n_layer))
316
+ self.f_blocks = nn.ModuleList(Block(config) for _ in range(config.n_layer))
317
+ self.head_dim = config.n_embd // config.n_head
318
+
319
+ t_rope = build_rope(seq_len=config.n_fft // 16, head_dim=self.head_dim)
320
+ f_rope = build_rope(seq_len=self.fps * 20, head_dim=self.head_dim)
321
+ self.register_buffer(name="t_rope", tensor=t_rope) # shape: (t, head_dim/2, 2)
322
+ self.register_buffer(name="f_rope", tensor=f_rope) # shape: (t, head_dim/2, 2)
323
+
324
+ self.post_fc = nn.Conv2d(
325
+ in_channels=16,
326
+ out_channels=in_channels,
327
+ kernel_size=1,
328
+ padding=0,
329
+ )
330
+
331
+ def forward(self, audio):
332
+ """Separation model.
333
+
334
+ b: batch_size
335
+ c: channels_num
336
+ l: audio_samples
337
+ t: frames_num
338
+ f: freq_bins
339
+
340
+ Args:
341
+ audio: (b, c, t)
342
+
343
+ Outputs:
344
+ output: (b, c, t)
345
+ """
346
+
347
+ # Complex spectrum
348
+ complex_sp = self.stft(audio) # shape: (b, c, t, f)
349
+
350
+ x = torch.view_as_real(complex_sp) # shape: (b, c, t, f, 2)
351
+ x = rearrange(x, 'b c t f k -> b (c k) t f') # shape: (b, d, t, f)
352
+
353
+ # pad stft
354
+ x, pad_t = self.pad_tensor(x) # x: (b, d, t, f)
355
+ B = x.shape[0]
356
+
357
+ x1, latent1 = self.encoder_block1(x)
358
+ x2, latent2 = self.encoder_block2(x1)
359
+ x3, latent3 = self.encoder_block3(x2)
360
+ x, latent4 = self.encoder_block4(x3)
361
+ for t_block, f_block in zip(self.t_blocks, self.f_blocks):
362
+
363
+ x = rearrange(x, 'b d t f -> (b f) t d')
364
+ x = t_block(x, self.t_rope, mask=None) # shape: (b*f, t, d)
365
+
366
+ x = rearrange(x, '(b f) t d -> (b t) f d', b=B)
367
+ x = f_block(x, self.f_rope, mask=None) # shape: (b*t, f, d)
368
+
369
+ x = rearrange(x, '(b t) f d -> b d t f', b=B) # shape: (b, d, t, f)
370
+ x5 = self.decoder_block1(x, latent4)
371
+ x6 = self.decoder_block2(x5, latent3)
372
+ x7 = self.decoder_block3(x6, latent2)
373
+ x8 = self.decoder_block4(x7, latent1)
374
+ x = self.post_fc(x8)
375
+
376
+ x = rearrange(x, 'b (c k) t f -> b c t f k', k=self.cmplx_num).contiguous()
377
+ x = x.to(torch.float) # compatible with bf16
378
+ mask = torch.view_as_complex(x) # shape: (b, c, t, f)
379
+
380
+ # Unpad mask to the original shape
381
+ mask = self.unpad_tensor(mask, pad_t) # shape: (b, c, t, f)
382
+
383
+ # Calculate stft of separated audio
384
+ # sep_stft = mask * complex_sp # shape: (b, c, t, f)
385
+
386
+ # ISTFT
387
+ output = self.istft(mask) # shape: (b, c, l)
388
+
389
+ return output
390
+
391
+ def pad_tensor(self, x: torch.Tensor) -> tuple[torch.Tensor, int]:
392
+ """Pad a spectrum that can be evenly divided by downsample_ratio.
393
+
394
+ Args:
395
+ x: E.g., (b, c, t=201, f=1025)
396
+
397
+ Outpus:
398
+ output: E.g., (b, c, t=208, f=1024)
399
+ """
400
+
401
+ # Pad last frames, e.g., 201 -> 208
402
+ T = x.shape[2]
403
+ pad_t = -T % self.ds_factor
404
+ x = F.pad(x, pad=(0, 0, 0, pad_t))
405
+
406
+ # Remove last frequency bin, e.g., 1025 -> 1024
407
+ x = x[:, :, :, 0 : -1]
408
+
409
+ return x, pad_t
410
+
411
+ def unpad_tensor(self, x: torch.Tensor, pad_t: int) -> torch.Tensor:
412
+ """Unpad a spectrum to the original shape.
413
+
414
+ Args:
415
+ x: E.g., (b, c, t=208, f=1024)
416
+
417
+ Outpus:
418
+ x: E.g., (b, c, t=201, f=1025)
419
+ """
420
+
421
+ # Pad last frequency bin, e.g., 1024 -> 1025
422
+ x = F.pad(x, pad=(0, 1))
423
+
424
+ # Unpad last frames, e.g., 208 -> 201
425
+ x = x[:, :, 0 : -pad_t, :]
426
+
427
+ return x
428
+
429
+
430
+ class ConvBlock(nn.Module):
431
+ def __init__(
432
+ self, in_channels, out_channels, kernel_size):
433
+ r"""Residual block."""
434
+ super(ConvBlock, self).__init__()
435
+
436
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
437
+
438
+ self.bn1 = nn.BatchNorm2d(in_channels)
439
+ self.bn2 = nn.BatchNorm2d(out_channels)
440
+
441
+ self.conv1 = nn.Conv2d(
442
+ in_channels=in_channels,
443
+ out_channels=out_channels,
444
+ kernel_size=kernel_size,
445
+ padding=padding,
446
+ bias=False,
447
+ )
448
+
449
+ self.conv2 = nn.Conv2d(
450
+ in_channels=out_channels,
451
+ out_channels=out_channels,
452
+ kernel_size=kernel_size,
453
+ padding=padding,
454
+ bias=False,
455
+ )
456
+
457
+ if in_channels != out_channels:
458
+ self.shortcut = nn.Conv2d(
459
+ in_channels=in_channels,
460
+ out_channels=out_channels,
461
+ kernel_size=(1, 1),
462
+ padding=(0, 0),
463
+ )
464
+ self.is_shortcut = True
465
+ else:
466
+ self.is_shortcut = False
467
+
468
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
469
+ """
470
+ Args:
471
+ x: (b, c_in, t, f)
472
+
473
+ Returns:
474
+ output: (b, c_out, t, f)
475
+ """
476
+ h = self.conv1(F.leaky_relu_(self.bn1(x)))
477
+ h = self.conv2(F.leaky_relu_(self.bn2(h)))
478
+
479
+ if self.is_shortcut:
480
+ return self.shortcut(x) + h
481
+ else:
482
+ return x + h
483
+
484
+
485
+ class EncoderBlock(nn.Module):
486
+ def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
487
+ super(EncoderBlock, self).__init__()
488
+
489
+ self.pool_size = 2
490
+
491
+ self.conv_block = ConvBlock(in_channels, out_channels, kernel_size)
492
+
493
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
494
+ """
495
+ Args:
496
+ x: (b, c_in, t, f)
497
+
498
+ Returns:
499
+ latent: (b, c_out, t, f)
500
+ output: (b, c_out, t/2, f/2)
501
+ """
502
+
503
+ latent = self.conv_block(x) # shape: (b, c_out, t, f)
504
+ output = F.avg_pool2d(latent, kernel_size=self.pool_size) # shape: (b, c_out, t/2, f/2)
505
+ return output, latent
506
+
507
+
508
+ class DecoderBlock(nn.Module):
509
+ def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
510
+ super(DecoderBlock, self).__init__()
511
+
512
+ stride = 2
513
+
514
+ self.upsample = torch.nn.ConvTranspose2d(
515
+ in_channels=in_channels,
516
+ out_channels=in_channels,
517
+ kernel_size=stride,
518
+ stride=stride,
519
+ padding=(0, 0),
520
+ bias=False,
521
+ )
522
+
523
+ self.conv_block = ConvBlock(in_channels * 2, out_channels, kernel_size)
524
+
525
+ def forward(self, x: torch.Tensor, latent: torch.Tensor) -> torch.Tensor:
526
+ """
527
+ Args:
528
+ x: (b, c_in, t/2, f/2)
529
+
530
+ Returns:
531
+ output: (b, c_out, t, f)
532
+ """
533
+
534
+ x = self.upsample(x) # shape: (b, c_in, t, f)
535
+ x = torch.cat((x, latent), dim=1) # shape: (b, 2*c_in, t, f)
536
+ x = self.conv_block(x) # shape: (b, c_out, t, f)
537
+
538
+ return x
539
+
540
+ if __name__ == "__main__":
541
+ # Example usage
542
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
543
+ config = UFormerConfig()
544
+ model = UFormer(config)
545
+ checkpoint_path = None
546
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
547
+ model.to(device)
548
+ audio = torch.randn(1, 2, 10*44100).to(device) # Example audio input (batch_size=1, channels=2, samples=88200)
549
+ output = model(audio)
550
+ print(output.shape) # Output shape
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ soundfile
4
+ gradio
5
+ huggingface-hub