#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816 """ from typing import Tuple import torch import torch.nn as nn from torch.nn import functional as F import torchaudio def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torch.Tensor: if n_frame % 2 == 0: n_frame += 1 n_frame_half = n_frame // 2 # spec shape: [b, c, t, f, 2] spec = F.pad(spec.pow(2).sum(-1).sum(-1), (n_frame_half, n_frame_half, 0, 0)) # spec shape: [b, c, t-pad] weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype) # w shape: [n_frame] spec = spec.unfold(-1, size=n_frame, step=1) * weight # x shape: [b, c, t, n_frame] result = torch.sum(spec, dim=-1).div(n_frame) # result shape: [b, c, t] return result def local_snr(spec_clean: torch.Tensor, spec_noise: torch.Tensor, n_frame: int = 5, db: bool = False, eps: float = 1e-12, ): # [b, c, t, f] spec_clean = torch.view_as_real(spec_clean) spec_noise = torch.view_as_real(spec_noise) # [b, c, t, f, 2] energy_clean = local_energy(spec_clean, n_frame=n_frame, device=spec_clean.device) energy_noise = local_energy(spec_noise, n_frame=n_frame, device=spec_noise.device) # [b, c, t] snr = energy_clean / energy_noise.clamp_min(eps) # snr shape: [b, c, t] if db: snr = snr.clamp_min(eps).log10().mul(10) return snr, energy_clean, energy_noise class LocalSnrTarget(nn.Module): def __init__(self, sample_rate: int = 8000, nfft: int = 512, win_size: int = 512, hop_size: int = 256, n_frame: int = 3, min_local_snr: int = -15, max_local_snr: int = 30, db: bool = True, ): super().__init__() self.sample_rate = sample_rate self.nfft = nfft self.win_size = win_size self.hop_size = hop_size self.n_frame = n_frame self.min_local_snr = min_local_snr self.max_local_snr = max_local_snr self.db = db def forward(self, spec_clean: torch.Tensor, spec_noise: torch.Tensor, ) -> torch.Tensor: """ :param spec_clean: torch.complex, shape: [b, c, t, f] :param spec_noise: torch.complex, shape: [b, c, t, f] :return: lsnr, shape: [b, t] """ lsnr, _, _ = local_snr( spec_clean=spec_clean, spec_noise=spec_noise, n_frame=self.n_frame, db=self.db, ) # lsnr shape: [b, c, t] lsnr = lsnr.clamp(self.min_local_snr, self.max_local_snr).squeeze(1) # lsnr shape: [b, t] return lsnr def main(): sample_rate = 8000 nfft = 512 win_size = 512 hop_size = 256 window_fn = "hamming" transform = torchaudio.transforms.Spectrogram( n_fft=nfft, win_length=win_size, hop_length=hop_size, power=None, window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, ) noisy = torch.randn(size=(1, 16000), dtype=torch.float32) spec = transform.forward(noisy) spec = spec.permute(0, 2, 1) spec = torch.unsqueeze(spec, dim=1) print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}") # [b, c, t, f] # spec = torch.view_as_real(spec) # [b, c, t, f, 2] local = LocalSnrTarget( sample_rate=sample_rate, nfft=nfft, win_size=win_size, hop_size=hop_size, n_frame=5, min_local_snr=-15, max_local_snr=30, db=True, ) lsnr_target = local.forward(spec, spec) print(f"lsnr_target.shape: {lsnr_target.shape}") return if __name__ == "__main__": main()