import torch import torch.nn as nn from einops import rearrange from .mamba_block import TFMambaBlock from .codec_module import DenseEncoder, MagDecoder, PhaseDecoder class SEMamba(nn.Module): """ SEMamba model for speech enhancement using Mamba blocks. This model uses a dense encoder, multiple Mamba blocks, and separate magnitude and phase decoders to process noisy magnitude and phase inputs. """ def __init__(self, cfg): """ Initialize the SEMamba model. Args: - cfg: Configuration object containing model parameters. """ super(SEMamba, self).__init__() self.cfg = cfg self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4 # Initialize dense encoder self.dense_encoder = DenseEncoder(cfg) # Initialize Mamba blocks self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)]) # Initialize decoders self.mask_decoder = MagDecoder(cfg) self.phase_decoder = PhaseDecoder(cfg) def forward(self, noisy_mag, noisy_pha): """ Forward pass for the SEMamba model. Args: - noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T]. - noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T]. Returns: - denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T]. - denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T]. - denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2]. """ # Reshape inputs noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F] noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F] # Concatenate magnitude and phase inputs x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F] # Encode input x = self.dense_encoder(x) # Apply Mamba blocks for block in self.TSMamba: x = block(x) # Decode magnitude and phase denoised_mag = rearrange(self.mask_decoder(x) * noisy_mag, 'b c t f -> b f t c').squeeze(-1) denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1) # Combine denoised magnitude and phase into a complex representation denoised_com = torch.stack( (denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)), dim=-1 ) return denoised_mag, denoised_pha, denoised_com