orthrus / src /model.py
Philip Fradkin
feat: update 4track model
bc2395e
import math
from functools import partial
import os
import json
import torch
import torch.nn as nn
import numpy as np
from mamba_ssm.modules.mamba_simple import Mamba, Block
from huggingface_hub import PyTorchModelHubMixin
# convert to one hot
def seq_to_oh(seq):
oh = np.zeros((len(seq), 4), dtype=int)
for i, base in enumerate(seq):
if base == 'A':
oh[i, 0] = 1
elif base == 'C':
oh[i, 1] = 1
elif base == 'G':
oh[i, 2] = 1
elif base == 'T':
oh[i, 3] = 1
return oh
def create_block(
d_model,
ssm_cfg=None,
norm_epsilon=1e-5,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
device=None,
dtype=None,
):
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
mix_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
norm_cls = partial(nn.LayerNorm, eps=norm_epsilon, **factory_kwargs)
block = Block(
d_model,
mix_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
class MixerModel(
nn.Module,
PyTorchModelHubMixin,
):
def __init__(
self,
d_model: int,
n_layer: int,
input_dim: int,
ssm_cfg=None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.embedding = nn.Linear(input_dim, d_model, **factory_kwargs)
self.layers = nn.ModuleList(
[
create_block(
d_model,
ssm_cfg=ssm_cfg,
norm_epsilon=norm_epsilon,
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
**factory_kwargs,
)
for i in range(n_layer)
]
)
self.norm_f = nn.LayerNorm(d_model, eps=norm_epsilon, **factory_kwargs)
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
def forward(self, x, inference_params=None, channel_last=False):
if not channel_last:
x = x.transpose(1, 2)
hidden_states = self.embedding(x)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
hidden_states = hidden_states
return hidden_states
def representation(
self,
x: torch.Tensor,
lengths: torch.Tensor,
channel_last: bool = False,
) -> torch.Tensor:
"""Get global representation of input data.
Args:
x: Data to embed. Has shape (B x C x L) if not channel_last.
lengths: Unpadded length of each data input.
channel_last: Expects input of shape (B x L x C).
Returns:
Global representation vector of shape (B x H).
"""
out = self.forward(x, channel_last=channel_last)
mean_tensor = mean_unpadded(out, lengths)
return mean_tensor
def mean_unpadded(x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
"""Take mean of tensor across second dimension without padding.
Args:
x: Tensor to take unpadded mean. Has shape (B x L x H).
lengths: Tensor of unpadded lengths. Has shape (B)
Returns:
Mean tensor of shape (B x H).
"""
mask = torch.arange(x.size(1), device=x.device)[None, :] < lengths[:, None]
masked_tensor = x * mask.unsqueeze(-1)
sum_tensor = masked_tensor.sum(dim=1)
mean_tensor = sum_tensor / lengths.unsqueeze(-1).float()
return mean_tensor
def _init_weights(
module,
n_layer,
initializer_range=0.02, # Now only used for embedding layer.
rescale_prenorm_residual=True,
n_residuals_per_layer=1, # Change to 2 if we have MLP
):
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
def load_model(run_path: str, checkpoint_name: str) -> nn.Module:
"""Load trained model located at specified path.
Args:
run_path: Path where run data is located.
checkpoint_name: Name of model checkpoint to load.
Returns:
Model with loaded weights.
"""
model_config_path = os.path.join(run_path, "model_config.json")
data_config_path = os.path.join(run_path, "data_config.json")
with open(model_config_path, "r") as f:
model_params = json.load(f)
# TODO: Temp backwards compatibility
if "n_tracks" not in model_params:
with open(data_config_path, "r") as f:
data_params = json.load(f)
n_tracks = data_params["n_tracks"]
else:
n_tracks = model_params["n_tracks"]
model_path = os.path.join(run_path, checkpoint_name)
model = MixerModel(
d_model=model_params["ssm_model_dim"],
n_layer=model_params["ssm_n_layers"],
input_dim=n_tracks
)
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {}
for k, v in checkpoint["state_dict"].items():
if k.startswith("model"):
state_dict[k.lstrip("model")[1:]] = v
model.load_state_dict(state_dict)
return model