File size: 1,866 Bytes
20a5020 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from dataclasses import dataclass, field
from typing import Literal
import torch
# https://github.com/state-spaces/mamba/blob//mamba_ssm/utils/generation.py#L18
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: torch.Tensor | None = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
@dataclass
class BackboneConfig:
d_model: int = 1024
d_intermediate: int = 0
attn_mlp_d_intermediate: int = 0
n_layer: int = 16
ssm_cfg: dict = field(default_factory=dict)
attn_layer_idx: list = field(default_factory=list)
attn_cfg: dict = field(default_factory=dict)
rms_norm: bool = False
residual_in_fp32: bool = False
norm_epsilon: float = 1e-5
@dataclass
class PrefixConditionerConfig:
conditioners: list[dict]
projection: Literal["none", "linear", "mlp"]
@dataclass
class ZonosConfig:
backbone: BackboneConfig
prefix_conditioner: PrefixConditionerConfig
eos_token_id: int = 1024
masked_token_id: int = 1025
pad_vocab_to_multiple_of: int = 8
@classmethod
def from_dict(cls, d: dict) -> "ZonosConfig":
d = d.copy()
backbone_config = BackboneConfig(**d.pop("backbone"))
prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
config = cls(backbone_config, prefix_conditioner_config, **d)
return config
|