diff --git a/__data_source_separation/source_separation/test/0000/mix_snr_-4.wav b/__data_source_separation/source_separation/test/0000/mix_snr_-4.wav new file mode 100644 index 0000000000000000000000000000000000000000..b2e6c3bfb1b2394ff89a3f9f4101a5e44682804f Binary files /dev/null and b/__data_source_separation/source_separation/test/0000/mix_snr_-4.wav differ diff --git a/__data_source_separation/source_separation/test/0000/noise.wav b/__data_source_separation/source_separation/test/0000/noise.wav new file mode 100644 index 0000000000000000000000000000000000000000..50b439ed797541c70fedee2bee9a198572d3fbab Binary files /dev/null and b/__data_source_separation/source_separation/test/0000/noise.wav differ diff --git a/__data_source_separation/source_separation/test/0000/voice.wav b/__data_source_separation/source_separation/test/0000/voice.wav new file mode 100644 index 0000000000000000000000000000000000000000..b1d1b1b19a81976b288557414723c4845d267adc Binary files /dev/null and b/__data_source_separation/source_separation/test/0000/voice.wav differ diff --git a/__data_source_separation/source_separation/test/0001/mix_snr_2.wav b/__data_source_separation/source_separation/test/0001/mix_snr_2.wav new file mode 100644 index 0000000000000000000000000000000000000000..691a32f0411a3cd22fab2cae5c2cc3b23bf7983d Binary files /dev/null and b/__data_source_separation/source_separation/test/0001/mix_snr_2.wav differ diff --git a/__data_source_separation/source_separation/test/0001/noise.wav b/__data_source_separation/source_separation/test/0001/noise.wav new file mode 100644 index 0000000000000000000000000000000000000000..ec7155a76ea4ad760805a3333b0d4aebabf7f56f Binary files /dev/null and b/__data_source_separation/source_separation/test/0001/noise.wav differ diff --git a/__data_source_separation/source_separation/test/0001/voice.wav b/__data_source_separation/source_separation/test/0001/voice.wav new file mode 100644 index 0000000000000000000000000000000000000000..fb687f9c81e744974b17c32c4360f67fc88f4ff0 Binary files /dev/null and b/__data_source_separation/source_separation/test/0001/voice.wav differ diff --git a/__output_audiosep/0004_0000/model_0059.pt b/__output_audiosep/0004_0000/model_0059.pt new file mode 100644 index 0000000000000000000000000000000000000000..9df73d2b10711d150af3e7f6951ae74b19462e5b --- /dev/null +++ b/__output_audiosep/0004_0000/model_0059.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef62b0fdd9e9b81000da0db190be1df7e5451d70ab3c86609cab409ec3e38ab8 +size 34402 diff --git a/__output_audiosep/1004_0000/model_0119.pt b/__output_audiosep/1004_0000/model_0119.pt new file mode 100644 index 0000000000000000000000000000000000000000..486897ba8edf9b6029db12850c946b9f2b1d7038 --- /dev/null +++ b/__output_audiosep/1004_0000/model_0119.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61768dbf9e81845dc656c5da3e8fe5ea053f73108f422caf7879b4ab55a0792a +size 12755810 diff --git a/__output_audiosep/3001_0000/model_0199.pt b/__output_audiosep/3001_0000/model_0199.pt new file mode 100644 index 0000000000000000000000000000000000000000..ab529c812b2c49e869d122738bde06f82e26f2dd --- /dev/null +++ b/__output_audiosep/3001_0000/model_0199.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11341d63bad0a1ec15c5a94c5cc6f049720869fb9988da45ef7d07edd3c82e21 +size 12743211 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ddedb381eb376aacb4eac782b82a861a23f980 --- /dev/null +++ b/app.py @@ -0,0 +1,7 @@ +import sys +import os +src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) +os.sys.path.append(src_path) +from gyraudio.audio_separation.visualization.interactive_audio import main as interactive_audio_main +if __name__ == "__main__": + interactive_audio_main(sys.argv[1:]) diff --git a/audio_samples/0009/mix_snr_-1.wav b/audio_samples/0009/mix_snr_-1.wav new file mode 100644 index 0000000000000000000000000000000000000000..eabbc1abc48115022ec91784d5526783caf7c658 Binary files /dev/null and b/audio_samples/0009/mix_snr_-1.wav differ diff --git a/audio_samples/0009/noise.wav b/audio_samples/0009/noise.wav new file mode 100644 index 0000000000000000000000000000000000000000..ad31680043a11f3288ff0d6d558681c078f77dcd Binary files /dev/null and b/audio_samples/0009/noise.wav differ diff --git a/audio_samples/0009/voice.wav b/audio_samples/0009/voice.wav new file mode 100644 index 0000000000000000000000000000000000000000..e08679c58aadd0313565ed475d126a7e234551d5 Binary files /dev/null and b/audio_samples/0009/voice.wav differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d3e98baa6b00e88de5c4e4ea5c24a82246cc4738 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +batch_processing +interactive-pipe>=0.7.0 +torch>=2.0.0 diff --git a/src/gyraudio/__init__.py b/src/gyraudio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39c149120fc0b690fbf4c5557336eaabac0a6e13 --- /dev/null +++ b/src/gyraudio/__init__.py @@ -0,0 +1,2 @@ +from pathlib import Path +root_dir = Path(__file__).parent.parent.parent diff --git a/src/gyraudio/audio_separation/architecture/building_block.py b/src/gyraudio/audio_separation/architecture/building_block.py new file mode 100644 index 0000000000000000000000000000000000000000..82b791a79ecd847614f42e63daf0ba305c1c6b63 --- /dev/null +++ b/src/gyraudio/audio_separation/architecture/building_block.py @@ -0,0 +1,51 @@ +import torch +from typing import List + + +class FilterBank(torch.nn.Module): + """Convolution filter bank (linear) + Serves as an embedding for the audio signal + """ + + def __init__(self, ch_in: int, out_dim=16, k_size=5, dilation_list: List[int] = [1, 2, 4, 8]): + super().__init__() + self.out_dim = out_dim + self.source_modality_conv = torch.nn.ModuleList() + for dilation in dilation_list: + self.source_modality_conv.append( + torch.nn.Conv1d(ch_in, out_dim//len(dilation_list), k_size, dilation=dilation, padding=(dilation*(k_size//2))) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.cat([conv(x) for conv in self.source_modality_conv], axis=1) + assert out.shape[1] == self.out_dim + return out + + +class ResConvolution(torch.nn.Module): + """ResNet building block + https://paperswithcode.com/method/residual-connection + """ + + def __init__(self, ch, hdim=None, k_size=5): + super().__init__() + hdim = hdim or ch + self.conv1 = torch.nn.Conv1d(ch, hdim, k_size, padding=k_size//2) + self.conv2 = torch.nn.Conv1d(hdim, ch, k_size, padding=k_size//2) + self.non_linearity = torch.nn.ReLU() + + def forward(self, x_in): + x = self.conv1(x_in) + x = self.non_linearity(x) + x = self.conv2(x) + x += x_in + x = self.non_linearity(x) + return x + + +if __name__ == "__main__": + model = FilterBank(1, 16) + inp = torch.rand(2, 1, 2048) + out = model(inp) + print(model) + print(out[0].shape) diff --git a/src/gyraudio/audio_separation/architecture/flat_conv.py b/src/gyraudio/audio_separation/architecture/flat_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..4ee9b47a838c7666e8d19dab0fa3b48c8d092905 --- /dev/null +++ b/src/gyraudio/audio_separation/architecture/flat_conv.py @@ -0,0 +1,62 @@ +import torch +from gyraudio.audio_separation.architecture.model import SeparationModel +from typing import Tuple + + +class FlatConvolutional(SeparationModel): + """Convolutional neural network for audio separation, + No decimation, no bottleneck, just basic signal processing + """ + + def __init__(self, + ch_in: int = 1, + ch_out: int = 2, + h_dim=16, + k_size=5, + dilation=1 + ) -> None: + super().__init__() + self.conv1 = torch.nn.Conv1d( + ch_in, h_dim, k_size, + dilation=dilation, padding=dilation*(k_size//2)) + self.conv2 = torch.nn.Conv1d( + h_dim, h_dim, k_size, + dilation=dilation, padding=dilation*(k_size//2)) + self.conv3 = torch.nn.Conv1d( + h_dim, h_dim, k_size, + dilation=dilation, padding=dilation*(k_size//2)) + self.conv4 = torch.nn.Conv1d( + h_dim, h_dim, k_size, + dilation=dilation, padding=dilation*(k_size//2)) + self.relu = torch.nn.ReLU() + self.encoder = torch.nn.Sequential( + self.conv1, + self.relu, + self.conv2, + self.relu, + self.conv3, + self.relu, + self.conv4, + self.relu + ) + self.demux = torch.nn.Sequential(*( + torch.nn.Conv1d(h_dim, h_dim//2, 1), # conv1x1 + torch.nn.ReLU(), + torch.nn.Conv1d(h_dim//2, ch_out, 1), # conv1x1 + )) + + def forward(self, mixed_sig_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Perform feature extraction followed by classifier head + + Args: + sig_in (torch.Tensor): [N, C, T] + + Returns: + torch.Tensor: logits (not probabilities) [N, n_classes] + """ + # Convolution backbone + # [N, C, T] -> [N, h, T] + features = self.encoder(mixed_sig_in) + # [N, h, T] -> [N, 2, T] + demuxed = self.demux(features) + return torch.chunk(demuxed, 2, dim=1) # [N, 1, T], [N, 1, T] diff --git a/src/gyraudio/audio_separation/architecture/model.py b/src/gyraudio/audio_separation/architecture/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ab38c4e89c40592b9179d8ba98302deecbfea33b --- /dev/null +++ b/src/gyraudio/audio_separation/architecture/model.py @@ -0,0 +1,28 @@ +import torch + + +class SeparationModel(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def count_parameters(self) -> int: + """Count the total number of parameters of the model + + Returns: + int: total amount of parameters + """ + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + def receptive_field(self) -> int: + """Compute the receptive field of the model + + Returns: + int: receptive field + """ + input_tensor = torch.rand(1, 1, 4096, requires_grad=True) + out, out_noise = self.forward(input_tensor) + grad = torch.zeros_like(out) + grad[..., out.shape[-1]//2] = torch.nan # set NaN gradient at the middle of the output + out.backward(gradient=grad) + self.zero_grad() # reset to avoid future problems + return int(torch.sum(input_tensor.grad.isnan()).cpu()) # Count NaN in the input diff --git a/src/gyraudio/audio_separation/architecture/neutral.py b/src/gyraudio/audio_separation/architecture/neutral.py new file mode 100644 index 0000000000000000000000000000000000000000..b81c5319a173d5ce1fad5ef95713c3ed10e29309 --- /dev/null +++ b/src/gyraudio/audio_separation/architecture/neutral.py @@ -0,0 +1,15 @@ + +import torch +from gyraudio.audio_separation.architecture.model import SeparationModel + + +class NeutralModel(SeparationModel): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.fake = torch.nn.Conv1d(1, 1, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Identity function + """ + n = self.fake(x) + return x, n diff --git a/src/gyraudio/audio_separation/architecture/transformer.py b/src/gyraudio/audio_separation/architecture/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..40467c6cf0b4697d00c85bd18d23ff511061d06b --- /dev/null +++ b/src/gyraudio/audio_separation/architecture/transformer.py @@ -0,0 +1,91 @@ +import torch +from gyraudio.audio_separation.architecture.model import SeparationModel +from gyraudio.audio_separation.architecture.building_block import FilterBank +from typing import Optional + + +class TransformerModel(SeparationModel): + """Transformer base model + ========================= + - Embed signal with a filter bank + - No positional encoding (Potential =add/concatenate positional encoding) + - `nlayers` * transformer blocks + """ + + def __init__(self, + nhead: int = 8, # H + nlayers: int = 4, # L + k_size=5, + embedding_dim: int = 64, # D + ch_in: int = 1, + ch_out: int = 1, + dropout: float = 0., # dr + positional_encoding: str = None + ) -> None: + """Transformer base model + + Args: + nhead (int): number of heads in each of the MHA models + embedding_dim (int): D number of channels in the audio embeddings + = output of the filter bank + assume `embedding_dim` = `h_dim` + h_dim is the hidden dimension of the model. + nlayers (int): number of nn.TransformerEncoderLayer in nn.TransformerEncoder + dropout (float, optional): dropout value. Defaults to 0. + """ + super().__init__() + self.model_type = "Transformer" + h_dim = embedding_dim # use the same embedding & hidden dimensions + + self.encoder = FilterBank(ch_in, embedding_dim, k_size=k_size) + if positional_encoding is None: + self.pos_encoder = torch.nn.Identity() + else: + raise NotImplementedError( + f"Unknown positional encoding {positional_encoding} - should be add/concat in future") + # self.pos_encoder = PositionalEncoding(h_dim, dropout=dropout) + + encoder_layers = torch.nn.TransformerEncoderLayer( + d_model=h_dim, # input dimension to the transformer encoder layer + nhead=nhead, # number of heads for MHA (Multi-head attention) + dim_feedforward=h_dim, # output dimension of the MLP on top of the transformer. + dropout=dropout, + batch_first=True + ) # we assume h_dim = d_model = dim_feedforward + + self.transformer_encoder = torch.nn.TransformerEncoder( + encoder_layers, + num_layers=nlayers + ) + self.h_dim = h_dim + self.target_modality_conv = torch.nn.Conv1d(h_dim, ch_out, 1) # conv1x1 channel mixer + # Note: we could finish with a few residual conv blocks... this is pure signal processing + + def forward( + self, src: torch.LongTensor, + src_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """Embdeddings, positional encoders, go trough `nlayers` of residual {multi (`nhead`) attention heads + MLP}. + + Args: + src (torch.LongTensor): [N, 1, T] audio signal + + Returns: + torch.FloatTensor: separated signal [N, 1, T] + """ + src = self.encoder(src) # [N, 1, T] -> [N, D, T] + src = src.transpose(-1, -2) # [N, D, T] -> [N, T, D] # Transformer expects (batch N, seq "T", features "D") + src = self.pos_encoder(src) # -> [N, T, D] - add positional encoding + + output = self.transformer_encoder(src, mask=src_mask) # -> [N, T, D] + output = output.transpose(-1, -2) # -> [N, D, T] + output = self.target_modality_conv(output) # -> [N, 1, T] + return output, None + + +if __name__ == "__main__": + model = TransformerModel() + inp = torch.rand(2, 1, 2048) + out = model(inp) + print(model) + print(out[0].shape) diff --git a/src/gyraudio/audio_separation/architecture/unet.py b/src/gyraudio/audio_separation/architecture/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..89bf30f2d2f360e14e91e2bdd18361984161506f --- /dev/null +++ b/src/gyraudio/audio_separation/architecture/unet.py @@ -0,0 +1,151 @@ +import torch +from gyraudio.audio_separation.architecture.model import SeparationModel +from gyraudio.audio_separation.architecture.building_block import ResConvolution +from typing import Optional +# import logging + + +class EncoderSingleStage(torch.nn.Module): + """ + Extend channels + Resnet + Downsample by 2 + """ + + def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5): + # ch_out ~ ch_in*extension_factor + super().__init__() + hdim = hdim or ch + self.extension_conv = torch.nn.Conv1d(ch, ch_out, k_size, padding=k_size//2) + self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size) + # warning on maxpooling jitter offset! + self.max_pool = torch.nn.MaxPool1d(kernel_size=2) + + def forward(self, x): + x = self.extension_conv(x) + x = self.res_conv(x) + x_ds = self.max_pool(x) + return x, x_ds + + +class DecoderSingleStage(torch.nn.Module): + """ + Upsample by 2 + Resnet + Extend channels + """ + + def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5): + """Decoder stage + + Args: + ch (int): channel size (downsampled & skip connection have same channel size) + ch_out (int): number of output channels (shall match the number of input channels of the next stage) + hdim (Optional[int], optional): Hidden dimension used in the residual block. Defaults to None. + k_size (int, optional): Convolution size. Defaults to 5. + Notes: + ====== + ch_out = 2*ch/extension_factor + + self.scale_mixers_conv + - tells how lower decoded (x_ds) scale is merged with current encoded scale (x_skip) + - could be a pointwise aka conv1x1 + """ + + super().__init__() + hdim = hdim or ch + self.scale_mixers_conv = torch.nn.Conv1d(2*ch, ch_out, k_size, padding=k_size//2) + + self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size) + # warning: Linear interpolation shall be "conjugated" with the skipping downsampling + # special care shall be taken care of regarding offsets + # https://arxiv.org/abs/1806.03185 + self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True) + self.non_linearity = torch.nn.ReLU() + + def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: + """""" + x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T] + + x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T] + x = self.scale_mixers_conv(x) # [N, ch_out, T] + x = self.non_linearity(x) + x = self.res_conv(x) # [N, ch_out, T] + return x + + +class ResUNet(SeparationModel): + """Convolutional neural network for audio separation, + + Decimation, bottleneck + """ + + def __init__(self, + ch_in: int = 1, + ch_out: int = 2, + channels_extension: float = 1.5, + h_dim=16, + k_size=5, + ) -> None: + super().__init__() + self.need_split = ch_out != ch_in + self.ch_out = ch_out + self.source_modality_conv = torch.nn.Conv1d(ch_in, h_dim, k_size, padding=k_size//2) + self.encoder_list = torch.nn.ModuleList() + self.decoder_list = torch.nn.ModuleList() + self.non_linearity = torch.nn.ReLU() + + h_dim_current = h_dim + for _level in range(4): + h_dim_ds = int(h_dim_current*channels_extension) + self.encoder_list.append(EncoderSingleStage(h_dim_current, h_dim_ds, k_size=k_size)) + self.decoder_list.append(DecoderSingleStage(h_dim_ds, h_dim_current, k_size=k_size)) + h_dim_current = h_dim_ds + self.bottleneck = ResConvolution(h_dim_current, k_size=k_size) + self.target_modality_conv = torch.nn.Conv1d(h_dim, ch_out, 1) # conv1x1 channel mixer + + def forward(self, x_in): + # x_in (1, 2048) + x0 = self.source_modality_conv(x_in) + x0 = self.non_linearity(x0) + # x0 -> (16, 2048) + + x1_skip, x1_ds = self.encoder_list[0](x0) + # x1_skip -> (24, 2048) + # x1_ds -> (24, 1024) + # print(x1_skip.shape, x1_ds.shape) + + x2_skip, x2_ds = self.encoder_list[1](x1_ds) + # x2_skip -> (36, 1024) + # x2_ds -> (36, 512) + # print(x2_skip.shape, x2_ds.shape) + + x3_skip, x3_ds = self.encoder_list[2](x2_ds) + # x3_skip -> (54, 512) + # x3_ds -> (54, 256) + # print(x3_skip.shape, x3_ds.shape) + + x4_skip, x4_ds = self.encoder_list[3](x3_ds) + # x4_skip -> (81, 256) + # x4_ds -> (81, 128) + # print(x4_skip.shape, x4_ds.shape) + + x4_dec = self.bottleneck(x4_ds) + x3_dec = self.decoder_list[3](x4_dec, x4_skip) + x2_dec = self.decoder_list[2](x3_dec, x3_skip) + x1_dec = self.decoder_list[1](x2_dec, x2_skip) + x0_dec = self.decoder_list[0](x1_dec, x1_skip) + demuxed = self.target_modality_conv(x0_dec) + # no relu + if self.need_split: + return torch.chunk(demuxed, self.ch_out, dim=1) + return demuxed, None + + +if __name__ == "__main__": + model = ResUNet() + inp = torch.rand(2, 1, 2048) + out = model(inp) + print(model) + print(model.count_parameters()) + print(out[0].shape) diff --git a/src/gyraudio/audio_separation/architecture/wave_unet.py b/src/gyraudio/audio_separation/architecture/wave_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..b330c35ad80bb373e2a386355cd77aeb5153185c --- /dev/null +++ b/src/gyraudio/audio_separation/architecture/wave_unet.py @@ -0,0 +1,163 @@ +import torch +from gyraudio.audio_separation.architecture.model import SeparationModel +from typing import Optional, Tuple + + +def get_non_linearity(activation: str): + if activation == "LeakyReLU": + non_linearity = torch.nn.LeakyReLU() + else: + non_linearity = torch.nn.ReLU() + return non_linearity + + +class BaseConvolutionBlock(torch.nn.Module): + def __init__(self, ch_in, ch_out: int, k_size: int, activation="LeakyReLU", dropout: float = 0, bias: bool = True) -> None: + super().__init__() + self.conv = torch.nn.Conv1d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias) + self.non_linearity = get_non_linearity(activation) + self.dropout = torch.nn.Dropout1d(p=dropout) + + def forward(self, x_in: torch.Tensor) -> torch.Tensor: + x = self.conv(x_in) # [N, ch_in, T] -> [N, ch_in+channels_extension, T] + x = self.non_linearity(x) + x = self.dropout(x) + return x + + +class EncoderStage(torch.nn.Module): + """Conv (and extend channels), downsample 2 by skipping samples + """ + + def __init__(self, ch_in: int, ch_out: int, k_size: int = 15, dropout: float = 0, bias: bool = True) -> None: + + super().__init__() + + self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias) + + def forward(self, x): + x = self.conv_block(x) + + x_ds = x[..., ::2] + # ch_out = ch_in+channels_extension + return x, x_ds + + +class DecoderStage(torch.nn.Module): + """Upsample by 2, Concatenate with skip connection, Conv (and shrink channels) + """ + + def __init__(self, ch_in: int, ch_out: int, k_size: int = 5, dropout: float = 0., bias: bool = True) -> None: + """Decoder stage + """ + + super().__init__() + self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias) + self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True) + + def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: + """""" + x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T] + x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T] + x = self.conv_block(x) # [N, ch_out, T] + return x + + +class WaveUNet(SeparationModel): + """UNET in temporal domain (waveform) + = Multiscale convolutional neural network for audio separation + https://arxiv.org/abs/1806.03185 + """ + + def __init__(self, + ch_in: int = 1, + ch_out: int = 2, + channels_extension: int = 24, + k_conv_ds: int = 15, + k_conv_us: int = 5, + num_layers: int = 6, + dropout: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + self.need_split = ch_out != ch_in + self.ch_out = ch_out + self.encoder_list = torch.nn.ModuleList() + self.decoder_list = torch.nn.ModuleList() + # Defining first encoder + self.encoder_list.append(EncoderStage(ch_in, channels_extension, k_size=k_conv_ds, dropout=dropout, bias=bias)) + for level in range(1, num_layers+1): + ch_i = level*channels_extension + ch_o = (level+1)*channels_extension + if level < num_layers: + # Skipping last encoder since we defined the first one outside the loop + self.encoder_list.append(EncoderStage(ch_i, ch_o, k_size=k_conv_ds, dropout=dropout, bias=bias)) + self.decoder_list.append(DecoderStage(ch_o+ch_i, ch_i, k_size=k_conv_us, dropout=dropout, bias=bias)) + self.bottleneck = BaseConvolutionBlock( + num_layers*channels_extension, + (num_layers+1)*channels_extension, + k_size=k_conv_ds, + dropout=dropout, + bias=bias) + self.dropout = torch.nn.Dropout1d(p=dropout) + self.target_modality_conv = torch.nn.Conv1d( + channels_extension+ch_in, ch_out, 1, bias=bias) # conv1x1 channel mixer + + def forward(self, x_in: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward UNET pass + + ``` + (1 , 2048)----------------->(24 , 2048) > (1 , 2048) + v ^ + (24 , 1024)----------------->(48 , 1024) + v ^ + (48 , 512 )----------------->(72 , 512 ) + v ^ + (72 , 256 )----------------->(96 , 256 ) + v ^ + (96 , 128 )----BOTTLENECK--->(120, 128 ) + ``` + + """ + skipped_list = [] + ds_list = [x_in] + for level, enc in enumerate(self.encoder_list): + x_skip, x_ds = enc(ds_list[-1]) + skipped_list.append(x_skip) + ds_list.append(x_ds.clone()) + # print(x_skip.shape, x_ds.shape) + x_dec = self.bottleneck(ds_list[-1]) + for level, dec in enumerate(self.decoder_list[::-1]): + x_dec = dec(x_dec, skipped_list[-1-level]) + # print(x_dec.shape) + x_dec = torch.cat([x_dec, x_in], dim=1) + # print(x_dec.shape) + x_dec = self.dropout(x_dec) + demuxed = self.target_modality_conv(x_dec) + # print(demuxed.shape) + if self.need_split: + return torch.chunk(demuxed, self.ch_out, dim=1) + return demuxed, None + + # x_skip, x_ds + # (24, 2048), (24, 1024) + # (48, 1024), (48, 512 ) + # (72, 512 ), (72, 256 ) + # (96, 256 ), (96, 128 ) + + # (120, 128 ) + # (96 , 256 ) + # (72 , 512 ) + # (48 , 1024) + # (24 , 2048) + # (25 , 2048) demuxed - after concat + # (1 , 2048) + + +if __name__ == "__main__": + model = WaveUNet(ch_out=1, num_layers=9) + inp = torch.rand(2, 1, 2048) + out = model(inp) + print(model) + print(model.count_parameters()) + print(out[0].shape) diff --git a/src/gyraudio/audio_separation/data/__init__.py b/src/gyraudio/audio_separation/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d20b74f53e8b855d1d4cebf2b8a05d3c0496938 --- /dev/null +++ b/src/gyraudio/audio_separation/data/__init__.py @@ -0,0 +1,5 @@ +from gyraudio.audio_separation.data.mixed import MixedAudioDataset +from gyraudio.audio_separation.data.remixed_fixed import RemixedFixedAudioDataset +from gyraudio.audio_separation.data.remixed_rnd import RemixedRandomAudioDataset +from gyraudio.audio_separation.data.single import SingleAudioDataset +from gyraudio.audio_separation.data.dataloader import get_dataloader, get_config_dataloader diff --git a/src/gyraudio/audio_separation/data/dataloader.py b/src/gyraudio/audio_separation/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..861c33a55ad41675fa3ba7db0cfd9823672e1a18 --- /dev/null +++ b/src/gyraudio/audio_separation/data/dataloader.py @@ -0,0 +1,47 @@ +from torch.utils.data import DataLoader +from gyraudio.audio_separation.data.mixed import MixedAudioDataset +from typing import Optional, List +from gyraudio.audio_separation.properties import ( + DATA_PATH, AUGMENTATION, SNR_FILTER, SHUFFLE, BATCH_SIZE, TRAIN, VALID, TEST, AUG_TRIM +) +from gyraudio import root_dir +RAW_AUDIO_ROOT = root_dir/"__data_source_separation"/"voice_origin" +MIXED_AUDIO_ROOT = root_dir/"__data_source_separation"/"source_separation" + + +def get_dataloader(configurations: dict, audio_dataset=MixedAudioDataset): + dataloaders = {} + for mode, configuration in configurations.items(): + dataset = audio_dataset( + configuration[DATA_PATH], + augmentation_config=configuration[AUGMENTATION], + snr_filter=configuration[SNR_FILTER] + ) + dl = DataLoader( + dataset, + shuffle=configuration[SHUFFLE], + batch_size=configuration[BATCH_SIZE], + collate_fn=dataset.collate_fn + ) + dataloaders[mode] = dl + return dataloaders + + +def get_config_dataloader( + audio_root=MIXED_AUDIO_ROOT, + mode: str = TRAIN, + shuffle: Optional[bool] = None, + batch_size: Optional[int] = 16, + snr_filter: Optional[List[float]] = None, + augmentation: dict = {}): + audio_folder = audio_root/mode + assert mode in [TRAIN, VALID, TEST] + assert audio_folder.exists() + config = { + DATA_PATH: audio_folder, + SHUFFLE: shuffle if shuffle is not None else (True if mode == TRAIN else False), + AUGMENTATION: augmentation, + SNR_FILTER: snr_filter, + BATCH_SIZE: batch_size + } + return config diff --git a/src/gyraudio/audio_separation/data/dataset.py b/src/gyraudio/audio_separation/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..904edbe33d2e712aad655426f98d5739f45a6407 --- /dev/null +++ b/src/gyraudio/audio_separation/data/dataset.py @@ -0,0 +1,104 @@ +from torch.utils.data import Dataset +from pathlib import Path +from typing import Optional +import torch +from torch.utils.data import default_collate +from typing import Tuple +from functools import partial +from gyraudio.audio_separation.properties import ( + AUG_AWGN, AUG_RESCALE, AUG_TRIM, LENGTHS, LENGTH_DIVIDER, TRIM_PROB +) + + +class AudioDataset(Dataset): + def __init__( + self, + data_path: Path, + augmentation_config: dict = {}, + snr_filter: Optional[float] = None, + debug: bool = False + ): + self.debug = debug + self.data_path = data_path + self.augmentation_config = augmentation_config + self.snr_filter = snr_filter + self.load_data() + self.length = len(self.file_list) + self.collate_fn = None + if AUG_TRIM in self.augmentation_config: + self.collate_fn = partial(collate_fn_generic, + lengths_lim=self.augmentation_config[AUG_TRIM][LENGTHS], + length_divider=self.augmentation_config[AUG_TRIM][LENGTH_DIVIDER], + trim_prob=self.augmentation_config[AUG_TRIM][TRIM_PROB]) + + def filter_data(self, snr): + if self.snr_filter is None: + return True + if snr in self.snr_filter: + return True + else: + return False + + def load_data(self): + raise NotImplementedError("load_data method must be implemented") + + def augment_data(self, mixed_audio_signal, clean_audio_signal, noise_audio_signal): + if AUG_RESCALE in self.augmentation_config: + current_amplitude = 0.5 + 1.5*torch.rand(1, device=mixed_audio_signal.device) + # logging.debug(current_amplitude) + mixed_audio_signal *= current_amplitude + noise_audio_signal *= current_amplitude + clean_audio_signal *= current_amplitude + if AUG_AWGN in self.augmentation_config: + # noise_std = self.augmentation_config[AUG_AWGN]["noise_std"] + noise_std = 0.01 + current_noise_std = torch.randn(1) * noise_std + # logging.debug(current_noise_std) + extra_awgn = torch.randn(mixed_audio_signal.shape, device=mixed_audio_signal.device) * current_noise_std + mixed_audio_signal = mixed_audio_signal+extra_awgn + # Open question: should we add noise to the noise signal aswell? + + return mixed_audio_signal, clean_audio_signal, noise_audio_signal + + def __len__(self): + return self.length + + def __getitem__(self, idx: int) -> torch.Tensor: + raise NotImplementedError("__getitem__ method must be implemented") + + +def collate_fn_generic(batch, lengths_lim, length_divider=1024, trim_prob=0.5) -> Tuple[torch.Tensor, torch.Tensor]: + """Collate function to allow trimming (=crop the time dimension) of the signals in a batch. + + Args: + batch (list): A list of tuples (triplets), where each tuple contain: + - mixed_audio_signal + - clean_audio_signal + - noise_audio_signal + lengths_lim (list) : A list of containing a minimum length (0) and a maximum length (1) + length_divider (int) : has to be a trimmed length divider + trim_prob (float) : trimming probability + + Returns: + - Tensor: A batch of mixed_audio_signal, trimmed to the same length. + - Tensor: A batch of clean_audio_signal + - Tensor: A batch of noise_audio_signal + """ + + # Find the length of the shortest signal in the batch + mixed_audio_signal, clean_audio_signal, noise_audio_signal = default_collate(batch) + length = mixed_audio_signal[0].shape[-1] + min_length, max_length = lengths_lim + take_full_signal = torch.rand(1) > trim_prob + if not take_full_signal: + start = torch.randint(0, length-min_length, (1,)) + trim_length = torch.randint(min_length, min(max_length, length-start-1)+1, (1,)) + trim_length = trim_length-trim_length % length_divider + end = start + trim_length + else: + start = 0 + end = length - length % length_divider + mixed_audio_signal = mixed_audio_signal[..., start:end] + clean_audio_signal = clean_audio_signal[..., start:end] + noise_audio_signal = noise_audio_signal[..., start:end] + return mixed_audio_signal, clean_audio_signal, noise_audio_signal diff --git a/src/gyraudio/audio_separation/data/mixed.py b/src/gyraudio/audio_separation/data/mixed.py new file mode 100644 index 0000000000000000000000000000000000000000..84e8b2c391da533b25d60e31a2330626d1505040 --- /dev/null +++ b/src/gyraudio/audio_separation/data/mixed.py @@ -0,0 +1,40 @@ +from gyraudio.audio_separation.data.dataset import AudioDataset +import logging +import torch +import torchaudio +from typing import Tuple + + +class MixedAudioDataset(AudioDataset): + def load_data(self): + self.folder_list = sorted(list(self.data_path.iterdir())) + self.file_list = [ + [ + list(folder.glob("mix*.wav"))[0], + folder/"voice.wav", + folder/"noise.wav" + ] for folder in self.folder_list + ] + snr_list = [float(file[0].stem.split("_")[-1]) for file in self.file_list] + self.file_list = [files for snr, files in zip(snr_list, self.file_list) if self.filter_data(snr)] + if self.debug: + logging.info(f"Available SNR {set(snr_list)}") + print(f"Available SNR {set(snr_list)}") + print("Filtered", len(self.file_list), self.snr_filter) + self.sampling_rate = None + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mixed_audio_path, signal_path, noise_path = self.file_list[idx] + assert mixed_audio_path.exists() + assert signal_path.exists() + assert noise_path.exists() + mixed_audio_signal, sampling_rate = torchaudio.load(str(mixed_audio_path)) + clean_audio_signal, sampling_rate = torchaudio.load(str(signal_path)) + noise_audio_signal, sampling_rate = torchaudio.load(str(noise_path)) + self.sampling_rate = sampling_rate + mixed_audio_signal, clean_audio_signal, noise_audio_signal = self.augment_data(mixed_audio_signal, clean_audio_signal, noise_audio_signal) + if self.debug: + logging.debug(f"{mixed_audio_signal.shape}") + logging.debug(f"{clean_audio_signal.shape}") + logging.debug(f"{noise_audio_signal.shape}") + return mixed_audio_signal, clean_audio_signal, noise_audio_signal diff --git a/src/gyraudio/audio_separation/data/remixed.py b/src/gyraudio/audio_separation/data/remixed.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba066e8a04c1ad0d1c7c1827c2b48adb4be585e --- /dev/null +++ b/src/gyraudio/audio_separation/data/remixed.py @@ -0,0 +1,53 @@ +from gyraudio.audio_separation.data.dataset import AudioDataset +from typing import Tuple +import logging +from torch import Tensor +import torch +import torchaudio + + +class RemixedAudioDataset(AudioDataset): + def generate_snr_list(self): + self.snr_list = None + + def load_data(self): + self.folder_list = sorted(list(self.data_path.iterdir())) + self.file_list = [ + [ + folder/"voice.wav", + folder/"noise.wav" + ] for folder in self.folder_list + ] + self.sampling_rate = None + self.min_snr, self.max_snr = -4, 4 + self.generate_snr_list() + if self.debug: + print("Not filtered", len(self.file_list), self.snr_filter) + print(self.snr_list) + + def get_idx_noise(self, idx): + raise NotImplementedError("get_idx_noise method must be implemented") + + def get_snr(self, idx): + raise NotImplementedError("get_snr method must be implemented") + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + signal_path = self.file_list[idx][0] + idx_noise = self.get_idx_noise(idx) + noise_path = self.file_list[idx_noise][1] + + assert signal_path.exists() + assert noise_path.exists() + clean_audio_signal, sampling_rate = torchaudio.load(str(signal_path)) + noise_audio_signal, sampling_rate = torchaudio.load(str(noise_path)) + snr = self.get_snr(idx) + alpha = 10 ** (-snr / 20) * torch.norm(clean_audio_signal) / torch.norm(noise_audio_signal) + mixed_audio_signal = clean_audio_signal + alpha*noise_audio_signal + self.sampling_rate = sampling_rate + mixed_audio_signal, clean_audio_signal, noise_audio_signal = self.augment_data( + mixed_audio_signal, clean_audio_signal, noise_audio_signal) + if self.debug: + logging.debug(f"{mixed_audio_signal.shape}") + logging.debug(f"{clean_audio_signal.shape}") + logging.debug(f"{noise_audio_signal.shape}") + return mixed_audio_signal, clean_audio_signal, noise_audio_signal diff --git a/src/gyraudio/audio_separation/data/remixed_fixed.py b/src/gyraudio/audio_separation/data/remixed_fixed.py new file mode 100644 index 0000000000000000000000000000000000000000..78a7d5232053e22cfb8806d34088957029ea0fa2 --- /dev/null +++ b/src/gyraudio/audio_separation/data/remixed_fixed.py @@ -0,0 +1,18 @@ +from gyraudio.audio_separation.data.remixed import RemixedAudioDataset +import torch + +class RemixedFixedAudioDataset(RemixedAudioDataset): + def generate_snr_list(self) : + rnd_gen = torch.Generator() + rnd_gen.manual_seed(2147483647) + if self.snr_filter is None : + self.snr_list = self.min_snr + (self.max_snr - self.min_snr)*torch.rand(len(self.file_list), generator = rnd_gen) + else : + indices = torch.randint(0, len(self.snr_filter), (len(self.file_list),), generator=rnd_gen) + self.snr_list = [self.snr_filter[idx] for idx in indices] + + def get_idx_noise(self, idx) : + return idx + + def get_snr(self, idx) : + return self.snr_list[idx] \ No newline at end of file diff --git a/src/gyraudio/audio_separation/data/remixed_rnd.py b/src/gyraudio/audio_separation/data/remixed_rnd.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9cb822e62a2d7279f559a44af61eeb94625da8 --- /dev/null +++ b/src/gyraudio/audio_separation/data/remixed_rnd.py @@ -0,0 +1,12 @@ +from gyraudio.audio_separation.data.remixed import RemixedAudioDataset +from torch import rand, randint + + +class RemixedRandomAudioDataset(RemixedAudioDataset): + def get_idx_noise(self, idx): + return randint(0, len(self.file_list)-1, (1,)) + + def get_snr(self, idx): + if self.snr_filter is None: + return self.min_snr + (self.max_snr - self.min_snr)*rand(1) + return self.snr_filter[randint(0, len(self.snr_filter), (1,))] diff --git a/src/gyraudio/audio_separation/data/silence_detector.py b/src/gyraudio/audio_separation/data/silence_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..e69d35a952e9ca4ba7718eaa987e725675078815 --- /dev/null +++ b/src/gyraudio/audio_separation/data/silence_detector.py @@ -0,0 +1,55 @@ +import torch +import matplotlib.pyplot as plt +import numpy as np + + +def get_silence_mask( + sig: torch.Tensor, morph_kernel_size: int = 499, k_smooth=21, thresh=0.0001, + debug: bool = False) -> torch.Tensor: + with torch.no_grad(): + smooth = torch.nn.Conv1d(1, 1, k_smooth, padding=k_smooth//2, bias=False).to(sig.device) + smooth.weight.data.fill_(1./k_smooth) + smoothed = smooth(torch.abs(sig)) + st = 1.*(torch.abs(smoothed) < thresh*torch.ones_like(smoothed, device=sig.device)) + sig_dil = torch.nn.MaxPool1d(morph_kernel_size, stride=1, padding=morph_kernel_size//2)(st) + sig_ero = -torch.nn.MaxPool1d(morph_kernel_size, stride=1, padding=morph_kernel_size//2)(-sig_dil) + if debug: + return sig_ero.squeeze(0), smoothed.squeeze(0), st.squeeze(0) + else: + return sig_ero + + +def visualize_silence_mask(sig: torch.Tensor, silence_thresh: float = 0.0001): + silence_thresh = 0.0001 + silence_mask, smoothed_amplitude, _ = get_silence_mask( + sig, k_smooth=21, morph_kernel_size=499, thresh=silence_thresh, debug=True + ) + plt.figure(figsize=(12, 4)) + plt.subplot(121) + plt.plot(sig.squeeze(0).cpu().numpy(), "k-", label="voice", alpha=0.5) + plt.plot(0.01*silence_mask.cpu().numpy(), "r-", alpha=1., label="silence mask") + plt.grid() + plt.legend() + plt.title("Voice and silence mask") + plt.ylim(-0.04, 0.04) + + plt.subplot(122) + plt.plot(smoothed_amplitude.cpu().numpy(), "g--", alpha=0.5, label="smoothed amplitude") + plt.plot(np.ones(silence_mask.shape[-1])*silence_thresh, "c--", alpha=1., label="threshold") + plt.plot(-silence_thresh+silence_thresh*silence_mask.cpu().numpy(), "r-", alpha=1, label="silence mask") + plt.grid() + plt.legend() + plt.title("Thresholding mechanism") + plt.ylim(-silence_thresh, silence_thresh*10) + plt.show() + + +if __name__ == "__main__": + from gyraudio.default_locations import SAMPLE_ROOT + from gyraudio.audio_separation.visualization.pre_load_audio import audio_loading + from gyraudio.audio_separation.properties import CLEAN, BUFFERS + sample_folder = SAMPLE_ROOT/"0009" + signals = audio_loading(sample_folder, preload=True) + device = "cuda" if torch.cuda.is_available() else "cpu" + sig_in = signals[BUFFERS][CLEAN].to(device) + visualize_silence_mask(sig_in) diff --git a/src/gyraudio/audio_separation/data/single.py b/src/gyraudio/audio_separation/data/single.py new file mode 100644 index 0000000000000000000000000000000000000000..307c289c40ae4994996e029370f87ea5fbdcf511 --- /dev/null +++ b/src/gyraudio/audio_separation/data/single.py @@ -0,0 +1,15 @@ +from gyraudio.audio_separation.data.dataset import AudioDataset +import logging +import torchaudio + + +class SingleAudioDataset(AudioDataset): + def load_data(self): + self.file_list = sorted(list(self.data_path.glob("*.wav"))) + + def __getitem__(self, idx: int): + audio_path = self.file_list[idx] + assert audio_path.exists() + audio_signal, sampling_rate = torchaudio.load(str(audio_path)) + logging.debug(f"{audio_signal.shape}") + return audio_signal diff --git a/src/gyraudio/audio_separation/experiment_tracking/experiments.py b/src/gyraudio/audio_separation/experiment_tracking/experiments.py new file mode 100644 index 0000000000000000000000000000000000000000..b11220c8dd7109aadc65daf7d2de3f4f6524d7c2 --- /dev/null +++ b/src/gyraudio/audio_separation/experiment_tracking/experiments.py @@ -0,0 +1,122 @@ +from gyraudio.default_locations import MIXED_AUDIO_ROOT +from gyraudio.audio_separation.properties import ( + TRAIN, TEST, VALID, NAME, EPOCHS, LEARNING_RATE, + OPTIMIZER, BATCH_SIZE, DATALOADER, AUGMENTATION, + SHORT_NAME, AUG_TRIM, TRIM_PROB, LENGTH_DIVIDER, LENGTHS, SNR_FILTER +) +from gyraudio.audio_separation.data.remixed_fixed import RemixedFixedAudioDataset +from gyraudio.audio_separation.data.remixed_rnd import RemixedRandomAudioDataset +from gyraudio.audio_separation.data import get_dataloader, get_config_dataloader +from gyraudio.audio_separation.experiment_tracking.experiments_definition import get_experiment_generator +import torch +from typing import Tuple + + +def get_experience(exp_major: int, exp_minor: int = 0, snr_filter_test=None, dry_run=False) -> Tuple[str, torch.nn.Module, dict, dict]: + """Get all experience details + + Args: + exp_major (int): Major experience number + exp_minor (int, optional): Used for HP search. Defaults to 0. + + + Returns: + Tuple[str, torch.nn.Module, dict, dict]: short_name, model, config, dataloaders + """ + model = None + config = {} + dataloader_name = "remix" + config = { + NAME: None, + OPTIMIZER: { + NAME: "adam", + LEARNING_RATE: 0.001 + }, + EPOCHS: 60, + DATALOADER: { + NAME: dataloader_name, + }, + BATCH_SIZE: [16, 16, 16], + SNR_FILTER : snr_filter_test + } + + model, config = get_experiment_generator(exp_major=exp_major)(config, no_model=dry_run, minor=exp_minor) + # POST PROCESSING + if isinstance(config[BATCH_SIZE], list) or isinstance(config[BATCH_SIZE], tuple): + config[BATCH_SIZE] = { + TRAIN: config[BATCH_SIZE][0], + TEST: config[BATCH_SIZE][1], + VALID: config[BATCH_SIZE][2], + } + + if config[DATALOADER][NAME] == "premix": + mixed_audio_root = MIXED_AUDIO_ROOT + dataloaders = get_dataloader({ + TRAIN: get_config_dataloader( + audio_root=mixed_audio_root, + mode=TRAIN, + shuffle=True, + batch_size=config[BATCH_SIZE][TRAIN], + augmentation=config[DATALOADER].get(AUGMENTATION, {}) + ), + TEST: get_config_dataloader( + audio_root=mixed_audio_root, + mode=TEST, + shuffle=False, + batch_size=config[BATCH_SIZE][TEST], + snr_filter=config[SNR_FILTER] + ) + }) + elif config[DATALOADER][NAME] == "remix": + mixed_audio_root = MIXED_AUDIO_ROOT + aug_test = {} + if AUG_TRIM in config[DATALOADER].get(AUGMENTATION, {}): + aug_test = { + AUG_TRIM: {LENGTHS: [None, None], LENGTH_DIVIDER: config[DATALOADER][AUGMENTATION] + [AUG_TRIM][LENGTH_DIVIDER], TRIM_PROB: -1.} + } + dl_train = get_dataloader( + { + TRAIN: get_config_dataloader( + audio_root=mixed_audio_root, + mode=TRAIN, + shuffle=True, + batch_size=config[BATCH_SIZE][TRAIN], + augmentation=config[DATALOADER].get(AUGMENTATION, {}) + ) + }, + audio_dataset=RemixedRandomAudioDataset + )[TRAIN] + dl_test = get_dataloader( + { + TEST: get_config_dataloader( + audio_root=mixed_audio_root, + mode=TEST, + shuffle=False, + batch_size=config[BATCH_SIZE][TEST] + ) + }, + audio_dataset=RemixedFixedAudioDataset + )[TEST] + dataloaders = { + TRAIN: dl_train, + TEST: dl_test + } + else: + raise NotImplementedError(f"Unknown dataloader {dataloader_name}") + assert config[NAME] is not None + + short_name = f"{exp_major:04d}_{exp_minor:04d}" + config[SHORT_NAME] = short_name + return short_name, model, config, dataloaders + + +if __name__ == "__main__": + from gyraudio.audio_separation.parser import shared_parser + parser_def = shared_parser() + args = parser_def.parse_args() + + for exp in args.experiments: + short_name, model, config, dl = get_experience(exp) + print(short_name) + print(config) diff --git a/src/gyraudio/audio_separation/experiment_tracking/experiments_decorator.py b/src/gyraudio/audio_separation/experiment_tracking/experiments_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..82119f937b303f3795624e01b29ecf7dcee610b6 --- /dev/null +++ b/src/gyraudio/audio_separation/experiment_tracking/experiments_decorator.py @@ -0,0 +1,48 @@ +import torch +from gyraudio.audio_separation.properties import ( + NAME, ANNOTATIONS, NB_PARAMS, RECEPTIVE_FIELD +) +from typing import Optional +REGISTERED_EXPERIMENTS_LIST = {} + + +# def count_parameters(model: torch.nn.Module) -> int: +# """Count number of trainable parameters + +# Args: +# model (torch.nn.Module): Pytorch model + +# Returns: +# int: Number of trainable elements +# """ +# return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def registered_experiment(major: Optional[int] = None, failed: Optional[bool] = False) -> callable: + """Decorate and register an experiment + - Register the experiment in the list of experiments + - Count the number of parameters and add it to the config + + Args: + major (Optional[int], optional): major id version = Number of the experiment. Defaults to None. + failed (Optional[bool], optional): If an experiment failed, + keep track of it but prevent from evaluating. Defaults to False. + + Returns: + callable: decorator function + """ + def decorator(func): + assert (major) not in REGISTERED_EXPERIMENTS_LIST, f"Experiment {major} already registered" + + def wrapper(config, minor=None, no_model=False, model=torch.nn.Module()): + config, model = func(config, model=None if not no_model else model, minor=minor) + config[NB_PARAMS] = model.count_parameters() + config[RECEPTIVE_FIELD] = model.receptive_field() + assert NAME in config, "NAME not defined" + assert ANNOTATIONS in config, "ANNOTATIONS not defined" + return model, config + if not failed: + REGISTERED_EXPERIMENTS_LIST[major] = wrapper + return wrapper + + return decorator diff --git a/src/gyraudio/audio_separation/experiment_tracking/experiments_definition.py b/src/gyraudio/audio_separation/experiment_tracking/experiments_definition.py new file mode 100644 index 0000000000000000000000000000000000000000..19afb4dc7e36aa9795a98121eff431a8b9caa54b --- /dev/null +++ b/src/gyraudio/audio_separation/experiment_tracking/experiments_definition.py @@ -0,0 +1,320 @@ +from gyraudio.audio_separation.architecture.flat_conv import FlatConvolutional +from gyraudio.audio_separation.architecture.unet import ResUNet +from gyraudio.audio_separation.architecture.wave_unet import WaveUNet +from gyraudio.audio_separation.architecture.neutral import NeutralModel +from gyraudio.audio_separation.architecture.transformer import TransformerModel +from gyraudio.audio_separation.properties import ( + NAME, ANNOTATIONS, MAX_STEPS_PER_EPOCH, EPOCHS, BATCH_SIZE, + OPTIMIZER, LEARNING_RATE, + DATALOADER, + WEIGHT_DECAY, + LOSS, LOSS_L1, + AUGMENTATION, AUG_TRIM, AUG_AWGN, AUG_RESCALE, + LENGTHS, LENGTH_DIVIDER, TRIM_PROB, + SCHEDULER, SCHEDULER_CONFIGURATION +) +from gyraudio.audio_separation.experiment_tracking.experiments_decorator import ( + registered_experiment, REGISTERED_EXPERIMENTS_LIST +) + + +@registered_experiment(major=9999) +def neutral(config, model: bool = None, minor=None): + config[BATCH_SIZE] = [4, 4, 4] + config[EPOCHS] = 1 + config[NAME] = "Neutral" + config[ANNOTATIONS] = "Neutral" + if model is None: + model = NeutralModel() + config[NAME] = "Neutral" + return config, model + + +@registered_experiment(major=0) +def exp_unit_test(config, model: bool = None, minor=None): + config[MAX_STEPS_PER_EPOCH] = 2 + config[BATCH_SIZE] = [4, 4, 4] + config[EPOCHS] = 2 + config[NAME] = "Unit Test - Flat Convolutional" + config[ANNOTATIONS] = "Baseline" + config[SCHEDULER] = "ReduceLROnPlateau" + config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8) + if model is None: + model = FlatConvolutional() + return config, model + +# ---------------- Low Baseline ----------------- + + +def exp_low_baseline( + config: dict, + batch_size: int = 16, + h_dim: int = 16, + k_size: int = 9, + dilation: int = 0, + model: bool = None, + minor=None +): + config[BATCH_SIZE] = [batch_size, batch_size, batch_size] + config[NAME] = "Flat Convolutional" + config[ANNOTATIONS] = f"Baseline H={h_dim}_K={k_size}" + if dilation > 1: + config[ANNOTATIONS] += f"_dil={dilation}" + config["Architecture"] = { + "name": "Flat-Conv", + "h_dim": h_dim, + "scales": 1, + "k_size": k_size, + "dilation": dilation + } + if model is None: + model = FlatConvolutional(k_size=k_size, h_dim=h_dim) + return config, model + + +@registered_experiment(major=1) +def exp_1(config, model: bool = None, minor=None): + config, model = exp_low_baseline(config, batch_size=32, k_size=5) + return config, model + + +@registered_experiment(major=2) +def exp_2(config, model: bool = None, minor=None): + config, model = exp_low_baseline(config, batch_size=32, k_size=9) + return config, model + + +@registered_experiment(major=3) +def exp_3(config, model: bool = None, minor=None): + config, model = exp_low_baseline(config, batch_size=32, k_size=9, dilation=2) + return config, model + + +@registered_experiment(major=4) +def exp_4(config, model: bool = None, minor=None): + config, model = exp_low_baseline(config, batch_size=16, k_size=9) + return config, model + +# ------------------ Res U-Net ------------------ + + +def exp_resunet(config, h_dim=16, k_size=5, model=None): + config[NAME] = "Res-UNet" + scales = 4 + config[ANNOTATIONS] = f"Res-UNet-{scales}scales_h={h_dim}_k={k_size}" + config["Architecture"] = { + "name": "Res-UNet", + "h_dim": h_dim, + "scales": scales, + "k_size": k_size, + } + if model is None: + model = ResUNet(h_dim=h_dim, k_size=k_size) + return config, model + + +@registered_experiment(major=2000) +def exp_2000_waveunet(config, model: bool = None, minor=None): + config[EPOCHS] = 60 + config, model = exp_resunet(config) + return config, model + + +@registered_experiment(major=2001) +def exp_2001_waveunet(config, model: bool = None, minor=None): + config[EPOCHS] = 60 + config, model = exp_resunet(config, h_dim=32, k_size=5) + return config, model + +# ------------------ Wave U-Net ------------------ + + +def exp_wave_unet(config: dict, + channels_extension: int = 24, + k_conv_ds: int = 15, + k_conv_us: int = 5, + num_layers: int = 4, + dropout: float = 0.0, + bias: bool = True, + model=None): + config[NAME] = "Wave-UNet" + config[ANNOTATIONS] = f"Wave-UNet-{num_layers}scales_h_ext={channels_extension}_k={k_conv_ds}ds-{k_conv_us}us" + if dropout > 0: + config[ANNOTATIONS] += f"-dr{dropout:.1e}" + if not bias: + config[ANNOTATIONS] += "-BiasFree" + config["Architecture"] = { + "k_conv_us": k_conv_us, + "k_conv_ds": k_conv_ds, + "num_layers": num_layers, + "channels_extension": channels_extension, + "dropout": dropout, + "bias": bias + } + if model is None: + model = WaveUNet( + **config["Architecture"] + ) + config["Architecture"][NAME] = "Wave-UNet" + return config, model + + +@registered_experiment(major=1000) +def exp_1000_waveunet(config, model: bool = None, minor=None): + config[EPOCHS] = 60 + config, model = exp_wave_unet(config, model=model, num_layers=4, channels_extension=24) + # 4 layers, ext +24 - Nvidia T500 4Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=1001) +def exp_1001_waveunet(config, model: bool = None, minor=None): + # OVERFIT 1M param ? + config[EPOCHS] = 60 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=16) + # 7 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=1002) +def exp_1002_waveunet(config, model: bool = None, minor=None): + # OVERFIT 1M param ? + config[EPOCHS] = 60 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=16) + config[DATALOADER][AUGMENTATION] = { + AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 1024, TRIM_PROB: 0.8}, + AUG_RESCALE: True + } + # 7 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=1003) +def exp_1003_waveunet(config, model: bool = None, minor=None): + # OVERFIT 2.3M params + config[EPOCHS] = 60 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=24) + # 7 layers, ext +24 - Nvidia RTX3060 6Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=1004) +def exp_1004_waveunet(config, model: bool = None, minor=None): + config[EPOCHS] = 120 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28) + # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=1014) +def exp_1014_waveunet(config, model: bool = None, minor=None): + # trained with min and max mixing snr hard coded between -2 and -1 + config[EPOCHS] = 50 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28) + # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=1005) +def exp_1005_waveunet(config, model: bool = None, minor=None): + config[EPOCHS] = 150 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=16) + config[DATALOADER][AUGMENTATION] = { + AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 1024, TRIM_PROB: 0.8}, + AUG_RESCALE: True + } + # 7 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=1006) +def exp_1006_waveunet(config, model: bool = None, minor=None): + config[EPOCHS] = 150 + config, model = exp_wave_unet(config, model=model, num_layers=11, channels_extension=16) + config[DATALOADER][AUGMENTATION] = { + AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 4096, TRIM_PROB: 0.8}, + AUG_RESCALE: True + } + # 11 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=1007) +def exp_1007_waveunet(config, model: bool = None, minor=None): + config[EPOCHS] = 150 + config, model = exp_wave_unet(config, model=model, num_layers=9, channels_extension=16) + config[DATALOADER][AUGMENTATION] = { + AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 4096, TRIM_PROB: 0.8}, + AUG_RESCALE: True + } + # 11 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=1008) +def exp_1008_waveunet(config, model: bool = None, minor=None): + # CHEAP BASELINE + config[EPOCHS] = 150 + config, model = exp_wave_unet(config, model=model, num_layers=4, channels_extension=16) + config[DATALOADER][AUGMENTATION] = { + AUG_TRIM: {LENGTHS: [8192, 80000], LENGTH_DIVIDER: 1024, TRIM_PROB: 0.8}, + AUG_RESCALE: True + } + # 4 layers, ext +16 - Nvidia T500 4Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=3000) +def exp_3000_waveunet(config, model: bool = None, minor=None): + config[EPOCHS] = 120 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False) + # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size + return config, model + + +@registered_experiment(major=3001) +def exp_3001_waveunet(config, model: bool = None, minor=None): + config[EPOCHS] = 200 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False) + # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size + config[SCHEDULER] = "ReduceLROnPlateau" + config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8) + config[OPTIMIZER][LEARNING_RATE] = 0.002 + return config, model + + +@registered_experiment(major=3002) +def exp_3002_waveunet(config, model: bool = None, minor=None): + # TRAINED WITH SNR -12db +12db (code changed manually!) + # See f910c6da3123e3d35cc0ce588bb5a72ce4a8c422 + config[EPOCHS] = 200 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False) + # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size + config[SCHEDULER] = "ReduceLROnPlateau" + config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8) + config[OPTIMIZER][LEARNING_RATE] = 0.002 + return config, model + + +@registered_experiment(major=4000) +def exp_4000_bias_free_waveunet_l1(config, model: bool = None, minor=None): + # config[MAX_STEPS_PER_EPOCH] = 2 + # config[BATCH_SIZE] = [2, 2, 2] + config[EPOCHS] = 200 + config[LOSS] = LOSS_L1 + config, model = exp_wave_unet(config, model=model, num_layers=7, channels_extension=28, bias=False) + # 7 layers, ext +28 - Nvidia RTX3060 6Gb RAM - 16 batch size + config[SCHEDULER] = "ReduceLROnPlateau" + config[SCHEDULER_CONFIGURATION] = dict(patience=5, factor=0.8) + config[OPTIMIZER][LEARNING_RATE] = 0.002 + return config, model + + +def get_experiment_generator(exp_major: int): + assert exp_major in REGISTERED_EXPERIMENTS_LIST, f"Experiment {exp_major} not registered" + exp_generator = REGISTERED_EXPERIMENTS_LIST[exp_major] + return exp_generator + + +if __name__ == "__main__": + print(f"Available experiments: {list(REGISTERED_EXPERIMENTS_LIST.keys())}") diff --git a/src/gyraudio/audio_separation/experiment_tracking/storage.py b/src/gyraudio/audio_separation/experiment_tracking/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..08439660827ff10d30b037851a2fade0aefe968c --- /dev/null +++ b/src/gyraudio/audio_separation/experiment_tracking/storage.py @@ -0,0 +1,65 @@ +from gyraudio.audio_separation.properties import SHORT_NAME, MODEL, OPTIMIZER, CURRENT_EPOCH, CONFIGURATION +from pathlib import Path +from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT +import logging +import torch + + +def get_output_folder(config: dict, root_dir: Path = EXPERIMENT_STORAGE_ROOT, override: bool = False) -> Path: + output_folder = root_dir/config["short_name"] + exists = False + if output_folder.exists(): + if not override: + logging.info(f"Experiment {config[SHORT_NAME]} already exists. Override is set to False. Skipping.") + if override: + logging.warning(f"Experiment {config[SHORT_NAME]} will be OVERRIDDEN") + exists = True + else: + output_folder.mkdir(parents=True, exist_ok=True) + exists = True + return exists, output_folder + + +def checkpoint_paths(exp_dir: Path, epoch=None): + if epoch is None: + checkpoints = sorted(exp_dir.glob("model_*.pt")) + assert len(checkpoints) > 0, f"No checkpoints found in {exp_dir}" + model_checkpoint = checkpoints[-1] + epoch = int(model_checkpoint.stem.split("_")[-1]) + optimizer_checkpoint = exp_dir/model_checkpoint.stem.replace("model", "optimizer") + else: + model_checkpoint = exp_dir/f"model_{epoch:04d}.pt" + optimizer_checkpoint = exp_dir/f"optimizer_{epoch:04d}.pt" + return model_checkpoint, optimizer_checkpoint, epoch + + +def load_checkpoint(model, exp_dir: Path, optimizer=None, epoch: int = None, + device="cuda" if torch.cuda.is_available() else "cpu"): + config = {} + model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch) + model_state_dict = torch.load(model_checkpoint, map_location=torch.device(device)) + model.load_state_dict(model_state_dict[MODEL]) + if optimizer is not None: + optimizer_state_dict = torch.load(optimizer_checkpoint, map_location=torch.device(device)) + optimizer.load_state_dict(optimizer_state_dict[OPTIMIZER]) + config = optimizer_state_dict[CONFIGURATION] + return model, optimizer, epoch, config + + +def save_checkpoint(model, exp_dir: Path, optimizer=None, config: dict = {}, epoch: int = None): + model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch) + torch.save( + { + MODEL: model.state_dict(), + }, + model_checkpoint + ) + torch.save( + { + CURRENT_EPOCH: epoch, + CONFIGURATION: config, + OPTIMIZER: optimizer.state_dict() + }, + optimizer_checkpoint + ) + print(f"Checkpoint saved:\n - model: {model_checkpoint}\n - checkpoint: {optimizer_checkpoint}") diff --git a/src/gyraudio/audio_separation/infer.py b/src/gyraudio/audio_separation/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..1456afc9032b76d1dd7d286ace7ad49b528c6d7a --- /dev/null +++ b/src/gyraudio/audio_separation/infer.py @@ -0,0 +1,202 @@ +from gyraudio.audio_separation.experiment_tracking.experiments import get_experience +from gyraudio.audio_separation.parser import shared_parser +from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER +from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT +from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint +from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder +from gyraudio.audio_separation.metrics import snr +from gyraudio.io.dump import Dump +from pathlib import Path +import sys +import torch +from tqdm import tqdm +import torchaudio +import pandas as pd +from typing import List +# Files paths +DEFAULT_RECORD_FILE = "infer_record.csv" # Store the characteristics of the inference record file +DEFAULT_EVALUATION_FILE = "eval_df.csv" # Store the characteristics of the inference record file +# Record keys +NBATCH = "nb_batch" +BEST_SNR = "best_snr" +BEST_SAVE_SNR = "best_save_snr" +WORST_SNR = "worst_snr" +WORST_SAVE_SNR = "worst_save_snr" +RECORD_KEYS = [NAME, SHORT_NAME, CURRENT_EPOCH, NBATCH, SNR_FILTER, BEST_SAVE_SNR, BEST_SNR, WORST_SAVE_SNR, WORST_SNR] +# Exaluation keys +SAVE_IDX = "save_idx" +SNR_IN = "snr_in" +SNR_OUT = "snr_out" +EVAL_KEYS = [SAVE_IDX, SNR_IN, SNR_OUT] + + +def load_file(path: Path, keys: List[str]) -> pd.DataFrame: + if not (path.exists()): + df = pd.DataFrame(columns=keys) + df.to_csv(path) + return pd.read_csv(path) + + +def launch_infer(exp: int, snr_filter: list = None, device: str = "cuda", model_dir: Path = None, + output_dir: Path = EXPERIMENT_STORAGE_ROOT, force_reload=False, max_batches=None, + ext=".wav"): + # Load experience + if snr_filter is not None: + snr_filter = sorted(snr_filter) + short_name, model, config, dl = get_experience(exp, snr_filter_test=snr_filter) + exists, exp_dir = get_output_folder(config, root_dir=model_dir, override=False) + assert exp_dir.exists(), f"Experiment {short_name} does not exist in {model_dir}" + model.eval() + model.to(device) + model, optimizer, epoch, config_checkpt = load_checkpoint(model, exp_dir, epoch=None, device=device) + # Folder creation + if output_dir is not None: + record_path = output_dir/DEFAULT_RECORD_FILE + record_df = load_file(record_path, RECORD_KEYS) + + # Define conditions for filtering + exist_conditions = { + NAME: config[NAME], + SHORT_NAME: config[SHORT_NAME], + CURRENT_EPOCH: epoch, + NBATCH: max_batches, + } + # Create boolean masks and combine them + masks = [(record_df[key] == value) for key, value in exist_conditions.items()] + if snr_filter is None: + masks.append((record_df[SNR_FILTER]).isnull()) + else: + masks.append(record_df[SNR_FILTER] == str(snr_filter)) + combined_mask = pd.Series(True, index=record_df.index) + for mask in masks: + combined_mask = combined_mask & mask + filtered_df = record_df[combined_mask] + + save_dir = output_dir/(exp_dir.name+"_infer" + (f"_epoch_{epoch:04d}_nbatch_{max_batches if max_batches is not None else len(dl[TEST])}") + + ("" if snr_filter is None else f"_snrs_{'_'.join(map(str, snr_filter))}")) + evaluation_path = save_dir/DEFAULT_EVALUATION_FILE + if not (filtered_df.empty) and not (force_reload): + assert evaluation_path.exists() + print(f"Inference already exists, see folder {save_dir}") + record_row_df = filtered_df + else: + record_row_df = pd.DataFrame({ + NAME: config[NAME], + SHORT_NAME: config[SHORT_NAME], + CURRENT_EPOCH: epoch, + NBATCH: max_batches, + SNR_FILTER: [None], + }, index=[0], columns=RECORD_KEYS) + record_row_df.at[0, SNR_FILTER] = snr_filter + + save_dir.mkdir(parents=True, exist_ok=True) + evaluation_df = load_file(evaluation_path, EVAL_KEYS) + with torch.no_grad(): + test_loss = 0. + save_idx = 0 + best_snr = 0 + worst_snr = 0 + processed_batches = 0 + for step_index, (batch_mix, batch_signal, batch_noise) in tqdm( + enumerate(dl[TEST]), desc=f"Inference epoch {epoch}", total=max_batches if max_batches is not None else len(dl[TEST])): + batch_mix, batch_signal, batch_noise = batch_mix.to( + device), batch_signal.to(device), batch_noise.to(device) + batch_output_signal, _batch_output_noise = model(batch_mix) + loss = torch.nn.functional.mse_loss(batch_output_signal, batch_signal) + test_loss += loss.item() + + # SNR stats + snr_in = snr(batch_mix, batch_signal, reduce=None) + snr_out = snr(batch_output_signal, batch_signal, reduce=None) + best_current, best_idx_current = torch.max(snr_out-snr_in, axis=0) + worst_current, worst_idx_current = torch.min(snr_out-snr_in, axis=0) + if best_current > best_snr: + best_snr = best_current + best_save_idx = save_idx + best_idx_current + if worst_current > worst_snr: + worst_snr = worst_current + worst_save_idx = save_idx + worst_idx_current + + # Save by signal + batch_output_signal = batch_output_signal.detach().cpu() + batch_signal = batch_signal.detach().cpu() + batch_mix = batch_mix.detach().cpu() + for audio_idx in range(batch_output_signal.shape[0]): + dic = {SAVE_IDX: save_idx, SNR_IN: float( + snr_in[audio_idx]), SNR_OUT: float(snr_out[audio_idx])} + new_eval_row = pd.DataFrame(dic, index=[0]) + evaluation_df = pd.concat([new_eval_row, evaluation_df.loc[:]], ignore_index=True) + + # Save .wav + torchaudio.save( + str(save_dir/f"{save_idx:04d}_mixed{ext}"), + batch_mix[audio_idx, :, :], + sample_rate=dl[TEST].dataset.sampling_rate, + channels_first=True + ) + torchaudio.save( + str(save_dir/f"{save_idx:04d}_out{ext}"), + batch_output_signal[audio_idx, :, :], + sample_rate=dl[TEST].dataset.sampling_rate, + channels_first=True + ) + torchaudio.save( + str(save_dir/f"{save_idx:04d}_original{ext}"), + batch_signal[audio_idx, :, :], + sample_rate=dl[TEST].dataset.sampling_rate, + channels_first=True + ) + Dump.save_json(dic, save_dir/f"{save_idx:04d}.json") + save_idx += 1 + processed_batches += 1 + if max_batches is not None and processed_batches >= max_batches: + break + test_loss = test_loss/len(dl[TEST]) + evaluation_df.to_csv(evaluation_path) + + record_row_df[BEST_SAVE_SNR] = int(best_save_idx) + record_row_df[BEST_SNR] = float(best_snr) + record_row_df[WORST_SAVE_SNR] = int(worst_save_idx) + record_row_df[WORST_SNR] = float(worst_snr) + record_df = pd.concat([record_row_df, record_df.loc[:]], ignore_index=True) + record_df.to_csv(record_path, index=0) + + print(f"Test loss: {test_loss:.3e}, \nbest snr performance: {best_save_idx} with {best_snr:.1f}dB, \nworst snr performance: {worst_save_idx} with {worst_snr:.1f}dB") + + return record_row_df, evaluation_path + + +def main(argv): + default_device = "cuda" if torch.cuda.is_available() else "cpu" + parser_def = shared_parser(help="Launch inference on a specific model" + + ("\n<<>>" if default_device == "cuda" else "")) + parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) + parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) + parser_def.add_argument("-d", "--device", type=str, default=default_device, + help="Training device", choices=["cpu", "cuda"]) + parser_def.add_argument("-r", "--reload", action="store_true", + help="Force reload files") + parser_def.add_argument("-b", "--nb-batch", type=int, default=None, + help="Number of batches to process") + parser_def.add_argument("-s", "--snr-filter", type=float, nargs="+", default=None, + help="SNR filters on the inference dataloader") + parser_def.add_argument("-ext", "--extension", type=str, default=".wav", help="Extension of the audio files to save", + choices=[".wav", ".mp4"]) + args = parser_def.parse_args(argv) + for exp in args.experiments: + launch_infer( + exp, + model_dir=Path(args.input_dir), + output_dir=Path(args.output_dir), + device=args.device, + force_reload=args.reload, + max_batches=args.nb_batch, + snr_filter=args.snr_filter, + ext=args.extension + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) + +# Example : python src\gyraudio\audio_separation\infer.py -i ./__output_audiosep -e 1002 -d cpu -b 2 -s 4 5 6 diff --git a/src/gyraudio/audio_separation/metrics.py b/src/gyraudio/audio_separation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..febb511003811418be8a30f1c389b3e19a79ed1b --- /dev/null +++ b/src/gyraudio/audio_separation/metrics.py @@ -0,0 +1,101 @@ +from gyraudio.audio_separation.properties import SIGNAL, NOISE, TOTAL, LOSS_TYPE, COEFFICIENT, SNR +import torch + + +def snr(prediction: torch.Tensor, ground_truth: torch.Tensor, reduce="mean") -> torch.Tensor: + """Compute the SNR between two tensors. + Args: + prediction (torch.Tensor): prediction tensor + ground_truth (torch.Tensor): ground truth tensor + Returns: + torch.Tensor: SNR + """ + power_signal = torch.sum(ground_truth**2, dim=(-2, -1)) + power_error = torch.sum((prediction-ground_truth)**2, dim=(-2, -1)) + eps = torch.finfo(torch.float32).eps + snr_per_element = 10*torch.log10((power_signal+eps)/(power_error+eps)) + final_snr = torch.mean(snr_per_element) if reduce == "mean" else snr_per_element + return final_snr + + +DEFAULT_COST = { + SIGNAL: { + COEFFICIENT: 0.5, + LOSS_TYPE: torch.nn.functional.mse_loss + }, + NOISE: { + COEFFICIENT: 0.5, + LOSS_TYPE: torch.nn.functional.mse_loss + }, + SNR: { + LOSS_TYPE: snr + } +} + + +class Costs: + """Keep track of cost functions. + ``` + for epoch in range(...): + metric.reset_epoch() + for step in dataloader(...): + ... # forward + prediction = model.forward(batch) + metric.update(prediction1, groundtruth1, SIGNAL1) + metric.update(prediction2, groundtruth2, SIGNAL2) + loss = metric.finish_step() + + loss.backward() + ... # backprop + metric.finish_epoch() + ... # log metrics + ``` + """ + + def __init__(self, name: str, costs=DEFAULT_COST) -> None: + self.name = name + self.keys = list(costs.keys()) + self.cost = costs + + def __reset_step(self) -> None: + self.metrics = {key: 0. for key in self.keys} + + def reset_epoch(self) -> None: + self.__reset_step() + self.total_metric = {key: 0. for key in self.keys+[TOTAL]} + self.count = 0 + + def update(self, + prediction: torch.Tensor, + ground_truth: torch.Tensor, + key: str + ) -> torch.Tensor: + assert key != TOTAL + # Compute loss for a single batch (=step) + loss_signal = self.cost[key][LOSS_TYPE](prediction, ground_truth) + self.metrics[key] = loss_signal + + def finish_step(self) -> torch.Tensor: + # Reset current total + self.metrics[TOTAL] = 0. + # Sum all metrics to total + for key in self.metrics: + if key != TOTAL and self.cost[key].get(COEFFICIENT, False): + self.metrics[TOTAL] += self.cost[key][COEFFICIENT]*self.metrics[key] + loss_signal = self.metrics[TOTAL] + for key in self.metrics: + if not isinstance(self.metrics[key], float): + self.metrics[key] = self.metrics[key].item() + self.total_metric[key] += self.metrics[key] + self.count += 1 + return loss_signal + + def finish_epoch(self) -> None: + for key in self.metrics: + self.total_metric[key] /= self.count + + def __repr__(self) -> str: + rep = f"{self.name}\t:\t" + for key in self.total_metric: + rep += f"{key}: {self.total_metric[key]:.3e} | " + return rep diff --git a/src/gyraudio/audio_separation/parser.py b/src/gyraudio/audio_separation/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..85834acaee006a9879d9530b98e29bf2c5355693 --- /dev/null +++ b/src/gyraudio/audio_separation/parser.py @@ -0,0 +1,8 @@ +import argparse + + +def shared_parser(help="Train models for audio separation"): + parser = argparse.ArgumentParser(description=help) + parser.add_argument("-e", "--experiments", type=int, nargs="+", required=True, + help="Experiment ids to be trained sequentially") + return parser diff --git a/src/gyraudio/audio_separation/properties.py b/src/gyraudio/audio_separation/properties.py new file mode 100644 index 0000000000000000000000000000000000000000..7704163f9af806dc9681f02c8e22988685cd62f9 --- /dev/null +++ b/src/gyraudio/audio_separation/properties.py @@ -0,0 +1,79 @@ +# Training modes (Train, Validation, Test) +TRAIN = "train" +VALID = "validation" +TEST = "test" + +# Dataset properties (keys) +DATA_PATH = "path" +BATCH_SIZE = "batch_size" +SHUFFLE = "shuffle" +SNR_FILTER = "snr_filter" +AUGMENTATION = "augmentation" +DATALOADER = "dataloader" + + +# Loss split +SIGNAL = "signal" +NOISE = "noise" +TOTAL = "total" +COEFFICIENT = "coefficient" + + +# Augmentation types +AUG_TRIM = "trim" # trim batches to arbitrary length +AUG_AWGN = "awgn" # add white gaussian noise +AUG_RESCALE = "rescale" # rescale all signals arbitrarily + +# Trim types +LENGTHS = "lengths" # a list of min and max length +LENGTH_DIVIDER = "length_divider" # an int that divides the length +TRIM_PROB = "trim_probability" # a float in [0, 1] of trimming probability + + +# Training configuration properties (keys) + +OPTIMIZER = "optimizer" +LEARNING_RATE = "lr" +WEIGHT_DECAY = "weight_decay" +BETAS = "betas" +EPOCHS = "epochs" +BATCH_SIZE = "batch_size" +MAX_STEPS_PER_EPOCH = "max_steps_per_epoch" + + +# Properties for the model +NAME = "name" +ANNOTATIONS = "annotations" +NB_PARAMS = "nb_params" +RECEPTIVE_FIELD = "receptive_field" +SHORT_NAME = "short_name" + + +# Scheduler +SCHEDULER = "scheduler" +SCHEDULER_CONFIGURATION = "scheduler_configuration" + +# Loss +LOSS = "loss" +LOSS_L1 = "L1" +LOSS_L2 = "MSE" +LOSS_TYPE = "loss_type" +SNR = "snr" + +# Checkpoint +MODEL = "model" +CURRENT_EPOCH = "current_epoch" +CONFIGURATION = "configuration" + + +# Signal names +CLEAN = "clean" +NOISY = "noise" +MIXED = "mixed" +PREDICTED = "predicted" + + +# MISC +PATHS = "paths" +BUFFERS = "buffers" +SAMPLING_RATE = "sampling_rate" diff --git a/src/gyraudio/audio_separation/train.py b/src/gyraudio/audio_separation/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a45cb33c8d10edf1ac6dc27e0e6e31b9719a16ba --- /dev/null +++ b/src/gyraudio/audio_separation/train.py @@ -0,0 +1,183 @@ +from gyraudio.audio_separation.experiment_tracking.experiments import get_experience +from gyraudio.audio_separation.parser import shared_parser +from gyraudio.audio_separation.properties import ( + TRAIN, TEST, EPOCHS, OPTIMIZER, NAME, MAX_STEPS_PER_EPOCH, + SIGNAL, NOISE, TOTAL, SNR, SCHEDULER, SCHEDULER_CONFIGURATION, + LOSS, LOSS_L2, LOSS_L1, LOSS_TYPE, COEFFICIENT +) +from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT +from torch.optim.lr_scheduler import ReduceLROnPlateau +from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder, save_checkpoint +from gyraudio.audio_separation.metrics import Costs, snr +# from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint +from pathlib import Path +from gyraudio.io.dump import Dump +import sys +import torch +from tqdm import tqdm +from copy import deepcopy +import wandb +import logging + + +def launch_training(exp: int, wandb_flag: bool = True, device: str = "cuda", save_dir: Path = None, override=False): + + short_name, model, config, dl = get_experience(exp) + exists, output_folder = get_output_folder(config, root_dir=save_dir, override=override) + if not exists: + logging.warning(f"Skipping experiment {short_name}") + return False + else: + logging.info(f"Experiment {short_name} saved in {output_folder}") + + print(short_name) + print(config) + logging.info(f"Starting training for {short_name}") + logging.info(f"Config: {config}") + if wandb_flag: + wandb.init( + project="audio-separation", + entity="teammd", + name=short_name, + tags=["debug"], + config=config + ) + training_loop(model, config, dl, wandb_flag=wandb_flag, device=device, exp_dir=output_folder) + if wandb_flag: + wandb.finish() + return True + + +def update_metrics(metrics, phase, pred, gt, pred_noise, gt_noise): + metrics[phase].update(pred, gt, SIGNAL) + metrics[phase].update(pred_noise, gt_noise, NOISE) + metrics[phase].update(pred, gt, SNR) + loss = metrics[phase].finish_step() + return loss + + +def training_loop(model: torch.nn.Module, config: dict, dl, device: str = "cuda", wandb_flag: bool = False, + exp_dir: Path = None): + optim_params = deepcopy(config[OPTIMIZER]) + optim_name = optim_params[NAME] + optim_params.pop(NAME) + if optim_name == "adam": + optimizer = torch.optim.Adam(model.parameters(), **optim_params) + + scheduler = None + scheduler_config = config.get(SCHEDULER_CONFIGURATION, {}) + scheduler_name = config.get(SCHEDULER, False) + if scheduler_name: + if scheduler_name == "ReduceLROnPlateau": + scheduler = ReduceLROnPlateau(optimizer, mode='max', verbose=True, **scheduler_config) + logging.info(f"Using scheduler {scheduler_name} with config {scheduler_config}") + else: + raise NotImplementedError(f"Scheduler {scheduler_name} not implemented") + max_steps = config.get(MAX_STEPS_PER_EPOCH, None) + chosen_loss = config.get(LOSS, LOSS_L2) + if chosen_loss == LOSS_L2: + costs = {TRAIN: Costs(TRAIN), TEST: Costs(TEST)} + elif chosen_loss == LOSS_L1: + cost_init = { + SIGNAL: { + COEFFICIENT: 0.5, + LOSS_TYPE: torch.nn.functional.l1_loss + }, + NOISE: { + COEFFICIENT: 0.5, + LOSS_TYPE: torch.nn.functional.l1_loss + }, + SNR: { + LOSS_TYPE: snr + } + } + costs = { + TRAIN: Costs(TRAIN, costs=cost_init), + TEST: Costs(TEST) + } + for epoch in range(config[EPOCHS]): + costs[TRAIN].reset_epoch() + costs[TEST].reset_epoch() + model.to(device) + # Training loop + # ----------------------------------------------------------- + + metrics = {TRAIN: {}, TEST: {}} + for step_index, (batch_mix, batch_signal, batch_noise) in tqdm( + enumerate(dl[TRAIN]), desc=f"Epoch {epoch}", total=len(dl[TRAIN])): + if max_steps is not None and step_index >= max_steps: + break + batch_mix, batch_signal, batch_noise = batch_mix.to(device), batch_signal.to(device), batch_noise.to(device) + model.zero_grad() + batch_output_signal, batch_output_noise = model(batch_mix) + loss = update_metrics( + costs, TRAIN, + batch_output_signal, batch_signal, + batch_output_noise, batch_noise + ) + # costs[TRAIN].update(batch_output_signal, batch_signal, SIGNAL) + # costs[TRAIN].update(batch_output_noise, batch_noise, NOISE) + # loss = costs[TRAIN].finish_step() + loss.backward() + optimizer.step() + costs[TRAIN].finish_epoch() + + # Validation loop + # ----------------------------------------------------------- + model.eval() + torch.cuda.empty_cache() + with torch.no_grad(): + for step_index, (batch_mix, batch_signal, batch_noise) in tqdm( + enumerate(dl[TEST]), desc=f"Epoch {epoch}", total=len(dl[TEST])): + if max_steps is not None and step_index >= max_steps: + break + batch_mix, batch_signal, batch_noise = batch_mix.to( + device), batch_signal.to(device), batch_noise.to(device) + batch_output_signal, batch_output_noise = model(batch_mix) + loss = update_metrics( + costs, TEST, + batch_output_signal, batch_signal, + batch_output_noise, batch_noise + ) + costs[TEST].finish_epoch() + if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(costs[TEST].total_metric[SNR]) + print(f"epoch {epoch}:\n{costs[TRAIN]}\n{costs[TEST]}") + wandblogs = {} + if wandb_flag: + for phase in [TRAIN, TEST]: + wandblogs[f"{phase} loss signal"] = costs[phase].total_metric[SIGNAL] + wandblogs[f"debug loss/{phase} loss signal"] = costs[phase].total_metric[SIGNAL] + wandblogs[f"debug loss/{phase} loss total"] = costs[phase].total_metric[TOTAL] + wandblogs[f"debug loss/{phase} loss noise"] = costs[phase].total_metric[NOISE] + wandblogs[f"{phase} snr"] = costs[phase].total_metric[SNR] + wandblogs["learning rate"] = optimizer.param_groups[0]['lr'] + wandb.log(wandblogs) + metrics[TRAIN] = costs[TRAIN].total_metric + metrics[TEST] = costs[TEST].total_metric + Dump.save_json(metrics, exp_dir/f"metrics_{epoch:04d}.json") + save_checkpoint(model, exp_dir, optimizer, config=config, epoch=epoch) + torch.cuda.empty_cache() + + +def main(argv): + default_device = "cuda" if torch.cuda.is_available() else "cpu" + parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/teammd/audio-separation" + + ("\n<<>>" if default_device == "cuda" else "")) + parser_def.add_argument("-nowb", "--no-wandb", action="store_true") + parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) + parser_def.add_argument("-f", "--force", action="store_true", help="Override existing experiment") + + parser_def.add_argument("-d", "--device", type=str, default=default_device, + help="Training device", choices=["cpu", "cuda"]) + args = parser_def.parse_args(argv) + for exp in args.experiments: + launch_training( + exp, wandb_flag=not args.no_wandb, save_dir=Path(args.output_dir), + override=args.force, + device=args.device + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/src/gyraudio/audio_separation/visualization/audio_player.py b/src/gyraudio/audio_separation/visualization/audio_player.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a9c0a5e9772fe25beca5f08ffd0eb3120ae007 --- /dev/null +++ b/src/gyraudio/audio_separation/visualization/audio_player.py @@ -0,0 +1,84 @@ +from gyraudio.audio_separation.properties import CLEAN, NOISY, MIXED, PREDICTED, SAMPLING_RATE +from pathlib import Path +from gyraudio.io.audio import save_audio_tensor +from gyraudio import root_dir +from interactive_pipe import Control, KeyboardControl +from interactive_pipe import interactive +import logging + +HERE = Path(__file__).parent +MUTE = "mute" +LOGOS = { + PREDICTED: HERE/"play_logo_pred.png", + MIXED: HERE/"play_logo_mixed.png", + CLEAN: HERE/"play_logo_clean.png", + NOISY: HERE/"play_logo_noise.png", + MUTE: HERE/"mute_logo.png", +} +ICONS = [it for key, it in LOGOS.items()] +KEYS = [key for key, it in LOGOS.items()] + +ping_pong_index = 0 + + +@interactive( + player=Control(MUTE, KEYS, icons=ICONS)) +def audio_selector(sig, mixed, pred, global_params={}, player=MUTE): + + global_params["selected_audio"] = player if player != MUTE else global_params.get("selected_audio", MIXED) + global_params[MUTE] = player == MUTE + if player == CLEAN: + audio_track = sig["buffers"][CLEAN] + elif player == NOISY: + audio_track = sig["buffers"][NOISY] + elif player == MIXED: + audio_track = mixed + elif player == PREDICTED: + audio_track = pred + else: + audio_track = mixed + return audio_track + + +@interactive( + loop=KeyboardControl(True, keydown="l")) +def audio_trim(audio_track, global_params={}, loop=True): + sampling_rate = global_params.get(SAMPLING_RATE, 8000) + if global_params.get("trim", False): + start, end = global_params["trim"]["start"], global_params["trim"]["end"] + remainder = (end-start) % 8 + audio_trim = audio_track[..., start:end-remainder] + repeat_factor = int(sampling_rate*4./(end-start)) + logging.debug(f"{repeat_factor}") + repeat_factor = max(1, repeat_factor) + if loop: + repeat_factor = 1 + audio_trim = audio_trim.repeat(1, repeat_factor) + logging.debug(f"{audio_trim.shape}") + else: + audio_trim = audio_track + return audio_trim + + +@interactive( + volume=(100, [0, 1000], "volume"), +) +def audio_player(audio_trim, global_params={}, volume=100): + sampling_rate = global_params.get(SAMPLING_RATE, 8000) + try: + if global_params.get(MUTE, True): + global_params["__stop"]() + print("mute!") + else: + ping_pong_path = root_dir/"__ping_pong" + ping_pong_path.mkdir(exist_ok=True) + global ping_pong_index + audio_track_path = ping_pong_path/f"_tmp_{ping_pong_index}.wav" + ping_pong_index = (ping_pong_index + 1) % 10 + save_audio_tensor(audio_track_path, volume/100.*audio_trim, + sampling_rate=global_params.get(SAMPLING_RATE, sampling_rate)) + global_params["__set_audio"](audio_track_path) + global_params["__play"]() + except Exception as exc: + logging.warning(f"Exception in audio_player {exc}") + pass diff --git a/src/gyraudio/audio_separation/visualization/interactive_audio.py b/src/gyraudio/audio_separation/visualization/interactive_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..aea428f3976f981a0621220869a5275d52aa6ebd --- /dev/null +++ b/src/gyraudio/audio_separation/visualization/interactive_audio.py @@ -0,0 +1,392 @@ +from batch_processing import Batch +import argparse +from pathlib import Path +from gyraudio.audio_separation.experiment_tracking.experiments import get_experience +from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder +from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT +from gyraudio.audio_separation.properties import ( + SHORT_NAME, CLEAN, NOISY, MIXED, PREDICTED, ANNOTATIONS, PATHS, BUFFERS, SAMPLING_RATE, NAME +) +import torch +from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint +from gyraudio.audio_separation.visualization.pre_load_audio import ( + parse_command_line_audio_load, load_buffers, audio_loading_batch) +from gyraudio.audio_separation.visualization.pre_load_custom_audio import ( + parse_command_line_generic_audio_load, generic_audio_loading_batch, + load_buffers_custom +) +from torchaudio.functional import resample +from typing import List +import numpy as np +import logging +from interactive_pipe.data_objects.curves import Curve, SingleCurve +from interactive_pipe import interactive, KeyboardControl, Control +from interactive_pipe.headless.pipeline import HeadlessPipeline +from interactive_pipe.graphical.qt_gui import InteractivePipeQT +from interactive_pipe.graphical.mpl_gui import InteractivePipeMatplotlib +from gyraudio.audio_separation.visualization.audio_player import audio_selector, audio_trim, audio_player + +default_device = "cuda" if torch.cuda.is_available() else "cpu" +LEARNT_SAMPLING_RATE = 8000 + + +@interactive( + idx=KeyboardControl(value_default=0, value_range=[ + 0, 1000], modulo=True, keyup="8", keydown="2"), + idn=KeyboardControl(value_default=0, value_range=[ + 0, 1000], modulo=True, keyup="9", keydown="3") +) +def signal_selector(signals, idx=0, idn=0, global_params={}): + if isinstance(signals, dict): + clean_sigs = signals[CLEAN] + clean = clean_sigs[idx % len(clean_sigs)] + if BUFFERS not in clean: + load_buffers_custom(clean) + noise_sigs = signals[NOISY] + noise = noise_sigs[idn % len(noise_sigs)] + if BUFFERS not in noise: + load_buffers_custom(noise) + cbuf, nbuf = clean[BUFFERS], noise[BUFFERS] + if clean[SAMPLING_RATE] != LEARNT_SAMPLING_RATE: + cbuf = resample(cbuf, clean[SAMPLING_RATE], LEARNT_SAMPLING_RATE) + clean[SAMPLING_RATE] = LEARNT_SAMPLING_RATE + if noise[SAMPLING_RATE] != LEARNT_SAMPLING_RATE: + nbuf = resample(nbuf, noise[SAMPLING_RATE], LEARNT_SAMPLING_RATE) + noise[SAMPLING_RATE] = LEARNT_SAMPLING_RATE + min_length = min(cbuf.shape[-1], nbuf.shape[-1]) + min_length = min_length - min_length % 1024 + signal = { + PATHS: { + CLEAN: clean[PATHS], + NOISY: noise[PATHS] + + }, + BUFFERS: { + CLEAN: cbuf[..., :1, :min_length], + NOISY: nbuf[..., :1, :min_length], + }, + NAME: f"Clean={clean[NAME]} | Noise={noise[NAME]}", + SAMPLING_RATE: LEARNT_SAMPLING_RATE + } + else: + # signals are loaded in CPU + signal = signals[idx % len(signals)] + if BUFFERS not in signal: + load_buffers(signal) + global_params["premixed_snr"] = signal.get("premixed_snr", None) + signal[NAME] = f"File={signal[NAME]}" + global_params["selected_info"] = signal[NAME] + global_params[SAMPLING_RATE] = signal[SAMPLING_RATE] + return signal + + +@interactive( + snr=(0., [-10., 10.], "SNR [dB]") +) +def remix(signals, snr=0., global_params={}): + signal = signals[BUFFERS][CLEAN] + noisy = signals[BUFFERS][NOISY] + alpha = 10 ** (-snr / 20) * torch.norm(signal) / torch.norm(noisy) + mixed_signal = signal + alpha * noisy + global_params["snr"] = snr + return mixed_signal + + +@interactive(std_dev=Control(0., value_range=[0., 0.1], name="extra noise std", step=0.0001), + amplify=(1., [0., 10.], "amplification of everything")) +def augment(signals, mixed, std_dev=0., amplify=1.): + signals[BUFFERS][MIXED] *= amplify + signals[BUFFERS][NOISY] *= amplify + signals[BUFFERS][CLEAN] *= amplify + mixed = mixed*amplify+torch.randn_like(mixed)*std_dev + return signals, mixed + + +# @interactive( +# device=("cuda", ["cpu", "cuda"]) if default_device == "cuda" else ("cpu", ["cpu"]) +# ) +def select_device(device=default_device, global_params={}): + global_params["device"] = device + + +@interactive( + model=KeyboardControl(value_default=0, value_range=[ + 0, 99], keyup="pagedown", keydown="pageup") +) +def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}): + selected_model = models[model % len(models)] + config = configs[model % len(models)] + short_name = config.get(SHORT_NAME, "") + annotations = config.get(ANNOTATIONS, "") + device = global_params.get("device", "cpu") + with torch.no_grad(): + selected_model.eval() + selected_model.to(device) + predicted_signal, predicted_noise = selected_model( + mixed.to(device).unsqueeze(0)) + predicted_signal = predicted_signal.squeeze(0) + pred_curve = SingleCurve(y=predicted_signal[0, :].detach().cpu().numpy(), + style="g-", label=f"predicted_{short_name} {annotations}") + return predicted_signal, pred_curve + + +def compute_metrics(pred, sig, global_params={}): + METRICS = "metrics" + target = sig[BUFFERS][CLEAN] + global_params[METRICS] = {} + global_params[METRICS]["MSE"] = torch.mean((target-pred.cpu())**2) + global_params[METRICS]["SNR"] = 10. * \ + torch.log10(torch.sum(target**2)/torch.sum((target-pred.cpu())**2)) + + +def get_trim(sig, zoom, center, num_samples=300): + N = len(sig) + native_ds = N/num_samples + center_idx = int(center*N) + window = int(num_samples/zoom*native_ds) + start_idx = max(0, center_idx - window//2) + end_idx = min(N, center_idx + window//2) + skip_factor = max(1, int(native_ds/zoom)) + return start_idx, end_idx, skip_factor + + +def zin(sig, zoom, center, num_samples=300): + start_idx, end_idx, skip_factor = get_trim( + sig, zoom, center, num_samples=num_samples) + out = np.zeros(num_samples) + trimmed = sig[start_idx:end_idx:skip_factor] + out[:len(trimmed)] = trimmed[:num_samples] + return out + + +@interactive( + center=KeyboardControl(value_default=0.5, value_range=[ + 0., 1.], step=0.01, keyup="6", keydown="4"), + zoom=KeyboardControl(value_default=0., value_range=[ + 0., 15.], step=1, keyup="+", keydown="-"), + zoomy=KeyboardControl( + value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down") +) +def visualize_audio(signal: dict, mixed_signal, pred, zoom=1, zoomy=0., center=0.5, global_params={}): + """Create curves + """ + zval = 1.5**zoom + start_idx, end_idx, _skip_factor = get_trim( + signal[BUFFERS][CLEAN][0, :], zval, center) + global_params["trim"] = dict(start=start_idx, end=end_idx) + selected = global_params.get("selected_audio", MIXED) + clean = SingleCurve(y=zin(signal[BUFFERS][CLEAN][0, :], zval, center), + alpha=1., + style="k-", + linewidth=0.9, + label=("*" if selected == CLEAN else " ")+"clean") + noisy = SingleCurve(y=zin(signal[BUFFERS][NOISY][0, :], zval, center), + alpha=0.3, + style="y--", + linewidth=1, + label=("*" if selected == NOISY else " ") + "noisy" + ) + mixed = SingleCurve(y=zin(mixed_signal[0, :], zval, center), style="r-", + alpha=0.1, + linewidth=2, + label=("*" if selected == MIXED else " ") + "mixed") + # true_mixed = SingleCurve(y=zin(signal[BUFFERS][MIXED][0, :], zval, center), + # alpha=0.3, style="b-", linewidth=1, label="true mixed") + pred.y = zin(pred.y, zval, center) + pred.label = ("*" if selected == PREDICTED else " ") + pred.label + curves = [noisy, mixed, pred, clean] + title = f"SNR in {global_params['snr']:.1f} dB" + if "selected_info" in global_params: + title += f" | {global_params['selected_info']}" + title += "\n" + for metric_name, metric_value in global_params.get("metrics", {}).items(): + title += f" | {metric_name} " + title += f"{metric_value:.2e}" if (abs(metric_value) < 1e-2 or abs(metric_value) + > 1000) else f"{metric_value:.2f}" + # if global_params.get("premixed_snr", None) is not None: + # title += f"| Premixed SNR : {global_params['premixed_snr']:.1f} dB" + return Curve(curves, ylim=[-0.04 * 1.5 ** zoomy, 0.04 * 1.5 ** zoomy], xlabel="Time index", ylabel="Amplitude", title=title) + + +def interactive_audio_separation_processing(signals, model_list, config_list): + sig = signal_selector(signals) + mixed = remix(sig) + # sig, mixed = augment(sig, mixed) + select_device() + pred, pred_curve = audio_sep_inference(mixed, model_list, config_list) + compute_metrics(pred, sig) + sound = audio_selector(sig, mixed, pred) + curve = visualize_audio(sig, mixed, pred_curve) + trimmed_sound = audio_trim(sound) + audio_player(trimmed_sound) + return curve + + +def interactive_audio_separation_visualization( + all_signals: List[dict], + model_list: List[torch.nn.Module], + config_list: List[dict], + gui="qt" +): + pip = HeadlessPipeline.from_function( + interactive_audio_separation_processing, cache=False) + if gui == "qt": + app = InteractivePipeQT( + pipeline=pip, name="audio separation", size=(1000, 1000), audio=True) + else: + logging.warning("No support for audio player with Matplotlib") + app = InteractivePipeMatplotlib( + pipeline=pip, name="audio separation", size=None, audio=False) + app(all_signals, model_list, config_list) + + +def visualization( + all_signals: List[dict], + model_list: List[torch.nn.Module], + config_list: List[dict], + device="cuda" +): + for signal in all_signals: + if BUFFERS not in signal: + load_buffers(signal, device="cpu") + clean = SingleCurve(y=signal[BUFFERS][CLEAN][0, :], label="clean") + noisy = SingleCurve(y=signal[BUFFERS][NOISY] + [0, :], label="noise", alpha=0.3) + curves = [clean, noisy] + for config, model in zip(config_list, model_list): + short_name = config.get(SHORT_NAME, "unknown") + predicted_signal, predicted_noise = model( + signal[BUFFERS][MIXED].to(device).unsqueeze(0)) + predicted = SingleCurve(y=predicted_signal.squeeze(0)[0, :].detach().cpu().numpy(), + label=f"predicted_{short_name}") + curves.append(predicted) + Curve(curves).show() + + +def parse_command_line(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser: + if gradio_demo: + parser = parse_command_line_gradio(parser) + else: + parser = parse_command_line_generic(parser) + return parser + + +def parse_command_line_gradio(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser: + if parser is None: + parser = parse_command_line_audio_load() + default_device = "cuda" if torch.cuda.is_available() else "cpu" + iparse = parser.add_argument_group("Audio separation visualization") + iparse.add_argument("-e", "--experiments", type=int, nargs="+", default=3001, + help="Experiment ids to be inferred sequentially") + iparse.add_argument("-p", "--interactive", default=True, + action="store_true", help="Play = Interactive mode") + iparse.add_argument("-m", "--model-root", type=str, + default=EXPERIMENT_STORAGE_ROOT) + iparse.add_argument("-d", "--device", type=str, default=default_device, + choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"]) + iparse.add_argument("-gui", "--gui", type=str, + default="gradio", choices=["qt", "mpl", "gradio"]) + pri + return parser + + +def parse_command_line_generic(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser: + if parser is None: + parser = parse_command_line_audio_load() + default_device = "cuda" if torch.cuda.is_available() else "cpu" + iparse = parser.add_argument_group("Audio separation visualization") + iparse.add_argument("-e", "--experiments", type=int, nargs="+", required=True, + help="Experiment ids to be inferred sequentially") + iparse.add_argument("-p", "--interactive", + action="store_true", help="Play = Interactive mode") + iparse.add_argument("-m", "--model-root", type=str, + default=EXPERIMENT_STORAGE_ROOT) + iparse.add_argument("-d", "--device", type=str, default=default_device, + choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"]) + iparse.add_argument("-gui", "--gui", type=str, + default="qt", choices=["qt", "mpl", "gradio"]) + return parser + + +def main(argv: List[str]): + """Paired signals and noise in folders""" + batch = Batch(argv) + batch.set_io_description( + input_help='input audio files', + output_help=argparse.SUPPRESS + ) + batch.set_multiprocessing_enabled(False) + parser = parse_command_line() + args = batch.parse_args(parser) + exp = args.experiments[0] + device = args.device + models_list = [] + config_list = [] + logging.info(f"Loading experiments models {args.experiments}") + for exp in args.experiments: + model_dir = Path(args.model_root) + short_name, model, config, _dl = get_experience(exp) + _, exp_dir = get_output_folder( + config, root_dir=model_dir, override=False) + assert exp_dir.exists( + ), f"Experiment {short_name} does not exist in {model_dir}" + model.eval() + model.to(device) + model, __optimizer, epoch, config = load_checkpoint( + model, exp_dir, epoch=None, device=args.device) + config[SHORT_NAME] = short_name + models_list.append(model) + config_list.append(config) + logging.info("Load audio buffers:") + all_signals = batch.run(audio_loading_batch) + if not args.interactive: + visualization(all_signals, models_list, config_list, device=device) + else: + interactive_audio_separation_visualization( + all_signals, models_list, config_list, gui=args.gui) + + +def main_custom(argv: List[str]): + """Handle custom noise and custom signals + """ + parser = parse_command_line() + parser.add_argument("-s", "--signal", type=str, required=True, + nargs="+", help="Signal to be preloaded") + parser.add_argument("-n", "--noise", type=str, required=True, + nargs="+", help="Noise to be preloaded") + args = parser.parse_args(argv) + exp = args.experiments[0] + device = args.device + models_list = [] + config_list = [] + logging.info(f"Loading experiments models {args.experiments}") + for exp in args.experiments: + model_dir = Path(args.model_root) + short_name, model, config, _dl = get_experience(exp) + _, exp_dir = get_output_folder( + config, root_dir=model_dir, override=False) + assert exp_dir.exists( + ), f"Experiment {short_name} does not exist in {model_dir}" + model.eval() + model.to(device) + model, __optimizer, epoch, config = load_checkpoint( + model, exp_dir, epoch=None, device=args.device) + config[SHORT_NAME] = short_name + models_list.append(model) + config_list.append(config) + all_signals = {} + for args_paths, key in zip([args.signal, args.noise], [CLEAN, NOISY]): + new_argv = ["-i"] + args_paths + if args.preload: + new_argv += ["--preload"] + batch = Batch(new_argv) + new_parser = parse_command_line_generic_audio_load() + batch.set_io_description( + input_help=argparse.SUPPRESS, # 'input audio files', + output_help=argparse.SUPPRESS + ) + batch.set_multiprocessing_enabled(False) + _ = batch.parse_args(new_parser) + all_signals[key] = batch.run(generic_audio_loading_batch) + interactive_audio_separation_visualization( + all_signals, models_list, config_list, gui=args.gui) diff --git a/src/gyraudio/audio_separation/visualization/interactive_infer.py b/src/gyraudio/audio_separation/visualization/interactive_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..cafa00b67fffe1e9108c93ed9f45c9188436ca00 --- /dev/null +++ b/src/gyraudio/audio_separation/visualization/interactive_infer.py @@ -0,0 +1,117 @@ +from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT +from gyraudio.audio_separation.parser import shared_parser +from gyraudio.audio_separation.infer import launch_infer, RECORD_KEYS, SNR_OUT, SNR_IN, NBATCH, SAVE_IDX +from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER +import sys +import os +from dash import Dash, html, dcc, callback, Output, Input, dash_table +import plotly.express as px +import plotly.graph_objects as go +from plotly.subplots import make_subplots +import pandas as pd +from typing import List +import torch +from pathlib import Path +DIFF_SNR = 'SNR out - SNR in' + + + +def get_app(record_row_dfs : pd.DataFrame, eval_dfs : List[pd.DataFrame]) : + app = Dash(__name__) + # names_options = [{'label' : f"{record[SHORT_NAME]} - {record[NAME]} epoch {record[CURRENT_EPOCH]:04d}", 'value' : record[NAME] } for idx, record in record_row_dfs.iterrows()] + app.layout = html.Div([ + html.H1(children='Inference results', style={'textAlign':'center'}), + # dcc.Dropdown(names_options, names_options[0]['value'], id='exp-selection'), + # dcc.RadioItems(['scatter', 'box'], 'box', inline=True, id='radio-plot-type'), + dcc.RadioItems([SNR_OUT, DIFF_SNR], DIFF_SNR, inline=True, id='radio-plot-out'), + dcc.Graph(id='graph-content') + ]) + + @callback( + Output('graph-content', 'figure'), + # Input('exp-selection', 'value'), + # Input('radio-plot-type', 'value'), + Input('radio-plot-out', 'value'), + ) + def update_graph(radio_plot_out) : + fig = make_subplots(rows = 2, cols = 1) + colors = px.colors.qualitative.Plotly + for id, record in record_row_dfs.iterrows() : + color = colors[id % len(colors)] + eval_df = eval_dfs[id].sort_values(by=SNR_IN) + eval_df[DIFF_SNR] = eval_df[SNR_OUT] - eval_df[SNR_IN] + legend = f'{record[SHORT_NAME]}_{record[NAME]}' + fig.add_trace( + go.Scatter( + x=eval_df[SNR_IN], + y=eval_df[radio_plot_out], + mode="markers", marker={'color' : color}, + name=legend, + hovertemplate = 'File : %{text}'+ + '
%{y}
', + text = [f"{eval[SAVE_IDX]:.0f}" for idx, eval in eval_df.iterrows()] + ), + row = 1, col = 1 + ) + eval_df_bins = eval_df + eval_df_bins[SNR_IN] = eval_df_bins[SNR_IN].apply(lambda snr : round(snr)) + fig.add_trace( + go.Box( + x=eval_df[SNR_IN], + y=eval_df[radio_plot_out], + fillcolor = color, + marker={'color' : color}, + name = legend + ), + row = 2, col = 1 + ) + + title = f"SNR performances" + fig.update_layout( + title=title, + xaxis2_title = SNR_IN, + yaxis_title = radio_plot_out, + hovermode='x unified' + ) + return fig + + + + return app + + +def main(argv): + default_device = "cuda" if torch.cuda.is_available() else "cpu" + parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/balthazarneveu/audio-sep" + + ("\n<<>>" if default_device == "cuda" else "")) + parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) + parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) + parser_def.add_argument("-d", "--device", type=str, default=default_device, + help="Training device", choices=["cpu", "cuda"]) + parser_def.add_argument("-b", "--nb-batch", type=int, default=None, + help="Number of batches to process") + parser_def.add_argument("-s", "--snr-filter", type=float, nargs="+", default=None, + help="SNR filters on the inference dataloader") + args = parser_def.parse_args(argv) + record_row_dfs = pd.DataFrame(columns = RECORD_KEYS) + eval_dfs = [] + for exp in args.experiments: + record_row_df, evaluation_path = launch_infer( + exp, + model_dir=Path(args.input_dir), + output_dir=Path(args.output_dir), + device=args.device, + max_batches=args.nb_batch, + snr_filter=args.snr_filter + ) + eval_df = pd.read_csv(evaluation_path) + # Careful, list order for concat is important for index matching eval_dfs list + record_row_dfs = pd.concat([record_row_dfs.loc[:], record_row_df], ignore_index=True) + eval_dfs.append(eval_df) + app = get_app(record_row_dfs, eval_dfs) + app.run(debug=True) + + +if __name__ == '__main__': + os.environ["KMP_DUPLICATE_LIB_OK"] = "True" + main(sys.argv[1:]) \ No newline at end of file diff --git a/src/gyraudio/audio_separation/visualization/mute_logo.png b/src/gyraudio/audio_separation/visualization/mute_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..fa9ec3935ff80e354800d911b46022c0a8bb0d2f Binary files /dev/null and b/src/gyraudio/audio_separation/visualization/mute_logo.png differ diff --git a/src/gyraudio/audio_separation/visualization/play_logo_clean.png b/src/gyraudio/audio_separation/visualization/play_logo_clean.png new file mode 100644 index 0000000000000000000000000000000000000000..20868f9d3f5a36964965b494b9c573f803ec526d Binary files /dev/null and b/src/gyraudio/audio_separation/visualization/play_logo_clean.png differ diff --git a/src/gyraudio/audio_separation/visualization/play_logo_mixed.png b/src/gyraudio/audio_separation/visualization/play_logo_mixed.png new file mode 100644 index 0000000000000000000000000000000000000000..357aab6dea8e8f27798792f108c71c2f56c1ef24 Binary files /dev/null and b/src/gyraudio/audio_separation/visualization/play_logo_mixed.png differ diff --git a/src/gyraudio/audio_separation/visualization/play_logo_noise.png b/src/gyraudio/audio_separation/visualization/play_logo_noise.png new file mode 100644 index 0000000000000000000000000000000000000000..79ad0dd983d70e6f203622f031e5d6c8e62fcae2 Binary files /dev/null and b/src/gyraudio/audio_separation/visualization/play_logo_noise.png differ diff --git a/src/gyraudio/audio_separation/visualization/play_logo_pred.png b/src/gyraudio/audio_separation/visualization/play_logo_pred.png new file mode 100644 index 0000000000000000000000000000000000000000..26281e1319071eb5cd75ba7aa85e657e2b969076 Binary files /dev/null and b/src/gyraudio/audio_separation/visualization/play_logo_pred.png differ diff --git a/src/gyraudio/audio_separation/visualization/pre_load_audio.py b/src/gyraudio/audio_separation/visualization/pre_load_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a9a1b36eebe27cad611255879591569a403c5f --- /dev/null +++ b/src/gyraudio/audio_separation/visualization/pre_load_audio.py @@ -0,0 +1,71 @@ +from batch_processing import Batch +import argparse +import sys +from pathlib import Path +from gyraudio.audio_separation.properties import CLEAN, NOISY, MIXED, PATHS, BUFFERS, NAME, SAMPLING_RATE +from gyraudio.io.audio import load_audio_tensor + + +def parse_command_line_audio_load() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description='Batch audio processing', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("-preload", "--preload", action="store_true", help="Preload audio files") + return parser + + +def outp(path: Path, suffix: str, extension=".wav"): + return (path.parent / (path.stem + suffix)).with_suffix(extension) + + +def load_buffers(signal: dict, device="cpu") -> None: + clean_signal, sampling_rate = load_audio_tensor(signal[PATHS][CLEAN], device=device) + noisy_signal, sampling_rate = load_audio_tensor(signal[PATHS][NOISY], device=device) + mixed_signal, sampling_rate = load_audio_tensor(signal[PATHS][MIXED], device=device) + signal[BUFFERS] = { + CLEAN: clean_signal, + NOISY: noisy_signal, + MIXED: mixed_signal + } + signal[SAMPLING_RATE] = sampling_rate + + +def audio_loading(input: Path, preload: bool) -> dict: + name = input.name + clean_audio_path = input/"voice.wav" + noisy_audio_path = input/"noise.wav" + mixed_audio_path = list(input.glob("mix*.wav"))[0] + signal = { + NAME: name, + PATHS: { + CLEAN: clean_audio_path, + NOISY: noisy_audio_path, + MIXED: mixed_audio_path + } + } + signal["premixed_snr"] = float(mixed_audio_path.stem.split("_")[-1]) + if preload: + load_buffers(signal) + return signal + + +def audio_loading_batch(input: Path, args: argparse.Namespace) -> dict: + """Wrapper to load audio files from a directory using batch_processing + """ + return audio_loading(input, preload=args.preload) + + +def main(argv): + batch = Batch(argv) + batch.set_io_description( + input_help='input audio files', + output_help=argparse.SUPPRESS + ) + parser = parse_command_line_audio_load() + batch.parse_args(parser) + all_signals = batch.run(audio_loading_batch) + return all_signals + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/src/gyraudio/audio_separation/visualization/pre_load_custom_audio.py b/src/gyraudio/audio_separation/visualization/pre_load_custom_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c08865cca223ee7ce5f835d25f26d7212ef2aa --- /dev/null +++ b/src/gyraudio/audio_separation/visualization/pre_load_custom_audio.py @@ -0,0 +1,53 @@ +from batch_processing import Batch +import argparse +import sys +from pathlib import Path +from gyraudio.audio_separation.properties import PATHS, BUFFERS, NAME, SAMPLING_RATE +from gyraudio.io.audio import load_audio_tensor + + +def parse_command_line_generic_audio_load() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description='Batch audio loading', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("-preload", "--preload", action="store_true", help="Preload audio files") + return parser + + +def load_buffers_custom(signal: dict, device="cpu") -> None: + generic_signal, sampling_rate = load_audio_tensor(signal[PATHS], device=device) + signal[BUFFERS] = generic_signal + signal[SAMPLING_RATE] = sampling_rate + + +def audio_loading(input: Path, preload: bool) -> dict: + name = input.parent.name + "/" + input.stem + signal = { + NAME: name, + PATHS: input, + } + if preload: + load_buffers_custom(signal) + return signal + + +def generic_audio_loading_batch(input: Path, args: argparse.Namespace) -> dict: + """Wrapper to load audio files from a directory using batch_processing + """ + return audio_loading(input, preload=args.preload) + + +def main(argv): + batch = Batch(argv) + batch.set_io_description( + input_help='input audio files', + output_help=argparse.SUPPRESS + ) + parser = parse_command_line_generic_audio_load() + batch.parse_args(parser) + all_signals = batch.run(generic_audio_loading_batch) + return all_signals + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/src/gyraudio/default_locations.py b/src/gyraudio/default_locations.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3201a8cb33924a00ac89451759a8a34d84021b --- /dev/null +++ b/src/gyraudio/default_locations.py @@ -0,0 +1,5 @@ +from gyraudio import root_dir +RAW_AUDIO_ROOT = root_dir/"__data_source_separation"/"voice_origin" +MIXED_AUDIO_ROOT = root_dir/"__data_source_separation"/"source_separation" +EXPERIMENT_STORAGE_ROOT = root_dir/"__output_audiosep" +SAMPLE_ROOT = root_dir/"audio_samples" diff --git a/src/gyraudio/io/audio.py b/src/gyraudio/io/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..e97121a3fafcadb9324ed5636f5c3e57320a0eb0 --- /dev/null +++ b/src/gyraudio/io/audio.py @@ -0,0 +1,29 @@ +import scipy +import numpy as np +import torchaudio +import torch +from pathlib import Path +from typing import Tuple + + +def load_raw_audio(path: str) -> np.array: + assert path.exists(), f"Audio path {path} does not exist" + rate, signal = scipy.io.wavfile.read(path) + return rate, signal + + +def load_audio_tensor(path: Path, device=None) -> Tuple[torch.Tensor, int]: + assert path.exists(), f"Audio path {path} does not exist" + signal, rate = torchaudio.load(str(path)) + if device is not None: + signal = signal.to(device) + return signal, rate + + +def save_audio_tensor(path: Path, signal: torch.Tensor, sampling_rate: int): + torchaudio.save( + str(path), + signal.detach().cpu(), + sample_rate=sampling_rate, + channels_first=True + ) diff --git a/src/gyraudio/io/dump.py b/src/gyraudio/io/dump.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7453339493a47ac10cb0266cd42e69056584f3 --- /dev/null +++ b/src/gyraudio/io/dump.py @@ -0,0 +1,54 @@ +from pathlib import Path +import json +import logging +import pickle +YAML_SUPPORT = True +YAML_NOT_DETECTED_MESSAGE = "yaml is not installed, consider installing it by pip install PyYAML" +try: + import yaml + from yaml.loader import SafeLoader, BaseLoader +except ImportError as e: + YAML_SUPPORT = False + logging.warning(f"{e}\n{YAML_NOT_DETECTED_MESSAGE}") + + +class Dump: + @staticmethod + def load_yaml(path: Path, safe_load=True) -> dict: + assert YAML_SUPPORT, YAML_NOT_DETECTED_MESSAGE + with open(path) as file: + params = yaml.load( + file, Loader=SafeLoader if safe_load else BaseLoader) + return params + + @staticmethod + def save_yaml(data: dict, path: Path, **kwargs): + path.parent.mkdir(parents=True, exist_ok=True) + assert YAML_SUPPORT, YAML_NOT_DETECTED_MESSAGE + with open(path, 'w') as outfile: + yaml.dump(data, outfile, **kwargs) + + @staticmethod + def load_json(path: Path,) -> dict: + with open(path) as file: + params = json.load(file) + return params + + @staticmethod + def save_json(data: dict, path: Path): + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, 'w') as outfile: + json.dump(data, outfile) + + @staticmethod + def load_pickle(path: Path,) -> dict: + with open(path, "rb") as file: + unpickler = pickle.Unpickler(file) + params = unpickler.load() + # params = pickle.load(file) + return params + + @staticmethod + def save_pickle(data: dict, path: Path): + with open(path, 'wb') as outfile: + pickle.dump(data, outfile) diff --git a/src/gyraudio/io/imu.py b/src/gyraudio/io/imu.py new file mode 100644 index 0000000000000000000000000000000000000000..aac673cb1756b350eda4f9bb424e307d83106774 --- /dev/null +++ b/src/gyraudio/io/imu.py @@ -0,0 +1,76 @@ +from pathlib import Path +import gpmf +import numpy as np +from gyraudio.properties import GYRO_KEY, ACCL_KEY +def extract_imu_blocks(stream, key=GYRO_KEY): + """ Extract imu data blocks from binary stream + + This is a generator on lists `KVLItem` objects. In + the GPMF stream, imu data comes into blocks of several + different data items. For each of these blocks we return a list. + + Parameters + ---------- + stream: bytes + The raw GPMF binary stream + + Returns + ------- + imu_items_generator: generator + Generator of lists of `KVLItem` objects + """ + for s in gpmf.parse.filter_klv(stream, "STRM"): + content = [] + is_imu = False + for elt in s.value: + content.append(elt) + if elt.key == key: + is_imu = True + if is_imu: + yield content + + +def parse_imu_block(imu_block, key=GYRO_KEY): + """Turn imu data blocks into `imuData` objects + + Parameters + ---------- + imu_block: list of KVLItem + A list of KVLItem corresponding to a imu data block. + + Returns + ------- + imu_data: imuData + A imuData object holding the imu information of a block. + """ + block_dict = { + s.key: s for s in imu_block + } + + imu_data = block_dict[key].value * 1.0 / block_dict["SCAL"].value + return { + "timestamp": block_dict["TSMP"], + key: imu_data, + } + + +def get_imu_data(pth: Path) -> np.array: + """Extract imu data from GPMF metadata from a Gopro video file + + Parameters + ---------- + path: str + Path to the GPMF file + + Returns + ------- + imu_data: list of imuData + List of imuData objects holding the imu information of the file. + """ + stream = gpmf.io.extract_gpmf_stream(pth) + imu_data_dict = {} + for key in [GYRO_KEY, ACCL_KEY]: + imu_blocks = extract_imu_blocks(stream, key=key) + imu_data = [parse_imu_block(imu_block, key=key) for imu_block in imu_blocks] + imu_data_dict[key] = np.vstack([np.array(imu[key]) for imu in imu_data]) + return imu_data_dict[GYRO_KEY], imu_data_dict[ACCL_KEY] diff --git a/src/gyraudio/properties.py b/src/gyraudio/properties.py new file mode 100644 index 0000000000000000000000000000000000000000..56b5e8608c78ce733ce99c45f99909bd8715c2ee --- /dev/null +++ b/src/gyraudio/properties.py @@ -0,0 +1,6 @@ +GYRO = "gyro" +ACCL = "accelerometer" +AUDIO = "audio" +AUDIO_RATE = "audio_rate" +GYRO_KEY = "GYRO" +ACCL_KEY = "ACCL"