import torch import transformers from transformers import AutoTokenizer from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from lm_eval.api.model import LM from lm_eval.models.huggingface import HFLM from lm_eval.api.registry import register_model from lm_eval.__main__ import cli_evaluate @register_model("mamba") class MambaEvalWrapper(HFLM): AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", dtype=torch.float16): LM.__init__(self) self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.vocab_size = self.tokenizer.vocab_size self._batch_size = int(batch_size) if batch_size is not None else 64 self._max_length = max_length self._device = torch.device(device) @property def batch_size(self): return self._batch_size def _model_generate(self, context, max_length, stop, **generation_kwargs): raise NotImplementedError() if __name__ == "__main__": cli_evaluate()