Spaces:
Running
Running
File size: 3,522 Bytes
67c46fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
"""MLP with convolutional gating (cgMLP) definition.
References:
https://openreview.net/forum?id=RA-zVvZLYIy
https://arxiv.org/abs/2105.08050
"""
import torch
from funasr_detach.models.transformer.utils.nets_utils import get_activation
from funasr_detach.models.transformer.layer_norm import LayerNorm
class ConvolutionalSpatialGatingUnit(torch.nn.Module):
"""Convolutional Spatial Gating Unit (CSGU)."""
def __init__(
self,
size: int,
kernel_size: int,
dropout_rate: float,
use_linear_after_conv: bool,
gate_activation: str,
):
super().__init__()
n_channels = size // 2 # split input channels
self.norm = LayerNorm(n_channels)
self.conv = torch.nn.Conv1d(
n_channels,
n_channels,
kernel_size,
1,
(kernel_size - 1) // 2,
groups=n_channels,
)
if use_linear_after_conv:
self.linear = torch.nn.Linear(n_channels, n_channels)
else:
self.linear = None
if gate_activation == "identity":
self.act = torch.nn.Identity()
else:
self.act = get_activation(gate_activation)
self.dropout = torch.nn.Dropout(dropout_rate)
def espnet_initialization_fn(self):
torch.nn.init.normal_(self.conv.weight, std=1e-6)
torch.nn.init.ones_(self.conv.bias)
if self.linear is not None:
torch.nn.init.normal_(self.linear.weight, std=1e-6)
torch.nn.init.ones_(self.linear.bias)
def forward(self, x, gate_add=None):
"""Forward method
Args:
x (torch.Tensor): (N, T, D)
gate_add (torch.Tensor): (N, T, D/2)
Returns:
out (torch.Tensor): (N, T, D/2)
"""
x_r, x_g = x.chunk(2, dim=-1)
x_g = self.norm(x_g) # (N, T, D/2)
x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
if self.linear is not None:
x_g = self.linear(x_g)
if gate_add is not None:
x_g = x_g + gate_add
x_g = self.act(x_g)
out = x_r * x_g # (N, T, D/2)
out = self.dropout(out)
return out
class ConvolutionalGatingMLP(torch.nn.Module):
"""Convolutional Gating MLP (cgMLP)."""
def __init__(
self,
size: int,
linear_units: int,
kernel_size: int,
dropout_rate: float,
use_linear_after_conv: bool,
gate_activation: str,
):
super().__init__()
self.channel_proj1 = torch.nn.Sequential(
torch.nn.Linear(size, linear_units), torch.nn.GELU()
)
self.csgu = ConvolutionalSpatialGatingUnit(
size=linear_units,
kernel_size=kernel_size,
dropout_rate=dropout_rate,
use_linear_after_conv=use_linear_after_conv,
gate_activation=gate_activation,
)
self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
def forward(self, x, mask):
if isinstance(x, tuple):
xs_pad, pos_emb = x
else:
xs_pad, pos_emb = x, None
xs_pad = self.channel_proj1(xs_pad) # size -> linear_units
xs_pad = self.csgu(xs_pad) # linear_units -> linear_units/2
xs_pad = self.channel_proj2(xs_pad) # linear_units/2 -> size
if pos_emb is not None:
out = (xs_pad, pos_emb)
else:
out = xs_pad
return out
|