balthou commited on
Commit
f6b56a2
·
1 Parent(s): 4454dfe

draft audio sep app

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __data_source_separation/source_separation/test/0000/mix_snr_-4.wav +0 -0
  2. __data_source_separation/source_separation/test/0000/noise.wav +0 -0
  3. __data_source_separation/source_separation/test/0000/voice.wav +0 -0
  4. __data_source_separation/source_separation/test/0001/mix_snr_2.wav +0 -0
  5. __data_source_separation/source_separation/test/0001/noise.wav +0 -0
  6. __data_source_separation/source_separation/test/0001/voice.wav +0 -0
  7. __output_audiosep/0004_0000/model_0059.pt +3 -0
  8. __output_audiosep/1004_0000/model_0119.pt +3 -0
  9. __output_audiosep/3001_0000/model_0199.pt +3 -0
  10. app.py +7 -0
  11. audio_samples/0009/mix_snr_-1.wav +0 -0
  12. audio_samples/0009/noise.wav +0 -0
  13. audio_samples/0009/voice.wav +0 -0
  14. requirements.txt +3 -0
  15. src/gyraudio/__init__.py +2 -0
  16. src/gyraudio/audio_separation/architecture/building_block.py +51 -0
  17. src/gyraudio/audio_separation/architecture/flat_conv.py +62 -0
  18. src/gyraudio/audio_separation/architecture/model.py +28 -0
  19. src/gyraudio/audio_separation/architecture/neutral.py +15 -0
  20. src/gyraudio/audio_separation/architecture/transformer.py +91 -0
  21. src/gyraudio/audio_separation/architecture/unet.py +151 -0
  22. src/gyraudio/audio_separation/architecture/wave_unet.py +163 -0
  23. src/gyraudio/audio_separation/data/__init__.py +5 -0
  24. src/gyraudio/audio_separation/data/dataloader.py +47 -0
  25. src/gyraudio/audio_separation/data/dataset.py +104 -0
  26. src/gyraudio/audio_separation/data/mixed.py +40 -0
  27. src/gyraudio/audio_separation/data/remixed.py +53 -0
  28. src/gyraudio/audio_separation/data/remixed_fixed.py +18 -0
  29. src/gyraudio/audio_separation/data/remixed_rnd.py +12 -0
  30. src/gyraudio/audio_separation/data/silence_detector.py +55 -0
  31. src/gyraudio/audio_separation/data/single.py +15 -0
  32. src/gyraudio/audio_separation/experiment_tracking/experiments.py +122 -0
  33. src/gyraudio/audio_separation/experiment_tracking/experiments_decorator.py +48 -0
  34. src/gyraudio/audio_separation/experiment_tracking/experiments_definition.py +320 -0
  35. src/gyraudio/audio_separation/experiment_tracking/storage.py +65 -0
  36. src/gyraudio/audio_separation/infer.py +202 -0
  37. src/gyraudio/audio_separation/metrics.py +101 -0
  38. src/gyraudio/audio_separation/parser.py +8 -0
  39. src/gyraudio/audio_separation/properties.py +79 -0
  40. src/gyraudio/audio_separation/train.py +183 -0
  41. src/gyraudio/audio_separation/visualization/audio_player.py +84 -0
  42. src/gyraudio/audio_separation/visualization/interactive_audio.py +392 -0
  43. src/gyraudio/audio_separation/visualization/interactive_infer.py +117 -0
  44. src/gyraudio/audio_separation/visualization/mute_logo.png +0 -0
  45. src/gyraudio/audio_separation/visualization/play_logo_clean.png +0 -0
  46. src/gyraudio/audio_separation/visualization/play_logo_mixed.png +0 -0
  47. src/gyraudio/audio_separation/visualization/play_logo_noise.png +0 -0
  48. src/gyraudio/audio_separation/visualization/play_logo_pred.png +0 -0
  49. src/gyraudio/audio_separation/visualization/pre_load_audio.py +71 -0
  50. src/gyraudio/audio_separation/visualization/pre_load_custom_audio.py +53 -0
__data_source_separation/source_separation/test/0000/mix_snr_-4.wav ADDED
Binary file (320 kB). View file
 
__data_source_separation/source_separation/test/0000/noise.wav ADDED
Binary file (320 kB). View file
 
__data_source_separation/source_separation/test/0000/voice.wav ADDED
Binary file (320 kB). View file
 
__data_source_separation/source_separation/test/0001/mix_snr_2.wav ADDED
Binary file (320 kB). View file
 
__data_source_separation/source_separation/test/0001/noise.wav ADDED
Binary file (320 kB). View file
 
__data_source_separation/source_separation/test/0001/voice.wav ADDED
Binary file (320 kB). View file
 
__output_audiosep/0004_0000/model_0059.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef62b0fdd9e9b81000da0db190be1df7e5451d70ab3c86609cab409ec3e38ab8
3
+ size 34402
__output_audiosep/1004_0000/model_0119.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61768dbf9e81845dc656c5da3e8fe5ea053f73108f422caf7879b4ab55a0792a
3
+ size 12755810
__output_audiosep/3001_0000/model_0199.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11341d63bad0a1ec15c5a94c5cc6f049720869fb9988da45ef7d07edd3c82e21
3
+ size 12743211
app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
4
+ os.sys.path.append(src_path)
5
+ from gyraudio.audio_separation.visualization.interactive_audio import main as interactive_audio_main
6
+ if __name__ == "__main__":
7
+ interactive_audio_main(sys.argv[1:])
audio_samples/0009/mix_snr_-1.wav ADDED
Binary file (320 kB). View file
 
audio_samples/0009/noise.wav ADDED
Binary file (320 kB). View file
 
audio_samples/0009/voice.wav ADDED
Binary file (320 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ batch_processing
2
+ interactive-pipe>=0.7.0
3
+ torch>=2.0.0
src/gyraudio/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from pathlib import Path
2
+ root_dir = Path(__file__).parent.parent.parent
src/gyraudio/audio_separation/architecture/building_block.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List
3
+
4
+
5
+ class FilterBank(torch.nn.Module):
6
+ """Convolution filter bank (linear)
7
+ Serves as an embedding for the audio signal
8
+ """
9
+
10
+ def __init__(self, ch_in: int, out_dim=16, k_size=5, dilation_list: List[int] = [1, 2, 4, 8]):
11
+ super().__init__()
12
+ self.out_dim = out_dim
13
+ self.source_modality_conv = torch.nn.ModuleList()
14
+ for dilation in dilation_list:
15
+ self.source_modality_conv.append(
16
+ torch.nn.Conv1d(ch_in, out_dim//len(dilation_list), k_size, dilation=dilation, padding=(dilation*(k_size//2)))
17
+ )
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ out = torch.cat([conv(x) for conv in self.source_modality_conv], axis=1)
21
+ assert out.shape[1] == self.out_dim
22
+ return out
23
+
24
+
25
+ class ResConvolution(torch.nn.Module):
26
+ """ResNet building block
27
+ https://paperswithcode.com/method/residual-connection
28
+ """
29
+
30
+ def __init__(self, ch, hdim=None, k_size=5):
31
+ super().__init__()
32
+ hdim = hdim or ch
33
+ self.conv1 = torch.nn.Conv1d(ch, hdim, k_size, padding=k_size//2)
34
+ self.conv2 = torch.nn.Conv1d(hdim, ch, k_size, padding=k_size//2)
35
+ self.non_linearity = torch.nn.ReLU()
36
+
37
+ def forward(self, x_in):
38
+ x = self.conv1(x_in)
39
+ x = self.non_linearity(x)
40
+ x = self.conv2(x)
41
+ x += x_in
42
+ x = self.non_linearity(x)
43
+ return x
44
+
45
+
46
+ if __name__ == "__main__":
47
+ model = FilterBank(1, 16)
48
+ inp = torch.rand(2, 1, 2048)
49
+ out = model(inp)
50
+ print(model)
51
+ print(out[0].shape)
src/gyraudio/audio_separation/architecture/flat_conv.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from gyraudio.audio_separation.architecture.model import SeparationModel
3
+ from typing import Tuple
4
+
5
+
6
+ class FlatConvolutional(SeparationModel):
7
+ """Convolutional neural network for audio separation,
8
+ No decimation, no bottleneck, just basic signal processing
9
+ """
10
+
11
+ def __init__(self,
12
+ ch_in: int = 1,
13
+ ch_out: int = 2,
14
+ h_dim=16,
15
+ k_size=5,
16
+ dilation=1
17
+ ) -> None:
18
+ super().__init__()
19
+ self.conv1 = torch.nn.Conv1d(
20
+ ch_in, h_dim, k_size,
21
+ dilation=dilation, padding=dilation*(k_size//2))
22
+ self.conv2 = torch.nn.Conv1d(
23
+ h_dim, h_dim, k_size,
24
+ dilation=dilation, padding=dilation*(k_size//2))
25
+ self.conv3 = torch.nn.Conv1d(
26
+ h_dim, h_dim, k_size,
27
+ dilation=dilation, padding=dilation*(k_size//2))
28
+ self.conv4 = torch.nn.Conv1d(
29
+ h_dim, h_dim, k_size,
30
+ dilation=dilation, padding=dilation*(k_size//2))
31
+ self.relu = torch.nn.ReLU()
32
+ self.encoder = torch.nn.Sequential(
33
+ self.conv1,
34
+ self.relu,
35
+ self.conv2,
36
+ self.relu,
37
+ self.conv3,
38
+ self.relu,
39
+ self.conv4,
40
+ self.relu
41
+ )
42
+ self.demux = torch.nn.Sequential(*(
43
+ torch.nn.Conv1d(h_dim, h_dim//2, 1), # conv1x1
44
+ torch.nn.ReLU(),
45
+ torch.nn.Conv1d(h_dim//2, ch_out, 1), # conv1x1
46
+ ))
47
+
48
+ def forward(self, mixed_sig_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
49
+ """Perform feature extraction followed by classifier head
50
+
51
+ Args:
52
+ sig_in (torch.Tensor): [N, C, T]
53
+
54
+ Returns:
55
+ torch.Tensor: logits (not probabilities) [N, n_classes]
56
+ """
57
+ # Convolution backbone
58
+ # [N, C, T] -> [N, h, T]
59
+ features = self.encoder(mixed_sig_in)
60
+ # [N, h, T] -> [N, 2, T]
61
+ demuxed = self.demux(features)
62
+ return torch.chunk(demuxed, 2, dim=1) # [N, 1, T], [N, 1, T]
src/gyraudio/audio_separation/architecture/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class SeparationModel(torch.nn.Module):
5
+ def __init__(self, *args, **kwargs) -> None:
6
+ super().__init__(*args, **kwargs)
7
+
8
+ def count_parameters(self) -> int:
9
+ """Count the total number of parameters of the model
10
+
11
+ Returns:
12
+ int: total amount of parameters
13
+ """
14
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
15
+
16
+ def receptive_field(self) -> int:
17
+ """Compute the receptive field of the model
18
+
19
+ Returns:
20
+ int: receptive field
21
+ """
22
+ input_tensor = torch.rand(1, 1, 4096, requires_grad=True)
23
+ out, out_noise = self.forward(input_tensor)
24
+ grad = torch.zeros_like(out)
25
+ grad[..., out.shape[-1]//2] = torch.nan # set NaN gradient at the middle of the output
26
+ out.backward(gradient=grad)
27
+ self.zero_grad() # reset to avoid future problems
28
+ return int(torch.sum(input_tensor.grad.isnan()).cpu()) # Count NaN in the input
src/gyraudio/audio_separation/architecture/neutral.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from gyraudio.audio_separation.architecture.model import SeparationModel
4
+
5
+
6
+ class NeutralModel(SeparationModel):
7
+ def __init__(self, *args, **kwargs) -> None:
8
+ super().__init__(*args, **kwargs)
9
+ self.fake = torch.nn.Conv1d(1, 1, 1, bias=False)
10
+
11
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
+ """Identity function
13
+ """
14
+ n = self.fake(x)
15
+ return x, n
src/gyraudio/audio_separation/architecture/transformer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from gyraudio.audio_separation.architecture.model import SeparationModel
3
+ from gyraudio.audio_separation.architecture.building_block import FilterBank
4
+ from typing import Optional
5
+
6
+
7
+ class TransformerModel(SeparationModel):
8
+ """Transformer base model
9
+ =========================
10
+ - Embed signal with a filter bank
11
+ - No positional encoding (Potential =add/concatenate positional encoding)
12
+ - `nlayers` * transformer blocks
13
+ """
14
+
15
+ def __init__(self,
16
+ nhead: int = 8, # H
17
+ nlayers: int = 4, # L
18
+ k_size=5,
19
+ embedding_dim: int = 64, # D
20
+ ch_in: int = 1,
21
+ ch_out: int = 1,
22
+ dropout: float = 0., # dr
23
+ positional_encoding: str = None
24
+ ) -> None:
25
+ """Transformer base model
26
+
27
+ Args:
28
+ nhead (int): number of heads in each of the MHA models
29
+ embedding_dim (int): D number of channels in the audio embeddings
30
+ = output of the filter bank
31
+ assume `embedding_dim` = `h_dim`
32
+ h_dim is the hidden dimension of the model.
33
+ nlayers (int): number of nn.TransformerEncoderLayer in nn.TransformerEncoder
34
+ dropout (float, optional): dropout value. Defaults to 0.
35
+ """
36
+ super().__init__()
37
+ self.model_type = "Transformer"
38
+ h_dim = embedding_dim # use the same embedding & hidden dimensions
39
+
40
+ self.encoder = FilterBank(ch_in, embedding_dim, k_size=k_size)
41
+ if positional_encoding is None:
42
+ self.pos_encoder = torch.nn.Identity()
43
+ else:
44
+ raise NotImplementedError(
45
+ f"Unknown positional encoding {positional_encoding} - should be add/concat in future")
46
+ # self.pos_encoder = PositionalEncoding(h_dim, dropout=dropout)
47
+
48
+ encoder_layers = torch.nn.TransformerEncoderLayer(
49
+ d_model=h_dim, # input dimension to the transformer encoder layer
50
+ nhead=nhead, # number of heads for MHA (Multi-head attention)
51
+ dim_feedforward=h_dim, # output dimension of the MLP on top of the transformer.
52
+ dropout=dropout,
53
+ batch_first=True
54
+ ) # we assume h_dim = d_model = dim_feedforward
55
+
56
+ self.transformer_encoder = torch.nn.TransformerEncoder(
57
+ encoder_layers,
58
+ num_layers=nlayers
59
+ )
60
+ self.h_dim = h_dim
61
+ self.target_modality_conv = torch.nn.Conv1d(h_dim, ch_out, 1) # conv1x1 channel mixer
62
+ # Note: we could finish with a few residual conv blocks... this is pure signal processing
63
+
64
+ def forward(
65
+ self, src: torch.LongTensor,
66
+ src_mask: Optional[torch.FloatTensor] = None
67
+ ) -> torch.FloatTensor:
68
+ """Embdeddings, positional encoders, go trough `nlayers` of residual {multi (`nhead`) attention heads + MLP}.
69
+
70
+ Args:
71
+ src (torch.LongTensor): [N, 1, T] audio signal
72
+
73
+ Returns:
74
+ torch.FloatTensor: separated signal [N, 1, T]
75
+ """
76
+ src = self.encoder(src) # [N, 1, T] -> [N, D, T]
77
+ src = src.transpose(-1, -2) # [N, D, T] -> [N, T, D] # Transformer expects (batch N, seq "T", features "D")
78
+ src = self.pos_encoder(src) # -> [N, T, D] - add positional encoding
79
+
80
+ output = self.transformer_encoder(src, mask=src_mask) # -> [N, T, D]
81
+ output = output.transpose(-1, -2) # -> [N, D, T]
82
+ output = self.target_modality_conv(output) # -> [N, 1, T]
83
+ return output, None
84
+
85
+
86
+ if __name__ == "__main__":
87
+ model = TransformerModel()
88
+ inp = torch.rand(2, 1, 2048)
89
+ out = model(inp)
90
+ print(model)
91
+ print(out[0].shape)
src/gyraudio/audio_separation/architecture/unet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from gyraudio.audio_separation.architecture.model import SeparationModel
3
+ from gyraudio.audio_separation.architecture.building_block import ResConvolution
4
+ from typing import Optional
5
+ # import logging
6
+
7
+
8
+ class EncoderSingleStage(torch.nn.Module):
9
+ """
10
+ Extend channels
11
+ Resnet
12
+ Downsample by 2
13
+ """
14
+
15
+ def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5):
16
+ # ch_out ~ ch_in*extension_factor
17
+ super().__init__()
18
+ hdim = hdim or ch
19
+ self.extension_conv = torch.nn.Conv1d(ch, ch_out, k_size, padding=k_size//2)
20
+ self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size)
21
+ # warning on maxpooling jitter offset!
22
+ self.max_pool = torch.nn.MaxPool1d(kernel_size=2)
23
+
24
+ def forward(self, x):
25
+ x = self.extension_conv(x)
26
+ x = self.res_conv(x)
27
+ x_ds = self.max_pool(x)
28
+ return x, x_ds
29
+
30
+
31
+ class DecoderSingleStage(torch.nn.Module):
32
+ """
33
+ Upsample by 2
34
+ Resnet
35
+ Extend channels
36
+ """
37
+
38
+ def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5):
39
+ """Decoder stage
40
+
41
+ Args:
42
+ ch (int): channel size (downsampled & skip connection have same channel size)
43
+ ch_out (int): number of output channels (shall match the number of input channels of the next stage)
44
+ hdim (Optional[int], optional): Hidden dimension used in the residual block. Defaults to None.
45
+ k_size (int, optional): Convolution size. Defaults to 5.
46
+ Notes:
47
+ ======
48
+ ch_out = 2*ch/extension_factor
49
+
50
+ self.scale_mixers_conv
51
+ - tells how lower decoded (x_ds) scale is merged with current encoded scale (x_skip)
52
+ - could be a pointwise aka conv1x1
53
+ """
54
+
55
+ super().__init__()
56
+ hdim = hdim or ch
57
+ self.scale_mixers_conv = torch.nn.Conv1d(2*ch, ch_out, k_size, padding=k_size//2)
58
+
59
+ self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size)
60
+ # warning: Linear interpolation shall be "conjugated" with the skipping downsampling
61
+ # special care shall be taken care of regarding offsets
62
+ # https://arxiv.org/abs/1806.03185
63
+ self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
64
+ self.non_linearity = torch.nn.ReLU()
65
+
66
+ def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
67
+ """"""
68
+ x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T]
69
+
70
+ x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T]
71
+ x = self.scale_mixers_conv(x) # [N, ch_out, T]
72
+ x = self.non_linearity(x)
73
+ x = self.res_conv(x) # [N, ch_out, T]
74
+ return x
75
+
76
+
77
+ class ResUNet(SeparationModel):
78
+ """Convolutional neural network for audio separation,
79
+
80
+ Decimation, bottleneck
81
+ """
82
+
83
+ def __init__(self,
84
+ ch_in: int = 1,
85
+ ch_out: int = 2,
86
+ channels_extension: float = 1.5,
87
+ h_dim=16,
88
+ k_size=5,
89
+ ) -> None:
90
+ super().__init__()
91
+ self.need_split = ch_out != ch_in
92
+ self.ch_out = ch_out
93
+ self.source_modality_conv = torch.nn.Conv1d(ch_in, h_dim, k_size, padding=k_size//2)
94
+ self.encoder_list = torch.nn.ModuleList()
95
+ self.decoder_list = torch.nn.ModuleList()
96
+ self.non_linearity = torch.nn.ReLU()
97
+
98
+ h_dim_current = h_dim
99
+ for _level in range(4):
100
+ h_dim_ds = int(h_dim_current*channels_extension)
101
+ self.encoder_list.append(EncoderSingleStage(h_dim_current, h_dim_ds, k_size=k_size))
102
+ self.decoder_list.append(DecoderSingleStage(h_dim_ds, h_dim_current, k_size=k_size))
103
+ h_dim_current = h_dim_ds
104
+ self.bottleneck = ResConvolution(h_dim_current, k_size=k_size)
105
+ self.target_modality_conv = torch.nn.Conv1d(h_dim, ch_out, 1) # conv1x1 channel mixer
106
+
107
+ def forward(self, x_in):
108
+ # x_in (1, 2048)
109
+ x0 = self.source_modality_conv(x_in)
110
+ x0 = self.non_linearity(x0)
111
+ # x0 -> (16, 2048)
112
+
113
+ x1_skip, x1_ds = self.encoder_list[0](x0)
114
+ # x1_skip -> (24, 2048)
115
+ # x1_ds -> (24, 1024)
116
+ # print(x1_skip.shape, x1_ds.shape)
117
+
118
+ x2_skip, x2_ds = self.encoder_list[1](x1_ds)
119
+ # x2_skip -> (36, 1024)
120
+ # x2_ds -> (36, 512)
121
+ # print(x2_skip.shape, x2_ds.shape)
122
+
123
+ x3_skip, x3_ds = self.encoder_list[2](x2_ds)
124
+ # x3_skip -> (54, 512)
125
+ # x3_ds -> (54, 256)
126
+ # print(x3_skip.shape, x3_ds.shape)
127
+
128
+ x4_skip, x4_ds = self.encoder_list[3](x3_ds)
129
+ # x4_skip -> (81, 256)
130
+ # x4_ds -> (81, 128)
131
+ # print(x4_skip.shape, x4_ds.shape)
132
+
133
+ x4_dec = self.bottleneck(x4_ds)
134
+ x3_dec = self.decoder_list[3](x4_dec, x4_skip)
135
+ x2_dec = self.decoder_list[2](x3_dec, x3_skip)
136
+ x1_dec = self.decoder_list[1](x2_dec, x2_skip)
137
+ x0_dec = self.decoder_list[0](x1_dec, x1_skip)
138
+ demuxed = self.target_modality_conv(x0_dec)
139
+ # no relu
140
+ if self.need_split:
141
+ return torch.chunk(demuxed, self.ch_out, dim=1)
142
+ return demuxed, None
143
+
144
+
145
+ if __name__ == "__main__":
146
+ model = ResUNet()
147
+ inp = torch.rand(2, 1, 2048)
148
+ out = model(inp)
149
+ print(model)
150
+ print(model.count_parameters())
151
+ print(out[0].shape)
src/gyraudio/audio_separation/architecture/wave_unet.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from gyraudio.audio_separation.architecture.model import SeparationModel
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ def get_non_linearity(activation: str):
7
+ if activation == "LeakyReLU":
8
+ non_linearity = torch.nn.LeakyReLU()
9
+ else:
10
+ non_linearity = torch.nn.ReLU()
11
+ return non_linearity
12
+
13
+
14
+ class BaseConvolutionBlock(torch.nn.Module):
15
+ def __init__(self, ch_in, ch_out: int, k_size: int, activation="LeakyReLU", dropout: float = 0, bias: bool = True) -> None:
16
+ super().__init__()
17
+ self.conv = torch.nn.Conv1d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias)
18
+ self.non_linearity = get_non_linearity(activation)
19
+ self.dropout = torch.nn.Dropout1d(p=dropout)
20
+
21
+ def forward(self, x_in: torch.Tensor) -> torch.Tensor:
22
+ x = self.conv(x_in) # [N, ch_in, T] -> [N, ch_in+channels_extension, T]
23
+ x = self.non_linearity(x)
24
+ x = self.dropout(x)
25
+ return x
26
+
27
+
28
+ class EncoderStage(torch.nn.Module):
29
+ """Conv (and extend channels), downsample 2 by skipping samples
30
+ """
31
+
32
+ def __init__(self, ch_in: int, ch_out: int, k_size: int = 15, dropout: float = 0, bias: bool = True) -> None:
33
+
34
+ super().__init__()
35
+
36
+ self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias)
37
+
38
+ def forward(self, x):
39
+ x = self.conv_block(x)
40
+
41
+ x_ds = x[..., ::2]
42
+ # ch_out = ch_in+channels_extension
43
+ return x, x_ds
44
+
45
+
46
+ class DecoderStage(torch.nn.Module):
47
+ """Upsample by 2, Concatenate with skip connection, Conv (and shrink channels)
48
+ """
49
+
50
+ def __init__(self, ch_in: int, ch_out: int, k_size: int = 5, dropout: float = 0., bias: bool = True) -> None:
51
+ """Decoder stage
52
+ """
53
+
54
+ super().__init__()
55
+ self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias)
56
+ self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
57
+
58
+ def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
59
+ """"""
60
+ x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T]
61
+ x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T]
62
+ x = self.conv_block(x) # [N, ch_out, T]
63
+ return x
64
+
65
+
66
+ class WaveUNet(SeparationModel):
67
+ """UNET in temporal domain (waveform)
68
+ = Multiscale convolutional neural network for audio separation
69
+ https://arxiv.org/abs/1806.03185
70
+ """
71
+
72
+ def __init__(self,
73
+ ch_in: int = 1,
74
+ ch_out: int = 2,
75
+ channels_extension: int = 24,
76
+ k_conv_ds: int = 15,
77
+ k_conv_us: int = 5,
78
+ num_layers: int = 6,
79
+ dropout: float = 0.0,
80
+ bias: bool = True,
81
+ ) -> None:
82
+ super().__init__()
83
+ self.need_split = ch_out != ch_in
84
+ self.ch_out = ch_out
85
+ self.encoder_list = torch.nn.ModuleList()
86
+ self.decoder_list = torch.nn.ModuleList()
87
+ # Defining first encoder
88
+ self.encoder_list.append(EncoderStage(ch_in, channels_extension, k_size=k_conv_ds, dropout=dropout, bias=bias))
89
+ for level in range(1, num_layers+1):
90
+ ch_i = level*channels_extension
91
+ ch_o = (level+1)*channels_extension
92
+ if level < num_layers:
93
+ # Skipping last encoder since we defined the first one outside the loop
94
+ self.encoder_list.append(EncoderStage(ch_i, ch_o, k_size=k_conv_ds, dropout=dropout, bias=bias))
95
+ self.decoder_list.append(DecoderStage(ch_o+ch_i, ch_i, k_size=k_conv_us, dropout=dropout, bias=bias))
96
+ self.bottleneck = BaseConvolutionBlock(
97
+ num_layers*channels_extension,
98
+ (num_layers+1)*channels_extension,
99
+ k_size=k_conv_ds,
100
+ dropout=dropout,
101
+ bias=bias)
102
+ self.dropout = torch.nn.Dropout1d(p=dropout)
103
+ self.target_modality_conv = torch.nn.Conv1d(
104
+ channels_extension+ch_in, ch_out, 1, bias=bias) # conv1x1 channel mixer
105
+
106
+ def forward(self, x_in: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
107
+ """Forward UNET pass
108
+
109
+ ```
110
+ (1 , 2048)----------------->(24 , 2048) > (1 , 2048)
111
+ v ^
112
+ (24 , 1024)----------------->(48 , 1024)
113
+ v ^
114
+ (48 , 512 )----------------->(72 , 512 )
115
+ v ^
116
+ (72 , 256 )----------------->(96 , 256 )
117
+ v ^
118
+ (96 , 128 )----BOTTLENECK--->(120, 128 )
119
+ ```
120
+
121
+ """
122
+ skipped_list = []
123
+ ds_list = [x_in]
124
+ for level, enc in enumerate(self.encoder_list):
125
+ x_skip, x_ds = enc(ds_list[-1])
126
+ skipped_list.append(x_skip)
127
+ ds_list.append(x_ds.clone())
128
+ # print(x_skip.shape, x_ds.shape)
129
+ x_dec = self.bottleneck(ds_list[-1])
130
+ for level, dec in enumerate(self.decoder_list[::-1]):
131
+ x_dec = dec(x_dec, skipped_list[-1-level])
132
+ # print(x_dec.shape)
133
+ x_dec = torch.cat([x_dec, x_in], dim=1)
134
+ # print(x_dec.shape)
135
+ x_dec = self.dropout(x_dec)
136
+ demuxed = self.target_modality_conv(x_dec)
137
+ # print(demuxed.shape)
138
+ if self.need_split:
139
+ return torch.chunk(demuxed, self.ch_out, dim=1)
140
+ return demuxed, None
141
+
142
+ # x_skip, x_ds
143
+ # (24, 2048), (24, 1024)
144
+ # (48, 1024), (48, 512 )
145
+ # (72, 512 ), (72, 256 )
146
+ # (96, 256 ), (96, 128 )
147
+
148
+ # (120, 128 )
149
+ # (96 , 256 )
150
+ # (72 , 512 )
151
+ # (48 , 1024)
152
+ # (24 , 2048)
153
+ # (25 , 2048) demuxed - after concat
154
+ # (1 , 2048)
155
+
156
+
157
+ if __name__ == "__main__":
158
+ model = WaveUNet(ch_out=1, num_layers=9)
159
+ inp = torch.rand(2, 1, 2048)
160
+ out = model(inp)
161
+ print(model)
162
+ print(model.count_parameters())
163
+ print(out[0].shape)
src/gyraudio/audio_separation/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.data.mixed import MixedAudioDataset
2
+ from gyraudio.audio_separation.data.remixed_fixed import RemixedFixedAudioDataset
3
+ from gyraudio.audio_separation.data.remixed_rnd import RemixedRandomAudioDataset
4
+ from gyraudio.audio_separation.data.single import SingleAudioDataset
5
+ from gyraudio.audio_separation.data.dataloader import get_dataloader, get_config_dataloader
src/gyraudio/audio_separation/data/dataloader.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from gyraudio.audio_separation.data.mixed import MixedAudioDataset
3
+ from typing import Optional, List
4
+ from gyraudio.audio_separation.properties import (
5
+ DATA_PATH, AUGMENTATION, SNR_FILTER, SHUFFLE, BATCH_SIZE, TRAIN, VALID, TEST, AUG_TRIM
6
+ )
7
+ from gyraudio import root_dir
8
+ RAW_AUDIO_ROOT = root_dir/"__data_source_separation"/"voice_origin"
9
+ MIXED_AUDIO_ROOT = root_dir/"__data_source_separation"/"source_separation"
10
+
11
+
12
+ def get_dataloader(configurations: dict, audio_dataset=MixedAudioDataset):
13
+ dataloaders = {}
14
+ for mode, configuration in configurations.items():
15
+ dataset = audio_dataset(
16
+ configuration[DATA_PATH],
17
+ augmentation_config=configuration[AUGMENTATION],
18
+ snr_filter=configuration[SNR_FILTER]
19
+ )
20
+ dl = DataLoader(
21
+ dataset,
22
+ shuffle=configuration[SHUFFLE],
23
+ batch_size=configuration[BATCH_SIZE],
24
+ collate_fn=dataset.collate_fn
25
+ )
26
+ dataloaders[mode] = dl
27
+ return dataloaders
28
+
29
+
30
+ def get_config_dataloader(
31
+ audio_root=MIXED_AUDIO_ROOT,
32
+ mode: str = TRAIN,
33
+ shuffle: Optional[bool] = None,
34
+ batch_size: Optional[int] = 16,
35
+ snr_filter: Optional[List[float]] = None,
36
+ augmentation: dict = {}):
37
+ audio_folder = audio_root/mode
38
+ assert mode in [TRAIN, VALID, TEST]
39
+ assert audio_folder.exists()
40
+ config = {
41
+ DATA_PATH: audio_folder,
42
+ SHUFFLE: shuffle if shuffle is not None else (True if mode == TRAIN else False),
43
+ AUGMENTATION: augmentation,
44
+ SNR_FILTER: snr_filter,
45
+ BATCH_SIZE: batch_size
46
+ }
47
+ return config
src/gyraudio/audio_separation/data/dataset.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from pathlib import Path
3
+ from typing import Optional
4
+ import torch
5
+ from torch.utils.data import default_collate
6
+ from typing import Tuple
7
+ from functools import partial
8
+ from gyraudio.audio_separation.properties import (
9
+ AUG_AWGN, AUG_RESCALE, AUG_TRIM, LENGTHS, LENGTH_DIVIDER, TRIM_PROB
10
+ )
11
+
12
+
13
+ class AudioDataset(Dataset):
14
+ def __init__(
15
+ self,
16
+ data_path: Path,
17
+ augmentation_config: dict = {},
18
+ snr_filter: Optional[float] = None,
19
+ debug: bool = False
20
+ ):
21
+ self.debug = debug
22
+ self.data_path = data_path
23
+ self.augmentation_config = augmentation_config
24
+ self.snr_filter = snr_filter
25
+ self.load_data()
26
+ self.length = len(self.file_list)
27
+ self.collate_fn = None
28
+ if AUG_TRIM in self.augmentation_config:
29
+ self.collate_fn = partial(collate_fn_generic,
30
+ lengths_lim=self.augmentation_config[AUG_TRIM][LENGTHS],
31
+ length_divider=self.augmentation_config[AUG_TRIM][LENGTH_DIVIDER],
32
+ trim_prob=self.augmentation_config[AUG_TRIM][TRIM_PROB])
33
+
34
+ def filter_data(self, snr):
35
+ if self.snr_filter is None:
36
+ return True
37
+ if snr in self.snr_filter:
38
+ return True
39
+ else:
40
+ return False
41
+
42
+ def load_data(self):
43
+ raise NotImplementedError("load_data method must be implemented")
44
+
45
+ def augment_data(self, mixed_audio_signal, clean_audio_signal, noise_audio_signal):
46
+ if AUG_RESCALE in self.augmentation_config:
47
+ current_amplitude = 0.5 + 1.5*torch.rand(1, device=mixed_audio_signal.device)
48
+ # logging.debug(current_amplitude)
49
+ mixed_audio_signal *= current_amplitude
50
+ noise_audio_signal *= current_amplitude
51
+ clean_audio_signal *= current_amplitude
52
+ if AUG_AWGN in self.augmentation_config:
53
+ # noise_std = self.augmentation_config[AUG_AWGN]["noise_std"]
54
+ noise_std = 0.01
55
+ current_noise_std = torch.randn(1) * noise_std
56
+ # logging.debug(current_noise_std)
57
+ extra_awgn = torch.randn(mixed_audio_signal.shape, device=mixed_audio_signal.device) * current_noise_std
58
+ mixed_audio_signal = mixed_audio_signal+extra_awgn
59
+ # Open question: should we add noise to the noise signal aswell?
60
+
61
+ return mixed_audio_signal, clean_audio_signal, noise_audio_signal
62
+
63
+ def __len__(self):
64
+ return self.length
65
+
66
+ def __getitem__(self, idx: int) -> torch.Tensor:
67
+ raise NotImplementedError("__getitem__ method must be implemented")
68
+
69
+
70
+ def collate_fn_generic(batch, lengths_lim, length_divider=1024, trim_prob=0.5) -> Tuple[torch.Tensor, torch.Tensor]:
71
+ """Collate function to allow trimming (=crop the time dimension) of the signals in a batch.
72
+
73
+ Args:
74
+ batch (list): A list of tuples (triplets), where each tuple contain:
75
+ - mixed_audio_signal
76
+ - clean_audio_signal
77
+ - noise_audio_signal
78
+ lengths_lim (list) : A list of containing a minimum length (0) and a maximum length (1)
79
+ length_divider (int) : has to be a trimmed length divider
80
+ trim_prob (float) : trimming probability
81
+
82
+ Returns:
83
+ - Tensor: A batch of mixed_audio_signal, trimmed to the same length.
84
+ - Tensor: A batch of clean_audio_signal
85
+ - Tensor: A batch of noise_audio_signal
86
+ """
87
+
88
+ # Find the length of the shortest signal in the batch
89
+ mixed_audio_signal, clean_audio_signal, noise_audio_signal = default_collate(batch)
90
+ length = mixed_audio_signal[0].shape[-1]
91
+ min_length, max_length = lengths_lim
92
+ take_full_signal = torch.rand(1) > trim_prob
93
+ if not take_full_signal:
94
+ start = torch.randint(0, length-min_length, (1,))
95
+ trim_length = torch.randint(min_length, min(max_length, length-start-1)+1, (1,))
96
+ trim_length = trim_length-trim_length % length_divider
97
+ end = start + trim_length
98
+ else:
99
+ start = 0
100
+ end = length - length % length_divider
101
+ mixed_audio_signal = mixed_audio_signal[..., start:end]
102
+ clean_audio_signal = clean_audio_signal[..., start:end]
103
+ noise_audio_signal = noise_audio_signal[..., start:end]
104
+ return mixed_audio_signal, clean_audio_signal, noise_audio_signal
src/gyraudio/audio_separation/data/mixed.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.data.dataset import AudioDataset
2
+ import logging
3
+ import torch
4
+ import torchaudio
5
+ from typing import Tuple
6
+
7
+
8
+ class MixedAudioDataset(AudioDataset):
9
+ def load_data(self):
10
+ self.folder_list = sorted(list(self.data_path.iterdir()))
11
+ self.file_list = [
12
+ [
13
+ list(folder.glob("mix*.wav"))[0],
14
+ folder/"voice.wav",
15
+ folder/"noise.wav"
16
+ ] for folder in self.folder_list
17
+ ]
18
+ snr_list = [float(file[0].stem.split("_")[-1]) for file in self.file_list]
19
+ self.file_list = [files for snr, files in zip(snr_list, self.file_list) if self.filter_data(snr)]
20
+ if self.debug:
21
+ logging.info(f"Available SNR {set(snr_list)}")
22
+ print(f"Available SNR {set(snr_list)}")
23
+ print("Filtered", len(self.file_list), self.snr_filter)
24
+ self.sampling_rate = None
25
+
26
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
27
+ mixed_audio_path, signal_path, noise_path = self.file_list[idx]
28
+ assert mixed_audio_path.exists()
29
+ assert signal_path.exists()
30
+ assert noise_path.exists()
31
+ mixed_audio_signal, sampling_rate = torchaudio.load(str(mixed_audio_path))
32
+ clean_audio_signal, sampling_rate = torchaudio.load(str(signal_path))
33
+ noise_audio_signal, sampling_rate = torchaudio.load(str(noise_path))
34
+ self.sampling_rate = sampling_rate
35
+ mixed_audio_signal, clean_audio_signal, noise_audio_signal = self.augment_data(mixed_audio_signal, clean_audio_signal, noise_audio_signal)
36
+ if self.debug:
37
+ logging.debug(f"{mixed_audio_signal.shape}")
38
+ logging.debug(f"{clean_audio_signal.shape}")
39
+ logging.debug(f"{noise_audio_signal.shape}")
40
+ return mixed_audio_signal, clean_audio_signal, noise_audio_signal
src/gyraudio/audio_separation/data/remixed.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.data.dataset import AudioDataset
2
+ from typing import Tuple
3
+ import logging
4
+ from torch import Tensor
5
+ import torch
6
+ import torchaudio
7
+
8
+
9
+ class RemixedAudioDataset(AudioDataset):
10
+ def generate_snr_list(self):
11
+ self.snr_list = None
12
+
13
+ def load_data(self):
14
+ self.folder_list = sorted(list(self.data_path.iterdir()))
15
+ self.file_list = [
16
+ [
17
+ folder/"voice.wav",
18
+ folder/"noise.wav"
19
+ ] for folder in self.folder_list
20
+ ]
21
+ self.sampling_rate = None
22
+ self.min_snr, self.max_snr = -4, 4
23
+ self.generate_snr_list()
24
+ if self.debug:
25
+ print("Not filtered", len(self.file_list), self.snr_filter)
26
+ print(self.snr_list)
27
+
28
+ def get_idx_noise(self, idx):
29
+ raise NotImplementedError("get_idx_noise method must be implemented")
30
+
31
+ def get_snr(self, idx):
32
+ raise NotImplementedError("get_snr method must be implemented")
33
+
34
+ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
35
+ signal_path = self.file_list[idx][0]
36
+ idx_noise = self.get_idx_noise(idx)
37
+ noise_path = self.file_list[idx_noise][1]
38
+
39
+ assert signal_path.exists()
40
+ assert noise_path.exists()
41
+ clean_audio_signal, sampling_rate = torchaudio.load(str(signal_path))
42
+ noise_audio_signal, sampling_rate = torchaudio.load(str(noise_path))
43
+ snr = self.get_snr(idx)
44
+ alpha = 10 ** (-snr / 20) * torch.norm(clean_audio_signal) / torch.norm(noise_audio_signal)
45
+ mixed_audio_signal = clean_audio_signal + alpha*noise_audio_signal
46
+ self.sampling_rate = sampling_rate
47
+ mixed_audio_signal, clean_audio_signal, noise_audio_signal = self.augment_data(
48
+ mixed_audio_signal, clean_audio_signal, noise_audio_signal)
49
+ if self.debug:
50
+ logging.debug(f"{mixed_audio_signal.shape}")
51
+ logging.debug(f"{clean_audio_signal.shape}")
52
+ logging.debug(f"{noise_audio_signal.shape}")
53
+ return mixed_audio_signal, clean_audio_signal, noise_audio_signal
src/gyraudio/audio_separation/data/remixed_fixed.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.data.remixed import RemixedAudioDataset
2
+ import torch
3
+
4
+ class RemixedFixedAudioDataset(RemixedAudioDataset):
5
+ def generate_snr_list(self) :
6
+ rnd_gen = torch.Generator()
7
+ rnd_gen.manual_seed(2147483647)
8
+ if self.snr_filter is None :
9
+ self.snr_list = self.min_snr + (self.max_snr - self.min_snr)*torch.rand(len(self.file_list), generator = rnd_gen)
10
+ else :
11
+ indices = torch.randint(0, len(self.snr_filter), (len(self.file_list),), generator=rnd_gen)
12
+ self.snr_list = [self.snr_filter[idx] for idx in indices]
13
+
14
+ def get_idx_noise(self, idx) :
15
+ return idx
16
+
17
+ def get_snr(self, idx) :
18
+ return self.snr_list[idx]
src/gyraudio/audio_separation/data/remixed_rnd.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.data.remixed import RemixedAudioDataset
2
+ from torch import rand, randint
3
+
4
+
5
+ class RemixedRandomAudioDataset(RemixedAudioDataset):
6
+ def get_idx_noise(self, idx):
7
+ return randint(0, len(self.file_list)-1, (1,))
8
+
9
+ def get_snr(self, idx):
10
+ if self.snr_filter is None:
11
+ return self.min_snr + (self.max_snr - self.min_snr)*rand(1)
12
+ return self.snr_filter[randint(0, len(self.snr_filter), (1,))]
src/gyraudio/audio_separation/data/silence_detector.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+
6
+ def get_silence_mask(
7
+ sig: torch.Tensor, morph_kernel_size: int = 499, k_smooth=21, thresh=0.0001,
8
+ debug: bool = False) -> torch.Tensor:
9
+ with torch.no_grad():
10
+ smooth = torch.nn.Conv1d(1, 1, k_smooth, padding=k_smooth//2, bias=False).to(sig.device)
11
+ smooth.weight.data.fill_(1./k_smooth)
12
+ smoothed = smooth(torch.abs(sig))
13
+ st = 1.*(torch.abs(smoothed) < thresh*torch.ones_like(smoothed, device=sig.device))
14
+ sig_dil = torch.nn.MaxPool1d(morph_kernel_size, stride=1, padding=morph_kernel_size//2)(st)
15
+ sig_ero = -torch.nn.MaxPool1d(morph_kernel_size, stride=1, padding=morph_kernel_size//2)(-sig_dil)
16
+ if debug:
17
+ return sig_ero.squeeze(0), smoothed.squeeze(0), st.squeeze(0)
18
+ else:
19
+ return sig_ero
20
+
21
+
22
+ def visualize_silence_mask(sig: torch.Tensor, silence_thresh: float = 0.0001):
23
+ silence_thresh = 0.0001
24
+ silence_mask, smoothed_amplitude, _ = get_silence_mask(
25
+ sig, k_smooth=21, morph_kernel_size=499, thresh=silence_thresh, debug=True
26
+ )
27
+ plt.figure(figsize=(12, 4))
28
+ plt.subplot(121)
29
+ plt.plot(sig.squeeze(0).cpu().numpy(), "k-", label="voice", alpha=0.5)
30
+ plt.plot(0.01*silence_mask.cpu().numpy(), "r-", alpha=1., label="silence mask")
31
+ plt.grid()
32
+ plt.legend()
33
+ plt.title("Voice and silence mask")
34
+ plt.ylim(-0.04, 0.04)
35
+
36
+ plt.subplot(122)
37
+ plt.plot(smoothed_amplitude.cpu().numpy(), "g--", alpha=0.5, label="smoothed amplitude")
38
+ plt.plot(np.ones(silence_mask.shape[-1])*silence_thresh, "c--", alpha=1., label="threshold")
39
+ plt.plot(-silence_thresh+silence_thresh*silence_mask.cpu().numpy(), "r-", alpha=1, label="silence mask")
40
+ plt.grid()
41
+ plt.legend()
42
+ plt.title("Thresholding mechanism")
43
+ plt.ylim(-silence_thresh, silence_thresh*10)
44
+ plt.show()
45
+
46
+
47
+ if __name__ == "__main__":
48
+ from gyraudio.default_locations import SAMPLE_ROOT
49
+ from gyraudio.audio_separation.visualization.pre_load_audio import audio_loading
50
+ from gyraudio.audio_separation.properties import CLEAN, BUFFERS
51
+ sample_folder = SAMPLE_ROOT/"0009"
52
+ signals = audio_loading(sample_folder, preload=True)
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ sig_in = signals[BUFFERS][CLEAN].to(device)
55
+ visualize_silence_mask(sig_in)
src/gyraudio/audio_separation/data/single.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.data.dataset import AudioDataset
2
+ import logging
3
+ import torchaudio
4
+
5
+
6
+ class SingleAudioDataset(AudioDataset):
7
+ def load_data(self):
8
+ self.file_list = sorted(list(self.data_path.glob("*.wav")))
9
+
10
+ def __getitem__(self, idx: int):
11
+ audio_path = self.file_list[idx]
12
+ assert audio_path.exists()
13
+ audio_signal, sampling_rate = torchaudio.load(str(audio_path))
14
+ logging.debug(f"{audio_signal.shape}")
15
+ return audio_signal
src/gyraudio/audio_separation/experiment_tracking/experiments.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.default_locations import MIXED_AUDIO_ROOT
2
+ from gyraudio.audio_separation.properties import (
3
+ TRAIN, TEST, VALID, NAME, EPOCHS, LEARNING_RATE,
4
+ OPTIMIZER, BATCH_SIZE, DATALOADER, AUGMENTATION,
5
+ SHORT_NAME, AUG_TRIM, TRIM_PROB, LENGTH_DIVIDER, LENGTHS, SNR_FILTER
6
+ )
7
+ from gyraudio.audio_separation.data.remixed_fixed import RemixedFixedAudioDataset
8
+ from gyraudio.audio_separation.data.remixed_rnd import RemixedRandomAudioDataset
9
+ from gyraudio.audio_separation.data import get_dataloader, get_config_dataloader
10
+ from gyraudio.audio_separation.experiment_tracking.experiments_definition import get_experiment_generator
11
+ import torch
12
+ from typing import Tuple
13
+
14
+
15
+ def get_experience(exp_major: int, exp_minor: int = 0, snr_filter_test=None, dry_run=False) -> Tuple[str, torch.nn.Module, dict, dict]:
16
+ """Get all experience details
17
+
18
+ Args:
19
+ exp_major (int): Major experience number
20
+ exp_minor (int, optional): Used for HP search. Defaults to 0.
21
+
22
+
23
+ Returns:
24
+ Tuple[str, torch.nn.Module, dict, dict]: short_name, model, config, dataloaders
25
+ """
26
+ model = None
27
+ config = {}
28
+ dataloader_name = "remix"
29
+ config = {
30
+ NAME: None,
31
+ OPTIMIZER: {
32
+ NAME: "adam",
33
+ LEARNING_RATE: 0.001
34
+ },
35
+ EPOCHS: 60,
36
+ DATALOADER: {
37
+ NAME: dataloader_name,
38
+ },
39
+ BATCH_SIZE: [16, 16, 16],
40
+ SNR_FILTER : snr_filter_test
41
+ }
42
+
43
+ model, config = get_experiment_generator(exp_major=exp_major)(config, no_model=dry_run, minor=exp_minor)
44
+ # POST PROCESSING
45
+ if isinstance(config[BATCH_SIZE], list) or isinstance(config[BATCH_SIZE], tuple):
46
+ config[BATCH_SIZE] = {
47
+ TRAIN: config[BATCH_SIZE][0],
48
+ TEST: config[BATCH_SIZE][1],
49
+ VALID: config[BATCH_SIZE][2],
50
+ }
51
+
52
+ if config[DATALOADER][NAME] == "premix":
53
+ mixed_audio_root = MIXED_AUDIO_ROOT
54
+ dataloaders = get_dataloader({
55
+ TRAIN: get_config_dataloader(
56
+ audio_root=mixed_audio_root,
57
+ mode=TRAIN,
58
+ shuffle=True,
59
+ batch_size=config[BATCH_SIZE][TRAIN],
60
+ augmentation=config[DATALOADER].get(AUGMENTATION, {})
61
+ ),
62
+ TEST: get_config_dataloader(
63
+ audio_root=mixed_audio_root,
64
+ mode=TEST,
65
+ shuffle=False,
66
+ batch_size=config[BATCH_SIZE][TEST],
67
+ snr_filter=config[SNR_FILTER]
68
+ )
69
+ })
70
+ elif config[DATALOADER][NAME] == "remix":
71
+ mixed_audio_root = MIXED_AUDIO_ROOT
72
+ aug_test = {}
73
+ if AUG_TRIM in config[DATALOADER].get(AUGMENTATION, {}):
74
+ aug_test = {
75
+ AUG_TRIM: {LENGTHS: [None, None], LENGTH_DIVIDER: config[DATALOADER][AUGMENTATION]
76
+ [AUG_TRIM][LENGTH_DIVIDER], TRIM_PROB: -1.}
77
+ }
78
+ dl_train = get_dataloader(
79
+ {
80
+ TRAIN: get_config_dataloader(
81
+ audio_root=mixed_audio_root,
82
+ mode=TRAIN,
83
+ shuffle=True,
84
+ batch_size=config[BATCH_SIZE][TRAIN],
85
+ augmentation=config[DATALOADER].get(AUGMENTATION, {})
86
+ )
87
+ },
88
+ audio_dataset=RemixedRandomAudioDataset
89
+ )[TRAIN]
90
+ dl_test = get_dataloader(
91
+ {
92
+ TEST: get_config_dataloader(
93
+ audio_root=mixed_audio_root,
94
+ mode=TEST,
95
+ shuffle=False,
96
+ batch_size=config[BATCH_SIZE][TEST]
97
+ )
98
+ },
99
+ audio_dataset=RemixedFixedAudioDataset
100
+ )[TEST]
101
+ dataloaders = {
102
+ TRAIN: dl_train,
103
+ TEST: dl_test
104
+ }
105
+ else:
106
+ raise NotImplementedError(f"Unknown dataloader {dataloader_name}")
107
+ assert config[NAME] is not None
108
+
109
+ short_name = f"{exp_major:04d}_{exp_minor:04d}"
110
+ config[SHORT_NAME] = short_name
111
+ return short_name, model, config, dataloaders
112
+
113
+
114
+ if __name__ == "__main__":
115
+ from gyraudio.audio_separation.parser import shared_parser
116
+ parser_def = shared_parser()
117
+ args = parser_def.parse_args()
118
+
119
+ for exp in args.experiments:
120
+ short_name, model, config, dl = get_experience(exp)
121
+ print(short_name)
122
+ print(config)
src/gyraudio/audio_separation/experiment_tracking/experiments_decorator.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from gyraudio.audio_separation.properties import (
3
+ NAME, ANNOTATIONS, NB_PARAMS, RECEPTIVE_FIELD
4
+ )
5
+ from typing import Optional
6
+ REGISTERED_EXPERIMENTS_LIST = {}
7
+
8
+
9
+ # def count_parameters(model: torch.nn.Module) -> int:
10
+ # """Count number of trainable parameters
11
+
12
+ # Args:
13
+ # model (torch.nn.Module): Pytorch model
14
+
15
+ # Returns:
16
+ # int: Number of trainable elements
17
+ # """
18
+ # return sum(p.numel() for p in model.parameters() if p.requires_grad)
19
+
20
+
21
+ def registered_experiment(major: Optional[int] = None, failed: Optional[bool] = False) -> callable:
22
+ """Decorate and register an experiment
23
+ - Register the experiment in the list of experiments
24
+ - Count the number of parameters and add it to the config
25
+
26
+ Args:
27
+ major (Optional[int], optional): major id version = Number of the experiment. Defaults to None.
28
+ failed (Optional[bool], optional): If an experiment failed,
29
+ keep track of it but prevent from evaluating. Defaults to False.
30
+
31
+ Returns:
32
+ callable: decorator function
33
+ """
34
+ def decorator(func):
35
+ assert (major) not in REGISTERED_EXPERIMENTS_LIST, f"Experiment {major} already registered"
36
+
37
+ def wrapper(config, minor=None, no_model=False, model=torch.nn.Module()):
38
+ config, model = func(config, model=None if not no_model else model, minor=minor)
39
+ config[NB_PARAMS] = model.count_parameters()
40
+ config[RECEPTIVE_FIELD] = model.receptive_field()
41
+ assert NAME in config, "NAME not defined"
42
+ assert ANNOTATIONS in config, "ANNOTATIONS not defined"
43
+ return model, config
44
+ if not failed:
45
+ REGISTERED_EXPERIMENTS_LIST[major] = wrapper
46
+ return wrapper
47
+
48
+ return decorator
src/gyraudio/audio_separation/experiment_tracking/experiments_definition.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.architecture.flat_conv import FlatConvolutional
2
+ from gyraudio.audio_separation.architecture.unet import ResUNet
3
+ from gyraudio.audio_separation.architecture.wave_unet import WaveUNet
4
+ from gyraudio.audio_separation.architecture.neutral import NeutralModel
5
+ from gyraudio.audio_separation.architecture.transformer import TransformerModel
6
+ from gyraudio.audio_separation.properties import (
7
+ NAME, ANNOTATIONS, MAX_STEPS_PER_EPOCH, EPOCHS, BATCH_SIZE,
8
+ OPTIMIZER, LEARNING_RATE,
9
+ DATALOADER,
10
+ WEIGHT_DECAY,
11
+ LOSS, LOSS_L1,
12
+ AUGMENTATION, AUG_TRIM, AUG_AWGN, AUG_RESCALE,
13
+ LENGTHS, LENGTH_DIVIDER, TRIM_PROB,
14
+ SCHEDULER, SCHEDULER_CONFIGURATION
15
+ )
16
+ from gyraudio.audio_separation.experiment_tracking.experiments_decorator import (
17
+ registered_experiment, REGISTERED_EXPERIMENTS_LIST
18
+ )
19
+
20
+
21
+ @registered_experiment(major=9999)
22
+ def neutral(config, model: bool = None, minor=None):
23
+ config[BATCH_SIZE] = [4, 4, 4]
24
+ config[EPOCHS] = 1
25
+ config[NAME] = "Neutral"
26
+ config[ANNOTATIONS] = "Neutral"
27
+ if model is None:
28
+ model = NeutralModel()
29
+ config[NAME] = "Neutral"
30
+ return config, model
31
+
32
+
33
+ @registered_experiment(major=0)
34
+ def exp_unit_test(config, model: bool = None, minor=None):
35
+ config[MAX_STEPS_PER_EPOCH] = 2
36
+ config[BATCH_SIZE] = [4, 4, 4]
37
+ config[EPOCHS] = 2
38
+ config[NAME] = "Unit Test - Flat Convolutional"
39
+ config[ANNOTATIONS] = "Baseline"
40
+ config[SCHEDULER] = "ReduceLROnPlateau"
41
+ config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8)
42
+ if model is None:
43
+ model = FlatConvolutional()
44
+ return config, model
45
+
46
+ # ---------------- Low Baseline -----------------
47
+
48
+
49
+ def exp_low_baseline(
50
+ config: dict,
51
+ batch_size: int = 16,
52
+ h_dim: int = 16,
53
+ k_size: int = 9,
54
+ dilation: int = 0,
55
+ model: bool = None,
56
+ minor=None
57
+ ):
58
+ config[BATCH_SIZE] = [batch_size, batch_size, batch_size]
59
+ config[NAME] = "Flat Convolutional"
60
+ config[ANNOTATIONS] = f"Baseline H={h_dim}_K={k_size}"
61
+ if dilation > 1:
62
+ config[ANNOTATIONS] += f"_dil={dilation}"
63
+ config["Architecture"] = {
64
+ "name": "Flat-Conv",
65
+ "h_dim": h_dim,
66
+ "scales": 1,
67
+ "k_size": k_size,
68
+ "dilation": dilation
69
+ }
70
+ if model is None:
71
+ model = FlatConvolutional(k_size=k_size, h_dim=h_dim)
72
+ return config, model
73
+
74
+
75
+ @registered_experiment(major=1)
76
+ def exp_1(config, model: bool = None, minor=None):
77
+ config, model = exp_low_baseline(config, batch_size=32, k_size=5)
78
+ return config, model
79
+
80
+
81
+ @registered_experiment(major=2)
82
+ def exp_2(config, model: bool = None, minor=None):
83
+ config, model = exp_low_baseline(config, batch_size=32, k_size=9)
84
+ return config, model
85
+
86
+
87
+ @registered_experiment(major=3)
88
+ def exp_3(config, model: bool = None, minor=None):
89
+ config, model = exp_low_baseline(config, batch_size=32, k_size=9, dilation=2)
90
+ return config, model
91
+
92
+
93
+ @registered_experiment(major=4)
94
+ def exp_4(config, model: bool = None, minor=None):
95
+ config, model = exp_low_baseline(config, batch_size=16, k_size=9)
96
+ return config, model
97
+
98
+ # ------------------ Res U-Net ------------------
99
+
100
+
101
+ def exp_resunet(config, h_dim=16, k_size=5, model=None):
102
+ config[NAME] = "Res-UNet"
103
+ scales = 4
104
+ config[ANNOTATIONS] = f"Res-UNet-{scales}scales_h={h_dim}_k={k_size}"
105
+ config["Architecture"] = {
106
+ "name": "Res-UNet",
107
+ "h_dim": h_dim,
108
+ "scales": scales,
109
+ "k_size": k_size,
110
+ }
111
+ if model is None:
112
+ model = ResUNet(h_dim=h_dim, k_size=k_size)
113
+ return config, model
114
+
115
+
116
+ @registered_experiment(major=2000)
117
+ def exp_2000_waveunet(config, model: bool = None, minor=None):
118
+ config[EPOCHS] = 60
119
+ config, model = exp_resunet(config)
120
+ return config, model
121
+
122
+
123
+ @registered_experiment(major=2001)
124
+ def exp_2001_waveunet(config, model: bool = None, minor=None):
125
+ config[EPOCHS] = 60
126
+ config, model = exp_resunet(config, h_dim=32, k_size=5)
127
+ return config, model
128
+
129
+ # ------------------ Wave U-Net ------------------
130
+
131
+
132
+ def exp_wave_unet(config: dict,
133
+ channels_extension: int = 24,
134
+ k_conv_ds: int = 15,
135
+ k_conv_us: int = 5,
136
+ num_layers: int = 4,
137
+ dropout: float = 0.0,
138
+ bias: bool = True,
139
+ model=None):
140
+ config[NAME] = "Wave-UNet"
141
+ config[ANNOTATIONS] = f"Wave-UNet-{num_layers}scales_h_ext={channels_extension}_k={k_conv_ds}ds-{k_conv_us}us"
142
+ if dropout > 0:
143
+ config[ANNOTATIONS] += f"-dr{dropout:.1e}"
144
+ if not bias:
145
+ config[ANNOTATIONS] += "-BiasFree"
146
+ config["Architecture"] = {
147
+ "k_conv_us": k_conv_us,
148
+ "k_conv_ds": k_conv_ds,
149
+ "num_layers": num_layers,
150
+ "channels_extension": channels_extension,
151
+ "dropout": dropout,
152
+ "bias": bias
153
+ }
154
+ if model is None:
155
+ model = WaveUNet(
156
+ **config["Architecture"]
157
+ )
158
+ config["Architecture"][NAME] = "Wave-UNet"
159
+ return config, model
160
+
161
+
162
+ @registered_experiment(major=1000)
163
+ def exp_1000_waveunet(config, model: bool = None, minor=None):
164
+ config[EPOCHS] = 60
165
+ config, model = exp_wave_unet(config, model=model, num_layers=4, channels_extension=24)
166
+ # 4 layers, ext +24 - Nvidia T500 4Gb RAM - 16 batch size
167
+ return config, model
168
+
169
+
170
+ @registered_experiment(major=1001)
171
+ def exp_1001_waveunet(config, model: bool = None, minor=None):
172
+ # OVERFIT 1M param ?
173
+ config[EPOCHS] = 60
174
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=16)
175
+ # 7 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
176
+ return config, model
177
+
178
+
179
+ @registered_experiment(major=1002)
180
+ def exp_1002_waveunet(config, model: bool = None, minor=None):
181
+ # OVERFIT 1M param ?
182
+ config[EPOCHS] = 60
183
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=16)
184
+ config[DATALOADER][AUGMENTATION] = {
185
+ AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 1024, TRIM_PROB: 0.8},
186
+ AUG_RESCALE: True
187
+ }
188
+ # 7 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
189
+ return config, model
190
+
191
+
192
+ @registered_experiment(major=1003)
193
+ def exp_1003_waveunet(config, model: bool = None, minor=None):
194
+ # OVERFIT 2.3M params
195
+ config[EPOCHS] = 60
196
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=24)
197
+ # 7 layers, ext +24 - Nvidia RTX3060 6Gb RAM - 16 batch size
198
+ return config, model
199
+
200
+
201
+ @registered_experiment(major=1004)
202
+ def exp_1004_waveunet(config, model: bool = None, minor=None):
203
+ config[EPOCHS] = 120
204
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28)
205
+ # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
206
+ return config, model
207
+
208
+
209
+ @registered_experiment(major=1014)
210
+ def exp_1014_waveunet(config, model: bool = None, minor=None):
211
+ # trained with min and max mixing snr hard coded between -2 and -1
212
+ config[EPOCHS] = 50
213
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28)
214
+ # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
215
+ return config, model
216
+
217
+
218
+ @registered_experiment(major=1005)
219
+ def exp_1005_waveunet(config, model: bool = None, minor=None):
220
+ config[EPOCHS] = 150
221
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=16)
222
+ config[DATALOADER][AUGMENTATION] = {
223
+ AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 1024, TRIM_PROB: 0.8},
224
+ AUG_RESCALE: True
225
+ }
226
+ # 7 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
227
+ return config, model
228
+
229
+
230
+ @registered_experiment(major=1006)
231
+ def exp_1006_waveunet(config, model: bool = None, minor=None):
232
+ config[EPOCHS] = 150
233
+ config, model = exp_wave_unet(config, model=model, num_layers=11, channels_extension=16)
234
+ config[DATALOADER][AUGMENTATION] = {
235
+ AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 4096, TRIM_PROB: 0.8},
236
+ AUG_RESCALE: True
237
+ }
238
+ # 11 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
239
+ return config, model
240
+
241
+
242
+ @registered_experiment(major=1007)
243
+ def exp_1007_waveunet(config, model: bool = None, minor=None):
244
+ config[EPOCHS] = 150
245
+ config, model = exp_wave_unet(config, model=model, num_layers=9, channels_extension=16)
246
+ config[DATALOADER][AUGMENTATION] = {
247
+ AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 4096, TRIM_PROB: 0.8},
248
+ AUG_RESCALE: True
249
+ }
250
+ # 11 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
251
+ return config, model
252
+
253
+
254
+ @registered_experiment(major=1008)
255
+ def exp_1008_waveunet(config, model: bool = None, minor=None):
256
+ # CHEAP BASELINE
257
+ config[EPOCHS] = 150
258
+ config, model = exp_wave_unet(config, model=model, num_layers=4, channels_extension=16)
259
+ config[DATALOADER][AUGMENTATION] = {
260
+ AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 1024, TRIM_PROB: 0.8},
261
+ AUG_RESCALE: True
262
+ }
263
+ # 4 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size
264
+ return config, model
265
+
266
+
267
+ @registered_experiment(major=3000)
268
+ def exp_3000_waveunet(config, model: bool = None, minor=None):
269
+ config[EPOCHS] = 120
270
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False)
271
+ # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
272
+ return config, model
273
+
274
+
275
+ @registered_experiment(major=3001)
276
+ def exp_3001_waveunet(config, model: bool = None, minor=None):
277
+ config[EPOCHS] = 200
278
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False)
279
+ # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
280
+ config[SCHEDULER] = "ReduceLROnPlateau"
281
+ config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8)
282
+ config[OPTIMIZER][LEARNING_RATE] = 0.002
283
+ return config, model
284
+
285
+
286
+ @registered_experiment(major=3002)
287
+ def exp_3002_waveunet(config, model: bool = None, minor=None):
288
+ # TRAINED WITH SNR -12db +12db (code changed manually!)
289
+ # See f910c6da3123e3d35cc0ce588bb5a72ce4a8c422
290
+ config[EPOCHS] = 200
291
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False)
292
+ # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
293
+ config[SCHEDULER] = "ReduceLROnPlateau"
294
+ config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8)
295
+ config[OPTIMIZER][LEARNING_RATE] = 0.002
296
+ return config, model
297
+
298
+
299
+ @registered_experiment(major=4000)
300
+ def exp_4000_bias_free_waveunet_l1(config, model: bool = None, minor=None):
301
+ # config[MAX_STEPS_PER_EPOCH] = 2
302
+ # config[BATCH_SIZE] = [2, 2, 2]
303
+ config[EPOCHS] = 200
304
+ config[LOSS] = LOSS_L1
305
+ config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False)
306
+ # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size
307
+ config[SCHEDULER] = "ReduceLROnPlateau"
308
+ config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8)
309
+ config[OPTIMIZER][LEARNING_RATE] = 0.002
310
+ return config, model
311
+
312
+
313
+ def get_experiment_generator(exp_major: int):
314
+ assert exp_major in REGISTERED_EXPERIMENTS_LIST, f"Experiment {exp_major} not registered"
315
+ exp_generator = REGISTERED_EXPERIMENTS_LIST[exp_major]
316
+ return exp_generator
317
+
318
+
319
+ if __name__ == "__main__":
320
+ print(f"Available experiments: {list(REGISTERED_EXPERIMENTS_LIST.keys())}")
src/gyraudio/audio_separation/experiment_tracking/storage.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.properties import SHORT_NAME, MODEL, OPTIMIZER, CURRENT_EPOCH, CONFIGURATION
2
+ from pathlib import Path
3
+ from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
4
+ import logging
5
+ import torch
6
+
7
+
8
+ def get_output_folder(config: dict, root_dir: Path = EXPERIMENT_STORAGE_ROOT, override: bool = False) -> Path:
9
+ output_folder = root_dir/config["short_name"]
10
+ exists = False
11
+ if output_folder.exists():
12
+ if not override:
13
+ logging.info(f"Experiment {config[SHORT_NAME]} already exists. Override is set to False. Skipping.")
14
+ if override:
15
+ logging.warning(f"Experiment {config[SHORT_NAME]} will be OVERRIDDEN")
16
+ exists = True
17
+ else:
18
+ output_folder.mkdir(parents=True, exist_ok=True)
19
+ exists = True
20
+ return exists, output_folder
21
+
22
+
23
+ def checkpoint_paths(exp_dir: Path, epoch=None):
24
+ if epoch is None:
25
+ checkpoints = sorted(exp_dir.glob("model_*.pt"))
26
+ assert len(checkpoints) > 0, f"No checkpoints found in {exp_dir}"
27
+ model_checkpoint = checkpoints[-1]
28
+ epoch = int(model_checkpoint.stem.split("_")[-1])
29
+ optimizer_checkpoint = exp_dir/model_checkpoint.stem.replace("model", "optimizer")
30
+ else:
31
+ model_checkpoint = exp_dir/f"model_{epoch:04d}.pt"
32
+ optimizer_checkpoint = exp_dir/f"optimizer_{epoch:04d}.pt"
33
+ return model_checkpoint, optimizer_checkpoint, epoch
34
+
35
+
36
+ def load_checkpoint(model, exp_dir: Path, optimizer=None, epoch: int = None,
37
+ device="cuda" if torch.cuda.is_available() else "cpu"):
38
+ config = {}
39
+ model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch)
40
+ model_state_dict = torch.load(model_checkpoint, map_location=torch.device(device))
41
+ model.load_state_dict(model_state_dict[MODEL])
42
+ if optimizer is not None:
43
+ optimizer_state_dict = torch.load(optimizer_checkpoint, map_location=torch.device(device))
44
+ optimizer.load_state_dict(optimizer_state_dict[OPTIMIZER])
45
+ config = optimizer_state_dict[CONFIGURATION]
46
+ return model, optimizer, epoch, config
47
+
48
+
49
+ def save_checkpoint(model, exp_dir: Path, optimizer=None, config: dict = {}, epoch: int = None):
50
+ model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch)
51
+ torch.save(
52
+ {
53
+ MODEL: model.state_dict(),
54
+ },
55
+ model_checkpoint
56
+ )
57
+ torch.save(
58
+ {
59
+ CURRENT_EPOCH: epoch,
60
+ CONFIGURATION: config,
61
+ OPTIMIZER: optimizer.state_dict()
62
+ },
63
+ optimizer_checkpoint
64
+ )
65
+ print(f"Checkpoint saved:\n - model: {model_checkpoint}\n - checkpoint: {optimizer_checkpoint}")
src/gyraudio/audio_separation/infer.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.experiment_tracking.experiments import get_experience
2
+ from gyraudio.audio_separation.parser import shared_parser
3
+ from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER
4
+ from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
5
+ from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint
6
+ from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder
7
+ from gyraudio.audio_separation.metrics import snr
8
+ from gyraudio.io.dump import Dump
9
+ from pathlib import Path
10
+ import sys
11
+ import torch
12
+ from tqdm import tqdm
13
+ import torchaudio
14
+ import pandas as pd
15
+ from typing import List
16
+ # Files paths
17
+ DEFAULT_RECORD_FILE = "infer_record.csv" # Store the characteristics of the inference record file
18
+ DEFAULT_EVALUATION_FILE = "eval_df.csv" # Store the characteristics of the inference record file
19
+ # Record keys
20
+ NBATCH = "nb_batch"
21
+ BEST_SNR = "best_snr"
22
+ BEST_SAVE_SNR = "best_save_snr"
23
+ WORST_SNR = "worst_snr"
24
+ WORST_SAVE_SNR = "worst_save_snr"
25
+ RECORD_KEYS = [NAME, SHORT_NAME, CURRENT_EPOCH, NBATCH, SNR_FILTER, BEST_SAVE_SNR, BEST_SNR, WORST_SAVE_SNR, WORST_SNR]
26
+ # Exaluation keys
27
+ SAVE_IDX = "save_idx"
28
+ SNR_IN = "snr_in"
29
+ SNR_OUT = "snr_out"
30
+ EVAL_KEYS = [SAVE_IDX, SNR_IN, SNR_OUT]
31
+
32
+
33
+ def load_file(path: Path, keys: List[str]) -> pd.DataFrame:
34
+ if not (path.exists()):
35
+ df = pd.DataFrame(columns=keys)
36
+ df.to_csv(path)
37
+ return pd.read_csv(path)
38
+
39
+
40
+ def launch_infer(exp: int, snr_filter: list = None, device: str = "cuda", model_dir: Path = None,
41
+ output_dir: Path = EXPERIMENT_STORAGE_ROOT, force_reload=False, max_batches=None,
42
+ ext=".wav"):
43
+ # Load experience
44
+ if snr_filter is not None:
45
+ snr_filter = sorted(snr_filter)
46
+ short_name, model, config, dl = get_experience(exp, snr_filter_test=snr_filter)
47
+ exists, exp_dir = get_output_folder(config, root_dir=model_dir, override=False)
48
+ assert exp_dir.exists(), f"Experiment {short_name} does not exist in {model_dir}"
49
+ model.eval()
50
+ model.to(device)
51
+ model, optimizer, epoch, config_checkpt = load_checkpoint(model, exp_dir, epoch=None, device=device)
52
+ # Folder creation
53
+ if output_dir is not None:
54
+ record_path = output_dir/DEFAULT_RECORD_FILE
55
+ record_df = load_file(record_path, RECORD_KEYS)
56
+
57
+ # Define conditions for filtering
58
+ exist_conditions = {
59
+ NAME: config[NAME],
60
+ SHORT_NAME: config[SHORT_NAME],
61
+ CURRENT_EPOCH: epoch,
62
+ NBATCH: max_batches,
63
+ }
64
+ # Create boolean masks and combine them
65
+ masks = [(record_df[key] == value) for key, value in exist_conditions.items()]
66
+ if snr_filter is None:
67
+ masks.append((record_df[SNR_FILTER]).isnull())
68
+ else:
69
+ masks.append(record_df[SNR_FILTER] == str(snr_filter))
70
+ combined_mask = pd.Series(True, index=record_df.index)
71
+ for mask in masks:
72
+ combined_mask = combined_mask & mask
73
+ filtered_df = record_df[combined_mask]
74
+
75
+ save_dir = output_dir/(exp_dir.name+"_infer" + (f"_epoch_{epoch:04d}_nbatch_{max_batches if max_batches is not None else len(dl[TEST])}")
76
+ + ("" if snr_filter is None else f"_snrs_{'_'.join(map(str, snr_filter))}"))
77
+ evaluation_path = save_dir/DEFAULT_EVALUATION_FILE
78
+ if not (filtered_df.empty) and not (force_reload):
79
+ assert evaluation_path.exists()
80
+ print(f"Inference already exists, see folder {save_dir}")
81
+ record_row_df = filtered_df
82
+ else:
83
+ record_row_df = pd.DataFrame({
84
+ NAME: config[NAME],
85
+ SHORT_NAME: config[SHORT_NAME],
86
+ CURRENT_EPOCH: epoch,
87
+ NBATCH: max_batches,
88
+ SNR_FILTER: [None],
89
+ }, index=[0], columns=RECORD_KEYS)
90
+ record_row_df.at[0, SNR_FILTER] = snr_filter
91
+
92
+ save_dir.mkdir(parents=True, exist_ok=True)
93
+ evaluation_df = load_file(evaluation_path, EVAL_KEYS)
94
+ with torch.no_grad():
95
+ test_loss = 0.
96
+ save_idx = 0
97
+ best_snr = 0
98
+ worst_snr = 0
99
+ processed_batches = 0
100
+ for step_index, (batch_mix, batch_signal, batch_noise) in tqdm(
101
+ enumerate(dl[TEST]), desc=f"Inference epoch {epoch}", total=max_batches if max_batches is not None else len(dl[TEST])):
102
+ batch_mix, batch_signal, batch_noise = batch_mix.to(
103
+ device), batch_signal.to(device), batch_noise.to(device)
104
+ batch_output_signal, _batch_output_noise = model(batch_mix)
105
+ loss = torch.nn.functional.mse_loss(batch_output_signal, batch_signal)
106
+ test_loss += loss.item()
107
+
108
+ # SNR stats
109
+ snr_in = snr(batch_mix, batch_signal, reduce=None)
110
+ snr_out = snr(batch_output_signal, batch_signal, reduce=None)
111
+ best_current, best_idx_current = torch.max(snr_out-snr_in, axis=0)
112
+ worst_current, worst_idx_current = torch.min(snr_out-snr_in, axis=0)
113
+ if best_current > best_snr:
114
+ best_snr = best_current
115
+ best_save_idx = save_idx + best_idx_current
116
+ if worst_current > worst_snr:
117
+ worst_snr = worst_current
118
+ worst_save_idx = save_idx + worst_idx_current
119
+
120
+ # Save by signal
121
+ batch_output_signal = batch_output_signal.detach().cpu()
122
+ batch_signal = batch_signal.detach().cpu()
123
+ batch_mix = batch_mix.detach().cpu()
124
+ for audio_idx in range(batch_output_signal.shape[0]):
125
+ dic = {SAVE_IDX: save_idx, SNR_IN: float(
126
+ snr_in[audio_idx]), SNR_OUT: float(snr_out[audio_idx])}
127
+ new_eval_row = pd.DataFrame(dic, index=[0])
128
+ evaluation_df = pd.concat([new_eval_row, evaluation_df.loc[:]], ignore_index=True)
129
+
130
+ # Save .wav
131
+ torchaudio.save(
132
+ str(save_dir/f"{save_idx:04d}_mixed{ext}"),
133
+ batch_mix[audio_idx, :, :],
134
+ sample_rate=dl[TEST].dataset.sampling_rate,
135
+ channels_first=True
136
+ )
137
+ torchaudio.save(
138
+ str(save_dir/f"{save_idx:04d}_out{ext}"),
139
+ batch_output_signal[audio_idx, :, :],
140
+ sample_rate=dl[TEST].dataset.sampling_rate,
141
+ channels_first=True
142
+ )
143
+ torchaudio.save(
144
+ str(save_dir/f"{save_idx:04d}_original{ext}"),
145
+ batch_signal[audio_idx, :, :],
146
+ sample_rate=dl[TEST].dataset.sampling_rate,
147
+ channels_first=True
148
+ )
149
+ Dump.save_json(dic, save_dir/f"{save_idx:04d}.json")
150
+ save_idx += 1
151
+ processed_batches += 1
152
+ if max_batches is not None and processed_batches >= max_batches:
153
+ break
154
+ test_loss = test_loss/len(dl[TEST])
155
+ evaluation_df.to_csv(evaluation_path)
156
+
157
+ record_row_df[BEST_SAVE_SNR] = int(best_save_idx)
158
+ record_row_df[BEST_SNR] = float(best_snr)
159
+ record_row_df[WORST_SAVE_SNR] = int(worst_save_idx)
160
+ record_row_df[WORST_SNR] = float(worst_snr)
161
+ record_df = pd.concat([record_row_df, record_df.loc[:]], ignore_index=True)
162
+ record_df.to_csv(record_path, index=0)
163
+
164
+ print(f"Test loss: {test_loss:.3e}, \nbest snr performance: {best_save_idx} with {best_snr:.1f}dB, \nworst snr performance: {worst_save_idx} with {worst_snr:.1f}dB")
165
+
166
+ return record_row_df, evaluation_path
167
+
168
+
169
+ def main(argv):
170
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
171
+ parser_def = shared_parser(help="Launch inference on a specific model"
172
+ + ("\n<<<Cuda available>>>" if default_device == "cuda" else ""))
173
+ parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
174
+ parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
175
+ parser_def.add_argument("-d", "--device", type=str, default=default_device,
176
+ help="Training device", choices=["cpu", "cuda"])
177
+ parser_def.add_argument("-r", "--reload", action="store_true",
178
+ help="Force reload files")
179
+ parser_def.add_argument("-b", "--nb-batch", type=int, default=None,
180
+ help="Number of batches to process")
181
+ parser_def.add_argument("-s", "--snr-filter", type=float, nargs="+", default=None,
182
+ help="SNR filters on the inference dataloader")
183
+ parser_def.add_argument("-ext", "--extension", type=str, default=".wav", help="Extension of the audio files to save",
184
+ choices=[".wav", ".mp4"])
185
+ args = parser_def.parse_args(argv)
186
+ for exp in args.experiments:
187
+ launch_infer(
188
+ exp,
189
+ model_dir=Path(args.input_dir),
190
+ output_dir=Path(args.output_dir),
191
+ device=args.device,
192
+ force_reload=args.reload,
193
+ max_batches=args.nb_batch,
194
+ snr_filter=args.snr_filter,
195
+ ext=args.extension
196
+ )
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main(sys.argv[1:])
201
+
202
+ # Example : python src\gyraudio\audio_separation\infer.py -i ./__output_audiosep -e 1002 -d cpu -b 2 -s 4 5 6
src/gyraudio/audio_separation/metrics.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.properties import SIGNAL, NOISE, TOTAL, LOSS_TYPE, COEFFICIENT, SNR
2
+ import torch
3
+
4
+
5
+ def snr(prediction: torch.Tensor, ground_truth: torch.Tensor, reduce="mean") -> torch.Tensor:
6
+ """Compute the SNR between two tensors.
7
+ Args:
8
+ prediction (torch.Tensor): prediction tensor
9
+ ground_truth (torch.Tensor): ground truth tensor
10
+ Returns:
11
+ torch.Tensor: SNR
12
+ """
13
+ power_signal = torch.sum(ground_truth**2, dim=(-2, -1))
14
+ power_error = torch.sum((prediction-ground_truth)**2, dim=(-2, -1))
15
+ eps = torch.finfo(torch.float32).eps
16
+ snr_per_element = 10*torch.log10((power_signal+eps)/(power_error+eps))
17
+ final_snr = torch.mean(snr_per_element) if reduce == "mean" else snr_per_element
18
+ return final_snr
19
+
20
+
21
+ DEFAULT_COST = {
22
+ SIGNAL: {
23
+ COEFFICIENT: 0.5,
24
+ LOSS_TYPE: torch.nn.functional.mse_loss
25
+ },
26
+ NOISE: {
27
+ COEFFICIENT: 0.5,
28
+ LOSS_TYPE: torch.nn.functional.mse_loss
29
+ },
30
+ SNR: {
31
+ LOSS_TYPE: snr
32
+ }
33
+ }
34
+
35
+
36
+ class Costs:
37
+ """Keep track of cost functions.
38
+ ```
39
+ for epoch in range(...):
40
+ metric.reset_epoch()
41
+ for step in dataloader(...):
42
+ ... # forward
43
+ prediction = model.forward(batch)
44
+ metric.update(prediction1, groundtruth1, SIGNAL1)
45
+ metric.update(prediction2, groundtruth2, SIGNAL2)
46
+ loss = metric.finish_step()
47
+
48
+ loss.backward()
49
+ ... # backprop
50
+ metric.finish_epoch()
51
+ ... # log metrics
52
+ ```
53
+ """
54
+
55
+ def __init__(self, name: str, costs=DEFAULT_COST) -> None:
56
+ self.name = name
57
+ self.keys = list(costs.keys())
58
+ self.cost = costs
59
+
60
+ def __reset_step(self) -> None:
61
+ self.metrics = {key: 0. for key in self.keys}
62
+
63
+ def reset_epoch(self) -> None:
64
+ self.__reset_step()
65
+ self.total_metric = {key: 0. for key in self.keys+[TOTAL]}
66
+ self.count = 0
67
+
68
+ def update(self,
69
+ prediction: torch.Tensor,
70
+ ground_truth: torch.Tensor,
71
+ key: str
72
+ ) -> torch.Tensor:
73
+ assert key != TOTAL
74
+ # Compute loss for a single batch (=step)
75
+ loss_signal = self.cost[key][LOSS_TYPE](prediction, ground_truth)
76
+ self.metrics[key] = loss_signal
77
+
78
+ def finish_step(self) -> torch.Tensor:
79
+ # Reset current total
80
+ self.metrics[TOTAL] = 0.
81
+ # Sum all metrics to total
82
+ for key in self.metrics:
83
+ if key != TOTAL and self.cost[key].get(COEFFICIENT, False):
84
+ self.metrics[TOTAL] += self.cost[key][COEFFICIENT]*self.metrics[key]
85
+ loss_signal = self.metrics[TOTAL]
86
+ for key in self.metrics:
87
+ if not isinstance(self.metrics[key], float):
88
+ self.metrics[key] = self.metrics[key].item()
89
+ self.total_metric[key] += self.metrics[key]
90
+ self.count += 1
91
+ return loss_signal
92
+
93
+ def finish_epoch(self) -> None:
94
+ for key in self.metrics:
95
+ self.total_metric[key] /= self.count
96
+
97
+ def __repr__(self) -> str:
98
+ rep = f"{self.name}\t:\t"
99
+ for key in self.total_metric:
100
+ rep += f"{key}: {self.total_metric[key]:.3e} | "
101
+ return rep
src/gyraudio/audio_separation/parser.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def shared_parser(help="Train models for audio separation"):
5
+ parser = argparse.ArgumentParser(description=help)
6
+ parser.add_argument("-e", "--experiments", type=int, nargs="+", required=True,
7
+ help="Experiment ids to be trained sequentially")
8
+ return parser
src/gyraudio/audio_separation/properties.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training modes (Train, Validation, Test)
2
+ TRAIN = "train"
3
+ VALID = "validation"
4
+ TEST = "test"
5
+
6
+ # Dataset properties (keys)
7
+ DATA_PATH = "path"
8
+ BATCH_SIZE = "batch_size"
9
+ SHUFFLE = "shuffle"
10
+ SNR_FILTER = "snr_filter"
11
+ AUGMENTATION = "augmentation"
12
+ DATALOADER = "dataloader"
13
+
14
+
15
+ # Loss split
16
+ SIGNAL = "signal"
17
+ NOISE = "noise"
18
+ TOTAL = "total"
19
+ COEFFICIENT = "coefficient"
20
+
21
+
22
+ # Augmentation types
23
+ AUG_TRIM = "trim" # trim batches to arbitrary length
24
+ AUG_AWGN = "awgn" # add white gaussian noise
25
+ AUG_RESCALE = "rescale" # rescale all signals arbitrarily
26
+
27
+ # Trim types
28
+ LENGTHS = "lengths" # a list of min and max length
29
+ LENGTH_DIVIDER = "length_divider" # an int that divides the length
30
+ TRIM_PROB = "trim_probability" # a float in [0, 1] of trimming probability
31
+
32
+
33
+ # Training configuration properties (keys)
34
+
35
+ OPTIMIZER = "optimizer"
36
+ LEARNING_RATE = "lr"
37
+ WEIGHT_DECAY = "weight_decay"
38
+ BETAS = "betas"
39
+ EPOCHS = "epochs"
40
+ BATCH_SIZE = "batch_size"
41
+ MAX_STEPS_PER_EPOCH = "max_steps_per_epoch"
42
+
43
+
44
+ # Properties for the model
45
+ NAME = "name"
46
+ ANNOTATIONS = "annotations"
47
+ NB_PARAMS = "nb_params"
48
+ RECEPTIVE_FIELD = "receptive_field"
49
+ SHORT_NAME = "short_name"
50
+
51
+
52
+ # Scheduler
53
+ SCHEDULER = "scheduler"
54
+ SCHEDULER_CONFIGURATION = "scheduler_configuration"
55
+
56
+ # Loss
57
+ LOSS = "loss"
58
+ LOSS_L1 = "L1"
59
+ LOSS_L2 = "MSE"
60
+ LOSS_TYPE = "loss_type"
61
+ SNR = "snr"
62
+
63
+ # Checkpoint
64
+ MODEL = "model"
65
+ CURRENT_EPOCH = "current_epoch"
66
+ CONFIGURATION = "configuration"
67
+
68
+
69
+ # Signal names
70
+ CLEAN = "clean"
71
+ NOISY = "noise"
72
+ MIXED = "mixed"
73
+ PREDICTED = "predicted"
74
+
75
+
76
+ # MISC
77
+ PATHS = "paths"
78
+ BUFFERS = "buffers"
79
+ SAMPLING_RATE = "sampling_rate"
src/gyraudio/audio_separation/train.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.experiment_tracking.experiments import get_experience
2
+ from gyraudio.audio_separation.parser import shared_parser
3
+ from gyraudio.audio_separation.properties import (
4
+ TRAIN, TEST, EPOCHS, OPTIMIZER, NAME, MAX_STEPS_PER_EPOCH,
5
+ SIGNAL, NOISE, TOTAL, SNR, SCHEDULER, SCHEDULER_CONFIGURATION,
6
+ LOSS, LOSS_L2, LOSS_L1, LOSS_TYPE, COEFFICIENT
7
+ )
8
+ from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
9
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
10
+ from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder, save_checkpoint
11
+ from gyraudio.audio_separation.metrics import Costs, snr
12
+ # from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint
13
+ from pathlib import Path
14
+ from gyraudio.io.dump import Dump
15
+ import sys
16
+ import torch
17
+ from tqdm import tqdm
18
+ from copy import deepcopy
19
+ import wandb
20
+ import logging
21
+
22
+
23
+ def launch_training(exp: int, wandb_flag: bool = True, device: str = "cuda", save_dir: Path = None, override=False):
24
+
25
+ short_name, model, config, dl = get_experience(exp)
26
+ exists, output_folder = get_output_folder(config, root_dir=save_dir, override=override)
27
+ if not exists:
28
+ logging.warning(f"Skipping experiment {short_name}")
29
+ return False
30
+ else:
31
+ logging.info(f"Experiment {short_name} saved in {output_folder}")
32
+
33
+ print(short_name)
34
+ print(config)
35
+ logging.info(f"Starting training for {short_name}")
36
+ logging.info(f"Config: {config}")
37
+ if wandb_flag:
38
+ wandb.init(
39
+ project="audio-separation",
40
+ entity="teammd",
41
+ name=short_name,
42
+ tags=["debug"],
43
+ config=config
44
+ )
45
+ training_loop(model, config, dl, wandb_flag=wandb_flag, device=device, exp_dir=output_folder)
46
+ if wandb_flag:
47
+ wandb.finish()
48
+ return True
49
+
50
+
51
+ def update_metrics(metrics, phase, pred, gt, pred_noise, gt_noise):
52
+ metrics[phase].update(pred, gt, SIGNAL)
53
+ metrics[phase].update(pred_noise, gt_noise, NOISE)
54
+ metrics[phase].update(pred, gt, SNR)
55
+ loss = metrics[phase].finish_step()
56
+ return loss
57
+
58
+
59
+ def training_loop(model: torch.nn.Module, config: dict, dl, device: str = "cuda", wandb_flag: bool = False,
60
+ exp_dir: Path = None):
61
+ optim_params = deepcopy(config[OPTIMIZER])
62
+ optim_name = optim_params[NAME]
63
+ optim_params.pop(NAME)
64
+ if optim_name == "adam":
65
+ optimizer = torch.optim.Adam(model.parameters(), **optim_params)
66
+
67
+ scheduler = None
68
+ scheduler_config = config.get(SCHEDULER_CONFIGURATION, {})
69
+ scheduler_name = config.get(SCHEDULER, False)
70
+ if scheduler_name:
71
+ if scheduler_name == "ReduceLROnPlateau":
72
+ scheduler = ReduceLROnPlateau(optimizer, mode='max', verbose=True, **scheduler_config)
73
+ logging.info(f"Using scheduler {scheduler_name} with config {scheduler_config}")
74
+ else:
75
+ raise NotImplementedError(f"Scheduler {scheduler_name} not implemented")
76
+ max_steps = config.get(MAX_STEPS_PER_EPOCH, None)
77
+ chosen_loss = config.get(LOSS, LOSS_L2)
78
+ if chosen_loss == LOSS_L2:
79
+ costs = {TRAIN: Costs(TRAIN), TEST: Costs(TEST)}
80
+ elif chosen_loss == LOSS_L1:
81
+ cost_init = {
82
+ SIGNAL: {
83
+ COEFFICIENT: 0.5,
84
+ LOSS_TYPE: torch.nn.functional.l1_loss
85
+ },
86
+ NOISE: {
87
+ COEFFICIENT: 0.5,
88
+ LOSS_TYPE: torch.nn.functional.l1_loss
89
+ },
90
+ SNR: {
91
+ LOSS_TYPE: snr
92
+ }
93
+ }
94
+ costs = {
95
+ TRAIN: Costs(TRAIN, costs=cost_init),
96
+ TEST: Costs(TEST)
97
+ }
98
+ for epoch in range(config[EPOCHS]):
99
+ costs[TRAIN].reset_epoch()
100
+ costs[TEST].reset_epoch()
101
+ model.to(device)
102
+ # Training loop
103
+ # -----------------------------------------------------------
104
+
105
+ metrics = {TRAIN: {}, TEST: {}}
106
+ for step_index, (batch_mix, batch_signal, batch_noise) in tqdm(
107
+ enumerate(dl[TRAIN]), desc=f"Epoch {epoch}", total=len(dl[TRAIN])):
108
+ if max_steps is not None and step_index >= max_steps:
109
+ break
110
+ batch_mix, batch_signal, batch_noise = batch_mix.to(device), batch_signal.to(device), batch_noise.to(device)
111
+ model.zero_grad()
112
+ batch_output_signal, batch_output_noise = model(batch_mix)
113
+ loss = update_metrics(
114
+ costs, TRAIN,
115
+ batch_output_signal, batch_signal,
116
+ batch_output_noise, batch_noise
117
+ )
118
+ # costs[TRAIN].update(batch_output_signal, batch_signal, SIGNAL)
119
+ # costs[TRAIN].update(batch_output_noise, batch_noise, NOISE)
120
+ # loss = costs[TRAIN].finish_step()
121
+ loss.backward()
122
+ optimizer.step()
123
+ costs[TRAIN].finish_epoch()
124
+
125
+ # Validation loop
126
+ # -----------------------------------------------------------
127
+ model.eval()
128
+ torch.cuda.empty_cache()
129
+ with torch.no_grad():
130
+ for step_index, (batch_mix, batch_signal, batch_noise) in tqdm(
131
+ enumerate(dl[TEST]), desc=f"Epoch {epoch}", total=len(dl[TEST])):
132
+ if max_steps is not None and step_index >= max_steps:
133
+ break
134
+ batch_mix, batch_signal, batch_noise = batch_mix.to(
135
+ device), batch_signal.to(device), batch_noise.to(device)
136
+ batch_output_signal, batch_output_noise = model(batch_mix)
137
+ loss = update_metrics(
138
+ costs, TEST,
139
+ batch_output_signal, batch_signal,
140
+ batch_output_noise, batch_noise
141
+ )
142
+ costs[TEST].finish_epoch()
143
+ if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau):
144
+ scheduler.step(costs[TEST].total_metric[SNR])
145
+ print(f"epoch {epoch}:\n{costs[TRAIN]}\n{costs[TEST]}")
146
+ wandblogs = {}
147
+ if wandb_flag:
148
+ for phase in [TRAIN, TEST]:
149
+ wandblogs[f"{phase} loss signal"] = costs[phase].total_metric[SIGNAL]
150
+ wandblogs[f"debug loss/{phase} loss signal"] = costs[phase].total_metric[SIGNAL]
151
+ wandblogs[f"debug loss/{phase} loss total"] = costs[phase].total_metric[TOTAL]
152
+ wandblogs[f"debug loss/{phase} loss noise"] = costs[phase].total_metric[NOISE]
153
+ wandblogs[f"{phase} snr"] = costs[phase].total_metric[SNR]
154
+ wandblogs["learning rate"] = optimizer.param_groups[0]['lr']
155
+ wandb.log(wandblogs)
156
+ metrics[TRAIN] = costs[TRAIN].total_metric
157
+ metrics[TEST] = costs[TEST].total_metric
158
+ Dump.save_json(metrics, exp_dir/f"metrics_{epoch:04d}.json")
159
+ save_checkpoint(model, exp_dir, optimizer, config=config, epoch=epoch)
160
+ torch.cuda.empty_cache()
161
+
162
+
163
+ def main(argv):
164
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
165
+ parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/teammd/audio-separation"
166
+ + ("\n<<<Cuda available>>>" if default_device == "cuda" else ""))
167
+ parser_def.add_argument("-nowb", "--no-wandb", action="store_true")
168
+ parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
169
+ parser_def.add_argument("-f", "--force", action="store_true", help="Override existing experiment")
170
+
171
+ parser_def.add_argument("-d", "--device", type=str, default=default_device,
172
+ help="Training device", choices=["cpu", "cuda"])
173
+ args = parser_def.parse_args(argv)
174
+ for exp in args.experiments:
175
+ launch_training(
176
+ exp, wandb_flag=not args.no_wandb, save_dir=Path(args.output_dir),
177
+ override=args.force,
178
+ device=args.device
179
+ )
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main(sys.argv[1:])
src/gyraudio/audio_separation/visualization/audio_player.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.audio_separation.properties import CLEAN, NOISY, MIXED, PREDICTED, SAMPLING_RATE
2
+ from pathlib import Path
3
+ from gyraudio.io.audio import save_audio_tensor
4
+ from gyraudio import root_dir
5
+ from interactive_pipe import Control, KeyboardControl
6
+ from interactive_pipe import interactive
7
+ import logging
8
+
9
+ HERE = Path(__file__).parent
10
+ MUTE = "mute"
11
+ LOGOS = {
12
+ PREDICTED: HERE/"play_logo_pred.png",
13
+ MIXED: HERE/"play_logo_mixed.png",
14
+ CLEAN: HERE/"play_logo_clean.png",
15
+ NOISY: HERE/"play_logo_noise.png",
16
+ MUTE: HERE/"mute_logo.png",
17
+ }
18
+ ICONS = [it for key, it in LOGOS.items()]
19
+ KEYS = [key for key, it in LOGOS.items()]
20
+
21
+ ping_pong_index = 0
22
+
23
+
24
+ @interactive(
25
+ player=Control(MUTE, KEYS, icons=ICONS))
26
+ def audio_selector(sig, mixed, pred, global_params={}, player=MUTE):
27
+
28
+ global_params["selected_audio"] = player if player != MUTE else global_params.get("selected_audio", MIXED)
29
+ global_params[MUTE] = player == MUTE
30
+ if player == CLEAN:
31
+ audio_track = sig["buffers"][CLEAN]
32
+ elif player == NOISY:
33
+ audio_track = sig["buffers"][NOISY]
34
+ elif player == MIXED:
35
+ audio_track = mixed
36
+ elif player == PREDICTED:
37
+ audio_track = pred
38
+ else:
39
+ audio_track = mixed
40
+ return audio_track
41
+
42
+
43
+ @interactive(
44
+ loop=KeyboardControl(True, keydown="l"))
45
+ def audio_trim(audio_track, global_params={}, loop=True):
46
+ sampling_rate = global_params.get(SAMPLING_RATE, 8000)
47
+ if global_params.get("trim", False):
48
+ start, end = global_params["trim"]["start"], global_params["trim"]["end"]
49
+ remainder = (end-start) % 8
50
+ audio_trim = audio_track[..., start:end-remainder]
51
+ repeat_factor = int(sampling_rate*4./(end-start))
52
+ logging.debug(f"{repeat_factor}")
53
+ repeat_factor = max(1, repeat_factor)
54
+ if loop:
55
+ repeat_factor = 1
56
+ audio_trim = audio_trim.repeat(1, repeat_factor)
57
+ logging.debug(f"{audio_trim.shape}")
58
+ else:
59
+ audio_trim = audio_track
60
+ return audio_trim
61
+
62
+
63
+ @interactive(
64
+ volume=(100, [0, 1000], "volume"),
65
+ )
66
+ def audio_player(audio_trim, global_params={}, volume=100):
67
+ sampling_rate = global_params.get(SAMPLING_RATE, 8000)
68
+ try:
69
+ if global_params.get(MUTE, True):
70
+ global_params["__stop"]()
71
+ print("mute!")
72
+ else:
73
+ ping_pong_path = root_dir/"__ping_pong"
74
+ ping_pong_path.mkdir(exist_ok=True)
75
+ global ping_pong_index
76
+ audio_track_path = ping_pong_path/f"_tmp_{ping_pong_index}.wav"
77
+ ping_pong_index = (ping_pong_index + 1) % 10
78
+ save_audio_tensor(audio_track_path, volume/100.*audio_trim,
79
+ sampling_rate=global_params.get(SAMPLING_RATE, sampling_rate))
80
+ global_params["__set_audio"](audio_track_path)
81
+ global_params["__play"]()
82
+ except Exception as exc:
83
+ logging.warning(f"Exception in audio_player {exc}")
84
+ pass
src/gyraudio/audio_separation/visualization/interactive_audio.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from batch_processing import Batch
2
+ import argparse
3
+ from pathlib import Path
4
+ from gyraudio.audio_separation.experiment_tracking.experiments import get_experience
5
+ from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder
6
+ from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
7
+ from gyraudio.audio_separation.properties import (
8
+ SHORT_NAME, CLEAN, NOISY, MIXED, PREDICTED, ANNOTATIONS, PATHS, BUFFERS, SAMPLING_RATE, NAME
9
+ )
10
+ import torch
11
+ from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint
12
+ from gyraudio.audio_separation.visualization.pre_load_audio import (
13
+ parse_command_line_audio_load, load_buffers, audio_loading_batch)
14
+ from gyraudio.audio_separation.visualization.pre_load_custom_audio import (
15
+ parse_command_line_generic_audio_load, generic_audio_loading_batch,
16
+ load_buffers_custom
17
+ )
18
+ from torchaudio.functional import resample
19
+ from typing import List
20
+ import numpy as np
21
+ import logging
22
+ from interactive_pipe.data_objects.curves import Curve, SingleCurve
23
+ from interactive_pipe import interactive, KeyboardControl, Control
24
+ from interactive_pipe.headless.pipeline import HeadlessPipeline
25
+ from interactive_pipe.graphical.qt_gui import InteractivePipeQT
26
+ from interactive_pipe.graphical.mpl_gui import InteractivePipeMatplotlib
27
+ from gyraudio.audio_separation.visualization.audio_player import audio_selector, audio_trim, audio_player
28
+
29
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ LEARNT_SAMPLING_RATE = 8000
31
+
32
+
33
+ @interactive(
34
+ idx=KeyboardControl(value_default=0, value_range=[
35
+ 0, 1000], modulo=True, keyup="8", keydown="2"),
36
+ idn=KeyboardControl(value_default=0, value_range=[
37
+ 0, 1000], modulo=True, keyup="9", keydown="3")
38
+ )
39
+ def signal_selector(signals, idx=0, idn=0, global_params={}):
40
+ if isinstance(signals, dict):
41
+ clean_sigs = signals[CLEAN]
42
+ clean = clean_sigs[idx % len(clean_sigs)]
43
+ if BUFFERS not in clean:
44
+ load_buffers_custom(clean)
45
+ noise_sigs = signals[NOISY]
46
+ noise = noise_sigs[idn % len(noise_sigs)]
47
+ if BUFFERS not in noise:
48
+ load_buffers_custom(noise)
49
+ cbuf, nbuf = clean[BUFFERS], noise[BUFFERS]
50
+ if clean[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
51
+ cbuf = resample(cbuf, clean[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
52
+ clean[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
53
+ if noise[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
54
+ nbuf = resample(nbuf, noise[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
55
+ noise[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
56
+ min_length = min(cbuf.shape[-1], nbuf.shape[-1])
57
+ min_length = min_length - min_length % 1024
58
+ signal = {
59
+ PATHS: {
60
+ CLEAN: clean[PATHS],
61
+ NOISY: noise[PATHS]
62
+
63
+ },
64
+ BUFFERS: {
65
+ CLEAN: cbuf[..., :1, :min_length],
66
+ NOISY: nbuf[..., :1, :min_length],
67
+ },
68
+ NAME: f"Clean={clean[NAME]} | Noise={noise[NAME]}",
69
+ SAMPLING_RATE: LEARNT_SAMPLING_RATE
70
+ }
71
+ else:
72
+ # signals are loaded in CPU
73
+ signal = signals[idx % len(signals)]
74
+ if BUFFERS not in signal:
75
+ load_buffers(signal)
76
+ global_params["premixed_snr"] = signal.get("premixed_snr", None)
77
+ signal[NAME] = f"File={signal[NAME]}"
78
+ global_params["selected_info"] = signal[NAME]
79
+ global_params[SAMPLING_RATE] = signal[SAMPLING_RATE]
80
+ return signal
81
+
82
+
83
+ @interactive(
84
+ snr=(0., [-10., 10.], "SNR [dB]")
85
+ )
86
+ def remix(signals, snr=0., global_params={}):
87
+ signal = signals[BUFFERS][CLEAN]
88
+ noisy = signals[BUFFERS][NOISY]
89
+ alpha = 10 ** (-snr / 20) * torch.norm(signal) / torch.norm(noisy)
90
+ mixed_signal = signal + alpha * noisy
91
+ global_params["snr"] = snr
92
+ return mixed_signal
93
+
94
+
95
+ @interactive(std_dev=Control(0., value_range=[0., 0.1], name="extra noise std", step=0.0001),
96
+ amplify=(1., [0., 10.], "amplification of everything"))
97
+ def augment(signals, mixed, std_dev=0., amplify=1.):
98
+ signals[BUFFERS][MIXED] *= amplify
99
+ signals[BUFFERS][NOISY] *= amplify
100
+ signals[BUFFERS][CLEAN] *= amplify
101
+ mixed = mixed*amplify+torch.randn_like(mixed)*std_dev
102
+ return signals, mixed
103
+
104
+
105
+ # @interactive(
106
+ # device=("cuda", ["cpu", "cuda"]) if default_device == "cuda" else ("cpu", ["cpu"])
107
+ # )
108
+ def select_device(device=default_device, global_params={}):
109
+ global_params["device"] = device
110
+
111
+
112
+ @interactive(
113
+ model=KeyboardControl(value_default=0, value_range=[
114
+ 0, 99], keyup="pagedown", keydown="pageup")
115
+ )
116
+ def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}):
117
+ selected_model = models[model % len(models)]
118
+ config = configs[model % len(models)]
119
+ short_name = config.get(SHORT_NAME, "")
120
+ annotations = config.get(ANNOTATIONS, "")
121
+ device = global_params.get("device", "cpu")
122
+ with torch.no_grad():
123
+ selected_model.eval()
124
+ selected_model.to(device)
125
+ predicted_signal, predicted_noise = selected_model(
126
+ mixed.to(device).unsqueeze(0))
127
+ predicted_signal = predicted_signal.squeeze(0)
128
+ pred_curve = SingleCurve(y=predicted_signal[0, :].detach().cpu().numpy(),
129
+ style="g-", label=f"predicted_{short_name} {annotations}")
130
+ return predicted_signal, pred_curve
131
+
132
+
133
+ def compute_metrics(pred, sig, global_params={}):
134
+ METRICS = "metrics"
135
+ target = sig[BUFFERS][CLEAN]
136
+ global_params[METRICS] = {}
137
+ global_params[METRICS]["MSE"] = torch.mean((target-pred.cpu())**2)
138
+ global_params[METRICS]["SNR"] = 10. * \
139
+ torch.log10(torch.sum(target**2)/torch.sum((target-pred.cpu())**2))
140
+
141
+
142
+ def get_trim(sig, zoom, center, num_samples=300):
143
+ N = len(sig)
144
+ native_ds = N/num_samples
145
+ center_idx = int(center*N)
146
+ window = int(num_samples/zoom*native_ds)
147
+ start_idx = max(0, center_idx - window//2)
148
+ end_idx = min(N, center_idx + window//2)
149
+ skip_factor = max(1, int(native_ds/zoom))
150
+ return start_idx, end_idx, skip_factor
151
+
152
+
153
+ def zin(sig, zoom, center, num_samples=300):
154
+ start_idx, end_idx, skip_factor = get_trim(
155
+ sig, zoom, center, num_samples=num_samples)
156
+ out = np.zeros(num_samples)
157
+ trimmed = sig[start_idx:end_idx:skip_factor]
158
+ out[:len(trimmed)] = trimmed[:num_samples]
159
+ return out
160
+
161
+
162
+ @interactive(
163
+ center=KeyboardControl(value_default=0.5, value_range=[
164
+ 0., 1.], step=0.01, keyup="6", keydown="4"),
165
+ zoom=KeyboardControl(value_default=0., value_range=[
166
+ 0., 15.], step=1, keyup="+", keydown="-"),
167
+ zoomy=KeyboardControl(
168
+ value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down")
169
+ )
170
+ def visualize_audio(signal: dict, mixed_signal, pred, zoom=1, zoomy=0., center=0.5, global_params={}):
171
+ """Create curves
172
+ """
173
+ zval = 1.5**zoom
174
+ start_idx, end_idx, _skip_factor = get_trim(
175
+ signal[BUFFERS][CLEAN][0, :], zval, center)
176
+ global_params["trim"] = dict(start=start_idx, end=end_idx)
177
+ selected = global_params.get("selected_audio", MIXED)
178
+ clean = SingleCurve(y=zin(signal[BUFFERS][CLEAN][0, :], zval, center),
179
+ alpha=1.,
180
+ style="k-",
181
+ linewidth=0.9,
182
+ label=("*" if selected == CLEAN else " ")+"clean")
183
+ noisy = SingleCurve(y=zin(signal[BUFFERS][NOISY][0, :], zval, center),
184
+ alpha=0.3,
185
+ style="y--",
186
+ linewidth=1,
187
+ label=("*" if selected == NOISY else " ") + "noisy"
188
+ )
189
+ mixed = SingleCurve(y=zin(mixed_signal[0, :], zval, center), style="r-",
190
+ alpha=0.1,
191
+ linewidth=2,
192
+ label=("*" if selected == MIXED else " ") + "mixed")
193
+ # true_mixed = SingleCurve(y=zin(signal[BUFFERS][MIXED][0, :], zval, center),
194
+ # alpha=0.3, style="b-", linewidth=1, label="true mixed")
195
+ pred.y = zin(pred.y, zval, center)
196
+ pred.label = ("*" if selected == PREDICTED else " ") + pred.label
197
+ curves = [noisy, mixed, pred, clean]
198
+ title = f"SNR in {global_params['snr']:.1f} dB"
199
+ if "selected_info" in global_params:
200
+ title += f" | {global_params['selected_info']}"
201
+ title += "\n"
202
+ for metric_name, metric_value in global_params.get("metrics", {}).items():
203
+ title += f" | {metric_name} "
204
+ title += f"{metric_value:.2e}" if (abs(metric_value) < 1e-2 or abs(metric_value)
205
+ > 1000) else f"{metric_value:.2f}"
206
+ # if global_params.get("premixed_snr", None) is not None:
207
+ # title += f"| Premixed SNR : {global_params['premixed_snr']:.1f} dB"
208
+ return Curve(curves, ylim=[-0.04 * 1.5 ** zoomy, 0.04 * 1.5 ** zoomy], xlabel="Time index", ylabel="Amplitude", title=title)
209
+
210
+
211
+ def interactive_audio_separation_processing(signals, model_list, config_list):
212
+ sig = signal_selector(signals)
213
+ mixed = remix(sig)
214
+ # sig, mixed = augment(sig, mixed)
215
+ select_device()
216
+ pred, pred_curve = audio_sep_inference(mixed, model_list, config_list)
217
+ compute_metrics(pred, sig)
218
+ sound = audio_selector(sig, mixed, pred)
219
+ curve = visualize_audio(sig, mixed, pred_curve)
220
+ trimmed_sound = audio_trim(sound)
221
+ audio_player(trimmed_sound)
222
+ return curve
223
+
224
+
225
+ def interactive_audio_separation_visualization(
226
+ all_signals: List[dict],
227
+ model_list: List[torch.nn.Module],
228
+ config_list: List[dict],
229
+ gui="qt"
230
+ ):
231
+ pip = HeadlessPipeline.from_function(
232
+ interactive_audio_separation_processing, cache=False)
233
+ if gui == "qt":
234
+ app = InteractivePipeQT(
235
+ pipeline=pip, name="audio separation", size=(1000, 1000), audio=True)
236
+ else:
237
+ logging.warning("No support for audio player with Matplotlib")
238
+ app = InteractivePipeMatplotlib(
239
+ pipeline=pip, name="audio separation", size=None, audio=False)
240
+ app(all_signals, model_list, config_list)
241
+
242
+
243
+ def visualization(
244
+ all_signals: List[dict],
245
+ model_list: List[torch.nn.Module],
246
+ config_list: List[dict],
247
+ device="cuda"
248
+ ):
249
+ for signal in all_signals:
250
+ if BUFFERS not in signal:
251
+ load_buffers(signal, device="cpu")
252
+ clean = SingleCurve(y=signal[BUFFERS][CLEAN][0, :], label="clean")
253
+ noisy = SingleCurve(y=signal[BUFFERS][NOISY]
254
+ [0, :], label="noise", alpha=0.3)
255
+ curves = [clean, noisy]
256
+ for config, model in zip(config_list, model_list):
257
+ short_name = config.get(SHORT_NAME, "unknown")
258
+ predicted_signal, predicted_noise = model(
259
+ signal[BUFFERS][MIXED].to(device).unsqueeze(0))
260
+ predicted = SingleCurve(y=predicted_signal.squeeze(0)[0, :].detach().cpu().numpy(),
261
+ label=f"predicted_{short_name}")
262
+ curves.append(predicted)
263
+ Curve(curves).show()
264
+
265
+
266
+ def parse_command_line(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser:
267
+ if gradio_demo:
268
+ parser = parse_command_line_gradio(parser)
269
+ else:
270
+ parser = parse_command_line_generic(parser)
271
+ return parser
272
+
273
+
274
+ def parse_command_line_gradio(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser:
275
+ if parser is None:
276
+ parser = parse_command_line_audio_load()
277
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
278
+ iparse = parser.add_argument_group("Audio separation visualization")
279
+ iparse.add_argument("-e", "--experiments", type=int, nargs="+", default=3001,
280
+ help="Experiment ids to be inferred sequentially")
281
+ iparse.add_argument("-p", "--interactive", default=True,
282
+ action="store_true", help="Play = Interactive mode")
283
+ iparse.add_argument("-m", "--model-root", type=str,
284
+ default=EXPERIMENT_STORAGE_ROOT)
285
+ iparse.add_argument("-d", "--device", type=str, default=default_device,
286
+ choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"])
287
+ iparse.add_argument("-gui", "--gui", type=str,
288
+ default="gradio", choices=["qt", "mpl", "gradio"])
289
+ pri
290
+ return parser
291
+
292
+
293
+ def parse_command_line_generic(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser:
294
+ if parser is None:
295
+ parser = parse_command_line_audio_load()
296
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
297
+ iparse = parser.add_argument_group("Audio separation visualization")
298
+ iparse.add_argument("-e", "--experiments", type=int, nargs="+", required=True,
299
+ help="Experiment ids to be inferred sequentially")
300
+ iparse.add_argument("-p", "--interactive",
301
+ action="store_true", help="Play = Interactive mode")
302
+ iparse.add_argument("-m", "--model-root", type=str,
303
+ default=EXPERIMENT_STORAGE_ROOT)
304
+ iparse.add_argument("-d", "--device", type=str, default=default_device,
305
+ choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"])
306
+ iparse.add_argument("-gui", "--gui", type=str,
307
+ default="qt", choices=["qt", "mpl", "gradio"])
308
+ return parser
309
+
310
+
311
+ def main(argv: List[str]):
312
+ """Paired signals and noise in folders"""
313
+ batch = Batch(argv)
314
+ batch.set_io_description(
315
+ input_help='input audio files',
316
+ output_help=argparse.SUPPRESS
317
+ )
318
+ batch.set_multiprocessing_enabled(False)
319
+ parser = parse_command_line()
320
+ args = batch.parse_args(parser)
321
+ exp = args.experiments[0]
322
+ device = args.device
323
+ models_list = []
324
+ config_list = []
325
+ logging.info(f"Loading experiments models {args.experiments}")
326
+ for exp in args.experiments:
327
+ model_dir = Path(args.model_root)
328
+ short_name, model, config, _dl = get_experience(exp)
329
+ _, exp_dir = get_output_folder(
330
+ config, root_dir=model_dir, override=False)
331
+ assert exp_dir.exists(
332
+ ), f"Experiment {short_name} does not exist in {model_dir}"
333
+ model.eval()
334
+ model.to(device)
335
+ model, __optimizer, epoch, config = load_checkpoint(
336
+ model, exp_dir, epoch=None, device=args.device)
337
+ config[SHORT_NAME] = short_name
338
+ models_list.append(model)
339
+ config_list.append(config)
340
+ logging.info("Load audio buffers:")
341
+ all_signals = batch.run(audio_loading_batch)
342
+ if not args.interactive:
343
+ visualization(all_signals, models_list, config_list, device=device)
344
+ else:
345
+ interactive_audio_separation_visualization(
346
+ all_signals, models_list, config_list, gui=args.gui)
347
+
348
+
349
+ def main_custom(argv: List[str]):
350
+ """Handle custom noise and custom signals
351
+ """
352
+ parser = parse_command_line()
353
+ parser.add_argument("-s", "--signal", type=str, required=True,
354
+ nargs="+", help="Signal to be preloaded")
355
+ parser.add_argument("-n", "--noise", type=str, required=True,
356
+ nargs="+", help="Noise to be preloaded")
357
+ args = parser.parse_args(argv)
358
+ exp = args.experiments[0]
359
+ device = args.device
360
+ models_list = []
361
+ config_list = []
362
+ logging.info(f"Loading experiments models {args.experiments}")
363
+ for exp in args.experiments:
364
+ model_dir = Path(args.model_root)
365
+ short_name, model, config, _dl = get_experience(exp)
366
+ _, exp_dir = get_output_folder(
367
+ config, root_dir=model_dir, override=False)
368
+ assert exp_dir.exists(
369
+ ), f"Experiment {short_name} does not exist in {model_dir}"
370
+ model.eval()
371
+ model.to(device)
372
+ model, __optimizer, epoch, config = load_checkpoint(
373
+ model, exp_dir, epoch=None, device=args.device)
374
+ config[SHORT_NAME] = short_name
375
+ models_list.append(model)
376
+ config_list.append(config)
377
+ all_signals = {}
378
+ for args_paths, key in zip([args.signal, args.noise], [CLEAN, NOISY]):
379
+ new_argv = ["-i"] + args_paths
380
+ if args.preload:
381
+ new_argv += ["--preload"]
382
+ batch = Batch(new_argv)
383
+ new_parser = parse_command_line_generic_audio_load()
384
+ batch.set_io_description(
385
+ input_help=argparse.SUPPRESS, # 'input audio files',
386
+ output_help=argparse.SUPPRESS
387
+ )
388
+ batch.set_multiprocessing_enabled(False)
389
+ _ = batch.parse_args(new_parser)
390
+ all_signals[key] = batch.run(generic_audio_loading_batch)
391
+ interactive_audio_separation_visualization(
392
+ all_signals, models_list, config_list, gui=args.gui)
src/gyraudio/audio_separation/visualization/interactive_infer.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
2
+ from gyraudio.audio_separation.parser import shared_parser
3
+ from gyraudio.audio_separation.infer import launch_infer, RECORD_KEYS, SNR_OUT, SNR_IN, NBATCH, SAVE_IDX
4
+ from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER
5
+ import sys
6
+ import os
7
+ from dash import Dash, html, dcc, callback, Output, Input, dash_table
8
+ import plotly.express as px
9
+ import plotly.graph_objects as go
10
+ from plotly.subplots import make_subplots
11
+ import pandas as pd
12
+ from typing import List
13
+ import torch
14
+ from pathlib import Path
15
+ DIFF_SNR = 'SNR out - SNR in'
16
+
17
+
18
+
19
+ def get_app(record_row_dfs : pd.DataFrame, eval_dfs : List[pd.DataFrame]) :
20
+ app = Dash(__name__)
21
+ # names_options = [{'label' : f"{record[SHORT_NAME]} - {record[NAME]} epoch {record[CURRENT_EPOCH]:04d}", 'value' : record[NAME] } for idx, record in record_row_dfs.iterrows()]
22
+ app.layout = html.Div([
23
+ html.H1(children='Inference results', style={'textAlign':'center'}),
24
+ # dcc.Dropdown(names_options, names_options[0]['value'], id='exp-selection'),
25
+ # dcc.RadioItems(['scatter', 'box'], 'box', inline=True, id='radio-plot-type'),
26
+ dcc.RadioItems([SNR_OUT, DIFF_SNR], DIFF_SNR, inline=True, id='radio-plot-out'),
27
+ dcc.Graph(id='graph-content')
28
+ ])
29
+
30
+ @callback(
31
+ Output('graph-content', 'figure'),
32
+ # Input('exp-selection', 'value'),
33
+ # Input('radio-plot-type', 'value'),
34
+ Input('radio-plot-out', 'value'),
35
+ )
36
+ def update_graph(radio_plot_out) :
37
+ fig = make_subplots(rows = 2, cols = 1)
38
+ colors = px.colors.qualitative.Plotly
39
+ for id, record in record_row_dfs.iterrows() :
40
+ color = colors[id % len(colors)]
41
+ eval_df = eval_dfs[id].sort_values(by=SNR_IN)
42
+ eval_df[DIFF_SNR] = eval_df[SNR_OUT] - eval_df[SNR_IN]
43
+ legend = f'{record[SHORT_NAME]}_{record[NAME]}'
44
+ fig.add_trace(
45
+ go.Scatter(
46
+ x=eval_df[SNR_IN],
47
+ y=eval_df[radio_plot_out],
48
+ mode="markers", marker={'color' : color},
49
+ name=legend,
50
+ hovertemplate = 'File : %{text}'+
51
+ '<br>%{y}<br>',
52
+ text = [f"{eval[SAVE_IDX]:.0f}" for idx, eval in eval_df.iterrows()]
53
+ ),
54
+ row = 1, col = 1
55
+ )
56
+ eval_df_bins = eval_df
57
+ eval_df_bins[SNR_IN] = eval_df_bins[SNR_IN].apply(lambda snr : round(snr))
58
+ fig.add_trace(
59
+ go.Box(
60
+ x=eval_df[SNR_IN],
61
+ y=eval_df[radio_plot_out],
62
+ fillcolor = color,
63
+ marker={'color' : color},
64
+ name = legend
65
+ ),
66
+ row = 2, col = 1
67
+ )
68
+
69
+ title = f"SNR performances"
70
+ fig.update_layout(
71
+ title=title,
72
+ xaxis2_title = SNR_IN,
73
+ yaxis_title = radio_plot_out,
74
+ hovermode='x unified'
75
+ )
76
+ return fig
77
+
78
+
79
+
80
+ return app
81
+
82
+
83
+ def main(argv):
84
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
85
+ parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/balthazarneveu/audio-sep"
86
+ + ("\n<<<Cuda available>>>" if default_device == "cuda" else ""))
87
+ parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
88
+ parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
89
+ parser_def.add_argument("-d", "--device", type=str, default=default_device,
90
+ help="Training device", choices=["cpu", "cuda"])
91
+ parser_def.add_argument("-b", "--nb-batch", type=int, default=None,
92
+ help="Number of batches to process")
93
+ parser_def.add_argument("-s", "--snr-filter", type=float, nargs="+", default=None,
94
+ help="SNR filters on the inference dataloader")
95
+ args = parser_def.parse_args(argv)
96
+ record_row_dfs = pd.DataFrame(columns = RECORD_KEYS)
97
+ eval_dfs = []
98
+ for exp in args.experiments:
99
+ record_row_df, evaluation_path = launch_infer(
100
+ exp,
101
+ model_dir=Path(args.input_dir),
102
+ output_dir=Path(args.output_dir),
103
+ device=args.device,
104
+ max_batches=args.nb_batch,
105
+ snr_filter=args.snr_filter
106
+ )
107
+ eval_df = pd.read_csv(evaluation_path)
108
+ # Careful, list order for concat is important for index matching eval_dfs list
109
+ record_row_dfs = pd.concat([record_row_dfs.loc[:], record_row_df], ignore_index=True)
110
+ eval_dfs.append(eval_df)
111
+ app = get_app(record_row_dfs, eval_dfs)
112
+ app.run(debug=True)
113
+
114
+
115
+ if __name__ == '__main__':
116
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
117
+ main(sys.argv[1:])
src/gyraudio/audio_separation/visualization/mute_logo.png ADDED
src/gyraudio/audio_separation/visualization/play_logo_clean.png ADDED
src/gyraudio/audio_separation/visualization/play_logo_mixed.png ADDED
src/gyraudio/audio_separation/visualization/play_logo_noise.png ADDED
src/gyraudio/audio_separation/visualization/play_logo_pred.png ADDED
src/gyraudio/audio_separation/visualization/pre_load_audio.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from batch_processing import Batch
2
+ import argparse
3
+ import sys
4
+ from pathlib import Path
5
+ from gyraudio.audio_separation.properties import CLEAN, NOISY, MIXED, PATHS, BUFFERS, NAME, SAMPLING_RATE
6
+ from gyraudio.io.audio import load_audio_tensor
7
+
8
+
9
+ def parse_command_line_audio_load() -> argparse.ArgumentParser:
10
+ parser = argparse.ArgumentParser(description='Batch audio processing',
11
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
+
13
+ parser.add_argument("-preload", "--preload", action="store_true", help="Preload audio files")
14
+ return parser
15
+
16
+
17
+ def outp(path: Path, suffix: str, extension=".wav"):
18
+ return (path.parent / (path.stem + suffix)).with_suffix(extension)
19
+
20
+
21
+ def load_buffers(signal: dict, device="cpu") -> None:
22
+ clean_signal, sampling_rate = load_audio_tensor(signal[PATHS][CLEAN], device=device)
23
+ noisy_signal, sampling_rate = load_audio_tensor(signal[PATHS][NOISY], device=device)
24
+ mixed_signal, sampling_rate = load_audio_tensor(signal[PATHS][MIXED], device=device)
25
+ signal[BUFFERS] = {
26
+ CLEAN: clean_signal,
27
+ NOISY: noisy_signal,
28
+ MIXED: mixed_signal
29
+ }
30
+ signal[SAMPLING_RATE] = sampling_rate
31
+
32
+
33
+ def audio_loading(input: Path, preload: bool) -> dict:
34
+ name = input.name
35
+ clean_audio_path = input/"voice.wav"
36
+ noisy_audio_path = input/"noise.wav"
37
+ mixed_audio_path = list(input.glob("mix*.wav"))[0]
38
+ signal = {
39
+ NAME: name,
40
+ PATHS: {
41
+ CLEAN: clean_audio_path,
42
+ NOISY: noisy_audio_path,
43
+ MIXED: mixed_audio_path
44
+ }
45
+ }
46
+ signal["premixed_snr"] = float(mixed_audio_path.stem.split("_")[-1])
47
+ if preload:
48
+ load_buffers(signal)
49
+ return signal
50
+
51
+
52
+ def audio_loading_batch(input: Path, args: argparse.Namespace) -> dict:
53
+ """Wrapper to load audio files from a directory using batch_processing
54
+ """
55
+ return audio_loading(input, preload=args.preload)
56
+
57
+
58
+ def main(argv):
59
+ batch = Batch(argv)
60
+ batch.set_io_description(
61
+ input_help='input audio files',
62
+ output_help=argparse.SUPPRESS
63
+ )
64
+ parser = parse_command_line_audio_load()
65
+ batch.parse_args(parser)
66
+ all_signals = batch.run(audio_loading_batch)
67
+ return all_signals
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main(sys.argv[1:])
src/gyraudio/audio_separation/visualization/pre_load_custom_audio.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from batch_processing import Batch
2
+ import argparse
3
+ import sys
4
+ from pathlib import Path
5
+ from gyraudio.audio_separation.properties import PATHS, BUFFERS, NAME, SAMPLING_RATE
6
+ from gyraudio.io.audio import load_audio_tensor
7
+
8
+
9
+ def parse_command_line_generic_audio_load() -> argparse.ArgumentParser:
10
+ parser = argparse.ArgumentParser(description='Batch audio loading',
11
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
+
13
+ parser.add_argument("-preload", "--preload", action="store_true", help="Preload audio files")
14
+ return parser
15
+
16
+
17
+ def load_buffers_custom(signal: dict, device="cpu") -> None:
18
+ generic_signal, sampling_rate = load_audio_tensor(signal[PATHS], device=device)
19
+ signal[BUFFERS] = generic_signal
20
+ signal[SAMPLING_RATE] = sampling_rate
21
+
22
+
23
+ def audio_loading(input: Path, preload: bool) -> dict:
24
+ name = input.parent.name + "/" + input.stem
25
+ signal = {
26
+ NAME: name,
27
+ PATHS: input,
28
+ }
29
+ if preload:
30
+ load_buffers_custom(signal)
31
+ return signal
32
+
33
+
34
+ def generic_audio_loading_batch(input: Path, args: argparse.Namespace) -> dict:
35
+ """Wrapper to load audio files from a directory using batch_processing
36
+ """
37
+ return audio_loading(input, preload=args.preload)
38
+
39
+
40
+ def main(argv):
41
+ batch = Batch(argv)
42
+ batch.set_io_description(
43
+ input_help='input audio files',
44
+ output_help=argparse.SUPPRESS
45
+ )
46
+ parser = parse_command_line_generic_audio_load()
47
+ batch.parse_args(parser)
48
+ all_signals = batch.run(generic_audio_loading_batch)
49
+ return all_signals
50
+
51
+
52
+ if __name__ == "__main__":
53
+ main(sys.argv[1:])