Yanisadel commited on
Commit
03e9d8a
·
verified ·
1 Parent(s): 9df085d

Upload model

Browse files
chatNT.py ADDED
@@ -0,0 +1,1896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file stores ChatNT and all associated layers and configs
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F # noqa: N812
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+
12
+
13
+ @dataclass
14
+ class RotaryEmbeddingConfig:
15
+ """
16
+ Rotary Positional Embedding configuration
17
+ max_seq_len: The number of positions to encode and cache.
18
+ dim: Dimension of RoPE.
19
+ theta: Rotation angle.
20
+ """
21
+
22
+ max_seq_len: int
23
+ dim: int
24
+ theta: float
25
+
26
+
27
+ @dataclass
28
+ class PerceiverResamplerConfig:
29
+ """
30
+ Parameters to initialize an PerceiverResampler model.
31
+
32
+ Args:
33
+ emb_layer_norm_before: Whether to use layer norm before the first attention
34
+ layer.
35
+ attention_heads: Number of attention heads.
36
+ key_size: The dimension of the query, key, and values within each attention
37
+ head, if not specified, it is set to attention_heads//embed_dim.
38
+ It can be useful to set a custom key size if we want to impose the size of
39
+ the query, key and value tensor ( for example, tensors shaped with
40
+ power of 2 are more efficiently handled on TPUs ).
41
+ Note: Parametrizing the model with a custom key size has been done in :
42
+ Brown, Tom, et al. "Language models are few-shot learners."
43
+ Advances in neural information processing systems 33 (2020): 1877-1901.
44
+ embed_dim: Embedding dimension.
45
+ ffn_embed_dim: Feed forward embedding dimension.
46
+ num_layers: Number of attention blocks.
47
+ ffn_activation_name: Activation function to be used in FFN block. Supported
48
+ names are "gelu", "relu", "swish".
49
+ use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
50
+ Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
51
+ to True and use swish as ffn_activation_name.
52
+ Same principle for a gated-relu. To keep the same number of parameters in
53
+ the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
54
+ See https://arxiv.org/pdf/2002.05202.pdf for more details.
55
+ resampled_length: length of the resampled output of the module
56
+ use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
57
+ gradients in the forward pass to reduce the computation in the backward).
58
+ """
59
+
60
+ # architecture
61
+ emb_layer_norm_before: bool = False
62
+ attention_heads: int = 20
63
+ key_size: Optional[int] = None
64
+ embed_dim: int = 1280
65
+ ffn_embed_dim: int = 5120
66
+ num_layers: int = 24
67
+ add_bias_kv: bool = False
68
+ add_bias_ffn: bool = True
69
+ ffn_activation_name: str = "gelu-no-approx"
70
+ use_glu_in_ffn: bool = False
71
+ resampled_length: int = 64
72
+
73
+ # performance
74
+ use_gradient_checkpointing: bool = False
75
+
76
+ def __post_init__(self) -> None:
77
+ """
78
+ Checks that the given values are compatible.
79
+ """
80
+
81
+ if self.key_size is None:
82
+ if not self.embed_dim % self.attention_heads == 0:
83
+ raise ValueError(
84
+ f"When no key size is provided, the embedding dimension should be "
85
+ f"divisible by the number of heads, however provided embedding "
86
+ f"dimension is {self.embed_dim} and the number of heads is "
87
+ f"{self.attention_heads}."
88
+ )
89
+ self.key_size = self.embed_dim // self.attention_heads
90
+
91
+
92
+ @dataclass
93
+ class GptConfig:
94
+ """
95
+ Parameters to initialize a Gpt model.
96
+
97
+ NOTE: the pad token is not defined
98
+
99
+ Args:
100
+ vocab_size: Token vocabulary.
101
+ eos_token_id: used to stop sentence generation
102
+ embed_dim: Embedding dimension.
103
+ ffn_embed_dim: Feed forward embedding dimension.
104
+ num_heads: Number of attention heads.
105
+ num_kv_heads: Number of key and value heads to support Grouped-Query and
106
+ Multi-Query Attention. If None, the number of key and value heads is
107
+ equal to the number of attention heads.
108
+ num_layers: Number of Decoder layer_stack
109
+ rope_config: The configuration for the rotary positional embeddings
110
+ add_bias_ffn: Add bias in feed forward network block.
111
+ ffn_activation_name: Activation function to be used in FFN block. Supported
112
+ names are "gelu", "gelu-no-approx", "relu", "swish".
113
+ use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed
114
+ Forward Network (FFN) block.
115
+ example: To do a swiGLU (gated-swish) put this arg
116
+ to True and use swish as ffn_activation_name.
117
+ Same principle for a gated-relu.
118
+ add_bias_lm_head: whether to use bias in the final LM layer
119
+ norm_type: The type of norm used ( pre normalization scheme ) used. can be
120
+ one of ["layer_norm", "RMS_norm"]
121
+ parallel_attention_ff: Whether to do the attention and the MLP in parallel,
122
+ and then sum up the results as it is done in Gpt-NeoX :
123
+ Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive
124
+ language model." arXiv preprint arXiv:2204.06745 (2022).
125
+ It is said to improve the training time of 15% when compiling with JAX
126
+ use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
127
+ gradients in the forward pass to reduce the computation in the backward).
128
+ add_bias_attn: Add bias to the attention mechanism (key, query, value, and
129
+ output projections).
130
+ """
131
+
132
+ # vocabulary
133
+ vocab_size: int
134
+ eos_token_id: int
135
+
136
+ # architecture
137
+ embed_dim: int = 16
138
+ ffn_embed_dim: int = 64
139
+ num_heads: int = 2
140
+ num_kv_heads: Optional[int] = None
141
+ num_layers: int = 2
142
+ rope_config: RotaryEmbeddingConfig = field(
143
+ default_factory=lambda: RotaryEmbeddingConfig(
144
+ max_seq_len=512, dim=8, theta=10000.0
145
+ )
146
+ )
147
+ add_bias_ffn: bool = False
148
+ ffn_activation_name: str = "swish"
149
+ use_glu_in_ffn: bool = True
150
+ add_bias_lm_head: bool = False
151
+ norm_type: str = "RMS_norm"
152
+ rms_norm_eps: float = 1e-6
153
+ parallel_attention_ff: bool = True
154
+
155
+ # inference / backward behavior
156
+ use_gradient_checkpointing: bool = False
157
+
158
+ # architecture params with default values
159
+ add_bias_attn: bool = False
160
+
161
+ def __post_init__(self) -> None:
162
+ """
163
+ Checks that the given values are compatible.
164
+ """
165
+ if not self.embed_dim % self.num_heads == 0:
166
+ raise ValueError(
167
+ f"The embedding dimension should be "
168
+ f"divisible by the number of heads, however provided embedding "
169
+ f"dimension is {self.embed_dim} and the number of heads is "
170
+ f"{self.num_heads}."
171
+ )
172
+
173
+ if not self.embed_dim // self.num_heads > 1:
174
+ raise ValueError(
175
+ "embed_dim / num_heads must be higher than 2 to apply rotary embeddings"
176
+ )
177
+
178
+ if not self.embed_dim // self.num_heads >= self.rope_config.dim:
179
+ raise ValueError(
180
+ "embed_dim // num_heads must be higher than rope_config.dim "
181
+ "to apply rotary embeddings"
182
+ )
183
+
184
+ def to_dict(self): # type: ignore
185
+ output = asdict(self)
186
+ output["rope_config"] = asdict(self.rope_config)
187
+ return output
188
+
189
+
190
+ @dataclass
191
+ class NucleotideTransformerConfig:
192
+ """
193
+ Parameters to initialize an NT model.
194
+
195
+ Args:
196
+ alphabet_size: Token vocabulary.
197
+ pad_token_id: ID of pad token.
198
+ mask_token_id: ID of mask token.
199
+ max_positions: Maximum sequence length.
200
+ embed_scale: Correction ratio applied to the embeddings to make up for the
201
+ norm difference between the input during training and inference.
202
+ emb_layer_norm_before: Whether to use layer norm before the first attention
203
+ layer.
204
+ attention_heads: Number of attention heads.
205
+ key_size: The dimension of the query, key, and values within each attention
206
+ head, if not specified, it is set to attention_heads//embed_dim.
207
+ It can be useful to set a custom key size if we want to impose the size of
208
+ the query, key and value tensor ( for example, tensors shaped with
209
+ power of 2 are more efficiently handled on TPUs ).
210
+ Note: Parametrizing the model with a custom key size has been done in :
211
+ Brown, Tom, et al. "Language models are few-shot learners."
212
+ Advances in neural information processing systems 33 (2020): 1877-1901.
213
+ embed_dim: Embedding dimension.
214
+ ffn_embed_dim: Feed forward embedding dimension.
215
+ num_layers: Number of attention blocks.
216
+ positional_embedding: Type of positional embedding to use before the first
217
+ attention layer. Options: "learned", "learned_standard" "sinusoidal" or
218
+ None.
219
+ NOTE: "learned" is the positional embedding of ESM, and "learned_standard"
220
+ is a more standard one, used for example in DNAbert.
221
+ lm_head: type of language model head. Options: "simple", "roberta" or None.
222
+ add_bias_kv: Add bias in attention layer.
223
+ add_bias_ffn: Add bias in feed forward network block.
224
+ use_rotary_embedding: Whether to use rotary embeddings. Requires:
225
+ positional_embeddings = None.
226
+ rescaling_factor: Scaling factor to use for rotary embeddings.
227
+ ffn_activation_name: Activation function to be used in FFN block. Supported
228
+ names are "gelu", "relu", "swish".
229
+ use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
230
+ Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
231
+ to True and use swish as ffn_activation_name.
232
+ Same principle for a gated-relu. To keep the same number of parameters in
233
+ the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
234
+ See https://arxiv.org/pdf/2002.05202.pdf for more details.
235
+ mask_before_attention: Use mask before attention layers.
236
+ layer_norm_eps: the eps factor in the different layer norms of the model (refer
237
+ to layer norm implementation)
238
+ token_dropout: Token dropout.
239
+ masking_ratio: Masking ratio (used if token dropout is enabled).
240
+ masking_prob: Masking probability (used if token dropout is enabled).
241
+ use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
242
+ gradients in the forward pass to reduce the computation in the backward).
243
+ """
244
+
245
+ alphabet_size: int
246
+ pad_token_id: int
247
+ mask_token_id: int
248
+
249
+ max_positions: int = 1024
250
+ embed_scale: float = 1.0
251
+
252
+ # architecture
253
+ emb_layer_norm_before: bool = False
254
+ attention_heads: int = 20
255
+ key_size: Optional[int] = None
256
+ embed_dim: int = 1280
257
+ ffn_embed_dim: int = 5120
258
+ num_layers: int = 24
259
+ positional_embedding: Optional[str] = "learned"
260
+ lm_head: Optional[str] = "simple"
261
+ add_bias_kv: bool = False
262
+ add_bias_ffn: bool = True
263
+ use_rotary_embedding: bool = False
264
+ rescaling_factor: Optional[float] = None
265
+ ffn_activation_name: str = "gelu-no-approx"
266
+ use_glu_in_ffn: bool = False
267
+ mask_before_attention: bool = False
268
+ layer_norm_eps: float = 1e-5
269
+ pre_layer_norm: bool = True
270
+ bias_word_embedding: bool = False
271
+
272
+ # dropout
273
+ token_dropout: bool = False
274
+ masking_ratio: float = 0.1
275
+ masking_prob: float = 0.8
276
+
277
+ # logging
278
+ use_gradient_checkpointing: bool = False
279
+
280
+ # return
281
+ embeddings_layers_to_save: List[int] = field(default_factory=list)
282
+ attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list)
283
+
284
+ def __post_init__(self) -> None:
285
+ """
286
+ Checks that the given values are compatible.
287
+ """
288
+
289
+ if self.key_size is None:
290
+ if not self.embed_dim % self.attention_heads == 0:
291
+ raise ValueError(
292
+ f"When no key size is provided, the embedding dimension should be "
293
+ f"divisible by the number of heads, however provided embedding "
294
+ f"dimension is {self.embed_dim} and the number of heads is "
295
+ f"{self.attention_heads}."
296
+ )
297
+ self.key_size = self.embed_dim // self.attention_heads
298
+ if self.positional_embedding is not None:
299
+ if type(self.positional_embedding) != str:
300
+ raise TypeError
301
+
302
+ if self.positional_embedding not in [
303
+ "learned",
304
+ "sinusoidal",
305
+ "learned_standard",
306
+ "alibi_dnabert_2",
307
+ ]:
308
+ raise ValueError(
309
+ "The positional_embedding argument should either be None,"
310
+ "`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'."
311
+ )
312
+ if self.lm_head is not None:
313
+ if type(self.lm_head) != str:
314
+ raise TypeError
315
+
316
+ if self.lm_head not in ["simple", "roberta"]:
317
+ raise ValueError(
318
+ "The lm_head argument should either be None,"
319
+ "`simple` or `roberta`."
320
+ )
321
+
322
+ if self.use_rotary_embedding and self.positional_embedding is not None:
323
+ raise ValueError(
324
+ "When using rotary embedding, positional_embedding must be set to none"
325
+ )
326
+
327
+ if self.add_bias_kv and self.use_rotary_embedding:
328
+ raise ValueError(
329
+ "Biases on key and values are not compatible with Rotary embeddings."
330
+ )
331
+
332
+ if self.positional_embedding == "alibi_dnabert_2":
333
+ assert not self.add_bias_kv
334
+
335
+
336
+ @dataclass
337
+ class ChatNTConfig(PretrainedConfig):
338
+ model_type = "ChatNT"
339
+
340
+ def __init__(self, **kwargs): # type: ignore
341
+ self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3))
342
+ self.nt_config: NucleotideTransformerConfig = kwargs.get(
343
+ "nt_config", NucleotideTransformerConfig(4000, 1, 4)
344
+ )
345
+ self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get(
346
+ "perceiver_resampler_config", PerceiverResamplerConfig()
347
+ )
348
+ self.seq_token_id: int = kwargs.get("seq_token_id", 32000)
349
+ self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1)
350
+ self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2)
351
+ super().__init__(**kwargs)
352
+
353
+ def to_dict(self): # type: ignore
354
+ output = super().to_dict()
355
+
356
+ def serialize(obj): # type: ignore
357
+ return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj)
358
+
359
+ output["gpt_config"] = serialize(self.gpt_config) # type: ignore
360
+ output["nt_config"] = serialize(self.nt_config) # type: ignore
361
+ output["perceiver_resampler_config"] = serialize( # type: ignore
362
+ self.perceiver_resampler_config
363
+ )
364
+ return output
365
+
366
+
367
+ class TorchBioBrainDecoder(nn.Module):
368
+ def __init__(
369
+ self,
370
+ gpt_config: GptConfig,
371
+ seq_token_id: int,
372
+ ):
373
+ """
374
+ Initializes the BioBrain decoder, using a GPT model for text generation with
375
+ bio embeddings.
376
+
377
+ Args:
378
+ gpt_config: Configuration for the GPT model
379
+ seq_token_id: Index of the SEQ token
380
+ """
381
+ super(TorchBioBrainDecoder, self).__init__()
382
+ self.gpt_config = gpt_config
383
+ self.seq_token_id = seq_token_id
384
+
385
+ # Initialize the GPT model (assumed you have it already in PyTorch)
386
+ self.gpt_model = TorchGptDecoder(self.gpt_config)
387
+
388
+ def forward(
389
+ self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor
390
+ ) -> torch.Tensor:
391
+ """
392
+ Forward pass through the model.
393
+
394
+ Args:
395
+ english_token_ids: Tensor of English token IDs with shape
396
+ (batch_size, num_english_tokens).
397
+ projected_bio_embeddings: Optional tensor of bio embeddings with shape
398
+ (batch_size, num_bio_sequences, ?, embed_dim).
399
+
400
+ Returns:
401
+ torch.Tensor: The logits from the GPT model,
402
+ shaped (batch_size, num_english_tokens, vocab_size).
403
+ """
404
+
405
+ # Compute English token embeddings
406
+ tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
407
+
408
+ if projected_bio_embeddings is not None:
409
+ (
410
+ batch_size,
411
+ num_bio_sequences,
412
+ _,
413
+ bio_embed_dim,
414
+ ) = projected_bio_embeddings.shape
415
+
416
+ # Insert the bio embeddings at the SEQ token positions
417
+ processed_tokens_ids = english_token_ids.clone()
418
+ for bio_seq_num in range(num_bio_sequences):
419
+ tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
420
+ processed_tokens_ids,
421
+ tokens_embeddings,
422
+ projected_bio_embeddings[:, bio_seq_num, :, :],
423
+ bio_seq_num=bio_seq_num,
424
+ )
425
+
426
+ # Regular GPT pass through
427
+ embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
428
+ embeddings = self.gpt_model.final_norm(embeddings)
429
+
430
+ # Compute logits
431
+ logits = self.gpt_model.lm_head(embeddings)
432
+
433
+ if projected_bio_embeddings is not None:
434
+ # Clean logits sequentially
435
+ processed_tokens_ids = english_token_ids.clone()
436
+ resampled_length = projected_bio_embeddings.shape[-2]
437
+ for _ in range(num_bio_sequences):
438
+ logits, processed_tokens_ids = self.cleanup_logits(
439
+ tokens=processed_tokens_ids,
440
+ logits=logits,
441
+ resampled_length=resampled_length,
442
+ )
443
+
444
+ return logits
445
+
446
+ def insert_embeddings(
447
+ self,
448
+ tokens: torch.Tensor,
449
+ input_embeddings: torch.Tensor,
450
+ resampled_embeddings: torch.Tensor,
451
+ bio_seq_num: int,
452
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
453
+ """
454
+ Inserts resampled embeddings in input_embeddings, starting at the SEQ token
455
+
456
+ Args:
457
+ tokens (torch.Tensor): Shape (batch_size, num_tokens)
458
+ input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
459
+ resampled_embeddings (torch.Tensor):
460
+ Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
461
+
462
+ Returns:
463
+ Tuple[torch.Tensor, torch.Tensor]:
464
+ - input_embeddings with resampled_embeddings inserted at the SEQ token
465
+ - tokens with the SEQ token set to -1
466
+ """
467
+
468
+ def _insert(
469
+ tokens_1d: torch.Tensor,
470
+ input_embeddings_1d: torch.Tensor,
471
+ resampled_embeddings_1d: torch.Tensor,
472
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
473
+ """
474
+ Args:
475
+ tokens (torch.Tensor): Shape (num_tokens,)
476
+ input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,)
477
+ resampled_embeddings (torch.Tensor):
478
+ Shape (bio_sequence_length, embed_dim,)
479
+ """
480
+ indices = torch.where(tokens_1d == self.seq_token_id)[0]
481
+ if indices.numel() > 0:
482
+ idx = indices[0].item()
483
+ insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
484
+ x = torch.cat(
485
+ [
486
+ input_embeddings_1d[:insertion_pos, :],
487
+ resampled_embeddings_1d,
488
+ input_embeddings_1d[insertion_pos:, :],
489
+ ],
490
+ dim=0,
491
+ )[: tokens_1d.shape[0] + 1, :]
492
+ x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[
493
+ :-1, :
494
+ ]
495
+ tokens_1d[idx] = -1
496
+ return x, tokens_1d
497
+ else:
498
+ return (
499
+ input_embeddings,
500
+ tokens_1d,
501
+ ) # Return unchanged if seq_token_id is not found
502
+
503
+ tokens_acc = []
504
+ embeddings_acc = []
505
+
506
+ for i in range(tokens.shape[0]):
507
+ embeddings_out, tokens_out = _insert(
508
+ tokens[i].clone(),
509
+ input_embeddings[i].clone(),
510
+ resampled_embeddings[i].clone(),
511
+ )
512
+ tokens_acc.append(tokens_out)
513
+ embeddings_acc.append(embeddings_out)
514
+ tokens_acc = torch.stack(tokens_acc)
515
+ embeddings_acc = torch.stack(embeddings_acc)
516
+
517
+ return embeddings_acc, tokens_acc
518
+
519
+ def cleanup_logits(
520
+ self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int
521
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
522
+ """
523
+ Removes the logits corresponding to the unused embeddings.
524
+
525
+ Args:
526
+ tokens: Input english tokens.
527
+ logits: Input logits.
528
+
529
+ Returns:
530
+ Cleaned logits, last values will be equal to 0.
531
+ """
532
+
533
+ def _clean(
534
+ token: torch.Tensor, logit: torch.Tensor
535
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
536
+ indices = torch.where(token == self.seq_token_id)[0]
537
+ if indices.numel() > 0:
538
+ idx = indices[0].item()
539
+
540
+ mask_idx = (
541
+ torch.arange(logit.shape[0] - resampled_length, device=logit.device)
542
+ > idx
543
+ )
544
+ mask_idx = mask_idx.unsqueeze(1)
545
+
546
+ # Remove values corresponding to bio tokens
547
+ logit = (
548
+ logit[:-resampled_length] * (~mask_idx)
549
+ + logit[resampled_length:] * mask_idx
550
+ )
551
+
552
+ # Append zeros at the end
553
+ logit = torch.cat(
554
+ (
555
+ logit,
556
+ torch.zeros(
557
+ (resampled_length, logit.shape[1]),
558
+ dtype=logit.dtype,
559
+ device=logit.device,
560
+ ),
561
+ )
562
+ )
563
+
564
+ # Update token
565
+ token[idx] = -1
566
+
567
+ return logit, token
568
+
569
+ else:
570
+ return logit, token
571
+
572
+ tokens_acc = []
573
+ logits_acc = []
574
+
575
+ for i in range(tokens.shape[0]):
576
+ logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone())
577
+ tokens_acc.append(tokens_out)
578
+ logits_acc.append(logits_out)
579
+ tokens_acc = torch.stack(tokens_acc)
580
+ logits_acc = torch.stack(logits_acc)
581
+
582
+ return logits_acc, tokens_acc
583
+
584
+
585
+ class TorchMultiOmicsModel(PreTrainedModel):
586
+ config_class = ChatNTConfig
587
+
588
+ def __init__(self, config: ChatNTConfig) -> None:
589
+ if isinstance(config, dict):
590
+ # If config is a dictionary instead of ChatNTConfig (which can happen
591
+ # depending how the config was saved), we convert it to the config
592
+ config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
593
+ **config["gpt_config"]["rope_config"]
594
+ )
595
+ config["gpt_config"] = GptConfig(**config["gpt_config"])
596
+ config["nt_config"] = NucleotideTransformerConfig(**config["nt_config"])
597
+ config["perceiver_resampler_config"] = PerceiverResamplerConfig(
598
+ **config["perceiver_resampler_config"]
599
+ )
600
+ config = ChatNTConfig(**config) # type: ignore
601
+
602
+ else:
603
+ if isinstance(config.gpt_config, dict):
604
+ config.gpt_config["rope_config"] = RotaryEmbeddingConfig(
605
+ **config.gpt_config["rope_config"]
606
+ )
607
+ config.gpt_config = GptConfig(**config.gpt_config)
608
+
609
+ if isinstance(config.nt_config, dict):
610
+ config.nt_config = NucleotideTransformerConfig(**config.nt_config)
611
+
612
+ if isinstance(config.perceiver_resampler_config, dict):
613
+ config.perceiver_resampler_config = PerceiverResamplerConfig(
614
+ **config.perceiver_resampler_config
615
+ )
616
+
617
+ super().__init__(config=config)
618
+ self.gpt_config = config.gpt_config
619
+ self.nt_config = config.nt_config
620
+ self.perceiver_resampler_config = config.perceiver_resampler_config
621
+ self.seq_token_id = config.seq_token_id
622
+ self.bio_pad_token_id = config.bio_pad_token_id
623
+ self.english_pad_token_id = config.english_pad_token_id
624
+
625
+ # Correct seq_token_id
626
+ self.seq_token_id -= 1
627
+
628
+ self.biobrain_encoder = TorchBioBrainEncoder(nt_config=self.nt_config)
629
+ self.biobrain_decoder = TorchBioBrainDecoder(
630
+ gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
631
+ )
632
+ self.projection_model = TorchMultiModalPerceiverResamplerProjection(
633
+ perceiver_resampler_config=self.perceiver_resampler_config,
634
+ input_embed_dim=self.nt_config.embed_dim,
635
+ embed_dim=self.gpt_config.embed_dim,
636
+ english_vocab_size=self.gpt_config.vocab_size,
637
+ bio_pad_token_id=self.bio_pad_token_id,
638
+ english_pad_token_id=self.english_pad_token_id,
639
+ )
640
+
641
+ def forward(
642
+ self,
643
+ multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
644
+ projection_english_tokens_ids: torch.Tensor,
645
+ projected_bio_embeddings: torch.Tensor = None,
646
+ ) -> dict[str, torch.Tensor]:
647
+ """
648
+
649
+ Args:
650
+ multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
651
+ english_tokens_ids: Represents the prompt tokens (english tokens)
652
+ Shape (batch_size, num_english_tokens)
653
+
654
+ bio_tokens_ids: Represents the bio sequences tokens
655
+ Shape (batch_size, num_bio_sequences, num_bio_tokens)
656
+
657
+ projection_english_tokens_ids (torch.Tensor):
658
+ Shape (batch_size, num_english_tokens)
659
+
660
+ projected_bio_embeddings (projected_bio_embeddings, optional):
661
+ Shape (batch_size, num_bio_sequencse, ?, embed_dim).
662
+ Defaults to None.
663
+
664
+ Returns:
665
+ dict[str, torch.Tensor] containing:
666
+ - logits:
667
+ Shape (batch_size, num_tokens, vocab_size)
668
+
669
+ - projected_bio_embeddings:
670
+ Shape (batch_size, num_bio_sequences, ?, embed_dim)
671
+ """
672
+ english_token_ids, bio_token_ids = multi_omics_tokens_ids
673
+ english_token_ids = english_token_ids.clone()
674
+ bio_token_ids = bio_token_ids.clone()
675
+ projection_english_tokens_ids = projection_english_tokens_ids.clone()
676
+ if projected_bio_embeddings is not None:
677
+ projected_bio_embeddings = projected_bio_embeddings.clone()
678
+
679
+ # Replace config.vocab_size value in english tokens
680
+ # We do this because the default vocab size (32000) doesn't match with the
681
+ # number of tokens because of seq_token_id(=32000) that was added
682
+ # Therefore, we will put seq_token_id to 31999
683
+ # (I will also put token n°31999 to 0, which is for unknown token)
684
+ # This is a workaround to avoid having to change the vocab size in the config
685
+ vocab_size = self.gpt_config.vocab_size
686
+ # Replace vocab
687
+ english_token_ids[english_token_ids == vocab_size - 1] = 0
688
+ projection_english_tokens_ids[
689
+ projection_english_tokens_ids == vocab_size - 1
690
+ ] = 0
691
+ english_token_ids[english_token_ids == vocab_size] = vocab_size - 1
692
+ projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = (
693
+ vocab_size - 1
694
+ )
695
+
696
+ if bio_token_ids is None:
697
+ projected_bio_embeddings = None
698
+ else:
699
+ num_bio_sequences = bio_token_ids.shape[1]
700
+
701
+ if projected_bio_embeddings is None:
702
+ # Compute bio sequences embeddings
703
+ bio_embeddings_list = [
704
+ self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
705
+ for bio_seq_num in range(num_bio_sequences)
706
+ ]
707
+
708
+ # Project these embeddings
709
+ projected_bio_embeddings = [
710
+ self.projection_model(
711
+ bio_token_ids=bio_token_ids[:, bio_seq_num],
712
+ bio_embeddings=bio_embeddings,
713
+ english_token_ids=projection_english_tokens_ids,
714
+ )
715
+ for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
716
+ ]
717
+ projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
718
+
719
+ # decode
720
+ logits = self.biobrain_decoder(
721
+ english_token_ids=english_token_ids,
722
+ projected_bio_embeddings=projected_bio_embeddings,
723
+ )
724
+
725
+ outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
726
+
727
+ return outs
728
+
729
+
730
+ class TorchRotaryEmbedding(torch.nn.Module):
731
+ def __init__(self, config: RotaryEmbeddingConfig):
732
+ super().__init__()
733
+
734
+ self.max_seq_len = config.max_seq_len
735
+ self.dim = config.dim
736
+ self.theta = config.theta
737
+ self.sincos_cache = None
738
+
739
+ def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
740
+ """
741
+ Create the sines and cosines for the RoPE.
742
+
743
+ Returns:
744
+ Sinusoidal positions of shape (self.max_seq_len, self.dim).
745
+ """
746
+ # Create the inverse frequency based on theta and dim
747
+ inv_freq = 1.0 / (
748
+ self.theta
749
+ ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
750
+ )
751
+
752
+ # Compute sinusoidal input using the broadcasting
753
+ sinusoid_inp = torch.einsum(
754
+ "i,j->ij", torch.arange(self.max_seq_len, device=device).float(), inv_freq
755
+ )
756
+
757
+ # Apply sin and cos to the sinusoidal input
758
+ sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
759
+
760
+ # Allocate a tensor for the final sin-cos values
761
+ sincos = torch.zeros(
762
+ (self.max_seq_len, self.dim), dtype=torch.float32, device=device
763
+ )
764
+
765
+ # Fill the sincos tensor with sin and cos values
766
+ sentinel = self.dim // 2 + self.dim % 2
767
+ sincos[:, :sentinel] = sin
768
+ sincos[:, sentinel:] = cos
769
+
770
+ return sincos
771
+
772
+ def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
773
+ """
774
+ Prepare a tensor to apply the RoPE mechanism.
775
+
776
+ Args:
777
+ x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
778
+ typically this is the key or query tensor.
779
+
780
+ Returns:
781
+ The even indices in the last dimension have their sign flipped.
782
+ Tensor of shape (batch_size, seq_len, num_heads, head_dim).
783
+ """
784
+ # Split the tensor into two halves (odd and even indexed dimensions)
785
+ rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1)
786
+
787
+ # Reshape the tensor to the original shape
788
+ rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,))
789
+ return rotate_half
790
+
791
+ def _apply_rotary_pos_emb(
792
+ self, x: torch.Tensor, sincos: torch.Tensor
793
+ ) -> torch.Tensor:
794
+ """
795
+ Applies rotary embeddings to x.
796
+
797
+ Args:
798
+ x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
799
+ typically this is the key or query tensor.
800
+ sincos: Tuple of sine and cosine tensors for position encoding.
801
+
802
+ Returns:
803
+ RoPE embeddings tensor.
804
+ """
805
+ sin_pos, cos_pos = sincos
806
+
807
+ # Reshape the sin and cos tensors for broadcasting
808
+ sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1)
809
+ cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1)
810
+
811
+ # Apply the rotary embedding mechanism
812
+ return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos)
813
+
814
+ def __call__(
815
+ self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None
816
+ ) -> tuple[torch.Tensor, torch.Tensor]:
817
+ """
818
+ Applies rotary embeddings to k and q.
819
+
820
+ Args:
821
+ k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
822
+ q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
823
+ positions: optional positions offset useful when caching,
824
+
825
+ Returns:
826
+ RoPE embeddings for the keys and values.
827
+ """
828
+ if self.sincos_cache is None:
829
+ device = k.device
830
+ self.sincos_cache = self._create_sinusoidal_positions(device=device)
831
+
832
+ batch_size, seq_len, num_heads, head_dim = k.shape
833
+
834
+ # Generate position ids
835
+ position_ids = (
836
+ torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1)
837
+ )
838
+
839
+ if positions is not None:
840
+ position_ids += positions
841
+
842
+ # Retrieve sincos values using the position_ids
843
+ sincos = self.sincos_cache[position_ids] # type: ignore
844
+
845
+ # Split sincos into sin_pos and cos_pos
846
+ sincos = torch.chunk(sincos, 2, dim=-1)
847
+
848
+ # Apply rotary position embedding to key (k) and query (q)
849
+ k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos)
850
+ k_pass = k[..., self.dim :]
851
+
852
+ q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos)
853
+ q_pass = q[..., self.dim :]
854
+
855
+ # Concatenate the rotated and non-rotated parts
856
+ keys = torch.cat([k_rot, k_pass], dim=-1)
857
+ values = torch.cat([q_rot, q_pass], dim=-1)
858
+
859
+ return keys, values
860
+
861
+
862
+ class TorchGptGroupedQueryAttention(nn.Module):
863
+ def __init__(
864
+ self,
865
+ embed_dim: int,
866
+ num_heads: int,
867
+ rope_config: RotaryEmbeddingConfig,
868
+ num_kv_heads: int = None, # type: ignore
869
+ head_dim: int = None, # type: ignore
870
+ add_bias_attn: bool = False, # type: ignore
871
+ ) -> None:
872
+ super().__init__()
873
+ self.num_heads = num_heads
874
+ self.num_kv_heads = num_kv_heads or num_heads
875
+ self.embed_dim = embed_dim
876
+ self.head_dim = head_dim or (embed_dim // num_heads)
877
+ self.add_bias_attn = add_bias_attn
878
+ self.rope = TorchRotaryEmbedding(rope_config)
879
+
880
+ self.query_linear = nn.Linear(
881
+ embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn
882
+ )
883
+ self.key_linear = nn.Linear(
884
+ embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
885
+ )
886
+ self.value_linear = nn.Linear(
887
+ embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
888
+ )
889
+ self.out_linear = nn.Linear(
890
+ self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn
891
+ )
892
+
893
+ def forward(
894
+ self,
895
+ query_inputs: torch.Tensor,
896
+ key_inputs: torch.Tensor,
897
+ value_inputs: torch.Tensor,
898
+ attention_mask: torch.Tensor = None,
899
+ ) -> torch.Tensor:
900
+ batch_size, seq_len, _ = query_inputs.shape
901
+
902
+ queries = self.query_linear(query_inputs).view( # noqa
903
+ batch_size, seq_len, self.num_heads, self.head_dim
904
+ )
905
+ keys = self.key_linear(key_inputs).view( # noqa
906
+ batch_size, seq_len, self.num_kv_heads, self.head_dim
907
+ )
908
+ values = self.value_linear(value_inputs).view( # noqa
909
+ batch_size, seq_len, self.num_kv_heads, self.head_dim
910
+ )
911
+
912
+ keys, queries = self.rope(keys, queries)
913
+
914
+ n_rep = self.num_heads // self.num_kv_heads
915
+ keys = keys.repeat_interleave(n_rep, dim=2)
916
+ values = values.repeat_interleave(n_rep, dim=2)
917
+
918
+ attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / (
919
+ self.head_dim**0.5
920
+ )
921
+
922
+ if attention_mask is not None:
923
+ attention_logits = attention_logits.masked_fill(
924
+ attention_mask == 0, float("-inf")
925
+ )
926
+
927
+ attention_weights = nn.functional.softmax(attention_logits, dim=-1)
928
+
929
+ values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
930
+ values = values.contiguous().view(batch_size, seq_len, -1)
931
+
932
+ return self.out_linear(values)
933
+
934
+
935
+ class TorchGptDecoder(nn.Module):
936
+ def __init__(self, config: GptConfig, name: Optional[str] = None):
937
+ super().__init__()
938
+ self.config = config
939
+
940
+ self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim)
941
+
942
+ if config.norm_type == "layer_norm":
943
+ self.final_norm = nn.LayerNorm(config.embed_dim)
944
+ elif config.norm_type == "RMS_norm":
945
+ self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps)
946
+ else:
947
+ raise ValueError(f"unrecognized norm_type in config {config.norm_type}")
948
+
949
+ self.layers = nn.ModuleList(
950
+ [
951
+ TorchGptDecoderLayer(
952
+ embed_dim=config.embed_dim,
953
+ ffn_embed_dim=config.ffn_embed_dim,
954
+ num_heads=config.num_heads,
955
+ rope_config=config.rope_config,
956
+ norm_type=config.norm_type,
957
+ parallel_attention_ff=config.parallel_attention_ff,
958
+ add_bias_ffn=config.add_bias_ffn,
959
+ ffn_activation_name=config.ffn_activation_name,
960
+ use_glu_in_ffn=config.use_glu_in_ffn,
961
+ num_kv_heads=config.num_kv_heads, # type: ignore
962
+ add_bias_attn=config.add_bias_attn,
963
+ rms_norm_eps=config.rms_norm_eps,
964
+ )
965
+ for _ in range(config.num_layers)
966
+ ]
967
+ )
968
+
969
+ self.lm_head = TorchSimpleLMHead(
970
+ embed_dim=config.embed_dim,
971
+ alphabet_size=config.vocab_size,
972
+ add_bias_lm_head=config.add_bias_lm_head,
973
+ )
974
+
975
+ def apply_transformer_layers(
976
+ self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
977
+ ) -> torch.Tensor:
978
+ if attention_mask is None:
979
+ attention_mask = build_causal_attention_mask(
980
+ 1, embeddings.shape[1], device=embeddings.device
981
+ )
982
+ for layer in self.layers:
983
+ embeddings = layer(embeddings, attention_mask)
984
+
985
+ return embeddings
986
+
987
+ def forward(
988
+ self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
989
+ ) -> dict[str, torch.Tensor]:
990
+ if attention_mask is None:
991
+ attention_mask = build_causal_attention_mask(
992
+ 1, token_ids.shape[1], device=token_ids.device
993
+ )
994
+
995
+ tokens_embeddings = self.token_embed(token_ids)
996
+
997
+ after_transformer_embeddings = self.apply_transformer_layers(
998
+ tokens_embeddings, attention_mask=attention_mask
999
+ )
1000
+
1001
+ embeddings = self.final_norm(after_transformer_embeddings)
1002
+ logits = self.lm_head(embeddings)
1003
+ return {"embeddings": embeddings, "logits": logits}
1004
+
1005
+
1006
+ class TorchSimpleLMHead(nn.Module):
1007
+ def __init__(
1008
+ self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True
1009
+ ) -> None:
1010
+ super().__init__()
1011
+ self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head)
1012
+
1013
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1014
+ return self.fc(x)
1015
+
1016
+
1017
+ class TorchGptDecoderLayer(nn.Module):
1018
+ def __init__(
1019
+ self,
1020
+ embed_dim: int,
1021
+ ffn_embed_dim: int,
1022
+ num_heads: int,
1023
+ rope_config: RotaryEmbeddingConfig,
1024
+ norm_type: str,
1025
+ parallel_attention_ff: bool,
1026
+ add_bias_ffn: bool,
1027
+ ffn_activation_name: str,
1028
+ use_glu_in_ffn: bool,
1029
+ num_kv_heads: int,
1030
+ add_bias_attn: bool,
1031
+ rms_norm_eps: float = 1e-6,
1032
+ ) -> None:
1033
+ super().__init__()
1034
+ self.num_heads = num_heads
1035
+ self.parallel_attention_ff = parallel_attention_ff
1036
+ self.use_glu_in_ffn = use_glu_in_ffn
1037
+
1038
+ # Self-Attention layer
1039
+ self.self_attn = TorchGptGroupedQueryAttention(
1040
+ embed_dim=embed_dim,
1041
+ num_heads=num_heads,
1042
+ num_kv_heads=num_kv_heads,
1043
+ rope_config=rope_config,
1044
+ add_bias_attn=add_bias_attn,
1045
+ )
1046
+
1047
+ # Normalization layers
1048
+ if norm_type == "layer_norm":
1049
+ self.attn_norm = nn.LayerNorm(embed_dim)
1050
+ if not self.parallel_attention_ff:
1051
+ self.ffn_norm = nn.LayerNorm(embed_dim)
1052
+ elif norm_type == "RMS_norm":
1053
+ self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1054
+ if not self.parallel_attention_ff:
1055
+ self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1056
+ else:
1057
+ raise ValueError(f"unrecognized norm_type: {norm_type}")
1058
+
1059
+ # Feedforward network
1060
+ self.activation = get_activation_fn(ffn_activation_name)
1061
+ ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1)
1062
+ self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn)
1063
+ self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1064
+
1065
+ def forward(
1066
+ self, embeddings: torch.Tensor, attention_mask: torch.Tensor
1067
+ ) -> torch.Tensor:
1068
+ residuals = embeddings
1069
+
1070
+ if self.parallel_attention_ff:
1071
+ # Parallel Attention + MLP
1072
+ embeddings_normed = self.attn_norm(embeddings)
1073
+
1074
+ attn_output, _ = self.self_attn(
1075
+ embeddings_normed,
1076
+ embeddings_normed,
1077
+ embeddings_normed,
1078
+ attn_mask=attention_mask,
1079
+ )
1080
+ ffn_output = self.mlp(embeddings_normed) # type: ignore
1081
+
1082
+ return residuals + attn_output + ffn_output
1083
+ else:
1084
+ # Sequential Attention + MLP
1085
+ normed_embeddings = self.attn_norm(embeddings)
1086
+
1087
+ attn_output = embeddings + self.self_attn(
1088
+ normed_embeddings,
1089
+ normed_embeddings,
1090
+ normed_embeddings,
1091
+ attention_mask=attention_mask,
1092
+ )
1093
+
1094
+ normed_embeddings2 = self.ffn_norm(attn_output)
1095
+ ffn_output = self.mlp(normed_embeddings2) # type: ignore
1096
+ return attn_output + ffn_output # Residual connection
1097
+
1098
+ def mlp(self, x: torch.Tensor) -> torch.Tensor:
1099
+ """Applies the feedforward network (MLP) with optional GLU."""
1100
+ ffn_output = self.fc1(x)
1101
+
1102
+ if self.use_glu_in_ffn:
1103
+ ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1)
1104
+ ffn_output = self.activation(ffn_output1) * ffn_output2
1105
+ else:
1106
+ ffn_output = self.activation(ffn_output)
1107
+
1108
+ return self.fc2(ffn_output)
1109
+
1110
+
1111
+ class TorchRMSNorm(nn.Module):
1112
+ def __init__(self, dim: int, eps: float = 1e-6) -> None:
1113
+ super().__init__()
1114
+ self.eps = eps
1115
+ self.scale = nn.Parameter(torch.ones(dim))
1116
+
1117
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1118
+ return (
1119
+ x
1120
+ * self.scale
1121
+ / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
1122
+ )
1123
+
1124
+
1125
+ def get_activation_fn(activation_name: str): # type: ignore
1126
+ activations = {
1127
+ "gelu": nn.functional.gelu,
1128
+ "relu": nn.functional.relu,
1129
+ "swish": nn.functional.silu,
1130
+ "silu": nn.functional.silu,
1131
+ }
1132
+ return activations.get(activation_name, nn.functional.relu)
1133
+
1134
+
1135
+ def build_causal_attention_mask(
1136
+ batch_size: int, seq_len: int, device: torch.device
1137
+ ) -> torch.Tensor:
1138
+ """
1139
+ Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1140
+ to an attention layer.
1141
+
1142
+ Args:
1143
+ batch_size: Batch size.
1144
+ seq_len: Length of the sequences.
1145
+
1146
+ Returns:
1147
+ Batch of causal masks.
1148
+ """
1149
+ mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device)
1150
+ causal_mask = torch.tril(mask)
1151
+ return causal_mask
1152
+
1153
+
1154
+ @dataclass
1155
+ class RotaryEmbeddingConfigBis:
1156
+ """
1157
+ Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows
1158
+ to adapt the rotary embeddings to larger lengths than what was used for training.
1159
+ One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa
1160
+ Args:
1161
+ """
1162
+
1163
+ rescaling_factor: Optional[float]
1164
+
1165
+
1166
+ class RotaryEmbeddingBis(torch.nn.Module):
1167
+ """
1168
+ Rotary position embeddings based on those in
1169
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).
1170
+ Query and keys are transformed by rotation
1171
+ matrices which depend on their relative positions.
1172
+ """
1173
+
1174
+ def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis):
1175
+ super().__init__()
1176
+
1177
+ # Extract argument from the config
1178
+ self.rescaling_factor = rotary_embedding_config.rescaling_factor
1179
+ self.upper_freq = 10000
1180
+ self.dim = dim
1181
+
1182
+ self._seq_len_cached = None
1183
+ self._cos_cached = None
1184
+ self._sin_cached = None
1185
+
1186
+ def _apply_rotary_pos_emb(
1187
+ self,
1188
+ heads: torch.Tensor,
1189
+ cos: torch.Tensor,
1190
+ sin: torch.Tensor,
1191
+ ) -> torch.Tensor:
1192
+ """ """
1193
+ x_first, x_second = (
1194
+ heads[..., : heads.shape[-1] // 2],
1195
+ heads[..., heads.shape[-1] // 2 :],
1196
+ )
1197
+
1198
+ first_part = x_first * cos - x_second * sin
1199
+ second_part = x_second * cos + x_first * sin
1200
+
1201
+ return torch.cat((first_part, second_part), dim=-1)
1202
+
1203
+ def _compute_cos_sin_tables(
1204
+ self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1205
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1206
+ seq_len = x.shape[seq_dimension]
1207
+ # Reset the tables if the sequence length has changed,
1208
+ # or if we're on a new device (possibly due to tracing for instance)
1209
+ self._seq_len_cached = seq_len
1210
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
1211
+ # freqs = torch.outer(t, inv_freq)
1212
+ freqs = torch.einsum("i, j -> ij", t, inv_freq)
1213
+
1214
+ self._cos_cached = torch.cos(freqs)[None, :, None, :]
1215
+ self._sin_cached = torch.sin(freqs)[None, :, None, :]
1216
+ # emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1217
+
1218
+ # self._cos_cached = emb.cos()[None, None, :, :]
1219
+ # self._sin_cached = emb.sin()[None, None, :, :]
1220
+
1221
+ return self._cos_cached, self._sin_cached
1222
+
1223
+ def forward(
1224
+ self, q: torch.Tensor, k: torch.Tensor
1225
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1226
+ if self.rescaling_factor is None:
1227
+ inv_freq = 1.0 / (
1228
+ self.upper_freq
1229
+ ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1230
+ )
1231
+ else:
1232
+ updated_base = self.upper_freq * (
1233
+ self.rescaling_factor ** (self.dim / (self.dim - 2))
1234
+ )
1235
+ inv_freq = 1.0 / (
1236
+ updated_base
1237
+ ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1238
+ )
1239
+
1240
+ self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1241
+ q,
1242
+ inv_freq,
1243
+ seq_dimension=-3,
1244
+ )
1245
+
1246
+ return (
1247
+ self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1248
+ self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1249
+ )
1250
+
1251
+
1252
+ class MultiHeadAttention(nn.Module):
1253
+ def __init__(
1254
+ self,
1255
+ num_heads: int,
1256
+ key_size: int,
1257
+ rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1258
+ add_bias_kv: bool = False,
1259
+ value_size: Optional[int] = None,
1260
+ model_size: Optional[int] = None,
1261
+ name: Optional[str] = None,
1262
+ ):
1263
+ super().__init__()
1264
+ if not model_size:
1265
+ model_size = key_size * num_heads
1266
+ if not value_size:
1267
+ value_size = key_size
1268
+ self.model_size = model_size
1269
+ self.key_size = key_size
1270
+ self.value_size = value_size
1271
+ self.add_bias_kv = add_bias_kv
1272
+ self.name = name
1273
+ self.num_heads = num_heads
1274
+ self._rotary_embedding_config = rotary_embedding_config
1275
+
1276
+ self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
1277
+ self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
1278
+ self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
1279
+ self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
1280
+ if self._rotary_embedding_config:
1281
+ self._rotary_embedding = RotaryEmbeddingBis(
1282
+ self.key_size, self._rotary_embedding_config
1283
+ )
1284
+
1285
+ def apply_rotary_embeddings(
1286
+ self,
1287
+ query: torch.Tensor,
1288
+ key: torch.Tensor,
1289
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1290
+ """ """
1291
+ query, key = self._rotary_embedding(query, key)
1292
+ return query, key
1293
+
1294
+ def forward(
1295
+ self,
1296
+ query: torch.Tensor,
1297
+ key: torch.Tensor,
1298
+ value: torch.Tensor,
1299
+ attention_mask: Optional[torch.Tensor] = None,
1300
+ attention_weight_bias: Optional[torch.Tensor] = None,
1301
+ ) -> dict[str, torch.Tensor]:
1302
+ """
1303
+ Returns:
1304
+ dictionary containing attention weights
1305
+ and outputs.
1306
+ """
1307
+ key_heads = self.w_k(key).reshape(
1308
+ (*key.shape[:-1], self.num_heads, self.key_size)
1309
+ )
1310
+ query_heads = self.w_q(query).reshape(
1311
+ (*query.shape[:-1], self.num_heads, self.key_size)
1312
+ )
1313
+ value_heads = self.w_v(value).reshape(
1314
+ (*value.shape[:-1], self.num_heads, self.value_size)
1315
+ )
1316
+ if self._rotary_embedding_config:
1317
+ query_heads, key_heads = self.apply_rotary_embeddings(
1318
+ query_heads, key_heads
1319
+ )
1320
+ attention_weights = torch.einsum(
1321
+ "...thd, ...Thd -> ...htT", query_heads, key_heads
1322
+ )
1323
+ sqrt_key_size = np.sqrt(self.key_size)
1324
+ attention_weights = attention_weights / sqrt_key_size
1325
+ if attention_mask is not None:
1326
+ attention_weights = torch.where(attention_mask, attention_weights, -1e30)
1327
+ if attention_weight_bias is not None:
1328
+ attention_weights = F.softmax(
1329
+ attention_weights + attention_weight_bias, dim=-1
1330
+ )
1331
+ else:
1332
+ attention_weights = F.softmax(attention_weights, dim=-1)
1333
+ value_out = torch.einsum(
1334
+ "...htT, ...Thd->...thd", attention_weights, value_heads
1335
+ )
1336
+ value_out = value_out.reshape((*value_out.shape[:-2], -1))
1337
+ embeddings = self.output(value_out)
1338
+
1339
+ return {"attention_weights": attention_weights, "embeddings": embeddings}
1340
+
1341
+
1342
+ class SelfAttentionBlock(nn.Module):
1343
+ def __init__(
1344
+ self,
1345
+ num_heads: int,
1346
+ embed_dim: int,
1347
+ ffn_embed_dim: int,
1348
+ key_size: Optional[int] = None,
1349
+ add_bias_kv: bool = False,
1350
+ add_bias_fnn: bool = True,
1351
+ ffn_activation_name: str = "gelu-no-approx",
1352
+ use_glu_in_ffn: bool = False,
1353
+ layer_norm_eps: float = 1e-5, # this is the default haiku value
1354
+ pre_layer_norm: bool = True,
1355
+ name: Optional[str] = None,
1356
+ rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1357
+ ):
1358
+ super().__init__()
1359
+ if key_size is None:
1360
+ if embed_dim % num_heads != 0:
1361
+ raise ValueError(
1362
+ f"The embedding dimension should be divisible by the number of "
1363
+ f"heads, however provided embedding dimension is {embed_dim} and "
1364
+ f"the number of heads is {num_heads}."
1365
+ )
1366
+ else:
1367
+ key_size = embed_dim // num_heads
1368
+
1369
+ # Get ffn activation function
1370
+ self._pre_layer_norm = pre_layer_norm
1371
+ self._use_glu_in_fnn = use_glu_in_ffn
1372
+ # Define layers
1373
+ if use_glu_in_ffn:
1374
+ # user should multiply ffn_embed_dim by 2/3 when using GLU
1375
+ # to keep total number of parameters equal
1376
+ # see https://arxiv.org/pdf/2002.05202.pdf. for more details
1377
+ # we multiply by 2 here as the output will be split in 2 for GLU
1378
+ self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
1379
+ else:
1380
+ self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
1381
+
1382
+ self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
1383
+
1384
+ self.layer_norm_self_attention = nn.LayerNorm(
1385
+ embed_dim,
1386
+ )
1387
+ self.layer_norm_mlp = nn.LayerNorm(embed_dim)
1388
+ if ffn_activation_name == "swish":
1389
+ self._ffn_activation_fn = nn.SiLU()
1390
+ elif ffn_activation_name == "gelu-no-approx":
1391
+ self._ffn_activation_fn = nn.GELU(approximate="tanh")
1392
+ else:
1393
+ self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
1394
+
1395
+ self.mha = MultiHeadAttention(
1396
+ num_heads=num_heads,
1397
+ key_size=key_size,
1398
+ add_bias_kv=add_bias_kv,
1399
+ model_size=embed_dim,
1400
+ name="self_attention",
1401
+ rotary_embedding_config=rotary_embedding_config,
1402
+ )
1403
+
1404
+ def mlp(self, embed: torch.Tensor) -> torch.Tensor:
1405
+
1406
+ if self._pre_layer_norm:
1407
+ x = self.layer_norm_mlp(embed)
1408
+ else:
1409
+ x = embed
1410
+
1411
+ if self._use_glu_in_fnn:
1412
+ x = self.fc1(x)
1413
+ x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
1414
+ x = self._ffn_activation_fn(x1) * x2
1415
+ else:
1416
+ x = self._ffn_activation_fn(self.fc1(x))
1417
+ x = self.fc2(x)
1418
+
1419
+ if not self._pre_layer_norm:
1420
+ x = self.layer_norm_mlp(x + embed)
1421
+ return x
1422
+
1423
+ def forward(
1424
+ self,
1425
+ x: torch.Tensor,
1426
+ attention_mask: Optional[torch.Tensor] = None,
1427
+ attention_weight_bias: Optional[torch.Tensor] = None,
1428
+ ) -> dict[str, torch.Tensor]:
1429
+
1430
+ res = x
1431
+ if self._pre_layer_norm:
1432
+ x = self.layer_norm_self_attention(x)
1433
+
1434
+ output: dict[str, torch.Tensor] = self.mha(
1435
+ x,
1436
+ x,
1437
+ x,
1438
+ attention_mask=attention_mask,
1439
+ attention_weight_bias=attention_weight_bias,
1440
+ )
1441
+
1442
+ if not self._pre_layer_norm:
1443
+ output["embeddings"] = self.layer_norm_self_attention(
1444
+ output["embeddings"] + res
1445
+ )
1446
+
1447
+ x = output["embeddings"]
1448
+ else:
1449
+ x = output["embeddings"]
1450
+ x = res + x
1451
+
1452
+ # MLP
1453
+ if not self._pre_layer_norm:
1454
+ x = self.mlp(x)
1455
+ else:
1456
+ x = x + self.mlp(x)
1457
+
1458
+ output["embeddings"] = x
1459
+ return output
1460
+
1461
+
1462
+ class RobertaLMHead(nn.Module):
1463
+ """
1464
+ Roberta Language Model head. Transforms final attention layer output into a
1465
+ distribution over tokens at each position.
1466
+ """
1467
+
1468
+ def __init__(self, embed_dim: int, alphabet_size: int):
1469
+ """
1470
+ Args:
1471
+ embed_dim: Embedding dimension.
1472
+ alphabet_size: Number of tokens in the alphabet.
1473
+ """
1474
+ super().__init__()
1475
+ self.embed_dim = embed_dim
1476
+ self.alphabet_size = alphabet_size
1477
+
1478
+ # Define layers
1479
+ self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1480
+ self._fc1 = nn.Linear(embed_dim, embed_dim)
1481
+ self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1482
+ self._final_fc = nn.Linear(embed_dim, alphabet_size)
1483
+
1484
+ def forward(self, x: torch.Tensor) -> dict:
1485
+ x = self._first_layer_norm(x)
1486
+ embeddings = x
1487
+ x = self._fc1(x)
1488
+ x = nn.functional.gelu(x)
1489
+ x = self._second_layer_norm(x)
1490
+ logits = self._final_fc(x)
1491
+ return {"embeddings": embeddings, "logits": logits}
1492
+
1493
+
1494
+ class TorchNucleotideTransformer(nn.Module):
1495
+ def __init__(
1496
+ self,
1497
+ nt_config: NucleotideTransformerConfig,
1498
+ ):
1499
+ super(TorchNucleotideTransformer, self).__init__()
1500
+ self.nt_config = nt_config
1501
+
1502
+ # Other cases are not implemented
1503
+ assert nt_config.positional_embedding is None
1504
+ assert nt_config.lm_head == "roberta"
1505
+ assert nt_config.use_rotary_embedding is True
1506
+ assert nt_config.token_dropout is False
1507
+ assert nt_config.emb_layer_norm_before is False
1508
+ assert nt_config.mask_before_attention is False
1509
+ assert nt_config.bias_word_embedding is False
1510
+ assert nt_config.use_gradient_checkpointing is False
1511
+
1512
+ self.embed_layer = nn.Embedding(nt_config.alphabet_size, nt_config.embed_dim)
1513
+
1514
+ self.lm_head = RobertaLMHead(
1515
+ embed_dim=nt_config.embed_dim,
1516
+ alphabet_size=nt_config.alphabet_size,
1517
+ )
1518
+
1519
+ self.rotary_embedding_config = RotaryEmbeddingConfigBis(
1520
+ rescaling_factor=nt_config.rescaling_factor
1521
+ )
1522
+
1523
+ self.attention_blocks = nn.ModuleList(
1524
+ [
1525
+ SelfAttentionBlock( # type: ignore
1526
+ num_heads=nt_config.attention_heads,
1527
+ embed_dim=nt_config.embed_dim,
1528
+ key_size=nt_config.key_size,
1529
+ ffn_embed_dim=nt_config.ffn_embed_dim,
1530
+ add_bias_kv=nt_config.add_bias_kv,
1531
+ add_bias_fnn=nt_config.add_bias_ffn,
1532
+ ffn_activation_name=nt_config.ffn_activation_name,
1533
+ use_glu_in_ffn=nt_config.use_glu_in_ffn,
1534
+ rotary_embedding_config=self.rotary_embedding_config,
1535
+ layer_norm_eps=nt_config.layer_norm_eps,
1536
+ pre_layer_norm=nt_config.pre_layer_norm,
1537
+ )
1538
+ for _ in range(nt_config.num_layers)
1539
+ ]
1540
+ )
1541
+
1542
+ def forward(
1543
+ self, tokens: torch.Tensor, attention_mask: torch.Tensor = None
1544
+ ) -> torch.Tensor:
1545
+ """
1546
+ Computes the embeddings based on the input tokens.
1547
+
1548
+ Args:
1549
+ tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1550
+ attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1551
+ If no mask is provided, a mask by default which equals 1 over all non
1552
+ pad tokens and 0 over pad tokens is computed.
1553
+
1554
+ Returns:
1555
+ Dictionary containing the final embeddings and logits.
1556
+ """
1557
+ x = self.embed_layer(tokens)
1558
+
1559
+ # RoBERTa's mask scaling factor
1560
+ x = self.nt_config.embed_scale * x
1561
+
1562
+ if attention_mask is None:
1563
+ attention_mask = build_padding_attention_mask(
1564
+ tokens=tokens, pad_token_id=self.nt_config.pad_token_id
1565
+ )
1566
+
1567
+ for layer in self.attention_blocks:
1568
+ x = layer(x, attention_mask)["embeddings"]
1569
+
1570
+ assert self.nt_config.lm_head == "roberta"
1571
+ x = self.lm_head(x)["embeddings"]
1572
+
1573
+ return x
1574
+
1575
+
1576
+ def build_padding_attention_mask(
1577
+ tokens: torch.Tensor, pad_token_id: int
1578
+ ) -> torch.Tensor:
1579
+ """
1580
+ Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
1581
+
1582
+ Args:
1583
+ tokens: Batch of sequences of shape (batch_size, seq_len).
1584
+ pad_token_id: Int corresponding to the <pad> token to mask.
1585
+
1586
+ Returns:
1587
+ Batch of attention masks, masking out <pad> tokens.
1588
+ """
1589
+ padding_mask = tokens != pad_token_id
1590
+ padding_mask = padding_mask.unsqueeze(1)
1591
+ padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask)
1592
+ return padding_mask
1593
+
1594
+
1595
+ class TorchBioBrainEncoder(nn.Module):
1596
+ def __init__(
1597
+ self,
1598
+ nt_config: NucleotideTransformerConfig,
1599
+ ):
1600
+ super(TorchBioBrainEncoder, self).__init__()
1601
+ self.nt_config = nt_config
1602
+ self.nt_model = TorchNucleotideTransformer(self.nt_config)
1603
+
1604
+ def forward(
1605
+ self,
1606
+ bio_token_ids: torch.Tensor,
1607
+ ) -> torch.Tensor:
1608
+ """
1609
+ Args:
1610
+ bio_token_ids (torch.Tensor):
1611
+ Shape (batch_size, num_bio_tokens)
1612
+
1613
+ Returns:
1614
+ torch.Tensor:
1615
+ Shape (batch_size, num_bio_tokens, embed_dim)
1616
+ """
1617
+ bio_embeddings = self.nt_model(tokens=bio_token_ids)
1618
+
1619
+ return bio_embeddings
1620
+
1621
+
1622
+ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1623
+ def __init__(
1624
+ self,
1625
+ num_heads: int,
1626
+ embed_dim: int,
1627
+ ffn_embed_dim: int,
1628
+ key_size: Optional[int] = None,
1629
+ add_bias_kv: bool = False,
1630
+ add_bias_ffn: bool = True,
1631
+ ffn_activation_name: str = "gelu",
1632
+ use_glu_in_ffn: bool = False,
1633
+ ):
1634
+ super().__init__()
1635
+
1636
+ if key_size is None:
1637
+ if embed_dim % num_heads != 0:
1638
+ raise ValueError(
1639
+ f"Embedding dimension {embed_dim} should be divisible by "
1640
+ f"num_heads {num_heads}."
1641
+ )
1642
+ key_size = embed_dim // num_heads
1643
+
1644
+ self.num_heads = num_heads
1645
+ self.embed_dim = embed_dim
1646
+ self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim
1647
+ self.use_glu_in_ffn = use_glu_in_ffn
1648
+
1649
+ self.cross_attention_1 = MultiHeadAttention(
1650
+ num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1651
+ )
1652
+ self.cross_attention_2 = MultiHeadAttention(
1653
+ num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1654
+ )
1655
+
1656
+ self.norm_cross_attention_1 = nn.LayerNorm(embed_dim)
1657
+ self.norm_cross_attention_2 = nn.LayerNorm(embed_dim)
1658
+ self.norm_mlp = nn.LayerNorm(embed_dim)
1659
+
1660
+ self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn)
1661
+ self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1662
+
1663
+ self.activation_fn = getattr(
1664
+ nn.functional, ffn_activation_name, nn.functional.gelu
1665
+ )
1666
+
1667
+ def mlp(self, x: torch.Tensor) -> torch.Tensor:
1668
+ x = self.norm_mlp(x)
1669
+ if self.use_glu_in_ffn:
1670
+ x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
1671
+ x = self.activation_fn(x1) * x2
1672
+ else:
1673
+ x = self.activation_fn(self.fc1(x))
1674
+ return self.fc2(x)
1675
+
1676
+ def forward(
1677
+ self,
1678
+ x: torch.Tensor,
1679
+ cross_attention_embeddings_1: torch.Tensor,
1680
+ cross_attention_embeddings_2: torch.Tensor,
1681
+ attention_mask_1: Optional[torch.Tensor] = None,
1682
+ attention_mask_2: Optional[torch.Tensor] = None,
1683
+ ) -> Dict[str, torch.Tensor]:
1684
+ res = x
1685
+ x = self.norm_cross_attention_1(x)
1686
+
1687
+ attn_output = self.cross_attention_1(
1688
+ query=x,
1689
+ key=cross_attention_embeddings_1,
1690
+ value=cross_attention_embeddings_1,
1691
+ attention_mask=attention_mask_1,
1692
+ )["embeddings"]
1693
+ x = res + attn_output
1694
+
1695
+ res = x
1696
+ x = self.norm_cross_attention_2(x)
1697
+ attn_output = self.cross_attention_2(
1698
+ query=x,
1699
+ key=cross_attention_embeddings_2,
1700
+ value=cross_attention_embeddings_2,
1701
+ attention_mask=attention_mask_2,
1702
+ )["embeddings"]
1703
+ x = res + attn_output
1704
+
1705
+ x = x + self.mlp(x)
1706
+
1707
+ return {"embeddings": x}
1708
+
1709
+
1710
+ class TorchMultiModalPerceiverResampler(nn.Module):
1711
+ """
1712
+ Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1713
+ """
1714
+
1715
+ def __init__(
1716
+ self,
1717
+ config: PerceiverResamplerConfig,
1718
+ name: Optional[str] = None,
1719
+ ):
1720
+ """
1721
+ Initialize a Perceiver Resampler model.
1722
+
1723
+ Args:
1724
+ config: Dataclass containing model hyperparameters.
1725
+ name: Name for module (custom will break weight loading).
1726
+ """
1727
+ super().__init__()
1728
+ self.config = config
1729
+ self.name = name
1730
+ self.layers = nn.ModuleList(
1731
+ [
1732
+ TorchMultiModalPerceiverResamplerBlock(
1733
+ num_heads=self.config.attention_heads,
1734
+ embed_dim=self.config.embed_dim,
1735
+ key_size=self.config.key_size,
1736
+ ffn_embed_dim=self.config.ffn_embed_dim,
1737
+ add_bias_kv=self.config.add_bias_kv,
1738
+ add_bias_ffn=self.config.add_bias_ffn,
1739
+ ffn_activation_name=self.config.ffn_activation_name,
1740
+ use_glu_in_ffn=self.config.use_glu_in_ffn,
1741
+ )
1742
+ for _ in range(self.config.num_layers)
1743
+ ]
1744
+ )
1745
+
1746
+ self.latent_queries = torch.nn.Parameter(
1747
+ torch.randn(self.config.resampled_length, self.config.embed_dim)
1748
+ * (
1749
+ 1.0
1750
+ / torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32))
1751
+ )
1752
+ )
1753
+
1754
+ def apply_attention_blocks(
1755
+ self,
1756
+ x: torch.Tensor,
1757
+ xf_1: torch.Tensor,
1758
+ xf_2: torch.Tensor,
1759
+ outs: Dict[str, torch.Tensor],
1760
+ attention_mask_1: Optional[torch.Tensor] = None,
1761
+ attention_mask_2: Optional[torch.Tensor] = None,
1762
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1763
+ """
1764
+ Create the blocks of attention layers and applies them.
1765
+ """
1766
+ for layer in self.layers:
1767
+ concat_input_1 = torch.cat([xf_1, x], dim=1)
1768
+ concat_input_2 = torch.cat([xf_2, x], dim=1)
1769
+
1770
+ output = layer(
1771
+ x=x,
1772
+ cross_attention_embeddings_1=concat_input_1,
1773
+ cross_attention_embeddings_2=concat_input_2,
1774
+ attention_mask_1=attention_mask_1,
1775
+ attention_mask_2=attention_mask_2,
1776
+ )
1777
+ x = output["embeddings"]
1778
+
1779
+ return x, outs
1780
+
1781
+ def forward(
1782
+ self,
1783
+ input_embeddings_1: torch.Tensor,
1784
+ input_embeddings_2: torch.Tensor,
1785
+ attention_mask_1: Optional[torch.Tensor] = None,
1786
+ attention_mask_2: Optional[torch.Tensor] = None,
1787
+ ) -> Dict[str, torch.Tensor]:
1788
+ """
1789
+ Computes the embeddings based on the input tokens.
1790
+ """
1791
+ assert (
1792
+ input_embeddings_1.shape[-1] == self.config.embed_dim
1793
+ ), "The input embedding dim should match the model embed dim"
1794
+ assert (
1795
+ input_embeddings_2.shape[-1] == self.config.embed_dim
1796
+ ), "The input embedding dim should match the model embed dim"
1797
+
1798
+ batch_size = input_embeddings_1.shape[0]
1799
+
1800
+ latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1)
1801
+
1802
+ outs: Dict[str, torch.Tensor] = {}
1803
+ x = latent_queries
1804
+
1805
+ x, outs = self.apply_attention_blocks(
1806
+ x=x,
1807
+ xf_1=input_embeddings_1,
1808
+ xf_2=input_embeddings_2,
1809
+ outs=outs,
1810
+ attention_mask_1=attention_mask_1,
1811
+ attention_mask_2=attention_mask_2,
1812
+ )
1813
+
1814
+ outs["embeddings"] = x
1815
+
1816
+ return outs
1817
+
1818
+
1819
+ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1820
+ def __init__(
1821
+ self,
1822
+ perceiver_resampler_config: PerceiverResamplerConfig,
1823
+ input_embed_dim: int,
1824
+ embed_dim: int,
1825
+ bio_pad_token_id: int,
1826
+ english_pad_token_id: int,
1827
+ english_vocab_size: int,
1828
+ ):
1829
+ super().__init__()
1830
+ self.config = perceiver_resampler_config
1831
+ self.input_embed_dim = input_embed_dim
1832
+ self.embed_dim = embed_dim
1833
+ self.bio_pad_token_id = bio_pad_token_id
1834
+ self.english_pad_token_id = english_pad_token_id
1835
+ self.english_vocab_size = english_vocab_size
1836
+
1837
+ self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1838
+ self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1839
+ self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
1840
+
1841
+ def forward(
1842
+ self,
1843
+ bio_token_ids: torch.Tensor,
1844
+ bio_embeddings: torch.Tensor,
1845
+ english_token_ids: torch.Tensor,
1846
+ ) -> torch.Tensor:
1847
+ """
1848
+ Args:
1849
+ bio_token_ids (torch.Tensor):
1850
+ Shape (batch_size, num_bio_tokens)
1851
+
1852
+ bio_embeddings (torch.Tensor):
1853
+ Shape (batch_size, num_bio_tokens, embed_dim)
1854
+
1855
+ english_token_ids (torch.Tensor):
1856
+ Shape (batch_size, num_english_tokens)
1857
+ """
1858
+ projected_bio_embeddings = self.bio_projection(bio_embeddings)
1859
+ english_embeddings = self.token_embedding(english_token_ids)
1860
+
1861
+ bio_attention_mask = build_perceiver_padding_attention_mask(
1862
+ bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
1863
+ )
1864
+ english_attention_mask = build_perceiver_padding_attention_mask(
1865
+ english_token_ids, self.config.resampled_length, self.english_pad_token_id
1866
+ )
1867
+
1868
+ projected_embeddings = self.perceiver_resampler(
1869
+ input_embeddings_1=projected_bio_embeddings,
1870
+ attention_mask_1=bio_attention_mask,
1871
+ input_embeddings_2=english_embeddings,
1872
+ attention_mask_2=english_attention_mask,
1873
+ )["embeddings"]
1874
+
1875
+ return projected_embeddings
1876
+
1877
+
1878
+ def build_perceiver_padding_attention_mask(
1879
+ tokens: torch.Tensor, resampled_length: int, pad_token_id: int
1880
+ ) -> torch.Tensor:
1881
+ batch_size, seq_len = tokens.shape
1882
+ padding_mask = tokens != pad_token_id # (batch_size, seq_len)
1883
+
1884
+ padding_mask = torch.cat(
1885
+ [
1886
+ padding_mask,
1887
+ torch.ones(
1888
+ (batch_size, resampled_length), dtype=torch.bool, device=tokens.device
1889
+ ),
1890
+ ],
1891
+ dim=1,
1892
+ ) # (batch_size, seq_len + resampled_length)
1893
+
1894
+ padding_mask = padding_mask[:, None, None, :]
1895
+ padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1896
+ return padding_mask
config.json CHANGED
@@ -7,17 +7,32 @@
7
  "AutoModel": "chatNT.TorchMultiOmicsModel"
8
  },
9
  "bio_pad_token_id": 1,
10
- "custom_pipelines": {
11
- "ChatNT-text-generation": {
12
- "impl": "text_generation.TextGenerationPipeline",
13
- "pt": [
14
- "AutoModel"
15
- ],
16
- "tf": []
17
- }
18
- },
19
  "english_pad_token_id": 2,
20
- "esm_config": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  "add_bias_ffn": false,
22
  "add_bias_kv": false,
23
  "alphabet_size": 4107,
@@ -50,30 +65,6 @@
50
  "use_gradient_checkpointing": false,
51
  "use_rotary_embedding": true
52
  },
53
- "gpt_config": {
54
- "add_bias_attn": false,
55
- "add_bias_ffn": false,
56
- "add_bias_lm_head": false,
57
- "embed_dim": 4096,
58
- "eos_token_id": 2,
59
- "ffn_activation_name": "silu",
60
- "ffn_embed_dim": 11008,
61
- "norm_type": "RMS_norm",
62
- "num_heads": 32,
63
- "num_kv_heads": 32,
64
- "num_layers": 32,
65
- "parallel_attention_ff": false,
66
- "rms_norm_eps": 1e-06,
67
- "rope_config": {
68
- "dim": 128,
69
- "max_seq_len": 2048,
70
- "theta": 10000.0
71
- },
72
- "use_glu_in_ffn": true,
73
- "use_gradient_checkpointing": false,
74
- "vocab_size": 32000
75
- },
76
- "model_type": "ChatNT",
77
  "perceiver_resampler_config": {
78
  "add_bias_ffn": true,
79
  "add_bias_kv": false,
 
7
  "AutoModel": "chatNT.TorchMultiOmicsModel"
8
  },
9
  "bio_pad_token_id": 1,
 
 
 
 
 
 
 
 
 
10
  "english_pad_token_id": 2,
11
+ "gpt_config": {
12
+ "add_bias_attn": false,
13
+ "add_bias_ffn": false,
14
+ "add_bias_lm_head": false,
15
+ "embed_dim": 4096,
16
+ "eos_token_id": 2,
17
+ "ffn_activation_name": "silu",
18
+ "ffn_embed_dim": 11008,
19
+ "norm_type": "RMS_norm",
20
+ "num_heads": 32,
21
+ "num_kv_heads": 32,
22
+ "num_layers": 32,
23
+ "parallel_attention_ff": false,
24
+ "rms_norm_eps": 1e-06,
25
+ "rope_config": {
26
+ "dim": 128,
27
+ "max_seq_len": 2048,
28
+ "theta": 10000.0
29
+ },
30
+ "use_glu_in_ffn": true,
31
+ "use_gradient_checkpointing": false,
32
+ "vocab_size": 32000
33
+ },
34
+ "model_type": "ChatNT",
35
+ "nt_config": {
36
  "add_bias_ffn": false,
37
  "add_bias_kv": false,
38
  "alphabet_size": 4107,
 
65
  "use_gradient_checkpointing": false,
66
  "use_rotary_embedding": true
67
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  "perceiver_resampler_config": {
69
  "add_bias_ffn": true,
70
  "add_bias_kv": false,
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08ad2c4dfd29e6d52694c7a6e2888d8904bad60eb7bf8979832dbc14802c6988
3
+ size 4998275134
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:187615f3a8661430364e2e824d5b0a0363c9cf5b3d8512f33c44015b0be27343
3
+ size 4890784808
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:916b86538557669e3a74c00d4d58ae44e494c4439aba8c2d6ee51baf05f62ebe
3
+ size 4985672264
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8524670292b2f477cd558fd76b3372840949dadd0b0a6c386519b05a82faebe6
3
+ size 1212565848
model.safetensors.index.json ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16087194134
4
+ },
5
+ "weight_map": {
6
+ "biobrain_decoder.gpt_model.final_norm.scale": "model-00001-of-00004.safetensors",
7
+ "biobrain_decoder.gpt_model.layers.0.attn_norm.scale": "model-00001-of-00004.safetensors",
8
+ "biobrain_decoder.gpt_model.layers.0.fc1.weight": "model-00001-of-00004.safetensors",
9
+ "biobrain_decoder.gpt_model.layers.0.fc2.weight": "model-00001-of-00004.safetensors",
10
+ "biobrain_decoder.gpt_model.layers.0.ffn_norm.scale": "model-00001-of-00004.safetensors",
11
+ "biobrain_decoder.gpt_model.layers.0.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
12
+ "biobrain_decoder.gpt_model.layers.0.self_attn.out_linear.weight": "model-00001-of-00004.safetensors",
13
+ "biobrain_decoder.gpt_model.layers.0.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
14
+ "biobrain_decoder.gpt_model.layers.0.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
15
+ "biobrain_decoder.gpt_model.layers.1.attn_norm.scale": "model-00001-of-00004.safetensors",
16
+ "biobrain_decoder.gpt_model.layers.1.fc1.weight": "model-00001-of-00004.safetensors",
17
+ "biobrain_decoder.gpt_model.layers.1.fc2.weight": "model-00001-of-00004.safetensors",
18
+ "biobrain_decoder.gpt_model.layers.1.ffn_norm.scale": "model-00001-of-00004.safetensors",
19
+ "biobrain_decoder.gpt_model.layers.1.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
20
+ "biobrain_decoder.gpt_model.layers.1.self_attn.out_linear.weight": "model-00001-of-00004.safetensors",
21
+ "biobrain_decoder.gpt_model.layers.1.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
22
+ "biobrain_decoder.gpt_model.layers.1.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
23
+ "biobrain_decoder.gpt_model.layers.10.attn_norm.scale": "model-00002-of-00004.safetensors",
24
+ "biobrain_decoder.gpt_model.layers.10.fc1.weight": "model-00002-of-00004.safetensors",
25
+ "biobrain_decoder.gpt_model.layers.10.fc2.weight": "model-00002-of-00004.safetensors",
26
+ "biobrain_decoder.gpt_model.layers.10.ffn_norm.scale": "model-00002-of-00004.safetensors",
27
+ "biobrain_decoder.gpt_model.layers.10.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
28
+ "biobrain_decoder.gpt_model.layers.10.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
29
+ "biobrain_decoder.gpt_model.layers.10.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
30
+ "biobrain_decoder.gpt_model.layers.10.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
31
+ "biobrain_decoder.gpt_model.layers.11.attn_norm.scale": "model-00002-of-00004.safetensors",
32
+ "biobrain_decoder.gpt_model.layers.11.fc1.weight": "model-00002-of-00004.safetensors",
33
+ "biobrain_decoder.gpt_model.layers.11.fc2.weight": "model-00002-of-00004.safetensors",
34
+ "biobrain_decoder.gpt_model.layers.11.ffn_norm.scale": "model-00002-of-00004.safetensors",
35
+ "biobrain_decoder.gpt_model.layers.11.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
36
+ "biobrain_decoder.gpt_model.layers.11.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
37
+ "biobrain_decoder.gpt_model.layers.11.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
38
+ "biobrain_decoder.gpt_model.layers.11.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
39
+ "biobrain_decoder.gpt_model.layers.12.attn_norm.scale": "model-00002-of-00004.safetensors",
40
+ "biobrain_decoder.gpt_model.layers.12.fc1.weight": "model-00002-of-00004.safetensors",
41
+ "biobrain_decoder.gpt_model.layers.12.fc2.weight": "model-00002-of-00004.safetensors",
42
+ "biobrain_decoder.gpt_model.layers.12.ffn_norm.scale": "model-00002-of-00004.safetensors",
43
+ "biobrain_decoder.gpt_model.layers.12.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
44
+ "biobrain_decoder.gpt_model.layers.12.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
45
+ "biobrain_decoder.gpt_model.layers.12.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
46
+ "biobrain_decoder.gpt_model.layers.12.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
47
+ "biobrain_decoder.gpt_model.layers.13.attn_norm.scale": "model-00002-of-00004.safetensors",
48
+ "biobrain_decoder.gpt_model.layers.13.fc1.weight": "model-00002-of-00004.safetensors",
49
+ "biobrain_decoder.gpt_model.layers.13.fc2.weight": "model-00002-of-00004.safetensors",
50
+ "biobrain_decoder.gpt_model.layers.13.ffn_norm.scale": "model-00002-of-00004.safetensors",
51
+ "biobrain_decoder.gpt_model.layers.13.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
52
+ "biobrain_decoder.gpt_model.layers.13.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
53
+ "biobrain_decoder.gpt_model.layers.13.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
54
+ "biobrain_decoder.gpt_model.layers.13.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
55
+ "biobrain_decoder.gpt_model.layers.14.attn_norm.scale": "model-00002-of-00004.safetensors",
56
+ "biobrain_decoder.gpt_model.layers.14.fc1.weight": "model-00002-of-00004.safetensors",
57
+ "biobrain_decoder.gpt_model.layers.14.fc2.weight": "model-00002-of-00004.safetensors",
58
+ "biobrain_decoder.gpt_model.layers.14.ffn_norm.scale": "model-00002-of-00004.safetensors",
59
+ "biobrain_decoder.gpt_model.layers.14.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
60
+ "biobrain_decoder.gpt_model.layers.14.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
61
+ "biobrain_decoder.gpt_model.layers.14.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
62
+ "biobrain_decoder.gpt_model.layers.14.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
63
+ "biobrain_decoder.gpt_model.layers.15.attn_norm.scale": "model-00002-of-00004.safetensors",
64
+ "biobrain_decoder.gpt_model.layers.15.fc1.weight": "model-00002-of-00004.safetensors",
65
+ "biobrain_decoder.gpt_model.layers.15.fc2.weight": "model-00002-of-00004.safetensors",
66
+ "biobrain_decoder.gpt_model.layers.15.ffn_norm.scale": "model-00002-of-00004.safetensors",
67
+ "biobrain_decoder.gpt_model.layers.15.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
68
+ "biobrain_decoder.gpt_model.layers.15.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
69
+ "biobrain_decoder.gpt_model.layers.15.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
70
+ "biobrain_decoder.gpt_model.layers.15.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
71
+ "biobrain_decoder.gpt_model.layers.16.attn_norm.scale": "model-00002-of-00004.safetensors",
72
+ "biobrain_decoder.gpt_model.layers.16.fc1.weight": "model-00002-of-00004.safetensors",
73
+ "biobrain_decoder.gpt_model.layers.16.fc2.weight": "model-00002-of-00004.safetensors",
74
+ "biobrain_decoder.gpt_model.layers.16.ffn_norm.scale": "model-00002-of-00004.safetensors",
75
+ "biobrain_decoder.gpt_model.layers.16.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
76
+ "biobrain_decoder.gpt_model.layers.16.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
77
+ "biobrain_decoder.gpt_model.layers.16.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
78
+ "biobrain_decoder.gpt_model.layers.16.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
79
+ "biobrain_decoder.gpt_model.layers.17.attn_norm.scale": "model-00002-of-00004.safetensors",
80
+ "biobrain_decoder.gpt_model.layers.17.fc1.weight": "model-00002-of-00004.safetensors",
81
+ "biobrain_decoder.gpt_model.layers.17.fc2.weight": "model-00002-of-00004.safetensors",
82
+ "biobrain_decoder.gpt_model.layers.17.ffn_norm.scale": "model-00002-of-00004.safetensors",
83
+ "biobrain_decoder.gpt_model.layers.17.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
84
+ "biobrain_decoder.gpt_model.layers.17.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
85
+ "biobrain_decoder.gpt_model.layers.17.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
86
+ "biobrain_decoder.gpt_model.layers.17.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
87
+ "biobrain_decoder.gpt_model.layers.18.attn_norm.scale": "model-00002-of-00004.safetensors",
88
+ "biobrain_decoder.gpt_model.layers.18.fc1.weight": "model-00002-of-00004.safetensors",
89
+ "biobrain_decoder.gpt_model.layers.18.fc2.weight": "model-00002-of-00004.safetensors",
90
+ "biobrain_decoder.gpt_model.layers.18.ffn_norm.scale": "model-00002-of-00004.safetensors",
91
+ "biobrain_decoder.gpt_model.layers.18.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
92
+ "biobrain_decoder.gpt_model.layers.18.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
93
+ "biobrain_decoder.gpt_model.layers.18.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
94
+ "biobrain_decoder.gpt_model.layers.18.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
95
+ "biobrain_decoder.gpt_model.layers.19.attn_norm.scale": "model-00002-of-00004.safetensors",
96
+ "biobrain_decoder.gpt_model.layers.19.fc1.weight": "model-00002-of-00004.safetensors",
97
+ "biobrain_decoder.gpt_model.layers.19.fc2.weight": "model-00002-of-00004.safetensors",
98
+ "biobrain_decoder.gpt_model.layers.19.ffn_norm.scale": "model-00002-of-00004.safetensors",
99
+ "biobrain_decoder.gpt_model.layers.19.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
100
+ "biobrain_decoder.gpt_model.layers.19.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
101
+ "biobrain_decoder.gpt_model.layers.19.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
102
+ "biobrain_decoder.gpt_model.layers.19.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
103
+ "biobrain_decoder.gpt_model.layers.2.attn_norm.scale": "model-00001-of-00004.safetensors",
104
+ "biobrain_decoder.gpt_model.layers.2.fc1.weight": "model-00001-of-00004.safetensors",
105
+ "biobrain_decoder.gpt_model.layers.2.fc2.weight": "model-00001-of-00004.safetensors",
106
+ "biobrain_decoder.gpt_model.layers.2.ffn_norm.scale": "model-00001-of-00004.safetensors",
107
+ "biobrain_decoder.gpt_model.layers.2.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
108
+ "biobrain_decoder.gpt_model.layers.2.self_attn.out_linear.weight": "model-00001-of-00004.safetensors",
109
+ "biobrain_decoder.gpt_model.layers.2.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
110
+ "biobrain_decoder.gpt_model.layers.2.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
111
+ "biobrain_decoder.gpt_model.layers.20.attn_norm.scale": "model-00002-of-00004.safetensors",
112
+ "biobrain_decoder.gpt_model.layers.20.fc1.weight": "model-00002-of-00004.safetensors",
113
+ "biobrain_decoder.gpt_model.layers.20.fc2.weight": "model-00002-of-00004.safetensors",
114
+ "biobrain_decoder.gpt_model.layers.20.ffn_norm.scale": "model-00002-of-00004.safetensors",
115
+ "biobrain_decoder.gpt_model.layers.20.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
116
+ "biobrain_decoder.gpt_model.layers.20.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
117
+ "biobrain_decoder.gpt_model.layers.20.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
118
+ "biobrain_decoder.gpt_model.layers.20.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
119
+ "biobrain_decoder.gpt_model.layers.21.attn_norm.scale": "model-00002-of-00004.safetensors",
120
+ "biobrain_decoder.gpt_model.layers.21.fc1.weight": "model-00003-of-00004.safetensors",
121
+ "biobrain_decoder.gpt_model.layers.21.fc2.weight": "model-00003-of-00004.safetensors",
122
+ "biobrain_decoder.gpt_model.layers.21.ffn_norm.scale": "model-00002-of-00004.safetensors",
123
+ "biobrain_decoder.gpt_model.layers.21.self_attn.key_linear.weight": "model-00002-of-00004.safetensors",
124
+ "biobrain_decoder.gpt_model.layers.21.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
125
+ "biobrain_decoder.gpt_model.layers.21.self_attn.query_linear.weight": "model-00002-of-00004.safetensors",
126
+ "biobrain_decoder.gpt_model.layers.21.self_attn.value_linear.weight": "model-00002-of-00004.safetensors",
127
+ "biobrain_decoder.gpt_model.layers.22.attn_norm.scale": "model-00003-of-00004.safetensors",
128
+ "biobrain_decoder.gpt_model.layers.22.fc1.weight": "model-00003-of-00004.safetensors",
129
+ "biobrain_decoder.gpt_model.layers.22.fc2.weight": "model-00003-of-00004.safetensors",
130
+ "biobrain_decoder.gpt_model.layers.22.ffn_norm.scale": "model-00003-of-00004.safetensors",
131
+ "biobrain_decoder.gpt_model.layers.22.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
132
+ "biobrain_decoder.gpt_model.layers.22.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
133
+ "biobrain_decoder.gpt_model.layers.22.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
134
+ "biobrain_decoder.gpt_model.layers.22.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
135
+ "biobrain_decoder.gpt_model.layers.23.attn_norm.scale": "model-00003-of-00004.safetensors",
136
+ "biobrain_decoder.gpt_model.layers.23.fc1.weight": "model-00003-of-00004.safetensors",
137
+ "biobrain_decoder.gpt_model.layers.23.fc2.weight": "model-00003-of-00004.safetensors",
138
+ "biobrain_decoder.gpt_model.layers.23.ffn_norm.scale": "model-00003-of-00004.safetensors",
139
+ "biobrain_decoder.gpt_model.layers.23.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
140
+ "biobrain_decoder.gpt_model.layers.23.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
141
+ "biobrain_decoder.gpt_model.layers.23.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
142
+ "biobrain_decoder.gpt_model.layers.23.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
143
+ "biobrain_decoder.gpt_model.layers.24.attn_norm.scale": "model-00003-of-00004.safetensors",
144
+ "biobrain_decoder.gpt_model.layers.24.fc1.weight": "model-00003-of-00004.safetensors",
145
+ "biobrain_decoder.gpt_model.layers.24.fc2.weight": "model-00003-of-00004.safetensors",
146
+ "biobrain_decoder.gpt_model.layers.24.ffn_norm.scale": "model-00003-of-00004.safetensors",
147
+ "biobrain_decoder.gpt_model.layers.24.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
148
+ "biobrain_decoder.gpt_model.layers.24.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
149
+ "biobrain_decoder.gpt_model.layers.24.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
150
+ "biobrain_decoder.gpt_model.layers.24.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
151
+ "biobrain_decoder.gpt_model.layers.25.attn_norm.scale": "model-00003-of-00004.safetensors",
152
+ "biobrain_decoder.gpt_model.layers.25.fc1.weight": "model-00003-of-00004.safetensors",
153
+ "biobrain_decoder.gpt_model.layers.25.fc2.weight": "model-00003-of-00004.safetensors",
154
+ "biobrain_decoder.gpt_model.layers.25.ffn_norm.scale": "model-00003-of-00004.safetensors",
155
+ "biobrain_decoder.gpt_model.layers.25.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
156
+ "biobrain_decoder.gpt_model.layers.25.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
157
+ "biobrain_decoder.gpt_model.layers.25.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
158
+ "biobrain_decoder.gpt_model.layers.25.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
159
+ "biobrain_decoder.gpt_model.layers.26.attn_norm.scale": "model-00003-of-00004.safetensors",
160
+ "biobrain_decoder.gpt_model.layers.26.fc1.weight": "model-00003-of-00004.safetensors",
161
+ "biobrain_decoder.gpt_model.layers.26.fc2.weight": "model-00003-of-00004.safetensors",
162
+ "biobrain_decoder.gpt_model.layers.26.ffn_norm.scale": "model-00003-of-00004.safetensors",
163
+ "biobrain_decoder.gpt_model.layers.26.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
164
+ "biobrain_decoder.gpt_model.layers.26.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
165
+ "biobrain_decoder.gpt_model.layers.26.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
166
+ "biobrain_decoder.gpt_model.layers.26.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
167
+ "biobrain_decoder.gpt_model.layers.27.attn_norm.scale": "model-00003-of-00004.safetensors",
168
+ "biobrain_decoder.gpt_model.layers.27.fc1.weight": "model-00003-of-00004.safetensors",
169
+ "biobrain_decoder.gpt_model.layers.27.fc2.weight": "model-00003-of-00004.safetensors",
170
+ "biobrain_decoder.gpt_model.layers.27.ffn_norm.scale": "model-00003-of-00004.safetensors",
171
+ "biobrain_decoder.gpt_model.layers.27.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
172
+ "biobrain_decoder.gpt_model.layers.27.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
173
+ "biobrain_decoder.gpt_model.layers.27.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
174
+ "biobrain_decoder.gpt_model.layers.27.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
175
+ "biobrain_decoder.gpt_model.layers.28.attn_norm.scale": "model-00003-of-00004.safetensors",
176
+ "biobrain_decoder.gpt_model.layers.28.fc1.weight": "model-00003-of-00004.safetensors",
177
+ "biobrain_decoder.gpt_model.layers.28.fc2.weight": "model-00003-of-00004.safetensors",
178
+ "biobrain_decoder.gpt_model.layers.28.ffn_norm.scale": "model-00003-of-00004.safetensors",
179
+ "biobrain_decoder.gpt_model.layers.28.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
180
+ "biobrain_decoder.gpt_model.layers.28.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
181
+ "biobrain_decoder.gpt_model.layers.28.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
182
+ "biobrain_decoder.gpt_model.layers.28.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
183
+ "biobrain_decoder.gpt_model.layers.29.attn_norm.scale": "model-00003-of-00004.safetensors",
184
+ "biobrain_decoder.gpt_model.layers.29.fc1.weight": "model-00003-of-00004.safetensors",
185
+ "biobrain_decoder.gpt_model.layers.29.fc2.weight": "model-00003-of-00004.safetensors",
186
+ "biobrain_decoder.gpt_model.layers.29.ffn_norm.scale": "model-00003-of-00004.safetensors",
187
+ "biobrain_decoder.gpt_model.layers.29.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
188
+ "biobrain_decoder.gpt_model.layers.29.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
189
+ "biobrain_decoder.gpt_model.layers.29.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
190
+ "biobrain_decoder.gpt_model.layers.29.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
191
+ "biobrain_decoder.gpt_model.layers.3.attn_norm.scale": "model-00001-of-00004.safetensors",
192
+ "biobrain_decoder.gpt_model.layers.3.fc1.weight": "model-00001-of-00004.safetensors",
193
+ "biobrain_decoder.gpt_model.layers.3.fc2.weight": "model-00001-of-00004.safetensors",
194
+ "biobrain_decoder.gpt_model.layers.3.ffn_norm.scale": "model-00001-of-00004.safetensors",
195
+ "biobrain_decoder.gpt_model.layers.3.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
196
+ "biobrain_decoder.gpt_model.layers.3.self_attn.out_linear.weight": "model-00001-of-00004.safetensors",
197
+ "biobrain_decoder.gpt_model.layers.3.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
198
+ "biobrain_decoder.gpt_model.layers.3.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
199
+ "biobrain_decoder.gpt_model.layers.30.attn_norm.scale": "model-00003-of-00004.safetensors",
200
+ "biobrain_decoder.gpt_model.layers.30.fc1.weight": "model-00003-of-00004.safetensors",
201
+ "biobrain_decoder.gpt_model.layers.30.fc2.weight": "model-00003-of-00004.safetensors",
202
+ "biobrain_decoder.gpt_model.layers.30.ffn_norm.scale": "model-00003-of-00004.safetensors",
203
+ "biobrain_decoder.gpt_model.layers.30.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
204
+ "biobrain_decoder.gpt_model.layers.30.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
205
+ "biobrain_decoder.gpt_model.layers.30.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
206
+ "biobrain_decoder.gpt_model.layers.30.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
207
+ "biobrain_decoder.gpt_model.layers.31.attn_norm.scale": "model-00003-of-00004.safetensors",
208
+ "biobrain_decoder.gpt_model.layers.31.fc1.weight": "model-00003-of-00004.safetensors",
209
+ "biobrain_decoder.gpt_model.layers.31.fc2.weight": "model-00003-of-00004.safetensors",
210
+ "biobrain_decoder.gpt_model.layers.31.ffn_norm.scale": "model-00003-of-00004.safetensors",
211
+ "biobrain_decoder.gpt_model.layers.31.self_attn.key_linear.weight": "model-00003-of-00004.safetensors",
212
+ "biobrain_decoder.gpt_model.layers.31.self_attn.out_linear.weight": "model-00003-of-00004.safetensors",
213
+ "biobrain_decoder.gpt_model.layers.31.self_attn.query_linear.weight": "model-00003-of-00004.safetensors",
214
+ "biobrain_decoder.gpt_model.layers.31.self_attn.value_linear.weight": "model-00003-of-00004.safetensors",
215
+ "biobrain_decoder.gpt_model.layers.4.attn_norm.scale": "model-00001-of-00004.safetensors",
216
+ "biobrain_decoder.gpt_model.layers.4.fc1.weight": "model-00001-of-00004.safetensors",
217
+ "biobrain_decoder.gpt_model.layers.4.fc2.weight": "model-00001-of-00004.safetensors",
218
+ "biobrain_decoder.gpt_model.layers.4.ffn_norm.scale": "model-00001-of-00004.safetensors",
219
+ "biobrain_decoder.gpt_model.layers.4.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
220
+ "biobrain_decoder.gpt_model.layers.4.self_attn.out_linear.weight": "model-00001-of-00004.safetensors",
221
+ "biobrain_decoder.gpt_model.layers.4.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
222
+ "biobrain_decoder.gpt_model.layers.4.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
223
+ "biobrain_decoder.gpt_model.layers.5.attn_norm.scale": "model-00001-of-00004.safetensors",
224
+ "biobrain_decoder.gpt_model.layers.5.fc1.weight": "model-00001-of-00004.safetensors",
225
+ "biobrain_decoder.gpt_model.layers.5.fc2.weight": "model-00001-of-00004.safetensors",
226
+ "biobrain_decoder.gpt_model.layers.5.ffn_norm.scale": "model-00001-of-00004.safetensors",
227
+ "biobrain_decoder.gpt_model.layers.5.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
228
+ "biobrain_decoder.gpt_model.layers.5.self_attn.out_linear.weight": "model-00001-of-00004.safetensors",
229
+ "biobrain_decoder.gpt_model.layers.5.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
230
+ "biobrain_decoder.gpt_model.layers.5.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
231
+ "biobrain_decoder.gpt_model.layers.6.attn_norm.scale": "model-00001-of-00004.safetensors",
232
+ "biobrain_decoder.gpt_model.layers.6.fc1.weight": "model-00001-of-00004.safetensors",
233
+ "biobrain_decoder.gpt_model.layers.6.fc2.weight": "model-00001-of-00004.safetensors",
234
+ "biobrain_decoder.gpt_model.layers.6.ffn_norm.scale": "model-00001-of-00004.safetensors",
235
+ "biobrain_decoder.gpt_model.layers.6.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
236
+ "biobrain_decoder.gpt_model.layers.6.self_attn.out_linear.weight": "model-00001-of-00004.safetensors",
237
+ "biobrain_decoder.gpt_model.layers.6.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
238
+ "biobrain_decoder.gpt_model.layers.6.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
239
+ "biobrain_decoder.gpt_model.layers.7.attn_norm.scale": "model-00001-of-00004.safetensors",
240
+ "biobrain_decoder.gpt_model.layers.7.fc1.weight": "model-00001-of-00004.safetensors",
241
+ "biobrain_decoder.gpt_model.layers.7.fc2.weight": "model-00001-of-00004.safetensors",
242
+ "biobrain_decoder.gpt_model.layers.7.ffn_norm.scale": "model-00001-of-00004.safetensors",
243
+ "biobrain_decoder.gpt_model.layers.7.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
244
+ "biobrain_decoder.gpt_model.layers.7.self_attn.out_linear.weight": "model-00001-of-00004.safetensors",
245
+ "biobrain_decoder.gpt_model.layers.7.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
246
+ "biobrain_decoder.gpt_model.layers.7.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
247
+ "biobrain_decoder.gpt_model.layers.8.attn_norm.scale": "model-00001-of-00004.safetensors",
248
+ "biobrain_decoder.gpt_model.layers.8.fc1.weight": "model-00001-of-00004.safetensors",
249
+ "biobrain_decoder.gpt_model.layers.8.fc2.weight": "model-00001-of-00004.safetensors",
250
+ "biobrain_decoder.gpt_model.layers.8.ffn_norm.scale": "model-00001-of-00004.safetensors",
251
+ "biobrain_decoder.gpt_model.layers.8.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
252
+ "biobrain_decoder.gpt_model.layers.8.self_attn.out_linear.weight": "model-00001-of-00004.safetensors",
253
+ "biobrain_decoder.gpt_model.layers.8.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
254
+ "biobrain_decoder.gpt_model.layers.8.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
255
+ "biobrain_decoder.gpt_model.layers.9.attn_norm.scale": "model-00002-of-00004.safetensors",
256
+ "biobrain_decoder.gpt_model.layers.9.fc1.weight": "model-00002-of-00004.safetensors",
257
+ "biobrain_decoder.gpt_model.layers.9.fc2.weight": "model-00002-of-00004.safetensors",
258
+ "biobrain_decoder.gpt_model.layers.9.ffn_norm.scale": "model-00002-of-00004.safetensors",
259
+ "biobrain_decoder.gpt_model.layers.9.self_attn.key_linear.weight": "model-00001-of-00004.safetensors",
260
+ "biobrain_decoder.gpt_model.layers.9.self_attn.out_linear.weight": "model-00002-of-00004.safetensors",
261
+ "biobrain_decoder.gpt_model.layers.9.self_attn.query_linear.weight": "model-00001-of-00004.safetensors",
262
+ "biobrain_decoder.gpt_model.layers.9.self_attn.value_linear.weight": "model-00001-of-00004.safetensors",
263
+ "biobrain_decoder.gpt_model.lm_head.fc.weight": "model-00003-of-00004.safetensors",
264
+ "biobrain_decoder.gpt_model.token_embed.weight": "model-00001-of-00004.safetensors",
265
+ "biobrain_encoder.nt_model.attention_blocks.0.fc1.weight": "model-00001-of-00004.safetensors",
266
+ "biobrain_encoder.nt_model.attention_blocks.0.fc2.weight": "model-00001-of-00004.safetensors",
267
+ "biobrain_encoder.nt_model.attention_blocks.0.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
268
+ "biobrain_encoder.nt_model.attention_blocks.0.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
269
+ "biobrain_encoder.nt_model.attention_blocks.0.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
270
+ "biobrain_encoder.nt_model.attention_blocks.0.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
271
+ "biobrain_encoder.nt_model.attention_blocks.0.mha.output.bias": "model-00001-of-00004.safetensors",
272
+ "biobrain_encoder.nt_model.attention_blocks.0.mha.output.weight": "model-00001-of-00004.safetensors",
273
+ "biobrain_encoder.nt_model.attention_blocks.0.mha.w_k.bias": "model-00001-of-00004.safetensors",
274
+ "biobrain_encoder.nt_model.attention_blocks.0.mha.w_k.weight": "model-00001-of-00004.safetensors",
275
+ "biobrain_encoder.nt_model.attention_blocks.0.mha.w_q.bias": "model-00001-of-00004.safetensors",
276
+ "biobrain_encoder.nt_model.attention_blocks.0.mha.w_q.weight": "model-00001-of-00004.safetensors",
277
+ "biobrain_encoder.nt_model.attention_blocks.0.mha.w_v.bias": "model-00001-of-00004.safetensors",
278
+ "biobrain_encoder.nt_model.attention_blocks.0.mha.w_v.weight": "model-00001-of-00004.safetensors",
279
+ "biobrain_encoder.nt_model.attention_blocks.1.fc1.weight": "model-00001-of-00004.safetensors",
280
+ "biobrain_encoder.nt_model.attention_blocks.1.fc2.weight": "model-00001-of-00004.safetensors",
281
+ "biobrain_encoder.nt_model.attention_blocks.1.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
282
+ "biobrain_encoder.nt_model.attention_blocks.1.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
283
+ "biobrain_encoder.nt_model.attention_blocks.1.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
284
+ "biobrain_encoder.nt_model.attention_blocks.1.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
285
+ "biobrain_encoder.nt_model.attention_blocks.1.mha.output.bias": "model-00001-of-00004.safetensors",
286
+ "biobrain_encoder.nt_model.attention_blocks.1.mha.output.weight": "model-00001-of-00004.safetensors",
287
+ "biobrain_encoder.nt_model.attention_blocks.1.mha.w_k.bias": "model-00001-of-00004.safetensors",
288
+ "biobrain_encoder.nt_model.attention_blocks.1.mha.w_k.weight": "model-00001-of-00004.safetensors",
289
+ "biobrain_encoder.nt_model.attention_blocks.1.mha.w_q.bias": "model-00001-of-00004.safetensors",
290
+ "biobrain_encoder.nt_model.attention_blocks.1.mha.w_q.weight": "model-00001-of-00004.safetensors",
291
+ "biobrain_encoder.nt_model.attention_blocks.1.mha.w_v.bias": "model-00001-of-00004.safetensors",
292
+ "biobrain_encoder.nt_model.attention_blocks.1.mha.w_v.weight": "model-00001-of-00004.safetensors",
293
+ "biobrain_encoder.nt_model.attention_blocks.10.fc1.weight": "model-00001-of-00004.safetensors",
294
+ "biobrain_encoder.nt_model.attention_blocks.10.fc2.weight": "model-00001-of-00004.safetensors",
295
+ "biobrain_encoder.nt_model.attention_blocks.10.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
296
+ "biobrain_encoder.nt_model.attention_blocks.10.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
297
+ "biobrain_encoder.nt_model.attention_blocks.10.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
298
+ "biobrain_encoder.nt_model.attention_blocks.10.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
299
+ "biobrain_encoder.nt_model.attention_blocks.10.mha.output.bias": "model-00001-of-00004.safetensors",
300
+ "biobrain_encoder.nt_model.attention_blocks.10.mha.output.weight": "model-00001-of-00004.safetensors",
301
+ "biobrain_encoder.nt_model.attention_blocks.10.mha.w_k.bias": "model-00001-of-00004.safetensors",
302
+ "biobrain_encoder.nt_model.attention_blocks.10.mha.w_k.weight": "model-00001-of-00004.safetensors",
303
+ "biobrain_encoder.nt_model.attention_blocks.10.mha.w_q.bias": "model-00001-of-00004.safetensors",
304
+ "biobrain_encoder.nt_model.attention_blocks.10.mha.w_q.weight": "model-00001-of-00004.safetensors",
305
+ "biobrain_encoder.nt_model.attention_blocks.10.mha.w_v.bias": "model-00001-of-00004.safetensors",
306
+ "biobrain_encoder.nt_model.attention_blocks.10.mha.w_v.weight": "model-00001-of-00004.safetensors",
307
+ "biobrain_encoder.nt_model.attention_blocks.11.fc1.weight": "model-00001-of-00004.safetensors",
308
+ "biobrain_encoder.nt_model.attention_blocks.11.fc2.weight": "model-00001-of-00004.safetensors",
309
+ "biobrain_encoder.nt_model.attention_blocks.11.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
310
+ "biobrain_encoder.nt_model.attention_blocks.11.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
311
+ "biobrain_encoder.nt_model.attention_blocks.11.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
312
+ "biobrain_encoder.nt_model.attention_blocks.11.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
313
+ "biobrain_encoder.nt_model.attention_blocks.11.mha.output.bias": "model-00001-of-00004.safetensors",
314
+ "biobrain_encoder.nt_model.attention_blocks.11.mha.output.weight": "model-00001-of-00004.safetensors",
315
+ "biobrain_encoder.nt_model.attention_blocks.11.mha.w_k.bias": "model-00001-of-00004.safetensors",
316
+ "biobrain_encoder.nt_model.attention_blocks.11.mha.w_k.weight": "model-00001-of-00004.safetensors",
317
+ "biobrain_encoder.nt_model.attention_blocks.11.mha.w_q.bias": "model-00001-of-00004.safetensors",
318
+ "biobrain_encoder.nt_model.attention_blocks.11.mha.w_q.weight": "model-00001-of-00004.safetensors",
319
+ "biobrain_encoder.nt_model.attention_blocks.11.mha.w_v.bias": "model-00001-of-00004.safetensors",
320
+ "biobrain_encoder.nt_model.attention_blocks.11.mha.w_v.weight": "model-00001-of-00004.safetensors",
321
+ "biobrain_encoder.nt_model.attention_blocks.12.fc1.weight": "model-00001-of-00004.safetensors",
322
+ "biobrain_encoder.nt_model.attention_blocks.12.fc2.weight": "model-00001-of-00004.safetensors",
323
+ "biobrain_encoder.nt_model.attention_blocks.12.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
324
+ "biobrain_encoder.nt_model.attention_blocks.12.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
325
+ "biobrain_encoder.nt_model.attention_blocks.12.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
326
+ "biobrain_encoder.nt_model.attention_blocks.12.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
327
+ "biobrain_encoder.nt_model.attention_blocks.12.mha.output.bias": "model-00001-of-00004.safetensors",
328
+ "biobrain_encoder.nt_model.attention_blocks.12.mha.output.weight": "model-00001-of-00004.safetensors",
329
+ "biobrain_encoder.nt_model.attention_blocks.12.mha.w_k.bias": "model-00001-of-00004.safetensors",
330
+ "biobrain_encoder.nt_model.attention_blocks.12.mha.w_k.weight": "model-00001-of-00004.safetensors",
331
+ "biobrain_encoder.nt_model.attention_blocks.12.mha.w_q.bias": "model-00001-of-00004.safetensors",
332
+ "biobrain_encoder.nt_model.attention_blocks.12.mha.w_q.weight": "model-00001-of-00004.safetensors",
333
+ "biobrain_encoder.nt_model.attention_blocks.12.mha.w_v.bias": "model-00001-of-00004.safetensors",
334
+ "biobrain_encoder.nt_model.attention_blocks.12.mha.w_v.weight": "model-00001-of-00004.safetensors",
335
+ "biobrain_encoder.nt_model.attention_blocks.13.fc1.weight": "model-00001-of-00004.safetensors",
336
+ "biobrain_encoder.nt_model.attention_blocks.13.fc2.weight": "model-00001-of-00004.safetensors",
337
+ "biobrain_encoder.nt_model.attention_blocks.13.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
338
+ "biobrain_encoder.nt_model.attention_blocks.13.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
339
+ "biobrain_encoder.nt_model.attention_blocks.13.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
340
+ "biobrain_encoder.nt_model.attention_blocks.13.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
341
+ "biobrain_encoder.nt_model.attention_blocks.13.mha.output.bias": "model-00001-of-00004.safetensors",
342
+ "biobrain_encoder.nt_model.attention_blocks.13.mha.output.weight": "model-00001-of-00004.safetensors",
343
+ "biobrain_encoder.nt_model.attention_blocks.13.mha.w_k.bias": "model-00001-of-00004.safetensors",
344
+ "biobrain_encoder.nt_model.attention_blocks.13.mha.w_k.weight": "model-00001-of-00004.safetensors",
345
+ "biobrain_encoder.nt_model.attention_blocks.13.mha.w_q.bias": "model-00001-of-00004.safetensors",
346
+ "biobrain_encoder.nt_model.attention_blocks.13.mha.w_q.weight": "model-00001-of-00004.safetensors",
347
+ "biobrain_encoder.nt_model.attention_blocks.13.mha.w_v.bias": "model-00001-of-00004.safetensors",
348
+ "biobrain_encoder.nt_model.attention_blocks.13.mha.w_v.weight": "model-00001-of-00004.safetensors",
349
+ "biobrain_encoder.nt_model.attention_blocks.14.fc1.weight": "model-00001-of-00004.safetensors",
350
+ "biobrain_encoder.nt_model.attention_blocks.14.fc2.weight": "model-00001-of-00004.safetensors",
351
+ "biobrain_encoder.nt_model.attention_blocks.14.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
352
+ "biobrain_encoder.nt_model.attention_blocks.14.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
353
+ "biobrain_encoder.nt_model.attention_blocks.14.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
354
+ "biobrain_encoder.nt_model.attention_blocks.14.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
355
+ "biobrain_encoder.nt_model.attention_blocks.14.mha.output.bias": "model-00001-of-00004.safetensors",
356
+ "biobrain_encoder.nt_model.attention_blocks.14.mha.output.weight": "model-00001-of-00004.safetensors",
357
+ "biobrain_encoder.nt_model.attention_blocks.14.mha.w_k.bias": "model-00001-of-00004.safetensors",
358
+ "biobrain_encoder.nt_model.attention_blocks.14.mha.w_k.weight": "model-00001-of-00004.safetensors",
359
+ "biobrain_encoder.nt_model.attention_blocks.14.mha.w_q.bias": "model-00001-of-00004.safetensors",
360
+ "biobrain_encoder.nt_model.attention_blocks.14.mha.w_q.weight": "model-00001-of-00004.safetensors",
361
+ "biobrain_encoder.nt_model.attention_blocks.14.mha.w_v.bias": "model-00001-of-00004.safetensors",
362
+ "biobrain_encoder.nt_model.attention_blocks.14.mha.w_v.weight": "model-00001-of-00004.safetensors",
363
+ "biobrain_encoder.nt_model.attention_blocks.15.fc1.weight": "model-00001-of-00004.safetensors",
364
+ "biobrain_encoder.nt_model.attention_blocks.15.fc2.weight": "model-00001-of-00004.safetensors",
365
+ "biobrain_encoder.nt_model.attention_blocks.15.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
366
+ "biobrain_encoder.nt_model.attention_blocks.15.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
367
+ "biobrain_encoder.nt_model.attention_blocks.15.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
368
+ "biobrain_encoder.nt_model.attention_blocks.15.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
369
+ "biobrain_encoder.nt_model.attention_blocks.15.mha.output.bias": "model-00001-of-00004.safetensors",
370
+ "biobrain_encoder.nt_model.attention_blocks.15.mha.output.weight": "model-00001-of-00004.safetensors",
371
+ "biobrain_encoder.nt_model.attention_blocks.15.mha.w_k.bias": "model-00001-of-00004.safetensors",
372
+ "biobrain_encoder.nt_model.attention_blocks.15.mha.w_k.weight": "model-00001-of-00004.safetensors",
373
+ "biobrain_encoder.nt_model.attention_blocks.15.mha.w_q.bias": "model-00001-of-00004.safetensors",
374
+ "biobrain_encoder.nt_model.attention_blocks.15.mha.w_q.weight": "model-00001-of-00004.safetensors",
375
+ "biobrain_encoder.nt_model.attention_blocks.15.mha.w_v.bias": "model-00001-of-00004.safetensors",
376
+ "biobrain_encoder.nt_model.attention_blocks.15.mha.w_v.weight": "model-00001-of-00004.safetensors",
377
+ "biobrain_encoder.nt_model.attention_blocks.16.fc1.weight": "model-00001-of-00004.safetensors",
378
+ "biobrain_encoder.nt_model.attention_blocks.16.fc2.weight": "model-00001-of-00004.safetensors",
379
+ "biobrain_encoder.nt_model.attention_blocks.16.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
380
+ "biobrain_encoder.nt_model.attention_blocks.16.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
381
+ "biobrain_encoder.nt_model.attention_blocks.16.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
382
+ "biobrain_encoder.nt_model.attention_blocks.16.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
383
+ "biobrain_encoder.nt_model.attention_blocks.16.mha.output.bias": "model-00001-of-00004.safetensors",
384
+ "biobrain_encoder.nt_model.attention_blocks.16.mha.output.weight": "model-00001-of-00004.safetensors",
385
+ "biobrain_encoder.nt_model.attention_blocks.16.mha.w_k.bias": "model-00001-of-00004.safetensors",
386
+ "biobrain_encoder.nt_model.attention_blocks.16.mha.w_k.weight": "model-00001-of-00004.safetensors",
387
+ "biobrain_encoder.nt_model.attention_blocks.16.mha.w_q.bias": "model-00001-of-00004.safetensors",
388
+ "biobrain_encoder.nt_model.attention_blocks.16.mha.w_q.weight": "model-00001-of-00004.safetensors",
389
+ "biobrain_encoder.nt_model.attention_blocks.16.mha.w_v.bias": "model-00001-of-00004.safetensors",
390
+ "biobrain_encoder.nt_model.attention_blocks.16.mha.w_v.weight": "model-00001-of-00004.safetensors",
391
+ "biobrain_encoder.nt_model.attention_blocks.17.fc1.weight": "model-00001-of-00004.safetensors",
392
+ "biobrain_encoder.nt_model.attention_blocks.17.fc2.weight": "model-00001-of-00004.safetensors",
393
+ "biobrain_encoder.nt_model.attention_blocks.17.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
394
+ "biobrain_encoder.nt_model.attention_blocks.17.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
395
+ "biobrain_encoder.nt_model.attention_blocks.17.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
396
+ "biobrain_encoder.nt_model.attention_blocks.17.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
397
+ "biobrain_encoder.nt_model.attention_blocks.17.mha.output.bias": "model-00001-of-00004.safetensors",
398
+ "biobrain_encoder.nt_model.attention_blocks.17.mha.output.weight": "model-00001-of-00004.safetensors",
399
+ "biobrain_encoder.nt_model.attention_blocks.17.mha.w_k.bias": "model-00001-of-00004.safetensors",
400
+ "biobrain_encoder.nt_model.attention_blocks.17.mha.w_k.weight": "model-00001-of-00004.safetensors",
401
+ "biobrain_encoder.nt_model.attention_blocks.17.mha.w_q.bias": "model-00001-of-00004.safetensors",
402
+ "biobrain_encoder.nt_model.attention_blocks.17.mha.w_q.weight": "model-00001-of-00004.safetensors",
403
+ "biobrain_encoder.nt_model.attention_blocks.17.mha.w_v.bias": "model-00001-of-00004.safetensors",
404
+ "biobrain_encoder.nt_model.attention_blocks.17.mha.w_v.weight": "model-00001-of-00004.safetensors",
405
+ "biobrain_encoder.nt_model.attention_blocks.18.fc1.weight": "model-00001-of-00004.safetensors",
406
+ "biobrain_encoder.nt_model.attention_blocks.18.fc2.weight": "model-00001-of-00004.safetensors",
407
+ "biobrain_encoder.nt_model.attention_blocks.18.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
408
+ "biobrain_encoder.nt_model.attention_blocks.18.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
409
+ "biobrain_encoder.nt_model.attention_blocks.18.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
410
+ "biobrain_encoder.nt_model.attention_blocks.18.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
411
+ "biobrain_encoder.nt_model.attention_blocks.18.mha.output.bias": "model-00001-of-00004.safetensors",
412
+ "biobrain_encoder.nt_model.attention_blocks.18.mha.output.weight": "model-00001-of-00004.safetensors",
413
+ "biobrain_encoder.nt_model.attention_blocks.18.mha.w_k.bias": "model-00001-of-00004.safetensors",
414
+ "biobrain_encoder.nt_model.attention_blocks.18.mha.w_k.weight": "model-00001-of-00004.safetensors",
415
+ "biobrain_encoder.nt_model.attention_blocks.18.mha.w_q.bias": "model-00001-of-00004.safetensors",
416
+ "biobrain_encoder.nt_model.attention_blocks.18.mha.w_q.weight": "model-00001-of-00004.safetensors",
417
+ "biobrain_encoder.nt_model.attention_blocks.18.mha.w_v.bias": "model-00001-of-00004.safetensors",
418
+ "biobrain_encoder.nt_model.attention_blocks.18.mha.w_v.weight": "model-00001-of-00004.safetensors",
419
+ "biobrain_encoder.nt_model.attention_blocks.19.fc1.weight": "model-00001-of-00004.safetensors",
420
+ "biobrain_encoder.nt_model.attention_blocks.19.fc2.weight": "model-00001-of-00004.safetensors",
421
+ "biobrain_encoder.nt_model.attention_blocks.19.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
422
+ "biobrain_encoder.nt_model.attention_blocks.19.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
423
+ "biobrain_encoder.nt_model.attention_blocks.19.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
424
+ "biobrain_encoder.nt_model.attention_blocks.19.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
425
+ "biobrain_encoder.nt_model.attention_blocks.19.mha.output.bias": "model-00001-of-00004.safetensors",
426
+ "biobrain_encoder.nt_model.attention_blocks.19.mha.output.weight": "model-00001-of-00004.safetensors",
427
+ "biobrain_encoder.nt_model.attention_blocks.19.mha.w_k.bias": "model-00001-of-00004.safetensors",
428
+ "biobrain_encoder.nt_model.attention_blocks.19.mha.w_k.weight": "model-00001-of-00004.safetensors",
429
+ "biobrain_encoder.nt_model.attention_blocks.19.mha.w_q.bias": "model-00001-of-00004.safetensors",
430
+ "biobrain_encoder.nt_model.attention_blocks.19.mha.w_q.weight": "model-00001-of-00004.safetensors",
431
+ "biobrain_encoder.nt_model.attention_blocks.19.mha.w_v.bias": "model-00001-of-00004.safetensors",
432
+ "biobrain_encoder.nt_model.attention_blocks.19.mha.w_v.weight": "model-00001-of-00004.safetensors",
433
+ "biobrain_encoder.nt_model.attention_blocks.2.fc1.weight": "model-00001-of-00004.safetensors",
434
+ "biobrain_encoder.nt_model.attention_blocks.2.fc2.weight": "model-00001-of-00004.safetensors",
435
+ "biobrain_encoder.nt_model.attention_blocks.2.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
436
+ "biobrain_encoder.nt_model.attention_blocks.2.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
437
+ "biobrain_encoder.nt_model.attention_blocks.2.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
438
+ "biobrain_encoder.nt_model.attention_blocks.2.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
439
+ "biobrain_encoder.nt_model.attention_blocks.2.mha.output.bias": "model-00001-of-00004.safetensors",
440
+ "biobrain_encoder.nt_model.attention_blocks.2.mha.output.weight": "model-00001-of-00004.safetensors",
441
+ "biobrain_encoder.nt_model.attention_blocks.2.mha.w_k.bias": "model-00001-of-00004.safetensors",
442
+ "biobrain_encoder.nt_model.attention_blocks.2.mha.w_k.weight": "model-00001-of-00004.safetensors",
443
+ "biobrain_encoder.nt_model.attention_blocks.2.mha.w_q.bias": "model-00001-of-00004.safetensors",
444
+ "biobrain_encoder.nt_model.attention_blocks.2.mha.w_q.weight": "model-00001-of-00004.safetensors",
445
+ "biobrain_encoder.nt_model.attention_blocks.2.mha.w_v.bias": "model-00001-of-00004.safetensors",
446
+ "biobrain_encoder.nt_model.attention_blocks.2.mha.w_v.weight": "model-00001-of-00004.safetensors",
447
+ "biobrain_encoder.nt_model.attention_blocks.20.fc1.weight": "model-00001-of-00004.safetensors",
448
+ "biobrain_encoder.nt_model.attention_blocks.20.fc2.weight": "model-00001-of-00004.safetensors",
449
+ "biobrain_encoder.nt_model.attention_blocks.20.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
450
+ "biobrain_encoder.nt_model.attention_blocks.20.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
451
+ "biobrain_encoder.nt_model.attention_blocks.20.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
452
+ "biobrain_encoder.nt_model.attention_blocks.20.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
453
+ "biobrain_encoder.nt_model.attention_blocks.20.mha.output.bias": "model-00001-of-00004.safetensors",
454
+ "biobrain_encoder.nt_model.attention_blocks.20.mha.output.weight": "model-00001-of-00004.safetensors",
455
+ "biobrain_encoder.nt_model.attention_blocks.20.mha.w_k.bias": "model-00001-of-00004.safetensors",
456
+ "biobrain_encoder.nt_model.attention_blocks.20.mha.w_k.weight": "model-00001-of-00004.safetensors",
457
+ "biobrain_encoder.nt_model.attention_blocks.20.mha.w_q.bias": "model-00001-of-00004.safetensors",
458
+ "biobrain_encoder.nt_model.attention_blocks.20.mha.w_q.weight": "model-00001-of-00004.safetensors",
459
+ "biobrain_encoder.nt_model.attention_blocks.20.mha.w_v.bias": "model-00001-of-00004.safetensors",
460
+ "biobrain_encoder.nt_model.attention_blocks.20.mha.w_v.weight": "model-00001-of-00004.safetensors",
461
+ "biobrain_encoder.nt_model.attention_blocks.21.fc1.weight": "model-00001-of-00004.safetensors",
462
+ "biobrain_encoder.nt_model.attention_blocks.21.fc2.weight": "model-00001-of-00004.safetensors",
463
+ "biobrain_encoder.nt_model.attention_blocks.21.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
464
+ "biobrain_encoder.nt_model.attention_blocks.21.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
465
+ "biobrain_encoder.nt_model.attention_blocks.21.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
466
+ "biobrain_encoder.nt_model.attention_blocks.21.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
467
+ "biobrain_encoder.nt_model.attention_blocks.21.mha.output.bias": "model-00001-of-00004.safetensors",
468
+ "biobrain_encoder.nt_model.attention_blocks.21.mha.output.weight": "model-00001-of-00004.safetensors",
469
+ "biobrain_encoder.nt_model.attention_blocks.21.mha.w_k.bias": "model-00001-of-00004.safetensors",
470
+ "biobrain_encoder.nt_model.attention_blocks.21.mha.w_k.weight": "model-00001-of-00004.safetensors",
471
+ "biobrain_encoder.nt_model.attention_blocks.21.mha.w_q.bias": "model-00001-of-00004.safetensors",
472
+ "biobrain_encoder.nt_model.attention_blocks.21.mha.w_q.weight": "model-00001-of-00004.safetensors",
473
+ "biobrain_encoder.nt_model.attention_blocks.21.mha.w_v.bias": "model-00001-of-00004.safetensors",
474
+ "biobrain_encoder.nt_model.attention_blocks.21.mha.w_v.weight": "model-00001-of-00004.safetensors",
475
+ "biobrain_encoder.nt_model.attention_blocks.22.fc1.weight": "model-00001-of-00004.safetensors",
476
+ "biobrain_encoder.nt_model.attention_blocks.22.fc2.weight": "model-00001-of-00004.safetensors",
477
+ "biobrain_encoder.nt_model.attention_blocks.22.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
478
+ "biobrain_encoder.nt_model.attention_blocks.22.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
479
+ "biobrain_encoder.nt_model.attention_blocks.22.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
480
+ "biobrain_encoder.nt_model.attention_blocks.22.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
481
+ "biobrain_encoder.nt_model.attention_blocks.22.mha.output.bias": "model-00001-of-00004.safetensors",
482
+ "biobrain_encoder.nt_model.attention_blocks.22.mha.output.weight": "model-00001-of-00004.safetensors",
483
+ "biobrain_encoder.nt_model.attention_blocks.22.mha.w_k.bias": "model-00001-of-00004.safetensors",
484
+ "biobrain_encoder.nt_model.attention_blocks.22.mha.w_k.weight": "model-00001-of-00004.safetensors",
485
+ "biobrain_encoder.nt_model.attention_blocks.22.mha.w_q.bias": "model-00001-of-00004.safetensors",
486
+ "biobrain_encoder.nt_model.attention_blocks.22.mha.w_q.weight": "model-00001-of-00004.safetensors",
487
+ "biobrain_encoder.nt_model.attention_blocks.22.mha.w_v.bias": "model-00001-of-00004.safetensors",
488
+ "biobrain_encoder.nt_model.attention_blocks.22.mha.w_v.weight": "model-00001-of-00004.safetensors",
489
+ "biobrain_encoder.nt_model.attention_blocks.23.fc1.weight": "model-00001-of-00004.safetensors",
490
+ "biobrain_encoder.nt_model.attention_blocks.23.fc2.weight": "model-00001-of-00004.safetensors",
491
+ "biobrain_encoder.nt_model.attention_blocks.23.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
492
+ "biobrain_encoder.nt_model.attention_blocks.23.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
493
+ "biobrain_encoder.nt_model.attention_blocks.23.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
494
+ "biobrain_encoder.nt_model.attention_blocks.23.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
495
+ "biobrain_encoder.nt_model.attention_blocks.23.mha.output.bias": "model-00001-of-00004.safetensors",
496
+ "biobrain_encoder.nt_model.attention_blocks.23.mha.output.weight": "model-00001-of-00004.safetensors",
497
+ "biobrain_encoder.nt_model.attention_blocks.23.mha.w_k.bias": "model-00001-of-00004.safetensors",
498
+ "biobrain_encoder.nt_model.attention_blocks.23.mha.w_k.weight": "model-00001-of-00004.safetensors",
499
+ "biobrain_encoder.nt_model.attention_blocks.23.mha.w_q.bias": "model-00001-of-00004.safetensors",
500
+ "biobrain_encoder.nt_model.attention_blocks.23.mha.w_q.weight": "model-00001-of-00004.safetensors",
501
+ "biobrain_encoder.nt_model.attention_blocks.23.mha.w_v.bias": "model-00001-of-00004.safetensors",
502
+ "biobrain_encoder.nt_model.attention_blocks.23.mha.w_v.weight": "model-00001-of-00004.safetensors",
503
+ "biobrain_encoder.nt_model.attention_blocks.24.fc1.weight": "model-00001-of-00004.safetensors",
504
+ "biobrain_encoder.nt_model.attention_blocks.24.fc2.weight": "model-00001-of-00004.safetensors",
505
+ "biobrain_encoder.nt_model.attention_blocks.24.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
506
+ "biobrain_encoder.nt_model.attention_blocks.24.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
507
+ "biobrain_encoder.nt_model.attention_blocks.24.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
508
+ "biobrain_encoder.nt_model.attention_blocks.24.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
509
+ "biobrain_encoder.nt_model.attention_blocks.24.mha.output.bias": "model-00001-of-00004.safetensors",
510
+ "biobrain_encoder.nt_model.attention_blocks.24.mha.output.weight": "model-00001-of-00004.safetensors",
511
+ "biobrain_encoder.nt_model.attention_blocks.24.mha.w_k.bias": "model-00001-of-00004.safetensors",
512
+ "biobrain_encoder.nt_model.attention_blocks.24.mha.w_k.weight": "model-00001-of-00004.safetensors",
513
+ "biobrain_encoder.nt_model.attention_blocks.24.mha.w_q.bias": "model-00001-of-00004.safetensors",
514
+ "biobrain_encoder.nt_model.attention_blocks.24.mha.w_q.weight": "model-00001-of-00004.safetensors",
515
+ "biobrain_encoder.nt_model.attention_blocks.24.mha.w_v.bias": "model-00001-of-00004.safetensors",
516
+ "biobrain_encoder.nt_model.attention_blocks.24.mha.w_v.weight": "model-00001-of-00004.safetensors",
517
+ "biobrain_encoder.nt_model.attention_blocks.25.fc1.weight": "model-00001-of-00004.safetensors",
518
+ "biobrain_encoder.nt_model.attention_blocks.25.fc2.weight": "model-00001-of-00004.safetensors",
519
+ "biobrain_encoder.nt_model.attention_blocks.25.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
520
+ "biobrain_encoder.nt_model.attention_blocks.25.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
521
+ "biobrain_encoder.nt_model.attention_blocks.25.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
522
+ "biobrain_encoder.nt_model.attention_blocks.25.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
523
+ "biobrain_encoder.nt_model.attention_blocks.25.mha.output.bias": "model-00001-of-00004.safetensors",
524
+ "biobrain_encoder.nt_model.attention_blocks.25.mha.output.weight": "model-00001-of-00004.safetensors",
525
+ "biobrain_encoder.nt_model.attention_blocks.25.mha.w_k.bias": "model-00001-of-00004.safetensors",
526
+ "biobrain_encoder.nt_model.attention_blocks.25.mha.w_k.weight": "model-00001-of-00004.safetensors",
527
+ "biobrain_encoder.nt_model.attention_blocks.25.mha.w_q.bias": "model-00001-of-00004.safetensors",
528
+ "biobrain_encoder.nt_model.attention_blocks.25.mha.w_q.weight": "model-00001-of-00004.safetensors",
529
+ "biobrain_encoder.nt_model.attention_blocks.25.mha.w_v.bias": "model-00001-of-00004.safetensors",
530
+ "biobrain_encoder.nt_model.attention_blocks.25.mha.w_v.weight": "model-00001-of-00004.safetensors",
531
+ "biobrain_encoder.nt_model.attention_blocks.26.fc1.weight": "model-00001-of-00004.safetensors",
532
+ "biobrain_encoder.nt_model.attention_blocks.26.fc2.weight": "model-00001-of-00004.safetensors",
533
+ "biobrain_encoder.nt_model.attention_blocks.26.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
534
+ "biobrain_encoder.nt_model.attention_blocks.26.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
535
+ "biobrain_encoder.nt_model.attention_blocks.26.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
536
+ "biobrain_encoder.nt_model.attention_blocks.26.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
537
+ "biobrain_encoder.nt_model.attention_blocks.26.mha.output.bias": "model-00001-of-00004.safetensors",
538
+ "biobrain_encoder.nt_model.attention_blocks.26.mha.output.weight": "model-00001-of-00004.safetensors",
539
+ "biobrain_encoder.nt_model.attention_blocks.26.mha.w_k.bias": "model-00001-of-00004.safetensors",
540
+ "biobrain_encoder.nt_model.attention_blocks.26.mha.w_k.weight": "model-00001-of-00004.safetensors",
541
+ "biobrain_encoder.nt_model.attention_blocks.26.mha.w_q.bias": "model-00001-of-00004.safetensors",
542
+ "biobrain_encoder.nt_model.attention_blocks.26.mha.w_q.weight": "model-00001-of-00004.safetensors",
543
+ "biobrain_encoder.nt_model.attention_blocks.26.mha.w_v.bias": "model-00001-of-00004.safetensors",
544
+ "biobrain_encoder.nt_model.attention_blocks.26.mha.w_v.weight": "model-00001-of-00004.safetensors",
545
+ "biobrain_encoder.nt_model.attention_blocks.27.fc1.weight": "model-00001-of-00004.safetensors",
546
+ "biobrain_encoder.nt_model.attention_blocks.27.fc2.weight": "model-00001-of-00004.safetensors",
547
+ "biobrain_encoder.nt_model.attention_blocks.27.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
548
+ "biobrain_encoder.nt_model.attention_blocks.27.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
549
+ "biobrain_encoder.nt_model.attention_blocks.27.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
550
+ "biobrain_encoder.nt_model.attention_blocks.27.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
551
+ "biobrain_encoder.nt_model.attention_blocks.27.mha.output.bias": "model-00001-of-00004.safetensors",
552
+ "biobrain_encoder.nt_model.attention_blocks.27.mha.output.weight": "model-00001-of-00004.safetensors",
553
+ "biobrain_encoder.nt_model.attention_blocks.27.mha.w_k.bias": "model-00001-of-00004.safetensors",
554
+ "biobrain_encoder.nt_model.attention_blocks.27.mha.w_k.weight": "model-00001-of-00004.safetensors",
555
+ "biobrain_encoder.nt_model.attention_blocks.27.mha.w_q.bias": "model-00001-of-00004.safetensors",
556
+ "biobrain_encoder.nt_model.attention_blocks.27.mha.w_q.weight": "model-00001-of-00004.safetensors",
557
+ "biobrain_encoder.nt_model.attention_blocks.27.mha.w_v.bias": "model-00001-of-00004.safetensors",
558
+ "biobrain_encoder.nt_model.attention_blocks.27.mha.w_v.weight": "model-00001-of-00004.safetensors",
559
+ "biobrain_encoder.nt_model.attention_blocks.28.fc1.weight": "model-00001-of-00004.safetensors",
560
+ "biobrain_encoder.nt_model.attention_blocks.28.fc2.weight": "model-00001-of-00004.safetensors",
561
+ "biobrain_encoder.nt_model.attention_blocks.28.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
562
+ "biobrain_encoder.nt_model.attention_blocks.28.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
563
+ "biobrain_encoder.nt_model.attention_blocks.28.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
564
+ "biobrain_encoder.nt_model.attention_blocks.28.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
565
+ "biobrain_encoder.nt_model.attention_blocks.28.mha.output.bias": "model-00001-of-00004.safetensors",
566
+ "biobrain_encoder.nt_model.attention_blocks.28.mha.output.weight": "model-00001-of-00004.safetensors",
567
+ "biobrain_encoder.nt_model.attention_blocks.28.mha.w_k.bias": "model-00001-of-00004.safetensors",
568
+ "biobrain_encoder.nt_model.attention_blocks.28.mha.w_k.weight": "model-00001-of-00004.safetensors",
569
+ "biobrain_encoder.nt_model.attention_blocks.28.mha.w_q.bias": "model-00001-of-00004.safetensors",
570
+ "biobrain_encoder.nt_model.attention_blocks.28.mha.w_q.weight": "model-00001-of-00004.safetensors",
571
+ "biobrain_encoder.nt_model.attention_blocks.28.mha.w_v.bias": "model-00001-of-00004.safetensors",
572
+ "biobrain_encoder.nt_model.attention_blocks.28.mha.w_v.weight": "model-00001-of-00004.safetensors",
573
+ "biobrain_encoder.nt_model.attention_blocks.3.fc1.weight": "model-00001-of-00004.safetensors",
574
+ "biobrain_encoder.nt_model.attention_blocks.3.fc2.weight": "model-00001-of-00004.safetensors",
575
+ "biobrain_encoder.nt_model.attention_blocks.3.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
576
+ "biobrain_encoder.nt_model.attention_blocks.3.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
577
+ "biobrain_encoder.nt_model.attention_blocks.3.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
578
+ "biobrain_encoder.nt_model.attention_blocks.3.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
579
+ "biobrain_encoder.nt_model.attention_blocks.3.mha.output.bias": "model-00001-of-00004.safetensors",
580
+ "biobrain_encoder.nt_model.attention_blocks.3.mha.output.weight": "model-00001-of-00004.safetensors",
581
+ "biobrain_encoder.nt_model.attention_blocks.3.mha.w_k.bias": "model-00001-of-00004.safetensors",
582
+ "biobrain_encoder.nt_model.attention_blocks.3.mha.w_k.weight": "model-00001-of-00004.safetensors",
583
+ "biobrain_encoder.nt_model.attention_blocks.3.mha.w_q.bias": "model-00001-of-00004.safetensors",
584
+ "biobrain_encoder.nt_model.attention_blocks.3.mha.w_q.weight": "model-00001-of-00004.safetensors",
585
+ "biobrain_encoder.nt_model.attention_blocks.3.mha.w_v.bias": "model-00001-of-00004.safetensors",
586
+ "biobrain_encoder.nt_model.attention_blocks.3.mha.w_v.weight": "model-00001-of-00004.safetensors",
587
+ "biobrain_encoder.nt_model.attention_blocks.4.fc1.weight": "model-00001-of-00004.safetensors",
588
+ "biobrain_encoder.nt_model.attention_blocks.4.fc2.weight": "model-00001-of-00004.safetensors",
589
+ "biobrain_encoder.nt_model.attention_blocks.4.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
590
+ "biobrain_encoder.nt_model.attention_blocks.4.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
591
+ "biobrain_encoder.nt_model.attention_blocks.4.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
592
+ "biobrain_encoder.nt_model.attention_blocks.4.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
593
+ "biobrain_encoder.nt_model.attention_blocks.4.mha.output.bias": "model-00001-of-00004.safetensors",
594
+ "biobrain_encoder.nt_model.attention_blocks.4.mha.output.weight": "model-00001-of-00004.safetensors",
595
+ "biobrain_encoder.nt_model.attention_blocks.4.mha.w_k.bias": "model-00001-of-00004.safetensors",
596
+ "biobrain_encoder.nt_model.attention_blocks.4.mha.w_k.weight": "model-00001-of-00004.safetensors",
597
+ "biobrain_encoder.nt_model.attention_blocks.4.mha.w_q.bias": "model-00001-of-00004.safetensors",
598
+ "biobrain_encoder.nt_model.attention_blocks.4.mha.w_q.weight": "model-00001-of-00004.safetensors",
599
+ "biobrain_encoder.nt_model.attention_blocks.4.mha.w_v.bias": "model-00001-of-00004.safetensors",
600
+ "biobrain_encoder.nt_model.attention_blocks.4.mha.w_v.weight": "model-00001-of-00004.safetensors",
601
+ "biobrain_encoder.nt_model.attention_blocks.5.fc1.weight": "model-00001-of-00004.safetensors",
602
+ "biobrain_encoder.nt_model.attention_blocks.5.fc2.weight": "model-00001-of-00004.safetensors",
603
+ "biobrain_encoder.nt_model.attention_blocks.5.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
604
+ "biobrain_encoder.nt_model.attention_blocks.5.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
605
+ "biobrain_encoder.nt_model.attention_blocks.5.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
606
+ "biobrain_encoder.nt_model.attention_blocks.5.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
607
+ "biobrain_encoder.nt_model.attention_blocks.5.mha.output.bias": "model-00001-of-00004.safetensors",
608
+ "biobrain_encoder.nt_model.attention_blocks.5.mha.output.weight": "model-00001-of-00004.safetensors",
609
+ "biobrain_encoder.nt_model.attention_blocks.5.mha.w_k.bias": "model-00001-of-00004.safetensors",
610
+ "biobrain_encoder.nt_model.attention_blocks.5.mha.w_k.weight": "model-00001-of-00004.safetensors",
611
+ "biobrain_encoder.nt_model.attention_blocks.5.mha.w_q.bias": "model-00001-of-00004.safetensors",
612
+ "biobrain_encoder.nt_model.attention_blocks.5.mha.w_q.weight": "model-00001-of-00004.safetensors",
613
+ "biobrain_encoder.nt_model.attention_blocks.5.mha.w_v.bias": "model-00001-of-00004.safetensors",
614
+ "biobrain_encoder.nt_model.attention_blocks.5.mha.w_v.weight": "model-00001-of-00004.safetensors",
615
+ "biobrain_encoder.nt_model.attention_blocks.6.fc1.weight": "model-00001-of-00004.safetensors",
616
+ "biobrain_encoder.nt_model.attention_blocks.6.fc2.weight": "model-00001-of-00004.safetensors",
617
+ "biobrain_encoder.nt_model.attention_blocks.6.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
618
+ "biobrain_encoder.nt_model.attention_blocks.6.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
619
+ "biobrain_encoder.nt_model.attention_blocks.6.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
620
+ "biobrain_encoder.nt_model.attention_blocks.6.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
621
+ "biobrain_encoder.nt_model.attention_blocks.6.mha.output.bias": "model-00001-of-00004.safetensors",
622
+ "biobrain_encoder.nt_model.attention_blocks.6.mha.output.weight": "model-00001-of-00004.safetensors",
623
+ "biobrain_encoder.nt_model.attention_blocks.6.mha.w_k.bias": "model-00001-of-00004.safetensors",
624
+ "biobrain_encoder.nt_model.attention_blocks.6.mha.w_k.weight": "model-00001-of-00004.safetensors",
625
+ "biobrain_encoder.nt_model.attention_blocks.6.mha.w_q.bias": "model-00001-of-00004.safetensors",
626
+ "biobrain_encoder.nt_model.attention_blocks.6.mha.w_q.weight": "model-00001-of-00004.safetensors",
627
+ "biobrain_encoder.nt_model.attention_blocks.6.mha.w_v.bias": "model-00001-of-00004.safetensors",
628
+ "biobrain_encoder.nt_model.attention_blocks.6.mha.w_v.weight": "model-00001-of-00004.safetensors",
629
+ "biobrain_encoder.nt_model.attention_blocks.7.fc1.weight": "model-00001-of-00004.safetensors",
630
+ "biobrain_encoder.nt_model.attention_blocks.7.fc2.weight": "model-00001-of-00004.safetensors",
631
+ "biobrain_encoder.nt_model.attention_blocks.7.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
632
+ "biobrain_encoder.nt_model.attention_blocks.7.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
633
+ "biobrain_encoder.nt_model.attention_blocks.7.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
634
+ "biobrain_encoder.nt_model.attention_blocks.7.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
635
+ "biobrain_encoder.nt_model.attention_blocks.7.mha.output.bias": "model-00001-of-00004.safetensors",
636
+ "biobrain_encoder.nt_model.attention_blocks.7.mha.output.weight": "model-00001-of-00004.safetensors",
637
+ "biobrain_encoder.nt_model.attention_blocks.7.mha.w_k.bias": "model-00001-of-00004.safetensors",
638
+ "biobrain_encoder.nt_model.attention_blocks.7.mha.w_k.weight": "model-00001-of-00004.safetensors",
639
+ "biobrain_encoder.nt_model.attention_blocks.7.mha.w_q.bias": "model-00001-of-00004.safetensors",
640
+ "biobrain_encoder.nt_model.attention_blocks.7.mha.w_q.weight": "model-00001-of-00004.safetensors",
641
+ "biobrain_encoder.nt_model.attention_blocks.7.mha.w_v.bias": "model-00001-of-00004.safetensors",
642
+ "biobrain_encoder.nt_model.attention_blocks.7.mha.w_v.weight": "model-00001-of-00004.safetensors",
643
+ "biobrain_encoder.nt_model.attention_blocks.8.fc1.weight": "model-00001-of-00004.safetensors",
644
+ "biobrain_encoder.nt_model.attention_blocks.8.fc2.weight": "model-00001-of-00004.safetensors",
645
+ "biobrain_encoder.nt_model.attention_blocks.8.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
646
+ "biobrain_encoder.nt_model.attention_blocks.8.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
647
+ "biobrain_encoder.nt_model.attention_blocks.8.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
648
+ "biobrain_encoder.nt_model.attention_blocks.8.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
649
+ "biobrain_encoder.nt_model.attention_blocks.8.mha.output.bias": "model-00001-of-00004.safetensors",
650
+ "biobrain_encoder.nt_model.attention_blocks.8.mha.output.weight": "model-00001-of-00004.safetensors",
651
+ "biobrain_encoder.nt_model.attention_blocks.8.mha.w_k.bias": "model-00001-of-00004.safetensors",
652
+ "biobrain_encoder.nt_model.attention_blocks.8.mha.w_k.weight": "model-00001-of-00004.safetensors",
653
+ "biobrain_encoder.nt_model.attention_blocks.8.mha.w_q.bias": "model-00001-of-00004.safetensors",
654
+ "biobrain_encoder.nt_model.attention_blocks.8.mha.w_q.weight": "model-00001-of-00004.safetensors",
655
+ "biobrain_encoder.nt_model.attention_blocks.8.mha.w_v.bias": "model-00001-of-00004.safetensors",
656
+ "biobrain_encoder.nt_model.attention_blocks.8.mha.w_v.weight": "model-00001-of-00004.safetensors",
657
+ "biobrain_encoder.nt_model.attention_blocks.9.fc1.weight": "model-00001-of-00004.safetensors",
658
+ "biobrain_encoder.nt_model.attention_blocks.9.fc2.weight": "model-00001-of-00004.safetensors",
659
+ "biobrain_encoder.nt_model.attention_blocks.9.layer_norm_mlp.bias": "model-00001-of-00004.safetensors",
660
+ "biobrain_encoder.nt_model.attention_blocks.9.layer_norm_mlp.weight": "model-00001-of-00004.safetensors",
661
+ "biobrain_encoder.nt_model.attention_blocks.9.layer_norm_self_attention.bias": "model-00001-of-00004.safetensors",
662
+ "biobrain_encoder.nt_model.attention_blocks.9.layer_norm_self_attention.weight": "model-00001-of-00004.safetensors",
663
+ "biobrain_encoder.nt_model.attention_blocks.9.mha.output.bias": "model-00001-of-00004.safetensors",
664
+ "biobrain_encoder.nt_model.attention_blocks.9.mha.output.weight": "model-00001-of-00004.safetensors",
665
+ "biobrain_encoder.nt_model.attention_blocks.9.mha.w_k.bias": "model-00001-of-00004.safetensors",
666
+ "biobrain_encoder.nt_model.attention_blocks.9.mha.w_k.weight": "model-00001-of-00004.safetensors",
667
+ "biobrain_encoder.nt_model.attention_blocks.9.mha.w_q.bias": "model-00001-of-00004.safetensors",
668
+ "biobrain_encoder.nt_model.attention_blocks.9.mha.w_q.weight": "model-00001-of-00004.safetensors",
669
+ "biobrain_encoder.nt_model.attention_blocks.9.mha.w_v.bias": "model-00001-of-00004.safetensors",
670
+ "biobrain_encoder.nt_model.attention_blocks.9.mha.w_v.weight": "model-00001-of-00004.safetensors",
671
+ "biobrain_encoder.nt_model.embed_layer.weight": "model-00001-of-00004.safetensors",
672
+ "biobrain_encoder.nt_model.lm_head._fc1.bias": "model-00001-of-00004.safetensors",
673
+ "biobrain_encoder.nt_model.lm_head._fc1.weight": "model-00001-of-00004.safetensors",
674
+ "biobrain_encoder.nt_model.lm_head._final_fc.bias": "model-00001-of-00004.safetensors",
675
+ "biobrain_encoder.nt_model.lm_head._final_fc.weight": "model-00001-of-00004.safetensors",
676
+ "biobrain_encoder.nt_model.lm_head._first_layer_norm.bias": "model-00001-of-00004.safetensors",
677
+ "biobrain_encoder.nt_model.lm_head._first_layer_norm.weight": "model-00001-of-00004.safetensors",
678
+ "biobrain_encoder.nt_model.lm_head._second_layer_norm.bias": "model-00001-of-00004.safetensors",
679
+ "biobrain_encoder.nt_model.lm_head._second_layer_norm.weight": "model-00001-of-00004.safetensors",
680
+ "projection_model.bio_projection.bias": "model-00003-of-00004.safetensors",
681
+ "projection_model.bio_projection.weight": "model-00003-of-00004.safetensors",
682
+ "projection_model.perceiver_resampler.latent_queries": "model-00003-of-00004.safetensors",
683
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.output.bias": "model-00003-of-00004.safetensors",
684
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.output.weight": "model-00003-of-00004.safetensors",
685
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_k.bias": "model-00003-of-00004.safetensors",
686
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_k.weight": "model-00003-of-00004.safetensors",
687
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_q.bias": "model-00003-of-00004.safetensors",
688
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_q.weight": "model-00003-of-00004.safetensors",
689
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_v.bias": "model-00003-of-00004.safetensors",
690
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_v.weight": "model-00003-of-00004.safetensors",
691
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.output.bias": "model-00004-of-00004.safetensors",
692
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.output.weight": "model-00004-of-00004.safetensors",
693
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_k.bias": "model-00004-of-00004.safetensors",
694
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_k.weight": "model-00004-of-00004.safetensors",
695
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_q.bias": "model-00004-of-00004.safetensors",
696
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_q.weight": "model-00004-of-00004.safetensors",
697
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_v.bias": "model-00004-of-00004.safetensors",
698
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_v.weight": "model-00004-of-00004.safetensors",
699
+ "projection_model.perceiver_resampler.layers.0.fc1.bias": "model-00004-of-00004.safetensors",
700
+ "projection_model.perceiver_resampler.layers.0.fc1.weight": "model-00004-of-00004.safetensors",
701
+ "projection_model.perceiver_resampler.layers.0.fc2.bias": "model-00004-of-00004.safetensors",
702
+ "projection_model.perceiver_resampler.layers.0.fc2.weight": "model-00004-of-00004.safetensors",
703
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_1.bias": "model-00004-of-00004.safetensors",
704
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_1.weight": "model-00004-of-00004.safetensors",
705
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_2.bias": "model-00004-of-00004.safetensors",
706
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_2.weight": "model-00004-of-00004.safetensors",
707
+ "projection_model.perceiver_resampler.layers.0.norm_mlp.bias": "model-00004-of-00004.safetensors",
708
+ "projection_model.perceiver_resampler.layers.0.norm_mlp.weight": "model-00004-of-00004.safetensors",
709
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.output.bias": "model-00004-of-00004.safetensors",
710
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.output.weight": "model-00004-of-00004.safetensors",
711
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_k.bias": "model-00004-of-00004.safetensors",
712
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_k.weight": "model-00004-of-00004.safetensors",
713
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_q.bias": "model-00004-of-00004.safetensors",
714
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_q.weight": "model-00004-of-00004.safetensors",
715
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_v.bias": "model-00004-of-00004.safetensors",
716
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_v.weight": "model-00004-of-00004.safetensors",
717
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.output.bias": "model-00004-of-00004.safetensors",
718
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.output.weight": "model-00004-of-00004.safetensors",
719
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_k.bias": "model-00004-of-00004.safetensors",
720
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_k.weight": "model-00004-of-00004.safetensors",
721
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_q.bias": "model-00004-of-00004.safetensors",
722
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_q.weight": "model-00004-of-00004.safetensors",
723
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_v.bias": "model-00004-of-00004.safetensors",
724
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_v.weight": "model-00004-of-00004.safetensors",
725
+ "projection_model.perceiver_resampler.layers.1.fc1.bias": "model-00004-of-00004.safetensors",
726
+ "projection_model.perceiver_resampler.layers.1.fc1.weight": "model-00004-of-00004.safetensors",
727
+ "projection_model.perceiver_resampler.layers.1.fc2.bias": "model-00004-of-00004.safetensors",
728
+ "projection_model.perceiver_resampler.layers.1.fc2.weight": "model-00004-of-00004.safetensors",
729
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_1.bias": "model-00004-of-00004.safetensors",
730
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_1.weight": "model-00004-of-00004.safetensors",
731
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_2.bias": "model-00004-of-00004.safetensors",
732
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_2.weight": "model-00004-of-00004.safetensors",
733
+ "projection_model.perceiver_resampler.layers.1.norm_mlp.bias": "model-00004-of-00004.safetensors",
734
+ "projection_model.perceiver_resampler.layers.1.norm_mlp.weight": "model-00004-of-00004.safetensors",
735
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.output.bias": "model-00004-of-00004.safetensors",
736
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.output.weight": "model-00004-of-00004.safetensors",
737
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_k.bias": "model-00004-of-00004.safetensors",
738
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_k.weight": "model-00004-of-00004.safetensors",
739
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_q.bias": "model-00004-of-00004.safetensors",
740
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_q.weight": "model-00004-of-00004.safetensors",
741
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_v.bias": "model-00004-of-00004.safetensors",
742
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_v.weight": "model-00004-of-00004.safetensors",
743
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.output.bias": "model-00004-of-00004.safetensors",
744
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.output.weight": "model-00004-of-00004.safetensors",
745
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_k.bias": "model-00004-of-00004.safetensors",
746
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_k.weight": "model-00004-of-00004.safetensors",
747
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_q.bias": "model-00004-of-00004.safetensors",
748
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_q.weight": "model-00004-of-00004.safetensors",
749
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_v.bias": "model-00004-of-00004.safetensors",
750
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_v.weight": "model-00004-of-00004.safetensors",
751
+ "projection_model.perceiver_resampler.layers.2.fc1.bias": "model-00004-of-00004.safetensors",
752
+ "projection_model.perceiver_resampler.layers.2.fc1.weight": "model-00004-of-00004.safetensors",
753
+ "projection_model.perceiver_resampler.layers.2.fc2.bias": "model-00004-of-00004.safetensors",
754
+ "projection_model.perceiver_resampler.layers.2.fc2.weight": "model-00004-of-00004.safetensors",
755
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_1.bias": "model-00004-of-00004.safetensors",
756
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_1.weight": "model-00004-of-00004.safetensors",
757
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_2.bias": "model-00004-of-00004.safetensors",
758
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_2.weight": "model-00004-of-00004.safetensors",
759
+ "projection_model.perceiver_resampler.layers.2.norm_mlp.bias": "model-00004-of-00004.safetensors",
760
+ "projection_model.perceiver_resampler.layers.2.norm_mlp.weight": "model-00004-of-00004.safetensors",
761
+ "projection_model.token_embedding.weight": "model-00003-of-00004.safetensors"
762
+ }
763
+ }