Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,845 Bytes
bc752b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""Encoder self-attention layer definition."""
import math
import pdb
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, strtobool
try:
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.models.mixer_seq_simple import _init_weights
from mamba_ssm.ops.triton.layernorm import RMSNorm
except ImportError:
print("Please install mamba_ssm to use MambaSSM component.")
class MambaBlock(nn.Module):
def __init__(self, in_channels, n_layer=1, d_state=16, d_conv=4, expand=4, bidirectional=False):
super(MambaBlock, self).__init__()
self.forward_blocks = nn.ModuleList([])
self.forward_norm_f = RMSNorm(in_channels, eps=1e-5)
for i in range(n_layer):
self.forward_blocks.append(
Block(
in_channels,
mixer_cls=partial(
Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand
),
norm_cls=partial(RMSNorm, eps=1e-5),
fused_add_norm=True,
residual_in_fp32=True,
)
)
if bidirectional:
self.backward_blocks = nn.ModuleList([])
for i in range(n_layer):
self.backward_blocks.append(
Block(
in_channels,
mixer_cls=partial(
Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand
),
norm_cls=partial(RMSNorm, eps=1e-5),
fused_add_norm=True,
residual_in_fp32=True,
)
)
self.backward_norm_f = RMSNorm(in_channels, eps=1e-5)
else:
self.backward_blocks = None
self.apply(partial(_init_weights, n_layer=n_layer))
def forward(self, input):
for_residual = None
forward_f = input.clone()
for block in self.forward_blocks:
forward_f, for_residual = block(forward_f, for_residual, inference_params=None)
residual = (forward_f + for_residual) if for_residual is not None else forward_f
residual = self.forward_norm_f(residual)
if self.backward_blocks is not None:
back_residual = None
backward_f = torch.flip(input, [1])
for block in self.backward_blocks:
backward_f, back_residual = block(backward_f, back_residual, inference_params=None)
back_residual = (
(backward_f + back_residual) if back_residual is not None else backward_f
)
back_residual = torch.flip(back_residual, [1])
back_residual = self.backward_norm_f(back_residual)
residual = torch.cat([residual, back_residual], -1)
return residual
class MambaSSM(torch.nn.Module):
@staticmethod
def add_arguments(group):
"""Add TDNN common arguments."""
group.add_argument(
"--mamba-num-layers", default=4, type=int, help="Output dim of MambaSSM."
)
group.add_argument(
"--mamba-input-dim", default=256, type=int, help="Input dim of MambaSSM."
)
group.add_argument(
"--mamba-output-dim", default=256, type=int, help="Output dim of MambaSSM."
)
group.add_argument("--mamba-d-state", default=16, type=int, help="d-state of MambaSSM.")
group.add_argument("--mamba-d-conv", default=4, type=int, help="d-conv of MambaSSM.")
group.add_argument("--mamba-expand", default=4, type=int, help="expand of MambaSSM.")
return group
def __init__(self, args):
"""Construct an Encoder object."""
super(MambaSSM, self).__init__()
self.mamb_num_layers = args.mamba_num_layers
self.mamba_input_dim = args.mamba_input_dim
self.mamba_output_dim = args.mamba_output_dim
self.mamba_d_state = args.mamba_d_state
self.mamba_d_conv = args.mamba_d_conv
self.mamba_expand = args.mamba_expand
self.mamba = MambaBlock(
self.mamba_input_dim,
self.mamb_num_layers,
self.mamba_d_state,
self.mamba_d_conv,
self.mamba_expand,
)
@torch.jit.unused
def forward(self, xs, ilens=None, masks=None):
"""Embed positions in tensor.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
xs_out = self.mamba(xs)
return xs_out.to(xs.dtype), ilens, masks
|