Spaces:
Building
Building
draft audio sep app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __data_source_separation/source_separation/test/0000/mix_snr_-4.wav +0 -0
- __data_source_separation/source_separation/test/0000/noise.wav +0 -0
- __data_source_separation/source_separation/test/0000/voice.wav +0 -0
- __data_source_separation/source_separation/test/0001/mix_snr_2.wav +0 -0
- __data_source_separation/source_separation/test/0001/noise.wav +0 -0
- __data_source_separation/source_separation/test/0001/voice.wav +0 -0
- __output_audiosep/0004_0000/model_0059.pt +3 -0
- __output_audiosep/1004_0000/model_0119.pt +3 -0
- __output_audiosep/3001_0000/model_0199.pt +3 -0
- app.py +7 -0
- audio_samples/0009/mix_snr_-1.wav +0 -0
- audio_samples/0009/noise.wav +0 -0
- audio_samples/0009/voice.wav +0 -0
- requirements.txt +3 -0
- src/gyraudio/__init__.py +2 -0
- src/gyraudio/audio_separation/architecture/building_block.py +51 -0
- src/gyraudio/audio_separation/architecture/flat_conv.py +62 -0
- src/gyraudio/audio_separation/architecture/model.py +28 -0
- src/gyraudio/audio_separation/architecture/neutral.py +15 -0
- src/gyraudio/audio_separation/architecture/transformer.py +91 -0
- src/gyraudio/audio_separation/architecture/unet.py +151 -0
- src/gyraudio/audio_separation/architecture/wave_unet.py +163 -0
- src/gyraudio/audio_separation/data/__init__.py +5 -0
- src/gyraudio/audio_separation/data/dataloader.py +47 -0
- src/gyraudio/audio_separation/data/dataset.py +104 -0
- src/gyraudio/audio_separation/data/mixed.py +40 -0
- src/gyraudio/audio_separation/data/remixed.py +53 -0
- src/gyraudio/audio_separation/data/remixed_fixed.py +18 -0
- src/gyraudio/audio_separation/data/remixed_rnd.py +12 -0
- src/gyraudio/audio_separation/data/silence_detector.py +55 -0
- src/gyraudio/audio_separation/data/single.py +15 -0
- src/gyraudio/audio_separation/experiment_tracking/experiments.py +122 -0
- src/gyraudio/audio_separation/experiment_tracking/experiments_decorator.py +48 -0
- src/gyraudio/audio_separation/experiment_tracking/experiments_definition.py +320 -0
- src/gyraudio/audio_separation/experiment_tracking/storage.py +65 -0
- src/gyraudio/audio_separation/infer.py +202 -0
- src/gyraudio/audio_separation/metrics.py +101 -0
- src/gyraudio/audio_separation/parser.py +8 -0
- src/gyraudio/audio_separation/properties.py +79 -0
- src/gyraudio/audio_separation/train.py +183 -0
- src/gyraudio/audio_separation/visualization/audio_player.py +84 -0
- src/gyraudio/audio_separation/visualization/interactive_audio.py +392 -0
- src/gyraudio/audio_separation/visualization/interactive_infer.py +117 -0
- src/gyraudio/audio_separation/visualization/mute_logo.png +0 -0
- src/gyraudio/audio_separation/visualization/play_logo_clean.png +0 -0
- src/gyraudio/audio_separation/visualization/play_logo_mixed.png +0 -0
- src/gyraudio/audio_separation/visualization/play_logo_noise.png +0 -0
- src/gyraudio/audio_separation/visualization/play_logo_pred.png +0 -0
- src/gyraudio/audio_separation/visualization/pre_load_audio.py +71 -0
- src/gyraudio/audio_separation/visualization/pre_load_custom_audio.py +53 -0
__data_source_separation/source_separation/test/0000/mix_snr_-4.wav
ADDED
Binary file (320 kB). View file
|
|
__data_source_separation/source_separation/test/0000/noise.wav
ADDED
Binary file (320 kB). View file
|
|
__data_source_separation/source_separation/test/0000/voice.wav
ADDED
Binary file (320 kB). View file
|
|
__data_source_separation/source_separation/test/0001/mix_snr_2.wav
ADDED
Binary file (320 kB). View file
|
|
__data_source_separation/source_separation/test/0001/noise.wav
ADDED
Binary file (320 kB). View file
|
|
__data_source_separation/source_separation/test/0001/voice.wav
ADDED
Binary file (320 kB). View file
|
|
__output_audiosep/0004_0000/model_0059.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef62b0fdd9e9b81000da0db190be1df7e5451d70ab3c86609cab409ec3e38ab8
|
3 |
+
size 34402
|
__output_audiosep/1004_0000/model_0119.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:61768dbf9e81845dc656c5da3e8fe5ea053f73108f422caf7879b4ab55a0792a
|
3 |
+
size 12755810
|
__output_audiosep/3001_0000/model_0199.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11341d63bad0a1ec15c5a94c5cc6f049720869fb9988da45ef7d07edd3c82e21
|
3 |
+
size 12743211
|
app.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
|
4 |
+
os.sys.path.append(src_path)
|
5 |
+
from gyraudio.audio_separation.visualization.interactive_audio import main as interactive_audio_main
|
6 |
+
if __name__ == "__main__":
|
7 |
+
interactive_audio_main(sys.argv[1:])
|
audio_samples/0009/mix_snr_-1.wav
ADDED
Binary file (320 kB). View file
|
|
audio_samples/0009/noise.wav
ADDED
Binary file (320 kB). View file
|
|
audio_samples/0009/voice.wav
ADDED
Binary file (320 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
batch_processing
|
2 |
+
interactive-pipe>=0.7.0
|
3 |
+
torch>=2.0.0
|
src/gyraudio/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
root_dir = Path(__file__).parent.parent.parent
|
src/gyraudio/audio_separation/architecture/building_block.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
class FilterBank(torch.nn.Module):
|
6 |
+
"""Convolution filter bank (linear)
|
7 |
+
Serves as an embedding for the audio signal
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, ch_in: int, out_dim=16, k_size=5, dilation_list: List[int] = [1, 2, 4, 8]):
|
11 |
+
super().__init__()
|
12 |
+
self.out_dim = out_dim
|
13 |
+
self.source_modality_conv = torch.nn.ModuleList()
|
14 |
+
for dilation in dilation_list:
|
15 |
+
self.source_modality_conv.append(
|
16 |
+
torch.nn.Conv1d(ch_in, out_dim//len(dilation_list), k_size, dilation=dilation, padding=(dilation*(k_size//2)))
|
17 |
+
)
|
18 |
+
|
19 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
20 |
+
out = torch.cat([conv(x) for conv in self.source_modality_conv], axis=1)
|
21 |
+
assert out.shape[1] == self.out_dim
|
22 |
+
return out
|
23 |
+
|
24 |
+
|
25 |
+
class ResConvolution(torch.nn.Module):
|
26 |
+
"""ResNet building block
|
27 |
+
https://paperswithcode.com/method/residual-connection
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, ch, hdim=None, k_size=5):
|
31 |
+
super().__init__()
|
32 |
+
hdim = hdim or ch
|
33 |
+
self.conv1 = torch.nn.Conv1d(ch, hdim, k_size, padding=k_size//2)
|
34 |
+
self.conv2 = torch.nn.Conv1d(hdim, ch, k_size, padding=k_size//2)
|
35 |
+
self.non_linearity = torch.nn.ReLU()
|
36 |
+
|
37 |
+
def forward(self, x_in):
|
38 |
+
x = self.conv1(x_in)
|
39 |
+
x = self.non_linearity(x)
|
40 |
+
x = self.conv2(x)
|
41 |
+
x += x_in
|
42 |
+
x = self.non_linearity(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
model = FilterBank(1, 16)
|
48 |
+
inp = torch.rand(2, 1, 2048)
|
49 |
+
out = model(inp)
|
50 |
+
print(model)
|
51 |
+
print(out[0].shape)
|
src/gyraudio/audio_separation/architecture/flat_conv.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from gyraudio.audio_separation.architecture.model import SeparationModel
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class FlatConvolutional(SeparationModel):
|
7 |
+
"""Convolutional neural network for audio separation,
|
8 |
+
No decimation, no bottleneck, just basic signal processing
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self,
|
12 |
+
ch_in: int = 1,
|
13 |
+
ch_out: int = 2,
|
14 |
+
h_dim=16,
|
15 |
+
k_size=5,
|
16 |
+
dilation=1
|
17 |
+
) -> None:
|
18 |
+
super().__init__()
|
19 |
+
self.conv1 = torch.nn.Conv1d(
|
20 |
+
ch_in, h_dim, k_size,
|
21 |
+
dilation=dilation, padding=dilation*(k_size//2))
|
22 |
+
self.conv2 = torch.nn.Conv1d(
|
23 |
+
h_dim, h_dim, k_size,
|
24 |
+
dilation=dilation, padding=dilation*(k_size//2))
|
25 |
+
self.conv3 = torch.nn.Conv1d(
|
26 |
+
h_dim, h_dim, k_size,
|
27 |
+
dilation=dilation, padding=dilation*(k_size//2))
|
28 |
+
self.conv4 = torch.nn.Conv1d(
|
29 |
+
h_dim, h_dim, k_size,
|
30 |
+
dilation=dilation, padding=dilation*(k_size//2))
|
31 |
+
self.relu = torch.nn.ReLU()
|
32 |
+
self.encoder = torch.nn.Sequential(
|
33 |
+
self.conv1,
|
34 |
+
self.relu,
|
35 |
+
self.conv2,
|
36 |
+
self.relu,
|
37 |
+
self.conv3,
|
38 |
+
self.relu,
|
39 |
+
self.conv4,
|
40 |
+
self.relu
|
41 |
+
)
|
42 |
+
self.demux = torch.nn.Sequential(*(
|
43 |
+
torch.nn.Conv1d(h_dim, h_dim//2, 1), # conv1x1
|
44 |
+
torch.nn.ReLU(),
|
45 |
+
torch.nn.Conv1d(h_dim//2, ch_out, 1), # conv1x1
|
46 |
+
))
|
47 |
+
|
48 |
+
def forward(self, mixed_sig_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
49 |
+
"""Perform feature extraction followed by classifier head
|
50 |
+
|
51 |
+
Args:
|
52 |
+
sig_in (torch.Tensor): [N, C, T]
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: logits (not probabilities) [N, n_classes]
|
56 |
+
"""
|
57 |
+
# Convolution backbone
|
58 |
+
# [N, C, T] -> [N, h, T]
|
59 |
+
features = self.encoder(mixed_sig_in)
|
60 |
+
# [N, h, T] -> [N, 2, T]
|
61 |
+
demuxed = self.demux(features)
|
62 |
+
return torch.chunk(demuxed, 2, dim=1) # [N, 1, T], [N, 1, T]
|
src/gyraudio/audio_separation/architecture/model.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class SeparationModel(torch.nn.Module):
|
5 |
+
def __init__(self, *args, **kwargs) -> None:
|
6 |
+
super().__init__(*args, **kwargs)
|
7 |
+
|
8 |
+
def count_parameters(self) -> int:
|
9 |
+
"""Count the total number of parameters of the model
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
int: total amount of parameters
|
13 |
+
"""
|
14 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
15 |
+
|
16 |
+
def receptive_field(self) -> int:
|
17 |
+
"""Compute the receptive field of the model
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
int: receptive field
|
21 |
+
"""
|
22 |
+
input_tensor = torch.rand(1, 1, 4096, requires_grad=True)
|
23 |
+
out, out_noise = self.forward(input_tensor)
|
24 |
+
grad = torch.zeros_like(out)
|
25 |
+
grad[..., out.shape[-1]//2] = torch.nan # set NaN gradient at the middle of the output
|
26 |
+
out.backward(gradient=grad)
|
27 |
+
self.zero_grad() # reset to avoid future problems
|
28 |
+
return int(torch.sum(input_tensor.grad.isnan()).cpu()) # Count NaN in the input
|
src/gyraudio/audio_separation/architecture/neutral.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from gyraudio.audio_separation.architecture.model import SeparationModel
|
4 |
+
|
5 |
+
|
6 |
+
class NeutralModel(SeparationModel):
|
7 |
+
def __init__(self, *args, **kwargs) -> None:
|
8 |
+
super().__init__(*args, **kwargs)
|
9 |
+
self.fake = torch.nn.Conv1d(1, 1, 1, bias=False)
|
10 |
+
|
11 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
12 |
+
"""Identity function
|
13 |
+
"""
|
14 |
+
n = self.fake(x)
|
15 |
+
return x, n
|
src/gyraudio/audio_separation/architecture/transformer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from gyraudio.audio_separation.architecture.model import SeparationModel
|
3 |
+
from gyraudio.audio_separation.architecture.building_block import FilterBank
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
|
7 |
+
class TransformerModel(SeparationModel):
|
8 |
+
"""Transformer base model
|
9 |
+
=========================
|
10 |
+
- Embed signal with a filter bank
|
11 |
+
- No positional encoding (Potential =add/concatenate positional encoding)
|
12 |
+
- `nlayers` * transformer blocks
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
nhead: int = 8, # H
|
17 |
+
nlayers: int = 4, # L
|
18 |
+
k_size=5,
|
19 |
+
embedding_dim: int = 64, # D
|
20 |
+
ch_in: int = 1,
|
21 |
+
ch_out: int = 1,
|
22 |
+
dropout: float = 0., # dr
|
23 |
+
positional_encoding: str = None
|
24 |
+
) -> None:
|
25 |
+
"""Transformer base model
|
26 |
+
|
27 |
+
Args:
|
28 |
+
nhead (int): number of heads in each of the MHA models
|
29 |
+
embedding_dim (int): D number of channels in the audio embeddings
|
30 |
+
= output of the filter bank
|
31 |
+
assume `embedding_dim` = `h_dim`
|
32 |
+
h_dim is the hidden dimension of the model.
|
33 |
+
nlayers (int): number of nn.TransformerEncoderLayer in nn.TransformerEncoder
|
34 |
+
dropout (float, optional): dropout value. Defaults to 0.
|
35 |
+
"""
|
36 |
+
super().__init__()
|
37 |
+
self.model_type = "Transformer"
|
38 |
+
h_dim = embedding_dim # use the same embedding & hidden dimensions
|
39 |
+
|
40 |
+
self.encoder = FilterBank(ch_in, embedding_dim, k_size=k_size)
|
41 |
+
if positional_encoding is None:
|
42 |
+
self.pos_encoder = torch.nn.Identity()
|
43 |
+
else:
|
44 |
+
raise NotImplementedError(
|
45 |
+
f"Unknown positional encoding {positional_encoding} - should be add/concat in future")
|
46 |
+
# self.pos_encoder = PositionalEncoding(h_dim, dropout=dropout)
|
47 |
+
|
48 |
+
encoder_layers = torch.nn.TransformerEncoderLayer(
|
49 |
+
d_model=h_dim, # input dimension to the transformer encoder layer
|
50 |
+
nhead=nhead, # number of heads for MHA (Multi-head attention)
|
51 |
+
dim_feedforward=h_dim, # output dimension of the MLP on top of the transformer.
|
52 |
+
dropout=dropout,
|
53 |
+
batch_first=True
|
54 |
+
) # we assume h_dim = d_model = dim_feedforward
|
55 |
+
|
56 |
+
self.transformer_encoder = torch.nn.TransformerEncoder(
|
57 |
+
encoder_layers,
|
58 |
+
num_layers=nlayers
|
59 |
+
)
|
60 |
+
self.h_dim = h_dim
|
61 |
+
self.target_modality_conv = torch.nn.Conv1d(h_dim, ch_out, 1) # conv1x1 channel mixer
|
62 |
+
# Note: we could finish with a few residual conv blocks... this is pure signal processing
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self, src: torch.LongTensor,
|
66 |
+
src_mask: Optional[torch.FloatTensor] = None
|
67 |
+
) -> torch.FloatTensor:
|
68 |
+
"""Embdeddings, positional encoders, go trough `nlayers` of residual {multi (`nhead`) attention heads + MLP}.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
src (torch.LongTensor): [N, 1, T] audio signal
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
torch.FloatTensor: separated signal [N, 1, T]
|
75 |
+
"""
|
76 |
+
src = self.encoder(src) # [N, 1, T] -> [N, D, T]
|
77 |
+
src = src.transpose(-1, -2) # [N, D, T] -> [N, T, D] # Transformer expects (batch N, seq "T", features "D")
|
78 |
+
src = self.pos_encoder(src) # -> [N, T, D] - add positional encoding
|
79 |
+
|
80 |
+
output = self.transformer_encoder(src, mask=src_mask) # -> [N, T, D]
|
81 |
+
output = output.transpose(-1, -2) # -> [N, D, T]
|
82 |
+
output = self.target_modality_conv(output) # -> [N, 1, T]
|
83 |
+
return output, None
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
model = TransformerModel()
|
88 |
+
inp = torch.rand(2, 1, 2048)
|
89 |
+
out = model(inp)
|
90 |
+
print(model)
|
91 |
+
print(out[0].shape)
|
src/gyraudio/audio_separation/architecture/unet.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from gyraudio.audio_separation.architecture.model import SeparationModel
|
3 |
+
from gyraudio.audio_separation.architecture.building_block import ResConvolution
|
4 |
+
from typing import Optional
|
5 |
+
# import logging
|
6 |
+
|
7 |
+
|
8 |
+
class EncoderSingleStage(torch.nn.Module):
|
9 |
+
"""
|
10 |
+
Extend channels
|
11 |
+
Resnet
|
12 |
+
Downsample by 2
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5):
|
16 |
+
# ch_out ~ ch_in*extension_factor
|
17 |
+
super().__init__()
|
18 |
+
hdim = hdim or ch
|
19 |
+
self.extension_conv = torch.nn.Conv1d(ch, ch_out, k_size, padding=k_size//2)
|
20 |
+
self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size)
|
21 |
+
# warning on maxpooling jitter offset!
|
22 |
+
self.max_pool = torch.nn.MaxPool1d(kernel_size=2)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.extension_conv(x)
|
26 |
+
x = self.res_conv(x)
|
27 |
+
x_ds = self.max_pool(x)
|
28 |
+
return x, x_ds
|
29 |
+
|
30 |
+
|
31 |
+
class DecoderSingleStage(torch.nn.Module):
|
32 |
+
"""
|
33 |
+
Upsample by 2
|
34 |
+
Resnet
|
35 |
+
Extend channels
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5):
|
39 |
+
"""Decoder stage
|
40 |
+
|
41 |
+
Args:
|
42 |
+
ch (int): channel size (downsampled & skip connection have same channel size)
|
43 |
+
ch_out (int): number of output channels (shall match the number of input channels of the next stage)
|
44 |
+
hdim (Optional[int], optional): Hidden dimension used in the residual block. Defaults to None.
|
45 |
+
k_size (int, optional): Convolution size. Defaults to 5.
|
46 |
+
Notes:
|
47 |
+
======
|
48 |
+
ch_out = 2*ch/extension_factor
|
49 |
+
|
50 |
+
self.scale_mixers_conv
|
51 |
+
- tells how lower decoded (x_ds) scale is merged with current encoded scale (x_skip)
|
52 |
+
- could be a pointwise aka conv1x1
|
53 |
+
"""
|
54 |
+
|
55 |
+
super().__init__()
|
56 |
+
hdim = hdim or ch
|
57 |
+
self.scale_mixers_conv = torch.nn.Conv1d(2*ch, ch_out, k_size, padding=k_size//2)
|
58 |
+
|
59 |
+
self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size)
|
60 |
+
# warning: Linear interpolation shall be "conjugated" with the skipping downsampling
|
61 |
+
# special care shall be taken care of regarding offsets
|
62 |
+
# https://arxiv.org/abs/1806.03185
|
63 |
+
self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
|
64 |
+
self.non_linearity = torch.nn.ReLU()
|
65 |
+
|
66 |
+
def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
|
67 |
+
""""""
|
68 |
+
x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T]
|
69 |
+
|
70 |
+
x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T]
|
71 |
+
x = self.scale_mixers_conv(x) # [N, ch_out, T]
|
72 |
+
x = self.non_linearity(x)
|
73 |
+
x = self.res_conv(x) # [N, ch_out, T]
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class ResUNet(SeparationModel):
|
78 |
+
"""Convolutional neural network for audio separation,
|
79 |
+
|
80 |
+
Decimation, bottleneck
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self,
|
84 |
+
ch_in: int = 1,
|
85 |
+
ch_out: int = 2,
|
86 |
+
channels_extension: float = 1.5,
|
87 |
+
h_dim=16,
|
88 |
+
k_size=5,
|
89 |
+
) -> None:
|
90 |
+
super().__init__()
|
91 |
+
self.need_split = ch_out != ch_in
|
92 |
+
self.ch_out = ch_out
|
93 |
+
self.source_modality_conv = torch.nn.Conv1d(ch_in, h_dim, k_size, padding=k_size//2)
|
94 |
+
self.encoder_list = torch.nn.ModuleList()
|
95 |
+
self.decoder_list = torch.nn.ModuleList()
|
96 |
+
self.non_linearity = torch.nn.ReLU()
|
97 |
+
|
98 |
+
h_dim_current = h_dim
|
99 |
+
for _level in range(4):
|
100 |
+
h_dim_ds = int(h_dim_current*channels_extension)
|
101 |
+
self.encoder_list.append(EncoderSingleStage(h_dim_current, h_dim_ds, k_size=k_size))
|
102 |
+
self.decoder_list.append(DecoderSingleStage(h_dim_ds, h_dim_current, k_size=k_size))
|
103 |
+
h_dim_current = h_dim_ds
|
104 |
+
self.bottleneck = ResConvolution(h_dim_current, k_size=k_size)
|
105 |
+
self.target_modality_conv = torch.nn.Conv1d(h_dim, ch_out, 1) # conv1x1 channel mixer
|
106 |
+
|
107 |
+
def forward(self, x_in):
|
108 |
+
# x_in (1, 2048)
|
109 |
+
x0 = self.source_modality_conv(x_in)
|
110 |
+
x0 = self.non_linearity(x0)
|
111 |
+
# x0 -> (16, 2048)
|
112 |
+
|
113 |
+
x1_skip, x1_ds = self.encoder_list[0](x0)
|
114 |
+
# x1_skip -> (24, 2048)
|
115 |
+
# x1_ds -> (24, 1024)
|
116 |
+
# print(x1_skip.shape, x1_ds.shape)
|
117 |
+
|
118 |
+
x2_skip, x2_ds = self.encoder_list[1](x1_ds)
|
119 |
+
# x2_skip -> (36, 1024)
|
120 |
+
# x2_ds -> (36, 512)
|
121 |
+
# print(x2_skip.shape, x2_ds.shape)
|
122 |
+
|
123 |
+
x3_skip, x3_ds = self.encoder_list[2](x2_ds)
|
124 |
+
# x3_skip -> (54, 512)
|
125 |
+
# x3_ds -> (54, 256)
|
126 |
+
# print(x3_skip.shape, x3_ds.shape)
|
127 |
+
|
128 |
+
x4_skip, x4_ds = self.encoder_list[3](x3_ds)
|
129 |
+
# x4_skip -> (81, 256)
|
130 |
+
# x4_ds -> (81, 128)
|
131 |
+
# print(x4_skip.shape, x4_ds.shape)
|
132 |
+
|
133 |
+
x4_dec = self.bottleneck(x4_ds)
|
134 |
+
x3_dec = self.decoder_list[3](x4_dec, x4_skip)
|
135 |
+
x2_dec = self.decoder_list[2](x3_dec, x3_skip)
|
136 |
+
x1_dec = self.decoder_list[1](x2_dec, x2_skip)
|
137 |
+
x0_dec = self.decoder_list[0](x1_dec, x1_skip)
|
138 |
+
demuxed = self.target_modality_conv(x0_dec)
|
139 |
+
# no relu
|
140 |
+
if self.need_split:
|
141 |
+
return torch.chunk(demuxed, self.ch_out, dim=1)
|
142 |
+
return demuxed, None
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
model = ResUNet()
|
147 |
+
inp = torch.rand(2, 1, 2048)
|
148 |
+
out = model(inp)
|
149 |
+
print(model)
|
150 |
+
print(model.count_parameters())
|
151 |
+
print(out[0].shape)
|
src/gyraudio/audio_separation/architecture/wave_unet.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from gyraudio.audio_separation.architecture.model import SeparationModel
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
def get_non_linearity(activation: str):
|
7 |
+
if activation == "LeakyReLU":
|
8 |
+
non_linearity = torch.nn.LeakyReLU()
|
9 |
+
else:
|
10 |
+
non_linearity = torch.nn.ReLU()
|
11 |
+
return non_linearity
|
12 |
+
|
13 |
+
|
14 |
+
class BaseConvolutionBlock(torch.nn.Module):
|
15 |
+
def __init__(self, ch_in, ch_out: int, k_size: int, activation="LeakyReLU", dropout: float = 0, bias: bool = True) -> None:
|
16 |
+
super().__init__()
|
17 |
+
self.conv = torch.nn.Conv1d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias)
|
18 |
+
self.non_linearity = get_non_linearity(activation)
|
19 |
+
self.dropout = torch.nn.Dropout1d(p=dropout)
|
20 |
+
|
21 |
+
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
|
22 |
+
x = self.conv(x_in) # [N, ch_in, T] -> [N, ch_in+channels_extension, T]
|
23 |
+
x = self.non_linearity(x)
|
24 |
+
x = self.dropout(x)
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
class EncoderStage(torch.nn.Module):
|
29 |
+
"""Conv (and extend channels), downsample 2 by skipping samples
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, ch_in: int, ch_out: int, k_size: int = 15, dropout: float = 0, bias: bool = True) -> None:
|
33 |
+
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
x = self.conv_block(x)
|
40 |
+
|
41 |
+
x_ds = x[..., ::2]
|
42 |
+
# ch_out = ch_in+channels_extension
|
43 |
+
return x, x_ds
|
44 |
+
|
45 |
+
|
46 |
+
class DecoderStage(torch.nn.Module):
|
47 |
+
"""Upsample by 2, Concatenate with skip connection, Conv (and shrink channels)
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, ch_in: int, ch_out: int, k_size: int = 5, dropout: float = 0., bias: bool = True) -> None:
|
51 |
+
"""Decoder stage
|
52 |
+
"""
|
53 |
+
|
54 |
+
super().__init__()
|
55 |
+
self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias)
|
56 |
+
self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
|
57 |
+
|
58 |
+
def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
|
59 |
+
""""""
|
60 |
+
x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T]
|
61 |
+
x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T]
|
62 |
+
x = self.conv_block(x) # [N, ch_out, T]
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class WaveUNet(SeparationModel):
|
67 |
+
"""UNET in temporal domain (waveform)
|
68 |
+
= Multiscale convolutional neural network for audio separation
|
69 |
+
https://arxiv.org/abs/1806.03185
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self,
|
73 |
+
ch_in: int = 1,
|
74 |
+
ch_out: int = 2,
|
75 |
+
channels_extension: int = 24,
|
76 |
+
k_conv_ds: int = 15,
|
77 |
+
k_conv_us: int = 5,
|
78 |
+
num_layers: int = 6,
|
79 |
+
dropout: float = 0.0,
|
80 |
+
bias: bool = True,
|
81 |
+
) -> None:
|
82 |
+
super().__init__()
|
83 |
+
self.need_split = ch_out != ch_in
|
84 |
+
self.ch_out = ch_out
|
85 |
+
self.encoder_list = torch.nn.ModuleList()
|
86 |
+
self.decoder_list = torch.nn.ModuleList()
|
87 |
+
# Defining first encoder
|
88 |
+
self.encoder_list.append(EncoderStage(ch_in, channels_extension, k_size=k_conv_ds, dropout=dropout, bias=bias))
|
89 |
+
for level in range(1, num_layers+1):
|
90 |
+
ch_i = level*channels_extension
|
91 |
+
ch_o = (level+1)*channels_extension
|
92 |
+
if level < num_layers:
|
93 |
+
# Skipping last encoder since we defined the first one outside the loop
|
94 |
+
self.encoder_list.append(EncoderStage(ch_i, ch_o, k_size=k_conv_ds, dropout=dropout, bias=bias))
|
95 |
+
self.decoder_list.append(DecoderStage(ch_o+ch_i, ch_i, k_size=k_conv_us, dropout=dropout, bias=bias))
|
96 |
+
self.bottleneck = BaseConvolutionBlock(
|
97 |
+
num_layers*channels_extension,
|
98 |
+
(num_layers+1)*channels_extension,
|
99 |
+
k_size=k_conv_ds,
|
100 |
+
dropout=dropout,
|
101 |
+
bias=bias)
|
102 |
+
self.dropout = torch.nn.Dropout1d(p=dropout)
|
103 |
+
self.target_modality_conv = torch.nn.Conv1d(
|
104 |
+
channels_extension+ch_in, ch_out, 1, bias=bias) # conv1x1 channel mixer
|
105 |
+
|
106 |
+
def forward(self, x_in: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
107 |
+
"""Forward UNET pass
|
108 |
+
|
109 |
+
```
|
110 |
+
(1 , 2048)----------------->(24 , 2048) > (1 , 2048)
|
111 |
+
v ^
|
112 |
+
(24 , 1024)----------------->(48 , 1024)
|
113 |
+
v ^
|
114 |
+
(48 , 512 )----------------->(72 , 512 )
|
115 |
+
v ^
|
116 |
+
(72 , 256 )----------------->(96 , 256 )
|
117 |
+
v ^
|
118 |
+
(96 , 128 )----BOTTLENECK--->(120, 128 )
|
119 |
+
```
|
120 |
+
|
121 |
+
"""
|
122 |
+
skipped_list = []
|
123 |
+
ds_list = [x_in]
|
124 |
+
for level, enc in enumerate(self.encoder_list):
|
125 |
+
x_skip, x_ds = enc(ds_list[-1])
|
126 |
+
skipped_list.append(x_skip)
|
127 |
+
ds_list.append(x_ds.clone())
|
128 |
+
# print(x_skip.shape, x_ds.shape)
|
129 |
+
x_dec = self.bottleneck(ds_list[-1])
|
130 |
+
for level, dec in enumerate(self.decoder_list[::-1]):
|
131 |
+
x_dec = dec(x_dec, skipped_list[-1-level])
|
132 |
+
# print(x_dec.shape)
|
133 |
+
x_dec = torch.cat([x_dec, x_in], dim=1)
|
134 |
+
# print(x_dec.shape)
|
135 |
+
x_dec = self.dropout(x_dec)
|
136 |
+
demuxed = self.target_modality_conv(x_dec)
|
137 |
+
# print(demuxed.shape)
|
138 |
+
if self.need_split:
|
139 |
+
return torch.chunk(demuxed, self.ch_out, dim=1)
|
140 |
+
return demuxed, None
|
141 |
+
|
142 |
+
# x_skip, x_ds
|
143 |
+
# (24, 2048), (24, 1024)
|
144 |
+
# (48, 1024), (48, 512 )
|
145 |
+
# (72, 512 ), (72, 256 )
|
146 |
+
# (96, 256 ), (96, 128 )
|
147 |
+
|
148 |
+
# (120, 128 )
|
149 |
+
# (96 , 256 )
|
150 |
+
# (72 , 512 )
|
151 |
+
# (48 , 1024)
|
152 |
+
# (24 , 2048)
|
153 |
+
# (25 , 2048) demuxed - after concat
|
154 |
+
# (1 , 2048)
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
model = WaveUNet(ch_out=1, num_layers=9)
|
159 |
+
inp = torch.rand(2, 1, 2048)
|
160 |
+
out = model(inp)
|
161 |
+
print(model)
|
162 |
+
print(model.count_parameters())
|
163 |
+
print(out[0].shape)
|
src/gyraudio/audio_separation/data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.data.mixed import MixedAudioDataset
|
2 |
+
from gyraudio.audio_separation.data.remixed_fixed import RemixedFixedAudioDataset
|
3 |
+
from gyraudio.audio_separation.data.remixed_rnd import RemixedRandomAudioDataset
|
4 |
+
from gyraudio.audio_separation.data.single import SingleAudioDataset
|
5 |
+
from gyraudio.audio_separation.data.dataloader import get_dataloader, get_config_dataloader
|
src/gyraudio/audio_separation/data/dataloader.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
from gyraudio.audio_separation.data.mixed import MixedAudioDataset
|
3 |
+
from typing import Optional, List
|
4 |
+
from gyraudio.audio_separation.properties import (
|
5 |
+
DATA_PATH, AUGMENTATION, SNR_FILTER, SHUFFLE, BATCH_SIZE, TRAIN, VALID, TEST, AUG_TRIM
|
6 |
+
)
|
7 |
+
from gyraudio import root_dir
|
8 |
+
RAW_AUDIO_ROOT = root_dir/"__data_source_separation"/"voice_origin"
|
9 |
+
MIXED_AUDIO_ROOT = root_dir/"__data_source_separation"/"source_separation"
|
10 |
+
|
11 |
+
|
12 |
+
def get_dataloader(configurations: dict, audio_dataset=MixedAudioDataset):
|
13 |
+
dataloaders = {}
|
14 |
+
for mode, configuration in configurations.items():
|
15 |
+
dataset = audio_dataset(
|
16 |
+
configuration[DATA_PATH],
|
17 |
+
augmentation_config=configuration[AUGMENTATION],
|
18 |
+
snr_filter=configuration[SNR_FILTER]
|
19 |
+
)
|
20 |
+
dl = DataLoader(
|
21 |
+
dataset,
|
22 |
+
shuffle=configuration[SHUFFLE],
|
23 |
+
batch_size=configuration[BATCH_SIZE],
|
24 |
+
collate_fn=dataset.collate_fn
|
25 |
+
)
|
26 |
+
dataloaders[mode] = dl
|
27 |
+
return dataloaders
|
28 |
+
|
29 |
+
|
30 |
+
def get_config_dataloader(
|
31 |
+
audio_root=MIXED_AUDIO_ROOT,
|
32 |
+
mode: str = TRAIN,
|
33 |
+
shuffle: Optional[bool] = None,
|
34 |
+
batch_size: Optional[int] = 16,
|
35 |
+
snr_filter: Optional[List[float]] = None,
|
36 |
+
augmentation: dict = {}):
|
37 |
+
audio_folder = audio_root/mode
|
38 |
+
assert mode in [TRAIN, VALID, TEST]
|
39 |
+
assert audio_folder.exists()
|
40 |
+
config = {
|
41 |
+
DATA_PATH: audio_folder,
|
42 |
+
SHUFFLE: shuffle if shuffle is not None else (True if mode == TRAIN else False),
|
43 |
+
AUGMENTATION: augmentation,
|
44 |
+
SNR_FILTER: snr_filter,
|
45 |
+
BATCH_SIZE: batch_size
|
46 |
+
}
|
47 |
+
return config
|
src/gyraudio/audio_separation/data/dataset.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import default_collate
|
6 |
+
from typing import Tuple
|
7 |
+
from functools import partial
|
8 |
+
from gyraudio.audio_separation.properties import (
|
9 |
+
AUG_AWGN, AUG_RESCALE, AUG_TRIM, LENGTHS, LENGTH_DIVIDER, TRIM_PROB
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class AudioDataset(Dataset):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
data_path: Path,
|
17 |
+
augmentation_config: dict = {},
|
18 |
+
snr_filter: Optional[float] = None,
|
19 |
+
debug: bool = False
|
20 |
+
):
|
21 |
+
self.debug = debug
|
22 |
+
self.data_path = data_path
|
23 |
+
self.augmentation_config = augmentation_config
|
24 |
+
self.snr_filter = snr_filter
|
25 |
+
self.load_data()
|
26 |
+
self.length = len(self.file_list)
|
27 |
+
self.collate_fn = None
|
28 |
+
if AUG_TRIM in self.augmentation_config:
|
29 |
+
self.collate_fn = partial(collate_fn_generic,
|
30 |
+
lengths_lim=self.augmentation_config[AUG_TRIM][LENGTHS],
|
31 |
+
length_divider=self.augmentation_config[AUG_TRIM][LENGTH_DIVIDER],
|
32 |
+
trim_prob=self.augmentation_config[AUG_TRIM][TRIM_PROB])
|
33 |
+
|
34 |
+
def filter_data(self, snr):
|
35 |
+
if self.snr_filter is None:
|
36 |
+
return True
|
37 |
+
if snr in self.snr_filter:
|
38 |
+
return True
|
39 |
+
else:
|
40 |
+
return False
|
41 |
+
|
42 |
+
def load_data(self):
|
43 |
+
raise NotImplementedError("load_data method must be implemented")
|
44 |
+
|
45 |
+
def augment_data(self, mixed_audio_signal, clean_audio_signal, noise_audio_signal):
|
46 |
+
if AUG_RESCALE in self.augmentation_config:
|
47 |
+
current_amplitude = 0.5 + 1.5*torch.rand(1, device=mixed_audio_signal.device)
|
48 |
+
# logging.debug(current_amplitude)
|
49 |
+
mixed_audio_signal *= current_amplitude
|
50 |
+
noise_audio_signal *= current_amplitude
|
51 |
+
clean_audio_signal *= current_amplitude
|
52 |
+
if AUG_AWGN in self.augmentation_config:
|
53 |
+
# noise_std = self.augmentation_config[AUG_AWGN]["noise_std"]
|
54 |
+
noise_std = 0.01
|
55 |
+
current_noise_std = torch.randn(1) * noise_std
|
56 |
+
# logging.debug(current_noise_std)
|
57 |
+
extra_awgn = torch.randn(mixed_audio_signal.shape, device=mixed_audio_signal.device) * current_noise_std
|
58 |
+
mixed_audio_signal = mixed_audio_signal+extra_awgn
|
59 |
+
# Open question: should we add noise to the noise signal aswell?
|
60 |
+
|
61 |
+
return mixed_audio_signal, clean_audio_signal, noise_audio_signal
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return self.length
|
65 |
+
|
66 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
67 |
+
raise NotImplementedError("__getitem__ method must be implemented")
|
68 |
+
|
69 |
+
|
70 |
+
def collate_fn_generic(batch, lengths_lim, length_divider=1024, trim_prob=0.5) -> Tuple[torch.Tensor, torch.Tensor]:
|
71 |
+
"""Collate function to allow trimming (=crop the time dimension) of the signals in a batch.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
batch (list): A list of tuples (triplets), where each tuple contain:
|
75 |
+
- mixed_audio_signal
|
76 |
+
- clean_audio_signal
|
77 |
+
- noise_audio_signal
|
78 |
+
lengths_lim (list) : A list of containing a minimum length (0) and a maximum length (1)
|
79 |
+
length_divider (int) : has to be a trimmed length divider
|
80 |
+
trim_prob (float) : trimming probability
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
- Tensor: A batch of mixed_audio_signal, trimmed to the same length.
|
84 |
+
- Tensor: A batch of clean_audio_signal
|
85 |
+
- Tensor: A batch of noise_audio_signal
|
86 |
+
"""
|
87 |
+
|
88 |
+
# Find the length of the shortest signal in the batch
|
89 |
+
mixed_audio_signal, clean_audio_signal, noise_audio_signal = default_collate(batch)
|
90 |
+
length = mixed_audio_signal[0].shape[-1]
|
91 |
+
min_length, max_length = lengths_lim
|
92 |
+
take_full_signal = torch.rand(1) > trim_prob
|
93 |
+
if not take_full_signal:
|
94 |
+
start = torch.randint(0, length-min_length, (1,))
|
95 |
+
trim_length = torch.randint(min_length, min(max_length, length-start-1)+1, (1,))
|
96 |
+
trim_length = trim_length-trim_length % length_divider
|
97 |
+
end = start + trim_length
|
98 |
+
else:
|
99 |
+
start = 0
|
100 |
+
end = length - length % length_divider
|
101 |
+
mixed_audio_signal = mixed_audio_signal[..., start:end]
|
102 |
+
clean_audio_signal = clean_audio_signal[..., start:end]
|
103 |
+
noise_audio_signal = noise_audio_signal[..., start:end]
|
104 |
+
return mixed_audio_signal, clean_audio_signal, noise_audio_signal
|
src/gyraudio/audio_separation/data/mixed.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.data.dataset import AudioDataset
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
|
8 |
+
class MixedAudioDataset(AudioDataset):
|
9 |
+
def load_data(self):
|
10 |
+
self.folder_list = sorted(list(self.data_path.iterdir()))
|
11 |
+
self.file_list = [
|
12 |
+
[
|
13 |
+
list(folder.glob("mix*.wav"))[0],
|
14 |
+
folder/"voice.wav",
|
15 |
+
folder/"noise.wav"
|
16 |
+
] for folder in self.folder_list
|
17 |
+
]
|
18 |
+
snr_list = [float(file[0].stem.split("_")[-1]) for file in self.file_list]
|
19 |
+
self.file_list = [files for snr, files in zip(snr_list, self.file_list) if self.filter_data(snr)]
|
20 |
+
if self.debug:
|
21 |
+
logging.info(f"Available SNR {set(snr_list)}")
|
22 |
+
print(f"Available SNR {set(snr_list)}")
|
23 |
+
print("Filtered", len(self.file_list), self.snr_filter)
|
24 |
+
self.sampling_rate = None
|
25 |
+
|
26 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
27 |
+
mixed_audio_path, signal_path, noise_path = self.file_list[idx]
|
28 |
+
assert mixed_audio_path.exists()
|
29 |
+
assert signal_path.exists()
|
30 |
+
assert noise_path.exists()
|
31 |
+
mixed_audio_signal, sampling_rate = torchaudio.load(str(mixed_audio_path))
|
32 |
+
clean_audio_signal, sampling_rate = torchaudio.load(str(signal_path))
|
33 |
+
noise_audio_signal, sampling_rate = torchaudio.load(str(noise_path))
|
34 |
+
self.sampling_rate = sampling_rate
|
35 |
+
mixed_audio_signal, clean_audio_signal, noise_audio_signal = self.augment_data(mixed_audio_signal, clean_audio_signal, noise_audio_signal)
|
36 |
+
if self.debug:
|
37 |
+
logging.debug(f"{mixed_audio_signal.shape}")
|
38 |
+
logging.debug(f"{clean_audio_signal.shape}")
|
39 |
+
logging.debug(f"{noise_audio_signal.shape}")
|
40 |
+
return mixed_audio_signal, clean_audio_signal, noise_audio_signal
|
src/gyraudio/audio_separation/data/remixed.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.data.dataset import AudioDataset
|
2 |
+
from typing import Tuple
|
3 |
+
import logging
|
4 |
+
from torch import Tensor
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
|
9 |
+
class RemixedAudioDataset(AudioDataset):
|
10 |
+
def generate_snr_list(self):
|
11 |
+
self.snr_list = None
|
12 |
+
|
13 |
+
def load_data(self):
|
14 |
+
self.folder_list = sorted(list(self.data_path.iterdir()))
|
15 |
+
self.file_list = [
|
16 |
+
[
|
17 |
+
folder/"voice.wav",
|
18 |
+
folder/"noise.wav"
|
19 |
+
] for folder in self.folder_list
|
20 |
+
]
|
21 |
+
self.sampling_rate = None
|
22 |
+
self.min_snr, self.max_snr = -4, 4
|
23 |
+
self.generate_snr_list()
|
24 |
+
if self.debug:
|
25 |
+
print("Not filtered", len(self.file_list), self.snr_filter)
|
26 |
+
print(self.snr_list)
|
27 |
+
|
28 |
+
def get_idx_noise(self, idx):
|
29 |
+
raise NotImplementedError("get_idx_noise method must be implemented")
|
30 |
+
|
31 |
+
def get_snr(self, idx):
|
32 |
+
raise NotImplementedError("get_snr method must be implemented")
|
33 |
+
|
34 |
+
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
|
35 |
+
signal_path = self.file_list[idx][0]
|
36 |
+
idx_noise = self.get_idx_noise(idx)
|
37 |
+
noise_path = self.file_list[idx_noise][1]
|
38 |
+
|
39 |
+
assert signal_path.exists()
|
40 |
+
assert noise_path.exists()
|
41 |
+
clean_audio_signal, sampling_rate = torchaudio.load(str(signal_path))
|
42 |
+
noise_audio_signal, sampling_rate = torchaudio.load(str(noise_path))
|
43 |
+
snr = self.get_snr(idx)
|
44 |
+
alpha = 10 ** (-snr / 20) * torch.norm(clean_audio_signal) / torch.norm(noise_audio_signal)
|
45 |
+
mixed_audio_signal = clean_audio_signal + alpha*noise_audio_signal
|
46 |
+
self.sampling_rate = sampling_rate
|
47 |
+
mixed_audio_signal, clean_audio_signal, noise_audio_signal = self.augment_data(
|
48 |
+
mixed_audio_signal, clean_audio_signal, noise_audio_signal)
|
49 |
+
if self.debug:
|
50 |
+
logging.debug(f"{mixed_audio_signal.shape}")
|
51 |
+
logging.debug(f"{clean_audio_signal.shape}")
|
52 |
+
logging.debug(f"{noise_audio_signal.shape}")
|
53 |
+
return mixed_audio_signal, clean_audio_signal, noise_audio_signal
|
src/gyraudio/audio_separation/data/remixed_fixed.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.data.remixed import RemixedAudioDataset
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class RemixedFixedAudioDataset(RemixedAudioDataset):
|
5 |
+
def generate_snr_list(self) :
|
6 |
+
rnd_gen = torch.Generator()
|
7 |
+
rnd_gen.manual_seed(2147483647)
|
8 |
+
if self.snr_filter is None :
|
9 |
+
self.snr_list = self.min_snr + (self.max_snr - self.min_snr)*torch.rand(len(self.file_list), generator = rnd_gen)
|
10 |
+
else :
|
11 |
+
indices = torch.randint(0, len(self.snr_filter), (len(self.file_list),), generator=rnd_gen)
|
12 |
+
self.snr_list = [self.snr_filter[idx] for idx in indices]
|
13 |
+
|
14 |
+
def get_idx_noise(self, idx) :
|
15 |
+
return idx
|
16 |
+
|
17 |
+
def get_snr(self, idx) :
|
18 |
+
return self.snr_list[idx]
|
src/gyraudio/audio_separation/data/remixed_rnd.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.data.remixed import RemixedAudioDataset
|
2 |
+
from torch import rand, randint
|
3 |
+
|
4 |
+
|
5 |
+
class RemixedRandomAudioDataset(RemixedAudioDataset):
|
6 |
+
def get_idx_noise(self, idx):
|
7 |
+
return randint(0, len(self.file_list)-1, (1,))
|
8 |
+
|
9 |
+
def get_snr(self, idx):
|
10 |
+
if self.snr_filter is None:
|
11 |
+
return self.min_snr + (self.max_snr - self.min_snr)*rand(1)
|
12 |
+
return self.snr_filter[randint(0, len(self.snr_filter), (1,))]
|
src/gyraudio/audio_separation/data/silence_detector.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def get_silence_mask(
|
7 |
+
sig: torch.Tensor, morph_kernel_size: int = 499, k_smooth=21, thresh=0.0001,
|
8 |
+
debug: bool = False) -> torch.Tensor:
|
9 |
+
with torch.no_grad():
|
10 |
+
smooth = torch.nn.Conv1d(1, 1, k_smooth, padding=k_smooth//2, bias=False).to(sig.device)
|
11 |
+
smooth.weight.data.fill_(1./k_smooth)
|
12 |
+
smoothed = smooth(torch.abs(sig))
|
13 |
+
st = 1.*(torch.abs(smoothed) < thresh*torch.ones_like(smoothed, device=sig.device))
|
14 |
+
sig_dil = torch.nn.MaxPool1d(morph_kernel_size, stride=1, padding=morph_kernel_size//2)(st)
|
15 |
+
sig_ero = -torch.nn.MaxPool1d(morph_kernel_size, stride=1, padding=morph_kernel_size//2)(-sig_dil)
|
16 |
+
if debug:
|
17 |
+
return sig_ero.squeeze(0), smoothed.squeeze(0), st.squeeze(0)
|
18 |
+
else:
|
19 |
+
return sig_ero
|
20 |
+
|
21 |
+
|
22 |
+
def visualize_silence_mask(sig: torch.Tensor, silence_thresh: float = 0.0001):
|
23 |
+
silence_thresh = 0.0001
|
24 |
+
silence_mask, smoothed_amplitude, _ = get_silence_mask(
|
25 |
+
sig, k_smooth=21, morph_kernel_size=499, thresh=silence_thresh, debug=True
|
26 |
+
)
|
27 |
+
plt.figure(figsize=(12, 4))
|
28 |
+
plt.subplot(121)
|
29 |
+
plt.plot(sig.squeeze(0).cpu().numpy(), "k-", label="voice", alpha=0.5)
|
30 |
+
plt.plot(0.01*silence_mask.cpu().numpy(), "r-", alpha=1., label="silence mask")
|
31 |
+
plt.grid()
|
32 |
+
plt.legend()
|
33 |
+
plt.title("Voice and silence mask")
|
34 |
+
plt.ylim(-0.04, 0.04)
|
35 |
+
|
36 |
+
plt.subplot(122)
|
37 |
+
plt.plot(smoothed_amplitude.cpu().numpy(), "g--", alpha=0.5, label="smoothed amplitude")
|
38 |
+
plt.plot(np.ones(silence_mask.shape[-1])*silence_thresh, "c--", alpha=1., label="threshold")
|
39 |
+
plt.plot(-silence_thresh+silence_thresh*silence_mask.cpu().numpy(), "r-", alpha=1, label="silence mask")
|
40 |
+
plt.grid()
|
41 |
+
plt.legend()
|
42 |
+
plt.title("Thresholding mechanism")
|
43 |
+
plt.ylim(-silence_thresh, silence_thresh*10)
|
44 |
+
plt.show()
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
from gyraudio.default_locations import SAMPLE_ROOT
|
49 |
+
from gyraudio.audio_separation.visualization.pre_load_audio import audio_loading
|
50 |
+
from gyraudio.audio_separation.properties import CLEAN, BUFFERS
|
51 |
+
sample_folder = SAMPLE_ROOT/"0009"
|
52 |
+
signals = audio_loading(sample_folder, preload=True)
|
53 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
54 |
+
sig_in = signals[BUFFERS][CLEAN].to(device)
|
55 |
+
visualize_silence_mask(sig_in)
|
src/gyraudio/audio_separation/data/single.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.data.dataset import AudioDataset
|
2 |
+
import logging
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
|
6 |
+
class SingleAudioDataset(AudioDataset):
|
7 |
+
def load_data(self):
|
8 |
+
self.file_list = sorted(list(self.data_path.glob("*.wav")))
|
9 |
+
|
10 |
+
def __getitem__(self, idx: int):
|
11 |
+
audio_path = self.file_list[idx]
|
12 |
+
assert audio_path.exists()
|
13 |
+
audio_signal, sampling_rate = torchaudio.load(str(audio_path))
|
14 |
+
logging.debug(f"{audio_signal.shape}")
|
15 |
+
return audio_signal
|
src/gyraudio/audio_separation/experiment_tracking/experiments.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.default_locations import MIXED_AUDIO_ROOT
|
2 |
+
from gyraudio.audio_separation.properties import (
|
3 |
+
TRAIN, TEST, VALID, NAME, EPOCHS, LEARNING_RATE,
|
4 |
+
OPTIMIZER, BATCH_SIZE, DATALOADER, AUGMENTATION,
|
5 |
+
SHORT_NAME, AUG_TRIM, TRIM_PROB, LENGTH_DIVIDER, LENGTHS, SNR_FILTER
|
6 |
+
)
|
7 |
+
from gyraudio.audio_separation.data.remixed_fixed import RemixedFixedAudioDataset
|
8 |
+
from gyraudio.audio_separation.data.remixed_rnd import RemixedRandomAudioDataset
|
9 |
+
from gyraudio.audio_separation.data import get_dataloader, get_config_dataloader
|
10 |
+
from gyraudio.audio_separation.experiment_tracking.experiments_definition import get_experiment_generator
|
11 |
+
import torch
|
12 |
+
from typing import Tuple
|
13 |
+
|
14 |
+
|
15 |
+
def get_experience(exp_major: int, exp_minor: int = 0, snr_filter_test=None, dry_run=False) -> Tuple[str, torch.nn.Module, dict, dict]:
|
16 |
+
"""Get all experience details
|
17 |
+
|
18 |
+
Args:
|
19 |
+
exp_major (int): Major experience number
|
20 |
+
exp_minor (int, optional): Used for HP search. Defaults to 0.
|
21 |
+
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Tuple[str, torch.nn.Module, dict, dict]: short_name, model, config, dataloaders
|
25 |
+
"""
|
26 |
+
model = None
|
27 |
+
config = {}
|
28 |
+
dataloader_name = "remix"
|
29 |
+
config = {
|
30 |
+
NAME: None,
|
31 |
+
OPTIMIZER: {
|
32 |
+
NAME: "adam",
|
33 |
+
LEARNING_RATE: 0.001
|
34 |
+
},
|
35 |
+
EPOCHS: 60,
|
36 |
+
DATALOADER: {
|
37 |
+
NAME: dataloader_name,
|
38 |
+
},
|
39 |
+
BATCH_SIZE: [16, 16, 16],
|
40 |
+
SNR_FILTER : snr_filter_test
|
41 |
+
}
|
42 |
+
|
43 |
+
model, config = get_experiment_generator(exp_major=exp_major)(config, no_model=dry_run, minor=exp_minor)
|
44 |
+
# POST PROCESSING
|
45 |
+
if isinstance(config[BATCH_SIZE], list) or isinstance(config[BATCH_SIZE], tuple):
|
46 |
+
config[BATCH_SIZE] = {
|
47 |
+
TRAIN: config[BATCH_SIZE][0],
|
48 |
+
TEST: config[BATCH_SIZE][1],
|
49 |
+
VALID: config[BATCH_SIZE][2],
|
50 |
+
}
|
51 |
+
|
52 |
+
if config[DATALOADER][NAME] == "premix":
|
53 |
+
mixed_audio_root = MIXED_AUDIO_ROOT
|
54 |
+
dataloaders = get_dataloader({
|
55 |
+
TRAIN: get_config_dataloader(
|
56 |
+
audio_root=mixed_audio_root,
|
57 |
+
mode=TRAIN,
|
58 |
+
shuffle=True,
|
59 |
+
batch_size=config[BATCH_SIZE][TRAIN],
|
60 |
+
augmentation=config[DATALOADER].get(AUGMENTATION, {})
|
61 |
+
),
|
62 |
+
TEST: get_config_dataloader(
|
63 |
+
audio_root=mixed_audio_root,
|
64 |
+
mode=TEST,
|
65 |
+
shuffle=False,
|
66 |
+
batch_size=config[BATCH_SIZE][TEST],
|
67 |
+
snr_filter=config[SNR_FILTER]
|
68 |
+
)
|
69 |
+
})
|
70 |
+
elif config[DATALOADER][NAME] == "remix":
|
71 |
+
mixed_audio_root = MIXED_AUDIO_ROOT
|
72 |
+
aug_test = {}
|
73 |
+
if AUG_TRIM in config[DATALOADER].get(AUGMENTATION, {}):
|
74 |
+
aug_test = {
|
75 |
+
AUG_TRIM: {LENGTHS: [None, None], LENGTH_DIVIDER: config[DATALOADER][AUGMENTATION]
|
76 |
+
[AUG_TRIM][LENGTH_DIVIDER], TRIM_PROB: -1.}
|
77 |
+
}
|
78 |
+
dl_train = get_dataloader(
|
79 |
+
{
|
80 |
+
TRAIN: get_config_dataloader(
|
81 |
+
audio_root=mixed_audio_root,
|
82 |
+
mode=TRAIN,
|
83 |
+
shuffle=True,
|
84 |
+
batch_size=config[BATCH_SIZE][TRAIN],
|
85 |
+
augmentation=config[DATALOADER].get(AUGMENTATION, {})
|
86 |
+
)
|
87 |
+
},
|
88 |
+
audio_dataset=RemixedRandomAudioDataset
|
89 |
+
)[TRAIN]
|
90 |
+
dl_test = get_dataloader(
|
91 |
+
{
|
92 |
+
TEST: get_config_dataloader(
|
93 |
+
audio_root=mixed_audio_root,
|
94 |
+
mode=TEST,
|
95 |
+
shuffle=False,
|
96 |
+
batch_size=config[BATCH_SIZE][TEST]
|
97 |
+
)
|
98 |
+
},
|
99 |
+
audio_dataset=RemixedFixedAudioDataset
|
100 |
+
)[TEST]
|
101 |
+
dataloaders = {
|
102 |
+
TRAIN: dl_train,
|
103 |
+
TEST: dl_test
|
104 |
+
}
|
105 |
+
else:
|
106 |
+
raise NotImplementedError(f"Unknown dataloader {dataloader_name}")
|
107 |
+
assert config[NAME] is not None
|
108 |
+
|
109 |
+
short_name = f"{exp_major:04d}_{exp_minor:04d}"
|
110 |
+
config[SHORT_NAME] = short_name
|
111 |
+
return short_name, model, config, dataloaders
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
from gyraudio.audio_separation.parser import shared_parser
|
116 |
+
parser_def = shared_parser()
|
117 |
+
args = parser_def.parse_args()
|
118 |
+
|
119 |
+
for exp in args.experiments:
|
120 |
+
short_name, model, config, dl = get_experience(exp)
|
121 |
+
print(short_name)
|
122 |
+
print(config)
|
src/gyraudio/audio_separation/experiment_tracking/experiments_decorator.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from gyraudio.audio_separation.properties import (
|
3 |
+
NAME, ANNOTATIONS, NB_PARAMS, RECEPTIVE_FIELD
|
4 |
+
)
|
5 |
+
from typing import Optional
|
6 |
+
REGISTERED_EXPERIMENTS_LIST = {}
|
7 |
+
|
8 |
+
|
9 |
+
# def count_parameters(model: torch.nn.Module) -> int:
|
10 |
+
# """Count number of trainable parameters
|
11 |
+
|
12 |
+
# Args:
|
13 |
+
# model (torch.nn.Module): Pytorch model
|
14 |
+
|
15 |
+
# Returns:
|
16 |
+
# int: Number of trainable elements
|
17 |
+
# """
|
18 |
+
# return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
19 |
+
|
20 |
+
|
21 |
+
def registered_experiment(major: Optional[int] = None, failed: Optional[bool] = False) -> callable:
|
22 |
+
"""Decorate and register an experiment
|
23 |
+
- Register the experiment in the list of experiments
|
24 |
+
- Count the number of parameters and add it to the config
|
25 |
+
|
26 |
+
Args:
|
27 |
+
major (Optional[int], optional): major id version = Number of the experiment. Defaults to None.
|
28 |
+
failed (Optional[bool], optional): If an experiment failed,
|
29 |
+
keep track of it but prevent from evaluating. Defaults to False.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
callable: decorator function
|
33 |
+
"""
|
34 |
+
def decorator(func):
|
35 |
+
assert (major) not in REGISTERED_EXPERIMENTS_LIST, f"Experiment {major} already registered"
|
36 |
+
|
37 |
+
def wrapper(config, minor=None, no_model=False, model=torch.nn.Module()):
|
38 |
+
config, model = func(config, model=None if not no_model else model, minor=minor)
|
39 |
+
config[NB_PARAMS] = model.count_parameters()
|
40 |
+
config[RECEPTIVE_FIELD] = model.receptive_field()
|
41 |
+
assert NAME in config, "NAME not defined"
|
42 |
+
assert ANNOTATIONS in config, "ANNOTATIONS not defined"
|
43 |
+
return model, config
|
44 |
+
if not failed:
|
45 |
+
REGISTERED_EXPERIMENTS_LIST[major] = wrapper
|
46 |
+
return wrapper
|
47 |
+
|
48 |
+
return decorator
|
src/gyraudio/audio_separation/experiment_tracking/experiments_definition.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.architecture.flat_conv import FlatConvolutional
|
2 |
+
from gyraudio.audio_separation.architecture.unet import ResUNet
|
3 |
+
from gyraudio.audio_separation.architecture.wave_unet import WaveUNet
|
4 |
+
from gyraudio.audio_separation.architecture.neutral import NeutralModel
|
5 |
+
from gyraudio.audio_separation.architecture.transformer import TransformerModel
|
6 |
+
from gyraudio.audio_separation.properties import (
|
7 |
+
NAME, ANNOTATIONS, MAX_STEPS_PER_EPOCH, EPOCHS, BATCH_SIZE,
|
8 |
+
OPTIMIZER, LEARNING_RATE,
|
9 |
+
DATALOADER,
|
10 |
+
WEIGHT_DECAY,
|
11 |
+
LOSS, LOSS_L1,
|
12 |
+
AUGMENTATION, AUG_TRIM, AUG_AWGN, AUG_RESCALE,
|
13 |
+
LENGTHS, LENGTH_DIVIDER, TRIM_PROB,
|
14 |
+
SCHEDULER, SCHEDULER_CONFIGURATION
|
15 |
+
)
|
16 |
+
from gyraudio.audio_separation.experiment_tracking.experiments_decorator import (
|
17 |
+
registered_experiment, REGISTERED_EXPERIMENTS_LIST
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
@registered_experiment(major=9999)
|
22 |
+
def neutral(config, model: bool = None, minor=None):
|
23 |
+
config[BATCH_SIZE] = [4, 4, 4]
|
24 |
+
config[EPOCHS] = 1
|
25 |
+
config[NAME] = "Neutral"
|
26 |
+
config[ANNOTATIONS] = "Neutral"
|
27 |
+
if model is None:
|
28 |
+
model = NeutralModel()
|
29 |
+
config[NAME] = "Neutral"
|
30 |
+
return config, model
|
31 |
+
|
32 |
+
|
33 |
+
@registered_experiment(major=0)
|
34 |
+
def exp_unit_test(config, model: bool = None, minor=None):
|
35 |
+
config[MAX_STEPS_PER_EPOCH] = 2
|
36 |
+
config[BATCH_SIZE] = [4, 4, 4]
|
37 |
+
config[EPOCHS] = 2
|
38 |
+
config[NAME] = "Unit Test - Flat Convolutional"
|
39 |
+
config[ANNOTATIONS] = "Baseline"
|
40 |
+
config[SCHEDULER] = "ReduceLROnPlateau"
|
41 |
+
config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8)
|
42 |
+
if model is None:
|
43 |
+
model = FlatConvolutional()
|
44 |
+
return config, model
|
45 |
+
|
46 |
+
# ---------------- Low Baseline -----------------
|
47 |
+
|
48 |
+
|
49 |
+
def exp_low_baseline(
|
50 |
+
config: dict,
|
51 |
+
batch_size: int = 16,
|
52 |
+
h_dim: int = 16,
|
53 |
+
k_size: int = 9,
|
54 |
+
dilation: int = 0,
|
55 |
+
model: bool = None,
|
56 |
+
minor=None
|
57 |
+
):
|
58 |
+
config[BATCH_SIZE] = [batch_size, batch_size, batch_size]
|
59 |
+
config[NAME] = "Flat Convolutional"
|
60 |
+
config[ANNOTATIONS] = f"Baseline H={h_dim}_K={k_size}"
|
61 |
+
if dilation > 1:
|
62 |
+
config[ANNOTATIONS] += f"_dil={dilation}"
|
63 |
+
config["Architecture"] = {
|
64 |
+
"name": "Flat-Conv",
|
65 |
+
"h_dim": h_dim,
|
66 |
+
"scales": 1,
|
67 |
+
"k_size": k_size,
|
68 |
+
"dilation": dilation
|
69 |
+
}
|
70 |
+
if model is None:
|
71 |
+
model = FlatConvolutional(k_size=k_size, h_dim=h_dim)
|
72 |
+
return config, model
|
73 |
+
|
74 |
+
|
75 |
+
@registered_experiment(major=1)
|
76 |
+
def exp_1(config, model: bool = None, minor=None):
|
77 |
+
config, model = exp_low_baseline(config, batch_size=32, k_size=5)
|
78 |
+
return config, model
|
79 |
+
|
80 |
+
|
81 |
+
@registered_experiment(major=2)
|
82 |
+
def exp_2(config, model: bool = None, minor=None):
|
83 |
+
config, model = exp_low_baseline(config, batch_size=32, k_size=9)
|
84 |
+
return config, model
|
85 |
+
|
86 |
+
|
87 |
+
@registered_experiment(major=3)
|
88 |
+
def exp_3(config, model: bool = None, minor=None):
|
89 |
+
config, model = exp_low_baseline(config, batch_size=32, k_size=9, dilation=2)
|
90 |
+
return config, model
|
91 |
+
|
92 |
+
|
93 |
+
@registered_experiment(major=4)
|
94 |
+
def exp_4(config, model: bool = None, minor=None):
|
95 |
+
config, model = exp_low_baseline(config, batch_size=16, k_size=9)
|
96 |
+
return config, model
|
97 |
+
|
98 |
+
# ------------------ Res U-Net ------------------
|
99 |
+
|
100 |
+
|
101 |
+
def exp_resunet(config, h_dim=16, k_size=5, model=None):
|
102 |
+
config[NAME] = "Res-UNet"
|
103 |
+
scales = 4
|
104 |
+
config[ANNOTATIONS] = f"Res-UNet-{scales}scales_h={h_dim}_k={k_size}"
|
105 |
+
config["Architecture"] = {
|
106 |
+
"name": "Res-UNet",
|
107 |
+
"h_dim": h_dim,
|
108 |
+
"scales": scales,
|
109 |
+
"k_size": k_size,
|
110 |
+
}
|
111 |
+
if model is None:
|
112 |
+
model = ResUNet(h_dim=h_dim, k_size=k_size)
|
113 |
+
return config, model
|
114 |
+
|
115 |
+
|
116 |
+
@registered_experiment(major=2000)
|
117 |
+
def exp_2000_waveunet(config, model: bool = None, minor=None):
|
118 |
+
config[EPOCHS] = 60
|
119 |
+
config, model = exp_resunet(config)
|
120 |
+
return config, model
|
121 |
+
|
122 |
+
|
123 |
+
@registered_experiment(major=2001)
|
124 |
+
def exp_2001_waveunet(config, model: bool = None, minor=None):
|
125 |
+
config[EPOCHS] = 60
|
126 |
+
config, model = exp_resunet(config, h_dim=32, k_size=5)
|
127 |
+
return config, model
|
128 |
+
|
129 |
+
# ------------------ Wave U-Net ------------------
|
130 |
+
|
131 |
+
|
132 |
+
def exp_wave_unet(config: dict,
|
133 |
+
channels_extension: int = 24,
|
134 |
+
k_conv_ds: int = 15,
|
135 |
+
k_conv_us: int = 5,
|
136 |
+
num_layers: int = 4,
|
137 |
+
dropout: float = 0.0,
|
138 |
+
bias: bool = True,
|
139 |
+
model=None):
|
140 |
+
config[NAME] = "Wave-UNet"
|
141 |
+
config[ANNOTATIONS] = f"Wave-UNet-{num_layers}scales_h_ext={channels_extension}_k={k_conv_ds}ds-{k_conv_us}us"
|
142 |
+
if dropout > 0:
|
143 |
+
config[ANNOTATIONS] += f"-dr{dropout:.1e}"
|
144 |
+
if not bias:
|
145 |
+
config[ANNOTATIONS] += "-BiasFree"
|
146 |
+
config["Architecture"] = {
|
147 |
+
"k_conv_us": k_conv_us,
|
148 |
+
"k_conv_ds": k_conv_ds,
|
149 |
+
"num_layers": num_layers,
|
150 |
+
"channels_extension": channels_extension,
|
151 |
+
"dropout": dropout,
|
152 |
+
"bias": bias
|
153 |
+
}
|
154 |
+
if model is None:
|
155 |
+
model = WaveUNet(
|
156 |
+
**config["Architecture"]
|
157 |
+
)
|
158 |
+
config["Architecture"][NAME] = "Wave-UNet"
|
159 |
+
return config, model
|
160 |
+
|
161 |
+
|
162 |
+
@registered_experiment(major=1000)
|
163 |
+
def exp_1000_waveunet(config, model: bool = None, minor=None):
|
164 |
+
config[EPOCHS] = 60
|
165 |
+
config, model = exp_wave_unet(config, model=model, num_layers=4, channels_extension=24)
|
166 |
+
# 4 layers, ext +24 - Nvidia T500 4Gb RAM - 16 batch size
|
167 |
+
return config, model
|
168 |
+
|
169 |
+
|
170 |
+
@registered_experiment(major=1001)
|
171 |
+
def exp_1001_waveunet(config, model: bool = None, minor=None):
|
172 |
+
# OVERFIT 1M param ?
|
173 |
+
config[EPOCHS] = 60
|
174 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=16)
|
175 |
+
# 7 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
|
176 |
+
return config, model
|
177 |
+
|
178 |
+
|
179 |
+
@registered_experiment(major=1002)
|
180 |
+
def exp_1002_waveunet(config, model: bool = None, minor=None):
|
181 |
+
# OVERFIT 1M param ?
|
182 |
+
config[EPOCHS] = 60
|
183 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=16)
|
184 |
+
config[DATALOADER][AUGMENTATION] = {
|
185 |
+
AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 1024, TRIM_PROB: 0.8},
|
186 |
+
AUG_RESCALE: True
|
187 |
+
}
|
188 |
+
# 7 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
|
189 |
+
return config, model
|
190 |
+
|
191 |
+
|
192 |
+
@registered_experiment(major=1003)
|
193 |
+
def exp_1003_waveunet(config, model: bool = None, minor=None):
|
194 |
+
# OVERFIT 2.3M params
|
195 |
+
config[EPOCHS] = 60
|
196 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=24)
|
197 |
+
# 7 layers, ext +24 - Nvidia RTX3060 6Gb RAM - 16 batch size
|
198 |
+
return config, model
|
199 |
+
|
200 |
+
|
201 |
+
@registered_experiment(major=1004)
|
202 |
+
def exp_1004_waveunet(config, model: bool = None, minor=None):
|
203 |
+
config[EPOCHS] = 120
|
204 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28)
|
205 |
+
# 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
|
206 |
+
return config, model
|
207 |
+
|
208 |
+
|
209 |
+
@registered_experiment(major=1014)
|
210 |
+
def exp_1014_waveunet(config, model: bool = None, minor=None):
|
211 |
+
# trained with min and max mixing snr hard coded between -2 and -1
|
212 |
+
config[EPOCHS] = 50
|
213 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28)
|
214 |
+
# 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
|
215 |
+
return config, model
|
216 |
+
|
217 |
+
|
218 |
+
@registered_experiment(major=1005)
|
219 |
+
def exp_1005_waveunet(config, model: bool = None, minor=None):
|
220 |
+
config[EPOCHS] = 150
|
221 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=16)
|
222 |
+
config[DATALOADER][AUGMENTATION] = {
|
223 |
+
AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 1024, TRIM_PROB: 0.8},
|
224 |
+
AUG_RESCALE: True
|
225 |
+
}
|
226 |
+
# 7 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
|
227 |
+
return config, model
|
228 |
+
|
229 |
+
|
230 |
+
@registered_experiment(major=1006)
|
231 |
+
def exp_1006_waveunet(config, model: bool = None, minor=None):
|
232 |
+
config[EPOCHS] = 150
|
233 |
+
config, model = exp_wave_unet(config, model=model, num_layers=11, channels_extension=16)
|
234 |
+
config[DATALOADER][AUGMENTATION] = {
|
235 |
+
AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 4096, TRIM_PROB: 0.8},
|
236 |
+
AUG_RESCALE: True
|
237 |
+
}
|
238 |
+
# 11 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
|
239 |
+
return config, model
|
240 |
+
|
241 |
+
|
242 |
+
@registered_experiment(major=1007)
|
243 |
+
def exp_1007_waveunet(config, model: bool = None, minor=None):
|
244 |
+
config[EPOCHS] = 150
|
245 |
+
config, model = exp_wave_unet(config, model=model, num_layers=9, channels_extension=16)
|
246 |
+
config[DATALOADER][AUGMENTATION] = {
|
247 |
+
AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 4096, TRIM_PROB: 0.8},
|
248 |
+
AUG_RESCALE: True
|
249 |
+
}
|
250 |
+
# 11 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
|
251 |
+
return config, model
|
252 |
+
|
253 |
+
|
254 |
+
@registered_experiment(major=1008)
|
255 |
+
def exp_1008_waveunet(config, model: bool = None, minor=None):
|
256 |
+
# CHEAP BASELINE
|
257 |
+
config[EPOCHS] = 150
|
258 |
+
config, model = exp_wave_unet(config, model=model, num_layers=4, channels_extension=16)
|
259 |
+
config[DATALOADER][AUGMENTATION] = {
|
260 |
+
AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 1024, TRIM_PROB: 0.8},
|
261 |
+
AUG_RESCALE: True
|
262 |
+
}
|
263 |
+
# 4 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
|
264 |
+
return config, model
|
265 |
+
|
266 |
+
|
267 |
+
@registered_experiment(major=3000)
|
268 |
+
def exp_3000_waveunet(config, model: bool = None, minor=None):
|
269 |
+
config[EPOCHS] = 120
|
270 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False)
|
271 |
+
# 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
|
272 |
+
return config, model
|
273 |
+
|
274 |
+
|
275 |
+
@registered_experiment(major=3001)
|
276 |
+
def exp_3001_waveunet(config, model: bool = None, minor=None):
|
277 |
+
config[EPOCHS] = 200
|
278 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False)
|
279 |
+
# 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
|
280 |
+
config[SCHEDULER] = "ReduceLROnPlateau"
|
281 |
+
config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8)
|
282 |
+
config[OPTIMIZER][LEARNING_RATE] = 0.002
|
283 |
+
return config, model
|
284 |
+
|
285 |
+
|
286 |
+
@registered_experiment(major=3002)
|
287 |
+
def exp_3002_waveunet(config, model: bool = None, minor=None):
|
288 |
+
# TRAINED WITH SNR -12db +12db (code changed manually!)
|
289 |
+
# See f910c6da3123e3d35cc0ce588bb5a72ce4a8c422
|
290 |
+
config[EPOCHS] = 200
|
291 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False)
|
292 |
+
# 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
|
293 |
+
config[SCHEDULER] = "ReduceLROnPlateau"
|
294 |
+
config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8)
|
295 |
+
config[OPTIMIZER][LEARNING_RATE] = 0.002
|
296 |
+
return config, model
|
297 |
+
|
298 |
+
|
299 |
+
@registered_experiment(major=4000)
|
300 |
+
def exp_4000_bias_free_waveunet_l1(config, model: bool = None, minor=None):
|
301 |
+
# config[MAX_STEPS_PER_EPOCH] = 2
|
302 |
+
# config[BATCH_SIZE] = [2, 2, 2]
|
303 |
+
config[EPOCHS] = 200
|
304 |
+
config[LOSS] = LOSS_L1
|
305 |
+
config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False)
|
306 |
+
# 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
|
307 |
+
config[SCHEDULER] = "ReduceLROnPlateau"
|
308 |
+
config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8)
|
309 |
+
config[OPTIMIZER][LEARNING_RATE] = 0.002
|
310 |
+
return config, model
|
311 |
+
|
312 |
+
|
313 |
+
def get_experiment_generator(exp_major: int):
|
314 |
+
assert exp_major in REGISTERED_EXPERIMENTS_LIST, f"Experiment {exp_major} not registered"
|
315 |
+
exp_generator = REGISTERED_EXPERIMENTS_LIST[exp_major]
|
316 |
+
return exp_generator
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
print(f"Available experiments: {list(REGISTERED_EXPERIMENTS_LIST.keys())}")
|
src/gyraudio/audio_separation/experiment_tracking/storage.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.properties import SHORT_NAME, MODEL, OPTIMIZER, CURRENT_EPOCH, CONFIGURATION
|
2 |
+
from pathlib import Path
|
3 |
+
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
|
4 |
+
import logging
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def get_output_folder(config: dict, root_dir: Path = EXPERIMENT_STORAGE_ROOT, override: bool = False) -> Path:
|
9 |
+
output_folder = root_dir/config["short_name"]
|
10 |
+
exists = False
|
11 |
+
if output_folder.exists():
|
12 |
+
if not override:
|
13 |
+
logging.info(f"Experiment {config[SHORT_NAME]} already exists. Override is set to False. Skipping.")
|
14 |
+
if override:
|
15 |
+
logging.warning(f"Experiment {config[SHORT_NAME]} will be OVERRIDDEN")
|
16 |
+
exists = True
|
17 |
+
else:
|
18 |
+
output_folder.mkdir(parents=True, exist_ok=True)
|
19 |
+
exists = True
|
20 |
+
return exists, output_folder
|
21 |
+
|
22 |
+
|
23 |
+
def checkpoint_paths(exp_dir: Path, epoch=None):
|
24 |
+
if epoch is None:
|
25 |
+
checkpoints = sorted(exp_dir.glob("model_*.pt"))
|
26 |
+
assert len(checkpoints) > 0, f"No checkpoints found in {exp_dir}"
|
27 |
+
model_checkpoint = checkpoints[-1]
|
28 |
+
epoch = int(model_checkpoint.stem.split("_")[-1])
|
29 |
+
optimizer_checkpoint = exp_dir/model_checkpoint.stem.replace("model", "optimizer")
|
30 |
+
else:
|
31 |
+
model_checkpoint = exp_dir/f"model_{epoch:04d}.pt"
|
32 |
+
optimizer_checkpoint = exp_dir/f"optimizer_{epoch:04d}.pt"
|
33 |
+
return model_checkpoint, optimizer_checkpoint, epoch
|
34 |
+
|
35 |
+
|
36 |
+
def load_checkpoint(model, exp_dir: Path, optimizer=None, epoch: int = None,
|
37 |
+
device="cuda" if torch.cuda.is_available() else "cpu"):
|
38 |
+
config = {}
|
39 |
+
model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch)
|
40 |
+
model_state_dict = torch.load(model_checkpoint, map_location=torch.device(device))
|
41 |
+
model.load_state_dict(model_state_dict[MODEL])
|
42 |
+
if optimizer is not None:
|
43 |
+
optimizer_state_dict = torch.load(optimizer_checkpoint, map_location=torch.device(device))
|
44 |
+
optimizer.load_state_dict(optimizer_state_dict[OPTIMIZER])
|
45 |
+
config = optimizer_state_dict[CONFIGURATION]
|
46 |
+
return model, optimizer, epoch, config
|
47 |
+
|
48 |
+
|
49 |
+
def save_checkpoint(model, exp_dir: Path, optimizer=None, config: dict = {}, epoch: int = None):
|
50 |
+
model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch)
|
51 |
+
torch.save(
|
52 |
+
{
|
53 |
+
MODEL: model.state_dict(),
|
54 |
+
},
|
55 |
+
model_checkpoint
|
56 |
+
)
|
57 |
+
torch.save(
|
58 |
+
{
|
59 |
+
CURRENT_EPOCH: epoch,
|
60 |
+
CONFIGURATION: config,
|
61 |
+
OPTIMIZER: optimizer.state_dict()
|
62 |
+
},
|
63 |
+
optimizer_checkpoint
|
64 |
+
)
|
65 |
+
print(f"Checkpoint saved:\n - model: {model_checkpoint}\n - checkpoint: {optimizer_checkpoint}")
|
src/gyraudio/audio_separation/infer.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.experiment_tracking.experiments import get_experience
|
2 |
+
from gyraudio.audio_separation.parser import shared_parser
|
3 |
+
from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER
|
4 |
+
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
|
5 |
+
from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint
|
6 |
+
from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder
|
7 |
+
from gyraudio.audio_separation.metrics import snr
|
8 |
+
from gyraudio.io.dump import Dump
|
9 |
+
from pathlib import Path
|
10 |
+
import sys
|
11 |
+
import torch
|
12 |
+
from tqdm import tqdm
|
13 |
+
import torchaudio
|
14 |
+
import pandas as pd
|
15 |
+
from typing import List
|
16 |
+
# Files paths
|
17 |
+
DEFAULT_RECORD_FILE = "infer_record.csv" # Store the characteristics of the inference record file
|
18 |
+
DEFAULT_EVALUATION_FILE = "eval_df.csv" # Store the characteristics of the inference record file
|
19 |
+
# Record keys
|
20 |
+
NBATCH = "nb_batch"
|
21 |
+
BEST_SNR = "best_snr"
|
22 |
+
BEST_SAVE_SNR = "best_save_snr"
|
23 |
+
WORST_SNR = "worst_snr"
|
24 |
+
WORST_SAVE_SNR = "worst_save_snr"
|
25 |
+
RECORD_KEYS = [NAME, SHORT_NAME, CURRENT_EPOCH, NBATCH, SNR_FILTER, BEST_SAVE_SNR, BEST_SNR, WORST_SAVE_SNR, WORST_SNR]
|
26 |
+
# Exaluation keys
|
27 |
+
SAVE_IDX = "save_idx"
|
28 |
+
SNR_IN = "snr_in"
|
29 |
+
SNR_OUT = "snr_out"
|
30 |
+
EVAL_KEYS = [SAVE_IDX, SNR_IN, SNR_OUT]
|
31 |
+
|
32 |
+
|
33 |
+
def load_file(path: Path, keys: List[str]) -> pd.DataFrame:
|
34 |
+
if not (path.exists()):
|
35 |
+
df = pd.DataFrame(columns=keys)
|
36 |
+
df.to_csv(path)
|
37 |
+
return pd.read_csv(path)
|
38 |
+
|
39 |
+
|
40 |
+
def launch_infer(exp: int, snr_filter: list = None, device: str = "cuda", model_dir: Path = None,
|
41 |
+
output_dir: Path = EXPERIMENT_STORAGE_ROOT, force_reload=False, max_batches=None,
|
42 |
+
ext=".wav"):
|
43 |
+
# Load experience
|
44 |
+
if snr_filter is not None:
|
45 |
+
snr_filter = sorted(snr_filter)
|
46 |
+
short_name, model, config, dl = get_experience(exp, snr_filter_test=snr_filter)
|
47 |
+
exists, exp_dir = get_output_folder(config, root_dir=model_dir, override=False)
|
48 |
+
assert exp_dir.exists(), f"Experiment {short_name} does not exist in {model_dir}"
|
49 |
+
model.eval()
|
50 |
+
model.to(device)
|
51 |
+
model, optimizer, epoch, config_checkpt = load_checkpoint(model, exp_dir, epoch=None, device=device)
|
52 |
+
# Folder creation
|
53 |
+
if output_dir is not None:
|
54 |
+
record_path = output_dir/DEFAULT_RECORD_FILE
|
55 |
+
record_df = load_file(record_path, RECORD_KEYS)
|
56 |
+
|
57 |
+
# Define conditions for filtering
|
58 |
+
exist_conditions = {
|
59 |
+
NAME: config[NAME],
|
60 |
+
SHORT_NAME: config[SHORT_NAME],
|
61 |
+
CURRENT_EPOCH: epoch,
|
62 |
+
NBATCH: max_batches,
|
63 |
+
}
|
64 |
+
# Create boolean masks and combine them
|
65 |
+
masks = [(record_df[key] == value) for key, value in exist_conditions.items()]
|
66 |
+
if snr_filter is None:
|
67 |
+
masks.append((record_df[SNR_FILTER]).isnull())
|
68 |
+
else:
|
69 |
+
masks.append(record_df[SNR_FILTER] == str(snr_filter))
|
70 |
+
combined_mask = pd.Series(True, index=record_df.index)
|
71 |
+
for mask in masks:
|
72 |
+
combined_mask = combined_mask & mask
|
73 |
+
filtered_df = record_df[combined_mask]
|
74 |
+
|
75 |
+
save_dir = output_dir/(exp_dir.name+"_infer" + (f"_epoch_{epoch:04d}_nbatch_{max_batches if max_batches is not None else len(dl[TEST])}")
|
76 |
+
+ ("" if snr_filter is None else f"_snrs_{'_'.join(map(str, snr_filter))}"))
|
77 |
+
evaluation_path = save_dir/DEFAULT_EVALUATION_FILE
|
78 |
+
if not (filtered_df.empty) and not (force_reload):
|
79 |
+
assert evaluation_path.exists()
|
80 |
+
print(f"Inference already exists, see folder {save_dir}")
|
81 |
+
record_row_df = filtered_df
|
82 |
+
else:
|
83 |
+
record_row_df = pd.DataFrame({
|
84 |
+
NAME: config[NAME],
|
85 |
+
SHORT_NAME: config[SHORT_NAME],
|
86 |
+
CURRENT_EPOCH: epoch,
|
87 |
+
NBATCH: max_batches,
|
88 |
+
SNR_FILTER: [None],
|
89 |
+
}, index=[0], columns=RECORD_KEYS)
|
90 |
+
record_row_df.at[0, SNR_FILTER] = snr_filter
|
91 |
+
|
92 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
93 |
+
evaluation_df = load_file(evaluation_path, EVAL_KEYS)
|
94 |
+
with torch.no_grad():
|
95 |
+
test_loss = 0.
|
96 |
+
save_idx = 0
|
97 |
+
best_snr = 0
|
98 |
+
worst_snr = 0
|
99 |
+
processed_batches = 0
|
100 |
+
for step_index, (batch_mix, batch_signal, batch_noise) in tqdm(
|
101 |
+
enumerate(dl[TEST]), desc=f"Inference epoch {epoch}", total=max_batches if max_batches is not None else len(dl[TEST])):
|
102 |
+
batch_mix, batch_signal, batch_noise = batch_mix.to(
|
103 |
+
device), batch_signal.to(device), batch_noise.to(device)
|
104 |
+
batch_output_signal, _batch_output_noise = model(batch_mix)
|
105 |
+
loss = torch.nn.functional.mse_loss(batch_output_signal, batch_signal)
|
106 |
+
test_loss += loss.item()
|
107 |
+
|
108 |
+
# SNR stats
|
109 |
+
snr_in = snr(batch_mix, batch_signal, reduce=None)
|
110 |
+
snr_out = snr(batch_output_signal, batch_signal, reduce=None)
|
111 |
+
best_current, best_idx_current = torch.max(snr_out-snr_in, axis=0)
|
112 |
+
worst_current, worst_idx_current = torch.min(snr_out-snr_in, axis=0)
|
113 |
+
if best_current > best_snr:
|
114 |
+
best_snr = best_current
|
115 |
+
best_save_idx = save_idx + best_idx_current
|
116 |
+
if worst_current > worst_snr:
|
117 |
+
worst_snr = worst_current
|
118 |
+
worst_save_idx = save_idx + worst_idx_current
|
119 |
+
|
120 |
+
# Save by signal
|
121 |
+
batch_output_signal = batch_output_signal.detach().cpu()
|
122 |
+
batch_signal = batch_signal.detach().cpu()
|
123 |
+
batch_mix = batch_mix.detach().cpu()
|
124 |
+
for audio_idx in range(batch_output_signal.shape[0]):
|
125 |
+
dic = {SAVE_IDX: save_idx, SNR_IN: float(
|
126 |
+
snr_in[audio_idx]), SNR_OUT: float(snr_out[audio_idx])}
|
127 |
+
new_eval_row = pd.DataFrame(dic, index=[0])
|
128 |
+
evaluation_df = pd.concat([new_eval_row, evaluation_df.loc[:]], ignore_index=True)
|
129 |
+
|
130 |
+
# Save .wav
|
131 |
+
torchaudio.save(
|
132 |
+
str(save_dir/f"{save_idx:04d}_mixed{ext}"),
|
133 |
+
batch_mix[audio_idx, :, :],
|
134 |
+
sample_rate=dl[TEST].dataset.sampling_rate,
|
135 |
+
channels_first=True
|
136 |
+
)
|
137 |
+
torchaudio.save(
|
138 |
+
str(save_dir/f"{save_idx:04d}_out{ext}"),
|
139 |
+
batch_output_signal[audio_idx, :, :],
|
140 |
+
sample_rate=dl[TEST].dataset.sampling_rate,
|
141 |
+
channels_first=True
|
142 |
+
)
|
143 |
+
torchaudio.save(
|
144 |
+
str(save_dir/f"{save_idx:04d}_original{ext}"),
|
145 |
+
batch_signal[audio_idx, :, :],
|
146 |
+
sample_rate=dl[TEST].dataset.sampling_rate,
|
147 |
+
channels_first=True
|
148 |
+
)
|
149 |
+
Dump.save_json(dic, save_dir/f"{save_idx:04d}.json")
|
150 |
+
save_idx += 1
|
151 |
+
processed_batches += 1
|
152 |
+
if max_batches is not None and processed_batches >= max_batches:
|
153 |
+
break
|
154 |
+
test_loss = test_loss/len(dl[TEST])
|
155 |
+
evaluation_df.to_csv(evaluation_path)
|
156 |
+
|
157 |
+
record_row_df[BEST_SAVE_SNR] = int(best_save_idx)
|
158 |
+
record_row_df[BEST_SNR] = float(best_snr)
|
159 |
+
record_row_df[WORST_SAVE_SNR] = int(worst_save_idx)
|
160 |
+
record_row_df[WORST_SNR] = float(worst_snr)
|
161 |
+
record_df = pd.concat([record_row_df, record_df.loc[:]], ignore_index=True)
|
162 |
+
record_df.to_csv(record_path, index=0)
|
163 |
+
|
164 |
+
print(f"Test loss: {test_loss:.3e}, \nbest snr performance: {best_save_idx} with {best_snr:.1f}dB, \nworst snr performance: {worst_save_idx} with {worst_snr:.1f}dB")
|
165 |
+
|
166 |
+
return record_row_df, evaluation_path
|
167 |
+
|
168 |
+
|
169 |
+
def main(argv):
|
170 |
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
171 |
+
parser_def = shared_parser(help="Launch inference on a specific model"
|
172 |
+
+ ("\n<<<Cuda available>>>" if default_device == "cuda" else ""))
|
173 |
+
parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
|
174 |
+
parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
|
175 |
+
parser_def.add_argument("-d", "--device", type=str, default=default_device,
|
176 |
+
help="Training device", choices=["cpu", "cuda"])
|
177 |
+
parser_def.add_argument("-r", "--reload", action="store_true",
|
178 |
+
help="Force reload files")
|
179 |
+
parser_def.add_argument("-b", "--nb-batch", type=int, default=None,
|
180 |
+
help="Number of batches to process")
|
181 |
+
parser_def.add_argument("-s", "--snr-filter", type=float, nargs="+", default=None,
|
182 |
+
help="SNR filters on the inference dataloader")
|
183 |
+
parser_def.add_argument("-ext", "--extension", type=str, default=".wav", help="Extension of the audio files to save",
|
184 |
+
choices=[".wav", ".mp4"])
|
185 |
+
args = parser_def.parse_args(argv)
|
186 |
+
for exp in args.experiments:
|
187 |
+
launch_infer(
|
188 |
+
exp,
|
189 |
+
model_dir=Path(args.input_dir),
|
190 |
+
output_dir=Path(args.output_dir),
|
191 |
+
device=args.device,
|
192 |
+
force_reload=args.reload,
|
193 |
+
max_batches=args.nb_batch,
|
194 |
+
snr_filter=args.snr_filter,
|
195 |
+
ext=args.extension
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == "__main__":
|
200 |
+
main(sys.argv[1:])
|
201 |
+
|
202 |
+
# Example : python src\gyraudio\audio_separation\infer.py -i ./__output_audiosep -e 1002 -d cpu -b 2 -s 4 5 6
|
src/gyraudio/audio_separation/metrics.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.properties import SIGNAL, NOISE, TOTAL, LOSS_TYPE, COEFFICIENT, SNR
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def snr(prediction: torch.Tensor, ground_truth: torch.Tensor, reduce="mean") -> torch.Tensor:
|
6 |
+
"""Compute the SNR between two tensors.
|
7 |
+
Args:
|
8 |
+
prediction (torch.Tensor): prediction tensor
|
9 |
+
ground_truth (torch.Tensor): ground truth tensor
|
10 |
+
Returns:
|
11 |
+
torch.Tensor: SNR
|
12 |
+
"""
|
13 |
+
power_signal = torch.sum(ground_truth**2, dim=(-2, -1))
|
14 |
+
power_error = torch.sum((prediction-ground_truth)**2, dim=(-2, -1))
|
15 |
+
eps = torch.finfo(torch.float32).eps
|
16 |
+
snr_per_element = 10*torch.log10((power_signal+eps)/(power_error+eps))
|
17 |
+
final_snr = torch.mean(snr_per_element) if reduce == "mean" else snr_per_element
|
18 |
+
return final_snr
|
19 |
+
|
20 |
+
|
21 |
+
DEFAULT_COST = {
|
22 |
+
SIGNAL: {
|
23 |
+
COEFFICIENT: 0.5,
|
24 |
+
LOSS_TYPE: torch.nn.functional.mse_loss
|
25 |
+
},
|
26 |
+
NOISE: {
|
27 |
+
COEFFICIENT: 0.5,
|
28 |
+
LOSS_TYPE: torch.nn.functional.mse_loss
|
29 |
+
},
|
30 |
+
SNR: {
|
31 |
+
LOSS_TYPE: snr
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
class Costs:
|
37 |
+
"""Keep track of cost functions.
|
38 |
+
```
|
39 |
+
for epoch in range(...):
|
40 |
+
metric.reset_epoch()
|
41 |
+
for step in dataloader(...):
|
42 |
+
... # forward
|
43 |
+
prediction = model.forward(batch)
|
44 |
+
metric.update(prediction1, groundtruth1, SIGNAL1)
|
45 |
+
metric.update(prediction2, groundtruth2, SIGNAL2)
|
46 |
+
loss = metric.finish_step()
|
47 |
+
|
48 |
+
loss.backward()
|
49 |
+
... # backprop
|
50 |
+
metric.finish_epoch()
|
51 |
+
... # log metrics
|
52 |
+
```
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, name: str, costs=DEFAULT_COST) -> None:
|
56 |
+
self.name = name
|
57 |
+
self.keys = list(costs.keys())
|
58 |
+
self.cost = costs
|
59 |
+
|
60 |
+
def __reset_step(self) -> None:
|
61 |
+
self.metrics = {key: 0. for key in self.keys}
|
62 |
+
|
63 |
+
def reset_epoch(self) -> None:
|
64 |
+
self.__reset_step()
|
65 |
+
self.total_metric = {key: 0. for key in self.keys+[TOTAL]}
|
66 |
+
self.count = 0
|
67 |
+
|
68 |
+
def update(self,
|
69 |
+
prediction: torch.Tensor,
|
70 |
+
ground_truth: torch.Tensor,
|
71 |
+
key: str
|
72 |
+
) -> torch.Tensor:
|
73 |
+
assert key != TOTAL
|
74 |
+
# Compute loss for a single batch (=step)
|
75 |
+
loss_signal = self.cost[key][LOSS_TYPE](prediction, ground_truth)
|
76 |
+
self.metrics[key] = loss_signal
|
77 |
+
|
78 |
+
def finish_step(self) -> torch.Tensor:
|
79 |
+
# Reset current total
|
80 |
+
self.metrics[TOTAL] = 0.
|
81 |
+
# Sum all metrics to total
|
82 |
+
for key in self.metrics:
|
83 |
+
if key != TOTAL and self.cost[key].get(COEFFICIENT, False):
|
84 |
+
self.metrics[TOTAL] += self.cost[key][COEFFICIENT]*self.metrics[key]
|
85 |
+
loss_signal = self.metrics[TOTAL]
|
86 |
+
for key in self.metrics:
|
87 |
+
if not isinstance(self.metrics[key], float):
|
88 |
+
self.metrics[key] = self.metrics[key].item()
|
89 |
+
self.total_metric[key] += self.metrics[key]
|
90 |
+
self.count += 1
|
91 |
+
return loss_signal
|
92 |
+
|
93 |
+
def finish_epoch(self) -> None:
|
94 |
+
for key in self.metrics:
|
95 |
+
self.total_metric[key] /= self.count
|
96 |
+
|
97 |
+
def __repr__(self) -> str:
|
98 |
+
rep = f"{self.name}\t:\t"
|
99 |
+
for key in self.total_metric:
|
100 |
+
rep += f"{key}: {self.total_metric[key]:.3e} | "
|
101 |
+
return rep
|
src/gyraudio/audio_separation/parser.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
def shared_parser(help="Train models for audio separation"):
|
5 |
+
parser = argparse.ArgumentParser(description=help)
|
6 |
+
parser.add_argument("-e", "--experiments", type=int, nargs="+", required=True,
|
7 |
+
help="Experiment ids to be trained sequentially")
|
8 |
+
return parser
|
src/gyraudio/audio_separation/properties.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training modes (Train, Validation, Test)
|
2 |
+
TRAIN = "train"
|
3 |
+
VALID = "validation"
|
4 |
+
TEST = "test"
|
5 |
+
|
6 |
+
# Dataset properties (keys)
|
7 |
+
DATA_PATH = "path"
|
8 |
+
BATCH_SIZE = "batch_size"
|
9 |
+
SHUFFLE = "shuffle"
|
10 |
+
SNR_FILTER = "snr_filter"
|
11 |
+
AUGMENTATION = "augmentation"
|
12 |
+
DATALOADER = "dataloader"
|
13 |
+
|
14 |
+
|
15 |
+
# Loss split
|
16 |
+
SIGNAL = "signal"
|
17 |
+
NOISE = "noise"
|
18 |
+
TOTAL = "total"
|
19 |
+
COEFFICIENT = "coefficient"
|
20 |
+
|
21 |
+
|
22 |
+
# Augmentation types
|
23 |
+
AUG_TRIM = "trim" # trim batches to arbitrary length
|
24 |
+
AUG_AWGN = "awgn" # add white gaussian noise
|
25 |
+
AUG_RESCALE = "rescale" # rescale all signals arbitrarily
|
26 |
+
|
27 |
+
# Trim types
|
28 |
+
LENGTHS = "lengths" # a list of min and max length
|
29 |
+
LENGTH_DIVIDER = "length_divider" # an int that divides the length
|
30 |
+
TRIM_PROB = "trim_probability" # a float in [0, 1] of trimming probability
|
31 |
+
|
32 |
+
|
33 |
+
# Training configuration properties (keys)
|
34 |
+
|
35 |
+
OPTIMIZER = "optimizer"
|
36 |
+
LEARNING_RATE = "lr"
|
37 |
+
WEIGHT_DECAY = "weight_decay"
|
38 |
+
BETAS = "betas"
|
39 |
+
EPOCHS = "epochs"
|
40 |
+
BATCH_SIZE = "batch_size"
|
41 |
+
MAX_STEPS_PER_EPOCH = "max_steps_per_epoch"
|
42 |
+
|
43 |
+
|
44 |
+
# Properties for the model
|
45 |
+
NAME = "name"
|
46 |
+
ANNOTATIONS = "annotations"
|
47 |
+
NB_PARAMS = "nb_params"
|
48 |
+
RECEPTIVE_FIELD = "receptive_field"
|
49 |
+
SHORT_NAME = "short_name"
|
50 |
+
|
51 |
+
|
52 |
+
# Scheduler
|
53 |
+
SCHEDULER = "scheduler"
|
54 |
+
SCHEDULER_CONFIGURATION = "scheduler_configuration"
|
55 |
+
|
56 |
+
# Loss
|
57 |
+
LOSS = "loss"
|
58 |
+
LOSS_L1 = "L1"
|
59 |
+
LOSS_L2 = "MSE"
|
60 |
+
LOSS_TYPE = "loss_type"
|
61 |
+
SNR = "snr"
|
62 |
+
|
63 |
+
# Checkpoint
|
64 |
+
MODEL = "model"
|
65 |
+
CURRENT_EPOCH = "current_epoch"
|
66 |
+
CONFIGURATION = "configuration"
|
67 |
+
|
68 |
+
|
69 |
+
# Signal names
|
70 |
+
CLEAN = "clean"
|
71 |
+
NOISY = "noise"
|
72 |
+
MIXED = "mixed"
|
73 |
+
PREDICTED = "predicted"
|
74 |
+
|
75 |
+
|
76 |
+
# MISC
|
77 |
+
PATHS = "paths"
|
78 |
+
BUFFERS = "buffers"
|
79 |
+
SAMPLING_RATE = "sampling_rate"
|
src/gyraudio/audio_separation/train.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.experiment_tracking.experiments import get_experience
|
2 |
+
from gyraudio.audio_separation.parser import shared_parser
|
3 |
+
from gyraudio.audio_separation.properties import (
|
4 |
+
TRAIN, TEST, EPOCHS, OPTIMIZER, NAME, MAX_STEPS_PER_EPOCH,
|
5 |
+
SIGNAL, NOISE, TOTAL, SNR, SCHEDULER, SCHEDULER_CONFIGURATION,
|
6 |
+
LOSS, LOSS_L2, LOSS_L1, LOSS_TYPE, COEFFICIENT
|
7 |
+
)
|
8 |
+
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
|
9 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
10 |
+
from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder, save_checkpoint
|
11 |
+
from gyraudio.audio_separation.metrics import Costs, snr
|
12 |
+
# from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint
|
13 |
+
from pathlib import Path
|
14 |
+
from gyraudio.io.dump import Dump
|
15 |
+
import sys
|
16 |
+
import torch
|
17 |
+
from tqdm import tqdm
|
18 |
+
from copy import deepcopy
|
19 |
+
import wandb
|
20 |
+
import logging
|
21 |
+
|
22 |
+
|
23 |
+
def launch_training(exp: int, wandb_flag: bool = True, device: str = "cuda", save_dir: Path = None, override=False):
|
24 |
+
|
25 |
+
short_name, model, config, dl = get_experience(exp)
|
26 |
+
exists, output_folder = get_output_folder(config, root_dir=save_dir, override=override)
|
27 |
+
if not exists:
|
28 |
+
logging.warning(f"Skipping experiment {short_name}")
|
29 |
+
return False
|
30 |
+
else:
|
31 |
+
logging.info(f"Experiment {short_name} saved in {output_folder}")
|
32 |
+
|
33 |
+
print(short_name)
|
34 |
+
print(config)
|
35 |
+
logging.info(f"Starting training for {short_name}")
|
36 |
+
logging.info(f"Config: {config}")
|
37 |
+
if wandb_flag:
|
38 |
+
wandb.init(
|
39 |
+
project="audio-separation",
|
40 |
+
entity="teammd",
|
41 |
+
name=short_name,
|
42 |
+
tags=["debug"],
|
43 |
+
config=config
|
44 |
+
)
|
45 |
+
training_loop(model, config, dl, wandb_flag=wandb_flag, device=device, exp_dir=output_folder)
|
46 |
+
if wandb_flag:
|
47 |
+
wandb.finish()
|
48 |
+
return True
|
49 |
+
|
50 |
+
|
51 |
+
def update_metrics(metrics, phase, pred, gt, pred_noise, gt_noise):
|
52 |
+
metrics[phase].update(pred, gt, SIGNAL)
|
53 |
+
metrics[phase].update(pred_noise, gt_noise, NOISE)
|
54 |
+
metrics[phase].update(pred, gt, SNR)
|
55 |
+
loss = metrics[phase].finish_step()
|
56 |
+
return loss
|
57 |
+
|
58 |
+
|
59 |
+
def training_loop(model: torch.nn.Module, config: dict, dl, device: str = "cuda", wandb_flag: bool = False,
|
60 |
+
exp_dir: Path = None):
|
61 |
+
optim_params = deepcopy(config[OPTIMIZER])
|
62 |
+
optim_name = optim_params[NAME]
|
63 |
+
optim_params.pop(NAME)
|
64 |
+
if optim_name == "adam":
|
65 |
+
optimizer = torch.optim.Adam(model.parameters(), **optim_params)
|
66 |
+
|
67 |
+
scheduler = None
|
68 |
+
scheduler_config = config.get(SCHEDULER_CONFIGURATION, {})
|
69 |
+
scheduler_name = config.get(SCHEDULER, False)
|
70 |
+
if scheduler_name:
|
71 |
+
if scheduler_name == "ReduceLROnPlateau":
|
72 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='max', verbose=True, **scheduler_config)
|
73 |
+
logging.info(f"Using scheduler {scheduler_name} with config {scheduler_config}")
|
74 |
+
else:
|
75 |
+
raise NotImplementedError(f"Scheduler {scheduler_name} not implemented")
|
76 |
+
max_steps = config.get(MAX_STEPS_PER_EPOCH, None)
|
77 |
+
chosen_loss = config.get(LOSS, LOSS_L2)
|
78 |
+
if chosen_loss == LOSS_L2:
|
79 |
+
costs = {TRAIN: Costs(TRAIN), TEST: Costs(TEST)}
|
80 |
+
elif chosen_loss == LOSS_L1:
|
81 |
+
cost_init = {
|
82 |
+
SIGNAL: {
|
83 |
+
COEFFICIENT: 0.5,
|
84 |
+
LOSS_TYPE: torch.nn.functional.l1_loss
|
85 |
+
},
|
86 |
+
NOISE: {
|
87 |
+
COEFFICIENT: 0.5,
|
88 |
+
LOSS_TYPE: torch.nn.functional.l1_loss
|
89 |
+
},
|
90 |
+
SNR: {
|
91 |
+
LOSS_TYPE: snr
|
92 |
+
}
|
93 |
+
}
|
94 |
+
costs = {
|
95 |
+
TRAIN: Costs(TRAIN, costs=cost_init),
|
96 |
+
TEST: Costs(TEST)
|
97 |
+
}
|
98 |
+
for epoch in range(config[EPOCHS]):
|
99 |
+
costs[TRAIN].reset_epoch()
|
100 |
+
costs[TEST].reset_epoch()
|
101 |
+
model.to(device)
|
102 |
+
# Training loop
|
103 |
+
# -----------------------------------------------------------
|
104 |
+
|
105 |
+
metrics = {TRAIN: {}, TEST: {}}
|
106 |
+
for step_index, (batch_mix, batch_signal, batch_noise) in tqdm(
|
107 |
+
enumerate(dl[TRAIN]), desc=f"Epoch {epoch}", total=len(dl[TRAIN])):
|
108 |
+
if max_steps is not None and step_index >= max_steps:
|
109 |
+
break
|
110 |
+
batch_mix, batch_signal, batch_noise = batch_mix.to(device), batch_signal.to(device), batch_noise.to(device)
|
111 |
+
model.zero_grad()
|
112 |
+
batch_output_signal, batch_output_noise = model(batch_mix)
|
113 |
+
loss = update_metrics(
|
114 |
+
costs, TRAIN,
|
115 |
+
batch_output_signal, batch_signal,
|
116 |
+
batch_output_noise, batch_noise
|
117 |
+
)
|
118 |
+
# costs[TRAIN].update(batch_output_signal, batch_signal, SIGNAL)
|
119 |
+
# costs[TRAIN].update(batch_output_noise, batch_noise, NOISE)
|
120 |
+
# loss = costs[TRAIN].finish_step()
|
121 |
+
loss.backward()
|
122 |
+
optimizer.step()
|
123 |
+
costs[TRAIN].finish_epoch()
|
124 |
+
|
125 |
+
# Validation loop
|
126 |
+
# -----------------------------------------------------------
|
127 |
+
model.eval()
|
128 |
+
torch.cuda.empty_cache()
|
129 |
+
with torch.no_grad():
|
130 |
+
for step_index, (batch_mix, batch_signal, batch_noise) in tqdm(
|
131 |
+
enumerate(dl[TEST]), desc=f"Epoch {epoch}", total=len(dl[TEST])):
|
132 |
+
if max_steps is not None and step_index >= max_steps:
|
133 |
+
break
|
134 |
+
batch_mix, batch_signal, batch_noise = batch_mix.to(
|
135 |
+
device), batch_signal.to(device), batch_noise.to(device)
|
136 |
+
batch_output_signal, batch_output_noise = model(batch_mix)
|
137 |
+
loss = update_metrics(
|
138 |
+
costs, TEST,
|
139 |
+
batch_output_signal, batch_signal,
|
140 |
+
batch_output_noise, batch_noise
|
141 |
+
)
|
142 |
+
costs[TEST].finish_epoch()
|
143 |
+
if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau):
|
144 |
+
scheduler.step(costs[TEST].total_metric[SNR])
|
145 |
+
print(f"epoch {epoch}:\n{costs[TRAIN]}\n{costs[TEST]}")
|
146 |
+
wandblogs = {}
|
147 |
+
if wandb_flag:
|
148 |
+
for phase in [TRAIN, TEST]:
|
149 |
+
wandblogs[f"{phase} loss signal"] = costs[phase].total_metric[SIGNAL]
|
150 |
+
wandblogs[f"debug loss/{phase} loss signal"] = costs[phase].total_metric[SIGNAL]
|
151 |
+
wandblogs[f"debug loss/{phase} loss total"] = costs[phase].total_metric[TOTAL]
|
152 |
+
wandblogs[f"debug loss/{phase} loss noise"] = costs[phase].total_metric[NOISE]
|
153 |
+
wandblogs[f"{phase} snr"] = costs[phase].total_metric[SNR]
|
154 |
+
wandblogs["learning rate"] = optimizer.param_groups[0]['lr']
|
155 |
+
wandb.log(wandblogs)
|
156 |
+
metrics[TRAIN] = costs[TRAIN].total_metric
|
157 |
+
metrics[TEST] = costs[TEST].total_metric
|
158 |
+
Dump.save_json(metrics, exp_dir/f"metrics_{epoch:04d}.json")
|
159 |
+
save_checkpoint(model, exp_dir, optimizer, config=config, epoch=epoch)
|
160 |
+
torch.cuda.empty_cache()
|
161 |
+
|
162 |
+
|
163 |
+
def main(argv):
|
164 |
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
165 |
+
parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/teammd/audio-separation"
|
166 |
+
+ ("\n<<<Cuda available>>>" if default_device == "cuda" else ""))
|
167 |
+
parser_def.add_argument("-nowb", "--no-wandb", action="store_true")
|
168 |
+
parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
|
169 |
+
parser_def.add_argument("-f", "--force", action="store_true", help="Override existing experiment")
|
170 |
+
|
171 |
+
parser_def.add_argument("-d", "--device", type=str, default=default_device,
|
172 |
+
help="Training device", choices=["cpu", "cuda"])
|
173 |
+
args = parser_def.parse_args(argv)
|
174 |
+
for exp in args.experiments:
|
175 |
+
launch_training(
|
176 |
+
exp, wandb_flag=not args.no_wandb, save_dir=Path(args.output_dir),
|
177 |
+
override=args.force,
|
178 |
+
device=args.device
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
main(sys.argv[1:])
|
src/gyraudio/audio_separation/visualization/audio_player.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.audio_separation.properties import CLEAN, NOISY, MIXED, PREDICTED, SAMPLING_RATE
|
2 |
+
from pathlib import Path
|
3 |
+
from gyraudio.io.audio import save_audio_tensor
|
4 |
+
from gyraudio import root_dir
|
5 |
+
from interactive_pipe import Control, KeyboardControl
|
6 |
+
from interactive_pipe import interactive
|
7 |
+
import logging
|
8 |
+
|
9 |
+
HERE = Path(__file__).parent
|
10 |
+
MUTE = "mute"
|
11 |
+
LOGOS = {
|
12 |
+
PREDICTED: HERE/"play_logo_pred.png",
|
13 |
+
MIXED: HERE/"play_logo_mixed.png",
|
14 |
+
CLEAN: HERE/"play_logo_clean.png",
|
15 |
+
NOISY: HERE/"play_logo_noise.png",
|
16 |
+
MUTE: HERE/"mute_logo.png",
|
17 |
+
}
|
18 |
+
ICONS = [it for key, it in LOGOS.items()]
|
19 |
+
KEYS = [key for key, it in LOGOS.items()]
|
20 |
+
|
21 |
+
ping_pong_index = 0
|
22 |
+
|
23 |
+
|
24 |
+
@interactive(
|
25 |
+
player=Control(MUTE, KEYS, icons=ICONS))
|
26 |
+
def audio_selector(sig, mixed, pred, global_params={}, player=MUTE):
|
27 |
+
|
28 |
+
global_params["selected_audio"] = player if player != MUTE else global_params.get("selected_audio", MIXED)
|
29 |
+
global_params[MUTE] = player == MUTE
|
30 |
+
if player == CLEAN:
|
31 |
+
audio_track = sig["buffers"][CLEAN]
|
32 |
+
elif player == NOISY:
|
33 |
+
audio_track = sig["buffers"][NOISY]
|
34 |
+
elif player == MIXED:
|
35 |
+
audio_track = mixed
|
36 |
+
elif player == PREDICTED:
|
37 |
+
audio_track = pred
|
38 |
+
else:
|
39 |
+
audio_track = mixed
|
40 |
+
return audio_track
|
41 |
+
|
42 |
+
|
43 |
+
@interactive(
|
44 |
+
loop=KeyboardControl(True, keydown="l"))
|
45 |
+
def audio_trim(audio_track, global_params={}, loop=True):
|
46 |
+
sampling_rate = global_params.get(SAMPLING_RATE, 8000)
|
47 |
+
if global_params.get("trim", False):
|
48 |
+
start, end = global_params["trim"]["start"], global_params["trim"]["end"]
|
49 |
+
remainder = (end-start) % 8
|
50 |
+
audio_trim = audio_track[..., start:end-remainder]
|
51 |
+
repeat_factor = int(sampling_rate*4./(end-start))
|
52 |
+
logging.debug(f"{repeat_factor}")
|
53 |
+
repeat_factor = max(1, repeat_factor)
|
54 |
+
if loop:
|
55 |
+
repeat_factor = 1
|
56 |
+
audio_trim = audio_trim.repeat(1, repeat_factor)
|
57 |
+
logging.debug(f"{audio_trim.shape}")
|
58 |
+
else:
|
59 |
+
audio_trim = audio_track
|
60 |
+
return audio_trim
|
61 |
+
|
62 |
+
|
63 |
+
@interactive(
|
64 |
+
volume=(100, [0, 1000], "volume"),
|
65 |
+
)
|
66 |
+
def audio_player(audio_trim, global_params={}, volume=100):
|
67 |
+
sampling_rate = global_params.get(SAMPLING_RATE, 8000)
|
68 |
+
try:
|
69 |
+
if global_params.get(MUTE, True):
|
70 |
+
global_params["__stop"]()
|
71 |
+
print("mute!")
|
72 |
+
else:
|
73 |
+
ping_pong_path = root_dir/"__ping_pong"
|
74 |
+
ping_pong_path.mkdir(exist_ok=True)
|
75 |
+
global ping_pong_index
|
76 |
+
audio_track_path = ping_pong_path/f"_tmp_{ping_pong_index}.wav"
|
77 |
+
ping_pong_index = (ping_pong_index + 1) % 10
|
78 |
+
save_audio_tensor(audio_track_path, volume/100.*audio_trim,
|
79 |
+
sampling_rate=global_params.get(SAMPLING_RATE, sampling_rate))
|
80 |
+
global_params["__set_audio"](audio_track_path)
|
81 |
+
global_params["__play"]()
|
82 |
+
except Exception as exc:
|
83 |
+
logging.warning(f"Exception in audio_player {exc}")
|
84 |
+
pass
|
src/gyraudio/audio_separation/visualization/interactive_audio.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from batch_processing import Batch
|
2 |
+
import argparse
|
3 |
+
from pathlib import Path
|
4 |
+
from gyraudio.audio_separation.experiment_tracking.experiments import get_experience
|
5 |
+
from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder
|
6 |
+
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
|
7 |
+
from gyraudio.audio_separation.properties import (
|
8 |
+
SHORT_NAME, CLEAN, NOISY, MIXED, PREDICTED, ANNOTATIONS, PATHS, BUFFERS, SAMPLING_RATE, NAME
|
9 |
+
)
|
10 |
+
import torch
|
11 |
+
from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint
|
12 |
+
from gyraudio.audio_separation.visualization.pre_load_audio import (
|
13 |
+
parse_command_line_audio_load, load_buffers, audio_loading_batch)
|
14 |
+
from gyraudio.audio_separation.visualization.pre_load_custom_audio import (
|
15 |
+
parse_command_line_generic_audio_load, generic_audio_loading_batch,
|
16 |
+
load_buffers_custom
|
17 |
+
)
|
18 |
+
from torchaudio.functional import resample
|
19 |
+
from typing import List
|
20 |
+
import numpy as np
|
21 |
+
import logging
|
22 |
+
from interactive_pipe.data_objects.curves import Curve, SingleCurve
|
23 |
+
from interactive_pipe import interactive, KeyboardControl, Control
|
24 |
+
from interactive_pipe.headless.pipeline import HeadlessPipeline
|
25 |
+
from interactive_pipe.graphical.qt_gui import InteractivePipeQT
|
26 |
+
from interactive_pipe.graphical.mpl_gui import InteractivePipeMatplotlib
|
27 |
+
from gyraudio.audio_separation.visualization.audio_player import audio_selector, audio_trim, audio_player
|
28 |
+
|
29 |
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
+
LEARNT_SAMPLING_RATE = 8000
|
31 |
+
|
32 |
+
|
33 |
+
@interactive(
|
34 |
+
idx=KeyboardControl(value_default=0, value_range=[
|
35 |
+
0, 1000], modulo=True, keyup="8", keydown="2"),
|
36 |
+
idn=KeyboardControl(value_default=0, value_range=[
|
37 |
+
0, 1000], modulo=True, keyup="9", keydown="3")
|
38 |
+
)
|
39 |
+
def signal_selector(signals, idx=0, idn=0, global_params={}):
|
40 |
+
if isinstance(signals, dict):
|
41 |
+
clean_sigs = signals[CLEAN]
|
42 |
+
clean = clean_sigs[idx % len(clean_sigs)]
|
43 |
+
if BUFFERS not in clean:
|
44 |
+
load_buffers_custom(clean)
|
45 |
+
noise_sigs = signals[NOISY]
|
46 |
+
noise = noise_sigs[idn % len(noise_sigs)]
|
47 |
+
if BUFFERS not in noise:
|
48 |
+
load_buffers_custom(noise)
|
49 |
+
cbuf, nbuf = clean[BUFFERS], noise[BUFFERS]
|
50 |
+
if clean[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
|
51 |
+
cbuf = resample(cbuf, clean[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
|
52 |
+
clean[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
|
53 |
+
if noise[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
|
54 |
+
nbuf = resample(nbuf, noise[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
|
55 |
+
noise[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
|
56 |
+
min_length = min(cbuf.shape[-1], nbuf.shape[-1])
|
57 |
+
min_length = min_length - min_length % 1024
|
58 |
+
signal = {
|
59 |
+
PATHS: {
|
60 |
+
CLEAN: clean[PATHS],
|
61 |
+
NOISY: noise[PATHS]
|
62 |
+
|
63 |
+
},
|
64 |
+
BUFFERS: {
|
65 |
+
CLEAN: cbuf[..., :1, :min_length],
|
66 |
+
NOISY: nbuf[..., :1, :min_length],
|
67 |
+
},
|
68 |
+
NAME: f"Clean={clean[NAME]} | Noise={noise[NAME]}",
|
69 |
+
SAMPLING_RATE: LEARNT_SAMPLING_RATE
|
70 |
+
}
|
71 |
+
else:
|
72 |
+
# signals are loaded in CPU
|
73 |
+
signal = signals[idx % len(signals)]
|
74 |
+
if BUFFERS not in signal:
|
75 |
+
load_buffers(signal)
|
76 |
+
global_params["premixed_snr"] = signal.get("premixed_snr", None)
|
77 |
+
signal[NAME] = f"File={signal[NAME]}"
|
78 |
+
global_params["selected_info"] = signal[NAME]
|
79 |
+
global_params[SAMPLING_RATE] = signal[SAMPLING_RATE]
|
80 |
+
return signal
|
81 |
+
|
82 |
+
|
83 |
+
@interactive(
|
84 |
+
snr=(0., [-10., 10.], "SNR [dB]")
|
85 |
+
)
|
86 |
+
def remix(signals, snr=0., global_params={}):
|
87 |
+
signal = signals[BUFFERS][CLEAN]
|
88 |
+
noisy = signals[BUFFERS][NOISY]
|
89 |
+
alpha = 10 ** (-snr / 20) * torch.norm(signal) / torch.norm(noisy)
|
90 |
+
mixed_signal = signal + alpha * noisy
|
91 |
+
global_params["snr"] = snr
|
92 |
+
return mixed_signal
|
93 |
+
|
94 |
+
|
95 |
+
@interactive(std_dev=Control(0., value_range=[0., 0.1], name="extra noise std", step=0.0001),
|
96 |
+
amplify=(1., [0., 10.], "amplification of everything"))
|
97 |
+
def augment(signals, mixed, std_dev=0., amplify=1.):
|
98 |
+
signals[BUFFERS][MIXED] *= amplify
|
99 |
+
signals[BUFFERS][NOISY] *= amplify
|
100 |
+
signals[BUFFERS][CLEAN] *= amplify
|
101 |
+
mixed = mixed*amplify+torch.randn_like(mixed)*std_dev
|
102 |
+
return signals, mixed
|
103 |
+
|
104 |
+
|
105 |
+
# @interactive(
|
106 |
+
# device=("cuda", ["cpu", "cuda"]) if default_device == "cuda" else ("cpu", ["cpu"])
|
107 |
+
# )
|
108 |
+
def select_device(device=default_device, global_params={}):
|
109 |
+
global_params["device"] = device
|
110 |
+
|
111 |
+
|
112 |
+
@interactive(
|
113 |
+
model=KeyboardControl(value_default=0, value_range=[
|
114 |
+
0, 99], keyup="pagedown", keydown="pageup")
|
115 |
+
)
|
116 |
+
def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}):
|
117 |
+
selected_model = models[model % len(models)]
|
118 |
+
config = configs[model % len(models)]
|
119 |
+
short_name = config.get(SHORT_NAME, "")
|
120 |
+
annotations = config.get(ANNOTATIONS, "")
|
121 |
+
device = global_params.get("device", "cpu")
|
122 |
+
with torch.no_grad():
|
123 |
+
selected_model.eval()
|
124 |
+
selected_model.to(device)
|
125 |
+
predicted_signal, predicted_noise = selected_model(
|
126 |
+
mixed.to(device).unsqueeze(0))
|
127 |
+
predicted_signal = predicted_signal.squeeze(0)
|
128 |
+
pred_curve = SingleCurve(y=predicted_signal[0, :].detach().cpu().numpy(),
|
129 |
+
style="g-", label=f"predicted_{short_name} {annotations}")
|
130 |
+
return predicted_signal, pred_curve
|
131 |
+
|
132 |
+
|
133 |
+
def compute_metrics(pred, sig, global_params={}):
|
134 |
+
METRICS = "metrics"
|
135 |
+
target = sig[BUFFERS][CLEAN]
|
136 |
+
global_params[METRICS] = {}
|
137 |
+
global_params[METRICS]["MSE"] = torch.mean((target-pred.cpu())**2)
|
138 |
+
global_params[METRICS]["SNR"] = 10. * \
|
139 |
+
torch.log10(torch.sum(target**2)/torch.sum((target-pred.cpu())**2))
|
140 |
+
|
141 |
+
|
142 |
+
def get_trim(sig, zoom, center, num_samples=300):
|
143 |
+
N = len(sig)
|
144 |
+
native_ds = N/num_samples
|
145 |
+
center_idx = int(center*N)
|
146 |
+
window = int(num_samples/zoom*native_ds)
|
147 |
+
start_idx = max(0, center_idx - window//2)
|
148 |
+
end_idx = min(N, center_idx + window//2)
|
149 |
+
skip_factor = max(1, int(native_ds/zoom))
|
150 |
+
return start_idx, end_idx, skip_factor
|
151 |
+
|
152 |
+
|
153 |
+
def zin(sig, zoom, center, num_samples=300):
|
154 |
+
start_idx, end_idx, skip_factor = get_trim(
|
155 |
+
sig, zoom, center, num_samples=num_samples)
|
156 |
+
out = np.zeros(num_samples)
|
157 |
+
trimmed = sig[start_idx:end_idx:skip_factor]
|
158 |
+
out[:len(trimmed)] = trimmed[:num_samples]
|
159 |
+
return out
|
160 |
+
|
161 |
+
|
162 |
+
@interactive(
|
163 |
+
center=KeyboardControl(value_default=0.5, value_range=[
|
164 |
+
0., 1.], step=0.01, keyup="6", keydown="4"),
|
165 |
+
zoom=KeyboardControl(value_default=0., value_range=[
|
166 |
+
0., 15.], step=1, keyup="+", keydown="-"),
|
167 |
+
zoomy=KeyboardControl(
|
168 |
+
value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down")
|
169 |
+
)
|
170 |
+
def visualize_audio(signal: dict, mixed_signal, pred, zoom=1, zoomy=0., center=0.5, global_params={}):
|
171 |
+
"""Create curves
|
172 |
+
"""
|
173 |
+
zval = 1.5**zoom
|
174 |
+
start_idx, end_idx, _skip_factor = get_trim(
|
175 |
+
signal[BUFFERS][CLEAN][0, :], zval, center)
|
176 |
+
global_params["trim"] = dict(start=start_idx, end=end_idx)
|
177 |
+
selected = global_params.get("selected_audio", MIXED)
|
178 |
+
clean = SingleCurve(y=zin(signal[BUFFERS][CLEAN][0, :], zval, center),
|
179 |
+
alpha=1.,
|
180 |
+
style="k-",
|
181 |
+
linewidth=0.9,
|
182 |
+
label=("*" if selected == CLEAN else " ")+"clean")
|
183 |
+
noisy = SingleCurve(y=zin(signal[BUFFERS][NOISY][0, :], zval, center),
|
184 |
+
alpha=0.3,
|
185 |
+
style="y--",
|
186 |
+
linewidth=1,
|
187 |
+
label=("*" if selected == NOISY else " ") + "noisy"
|
188 |
+
)
|
189 |
+
mixed = SingleCurve(y=zin(mixed_signal[0, :], zval, center), style="r-",
|
190 |
+
alpha=0.1,
|
191 |
+
linewidth=2,
|
192 |
+
label=("*" if selected == MIXED else " ") + "mixed")
|
193 |
+
# true_mixed = SingleCurve(y=zin(signal[BUFFERS][MIXED][0, :], zval, center),
|
194 |
+
# alpha=0.3, style="b-", linewidth=1, label="true mixed")
|
195 |
+
pred.y = zin(pred.y, zval, center)
|
196 |
+
pred.label = ("*" if selected == PREDICTED else " ") + pred.label
|
197 |
+
curves = [noisy, mixed, pred, clean]
|
198 |
+
title = f"SNR in {global_params['snr']:.1f} dB"
|
199 |
+
if "selected_info" in global_params:
|
200 |
+
title += f" | {global_params['selected_info']}"
|
201 |
+
title += "\n"
|
202 |
+
for metric_name, metric_value in global_params.get("metrics", {}).items():
|
203 |
+
title += f" | {metric_name} "
|
204 |
+
title += f"{metric_value:.2e}" if (abs(metric_value) < 1e-2 or abs(metric_value)
|
205 |
+
> 1000) else f"{metric_value:.2f}"
|
206 |
+
# if global_params.get("premixed_snr", None) is not None:
|
207 |
+
# title += f"| Premixed SNR : {global_params['premixed_snr']:.1f} dB"
|
208 |
+
return Curve(curves, ylim=[-0.04 * 1.5 ** zoomy, 0.04 * 1.5 ** zoomy], xlabel="Time index", ylabel="Amplitude", title=title)
|
209 |
+
|
210 |
+
|
211 |
+
def interactive_audio_separation_processing(signals, model_list, config_list):
|
212 |
+
sig = signal_selector(signals)
|
213 |
+
mixed = remix(sig)
|
214 |
+
# sig, mixed = augment(sig, mixed)
|
215 |
+
select_device()
|
216 |
+
pred, pred_curve = audio_sep_inference(mixed, model_list, config_list)
|
217 |
+
compute_metrics(pred, sig)
|
218 |
+
sound = audio_selector(sig, mixed, pred)
|
219 |
+
curve = visualize_audio(sig, mixed, pred_curve)
|
220 |
+
trimmed_sound = audio_trim(sound)
|
221 |
+
audio_player(trimmed_sound)
|
222 |
+
return curve
|
223 |
+
|
224 |
+
|
225 |
+
def interactive_audio_separation_visualization(
|
226 |
+
all_signals: List[dict],
|
227 |
+
model_list: List[torch.nn.Module],
|
228 |
+
config_list: List[dict],
|
229 |
+
gui="qt"
|
230 |
+
):
|
231 |
+
pip = HeadlessPipeline.from_function(
|
232 |
+
interactive_audio_separation_processing, cache=False)
|
233 |
+
if gui == "qt":
|
234 |
+
app = InteractivePipeQT(
|
235 |
+
pipeline=pip, name="audio separation", size=(1000, 1000), audio=True)
|
236 |
+
else:
|
237 |
+
logging.warning("No support for audio player with Matplotlib")
|
238 |
+
app = InteractivePipeMatplotlib(
|
239 |
+
pipeline=pip, name="audio separation", size=None, audio=False)
|
240 |
+
app(all_signals, model_list, config_list)
|
241 |
+
|
242 |
+
|
243 |
+
def visualization(
|
244 |
+
all_signals: List[dict],
|
245 |
+
model_list: List[torch.nn.Module],
|
246 |
+
config_list: List[dict],
|
247 |
+
device="cuda"
|
248 |
+
):
|
249 |
+
for signal in all_signals:
|
250 |
+
if BUFFERS not in signal:
|
251 |
+
load_buffers(signal, device="cpu")
|
252 |
+
clean = SingleCurve(y=signal[BUFFERS][CLEAN][0, :], label="clean")
|
253 |
+
noisy = SingleCurve(y=signal[BUFFERS][NOISY]
|
254 |
+
[0, :], label="noise", alpha=0.3)
|
255 |
+
curves = [clean, noisy]
|
256 |
+
for config, model in zip(config_list, model_list):
|
257 |
+
short_name = config.get(SHORT_NAME, "unknown")
|
258 |
+
predicted_signal, predicted_noise = model(
|
259 |
+
signal[BUFFERS][MIXED].to(device).unsqueeze(0))
|
260 |
+
predicted = SingleCurve(y=predicted_signal.squeeze(0)[0, :].detach().cpu().numpy(),
|
261 |
+
label=f"predicted_{short_name}")
|
262 |
+
curves.append(predicted)
|
263 |
+
Curve(curves).show()
|
264 |
+
|
265 |
+
|
266 |
+
def parse_command_line(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser:
|
267 |
+
if gradio_demo:
|
268 |
+
parser = parse_command_line_gradio(parser)
|
269 |
+
else:
|
270 |
+
parser = parse_command_line_generic(parser)
|
271 |
+
return parser
|
272 |
+
|
273 |
+
|
274 |
+
def parse_command_line_gradio(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser:
|
275 |
+
if parser is None:
|
276 |
+
parser = parse_command_line_audio_load()
|
277 |
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
278 |
+
iparse = parser.add_argument_group("Audio separation visualization")
|
279 |
+
iparse.add_argument("-e", "--experiments", type=int, nargs="+", default=3001,
|
280 |
+
help="Experiment ids to be inferred sequentially")
|
281 |
+
iparse.add_argument("-p", "--interactive", default=True,
|
282 |
+
action="store_true", help="Play = Interactive mode")
|
283 |
+
iparse.add_argument("-m", "--model-root", type=str,
|
284 |
+
default=EXPERIMENT_STORAGE_ROOT)
|
285 |
+
iparse.add_argument("-d", "--device", type=str, default=default_device,
|
286 |
+
choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"])
|
287 |
+
iparse.add_argument("-gui", "--gui", type=str,
|
288 |
+
default="gradio", choices=["qt", "mpl", "gradio"])
|
289 |
+
pri
|
290 |
+
return parser
|
291 |
+
|
292 |
+
|
293 |
+
def parse_command_line_generic(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser:
|
294 |
+
if parser is None:
|
295 |
+
parser = parse_command_line_audio_load()
|
296 |
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
297 |
+
iparse = parser.add_argument_group("Audio separation visualization")
|
298 |
+
iparse.add_argument("-e", "--experiments", type=int, nargs="+", required=True,
|
299 |
+
help="Experiment ids to be inferred sequentially")
|
300 |
+
iparse.add_argument("-p", "--interactive",
|
301 |
+
action="store_true", help="Play = Interactive mode")
|
302 |
+
iparse.add_argument("-m", "--model-root", type=str,
|
303 |
+
default=EXPERIMENT_STORAGE_ROOT)
|
304 |
+
iparse.add_argument("-d", "--device", type=str, default=default_device,
|
305 |
+
choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"])
|
306 |
+
iparse.add_argument("-gui", "--gui", type=str,
|
307 |
+
default="qt", choices=["qt", "mpl", "gradio"])
|
308 |
+
return parser
|
309 |
+
|
310 |
+
|
311 |
+
def main(argv: List[str]):
|
312 |
+
"""Paired signals and noise in folders"""
|
313 |
+
batch = Batch(argv)
|
314 |
+
batch.set_io_description(
|
315 |
+
input_help='input audio files',
|
316 |
+
output_help=argparse.SUPPRESS
|
317 |
+
)
|
318 |
+
batch.set_multiprocessing_enabled(False)
|
319 |
+
parser = parse_command_line()
|
320 |
+
args = batch.parse_args(parser)
|
321 |
+
exp = args.experiments[0]
|
322 |
+
device = args.device
|
323 |
+
models_list = []
|
324 |
+
config_list = []
|
325 |
+
logging.info(f"Loading experiments models {args.experiments}")
|
326 |
+
for exp in args.experiments:
|
327 |
+
model_dir = Path(args.model_root)
|
328 |
+
short_name, model, config, _dl = get_experience(exp)
|
329 |
+
_, exp_dir = get_output_folder(
|
330 |
+
config, root_dir=model_dir, override=False)
|
331 |
+
assert exp_dir.exists(
|
332 |
+
), f"Experiment {short_name} does not exist in {model_dir}"
|
333 |
+
model.eval()
|
334 |
+
model.to(device)
|
335 |
+
model, __optimizer, epoch, config = load_checkpoint(
|
336 |
+
model, exp_dir, epoch=None, device=args.device)
|
337 |
+
config[SHORT_NAME] = short_name
|
338 |
+
models_list.append(model)
|
339 |
+
config_list.append(config)
|
340 |
+
logging.info("Load audio buffers:")
|
341 |
+
all_signals = batch.run(audio_loading_batch)
|
342 |
+
if not args.interactive:
|
343 |
+
visualization(all_signals, models_list, config_list, device=device)
|
344 |
+
else:
|
345 |
+
interactive_audio_separation_visualization(
|
346 |
+
all_signals, models_list, config_list, gui=args.gui)
|
347 |
+
|
348 |
+
|
349 |
+
def main_custom(argv: List[str]):
|
350 |
+
"""Handle custom noise and custom signals
|
351 |
+
"""
|
352 |
+
parser = parse_command_line()
|
353 |
+
parser.add_argument("-s", "--signal", type=str, required=True,
|
354 |
+
nargs="+", help="Signal to be preloaded")
|
355 |
+
parser.add_argument("-n", "--noise", type=str, required=True,
|
356 |
+
nargs="+", help="Noise to be preloaded")
|
357 |
+
args = parser.parse_args(argv)
|
358 |
+
exp = args.experiments[0]
|
359 |
+
device = args.device
|
360 |
+
models_list = []
|
361 |
+
config_list = []
|
362 |
+
logging.info(f"Loading experiments models {args.experiments}")
|
363 |
+
for exp in args.experiments:
|
364 |
+
model_dir = Path(args.model_root)
|
365 |
+
short_name, model, config, _dl = get_experience(exp)
|
366 |
+
_, exp_dir = get_output_folder(
|
367 |
+
config, root_dir=model_dir, override=False)
|
368 |
+
assert exp_dir.exists(
|
369 |
+
), f"Experiment {short_name} does not exist in {model_dir}"
|
370 |
+
model.eval()
|
371 |
+
model.to(device)
|
372 |
+
model, __optimizer, epoch, config = load_checkpoint(
|
373 |
+
model, exp_dir, epoch=None, device=args.device)
|
374 |
+
config[SHORT_NAME] = short_name
|
375 |
+
models_list.append(model)
|
376 |
+
config_list.append(config)
|
377 |
+
all_signals = {}
|
378 |
+
for args_paths, key in zip([args.signal, args.noise], [CLEAN, NOISY]):
|
379 |
+
new_argv = ["-i"] + args_paths
|
380 |
+
if args.preload:
|
381 |
+
new_argv += ["--preload"]
|
382 |
+
batch = Batch(new_argv)
|
383 |
+
new_parser = parse_command_line_generic_audio_load()
|
384 |
+
batch.set_io_description(
|
385 |
+
input_help=argparse.SUPPRESS, # 'input audio files',
|
386 |
+
output_help=argparse.SUPPRESS
|
387 |
+
)
|
388 |
+
batch.set_multiprocessing_enabled(False)
|
389 |
+
_ = batch.parse_args(new_parser)
|
390 |
+
all_signals[key] = batch.run(generic_audio_loading_batch)
|
391 |
+
interactive_audio_separation_visualization(
|
392 |
+
all_signals, models_list, config_list, gui=args.gui)
|
src/gyraudio/audio_separation/visualization/interactive_infer.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
|
2 |
+
from gyraudio.audio_separation.parser import shared_parser
|
3 |
+
from gyraudio.audio_separation.infer import launch_infer, RECORD_KEYS, SNR_OUT, SNR_IN, NBATCH, SAVE_IDX
|
4 |
+
from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
from dash import Dash, html, dcc, callback, Output, Input, dash_table
|
8 |
+
import plotly.express as px
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
from plotly.subplots import make_subplots
|
11 |
+
import pandas as pd
|
12 |
+
from typing import List
|
13 |
+
import torch
|
14 |
+
from pathlib import Path
|
15 |
+
DIFF_SNR = 'SNR out - SNR in'
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def get_app(record_row_dfs : pd.DataFrame, eval_dfs : List[pd.DataFrame]) :
|
20 |
+
app = Dash(__name__)
|
21 |
+
# names_options = [{'label' : f"{record[SHORT_NAME]} - {record[NAME]} epoch {record[CURRENT_EPOCH]:04d}", 'value' : record[NAME] } for idx, record in record_row_dfs.iterrows()]
|
22 |
+
app.layout = html.Div([
|
23 |
+
html.H1(children='Inference results', style={'textAlign':'center'}),
|
24 |
+
# dcc.Dropdown(names_options, names_options[0]['value'], id='exp-selection'),
|
25 |
+
# dcc.RadioItems(['scatter', 'box'], 'box', inline=True, id='radio-plot-type'),
|
26 |
+
dcc.RadioItems([SNR_OUT, DIFF_SNR], DIFF_SNR, inline=True, id='radio-plot-out'),
|
27 |
+
dcc.Graph(id='graph-content')
|
28 |
+
])
|
29 |
+
|
30 |
+
@callback(
|
31 |
+
Output('graph-content', 'figure'),
|
32 |
+
# Input('exp-selection', 'value'),
|
33 |
+
# Input('radio-plot-type', 'value'),
|
34 |
+
Input('radio-plot-out', 'value'),
|
35 |
+
)
|
36 |
+
def update_graph(radio_plot_out) :
|
37 |
+
fig = make_subplots(rows = 2, cols = 1)
|
38 |
+
colors = px.colors.qualitative.Plotly
|
39 |
+
for id, record in record_row_dfs.iterrows() :
|
40 |
+
color = colors[id % len(colors)]
|
41 |
+
eval_df = eval_dfs[id].sort_values(by=SNR_IN)
|
42 |
+
eval_df[DIFF_SNR] = eval_df[SNR_OUT] - eval_df[SNR_IN]
|
43 |
+
legend = f'{record[SHORT_NAME]}_{record[NAME]}'
|
44 |
+
fig.add_trace(
|
45 |
+
go.Scatter(
|
46 |
+
x=eval_df[SNR_IN],
|
47 |
+
y=eval_df[radio_plot_out],
|
48 |
+
mode="markers", marker={'color' : color},
|
49 |
+
name=legend,
|
50 |
+
hovertemplate = 'File : %{text}'+
|
51 |
+
'<br>%{y}<br>',
|
52 |
+
text = [f"{eval[SAVE_IDX]:.0f}" for idx, eval in eval_df.iterrows()]
|
53 |
+
),
|
54 |
+
row = 1, col = 1
|
55 |
+
)
|
56 |
+
eval_df_bins = eval_df
|
57 |
+
eval_df_bins[SNR_IN] = eval_df_bins[SNR_IN].apply(lambda snr : round(snr))
|
58 |
+
fig.add_trace(
|
59 |
+
go.Box(
|
60 |
+
x=eval_df[SNR_IN],
|
61 |
+
y=eval_df[radio_plot_out],
|
62 |
+
fillcolor = color,
|
63 |
+
marker={'color' : color},
|
64 |
+
name = legend
|
65 |
+
),
|
66 |
+
row = 2, col = 1
|
67 |
+
)
|
68 |
+
|
69 |
+
title = f"SNR performances"
|
70 |
+
fig.update_layout(
|
71 |
+
title=title,
|
72 |
+
xaxis2_title = SNR_IN,
|
73 |
+
yaxis_title = radio_plot_out,
|
74 |
+
hovermode='x unified'
|
75 |
+
)
|
76 |
+
return fig
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
return app
|
81 |
+
|
82 |
+
|
83 |
+
def main(argv):
|
84 |
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
85 |
+
parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/balthazarneveu/audio-sep"
|
86 |
+
+ ("\n<<<Cuda available>>>" if default_device == "cuda" else ""))
|
87 |
+
parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
|
88 |
+
parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
|
89 |
+
parser_def.add_argument("-d", "--device", type=str, default=default_device,
|
90 |
+
help="Training device", choices=["cpu", "cuda"])
|
91 |
+
parser_def.add_argument("-b", "--nb-batch", type=int, default=None,
|
92 |
+
help="Number of batches to process")
|
93 |
+
parser_def.add_argument("-s", "--snr-filter", type=float, nargs="+", default=None,
|
94 |
+
help="SNR filters on the inference dataloader")
|
95 |
+
args = parser_def.parse_args(argv)
|
96 |
+
record_row_dfs = pd.DataFrame(columns = RECORD_KEYS)
|
97 |
+
eval_dfs = []
|
98 |
+
for exp in args.experiments:
|
99 |
+
record_row_df, evaluation_path = launch_infer(
|
100 |
+
exp,
|
101 |
+
model_dir=Path(args.input_dir),
|
102 |
+
output_dir=Path(args.output_dir),
|
103 |
+
device=args.device,
|
104 |
+
max_batches=args.nb_batch,
|
105 |
+
snr_filter=args.snr_filter
|
106 |
+
)
|
107 |
+
eval_df = pd.read_csv(evaluation_path)
|
108 |
+
# Careful, list order for concat is important for index matching eval_dfs list
|
109 |
+
record_row_dfs = pd.concat([record_row_dfs.loc[:], record_row_df], ignore_index=True)
|
110 |
+
eval_dfs.append(eval_df)
|
111 |
+
app = get_app(record_row_dfs, eval_dfs)
|
112 |
+
app.run(debug=True)
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
117 |
+
main(sys.argv[1:])
|
src/gyraudio/audio_separation/visualization/mute_logo.png
ADDED
![]() |
src/gyraudio/audio_separation/visualization/play_logo_clean.png
ADDED
![]() |
src/gyraudio/audio_separation/visualization/play_logo_mixed.png
ADDED
![]() |
src/gyraudio/audio_separation/visualization/play_logo_noise.png
ADDED
![]() |
src/gyraudio/audio_separation/visualization/play_logo_pred.png
ADDED
![]() |
src/gyraudio/audio_separation/visualization/pre_load_audio.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from batch_processing import Batch
|
2 |
+
import argparse
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
from gyraudio.audio_separation.properties import CLEAN, NOISY, MIXED, PATHS, BUFFERS, NAME, SAMPLING_RATE
|
6 |
+
from gyraudio.io.audio import load_audio_tensor
|
7 |
+
|
8 |
+
|
9 |
+
def parse_command_line_audio_load() -> argparse.ArgumentParser:
|
10 |
+
parser = argparse.ArgumentParser(description='Batch audio processing',
|
11 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
12 |
+
|
13 |
+
parser.add_argument("-preload", "--preload", action="store_true", help="Preload audio files")
|
14 |
+
return parser
|
15 |
+
|
16 |
+
|
17 |
+
def outp(path: Path, suffix: str, extension=".wav"):
|
18 |
+
return (path.parent / (path.stem + suffix)).with_suffix(extension)
|
19 |
+
|
20 |
+
|
21 |
+
def load_buffers(signal: dict, device="cpu") -> None:
|
22 |
+
clean_signal, sampling_rate = load_audio_tensor(signal[PATHS][CLEAN], device=device)
|
23 |
+
noisy_signal, sampling_rate = load_audio_tensor(signal[PATHS][NOISY], device=device)
|
24 |
+
mixed_signal, sampling_rate = load_audio_tensor(signal[PATHS][MIXED], device=device)
|
25 |
+
signal[BUFFERS] = {
|
26 |
+
CLEAN: clean_signal,
|
27 |
+
NOISY: noisy_signal,
|
28 |
+
MIXED: mixed_signal
|
29 |
+
}
|
30 |
+
signal[SAMPLING_RATE] = sampling_rate
|
31 |
+
|
32 |
+
|
33 |
+
def audio_loading(input: Path, preload: bool) -> dict:
|
34 |
+
name = input.name
|
35 |
+
clean_audio_path = input/"voice.wav"
|
36 |
+
noisy_audio_path = input/"noise.wav"
|
37 |
+
mixed_audio_path = list(input.glob("mix*.wav"))[0]
|
38 |
+
signal = {
|
39 |
+
NAME: name,
|
40 |
+
PATHS: {
|
41 |
+
CLEAN: clean_audio_path,
|
42 |
+
NOISY: noisy_audio_path,
|
43 |
+
MIXED: mixed_audio_path
|
44 |
+
}
|
45 |
+
}
|
46 |
+
signal["premixed_snr"] = float(mixed_audio_path.stem.split("_")[-1])
|
47 |
+
if preload:
|
48 |
+
load_buffers(signal)
|
49 |
+
return signal
|
50 |
+
|
51 |
+
|
52 |
+
def audio_loading_batch(input: Path, args: argparse.Namespace) -> dict:
|
53 |
+
"""Wrapper to load audio files from a directory using batch_processing
|
54 |
+
"""
|
55 |
+
return audio_loading(input, preload=args.preload)
|
56 |
+
|
57 |
+
|
58 |
+
def main(argv):
|
59 |
+
batch = Batch(argv)
|
60 |
+
batch.set_io_description(
|
61 |
+
input_help='input audio files',
|
62 |
+
output_help=argparse.SUPPRESS
|
63 |
+
)
|
64 |
+
parser = parse_command_line_audio_load()
|
65 |
+
batch.parse_args(parser)
|
66 |
+
all_signals = batch.run(audio_loading_batch)
|
67 |
+
return all_signals
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
main(sys.argv[1:])
|
src/gyraudio/audio_separation/visualization/pre_load_custom_audio.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from batch_processing import Batch
|
2 |
+
import argparse
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
from gyraudio.audio_separation.properties import PATHS, BUFFERS, NAME, SAMPLING_RATE
|
6 |
+
from gyraudio.io.audio import load_audio_tensor
|
7 |
+
|
8 |
+
|
9 |
+
def parse_command_line_generic_audio_load() -> argparse.ArgumentParser:
|
10 |
+
parser = argparse.ArgumentParser(description='Batch audio loading',
|
11 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
12 |
+
|
13 |
+
parser.add_argument("-preload", "--preload", action="store_true", help="Preload audio files")
|
14 |
+
return parser
|
15 |
+
|
16 |
+
|
17 |
+
def load_buffers_custom(signal: dict, device="cpu") -> None:
|
18 |
+
generic_signal, sampling_rate = load_audio_tensor(signal[PATHS], device=device)
|
19 |
+
signal[BUFFERS] = generic_signal
|
20 |
+
signal[SAMPLING_RATE] = sampling_rate
|
21 |
+
|
22 |
+
|
23 |
+
def audio_loading(input: Path, preload: bool) -> dict:
|
24 |
+
name = input.parent.name + "/" + input.stem
|
25 |
+
signal = {
|
26 |
+
NAME: name,
|
27 |
+
PATHS: input,
|
28 |
+
}
|
29 |
+
if preload:
|
30 |
+
load_buffers_custom(signal)
|
31 |
+
return signal
|
32 |
+
|
33 |
+
|
34 |
+
def generic_audio_loading_batch(input: Path, args: argparse.Namespace) -> dict:
|
35 |
+
"""Wrapper to load audio files from a directory using batch_processing
|
36 |
+
"""
|
37 |
+
return audio_loading(input, preload=args.preload)
|
38 |
+
|
39 |
+
|
40 |
+
def main(argv):
|
41 |
+
batch = Batch(argv)
|
42 |
+
batch.set_io_description(
|
43 |
+
input_help='input audio files',
|
44 |
+
output_help=argparse.SUPPRESS
|
45 |
+
)
|
46 |
+
parser = parse_command_line_generic_audio_load()
|
47 |
+
batch.parse_args(parser)
|
48 |
+
all_signals = batch.run(generic_audio_loading_batch)
|
49 |
+
return all_signals
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
main(sys.argv[1:])
|