File size: 7,311 Bytes
51e2f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import torch
import torch.nn as nn
import torch.nn.functional as Func

class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return Func.normalize(x, dim=-1) * self.scale * self.gamma


class MambaModule(nn.Module):
    def __init__(self, d_model, d_state, d_conv, d_expand):
        super().__init__()
        self.norm = RMSNorm(dim=d_model)
        self.mamba = Mamba(
                d_model=d_model, 
                d_state=d_state, 
                d_conv=d_conv, 
                d_expand=d_expand
            )

    def forward(self, x):
        x = x + self.mamba(self.norm(x))
        return x


class RNNModule(nn.Module):
    """

    RNNModule class implements a recurrent neural network module with LSTM cells.



    Args:

    - input_dim (int): Dimensionality of the input features.

    - hidden_dim (int): Dimensionality of the hidden state of the LSTM.

    - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True.



    Shapes:

    - Input: (B, T, D) where

        B is batch size,

        T is sequence length,

        D is input dimensionality.

    - Output: (B, T, D) where

        B is batch size,

        T is sequence length,

        D is input dimensionality.

    """

    def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True):
        """

        Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag.

        """
        super().__init__()
        self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim)
        self.rnn = nn.LSTM(
            input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional
        )
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Performs forward pass through the RNNModule.



        Args:

        - x (torch.Tensor): Input tensor of shape (B, T, D).



        Returns:

        - torch.Tensor: Output tensor of shape (B, T, D).

        """
        x = x.transpose(1, 2)
        x = self.groupnorm(x)
        x = x.transpose(1, 2)

        x, (hidden, _) = self.rnn(x)
        x = self.fc(x)
        return x


class RFFTModule(nn.Module):
    """

    RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT)

    or its inverse on input tensors.



    Args:

    - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False.



    Shapes:

    - Input: (B, F, T, D) where

        B is batch size,

        F is the number of features,

        T is sequence length,

        D is input dimensionality.

    - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT.

              (B, F, T, D // 2, 2) if performing inverse FFT.

    """

    def __init__(self, inverse: bool = False):
        """

        Initializes RFFTModule with inverse flag.

        """
        super().__init__()
        self.inverse = inverse

    def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor:
        """

        Performs forward or inverse FFT on the input tensor x.



        Args:

        - x (torch.Tensor): Input tensor of shape (B, F, T, D).

        - time_dim (int): Input size of time dimension.



        Returns:

        - torch.Tensor: Output tensor after FFT or its inverse operation.

        """
        dtype = x.dtype
        B, F, T, D = x.shape

        # RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision
        x = x.float()

        if not self.inverse:
            x = torch.fft.rfft(x, dim=2)
            x = torch.view_as_real(x)
            x = x.reshape(B, F, T // 2 + 1, D * 2)
        else:
            x = x.reshape(B, F, T, D // 2, 2)
            x = torch.view_as_complex(x)
            x = torch.fft.irfft(x, n=time_dim, dim=2)
        
        x = x.to(dtype)
        return x

    def extra_repr(self) -> str:
        """

        Returns extra representation string with module's configuration.

        """
        return f"inverse={self.inverse}"


class DualPathRNN(nn.Module):
    """

    DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule.



    Args:

    - n_layers (int): Number of layers in the network.

    - input_dim (int): Dimensionality of the input features.

    - hidden_dim (int): Dimensionality of the hidden state of the RNNModule.



    Shapes:

    - Input: (B, F, T, D) where

        B is batch size,

        F is the number of features (frequency dimension),

        T is sequence length (time dimension),

        D is input dimensionality (channel dimension).

    - Output: (B, F, T, D) where

        B is batch size,

        F is the number of features (frequency dimension),

        T is sequence length (time dimension),

        D is input dimensionality (channel dimension).

    """

    def __init__(

        self,

        n_layers: int,

        input_dim: int,

        hidden_dim: int,



        use_mamba: bool = False,

        d_state: int = 16,

        d_conv: int = 4,

        d_expand: int = 2

    ):
        """

        Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension.

        """
        super().__init__()

        if use_mamba:
            from mamba_ssm.modules.mamba_simple import Mamba
            net = MambaModule
            dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand}
            ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2}
        else:
            net = RNNModule
            dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim}
            ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2}

        self.layers = nn.ModuleList()
        for i in range(1, n_layers + 1):
            kwargs = dkwargs if i % 2 == 1 else ukwargs
            layer = nn.ModuleList([
                net(**kwargs),
                net(**kwargs),
                RFFTModule(inverse=(i % 2 == 0)),
            ])
            self.layers.append(layer)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Performs forward pass through the DualPathRNN.



        Args:

        - x (torch.Tensor): Input tensor of shape (B, F, T, D).



        Returns:

        - torch.Tensor: Output tensor of shape (B, F, T, D).

        """

        time_dim = x.shape[2]

        for time_layer, freq_layer, rfft_layer in self.layers:
            B, F, T, D = x.shape

            x = x.reshape((B * F), T, D)
            x = time_layer(x)
            x = x.reshape(B, F, T, D)
            x = x.permute(0, 2, 1, 3)

            x = x.reshape((B * T), F, D)
            x = freq_layer(x)
            x = x.reshape(B, T, F, D)
            x = x.permute(0, 2, 1, 3)

            x = rfft_layer(x, time_dim)
        
        return x