|
from dataclasses import dataclass, fields, asdict |
|
import json |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from mamba import Mamba, MambaConfig, RMSNorm |
|
|
|
""" |
|
|
|
Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits. |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class MambaLMConfig(MambaConfig): |
|
vocab_size: int = 32000 |
|
pad_vocab_size_multiple: int = 8 |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
|
|
|
|
|
|
def to_mamba_config(self) -> MambaConfig: |
|
mamba_config_fields = {field.name for field in fields(MambaConfig)} |
|
filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields} |
|
return MambaConfig(**filtered_dict) |
|
|
|
|
|
def from_pretrained(name: str): |
|
""" |
|
Returns a model loaded with pretrained weights pulled from HuggingFace. |
|
|
|
Args: |
|
name: As of now, supports |
|
* 'state-spaces/mamba-2.8b-slimpj' |
|
* 'state-spaces/mamba-2.8b' |
|
* 'state-spaces/mamba-1.4b' |
|
* 'state-spaces/mamba-790m' |
|
* 'state-spaces/mamba-370m' |
|
* 'state-spaces/mamba-130m' |
|
|
|
Returns: |
|
model: a Mamba model configured with the proper parameters and initialized with the proper weights |
|
""" |
|
|
|
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME |
|
from transformers.utils.hub import cached_file |
|
|
|
def load_config_hf(model_name): |
|
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) |
|
return json.load(open(resolved_archive_file)) |
|
|
|
def load_state_dict_hf(model_name): |
|
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) |
|
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True) |
|
|
|
|
|
config_data = load_config_hf(name) |
|
config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size']) |
|
|
|
model = MambaLM(config) |
|
|
|
|
|
state_dict = load_state_dict_hf(name) |
|
|
|
new_state_dict = {} |
|
for key in state_dict: |
|
if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight': |
|
new_key = key.replace('backbone.', '') |
|
else: |
|
new_key = key.replace('backbone', 'mamba') |
|
|
|
new_state_dict[new_key] = state_dict[key] |
|
|
|
model.load_state_dict(new_state_dict) |
|
|
|
return model |
|
|
|
class MambaLM(nn.Module): |
|
def __init__(self, lm_config: MambaLMConfig): |
|
super().__init__() |
|
self.lm_config = lm_config |
|
self.config = lm_config.to_mamba_config() |
|
|
|
self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model) |
|
self.mamba = Mamba(self.config) |
|
self.norm_f = RMSNorm(self.config.d_model) |
|
|
|
self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False) |
|
self.lm_head.weight = self.embedding.weight |
|
|
|
def forward(self, tokens): |
|
|
|
|
|
|
|
|
|
x = self.embedding(tokens) |
|
|
|
x = self.mamba(x) |
|
x = self.norm_f(x) |
|
|
|
logits = self.lm_head(x) |
|
|
|
return logits |
|
|
|
def step(self, token, caches): |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.embedding(token) |
|
|
|
x, caches = self.mamba.step(x, caches) |
|
x = self.norm_f(x) |
|
|
|
logits = self.lm_head(x) |
|
|
|
return logits, caches |
|
|
|
|
|
|
|
def generate(self, tokenizer, prompt: str, num_tokens: int = 50, sample: bool = True, top_k: int = 40): |
|
self.eval() |
|
|
|
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) |
|
|
|
|
|
|
|
|
|
|
|
caches = [(None, torch.zeros(1, self.config.d_inner, self.config.d_conv-1, device=input_ids.device)) for _ in range(self.config.n_layers)] |
|
|
|
for i in range(input_ids.size(1) + num_tokens - 1): |
|
with torch.no_grad(): |
|
|
|
next_token_logits, caches = self.step(input_ids[:, i], caches) |
|
|
|
|
|
if i+1 >= input_ids.size(1): |
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
|
if top_k is not None: |
|
values, _ = torch.topk(probs, k=top_k) |
|
probs[probs < values[:, -1, None]] = 0 |
|
probs = probs / probs.sum(axis=1, keepdims=True) |
|
|
|
if sample: |
|
next_token = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
else: |
|
next_token = torch.argmax(probs, dim=-1) |
|
|
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1) |
|
|
|
output = [tokenizer.decode(output.tolist()) for output in input_ids][0] |
|
|
|
self.train() |
|
|
|
return output |
|
|
|
|