File size: 6,572 Bytes
0215062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Caduceus config for Hugging Face.

"""

from typing import Optional, Union

from transformers import PretrainedConfig


class CaduceusConfig(PretrainedConfig):
    """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""

    model_type = "caduceus"

    def __init__(
        self,
        # From original MambaConfig
        d_model: int = 2560,
        d_intermediate: int = 0,
        use_mamba2: bool = False,
        n_layer: int = 64,
        vocab_size: int = 50277,
        ssm_cfg: Optional[dict] = None,
        rms_norm: bool = True,
        residual_in_fp32: bool = True,
        fused_add_norm: bool = True,
        pad_vocab_size_multiple: int = 8,
        # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
        norm_epsilon: float = 1e-5,
        # Used in init_weights
        initializer_cfg: Optional[dict] = None,
        # Caduceus-specific params
        bidirectional: bool = True,
        bidirectional_strategy: Union[str, None] = "add",
        bidirectional_weight_tie: bool = True,
        rcps: bool = False,
        complement_map: Optional[dict] = None,  # used for RCPSEmbedding / RCPSLMHead
        pos_embeddings: Optional[str] = None,
        row_first: Optional[bool] = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.d_intermediate = d_intermediate
        self.use_mamba2 = use_mamba2
        self.n_layer = n_layer
        self.vocab_size = vocab_size
        self.ssm_cfg = ssm_cfg
        self.rms_norm = rms_norm
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.pad_vocab_size_multiple = pad_vocab_size_multiple
        self.norm_epsilon = norm_epsilon
        self.initializer_cfg = initializer_cfg
        self.bidirectional = bidirectional
        self.bidirectional_strategy = bidirectional_strategy
        self.bidirectional_weight_tie = bidirectional_weight_tie
        self.rcps = rcps
        self.complement_map = complement_map
        self.pos_embeddings = pos_embeddings
        self.row_first = row_first

class AxialCaduceusConfig(PretrainedConfig):
    """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""

    model_type = "axial_caduceus"

    def __init__(
        self,
        # From original MambaConfig
        d_model: int = 2560,
        d_intermediate: int = 0,
        use_mamba2: bool = False,
        n_layer: int = 64,
        vocab_size: int = 50277,
        ssm_cfg: Optional[dict] = None,
        rms_norm: bool = True,
        residual_in_fp32: bool = True,
        fused_add_norm: bool = True,
        pad_vocab_size_multiple: int = 8,
        # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
        norm_epsilon: float = 1e-5,
        # Used in init_weights
        initializer_cfg: Optional[dict] = None,
        # Caduceus-specific params
        bidirectional: bool = True,
        bidirectional_strategy: Union[str, None] = "add",
        bidirectional_weight_tie: bool = True,
        rcps: bool = False,
        complement_map: Optional[dict] = None,  # used for RCPSEmbedding / RCPSLMHead
        pos_embeddings: Optional[str] = None,
        row_first: Optional[bool] = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.d_intermediate = d_intermediate
        self.use_mamba2 = use_mamba2
        self.n_layer = n_layer
        self.vocab_size = vocab_size
        self.ssm_cfg = ssm_cfg
        self.rms_norm = rms_norm
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.pad_vocab_size_multiple = pad_vocab_size_multiple
        self.norm_epsilon = norm_epsilon
        self.initializer_cfg = initializer_cfg
        self.bidirectional = bidirectional
        self.bidirectional_strategy = bidirectional_strategy
        self.bidirectional_weight_tie = bidirectional_weight_tie
        self.rcps = rcps
        self.complement_map = complement_map
        self.pos_embeddings = pos_embeddings
        self.row_first = row_first



class MixedCaduceusConfig(PretrainedConfig):
    """Config that extends the original CaduceusConfig with params relevant to alternating between attention and caducues"""

    model_type = "mixed_caduceus"

    def __init__(
        self,
        # From original MambaConfig
        d_model: int = 2560,
        d_intermediate: int = 0,
        use_mamba2: bool = False,
        n_layer: int = 64,
        vocab_size: int = 50277,
        ssm_cfg: Optional[dict] = None,
        rms_norm: bool = True,
        residual_in_fp32: bool = True,
        fused_add_norm: bool = True,
        pad_vocab_size_multiple: int = 8,
        # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
        norm_epsilon: float = 1e-5,
        # Used in init_weights
        initializer_cfg: Optional[dict] = None,
        # Caduceus-specific params
        bidirectional: bool = True,
        bidirectional_strategy: Union[str, None] = "add",
        bidirectional_weight_tie: bool = True,
        rcps: bool = False,
        complement_map: Optional[dict] = None,  # used for RCPSEmbedding / RCPSLMHead
        # attention specific params
        attn_d_model: int = 128,
        attn_n_heads: int = 16,
        attn_attn_dropout: float = 0.1,
        attn_block_dropout: float = 0.1,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.d_intermediate = d_intermediate
        self.use_mamba2 = use_mamba2
        self.n_layer = n_layer
        self.vocab_size = vocab_size
        self.ssm_cfg = ssm_cfg
        self.rms_norm = rms_norm
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.pad_vocab_size_multiple = pad_vocab_size_multiple
        self.norm_epsilon = norm_epsilon
        self.initializer_cfg = initializer_cfg
        self.bidirectional = bidirectional
        self.bidirectional_strategy = bidirectional_strategy
        self.bidirectional_weight_tie = bidirectional_weight_tie
        self.rcps = rcps
        self.complement_map = complement_map
        self.attn_d_model = attn_d_model
        self.attn_n_heads = attn_n_heads
        self.attn_attn_dropout = attn_attn_dropout
        self.attn_block_dropout = attn_block_dropout