|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as Func
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.scale = dim ** 0.5
|
|
self.gamma = nn.Parameter(torch.ones(dim))
|
|
|
|
def forward(self, x):
|
|
return Func.normalize(x, dim=-1) * self.scale * self.gamma
|
|
|
|
|
|
class MambaModule(nn.Module):
|
|
def __init__(self, d_model, d_state, d_conv, d_expand):
|
|
super().__init__()
|
|
self.norm = RMSNorm(dim=d_model)
|
|
self.mamba = Mamba(
|
|
d_model=d_model,
|
|
d_state=d_state,
|
|
d_conv=d_conv,
|
|
d_expand=d_expand
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x + self.mamba(self.norm(x))
|
|
return x
|
|
|
|
|
|
class RNNModule(nn.Module):
|
|
"""
|
|
RNNModule class implements a recurrent neural network module with LSTM cells.
|
|
|
|
Args:
|
|
- input_dim (int): Dimensionality of the input features.
|
|
- hidden_dim (int): Dimensionality of the hidden state of the LSTM.
|
|
- bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True.
|
|
|
|
Shapes:
|
|
- Input: (B, T, D) where
|
|
B is batch size,
|
|
T is sequence length,
|
|
D is input dimensionality.
|
|
- Output: (B, T, D) where
|
|
B is batch size,
|
|
T is sequence length,
|
|
D is input dimensionality.
|
|
"""
|
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True):
|
|
"""
|
|
Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag.
|
|
"""
|
|
super().__init__()
|
|
self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim)
|
|
self.rnn = nn.LSTM(
|
|
input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional
|
|
)
|
|
self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Performs forward pass through the RNNModule.
|
|
|
|
Args:
|
|
- x (torch.Tensor): Input tensor of shape (B, T, D).
|
|
|
|
Returns:
|
|
- torch.Tensor: Output tensor of shape (B, T, D).
|
|
"""
|
|
x = x.transpose(1, 2)
|
|
x = self.groupnorm(x)
|
|
x = x.transpose(1, 2)
|
|
|
|
x, (hidden, _) = self.rnn(x)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
|
|
class RFFTModule(nn.Module):
|
|
"""
|
|
RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT)
|
|
or its inverse on input tensors.
|
|
|
|
Args:
|
|
- inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False.
|
|
|
|
Shapes:
|
|
- Input: (B, F, T, D) where
|
|
B is batch size,
|
|
F is the number of features,
|
|
T is sequence length,
|
|
D is input dimensionality.
|
|
- Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT.
|
|
(B, F, T, D // 2, 2) if performing inverse FFT.
|
|
"""
|
|
|
|
def __init__(self, inverse: bool = False):
|
|
"""
|
|
Initializes RFFTModule with inverse flag.
|
|
"""
|
|
super().__init__()
|
|
self.inverse = inverse
|
|
|
|
def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor:
|
|
"""
|
|
Performs forward or inverse FFT on the input tensor x.
|
|
|
|
Args:
|
|
- x (torch.Tensor): Input tensor of shape (B, F, T, D).
|
|
- time_dim (int): Input size of time dimension.
|
|
|
|
Returns:
|
|
- torch.Tensor: Output tensor after FFT or its inverse operation.
|
|
"""
|
|
dtype = x.dtype
|
|
B, F, T, D = x.shape
|
|
|
|
|
|
x = x.float()
|
|
|
|
if not self.inverse:
|
|
x = torch.fft.rfft(x, dim=2)
|
|
x = torch.view_as_real(x)
|
|
x = x.reshape(B, F, T // 2 + 1, D * 2)
|
|
else:
|
|
x = x.reshape(B, F, T, D // 2, 2)
|
|
x = torch.view_as_complex(x)
|
|
x = torch.fft.irfft(x, n=time_dim, dim=2)
|
|
|
|
x = x.to(dtype)
|
|
return x
|
|
|
|
def extra_repr(self) -> str:
|
|
"""
|
|
Returns extra representation string with module's configuration.
|
|
"""
|
|
return f"inverse={self.inverse}"
|
|
|
|
|
|
class DualPathRNN(nn.Module):
|
|
"""
|
|
DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule.
|
|
|
|
Args:
|
|
- n_layers (int): Number of layers in the network.
|
|
- input_dim (int): Dimensionality of the input features.
|
|
- hidden_dim (int): Dimensionality of the hidden state of the RNNModule.
|
|
|
|
Shapes:
|
|
- Input: (B, F, T, D) where
|
|
B is batch size,
|
|
F is the number of features (frequency dimension),
|
|
T is sequence length (time dimension),
|
|
D is input dimensionality (channel dimension).
|
|
- Output: (B, F, T, D) where
|
|
B is batch size,
|
|
F is the number of features (frequency dimension),
|
|
T is sequence length (time dimension),
|
|
D is input dimensionality (channel dimension).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
n_layers: int,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
|
|
use_mamba: bool = False,
|
|
d_state: int = 16,
|
|
d_conv: int = 4,
|
|
d_expand: int = 2
|
|
):
|
|
"""
|
|
Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension.
|
|
"""
|
|
super().__init__()
|
|
|
|
if use_mamba:
|
|
from mamba_ssm.modules.mamba_simple import Mamba
|
|
net = MambaModule
|
|
dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand}
|
|
ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2}
|
|
else:
|
|
net = RNNModule
|
|
dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim}
|
|
ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2}
|
|
|
|
self.layers = nn.ModuleList()
|
|
for i in range(1, n_layers + 1):
|
|
kwargs = dkwargs if i % 2 == 1 else ukwargs
|
|
layer = nn.ModuleList([
|
|
net(**kwargs),
|
|
net(**kwargs),
|
|
RFFTModule(inverse=(i % 2 == 0)),
|
|
])
|
|
self.layers.append(layer)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Performs forward pass through the DualPathRNN.
|
|
|
|
Args:
|
|
- x (torch.Tensor): Input tensor of shape (B, F, T, D).
|
|
|
|
Returns:
|
|
- torch.Tensor: Output tensor of shape (B, F, T, D).
|
|
"""
|
|
|
|
time_dim = x.shape[2]
|
|
|
|
for time_layer, freq_layer, rfft_layer in self.layers:
|
|
B, F, T, D = x.shape
|
|
|
|
x = x.reshape((B * F), T, D)
|
|
x = time_layer(x)
|
|
x = x.reshape(B, F, T, D)
|
|
x = x.permute(0, 2, 1, 3)
|
|
|
|
x = x.reshape((B * T), F, D)
|
|
x = freq_layer(x)
|
|
x = x.reshape(B, T, F, D)
|
|
x = x.permute(0, 2, 1, 3)
|
|
|
|
x = rfft_layer(x, time_dim)
|
|
|
|
return x
|
|
|