|
from collections import defaultdict
|
|
|
|
from tqdm import tqdm
|
|
from typing import Callable, Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
@torch.jit.script
|
|
def merge(
|
|
combined: torch.Tensor,
|
|
original_batch_size: int,
|
|
n_channel: int,
|
|
n_chunks: int,
|
|
chunk_size: int, ):
|
|
combined = torch.reshape(
|
|
combined,
|
|
(original_batch_size, n_chunks, n_channel, chunk_size)
|
|
)
|
|
combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
|
|
original_batch_size * n_channel,
|
|
chunk_size,
|
|
n_chunks
|
|
)
|
|
|
|
return combined
|
|
|
|
|
|
@torch.jit.script
|
|
def unfold(
|
|
padded_audio: torch.Tensor,
|
|
original_batch_size: int,
|
|
n_channel: int,
|
|
chunk_size: int,
|
|
hop_size: int
|
|
) -> torch.Tensor:
|
|
|
|
unfolded_input = F.unfold(
|
|
padded_audio[:, :, None, :],
|
|
kernel_size=(1, chunk_size),
|
|
stride=(1, hop_size)
|
|
)
|
|
|
|
_, _, n_chunks = unfolded_input.shape
|
|
unfolded_input = unfolded_input.view(
|
|
original_batch_size,
|
|
n_channel,
|
|
chunk_size,
|
|
n_chunks
|
|
)
|
|
unfolded_input = torch.permute(
|
|
unfolded_input,
|
|
(0, 3, 1, 2)
|
|
).reshape(
|
|
original_batch_size * n_chunks,
|
|
n_channel,
|
|
chunk_size
|
|
)
|
|
|
|
return unfolded_input
|
|
|
|
|
|
@torch.jit.script
|
|
|
|
def merge_chunks_all(
|
|
combined: torch.Tensor,
|
|
original_batch_size: int,
|
|
n_channel: int,
|
|
n_samples: int,
|
|
n_padded_samples: int,
|
|
n_chunks: int,
|
|
chunk_size: int,
|
|
hop_size: int,
|
|
edge_frame_pad_sizes: Tuple[int, int],
|
|
standard_window: torch.Tensor,
|
|
first_window: torch.Tensor,
|
|
last_window: torch.Tensor
|
|
):
|
|
combined = merge(
|
|
combined,
|
|
original_batch_size,
|
|
n_channel,
|
|
n_chunks,
|
|
chunk_size
|
|
)
|
|
|
|
combined = combined * standard_window[:, None].to(combined.device)
|
|
|
|
combined = F.fold(
|
|
combined.to(torch.float32), output_size=(1, n_padded_samples),
|
|
kernel_size=(1, chunk_size),
|
|
stride=(1, hop_size)
|
|
)
|
|
|
|
combined = combined.view(
|
|
original_batch_size,
|
|
n_channel,
|
|
n_padded_samples
|
|
)
|
|
|
|
pad_front, pad_back = edge_frame_pad_sizes
|
|
combined = combined[..., pad_front:-pad_back]
|
|
|
|
combined = combined[..., :n_samples]
|
|
|
|
return combined
|
|
|
|
|
|
|
|
|
|
def merge_chunks_edge(
|
|
combined: torch.Tensor,
|
|
original_batch_size: int,
|
|
n_channel: int,
|
|
n_samples: int,
|
|
n_padded_samples: int,
|
|
n_chunks: int,
|
|
chunk_size: int,
|
|
hop_size: int,
|
|
edge_frame_pad_sizes: Tuple[int, int],
|
|
standard_window: torch.Tensor,
|
|
first_window: torch.Tensor,
|
|
last_window: torch.Tensor
|
|
):
|
|
combined = merge(
|
|
combined,
|
|
original_batch_size,
|
|
n_channel,
|
|
n_chunks,
|
|
chunk_size
|
|
)
|
|
|
|
combined[..., 0] = combined[..., 0] * first_window
|
|
combined[..., -1] = combined[..., -1] * last_window
|
|
combined[..., 1:-1] = combined[...,
|
|
1:-1] * standard_window[:, None]
|
|
|
|
combined = F.fold(
|
|
combined, output_size=(1, n_padded_samples),
|
|
kernel_size=(1, chunk_size),
|
|
stride=(1, hop_size)
|
|
)
|
|
|
|
combined = combined.view(
|
|
original_batch_size,
|
|
n_channel,
|
|
n_padded_samples
|
|
)
|
|
|
|
combined = combined[..., :n_samples]
|
|
|
|
return combined
|
|
|
|
|
|
class BaseFader(nn.Module):
|
|
def __init__(
|
|
self,
|
|
chunk_size_second: float,
|
|
hop_size_second: float,
|
|
fs: int,
|
|
fade_edge_frames: bool,
|
|
batch_size: int,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.chunk_size = int(chunk_size_second * fs)
|
|
self.hop_size = int(hop_size_second * fs)
|
|
self.overlap_size = self.chunk_size - self.hop_size
|
|
self.fade_edge_frames = fade_edge_frames
|
|
self.batch_size = batch_size
|
|
|
|
|
|
def prepare(self, audio):
|
|
|
|
if self.fade_edge_frames:
|
|
audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
|
|
|
|
n_samples = audio.shape[-1]
|
|
n_chunks = int(
|
|
np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1
|
|
)
|
|
|
|
padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
|
|
pad_size = padded_size - n_samples
|
|
|
|
padded_audio = F.pad(audio, (0, pad_size))
|
|
|
|
return padded_audio, n_chunks
|
|
|
|
def forward(
|
|
self,
|
|
audio: torch.Tensor,
|
|
model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
|
):
|
|
|
|
original_dtype = audio.dtype
|
|
original_device = audio.device
|
|
|
|
audio = audio.to("cpu")
|
|
|
|
original_batch_size, n_channel, n_samples = audio.shape
|
|
padded_audio, n_chunks = self.prepare(audio)
|
|
del audio
|
|
n_padded_samples = padded_audio.shape[-1]
|
|
|
|
if n_channel > 1:
|
|
padded_audio = padded_audio.view(
|
|
original_batch_size * n_channel, 1, n_padded_samples
|
|
)
|
|
|
|
unfolded_input = unfold(
|
|
padded_audio,
|
|
original_batch_size,
|
|
n_channel,
|
|
self.chunk_size, self.hop_size
|
|
)
|
|
|
|
n_total_chunks, n_channel, chunk_size = unfolded_input.shape
|
|
|
|
n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
|
|
|
|
chunks_in = [
|
|
unfolded_input[
|
|
b * self.batch_size:(b + 1) * self.batch_size, ...].clone()
|
|
for b in range(n_batch)
|
|
]
|
|
|
|
all_chunks_out = defaultdict(
|
|
lambda: torch.zeros_like(
|
|
unfolded_input, device="cpu"
|
|
)
|
|
)
|
|
|
|
|
|
for b, cin in enumerate(chunks_in):
|
|
if torch.allclose(cin, torch.tensor(0.0)):
|
|
del cin
|
|
continue
|
|
|
|
chunks_out = model_fn(cin.to(original_device))
|
|
del cin
|
|
for s, c in chunks_out.items():
|
|
all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size,
|
|
...] = c.cpu()
|
|
del chunks_out
|
|
|
|
del unfolded_input
|
|
del padded_audio
|
|
|
|
if self.fade_edge_frames:
|
|
fn = merge_chunks_all
|
|
else:
|
|
fn = merge_chunks_edge
|
|
outputs = {}
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
for s, c in all_chunks_out.items():
|
|
combined: torch.Tensor = fn(
|
|
c,
|
|
original_batch_size,
|
|
n_channel,
|
|
n_samples,
|
|
n_padded_samples,
|
|
n_chunks,
|
|
self.chunk_size,
|
|
self.hop_size,
|
|
self.edge_frame_pad_sizes,
|
|
self.standard_window,
|
|
self.__dict__.get("first_window", self.standard_window),
|
|
self.__dict__.get("last_window", self.standard_window)
|
|
)
|
|
|
|
outputs[s] = combined.to(
|
|
dtype=original_dtype,
|
|
device=original_device
|
|
)
|
|
|
|
return {
|
|
"audio": outputs
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinearFader(BaseFader):
|
|
def __init__(
|
|
self,
|
|
chunk_size_second: float,
|
|
hop_size_second: float,
|
|
fs: int,
|
|
fade_edge_frames: bool = False,
|
|
batch_size: int = 1,
|
|
) -> None:
|
|
|
|
assert hop_size_second >= chunk_size_second / 2
|
|
|
|
super().__init__(
|
|
chunk_size_second=chunk_size_second,
|
|
hop_size_second=hop_size_second,
|
|
fs=fs,
|
|
fade_edge_frames=fade_edge_frames,
|
|
batch_size=batch_size,
|
|
)
|
|
|
|
in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
|
|
out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
|
|
center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
|
|
inout_ones = torch.ones(self.overlap_size)
|
|
|
|
|
|
self.register_buffer(
|
|
"standard_window",
|
|
torch.concat([in_fade, center_ones, out_fade])
|
|
)
|
|
|
|
self.fade_edge_frames = fade_edge_frames
|
|
self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
|
|
|
|
if not self.fade_edge_frames:
|
|
self.first_window = nn.Parameter(
|
|
torch.concat([inout_ones, center_ones, out_fade]),
|
|
requires_grad=False
|
|
)
|
|
self.last_window = nn.Parameter(
|
|
torch.concat([in_fade, center_ones, inout_ones]),
|
|
requires_grad=False
|
|
)
|
|
|
|
|
|
class OverlapAddFader(BaseFader):
|
|
def __init__(
|
|
self,
|
|
window_type: str,
|
|
chunk_size_second: float,
|
|
hop_size_second: float,
|
|
fs: int,
|
|
batch_size: int = 1,
|
|
) -> None:
|
|
assert (chunk_size_second / hop_size_second) % 2 == 0
|
|
assert int(chunk_size_second * fs) % 2 == 0
|
|
|
|
super().__init__(
|
|
chunk_size_second=chunk_size_second,
|
|
hop_size_second=hop_size_second,
|
|
fs=fs,
|
|
fade_edge_frames=True,
|
|
batch_size=batch_size,
|
|
)
|
|
|
|
self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
|
|
|
|
|
|
self.edge_frame_pad_sizes = (
|
|
2 * self.overlap_size,
|
|
2 * self.overlap_size
|
|
)
|
|
|
|
self.register_buffer(
|
|
"standard_window", torch.windows.__dict__[window_type](
|
|
self.chunk_size, sym=False,
|
|
) / self.hop_multiplier
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import torchaudio as ta
|
|
fs = 44100
|
|
ola = OverlapAddFader(
|
|
"hann",
|
|
6.0,
|
|
1.0,
|
|
fs,
|
|
batch_size=16
|
|
)
|
|
audio_, _ = ta.load(
|
|
"$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too "
|
|
"Much/vocals.wav"
|
|
)
|
|
audio_ = audio_[None, ...]
|
|
out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
|
|
print(torch.allclose(out, audio_))
|
|
|