|
import torch |
|
import torch.nn as nn |
|
from mamba_ssm.models.mixer_seq_simple import create_block |
|
from mamba_ssm.ops.triton.layer_norm import layer_norm_fn |
|
from mamba_ssm.utils.generation import InferenceParams |
|
|
|
from zonos.config import BackboneConfig |
|
|
|
|
|
class ZonosBackbone(nn.Module): |
|
def __init__(self, config: BackboneConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
create_block( |
|
d_model=config.d_model, |
|
d_intermediate=config.d_intermediate |
|
if (i not in config.attn_layer_idx) |
|
else config.attn_mlp_d_intermediate, |
|
ssm_cfg=config.ssm_cfg, |
|
layer_idx=i, |
|
attn_layer_idx=config.attn_layer_idx, |
|
attn_cfg=config.attn_cfg, |
|
norm_epsilon=config.norm_epsilon, |
|
residual_in_fp32=config.residual_in_fp32, |
|
fused_add_norm=True, |
|
rms_norm=config.rms_norm, |
|
) |
|
for i in range(config.n_layer) |
|
] |
|
) |
|
|
|
self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon) |
|
|
|
def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None): |
|
residual = None |
|
for layer in self.layers: |
|
hidden_states, residual = layer(hidden_states, residual, inference_params) |
|
|
|
return layer_norm_fn( |
|
hidden_states, |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
residual, |
|
eps=self.norm_f.eps, |
|
residual_in_fp32=self.config.residual_in_fp32, |
|
is_rms_norm=self.config.rms_norm, |
|
) |
|
|