"""Encoder self-attention layer definition.""" import math import pdb from functools import partial import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, strtobool try: from mamba_ssm.modules.mamba_simple import Mamba, Block from mamba_ssm.models.mixer_seq_simple import _init_weights from mamba_ssm.ops.triton.layernorm import RMSNorm except ImportError: print("Please install mamba_ssm to use MambaSSM component.") class MambaBlock(nn.Module): def __init__(self, in_channels, n_layer=1, d_state=16, d_conv=4, expand=4, bidirectional=False): super(MambaBlock, self).__init__() self.forward_blocks = nn.ModuleList([]) self.forward_norm_f = RMSNorm(in_channels, eps=1e-5) for i in range(n_layer): self.forward_blocks.append( Block( in_channels, mixer_cls=partial( Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand ), norm_cls=partial(RMSNorm, eps=1e-5), fused_add_norm=True, residual_in_fp32=True, ) ) if bidirectional: self.backward_blocks = nn.ModuleList([]) for i in range(n_layer): self.backward_blocks.append( Block( in_channels, mixer_cls=partial( Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand ), norm_cls=partial(RMSNorm, eps=1e-5), fused_add_norm=True, residual_in_fp32=True, ) ) self.backward_norm_f = RMSNorm(in_channels, eps=1e-5) else: self.backward_blocks = None self.apply(partial(_init_weights, n_layer=n_layer)) def forward(self, input): for_residual = None forward_f = input.clone() for block in self.forward_blocks: forward_f, for_residual = block(forward_f, for_residual, inference_params=None) residual = (forward_f + for_residual) if for_residual is not None else forward_f residual = self.forward_norm_f(residual) if self.backward_blocks is not None: back_residual = None backward_f = torch.flip(input, [1]) for block in self.backward_blocks: backward_f, back_residual = block(backward_f, back_residual, inference_params=None) back_residual = ( (backward_f + back_residual) if back_residual is not None else backward_f ) back_residual = torch.flip(back_residual, [1]) back_residual = self.backward_norm_f(back_residual) residual = torch.cat([residual, back_residual], -1) return residual class MambaSSM(torch.nn.Module): @staticmethod def add_arguments(group): """Add TDNN common arguments.""" group.add_argument( "--mamba-num-layers", default=4, type=int, help="Output dim of MambaSSM." ) group.add_argument( "--mamba-input-dim", default=256, type=int, help="Input dim of MambaSSM." ) group.add_argument( "--mamba-output-dim", default=256, type=int, help="Output dim of MambaSSM." ) group.add_argument("--mamba-d-state", default=16, type=int, help="d-state of MambaSSM.") group.add_argument("--mamba-d-conv", default=4, type=int, help="d-conv of MambaSSM.") group.add_argument("--mamba-expand", default=4, type=int, help="expand of MambaSSM.") return group def __init__(self, args): """Construct an Encoder object.""" super(MambaSSM, self).__init__() self.mamb_num_layers = args.mamba_num_layers self.mamba_input_dim = args.mamba_input_dim self.mamba_output_dim = args.mamba_output_dim self.mamba_d_state = args.mamba_d_state self.mamba_d_conv = args.mamba_d_conv self.mamba_expand = args.mamba_expand self.mamba = MambaBlock( self.mamba_input_dim, self.mamb_num_layers, self.mamba_d_state, self.mamba_d_conv, self.mamba_expand, ) @torch.jit.unused def forward(self, xs, ilens=None, masks=None): """Embed positions in tensor. :param torch.Tensor xs: input tensor :param torch.Tensor masks: input mask :return: position embedded tensor and mask :rtype Tuple[torch.Tensor, torch.Tensor]: """ xs_out = self.mamba(xs) return xs_out.to(xs.dtype), ilens, masks