Spaces:
Running
on
Zero
Running
on
Zero
"""Encoder self-attention layer definition.""" | |
import math | |
import pdb | |
from functools import partial | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, strtobool | |
try: | |
from mamba_ssm.modules.mamba_simple import Mamba, Block | |
from mamba_ssm.models.mixer_seq_simple import _init_weights | |
from mamba_ssm.ops.triton.layernorm import RMSNorm | |
except ImportError: | |
print("Please install mamba_ssm to use MambaSSM component.") | |
class MambaBlock(nn.Module): | |
def __init__(self, in_channels, n_layer=1, d_state=16, d_conv=4, expand=4, bidirectional=False): | |
super(MambaBlock, self).__init__() | |
self.forward_blocks = nn.ModuleList([]) | |
self.forward_norm_f = RMSNorm(in_channels, eps=1e-5) | |
for i in range(n_layer): | |
self.forward_blocks.append( | |
Block( | |
in_channels, | |
mixer_cls=partial( | |
Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand | |
), | |
norm_cls=partial(RMSNorm, eps=1e-5), | |
fused_add_norm=True, | |
residual_in_fp32=True, | |
) | |
) | |
if bidirectional: | |
self.backward_blocks = nn.ModuleList([]) | |
for i in range(n_layer): | |
self.backward_blocks.append( | |
Block( | |
in_channels, | |
mixer_cls=partial( | |
Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand | |
), | |
norm_cls=partial(RMSNorm, eps=1e-5), | |
fused_add_norm=True, | |
residual_in_fp32=True, | |
) | |
) | |
self.backward_norm_f = RMSNorm(in_channels, eps=1e-5) | |
else: | |
self.backward_blocks = None | |
self.apply(partial(_init_weights, n_layer=n_layer)) | |
def forward(self, input): | |
for_residual = None | |
forward_f = input.clone() | |
for block in self.forward_blocks: | |
forward_f, for_residual = block(forward_f, for_residual, inference_params=None) | |
residual = (forward_f + for_residual) if for_residual is not None else forward_f | |
residual = self.forward_norm_f(residual) | |
if self.backward_blocks is not None: | |
back_residual = None | |
backward_f = torch.flip(input, [1]) | |
for block in self.backward_blocks: | |
backward_f, back_residual = block(backward_f, back_residual, inference_params=None) | |
back_residual = ( | |
(backward_f + back_residual) if back_residual is not None else backward_f | |
) | |
back_residual = torch.flip(back_residual, [1]) | |
back_residual = self.backward_norm_f(back_residual) | |
residual = torch.cat([residual, back_residual], -1) | |
return residual | |
class MambaSSM(torch.nn.Module): | |
def add_arguments(group): | |
"""Add TDNN common arguments.""" | |
group.add_argument( | |
"--mamba-num-layers", default=4, type=int, help="Output dim of MambaSSM." | |
) | |
group.add_argument( | |
"--mamba-input-dim", default=256, type=int, help="Input dim of MambaSSM." | |
) | |
group.add_argument( | |
"--mamba-output-dim", default=256, type=int, help="Output dim of MambaSSM." | |
) | |
group.add_argument("--mamba-d-state", default=16, type=int, help="d-state of MambaSSM.") | |
group.add_argument("--mamba-d-conv", default=4, type=int, help="d-conv of MambaSSM.") | |
group.add_argument("--mamba-expand", default=4, type=int, help="expand of MambaSSM.") | |
return group | |
def __init__(self, args): | |
"""Construct an Encoder object.""" | |
super(MambaSSM, self).__init__() | |
self.mamb_num_layers = args.mamba_num_layers | |
self.mamba_input_dim = args.mamba_input_dim | |
self.mamba_output_dim = args.mamba_output_dim | |
self.mamba_d_state = args.mamba_d_state | |
self.mamba_d_conv = args.mamba_d_conv | |
self.mamba_expand = args.mamba_expand | |
self.mamba = MambaBlock( | |
self.mamba_input_dim, | |
self.mamb_num_layers, | |
self.mamba_d_state, | |
self.mamba_d_conv, | |
self.mamba_expand, | |
) | |
def forward(self, xs, ilens=None, masks=None): | |
"""Embed positions in tensor. | |
:param torch.Tensor xs: input tensor | |
:param torch.Tensor masks: input mask | |
:return: position embedded tensor and mask | |
:rtype Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
xs_out = self.mamba(xs) | |
return xs_out.to(xs.dtype), ilens, masks | |