Yanisadel commited on
Commit
64c0358
·
1 Parent(s): 386475b

Upload model

Browse files
Files changed (3) hide show
  1. chatNT.py +1883 -0
  2. config.json +85 -0
  3. model.safetensors.index.json +763 -0
chatNT.py ADDED
@@ -0,0 +1,1883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file stores ChatNT and all associated layers and configs
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F # noqa: N812
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+
12
+
13
+ @dataclass
14
+ class RotaryEmbeddingConfig:
15
+ """
16
+ Rotary Positional Embedding configuration
17
+ max_seq_len: The number of positions to encode and cache.
18
+ dim: Dimension of RoPE.
19
+ theta: Rotation angle.
20
+ """
21
+
22
+ max_seq_len: int
23
+ dim: int
24
+ theta: float
25
+
26
+
27
+ @dataclass
28
+ class PerceiverResamplerConfig:
29
+ """
30
+ Parameters to initialize an PerceiverResampler model. Based on the ESM architecture.
31
+
32
+ Args:
33
+ emb_layer_norm_before: Whether to use layer norm before the first attention
34
+ layer.
35
+ attention_heads: Number of attention heads.
36
+ key_size: The dimension of the query, key, and values within each attention
37
+ head, if not specified, it is set to attention_heads//embed_dim.
38
+ It can be useful to set a custom key size if we want to impose the size of
39
+ the query, key and value tensor ( for example, tensors shaped with
40
+ power of 2 are more efficiently handled on TPUs ).
41
+ Note: Parametrizing the model with a custom key size has been done in :
42
+ Brown, Tom, et al. "Language models are few-shot learners."
43
+ Advances in neural information processing systems 33 (2020): 1877-1901.
44
+ embed_dim: Embedding dimension.
45
+ ffn_embed_dim: Feed forward embedding dimension.
46
+ num_layers: Number of attention blocks.
47
+ ffn_activation_name: Activation function to be used in FFN block. Supported
48
+ names are "gelu", "relu", "swish".
49
+ use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
50
+ Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
51
+ to True and use swish as ffn_activation_name.
52
+ Same principle for a gated-relu. To keep the same number of parameters in
53
+ the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
54
+ See https://arxiv.org/pdf/2002.05202.pdf for more details.
55
+ resampled_length: length of the resampled output of the module
56
+ use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
57
+ gradients in the forward pass to reduce the computation in the backward).
58
+ """
59
+
60
+ # architecture
61
+ emb_layer_norm_before: bool = False
62
+ attention_heads: int = 20
63
+ key_size: Optional[int] = None
64
+ embed_dim: int = 1280
65
+ ffn_embed_dim: int = 5120
66
+ num_layers: int = 24
67
+ add_bias_kv: bool = False
68
+ add_bias_ffn: bool = True
69
+ ffn_activation_name: str = "gelu-no-approx"
70
+ use_glu_in_ffn: bool = False
71
+ resampled_length: int = 64
72
+
73
+ # performance
74
+ use_gradient_checkpointing: bool = False
75
+
76
+ def __post_init__(self) -> None:
77
+ """
78
+ Checks that the given values are compatible.
79
+ """
80
+
81
+ if self.key_size is None:
82
+ if not self.embed_dim % self.attention_heads == 0:
83
+ raise ValueError(
84
+ f"When no key size is provided, the embedding dimension should be "
85
+ f"divisible by the number of heads, however provided embedding "
86
+ f"dimension is {self.embed_dim} and the number of heads is "
87
+ f"{self.attention_heads}."
88
+ )
89
+ self.key_size = self.embed_dim // self.attention_heads
90
+
91
+
92
+ @dataclass
93
+ class GptConfig:
94
+ """
95
+ Parameters to initialize a Gpt model.
96
+
97
+ NOTE: the pad token is not defined
98
+
99
+ Args:
100
+ vocab_size: Token vocabulary.
101
+ eos_token_id: used to stop sentence generation
102
+ embed_dim: Embedding dimension.
103
+ ffn_embed_dim: Feed forward embedding dimension.
104
+ num_heads: Number of attention heads.
105
+ num_kv_heads: Number of key and value heads to support Grouped-Query and
106
+ Multi-Query Attention. If None, the number of key and value heads is
107
+ equal to the number of attention heads.
108
+ num_layers: Number of Decoder layer_stack
109
+ rope_config: The configuration for the rotary positional embeddings
110
+ add_bias_ffn: Add bias in feed forward network block.
111
+ ffn_activation_name: Activation function to be used in FFN block. Supported
112
+ names are "gelu", "gelu-no-approx", "relu", "swish".
113
+ use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed
114
+ Forward Network (FFN) block.
115
+ example: To do a swiGLU (gated-swish) put this arg
116
+ to True and use swish as ffn_activation_name.
117
+ Same principle for a gated-relu.
118
+ add_bias_lm_head: whether to use bias in the final LM layer
119
+ norm_type: The type of norm used ( pre normalization scheme ) used. can be
120
+ one of ["layer_norm", "RMS_norm"]
121
+ parallel_attention_ff: Whether to do the attention and the MLP in parallel,
122
+ and then sum up the results as it is done in Gpt-NeoX :
123
+ Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive
124
+ language model." arXiv preprint arXiv:2204.06745 (2022).
125
+ It is said to improve the training time of 15% when compiling with JAX
126
+ use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
127
+ gradients in the forward pass to reduce the computation in the backward).
128
+ add_bias_attn: Add bias to the attention mechanism (key, query, value, and
129
+ output projections).
130
+ """
131
+
132
+ # vocabulary
133
+ vocab_size: int
134
+ eos_token_id: int
135
+
136
+ # architecture
137
+ embed_dim: int = 16
138
+ ffn_embed_dim: int = 64
139
+ num_heads: int = 2
140
+ num_kv_heads: Optional[int] = None
141
+ num_layers: int = 2
142
+ rope_config: RotaryEmbeddingConfig = field(
143
+ default_factory=lambda: RotaryEmbeddingConfig(
144
+ max_seq_len=512, dim=8, theta=10000.0
145
+ )
146
+ )
147
+ add_bias_ffn: bool = False
148
+ ffn_activation_name: str = "swish"
149
+ use_glu_in_ffn: bool = True
150
+ add_bias_lm_head: bool = False
151
+ norm_type: str = "RMS_norm"
152
+ rms_norm_eps: float = 1e-6
153
+ parallel_attention_ff: bool = True
154
+
155
+ # inference / backward behavior
156
+ use_gradient_checkpointing: bool = False
157
+
158
+ # architecture params with default values
159
+ add_bias_attn: bool = False
160
+
161
+ def __post_init__(self) -> None:
162
+ """
163
+ Checks that the given values are compatible.
164
+ """
165
+ if not self.embed_dim % self.num_heads == 0:
166
+ raise ValueError(
167
+ f"The embedding dimension should be "
168
+ f"divisible by the number of heads, however provided embedding "
169
+ f"dimension is {self.embed_dim} and the number of heads is "
170
+ f"{self.num_heads}."
171
+ )
172
+
173
+ if not self.embed_dim // self.num_heads > 1:
174
+ raise ValueError(
175
+ "embed_dim / num_heads must be higher than 2 to apply rotary embeddings"
176
+ )
177
+
178
+ if not self.embed_dim // self.num_heads >= self.rope_config.dim:
179
+ raise ValueError(
180
+ "embed_dim // num_heads must be higher than rope_config.dim "
181
+ "to apply rotary embeddings"
182
+ )
183
+
184
+ def to_dict(self): # type: ignore
185
+ output = asdict(self)
186
+ output["rope_config"] = asdict(self.rope_config)
187
+ return output
188
+
189
+
190
+ @dataclass
191
+ class ESMTransformerConfig:
192
+ """
193
+ Parameters to initialize an ESM model. While the ESM architecture is an encoder-only
194
+ model, different choices have been made for each version and this configuration aims
195
+ to cover most of them.
196
+
197
+ Args:
198
+ alphabet_size: Token vocabulary.
199
+ pad_token_id: ID of pad token.
200
+ mask_token_id: ID of mask token.
201
+ max_positions: Maximum sequence length.
202
+ embed_scale: Correction ratio applied to the embeddings to make up for the
203
+ norm difference between the input during training and inference.
204
+ emb_layer_norm_before: Whether to use layer norm before the first attention
205
+ layer.
206
+ attention_heads: Number of attention heads.
207
+ key_size: The dimension of the query, key, and values within each attention
208
+ head, if not specified, it is set to attention_heads//embed_dim.
209
+ It can be useful to set a custom key size if we want to impose the size of
210
+ the query, key and value tensor ( for example, tensors shaped with
211
+ power of 2 are more efficiently handled on TPUs ).
212
+ Note: Parametrizing the model with a custom key size has been done in :
213
+ Brown, Tom, et al. "Language models are few-shot learners."
214
+ Advances in neural information processing systems 33 (2020): 1877-1901.
215
+ embed_dim: Embedding dimension.
216
+ ffn_embed_dim: Feed forward embedding dimension.
217
+ num_layers: Number of attention blocks.
218
+ positional_embedding: Type of positional embedding to use before the first
219
+ attention layer. Options: "learned", "learned_standard" "sinusoidal" or
220
+ None.
221
+ NOTE: "learned" is the positional embedding of ESM, and "learned_standard"
222
+ is a more standard one, used for example in DNAbert.
223
+ lm_head: type of language model head. Options: "simple", "roberta" or None.
224
+ add_bias_kv: Add bias in attention layer.
225
+ add_bias_ffn: Add bias in feed forward network block.
226
+ use_rotary_embedding: Whether to use rotary embeddings (for ESM2). Requires:
227
+ positional_embeddings = None.
228
+ rescaling_factor: Scaling factor to use for rotary embeddings.
229
+ ffn_activation_name: Activation function to be used in FFN block. Supported
230
+ names are "gelu", "relu", "swish".
231
+ use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
232
+ Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
233
+ to True and use swish as ffn_activation_name.
234
+ Same principle for a gated-relu. To keep the same number of parameters in
235
+ the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
236
+ See https://arxiv.org/pdf/2002.05202.pdf for more details.
237
+ mask_before_attention: Use mask before attention layers (for EMS1b and ESM2).
238
+ layer_norm_eps: the eps factor in the different layer norms of the model (refer
239
+ to layer norm implementation)
240
+ token_dropout: Token dropout.
241
+ masking_ratio: Masking ratio (used if token dropout is enabled).
242
+ masking_prob: Masking probability (used if token dropout is enabled).
243
+ use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
244
+ gradients in the forward pass to reduce the computation in the backward).
245
+ """
246
+
247
+ alphabet_size: int
248
+ pad_token_id: int
249
+ mask_token_id: int
250
+
251
+ max_positions: int = 1024
252
+ embed_scale: float = 1.0
253
+
254
+ # architecture
255
+ emb_layer_norm_before: bool = False
256
+ attention_heads: int = 20
257
+ key_size: Optional[int] = None
258
+ embed_dim: int = 1280
259
+ ffn_embed_dim: int = 5120
260
+ num_layers: int = 24
261
+ positional_embedding: Optional[str] = "learned"
262
+ lm_head: Optional[str] = "simple"
263
+ add_bias_kv: bool = False
264
+ add_bias_ffn: bool = True
265
+ use_rotary_embedding: bool = False
266
+ rescaling_factor: Optional[float] = None
267
+ ffn_activation_name: str = "gelu-no-approx"
268
+ use_glu_in_ffn: bool = False
269
+ mask_before_attention: bool = False
270
+ layer_norm_eps: float = 1e-5
271
+ pre_layer_norm: bool = True
272
+ bias_word_embedding: bool = False
273
+
274
+ # dropout
275
+ token_dropout: bool = False
276
+ masking_ratio: float = 0.1
277
+ masking_prob: float = 0.8
278
+
279
+ # logging
280
+ use_gradient_checkpointing: bool = False
281
+
282
+ # return
283
+ embeddings_layers_to_save: List[int] = field(default_factory=list)
284
+ attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list)
285
+
286
+ def __post_init__(self) -> None:
287
+ """
288
+ Checks that the given values are compatible.
289
+ """
290
+
291
+ if self.key_size is None:
292
+ if not self.embed_dim % self.attention_heads == 0:
293
+ raise ValueError(
294
+ f"When no key size is provided, the embedding dimension should be "
295
+ f"divisible by the number of heads, however provided embedding "
296
+ f"dimension is {self.embed_dim} and the number of heads is "
297
+ f"{self.attention_heads}."
298
+ )
299
+ self.key_size = self.embed_dim // self.attention_heads
300
+ if self.positional_embedding is not None:
301
+ if type(self.positional_embedding) != str:
302
+ raise TypeError
303
+
304
+ if self.positional_embedding not in [
305
+ "learned",
306
+ "sinusoidal",
307
+ "learned_standard",
308
+ "alibi_dnabert_2",
309
+ ]:
310
+ raise ValueError(
311
+ "The positional_embedding argument should either be None,"
312
+ "`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'."
313
+ )
314
+ if self.lm_head is not None:
315
+ if type(self.lm_head) != str:
316
+ raise TypeError
317
+
318
+ if self.lm_head not in ["simple", "roberta"]:
319
+ raise ValueError(
320
+ "The lm_head argument should either be None,"
321
+ "`simple` or `roberta`."
322
+ )
323
+
324
+ if self.use_rotary_embedding and self.positional_embedding is not None:
325
+ raise ValueError(
326
+ "When using rotary embedding, positional_embedding must be set to none"
327
+ )
328
+
329
+ if self.add_bias_kv and self.use_rotary_embedding:
330
+ raise ValueError(
331
+ "Biases on key and values are not compatible with Rotary embeddings."
332
+ )
333
+
334
+ if self.positional_embedding == "alibi_dnabert_2":
335
+ assert not self.add_bias_kv
336
+
337
+
338
+ @dataclass
339
+ class ChatNTConfig(PretrainedConfig):
340
+ model_type = "ChatNT"
341
+
342
+ def __init__(self, **kwargs): # type: ignore
343
+ self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3))
344
+ self.esm_config: ESMTransformerConfig = kwargs.get(
345
+ "esm_config", ESMTransformerConfig(4000, 1, 4)
346
+ )
347
+ self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get(
348
+ "perceiver_resampler_config", PerceiverResamplerConfig()
349
+ )
350
+ self.seq_token_id: int = kwargs.get("seq_token_id", 32000)
351
+ self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1)
352
+ self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2)
353
+ super().__init__(**kwargs)
354
+
355
+ def to_dict(self): # type: ignore
356
+ output = super().to_dict()
357
+
358
+ def serialize(obj): # type: ignore
359
+ return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj)
360
+
361
+ output["gpt_config"] = serialize(self.gpt_config) # type: ignore
362
+ output["esm_config"] = serialize(self.esm_config) # type: ignore
363
+ output["perceiver_resampler_config"] = serialize( # type: ignore
364
+ self.perceiver_resampler_config
365
+ )
366
+ return output
367
+
368
+
369
+ class TorchBioBrainDecoder(nn.Module):
370
+ def __init__(
371
+ self,
372
+ gpt_config: GptConfig,
373
+ seq_token_id: int,
374
+ ):
375
+ """
376
+ Initializes the BioBrain decoder, using a GPT model for text generation with
377
+ bio embeddings.
378
+
379
+ Args:
380
+ gpt_config: Configuration for the GPT model
381
+ seq_token_id: Index of the SEQ token
382
+ """
383
+ super(TorchBioBrainDecoder, self).__init__()
384
+ self.gpt_config = gpt_config
385
+ self.seq_token_id = seq_token_id
386
+
387
+ # Initialize the GPT model (assumed you have it already in PyTorch)
388
+ self.gpt_model = TorchGptDecoder(self.gpt_config)
389
+
390
+ def forward(
391
+ self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor
392
+ ) -> torch.Tensor:
393
+ """
394
+ Forward pass through the model.
395
+
396
+ Args:
397
+ english_token_ids: Tensor of English token IDs with shape
398
+ (batch_size, num_english_tokens).
399
+ projected_bio_embeddings: Optional tensor of bio embeddings with shape
400
+ (batch_size, num_bio_sequences, ?, embed_dim).
401
+
402
+ Returns:
403
+ torch.Tensor: The logits from the GPT model,
404
+ shaped (batch_size, num_english_tokens, vocab_size).
405
+ """
406
+
407
+ # Compute English token embeddings
408
+ tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
409
+
410
+ if projected_bio_embeddings is not None:
411
+ (
412
+ batch_size,
413
+ num_bio_sequences,
414
+ _,
415
+ bio_embed_dim,
416
+ ) = projected_bio_embeddings.shape
417
+
418
+ # Insert the bio embeddings at the SEQ token positions
419
+ processed_tokens_ids = english_token_ids.clone()
420
+ for bio_seq_num in range(num_bio_sequences):
421
+ tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
422
+ processed_tokens_ids,
423
+ tokens_embeddings,
424
+ projected_bio_embeddings[:, bio_seq_num, :, :],
425
+ bio_seq_num=bio_seq_num,
426
+ )
427
+
428
+ # Regular GPT pass through
429
+ embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
430
+ embeddings = self.gpt_model.final_norm(embeddings)
431
+
432
+ # Compute logits
433
+ logits = self.gpt_model.lm_head(embeddings)
434
+
435
+ if projected_bio_embeddings is not None:
436
+ # Clean logits sequentially
437
+ processed_tokens_ids = english_token_ids.clone()
438
+ resampled_length = projected_bio_embeddings.shape[-2]
439
+ for _ in range(num_bio_sequences):
440
+ logits, processed_tokens_ids = self.cleanup_logits(
441
+ tokens=processed_tokens_ids,
442
+ logits=logits,
443
+ resampled_length=resampled_length,
444
+ )
445
+
446
+ return logits
447
+
448
+ def insert_embeddings(
449
+ self,
450
+ tokens: torch.Tensor,
451
+ input_embeddings: torch.Tensor,
452
+ resampled_embeddings: torch.Tensor,
453
+ bio_seq_num: int,
454
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
455
+ """
456
+ Inserts resampled embeddings in input_embeddings, starting at the SEQ token
457
+
458
+ Args:
459
+ tokens (torch.Tensor): Shape (batch_size, num_tokens)
460
+ input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
461
+ resampled_embeddings (torch.Tensor):
462
+ Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
463
+
464
+ Returns:
465
+ Tuple[torch.Tensor, torch.Tensor]:
466
+ - input_embeddings with resampled_embeddings inserted at the SEQ token
467
+ - tokens with the SEQ token set to -1
468
+ """
469
+
470
+ def _insert(
471
+ tokens_1d: torch.Tensor,
472
+ input_embeddings_1d: torch.Tensor,
473
+ resampled_embeddings_1d: torch.Tensor,
474
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
475
+ """
476
+ Args:
477
+ tokens (torch.Tensor): Shape (num_tokens,)
478
+ input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,)
479
+ resampled_embeddings (torch.Tensor):
480
+ Shape (bio_sequence_length, embed_dim,)
481
+ """
482
+ indices = torch.where(tokens_1d == self.seq_token_id)[0]
483
+ if indices.numel() > 0:
484
+ idx = indices[0].item()
485
+ insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
486
+ x = torch.cat(
487
+ [
488
+ input_embeddings_1d[:insertion_pos, :],
489
+ resampled_embeddings_1d,
490
+ input_embeddings_1d[insertion_pos:, :],
491
+ ],
492
+ dim=0,
493
+ )[: tokens_1d.shape[0] + 1, :]
494
+ x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[
495
+ :-1, :
496
+ ]
497
+ tokens_1d[idx] = -1
498
+ return x, tokens_1d
499
+ else:
500
+ return (
501
+ input_embeddings,
502
+ tokens_1d,
503
+ ) # Return unchanged if seq_token_id is not found
504
+
505
+ tokens_acc = []
506
+ embeddings_acc = []
507
+
508
+ for i in range(tokens.shape[0]):
509
+ embeddings_out, tokens_out = _insert(
510
+ tokens[i].clone(),
511
+ input_embeddings[i].clone(),
512
+ resampled_embeddings[i].clone(),
513
+ )
514
+ tokens_acc.append(tokens_out)
515
+ embeddings_acc.append(embeddings_out)
516
+ tokens_acc = torch.stack(tokens_acc)
517
+ embeddings_acc = torch.stack(embeddings_acc)
518
+
519
+ return embeddings_acc, tokens_acc
520
+
521
+ def cleanup_logits(
522
+ self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int
523
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
524
+ """
525
+ Removes the logits corresponding to the unused embeddings.
526
+
527
+ Args:
528
+ tokens: Input english tokens.
529
+ logits: Input logits.
530
+
531
+ Returns:
532
+ Cleaned logits, last values will be equal to 0.
533
+ """
534
+
535
+ def _clean(
536
+ token: torch.Tensor, logit: torch.Tensor
537
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
538
+ indices = torch.where(token == self.seq_token_id)[0]
539
+ if indices.numel() > 0:
540
+ idx = indices[0].item()
541
+
542
+ mask_idx = (
543
+ torch.arange(logit.shape[0] - resampled_length, device=logit.device)
544
+ > idx
545
+ )
546
+ mask_idx = mask_idx.unsqueeze(1)
547
+
548
+ # Remove values corresponding to bio tokens
549
+ logit = (
550
+ logit[:-resampled_length] * (~mask_idx)
551
+ + logit[resampled_length:] * mask_idx
552
+ )
553
+
554
+ # Append zeros at the end
555
+ logit = torch.cat(
556
+ (
557
+ logit,
558
+ torch.zeros(
559
+ (resampled_length, logit.shape[1]),
560
+ dtype=logit.dtype,
561
+ device=logit.device,
562
+ ),
563
+ )
564
+ )
565
+
566
+ # Update token
567
+ token[idx] = -1
568
+
569
+ return logit, token
570
+
571
+ else:
572
+ return logit, token
573
+
574
+ tokens_acc = []
575
+ logits_acc = []
576
+
577
+ for i in range(tokens.shape[0]):
578
+ logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone())
579
+ tokens_acc.append(tokens_out)
580
+ logits_acc.append(logits_out)
581
+ tokens_acc = torch.stack(tokens_acc)
582
+ logits_acc = torch.stack(logits_acc)
583
+
584
+ return logits_acc, tokens_acc
585
+
586
+
587
+ class TorchMultiOmicsModel(PreTrainedModel):
588
+ config_class = ChatNTConfig
589
+
590
+ def __init__(self, config: ChatNTConfig) -> None:
591
+ if isinstance(config, dict):
592
+ # If config is a dictionary instead of ChatNTConfig (which can happen
593
+ # depending how the config was saved), we convert it to the config
594
+ config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
595
+ **config["gpt_config"]["rope_config"]
596
+ )
597
+ config["gpt_config"] = GptConfig(**config["gpt_config"])
598
+ config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
599
+ config["perceiver_resampler_config"] = PerceiverResamplerConfig(
600
+ **config["perceiver_resampler_config"]
601
+ )
602
+ config = ChatNTConfig(**config) # type: ignore
603
+
604
+ else:
605
+ if isinstance(config.gpt_config, dict):
606
+ config.gpt_config["rope_config"] = RotaryEmbeddingConfig(
607
+ **config.gpt_config["rope_config"]
608
+ )
609
+ config.gpt_config = GptConfig(**config.gpt_config)
610
+
611
+ if isinstance(config.esm_config, dict):
612
+ config.esm_config = ESMTransformerConfig(**config.esm_config)
613
+
614
+ if isinstance(config.perceiver_resampler_config, dict):
615
+ config.perceiver_resampler_config = PerceiverResamplerConfig(
616
+ **config.perceiver_resampler_config
617
+ )
618
+
619
+ super().__init__(config=config)
620
+ self.gpt_config = config.gpt_config
621
+ self.esm_config = config.esm_config
622
+ self.perceiver_resampler_config = config.perceiver_resampler_config
623
+ self.seq_token_id = config.seq_token_id
624
+ self.bio_pad_token_id = config.bio_pad_token_id
625
+ self.english_pad_token_id = config.english_pad_token_id
626
+
627
+ # Correct seq_token_id
628
+ self.seq_token_id -= 1
629
+
630
+ self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config)
631
+ self.biobrain_decoder = TorchBioBrainDecoder(
632
+ gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
633
+ )
634
+ self.projection_model = TorchMultiModalPerceiverResamplerProjection(
635
+ perceiver_resampler_config=self.perceiver_resampler_config,
636
+ input_embed_dim=self.esm_config.embed_dim,
637
+ embed_dim=self.gpt_config.embed_dim,
638
+ english_vocab_size=self.gpt_config.vocab_size,
639
+ bio_pad_token_id=self.bio_pad_token_id,
640
+ english_pad_token_id=self.english_pad_token_id,
641
+ )
642
+
643
+ def forward(
644
+ self,
645
+ multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
646
+ projection_english_tokens_ids: torch.Tensor,
647
+ projected_bio_embeddings: torch.Tensor = None,
648
+ ) -> dict[str, torch.Tensor]:
649
+ """
650
+
651
+ Args:
652
+ multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
653
+ english_tokens_ids: Represents the prompt tokens (english tokens)
654
+ Shape (batch_size, num_english_tokens)
655
+
656
+ bio_tokens_ids: Represents the bio sequences tokens
657
+ Shape (batch_size, num_bio_sequences, num_bio_tokens)
658
+
659
+ projection_english_tokens_ids (torch.Tensor):
660
+ Shape (batch_size, num_english_tokens)
661
+
662
+ projected_bio_embeddings (projected_bio_embeddings, optional):
663
+ Shape (batch_size, num_bio_sequencse, ?, embed_dim).
664
+ Defaults to None.
665
+
666
+ Returns:
667
+ dict[str, torch.Tensor] containing:
668
+ - logits:
669
+ Shape (batch_size, num_tokens, vocab_size)
670
+
671
+ - projected_bio_embeddings:
672
+ Shape (batch_size, num_bio_sequences, ?, embed_dim)
673
+ """
674
+ english_token_ids, bio_token_ids = multi_omics_tokens_ids
675
+ english_token_ids = english_token_ids.clone()
676
+ bio_token_ids = bio_token_ids.clone()
677
+ projection_english_tokens_ids = projection_english_tokens_ids.clone()
678
+ if projected_bio_embeddings is not None:
679
+ projected_bio_embeddings = projected_bio_embeddings.clone()
680
+
681
+ # Replace config.vocab_size value in english tokens
682
+ # We do this because the default vocab size (32000) doesn't match with the
683
+ # number of tokens because of seq_token_id(=32000) that was added
684
+ # Therefore, we will put seq_token_id to 31999
685
+ # (I will also put token n°31999 to 0, which is for unknown token)
686
+ # This is a workaround to avoid having to change the vocab size in the config
687
+ vocab_size = self.gpt_config.vocab_size
688
+ # Replace vocab
689
+ english_token_ids[english_token_ids == vocab_size - 1] = 0
690
+ projection_english_tokens_ids[
691
+ projection_english_tokens_ids == vocab_size - 1
692
+ ] = 0
693
+ english_token_ids[english_token_ids == vocab_size] = vocab_size - 1
694
+ projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = (
695
+ vocab_size - 1
696
+ )
697
+
698
+ if bio_token_ids is None:
699
+ projected_bio_embeddings = None
700
+ else:
701
+ num_bio_sequences = bio_token_ids.shape[1]
702
+
703
+ if projected_bio_embeddings is None:
704
+ # Compute bio sequences embeddings
705
+ bio_embeddings_list = [
706
+ self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
707
+ for bio_seq_num in range(num_bio_sequences)
708
+ ]
709
+
710
+ # Project these embeddings
711
+ projected_bio_embeddings = [
712
+ self.projection_model(
713
+ bio_token_ids=bio_token_ids[:, bio_seq_num],
714
+ bio_embeddings=bio_embeddings,
715
+ english_token_ids=projection_english_tokens_ids,
716
+ )
717
+ for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
718
+ ]
719
+ projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
720
+
721
+ # decode
722
+ logits = self.biobrain_decoder(
723
+ english_token_ids=english_token_ids,
724
+ projected_bio_embeddings=projected_bio_embeddings,
725
+ )
726
+
727
+ outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
728
+
729
+ return outs
730
+
731
+
732
+ class TorchRotaryEmbedding(torch.nn.Module):
733
+ def __init__(self, config: RotaryEmbeddingConfig):
734
+ super().__init__()
735
+
736
+ self.max_seq_len = config.max_seq_len
737
+ self.dim = config.dim
738
+ self.theta = config.theta
739
+ self.sincos_cache = self._create_sinusoidal_positions()
740
+
741
+ def _create_sinusoidal_positions(self) -> torch.Tensor:
742
+ """
743
+ Create the sines and cosines for the RoPE.
744
+
745
+ Returns:
746
+ Sinusoidal positions of shape (self.max_seq_len, self.dim).
747
+ """
748
+ # Create the inverse frequency based on theta and dim
749
+ inv_freq = 1.0 / (
750
+ self.theta ** (torch.arange(0, self.dim, 2).float() / self.dim)
751
+ )
752
+
753
+ # Compute sinusoidal input using the broadcasting
754
+ sinusoid_inp = torch.einsum(
755
+ "i,j->ij", torch.arange(self.max_seq_len).float(), inv_freq
756
+ )
757
+
758
+ # Apply sin and cos to the sinusoidal input
759
+ sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
760
+
761
+ # Allocate a tensor for the final sin-cos values
762
+ sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32)
763
+
764
+ # Fill the sincos tensor with sin and cos values
765
+ sentinel = self.dim // 2 + self.dim % 2
766
+ sincos[:, :sentinel] = sin
767
+ sincos[:, sentinel:] = cos
768
+
769
+ return sincos
770
+
771
+ def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
772
+ """
773
+ Prepare a tensor to apply the RoPE mechanism.
774
+
775
+ Args:
776
+ x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
777
+ typically this is the key or query tensor.
778
+
779
+ Returns:
780
+ The even indices in the last dimension have their sign flipped.
781
+ Tensor of shape (batch_size, seq_len, num_heads, head_dim).
782
+ """
783
+ # Split the tensor into two halves (odd and even indexed dimensions)
784
+ rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1)
785
+
786
+ # Reshape the tensor to the original shape
787
+ rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,))
788
+ return rotate_half
789
+
790
+ def _apply_rotary_pos_emb(
791
+ self, x: torch.Tensor, sincos: torch.Tensor
792
+ ) -> torch.Tensor:
793
+ """
794
+ Applies rotary embeddings to x.
795
+
796
+ Args:
797
+ x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
798
+ typically this is the key or query tensor.
799
+ sincos: Tuple of sine and cosine tensors for position encoding.
800
+
801
+ Returns:
802
+ RoPE embeddings tensor.
803
+ """
804
+ sin_pos, cos_pos = sincos
805
+
806
+ # Reshape the sin and cos tensors for broadcasting
807
+ sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1)
808
+ cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1)
809
+
810
+ # Apply the rotary embedding mechanism
811
+ return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos)
812
+
813
+ def __call__(
814
+ self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None
815
+ ) -> tuple[torch.Tensor, torch.Tensor]:
816
+ """
817
+ Applies rotary embeddings to k and q.
818
+
819
+ Args:
820
+ k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
821
+ q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
822
+ positions: optional positions offset useful when caching,
823
+
824
+ Returns:
825
+ RoPE embeddings for the keys and values.
826
+ """
827
+ batch_size, seq_len, num_heads, head_dim = k.shape
828
+
829
+ # Generate position ids
830
+ position_ids = (
831
+ torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1)
832
+ )
833
+
834
+ if positions is not None:
835
+ position_ids += positions
836
+
837
+ # Retrieve sincos values using the position_ids
838
+ sincos = self.sincos_cache[position_ids]
839
+
840
+ # Split sincos into sin_pos and cos_pos
841
+ sincos = torch.chunk(sincos, 2, dim=-1)
842
+
843
+ # Apply rotary position embedding to key (k) and query (q)
844
+ k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos)
845
+ k_pass = k[..., self.dim :]
846
+
847
+ q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos)
848
+ q_pass = q[..., self.dim :]
849
+
850
+ # Concatenate the rotated and non-rotated parts
851
+ keys = torch.cat([k_rot, k_pass], dim=-1)
852
+ values = torch.cat([q_rot, q_pass], dim=-1)
853
+
854
+ return keys, values
855
+
856
+
857
+ class TorchGptGroupedQueryAttention(nn.Module):
858
+ def __init__(
859
+ self,
860
+ embed_dim: int,
861
+ num_heads: int,
862
+ rope_config: RotaryEmbeddingConfig,
863
+ num_kv_heads: int = None, # type: ignore
864
+ head_dim: int = None, # type: ignore
865
+ add_bias_attn: bool = False, # type: ignore
866
+ ) -> None:
867
+ super().__init__()
868
+ self.num_heads = num_heads
869
+ self.num_kv_heads = num_kv_heads or num_heads
870
+ self.embed_dim = embed_dim
871
+ self.head_dim = head_dim or (embed_dim // num_heads)
872
+ self.add_bias_attn = add_bias_attn
873
+ self.rope = TorchRotaryEmbedding(rope_config)
874
+
875
+ self.query_linear = nn.Linear(
876
+ embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn
877
+ )
878
+ self.key_linear = nn.Linear(
879
+ embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
880
+ )
881
+ self.value_linear = nn.Linear(
882
+ embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
883
+ )
884
+ self.out_linear = nn.Linear(
885
+ self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn
886
+ )
887
+
888
+ def forward(
889
+ self,
890
+ query_inputs: torch.Tensor,
891
+ key_inputs: torch.Tensor,
892
+ value_inputs: torch.Tensor,
893
+ attention_mask: torch.Tensor = None,
894
+ ) -> torch.Tensor:
895
+ batch_size, seq_len, _ = query_inputs.shape
896
+
897
+ queries = self.query_linear(query_inputs).view( # noqa
898
+ batch_size, seq_len, self.num_heads, self.head_dim
899
+ )
900
+ keys = self.key_linear(key_inputs).view( # noqa
901
+ batch_size, seq_len, self.num_kv_heads, self.head_dim
902
+ )
903
+ values = self.value_linear(value_inputs).view( # noqa
904
+ batch_size, seq_len, self.num_kv_heads, self.head_dim
905
+ )
906
+
907
+ keys, queries = self.rope(keys, queries)
908
+
909
+ n_rep = self.num_heads // self.num_kv_heads
910
+ keys = keys.repeat_interleave(n_rep, dim=2)
911
+ values = values.repeat_interleave(n_rep, dim=2)
912
+
913
+ attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / (
914
+ self.head_dim**0.5
915
+ )
916
+
917
+ if attention_mask is not None:
918
+ attention_logits = attention_logits.masked_fill(
919
+ attention_mask == 0, float("-inf")
920
+ )
921
+
922
+ attention_weights = nn.functional.softmax(attention_logits, dim=-1)
923
+
924
+ values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
925
+ values = values.contiguous().view(batch_size, seq_len, -1)
926
+
927
+ return self.out_linear(values)
928
+
929
+
930
+ class TorchGptDecoder(nn.Module):
931
+ def __init__(self, config: GptConfig, name: Optional[str] = None):
932
+ super().__init__()
933
+ self.config = config
934
+
935
+ self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim)
936
+
937
+ if config.norm_type == "layer_norm":
938
+ self.final_norm = nn.LayerNorm(config.embed_dim)
939
+ elif config.norm_type == "RMS_norm":
940
+ self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps)
941
+ else:
942
+ raise ValueError(f"unrecognized norm_type in config {config.norm_type}")
943
+
944
+ self.layers = nn.ModuleList(
945
+ [
946
+ TorchGptDecoderLayer(
947
+ embed_dim=config.embed_dim,
948
+ ffn_embed_dim=config.ffn_embed_dim,
949
+ num_heads=config.num_heads,
950
+ rope_config=config.rope_config,
951
+ norm_type=config.norm_type,
952
+ parallel_attention_ff=config.parallel_attention_ff,
953
+ add_bias_ffn=config.add_bias_ffn,
954
+ ffn_activation_name=config.ffn_activation_name,
955
+ use_glu_in_ffn=config.use_glu_in_ffn,
956
+ num_kv_heads=config.num_kv_heads, # type: ignore
957
+ add_bias_attn=config.add_bias_attn,
958
+ rms_norm_eps=config.rms_norm_eps,
959
+ )
960
+ for _ in range(config.num_layers)
961
+ ]
962
+ )
963
+
964
+ self.lm_head = TorchSimpleLMHead(
965
+ embed_dim=config.embed_dim,
966
+ alphabet_size=config.vocab_size,
967
+ add_bias_lm_head=config.add_bias_lm_head,
968
+ )
969
+
970
+ def apply_transformer_layers(
971
+ self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
972
+ ) -> torch.Tensor:
973
+ if attention_mask is None:
974
+ attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
975
+ for layer in self.layers:
976
+ embeddings = layer(embeddings, attention_mask)
977
+
978
+ return embeddings
979
+
980
+ def forward(
981
+ self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
982
+ ) -> dict[str, torch.Tensor]:
983
+ if attention_mask is None:
984
+ attention_mask = build_causal_attention_mask(1, token_ids.shape[1])
985
+
986
+ tokens_embeddings = self.token_embed(token_ids)
987
+
988
+ after_transformer_embeddings = self.apply_transformer_layers(
989
+ tokens_embeddings, attention_mask=attention_mask
990
+ )
991
+
992
+ embeddings = self.final_norm(after_transformer_embeddings)
993
+ logits = self.lm_head(embeddings)
994
+ return {"embeddings": embeddings, "logits": logits}
995
+
996
+
997
+ class TorchSimpleLMHead(nn.Module):
998
+ def __init__(
999
+ self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True
1000
+ ) -> None:
1001
+ super().__init__()
1002
+ self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head)
1003
+
1004
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1005
+ return self.fc(x)
1006
+
1007
+
1008
+ class TorchGptDecoderLayer(nn.Module):
1009
+ def __init__(
1010
+ self,
1011
+ embed_dim: int,
1012
+ ffn_embed_dim: int,
1013
+ num_heads: int,
1014
+ rope_config: RotaryEmbeddingConfig,
1015
+ norm_type: str,
1016
+ parallel_attention_ff: bool,
1017
+ add_bias_ffn: bool,
1018
+ ffn_activation_name: str,
1019
+ use_glu_in_ffn: bool,
1020
+ num_kv_heads: int,
1021
+ add_bias_attn: bool,
1022
+ rms_norm_eps: float = 1e-6,
1023
+ ) -> None:
1024
+ super().__init__()
1025
+ self.num_heads = num_heads
1026
+ self.parallel_attention_ff = parallel_attention_ff
1027
+ self.use_glu_in_ffn = use_glu_in_ffn
1028
+
1029
+ # Self-Attention layer
1030
+ self.self_attn = TorchGptGroupedQueryAttention(
1031
+ embed_dim=embed_dim,
1032
+ num_heads=num_heads,
1033
+ num_kv_heads=num_kv_heads,
1034
+ rope_config=rope_config,
1035
+ add_bias_attn=add_bias_attn,
1036
+ )
1037
+
1038
+ # Normalization layers
1039
+ if norm_type == "layer_norm":
1040
+ self.attn_norm = nn.LayerNorm(embed_dim)
1041
+ if not self.parallel_attention_ff:
1042
+ self.ffn_norm = nn.LayerNorm(embed_dim)
1043
+ elif norm_type == "RMS_norm":
1044
+ self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1045
+ if not self.parallel_attention_ff:
1046
+ self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1047
+ else:
1048
+ raise ValueError(f"unrecognized norm_type: {norm_type}")
1049
+
1050
+ # Feedforward network
1051
+ self.activation = get_activation_fn(ffn_activation_name)
1052
+ ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1)
1053
+ self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn)
1054
+ self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1055
+
1056
+ def forward(
1057
+ self, embeddings: torch.Tensor, attention_mask: torch.Tensor
1058
+ ) -> torch.Tensor:
1059
+ residuals = embeddings
1060
+
1061
+ if self.parallel_attention_ff:
1062
+ # Parallel Attention + MLP
1063
+ embeddings_normed = self.attn_norm(embeddings)
1064
+
1065
+ attn_output, _ = self.self_attn(
1066
+ embeddings_normed,
1067
+ embeddings_normed,
1068
+ embeddings_normed,
1069
+ attn_mask=attention_mask,
1070
+ )
1071
+ ffn_output = self.mlp(embeddings_normed) # type: ignore
1072
+
1073
+ return residuals + attn_output + ffn_output
1074
+ else:
1075
+ # Sequential Attention + MLP
1076
+ normed_embeddings = self.attn_norm(embeddings)
1077
+
1078
+ attn_output = embeddings + self.self_attn(
1079
+ normed_embeddings,
1080
+ normed_embeddings,
1081
+ normed_embeddings,
1082
+ attention_mask=attention_mask,
1083
+ )
1084
+
1085
+ normed_embeddings2 = self.ffn_norm(attn_output)
1086
+ ffn_output = self.mlp(normed_embeddings2) # type: ignore
1087
+ return attn_output + ffn_output # Residual connection
1088
+
1089
+ def mlp(self, x: torch.Tensor) -> torch.Tensor:
1090
+ """Applies the feedforward network (MLP) with optional GLU."""
1091
+ ffn_output = self.fc1(x)
1092
+
1093
+ if self.use_glu_in_ffn:
1094
+ ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1)
1095
+ ffn_output = self.activation(ffn_output1) * ffn_output2
1096
+ else:
1097
+ ffn_output = self.activation(ffn_output)
1098
+
1099
+ return self.fc2(ffn_output)
1100
+
1101
+
1102
+ class TorchRMSNorm(nn.Module):
1103
+ def __init__(self, dim: int, eps: float = 1e-6) -> None:
1104
+ super().__init__()
1105
+ self.eps = eps
1106
+ self.scale = nn.Parameter(torch.ones(dim))
1107
+
1108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1109
+ return (
1110
+ x
1111
+ * self.scale
1112
+ / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
1113
+ )
1114
+
1115
+
1116
+ def get_activation_fn(activation_name: str): # type: ignore
1117
+ activations = {
1118
+ "gelu": nn.functional.gelu,
1119
+ "relu": nn.functional.relu,
1120
+ "swish": nn.functional.silu,
1121
+ "silu": nn.functional.silu,
1122
+ }
1123
+ return activations.get(activation_name, nn.functional.relu)
1124
+
1125
+
1126
+ def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
1127
+ """
1128
+ Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1129
+ to an attention layer.
1130
+
1131
+ Args:
1132
+ batch_size: Batch size.
1133
+ seq_len: Length of the sequences.
1134
+
1135
+ Returns:
1136
+ Batch of causal masks.
1137
+ """
1138
+ mask = torch.ones((batch_size, 1, seq_len, seq_len))
1139
+ causal_mask = torch.tril(mask)
1140
+ return causal_mask
1141
+
1142
+
1143
+ @dataclass
1144
+ class RotaryEmbeddingConfigBis:
1145
+ """
1146
+ Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows
1147
+ to adapt the rotary embeddings to larger lengths than what was used for training.
1148
+ One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa
1149
+ Args:
1150
+ """
1151
+
1152
+ rescaling_factor: Optional[float]
1153
+
1154
+
1155
+ class RotaryEmbeddingBis(torch.nn.Module):
1156
+ """
1157
+ Rotary position embeddings based on those in
1158
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).
1159
+ Query and keys are transformed by rotation
1160
+ matrices which depend on their relative positions.
1161
+ """
1162
+
1163
+ def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis):
1164
+ super().__init__()
1165
+
1166
+ # Extract argument from the config
1167
+ self.rescaling_factor = rotary_embedding_config.rescaling_factor
1168
+ self.upper_freq = 10000
1169
+ self.dim = dim
1170
+
1171
+ self._seq_len_cached = None
1172
+ self._cos_cached = None
1173
+ self._sin_cached = None
1174
+
1175
+ def _apply_rotary_pos_emb(
1176
+ self,
1177
+ heads: torch.Tensor,
1178
+ cos: torch.Tensor,
1179
+ sin: torch.Tensor,
1180
+ ) -> torch.Tensor:
1181
+ """ """
1182
+ x_first, x_second = (
1183
+ heads[..., : heads.shape[-1] // 2],
1184
+ heads[..., heads.shape[-1] // 2 :],
1185
+ )
1186
+
1187
+ first_part = x_first * cos - x_second * sin
1188
+ second_part = x_second * cos + x_first * sin
1189
+
1190
+ return torch.cat((first_part, second_part), dim=-1)
1191
+
1192
+ def _compute_cos_sin_tables(
1193
+ self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1194
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1195
+ seq_len = x.shape[seq_dimension]
1196
+ # Reset the tables if the sequence length has changed,
1197
+ # or if we're on a new device (possibly due to tracing for instance)
1198
+ self._seq_len_cached = seq_len
1199
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
1200
+ # freqs = torch.outer(t, inv_freq)
1201
+ freqs = torch.einsum("i, j -> ij", t, inv_freq)
1202
+
1203
+ self._cos_cached = torch.cos(freqs)[None, :, None, :]
1204
+ self._sin_cached = torch.sin(freqs)[None, :, None, :]
1205
+ # emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1206
+
1207
+ # self._cos_cached = emb.cos()[None, None, :, :]
1208
+ # self._sin_cached = emb.sin()[None, None, :, :]
1209
+
1210
+ return self._cos_cached, self._sin_cached
1211
+
1212
+ def forward(
1213
+ self, q: torch.Tensor, k: torch.Tensor
1214
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1215
+ if self.rescaling_factor is None:
1216
+ inv_freq = 1.0 / (
1217
+ self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim)
1218
+ )
1219
+ else:
1220
+ updated_base = self.upper_freq * (
1221
+ self.rescaling_factor ** (self.dim / (self.dim - 2))
1222
+ )
1223
+ inv_freq = 1.0 / (
1224
+ updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
1225
+ )
1226
+
1227
+ self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1228
+ q,
1229
+ inv_freq,
1230
+ seq_dimension=-3,
1231
+ )
1232
+
1233
+ return (
1234
+ self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1235
+ self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1236
+ )
1237
+
1238
+
1239
+ class MultiHeadAttention(nn.Module):
1240
+ def __init__(
1241
+ self,
1242
+ num_heads: int,
1243
+ key_size: int,
1244
+ rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1245
+ add_bias_kv: bool = False,
1246
+ value_size: Optional[int] = None,
1247
+ model_size: Optional[int] = None,
1248
+ name: Optional[str] = None,
1249
+ ):
1250
+ super().__init__()
1251
+ if not model_size:
1252
+ model_size = key_size * num_heads
1253
+ if not value_size:
1254
+ value_size = key_size
1255
+ self.model_size = model_size
1256
+ self.key_size = key_size
1257
+ self.value_size = value_size
1258
+ self.add_bias_kv = add_bias_kv
1259
+ self.name = name
1260
+ self.num_heads = num_heads
1261
+ self._rotary_embedding_config = rotary_embedding_config
1262
+
1263
+ self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
1264
+ self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
1265
+ self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
1266
+ self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
1267
+ if self._rotary_embedding_config:
1268
+ self._rotary_embedding = RotaryEmbeddingBis(
1269
+ self.key_size, self._rotary_embedding_config
1270
+ )
1271
+
1272
+ def apply_rotary_embeddings(
1273
+ self,
1274
+ query: torch.Tensor,
1275
+ key: torch.Tensor,
1276
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1277
+ """ """
1278
+ query, key = self._rotary_embedding(query, key)
1279
+ return query, key
1280
+
1281
+ def forward(
1282
+ self,
1283
+ query: torch.Tensor,
1284
+ key: torch.Tensor,
1285
+ value: torch.Tensor,
1286
+ attention_mask: Optional[torch.Tensor] = None,
1287
+ attention_weight_bias: Optional[torch.Tensor] = None,
1288
+ ) -> dict[str, torch.Tensor]:
1289
+ """
1290
+ Returns:
1291
+ dictionary containing attention weights
1292
+ and outputs.
1293
+ """
1294
+ key_heads = self.w_k(key).reshape(
1295
+ (*key.shape[:-1], self.num_heads, self.key_size)
1296
+ )
1297
+ query_heads = self.w_q(query).reshape(
1298
+ (*query.shape[:-1], self.num_heads, self.key_size)
1299
+ )
1300
+ value_heads = self.w_v(value).reshape(
1301
+ (*value.shape[:-1], self.num_heads, self.value_size)
1302
+ )
1303
+ if self._rotary_embedding_config:
1304
+ query_heads, key_heads = self.apply_rotary_embeddings(
1305
+ query_heads, key_heads
1306
+ )
1307
+ attention_weights = torch.einsum(
1308
+ "...thd, ...Thd -> ...htT", query_heads, key_heads
1309
+ )
1310
+ sqrt_key_size = np.sqrt(self.key_size)
1311
+ attention_weights = attention_weights / sqrt_key_size
1312
+ if attention_mask is not None:
1313
+ attention_weights = torch.where(attention_mask, attention_weights, -1e30)
1314
+ if attention_weight_bias is not None:
1315
+ attention_weights = F.softmax(
1316
+ attention_weights + attention_weight_bias, dim=-1
1317
+ )
1318
+ else:
1319
+ attention_weights = F.softmax(attention_weights, dim=-1)
1320
+ value_out = torch.einsum(
1321
+ "...htT, ...Thd->...thd", attention_weights, value_heads
1322
+ )
1323
+ value_out = value_out.reshape((*value_out.shape[:-2], -1))
1324
+ embeddings = self.output(value_out)
1325
+
1326
+ return {"attention_weights": attention_weights, "embeddings": embeddings}
1327
+
1328
+
1329
+ class SelfAttentionBlock(nn.Module):
1330
+ def __init__(
1331
+ self,
1332
+ num_heads: int,
1333
+ embed_dim: int,
1334
+ ffn_embed_dim: int,
1335
+ key_size: Optional[int] = None,
1336
+ add_bias_kv: bool = False,
1337
+ add_bias_fnn: bool = True,
1338
+ ffn_activation_name: str = "gelu-no-approx",
1339
+ use_glu_in_ffn: bool = False,
1340
+ layer_norm_eps: float = 1e-5, # this is the default haiku value
1341
+ pre_layer_norm: bool = True,
1342
+ name: Optional[str] = None,
1343
+ rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1344
+ ):
1345
+ super().__init__()
1346
+ if key_size is None:
1347
+ if embed_dim % num_heads != 0:
1348
+ raise ValueError(
1349
+ f"The embedding dimension should be divisible by the number of "
1350
+ f"heads, however provided embedding dimension is {embed_dim} and "
1351
+ f"the number of heads is {num_heads}."
1352
+ )
1353
+ else:
1354
+ key_size = embed_dim // num_heads
1355
+
1356
+ # Get ffn activation function
1357
+ self._pre_layer_norm = pre_layer_norm
1358
+ self._use_glu_in_fnn = use_glu_in_ffn
1359
+ # Define layers
1360
+ if use_glu_in_ffn:
1361
+ # user should multiply ffn_embed_dim by 2/3 when using GLU
1362
+ # to keep total number of parameters equal
1363
+ # see https://arxiv.org/pdf/2002.05202.pdf. for more details
1364
+ # we multiply by 2 here as the output will be split in 2 for GLU
1365
+ self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
1366
+ else:
1367
+ self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
1368
+
1369
+ self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
1370
+
1371
+ self.layer_norm_self_attention = nn.LayerNorm(
1372
+ embed_dim,
1373
+ )
1374
+ self.layer_norm_mlp = nn.LayerNorm(embed_dim)
1375
+ if ffn_activation_name == "swish":
1376
+ self._ffn_activation_fn = nn.SiLU()
1377
+ elif ffn_activation_name == "gelu-no-approx":
1378
+ self._ffn_activation_fn = nn.GELU(approximate="tanh")
1379
+ else:
1380
+ self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
1381
+
1382
+ self.mha = MultiHeadAttention(
1383
+ num_heads=num_heads,
1384
+ key_size=key_size,
1385
+ add_bias_kv=add_bias_kv,
1386
+ model_size=embed_dim,
1387
+ name="self_attention",
1388
+ rotary_embedding_config=rotary_embedding_config,
1389
+ )
1390
+
1391
+ def mlp(self, embed: torch.Tensor) -> torch.Tensor:
1392
+
1393
+ if self._pre_layer_norm:
1394
+ x = self.layer_norm_mlp(embed)
1395
+ else:
1396
+ x = embed
1397
+
1398
+ if self._use_glu_in_fnn:
1399
+ x = self.fc1(x)
1400
+ x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
1401
+ x = self._ffn_activation_fn(x1) * x2
1402
+ else:
1403
+ x = self._ffn_activation_fn(self.fc1(x))
1404
+ x = self.fc2(x)
1405
+
1406
+ if not self._pre_layer_norm:
1407
+ x = self.layer_norm_mlp(x + embed)
1408
+ return x
1409
+
1410
+ def forward(
1411
+ self,
1412
+ x: torch.Tensor,
1413
+ attention_mask: Optional[torch.Tensor] = None,
1414
+ attention_weight_bias: Optional[torch.Tensor] = None,
1415
+ ) -> dict[str, torch.Tensor]:
1416
+
1417
+ res = x
1418
+ if self._pre_layer_norm:
1419
+ x = self.layer_norm_self_attention(x)
1420
+
1421
+ output: dict[str, torch.Tensor] = self.mha(
1422
+ x,
1423
+ x,
1424
+ x,
1425
+ attention_mask=attention_mask,
1426
+ attention_weight_bias=attention_weight_bias,
1427
+ )
1428
+
1429
+ if not self._pre_layer_norm:
1430
+ output["embeddings"] = self.layer_norm_self_attention(
1431
+ output["embeddings"] + res
1432
+ )
1433
+
1434
+ x = output["embeddings"]
1435
+ else:
1436
+ x = output["embeddings"]
1437
+ x = res + x
1438
+
1439
+ # MLP
1440
+ if not self._pre_layer_norm:
1441
+ x = self.mlp(x)
1442
+ else:
1443
+ x = x + self.mlp(x)
1444
+
1445
+ output["embeddings"] = x
1446
+ return output
1447
+
1448
+
1449
+ class RobertaLMHead(nn.Module):
1450
+ """
1451
+ Roberta Language Model head. Transforms final attention layer output into a
1452
+ distribution over tokens at each position.
1453
+ """
1454
+
1455
+ def __init__(self, embed_dim: int, alphabet_size: int):
1456
+ """
1457
+ Args:
1458
+ embed_dim: Embedding dimension.
1459
+ alphabet_size: Number of tokens in the alphabet.
1460
+ """
1461
+ super().__init__()
1462
+ self.embed_dim = embed_dim
1463
+ self.alphabet_size = alphabet_size
1464
+
1465
+ # Define layers
1466
+ self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1467
+ self._fc1 = nn.Linear(embed_dim, embed_dim)
1468
+ self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1469
+ self._final_fc = nn.Linear(embed_dim, alphabet_size)
1470
+
1471
+ def forward(self, x: torch.Tensor) -> dict:
1472
+ x = self._first_layer_norm(x)
1473
+ embeddings = x
1474
+ x = self._fc1(x)
1475
+ x = nn.functional.gelu(x)
1476
+ x = self._second_layer_norm(x)
1477
+ logits = self._final_fc(x)
1478
+ return {"embeddings": embeddings, "logits": logits}
1479
+
1480
+
1481
+ class TorchESMTransformer(nn.Module):
1482
+ def __init__(
1483
+ self,
1484
+ esm_config: ESMTransformerConfig,
1485
+ ):
1486
+ super(TorchESMTransformer, self).__init__()
1487
+ self.esm_config = esm_config
1488
+
1489
+ # Other cases are not implemented
1490
+ assert esm_config.positional_embedding is None
1491
+ assert esm_config.lm_head == "roberta"
1492
+ assert esm_config.use_rotary_embedding is True
1493
+ assert esm_config.token_dropout is False
1494
+ assert esm_config.emb_layer_norm_before is False
1495
+ assert esm_config.mask_before_attention is False
1496
+ assert esm_config.bias_word_embedding is False
1497
+ assert esm_config.use_gradient_checkpointing is False
1498
+
1499
+ self.embed_layer = nn.Embedding(esm_config.alphabet_size, esm_config.embed_dim)
1500
+
1501
+ self.lm_head = RobertaLMHead(
1502
+ embed_dim=esm_config.embed_dim,
1503
+ alphabet_size=esm_config.alphabet_size,
1504
+ )
1505
+
1506
+ self.rotary_embedding_config = RotaryEmbeddingConfigBis(
1507
+ rescaling_factor=esm_config.rescaling_factor
1508
+ )
1509
+
1510
+ self.attention_blocks = nn.ModuleList(
1511
+ [
1512
+ SelfAttentionBlock( # type: ignore
1513
+ num_heads=esm_config.attention_heads,
1514
+ embed_dim=esm_config.embed_dim,
1515
+ key_size=esm_config.key_size,
1516
+ ffn_embed_dim=esm_config.ffn_embed_dim,
1517
+ add_bias_kv=esm_config.add_bias_kv,
1518
+ add_bias_fnn=esm_config.add_bias_ffn,
1519
+ ffn_activation_name=esm_config.ffn_activation_name,
1520
+ use_glu_in_ffn=esm_config.use_glu_in_ffn,
1521
+ rotary_embedding_config=self.rotary_embedding_config,
1522
+ layer_norm_eps=esm_config.layer_norm_eps,
1523
+ pre_layer_norm=esm_config.pre_layer_norm,
1524
+ )
1525
+ for _ in range(esm_config.num_layers)
1526
+ ]
1527
+ )
1528
+
1529
+ def forward(
1530
+ self, tokens: torch.Tensor, attention_mask: torch.Tensor = None
1531
+ ) -> torch.Tensor:
1532
+ """
1533
+ Computes the embeddings based on the input tokens.
1534
+
1535
+ Args:
1536
+ tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1537
+ attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1538
+ If no mask is provided, a mask by default which equals 1 over all non
1539
+ pad tokens and 0 over pad tokens is computed.
1540
+
1541
+ Returns:
1542
+ Dictionary containing the final embeddings and logits.
1543
+ """
1544
+ x = self.embed_layer(tokens)
1545
+
1546
+ # RoBERTa's mask scaling factor
1547
+ x = self.esm_config.embed_scale * x
1548
+
1549
+ if attention_mask is None:
1550
+ attention_mask = build_padding_attention_mask(
1551
+ tokens=tokens, pad_token_id=self.esm_config.pad_token_id
1552
+ )
1553
+
1554
+ for layer in self.attention_blocks:
1555
+ x = layer(x, attention_mask)["embeddings"]
1556
+
1557
+ assert self.esm_config.lm_head == "roberta"
1558
+ x = self.lm_head(x)["embeddings"]
1559
+
1560
+ return x
1561
+
1562
+
1563
+ def build_padding_attention_mask(
1564
+ tokens: torch.Tensor, pad_token_id: int
1565
+ ) -> torch.Tensor:
1566
+ """
1567
+ Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
1568
+
1569
+ Args:
1570
+ tokens: Batch of sequences of shape (batch_size, seq_len).
1571
+ pad_token_id: Int corresponding to the <pad> token to mask.
1572
+
1573
+ Returns:
1574
+ Batch of attention masks, masking out <pad> tokens.
1575
+ """
1576
+ padding_mask = tokens != pad_token_id
1577
+ padding_mask = padding_mask.unsqueeze(1)
1578
+ padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask)
1579
+ return padding_mask
1580
+
1581
+
1582
+ class TorchBioBrainEncoder(nn.Module):
1583
+ def __init__(
1584
+ self,
1585
+ esm_config: ESMTransformerConfig,
1586
+ ):
1587
+ super(TorchBioBrainEncoder, self).__init__()
1588
+ self.esm_config = esm_config
1589
+ self.esm_model = TorchESMTransformer(self.esm_config)
1590
+
1591
+ def forward(
1592
+ self,
1593
+ bio_token_ids: torch.Tensor,
1594
+ ) -> torch.Tensor:
1595
+ """
1596
+ Args:
1597
+ bio_token_ids (torch.Tensor):
1598
+ Shape (batch_size, num_bio_tokens)
1599
+
1600
+ Returns:
1601
+ torch.Tensor:
1602
+ Shape (batch_size, num_bio_tokens, embed_dim)
1603
+ """
1604
+ bio_embeddings = self.esm_model(tokens=bio_token_ids)
1605
+
1606
+ return bio_embeddings
1607
+
1608
+
1609
+ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1610
+ def __init__(
1611
+ self,
1612
+ num_heads: int,
1613
+ embed_dim: int,
1614
+ ffn_embed_dim: int,
1615
+ key_size: Optional[int] = None,
1616
+ add_bias_kv: bool = False,
1617
+ add_bias_ffn: bool = True,
1618
+ ffn_activation_name: str = "gelu",
1619
+ use_glu_in_ffn: bool = False,
1620
+ ):
1621
+ super().__init__()
1622
+
1623
+ if key_size is None:
1624
+ if embed_dim % num_heads != 0:
1625
+ raise ValueError(
1626
+ f"Embedding dimension {embed_dim} should be divisible by "
1627
+ f"num_heads {num_heads}."
1628
+ )
1629
+ key_size = embed_dim // num_heads
1630
+
1631
+ self.num_heads = num_heads
1632
+ self.embed_dim = embed_dim
1633
+ self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim
1634
+ self.use_glu_in_ffn = use_glu_in_ffn
1635
+
1636
+ self.cross_attention_1 = MultiHeadAttention(
1637
+ num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1638
+ )
1639
+ self.cross_attention_2 = MultiHeadAttention(
1640
+ num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1641
+ )
1642
+
1643
+ self.norm_cross_attention_1 = nn.LayerNorm(embed_dim)
1644
+ self.norm_cross_attention_2 = nn.LayerNorm(embed_dim)
1645
+ self.norm_mlp = nn.LayerNorm(embed_dim)
1646
+
1647
+ self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn)
1648
+ self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1649
+
1650
+ self.activation_fn = getattr(
1651
+ nn.functional, ffn_activation_name, nn.functional.gelu
1652
+ )
1653
+
1654
+ def mlp(self, x: torch.Tensor) -> torch.Tensor:
1655
+ x = self.norm_mlp(x)
1656
+ if self.use_glu_in_ffn:
1657
+ x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
1658
+ x = self.activation_fn(x1) * x2
1659
+ else:
1660
+ x = self.activation_fn(self.fc1(x))
1661
+ return self.fc2(x)
1662
+
1663
+ def forward(
1664
+ self,
1665
+ x: torch.Tensor,
1666
+ cross_attention_embeddings_1: torch.Tensor,
1667
+ cross_attention_embeddings_2: torch.Tensor,
1668
+ attention_mask_1: Optional[torch.Tensor] = None,
1669
+ attention_mask_2: Optional[torch.Tensor] = None,
1670
+ ) -> Dict[str, torch.Tensor]:
1671
+ res = x
1672
+ x = self.norm_cross_attention_1(x)
1673
+
1674
+ attn_output = self.cross_attention_1(
1675
+ query=x,
1676
+ key=cross_attention_embeddings_1,
1677
+ value=cross_attention_embeddings_1,
1678
+ attention_mask=attention_mask_1,
1679
+ )["embeddings"]
1680
+ x = res + attn_output
1681
+
1682
+ res = x
1683
+ x = self.norm_cross_attention_2(x)
1684
+ attn_output = self.cross_attention_2(
1685
+ query=x,
1686
+ key=cross_attention_embeddings_2,
1687
+ value=cross_attention_embeddings_2,
1688
+ attention_mask=attention_mask_2,
1689
+ )["embeddings"]
1690
+ x = res + attn_output
1691
+
1692
+ x = x + self.mlp(x)
1693
+
1694
+ return {"embeddings": x}
1695
+
1696
+
1697
+ class TorchMultiModalPerceiverResampler(nn.Module):
1698
+ """
1699
+ Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1700
+ """
1701
+
1702
+ def __init__(
1703
+ self,
1704
+ config: PerceiverResamplerConfig,
1705
+ name: Optional[str] = None,
1706
+ ):
1707
+ """
1708
+ Initialize a Perceiver Resampler model.
1709
+
1710
+ Args:
1711
+ config: Dataclass containing model hyperparameters.
1712
+ name: Name for module (custom will break weight loading).
1713
+ """
1714
+ super().__init__()
1715
+ self.config = config
1716
+ self.name = name
1717
+ self.layers = nn.ModuleList(
1718
+ [
1719
+ TorchMultiModalPerceiverResamplerBlock(
1720
+ num_heads=self.config.attention_heads,
1721
+ embed_dim=self.config.embed_dim,
1722
+ key_size=self.config.key_size,
1723
+ ffn_embed_dim=self.config.ffn_embed_dim,
1724
+ add_bias_kv=self.config.add_bias_kv,
1725
+ add_bias_ffn=self.config.add_bias_ffn,
1726
+ ffn_activation_name=self.config.ffn_activation_name,
1727
+ use_glu_in_ffn=self.config.use_glu_in_ffn,
1728
+ )
1729
+ for _ in range(self.config.num_layers)
1730
+ ]
1731
+ )
1732
+
1733
+ self.latent_queries = torch.nn.Parameter(
1734
+ torch.randn(self.config.resampled_length, self.config.embed_dim)
1735
+ * (
1736
+ 1.0
1737
+ / torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32))
1738
+ )
1739
+ )
1740
+
1741
+ def apply_attention_blocks(
1742
+ self,
1743
+ x: torch.Tensor,
1744
+ xf_1: torch.Tensor,
1745
+ xf_2: torch.Tensor,
1746
+ outs: Dict[str, torch.Tensor],
1747
+ attention_mask_1: Optional[torch.Tensor] = None,
1748
+ attention_mask_2: Optional[torch.Tensor] = None,
1749
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1750
+ """
1751
+ Create the blocks of attention layers and applies them.
1752
+ """
1753
+ for layer in self.layers:
1754
+ concat_input_1 = torch.cat([xf_1, x], dim=1)
1755
+ concat_input_2 = torch.cat([xf_2, x], dim=1)
1756
+
1757
+ output = layer(
1758
+ x=x,
1759
+ cross_attention_embeddings_1=concat_input_1,
1760
+ cross_attention_embeddings_2=concat_input_2,
1761
+ attention_mask_1=attention_mask_1,
1762
+ attention_mask_2=attention_mask_2,
1763
+ )
1764
+ x = output["embeddings"]
1765
+
1766
+ return x, outs
1767
+
1768
+ def forward(
1769
+ self,
1770
+ input_embeddings_1: torch.Tensor,
1771
+ input_embeddings_2: torch.Tensor,
1772
+ attention_mask_1: Optional[torch.Tensor] = None,
1773
+ attention_mask_2: Optional[torch.Tensor] = None,
1774
+ ) -> Dict[str, torch.Tensor]:
1775
+ """
1776
+ Computes the embeddings based on the input tokens.
1777
+ """
1778
+ assert (
1779
+ input_embeddings_1.shape[-1] == self.config.embed_dim
1780
+ ), "The input embedding dim should match the model embed dim"
1781
+ assert (
1782
+ input_embeddings_2.shape[-1] == self.config.embed_dim
1783
+ ), "The input embedding dim should match the model embed dim"
1784
+
1785
+ batch_size = input_embeddings_1.shape[0]
1786
+
1787
+ latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1)
1788
+
1789
+ outs: Dict[str, torch.Tensor] = {}
1790
+ x = latent_queries
1791
+
1792
+ x, outs = self.apply_attention_blocks(
1793
+ x=x,
1794
+ xf_1=input_embeddings_1,
1795
+ xf_2=input_embeddings_2,
1796
+ outs=outs,
1797
+ attention_mask_1=attention_mask_1,
1798
+ attention_mask_2=attention_mask_2,
1799
+ )
1800
+
1801
+ outs["embeddings"] = x
1802
+
1803
+ return outs
1804
+
1805
+
1806
+ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1807
+ def __init__(
1808
+ self,
1809
+ perceiver_resampler_config: PerceiverResamplerConfig,
1810
+ input_embed_dim: int,
1811
+ embed_dim: int,
1812
+ bio_pad_token_id: int,
1813
+ english_pad_token_id: int,
1814
+ english_vocab_size: int,
1815
+ ):
1816
+ super().__init__()
1817
+ self.config = perceiver_resampler_config
1818
+ self.input_embed_dim = input_embed_dim
1819
+ self.embed_dim = embed_dim
1820
+ self.bio_pad_token_id = bio_pad_token_id
1821
+ self.english_pad_token_id = english_pad_token_id
1822
+ self.english_vocab_size = english_vocab_size
1823
+
1824
+ self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1825
+ self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1826
+ self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
1827
+
1828
+ def forward(
1829
+ self,
1830
+ bio_token_ids: torch.Tensor,
1831
+ bio_embeddings: torch.Tensor,
1832
+ english_token_ids: torch.Tensor,
1833
+ ) -> torch.Tensor:
1834
+ """
1835
+ Args:
1836
+ bio_token_ids (torch.Tensor):
1837
+ Shape (batch_size, num_bio_tokens)
1838
+
1839
+ bio_embeddings (torch.Tensor):
1840
+ Shape (batch_size, num_bio_tokens, embed_dim)
1841
+
1842
+ english_token_ids (torch.Tensor):
1843
+ Shape (batch_size, num_english_tokens)
1844
+ """
1845
+ projected_bio_embeddings = self.bio_projection(bio_embeddings)
1846
+ english_embeddings = self.token_embedding(english_token_ids)
1847
+
1848
+ bio_attention_mask = build_perceiver_padding_attention_mask(
1849
+ bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
1850
+ )
1851
+ english_attention_mask = build_perceiver_padding_attention_mask(
1852
+ english_token_ids, self.config.resampled_length, self.english_pad_token_id
1853
+ )
1854
+
1855
+ projected_embeddings = self.perceiver_resampler(
1856
+ input_embeddings_1=projected_bio_embeddings,
1857
+ attention_mask_1=bio_attention_mask,
1858
+ input_embeddings_2=english_embeddings,
1859
+ attention_mask_2=english_attention_mask,
1860
+ )["embeddings"]
1861
+
1862
+ return projected_embeddings
1863
+
1864
+
1865
+ def build_perceiver_padding_attention_mask(
1866
+ tokens: torch.Tensor, resampled_length: int, pad_token_id: int
1867
+ ) -> torch.Tensor:
1868
+ batch_size, seq_len = tokens.shape
1869
+ padding_mask = tokens != pad_token_id # (batch_size, seq_len)
1870
+
1871
+ padding_mask = torch.cat(
1872
+ [
1873
+ padding_mask,
1874
+ torch.ones(
1875
+ (batch_size, resampled_length), dtype=torch.bool, device=tokens.device
1876
+ ),
1877
+ ],
1878
+ dim=1,
1879
+ ) # (batch_size, seq_len + resampled_length)
1880
+
1881
+ padding_mask = padding_mask[:, None, None, :]
1882
+ padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1883
+ return padding_mask
config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TorchMultiOmicsModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "chatNT.ChatNTConfig",
7
+ "AutoModel": "chatNT.TorchMultiOmicsModel"
8
+ },
9
+ "bio_pad_token_id": 1,
10
+ "english_pad_token_id": 2,
11
+ "esm_config": {
12
+ "add_bias_ffn": false,
13
+ "add_bias_kv": false,
14
+ "alphabet_size": 4107,
15
+ "attention_heads": 16,
16
+ "attention_maps_to_save": [],
17
+ "bias_word_embedding": false,
18
+ "emb_layer_norm_before": false,
19
+ "embed_dim": 1024,
20
+ "embed_scale": 1.0,
21
+ "embeddings_layers_to_save": [
22
+ 21
23
+ ],
24
+ "ffn_activation_name": "swish",
25
+ "ffn_embed_dim": 4096,
26
+ "key_size": 64,
27
+ "layer_norm_eps": 1e-05,
28
+ "lm_head": "roberta",
29
+ "mask_before_attention": false,
30
+ "mask_token_id": 2,
31
+ "masking_prob": 0.0,
32
+ "masking_ratio": 0.0,
33
+ "max_positions": 2048,
34
+ "num_layers": 29,
35
+ "pad_token_id": 1,
36
+ "positional_embedding": null,
37
+ "pre_layer_norm": true,
38
+ "rescaling_factor": null,
39
+ "token_dropout": false,
40
+ "use_glu_in_ffn": true,
41
+ "use_gradient_checkpointing": false,
42
+ "use_rotary_embedding": true
43
+ },
44
+ "gpt_config": {
45
+ "add_bias_attn": false,
46
+ "add_bias_ffn": false,
47
+ "add_bias_lm_head": false,
48
+ "embed_dim": 4096,
49
+ "eos_token_id": 2,
50
+ "ffn_activation_name": "silu",
51
+ "ffn_embed_dim": 11008,
52
+ "norm_type": "RMS_norm",
53
+ "num_heads": 32,
54
+ "num_kv_heads": 32,
55
+ "num_layers": 32,
56
+ "parallel_attention_ff": false,
57
+ "rms_norm_eps": 1e-06,
58
+ "rope_config": {
59
+ "dim": 128,
60
+ "max_seq_len": 2048,
61
+ "theta": 10000.0
62
+ },
63
+ "use_glu_in_ffn": true,
64
+ "use_gradient_checkpointing": false,
65
+ "vocab_size": 32000
66
+ },
67
+ "model_type": "ChatNT",
68
+ "perceiver_resampler_config": {
69
+ "add_bias_ffn": true,
70
+ "add_bias_kv": false,
71
+ "attention_heads": 32,
72
+ "emb_layer_norm_before": false,
73
+ "embed_dim": 4096,
74
+ "ffn_activation_name": "gelu-no-approx",
75
+ "ffn_embed_dim": 11008,
76
+ "key_size": 128,
77
+ "num_layers": 3,
78
+ "resampled_length": 64,
79
+ "use_glu_in_ffn": false,
80
+ "use_gradient_checkpointing": false
81
+ },
82
+ "seq_token_id": 32000,
83
+ "torch_dtype": "float32",
84
+ "transformers_version": "4.41.1"
85
+ }
model.safetensors.index.json ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 32174388268
4
+ },
5
+ "weight_map": {
6
+ "biobrain_decoder.gpt_model.final_norm.scale": "model-00001-of-00007.safetensors",
7
+ "biobrain_decoder.gpt_model.layers.0.attn_norm.scale": "model-00001-of-00007.safetensors",
8
+ "biobrain_decoder.gpt_model.layers.0.fc1.weight": "model-00001-of-00007.safetensors",
9
+ "biobrain_decoder.gpt_model.layers.0.fc2.weight": "model-00001-of-00007.safetensors",
10
+ "biobrain_decoder.gpt_model.layers.0.ffn_norm.scale": "model-00001-of-00007.safetensors",
11
+ "biobrain_decoder.gpt_model.layers.0.self_attn.key_linear.weight": "model-00001-of-00007.safetensors",
12
+ "biobrain_decoder.gpt_model.layers.0.self_attn.out_linear.weight": "model-00001-of-00007.safetensors",
13
+ "biobrain_decoder.gpt_model.layers.0.self_attn.query_linear.weight": "model-00001-of-00007.safetensors",
14
+ "biobrain_decoder.gpt_model.layers.0.self_attn.value_linear.weight": "model-00001-of-00007.safetensors",
15
+ "biobrain_decoder.gpt_model.layers.1.attn_norm.scale": "model-00001-of-00007.safetensors",
16
+ "biobrain_decoder.gpt_model.layers.1.fc1.weight": "model-00001-of-00007.safetensors",
17
+ "biobrain_decoder.gpt_model.layers.1.fc2.weight": "model-00001-of-00007.safetensors",
18
+ "biobrain_decoder.gpt_model.layers.1.ffn_norm.scale": "model-00001-of-00007.safetensors",
19
+ "biobrain_decoder.gpt_model.layers.1.self_attn.key_linear.weight": "model-00001-of-00007.safetensors",
20
+ "biobrain_decoder.gpt_model.layers.1.self_attn.out_linear.weight": "model-00001-of-00007.safetensors",
21
+ "biobrain_decoder.gpt_model.layers.1.self_attn.query_linear.weight": "model-00001-of-00007.safetensors",
22
+ "biobrain_decoder.gpt_model.layers.1.self_attn.value_linear.weight": "model-00001-of-00007.safetensors",
23
+ "biobrain_decoder.gpt_model.layers.10.attn_norm.scale": "model-00003-of-00007.safetensors",
24
+ "biobrain_decoder.gpt_model.layers.10.fc1.weight": "model-00003-of-00007.safetensors",
25
+ "biobrain_decoder.gpt_model.layers.10.fc2.weight": "model-00003-of-00007.safetensors",
26
+ "biobrain_decoder.gpt_model.layers.10.ffn_norm.scale": "model-00003-of-00007.safetensors",
27
+ "biobrain_decoder.gpt_model.layers.10.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
28
+ "biobrain_decoder.gpt_model.layers.10.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
29
+ "biobrain_decoder.gpt_model.layers.10.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
30
+ "biobrain_decoder.gpt_model.layers.10.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
31
+ "biobrain_decoder.gpt_model.layers.11.attn_norm.scale": "model-00003-of-00007.safetensors",
32
+ "biobrain_decoder.gpt_model.layers.11.fc1.weight": "model-00003-of-00007.safetensors",
33
+ "biobrain_decoder.gpt_model.layers.11.fc2.weight": "model-00003-of-00007.safetensors",
34
+ "biobrain_decoder.gpt_model.layers.11.ffn_norm.scale": "model-00003-of-00007.safetensors",
35
+ "biobrain_decoder.gpt_model.layers.11.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
36
+ "biobrain_decoder.gpt_model.layers.11.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
37
+ "biobrain_decoder.gpt_model.layers.11.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
38
+ "biobrain_decoder.gpt_model.layers.11.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
39
+ "biobrain_decoder.gpt_model.layers.12.attn_norm.scale": "model-00003-of-00007.safetensors",
40
+ "biobrain_decoder.gpt_model.layers.12.fc1.weight": "model-00003-of-00007.safetensors",
41
+ "biobrain_decoder.gpt_model.layers.12.fc2.weight": "model-00003-of-00007.safetensors",
42
+ "biobrain_decoder.gpt_model.layers.12.ffn_norm.scale": "model-00003-of-00007.safetensors",
43
+ "biobrain_decoder.gpt_model.layers.12.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
44
+ "biobrain_decoder.gpt_model.layers.12.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
45
+ "biobrain_decoder.gpt_model.layers.12.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
46
+ "biobrain_decoder.gpt_model.layers.12.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
47
+ "biobrain_decoder.gpt_model.layers.13.attn_norm.scale": "model-00003-of-00007.safetensors",
48
+ "biobrain_decoder.gpt_model.layers.13.fc1.weight": "model-00003-of-00007.safetensors",
49
+ "biobrain_decoder.gpt_model.layers.13.fc2.weight": "model-00003-of-00007.safetensors",
50
+ "biobrain_decoder.gpt_model.layers.13.ffn_norm.scale": "model-00003-of-00007.safetensors",
51
+ "biobrain_decoder.gpt_model.layers.13.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
52
+ "biobrain_decoder.gpt_model.layers.13.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
53
+ "biobrain_decoder.gpt_model.layers.13.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
54
+ "biobrain_decoder.gpt_model.layers.13.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
55
+ "biobrain_decoder.gpt_model.layers.14.attn_norm.scale": "model-00003-of-00007.safetensors",
56
+ "biobrain_decoder.gpt_model.layers.14.fc1.weight": "model-00003-of-00007.safetensors",
57
+ "biobrain_decoder.gpt_model.layers.14.fc2.weight": "model-00003-of-00007.safetensors",
58
+ "biobrain_decoder.gpt_model.layers.14.ffn_norm.scale": "model-00003-of-00007.safetensors",
59
+ "biobrain_decoder.gpt_model.layers.14.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
60
+ "biobrain_decoder.gpt_model.layers.14.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
61
+ "biobrain_decoder.gpt_model.layers.14.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
62
+ "biobrain_decoder.gpt_model.layers.14.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
63
+ "biobrain_decoder.gpt_model.layers.15.attn_norm.scale": "model-00003-of-00007.safetensors",
64
+ "biobrain_decoder.gpt_model.layers.15.fc1.weight": "model-00004-of-00007.safetensors",
65
+ "biobrain_decoder.gpt_model.layers.15.fc2.weight": "model-00004-of-00007.safetensors",
66
+ "biobrain_decoder.gpt_model.layers.15.ffn_norm.scale": "model-00003-of-00007.safetensors",
67
+ "biobrain_decoder.gpt_model.layers.15.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
68
+ "biobrain_decoder.gpt_model.layers.15.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
69
+ "biobrain_decoder.gpt_model.layers.15.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
70
+ "biobrain_decoder.gpt_model.layers.15.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
71
+ "biobrain_decoder.gpt_model.layers.16.attn_norm.scale": "model-00004-of-00007.safetensors",
72
+ "biobrain_decoder.gpt_model.layers.16.fc1.weight": "model-00004-of-00007.safetensors",
73
+ "biobrain_decoder.gpt_model.layers.16.fc2.weight": "model-00004-of-00007.safetensors",
74
+ "biobrain_decoder.gpt_model.layers.16.ffn_norm.scale": "model-00004-of-00007.safetensors",
75
+ "biobrain_decoder.gpt_model.layers.16.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
76
+ "biobrain_decoder.gpt_model.layers.16.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
77
+ "biobrain_decoder.gpt_model.layers.16.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
78
+ "biobrain_decoder.gpt_model.layers.16.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
79
+ "biobrain_decoder.gpt_model.layers.17.attn_norm.scale": "model-00004-of-00007.safetensors",
80
+ "biobrain_decoder.gpt_model.layers.17.fc1.weight": "model-00004-of-00007.safetensors",
81
+ "biobrain_decoder.gpt_model.layers.17.fc2.weight": "model-00004-of-00007.safetensors",
82
+ "biobrain_decoder.gpt_model.layers.17.ffn_norm.scale": "model-00004-of-00007.safetensors",
83
+ "biobrain_decoder.gpt_model.layers.17.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
84
+ "biobrain_decoder.gpt_model.layers.17.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
85
+ "biobrain_decoder.gpt_model.layers.17.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
86
+ "biobrain_decoder.gpt_model.layers.17.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
87
+ "biobrain_decoder.gpt_model.layers.18.attn_norm.scale": "model-00004-of-00007.safetensors",
88
+ "biobrain_decoder.gpt_model.layers.18.fc1.weight": "model-00004-of-00007.safetensors",
89
+ "biobrain_decoder.gpt_model.layers.18.fc2.weight": "model-00004-of-00007.safetensors",
90
+ "biobrain_decoder.gpt_model.layers.18.ffn_norm.scale": "model-00004-of-00007.safetensors",
91
+ "biobrain_decoder.gpt_model.layers.18.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
92
+ "biobrain_decoder.gpt_model.layers.18.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
93
+ "biobrain_decoder.gpt_model.layers.18.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
94
+ "biobrain_decoder.gpt_model.layers.18.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
95
+ "biobrain_decoder.gpt_model.layers.19.attn_norm.scale": "model-00004-of-00007.safetensors",
96
+ "biobrain_decoder.gpt_model.layers.19.fc1.weight": "model-00004-of-00007.safetensors",
97
+ "biobrain_decoder.gpt_model.layers.19.fc2.weight": "model-00004-of-00007.safetensors",
98
+ "biobrain_decoder.gpt_model.layers.19.ffn_norm.scale": "model-00004-of-00007.safetensors",
99
+ "biobrain_decoder.gpt_model.layers.19.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
100
+ "biobrain_decoder.gpt_model.layers.19.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
101
+ "biobrain_decoder.gpt_model.layers.19.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
102
+ "biobrain_decoder.gpt_model.layers.19.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
103
+ "biobrain_decoder.gpt_model.layers.2.attn_norm.scale": "model-00001-of-00007.safetensors",
104
+ "biobrain_decoder.gpt_model.layers.2.fc1.weight": "model-00001-of-00007.safetensors",
105
+ "biobrain_decoder.gpt_model.layers.2.fc2.weight": "model-00001-of-00007.safetensors",
106
+ "biobrain_decoder.gpt_model.layers.2.ffn_norm.scale": "model-00001-of-00007.safetensors",
107
+ "biobrain_decoder.gpt_model.layers.2.self_attn.key_linear.weight": "model-00001-of-00007.safetensors",
108
+ "biobrain_decoder.gpt_model.layers.2.self_attn.out_linear.weight": "model-00001-of-00007.safetensors",
109
+ "biobrain_decoder.gpt_model.layers.2.self_attn.query_linear.weight": "model-00001-of-00007.safetensors",
110
+ "biobrain_decoder.gpt_model.layers.2.self_attn.value_linear.weight": "model-00001-of-00007.safetensors",
111
+ "biobrain_decoder.gpt_model.layers.20.attn_norm.scale": "model-00004-of-00007.safetensors",
112
+ "biobrain_decoder.gpt_model.layers.20.fc1.weight": "model-00004-of-00007.safetensors",
113
+ "biobrain_decoder.gpt_model.layers.20.fc2.weight": "model-00004-of-00007.safetensors",
114
+ "biobrain_decoder.gpt_model.layers.20.ffn_norm.scale": "model-00004-of-00007.safetensors",
115
+ "biobrain_decoder.gpt_model.layers.20.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
116
+ "biobrain_decoder.gpt_model.layers.20.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
117
+ "biobrain_decoder.gpt_model.layers.20.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
118
+ "biobrain_decoder.gpt_model.layers.20.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
119
+ "biobrain_decoder.gpt_model.layers.21.attn_norm.scale": "model-00004-of-00007.safetensors",
120
+ "biobrain_decoder.gpt_model.layers.21.fc1.weight": "model-00005-of-00007.safetensors",
121
+ "biobrain_decoder.gpt_model.layers.21.fc2.weight": "model-00005-of-00007.safetensors",
122
+ "biobrain_decoder.gpt_model.layers.21.ffn_norm.scale": "model-00004-of-00007.safetensors",
123
+ "biobrain_decoder.gpt_model.layers.21.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
124
+ "biobrain_decoder.gpt_model.layers.21.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
125
+ "biobrain_decoder.gpt_model.layers.21.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
126
+ "biobrain_decoder.gpt_model.layers.21.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
127
+ "biobrain_decoder.gpt_model.layers.22.attn_norm.scale": "model-00005-of-00007.safetensors",
128
+ "biobrain_decoder.gpt_model.layers.22.fc1.weight": "model-00005-of-00007.safetensors",
129
+ "biobrain_decoder.gpt_model.layers.22.fc2.weight": "model-00005-of-00007.safetensors",
130
+ "biobrain_decoder.gpt_model.layers.22.ffn_norm.scale": "model-00005-of-00007.safetensors",
131
+ "biobrain_decoder.gpt_model.layers.22.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
132
+ "biobrain_decoder.gpt_model.layers.22.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
133
+ "biobrain_decoder.gpt_model.layers.22.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
134
+ "biobrain_decoder.gpt_model.layers.22.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
135
+ "biobrain_decoder.gpt_model.layers.23.attn_norm.scale": "model-00005-of-00007.safetensors",
136
+ "biobrain_decoder.gpt_model.layers.23.fc1.weight": "model-00005-of-00007.safetensors",
137
+ "biobrain_decoder.gpt_model.layers.23.fc2.weight": "model-00005-of-00007.safetensors",
138
+ "biobrain_decoder.gpt_model.layers.23.ffn_norm.scale": "model-00005-of-00007.safetensors",
139
+ "biobrain_decoder.gpt_model.layers.23.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
140
+ "biobrain_decoder.gpt_model.layers.23.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
141
+ "biobrain_decoder.gpt_model.layers.23.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
142
+ "biobrain_decoder.gpt_model.layers.23.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
143
+ "biobrain_decoder.gpt_model.layers.24.attn_norm.scale": "model-00005-of-00007.safetensors",
144
+ "biobrain_decoder.gpt_model.layers.24.fc1.weight": "model-00005-of-00007.safetensors",
145
+ "biobrain_decoder.gpt_model.layers.24.fc2.weight": "model-00005-of-00007.safetensors",
146
+ "biobrain_decoder.gpt_model.layers.24.ffn_norm.scale": "model-00005-of-00007.safetensors",
147
+ "biobrain_decoder.gpt_model.layers.24.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
148
+ "biobrain_decoder.gpt_model.layers.24.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
149
+ "biobrain_decoder.gpt_model.layers.24.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
150
+ "biobrain_decoder.gpt_model.layers.24.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
151
+ "biobrain_decoder.gpt_model.layers.25.attn_norm.scale": "model-00005-of-00007.safetensors",
152
+ "biobrain_decoder.gpt_model.layers.25.fc1.weight": "model-00005-of-00007.safetensors",
153
+ "biobrain_decoder.gpt_model.layers.25.fc2.weight": "model-00005-of-00007.safetensors",
154
+ "biobrain_decoder.gpt_model.layers.25.ffn_norm.scale": "model-00005-of-00007.safetensors",
155
+ "biobrain_decoder.gpt_model.layers.25.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
156
+ "biobrain_decoder.gpt_model.layers.25.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
157
+ "biobrain_decoder.gpt_model.layers.25.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
158
+ "biobrain_decoder.gpt_model.layers.25.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
159
+ "biobrain_decoder.gpt_model.layers.26.attn_norm.scale": "model-00005-of-00007.safetensors",
160
+ "biobrain_decoder.gpt_model.layers.26.fc1.weight": "model-00005-of-00007.safetensors",
161
+ "biobrain_decoder.gpt_model.layers.26.fc2.weight": "model-00005-of-00007.safetensors",
162
+ "biobrain_decoder.gpt_model.layers.26.ffn_norm.scale": "model-00005-of-00007.safetensors",
163
+ "biobrain_decoder.gpt_model.layers.26.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
164
+ "biobrain_decoder.gpt_model.layers.26.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
165
+ "biobrain_decoder.gpt_model.layers.26.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
166
+ "biobrain_decoder.gpt_model.layers.26.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
167
+ "biobrain_decoder.gpt_model.layers.27.attn_norm.scale": "model-00005-of-00007.safetensors",
168
+ "biobrain_decoder.gpt_model.layers.27.fc1.weight": "model-00006-of-00007.safetensors",
169
+ "biobrain_decoder.gpt_model.layers.27.fc2.weight": "model-00006-of-00007.safetensors",
170
+ "biobrain_decoder.gpt_model.layers.27.ffn_norm.scale": "model-00005-of-00007.safetensors",
171
+ "biobrain_decoder.gpt_model.layers.27.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
172
+ "biobrain_decoder.gpt_model.layers.27.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
173
+ "biobrain_decoder.gpt_model.layers.27.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
174
+ "biobrain_decoder.gpt_model.layers.27.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
175
+ "biobrain_decoder.gpt_model.layers.28.attn_norm.scale": "model-00006-of-00007.safetensors",
176
+ "biobrain_decoder.gpt_model.layers.28.fc1.weight": "model-00006-of-00007.safetensors",
177
+ "biobrain_decoder.gpt_model.layers.28.fc2.weight": "model-00006-of-00007.safetensors",
178
+ "biobrain_decoder.gpt_model.layers.28.ffn_norm.scale": "model-00006-of-00007.safetensors",
179
+ "biobrain_decoder.gpt_model.layers.28.self_attn.key_linear.weight": "model-00006-of-00007.safetensors",
180
+ "biobrain_decoder.gpt_model.layers.28.self_attn.out_linear.weight": "model-00006-of-00007.safetensors",
181
+ "biobrain_decoder.gpt_model.layers.28.self_attn.query_linear.weight": "model-00006-of-00007.safetensors",
182
+ "biobrain_decoder.gpt_model.layers.28.self_attn.value_linear.weight": "model-00006-of-00007.safetensors",
183
+ "biobrain_decoder.gpt_model.layers.29.attn_norm.scale": "model-00006-of-00007.safetensors",
184
+ "biobrain_decoder.gpt_model.layers.29.fc1.weight": "model-00006-of-00007.safetensors",
185
+ "biobrain_decoder.gpt_model.layers.29.fc2.weight": "model-00006-of-00007.safetensors",
186
+ "biobrain_decoder.gpt_model.layers.29.ffn_norm.scale": "model-00006-of-00007.safetensors",
187
+ "biobrain_decoder.gpt_model.layers.29.self_attn.key_linear.weight": "model-00006-of-00007.safetensors",
188
+ "biobrain_decoder.gpt_model.layers.29.self_attn.out_linear.weight": "model-00006-of-00007.safetensors",
189
+ "biobrain_decoder.gpt_model.layers.29.self_attn.query_linear.weight": "model-00006-of-00007.safetensors",
190
+ "biobrain_decoder.gpt_model.layers.29.self_attn.value_linear.weight": "model-00006-of-00007.safetensors",
191
+ "biobrain_decoder.gpt_model.layers.3.attn_norm.scale": "model-00002-of-00007.safetensors",
192
+ "biobrain_decoder.gpt_model.layers.3.fc1.weight": "model-00002-of-00007.safetensors",
193
+ "biobrain_decoder.gpt_model.layers.3.fc2.weight": "model-00002-of-00007.safetensors",
194
+ "biobrain_decoder.gpt_model.layers.3.ffn_norm.scale": "model-00002-of-00007.safetensors",
195
+ "biobrain_decoder.gpt_model.layers.3.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
196
+ "biobrain_decoder.gpt_model.layers.3.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
197
+ "biobrain_decoder.gpt_model.layers.3.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
198
+ "biobrain_decoder.gpt_model.layers.3.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
199
+ "biobrain_decoder.gpt_model.layers.30.attn_norm.scale": "model-00006-of-00007.safetensors",
200
+ "biobrain_decoder.gpt_model.layers.30.fc1.weight": "model-00006-of-00007.safetensors",
201
+ "biobrain_decoder.gpt_model.layers.30.fc2.weight": "model-00006-of-00007.safetensors",
202
+ "biobrain_decoder.gpt_model.layers.30.ffn_norm.scale": "model-00006-of-00007.safetensors",
203
+ "biobrain_decoder.gpt_model.layers.30.self_attn.key_linear.weight": "model-00006-of-00007.safetensors",
204
+ "biobrain_decoder.gpt_model.layers.30.self_attn.out_linear.weight": "model-00006-of-00007.safetensors",
205
+ "biobrain_decoder.gpt_model.layers.30.self_attn.query_linear.weight": "model-00006-of-00007.safetensors",
206
+ "biobrain_decoder.gpt_model.layers.30.self_attn.value_linear.weight": "model-00006-of-00007.safetensors",
207
+ "biobrain_decoder.gpt_model.layers.31.attn_norm.scale": "model-00006-of-00007.safetensors",
208
+ "biobrain_decoder.gpt_model.layers.31.fc1.weight": "model-00006-of-00007.safetensors",
209
+ "biobrain_decoder.gpt_model.layers.31.fc2.weight": "model-00006-of-00007.safetensors",
210
+ "biobrain_decoder.gpt_model.layers.31.ffn_norm.scale": "model-00006-of-00007.safetensors",
211
+ "biobrain_decoder.gpt_model.layers.31.self_attn.key_linear.weight": "model-00006-of-00007.safetensors",
212
+ "biobrain_decoder.gpt_model.layers.31.self_attn.out_linear.weight": "model-00006-of-00007.safetensors",
213
+ "biobrain_decoder.gpt_model.layers.31.self_attn.query_linear.weight": "model-00006-of-00007.safetensors",
214
+ "biobrain_decoder.gpt_model.layers.31.self_attn.value_linear.weight": "model-00006-of-00007.safetensors",
215
+ "biobrain_decoder.gpt_model.layers.4.attn_norm.scale": "model-00002-of-00007.safetensors",
216
+ "biobrain_decoder.gpt_model.layers.4.fc1.weight": "model-00002-of-00007.safetensors",
217
+ "biobrain_decoder.gpt_model.layers.4.fc2.weight": "model-00002-of-00007.safetensors",
218
+ "biobrain_decoder.gpt_model.layers.4.ffn_norm.scale": "model-00002-of-00007.safetensors",
219
+ "biobrain_decoder.gpt_model.layers.4.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
220
+ "biobrain_decoder.gpt_model.layers.4.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
221
+ "biobrain_decoder.gpt_model.layers.4.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
222
+ "biobrain_decoder.gpt_model.layers.4.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
223
+ "biobrain_decoder.gpt_model.layers.5.attn_norm.scale": "model-00002-of-00007.safetensors",
224
+ "biobrain_decoder.gpt_model.layers.5.fc1.weight": "model-00002-of-00007.safetensors",
225
+ "biobrain_decoder.gpt_model.layers.5.fc2.weight": "model-00002-of-00007.safetensors",
226
+ "biobrain_decoder.gpt_model.layers.5.ffn_norm.scale": "model-00002-of-00007.safetensors",
227
+ "biobrain_decoder.gpt_model.layers.5.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
228
+ "biobrain_decoder.gpt_model.layers.5.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
229
+ "biobrain_decoder.gpt_model.layers.5.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
230
+ "biobrain_decoder.gpt_model.layers.5.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
231
+ "biobrain_decoder.gpt_model.layers.6.attn_norm.scale": "model-00002-of-00007.safetensors",
232
+ "biobrain_decoder.gpt_model.layers.6.fc1.weight": "model-00002-of-00007.safetensors",
233
+ "biobrain_decoder.gpt_model.layers.6.fc2.weight": "model-00002-of-00007.safetensors",
234
+ "biobrain_decoder.gpt_model.layers.6.ffn_norm.scale": "model-00002-of-00007.safetensors",
235
+ "biobrain_decoder.gpt_model.layers.6.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
236
+ "biobrain_decoder.gpt_model.layers.6.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
237
+ "biobrain_decoder.gpt_model.layers.6.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
238
+ "biobrain_decoder.gpt_model.layers.6.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
239
+ "biobrain_decoder.gpt_model.layers.7.attn_norm.scale": "model-00002-of-00007.safetensors",
240
+ "biobrain_decoder.gpt_model.layers.7.fc1.weight": "model-00002-of-00007.safetensors",
241
+ "biobrain_decoder.gpt_model.layers.7.fc2.weight": "model-00002-of-00007.safetensors",
242
+ "biobrain_decoder.gpt_model.layers.7.ffn_norm.scale": "model-00002-of-00007.safetensors",
243
+ "biobrain_decoder.gpt_model.layers.7.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
244
+ "biobrain_decoder.gpt_model.layers.7.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
245
+ "biobrain_decoder.gpt_model.layers.7.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
246
+ "biobrain_decoder.gpt_model.layers.7.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
247
+ "biobrain_decoder.gpt_model.layers.8.attn_norm.scale": "model-00002-of-00007.safetensors",
248
+ "biobrain_decoder.gpt_model.layers.8.fc1.weight": "model-00002-of-00007.safetensors",
249
+ "biobrain_decoder.gpt_model.layers.8.fc2.weight": "model-00002-of-00007.safetensors",
250
+ "biobrain_decoder.gpt_model.layers.8.ffn_norm.scale": "model-00002-of-00007.safetensors",
251
+ "biobrain_decoder.gpt_model.layers.8.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
252
+ "biobrain_decoder.gpt_model.layers.8.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
253
+ "biobrain_decoder.gpt_model.layers.8.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
254
+ "biobrain_decoder.gpt_model.layers.8.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
255
+ "biobrain_decoder.gpt_model.layers.9.attn_norm.scale": "model-00003-of-00007.safetensors",
256
+ "biobrain_decoder.gpt_model.layers.9.fc1.weight": "model-00003-of-00007.safetensors",
257
+ "biobrain_decoder.gpt_model.layers.9.fc2.weight": "model-00003-of-00007.safetensors",
258
+ "biobrain_decoder.gpt_model.layers.9.ffn_norm.scale": "model-00003-of-00007.safetensors",
259
+ "biobrain_decoder.gpt_model.layers.9.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
260
+ "biobrain_decoder.gpt_model.layers.9.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
261
+ "biobrain_decoder.gpt_model.layers.9.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
262
+ "biobrain_decoder.gpt_model.layers.9.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
263
+ "biobrain_decoder.gpt_model.lm_head.fc.weight": "model-00006-of-00007.safetensors",
264
+ "biobrain_decoder.gpt_model.token_embed.weight": "model-00001-of-00007.safetensors",
265
+ "biobrain_encoder.esm_model.attention_blocks.0.fc1.weight": "model-00001-of-00007.safetensors",
266
+ "biobrain_encoder.esm_model.attention_blocks.0.fc2.weight": "model-00001-of-00007.safetensors",
267
+ "biobrain_encoder.esm_model.attention_blocks.0.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
268
+ "biobrain_encoder.esm_model.attention_blocks.0.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
269
+ "biobrain_encoder.esm_model.attention_blocks.0.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
270
+ "biobrain_encoder.esm_model.attention_blocks.0.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
271
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.output.bias": "model-00001-of-00007.safetensors",
272
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.output.weight": "model-00001-of-00007.safetensors",
273
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_k.bias": "model-00001-of-00007.safetensors",
274
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_k.weight": "model-00001-of-00007.safetensors",
275
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_q.bias": "model-00001-of-00007.safetensors",
276
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_q.weight": "model-00001-of-00007.safetensors",
277
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_v.bias": "model-00001-of-00007.safetensors",
278
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_v.weight": "model-00001-of-00007.safetensors",
279
+ "biobrain_encoder.esm_model.attention_blocks.1.fc1.weight": "model-00001-of-00007.safetensors",
280
+ "biobrain_encoder.esm_model.attention_blocks.1.fc2.weight": "model-00001-of-00007.safetensors",
281
+ "biobrain_encoder.esm_model.attention_blocks.1.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
282
+ "biobrain_encoder.esm_model.attention_blocks.1.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
283
+ "biobrain_encoder.esm_model.attention_blocks.1.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
284
+ "biobrain_encoder.esm_model.attention_blocks.1.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
285
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.output.bias": "model-00001-of-00007.safetensors",
286
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.output.weight": "model-00001-of-00007.safetensors",
287
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_k.bias": "model-00001-of-00007.safetensors",
288
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_k.weight": "model-00001-of-00007.safetensors",
289
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_q.bias": "model-00001-of-00007.safetensors",
290
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_q.weight": "model-00001-of-00007.safetensors",
291
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_v.bias": "model-00001-of-00007.safetensors",
292
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_v.weight": "model-00001-of-00007.safetensors",
293
+ "biobrain_encoder.esm_model.attention_blocks.10.fc1.weight": "model-00001-of-00007.safetensors",
294
+ "biobrain_encoder.esm_model.attention_blocks.10.fc2.weight": "model-00001-of-00007.safetensors",
295
+ "biobrain_encoder.esm_model.attention_blocks.10.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
296
+ "biobrain_encoder.esm_model.attention_blocks.10.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
297
+ "biobrain_encoder.esm_model.attention_blocks.10.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
298
+ "biobrain_encoder.esm_model.attention_blocks.10.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
299
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.output.bias": "model-00001-of-00007.safetensors",
300
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.output.weight": "model-00001-of-00007.safetensors",
301
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_k.bias": "model-00001-of-00007.safetensors",
302
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_k.weight": "model-00001-of-00007.safetensors",
303
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_q.bias": "model-00001-of-00007.safetensors",
304
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_q.weight": "model-00001-of-00007.safetensors",
305
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_v.bias": "model-00001-of-00007.safetensors",
306
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_v.weight": "model-00001-of-00007.safetensors",
307
+ "biobrain_encoder.esm_model.attention_blocks.11.fc1.weight": "model-00001-of-00007.safetensors",
308
+ "biobrain_encoder.esm_model.attention_blocks.11.fc2.weight": "model-00001-of-00007.safetensors",
309
+ "biobrain_encoder.esm_model.attention_blocks.11.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
310
+ "biobrain_encoder.esm_model.attention_blocks.11.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
311
+ "biobrain_encoder.esm_model.attention_blocks.11.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
312
+ "biobrain_encoder.esm_model.attention_blocks.11.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
313
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.output.bias": "model-00001-of-00007.safetensors",
314
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.output.weight": "model-00001-of-00007.safetensors",
315
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_k.bias": "model-00001-of-00007.safetensors",
316
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_k.weight": "model-00001-of-00007.safetensors",
317
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_q.bias": "model-00001-of-00007.safetensors",
318
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_q.weight": "model-00001-of-00007.safetensors",
319
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_v.bias": "model-00001-of-00007.safetensors",
320
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_v.weight": "model-00001-of-00007.safetensors",
321
+ "biobrain_encoder.esm_model.attention_blocks.12.fc1.weight": "model-00001-of-00007.safetensors",
322
+ "biobrain_encoder.esm_model.attention_blocks.12.fc2.weight": "model-00001-of-00007.safetensors",
323
+ "biobrain_encoder.esm_model.attention_blocks.12.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
324
+ "biobrain_encoder.esm_model.attention_blocks.12.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
325
+ "biobrain_encoder.esm_model.attention_blocks.12.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
326
+ "biobrain_encoder.esm_model.attention_blocks.12.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
327
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.output.bias": "model-00001-of-00007.safetensors",
328
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.output.weight": "model-00001-of-00007.safetensors",
329
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_k.bias": "model-00001-of-00007.safetensors",
330
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_k.weight": "model-00001-of-00007.safetensors",
331
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_q.bias": "model-00001-of-00007.safetensors",
332
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_q.weight": "model-00001-of-00007.safetensors",
333
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_v.bias": "model-00001-of-00007.safetensors",
334
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_v.weight": "model-00001-of-00007.safetensors",
335
+ "biobrain_encoder.esm_model.attention_blocks.13.fc1.weight": "model-00001-of-00007.safetensors",
336
+ "biobrain_encoder.esm_model.attention_blocks.13.fc2.weight": "model-00001-of-00007.safetensors",
337
+ "biobrain_encoder.esm_model.attention_blocks.13.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
338
+ "biobrain_encoder.esm_model.attention_blocks.13.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
339
+ "biobrain_encoder.esm_model.attention_blocks.13.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
340
+ "biobrain_encoder.esm_model.attention_blocks.13.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
341
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.output.bias": "model-00001-of-00007.safetensors",
342
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.output.weight": "model-00001-of-00007.safetensors",
343
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_k.bias": "model-00001-of-00007.safetensors",
344
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_k.weight": "model-00001-of-00007.safetensors",
345
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_q.bias": "model-00001-of-00007.safetensors",
346
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_q.weight": "model-00001-of-00007.safetensors",
347
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_v.bias": "model-00001-of-00007.safetensors",
348
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_v.weight": "model-00001-of-00007.safetensors",
349
+ "biobrain_encoder.esm_model.attention_blocks.14.fc1.weight": "model-00001-of-00007.safetensors",
350
+ "biobrain_encoder.esm_model.attention_blocks.14.fc2.weight": "model-00001-of-00007.safetensors",
351
+ "biobrain_encoder.esm_model.attention_blocks.14.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
352
+ "biobrain_encoder.esm_model.attention_blocks.14.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
353
+ "biobrain_encoder.esm_model.attention_blocks.14.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
354
+ "biobrain_encoder.esm_model.attention_blocks.14.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
355
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.output.bias": "model-00001-of-00007.safetensors",
356
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.output.weight": "model-00001-of-00007.safetensors",
357
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_k.bias": "model-00001-of-00007.safetensors",
358
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_k.weight": "model-00001-of-00007.safetensors",
359
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_q.bias": "model-00001-of-00007.safetensors",
360
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_q.weight": "model-00001-of-00007.safetensors",
361
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_v.bias": "model-00001-of-00007.safetensors",
362
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_v.weight": "model-00001-of-00007.safetensors",
363
+ "biobrain_encoder.esm_model.attention_blocks.15.fc1.weight": "model-00001-of-00007.safetensors",
364
+ "biobrain_encoder.esm_model.attention_blocks.15.fc2.weight": "model-00001-of-00007.safetensors",
365
+ "biobrain_encoder.esm_model.attention_blocks.15.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
366
+ "biobrain_encoder.esm_model.attention_blocks.15.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
367
+ "biobrain_encoder.esm_model.attention_blocks.15.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
368
+ "biobrain_encoder.esm_model.attention_blocks.15.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
369
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.output.bias": "model-00001-of-00007.safetensors",
370
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.output.weight": "model-00001-of-00007.safetensors",
371
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_k.bias": "model-00001-of-00007.safetensors",
372
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_k.weight": "model-00001-of-00007.safetensors",
373
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_q.bias": "model-00001-of-00007.safetensors",
374
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_q.weight": "model-00001-of-00007.safetensors",
375
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_v.bias": "model-00001-of-00007.safetensors",
376
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_v.weight": "model-00001-of-00007.safetensors",
377
+ "biobrain_encoder.esm_model.attention_blocks.16.fc1.weight": "model-00001-of-00007.safetensors",
378
+ "biobrain_encoder.esm_model.attention_blocks.16.fc2.weight": "model-00001-of-00007.safetensors",
379
+ "biobrain_encoder.esm_model.attention_blocks.16.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
380
+ "biobrain_encoder.esm_model.attention_blocks.16.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
381
+ "biobrain_encoder.esm_model.attention_blocks.16.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
382
+ "biobrain_encoder.esm_model.attention_blocks.16.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
383
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.output.bias": "model-00001-of-00007.safetensors",
384
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.output.weight": "model-00001-of-00007.safetensors",
385
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_k.bias": "model-00001-of-00007.safetensors",
386
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_k.weight": "model-00001-of-00007.safetensors",
387
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_q.bias": "model-00001-of-00007.safetensors",
388
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_q.weight": "model-00001-of-00007.safetensors",
389
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_v.bias": "model-00001-of-00007.safetensors",
390
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_v.weight": "model-00001-of-00007.safetensors",
391
+ "biobrain_encoder.esm_model.attention_blocks.17.fc1.weight": "model-00001-of-00007.safetensors",
392
+ "biobrain_encoder.esm_model.attention_blocks.17.fc2.weight": "model-00001-of-00007.safetensors",
393
+ "biobrain_encoder.esm_model.attention_blocks.17.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
394
+ "biobrain_encoder.esm_model.attention_blocks.17.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
395
+ "biobrain_encoder.esm_model.attention_blocks.17.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
396
+ "biobrain_encoder.esm_model.attention_blocks.17.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
397
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.output.bias": "model-00001-of-00007.safetensors",
398
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.output.weight": "model-00001-of-00007.safetensors",
399
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_k.bias": "model-00001-of-00007.safetensors",
400
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_k.weight": "model-00001-of-00007.safetensors",
401
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_q.bias": "model-00001-of-00007.safetensors",
402
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_q.weight": "model-00001-of-00007.safetensors",
403
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_v.bias": "model-00001-of-00007.safetensors",
404
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_v.weight": "model-00001-of-00007.safetensors",
405
+ "biobrain_encoder.esm_model.attention_blocks.18.fc1.weight": "model-00001-of-00007.safetensors",
406
+ "biobrain_encoder.esm_model.attention_blocks.18.fc2.weight": "model-00001-of-00007.safetensors",
407
+ "biobrain_encoder.esm_model.attention_blocks.18.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
408
+ "biobrain_encoder.esm_model.attention_blocks.18.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
409
+ "biobrain_encoder.esm_model.attention_blocks.18.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
410
+ "biobrain_encoder.esm_model.attention_blocks.18.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
411
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.output.bias": "model-00001-of-00007.safetensors",
412
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.output.weight": "model-00001-of-00007.safetensors",
413
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_k.bias": "model-00001-of-00007.safetensors",
414
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_k.weight": "model-00001-of-00007.safetensors",
415
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_q.bias": "model-00001-of-00007.safetensors",
416
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_q.weight": "model-00001-of-00007.safetensors",
417
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_v.bias": "model-00001-of-00007.safetensors",
418
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_v.weight": "model-00001-of-00007.safetensors",
419
+ "biobrain_encoder.esm_model.attention_blocks.19.fc1.weight": "model-00001-of-00007.safetensors",
420
+ "biobrain_encoder.esm_model.attention_blocks.19.fc2.weight": "model-00001-of-00007.safetensors",
421
+ "biobrain_encoder.esm_model.attention_blocks.19.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
422
+ "biobrain_encoder.esm_model.attention_blocks.19.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
423
+ "biobrain_encoder.esm_model.attention_blocks.19.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
424
+ "biobrain_encoder.esm_model.attention_blocks.19.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
425
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.output.bias": "model-00001-of-00007.safetensors",
426
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.output.weight": "model-00001-of-00007.safetensors",
427
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_k.bias": "model-00001-of-00007.safetensors",
428
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_k.weight": "model-00001-of-00007.safetensors",
429
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_q.bias": "model-00001-of-00007.safetensors",
430
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_q.weight": "model-00001-of-00007.safetensors",
431
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_v.bias": "model-00001-of-00007.safetensors",
432
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_v.weight": "model-00001-of-00007.safetensors",
433
+ "biobrain_encoder.esm_model.attention_blocks.2.fc1.weight": "model-00001-of-00007.safetensors",
434
+ "biobrain_encoder.esm_model.attention_blocks.2.fc2.weight": "model-00001-of-00007.safetensors",
435
+ "biobrain_encoder.esm_model.attention_blocks.2.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
436
+ "biobrain_encoder.esm_model.attention_blocks.2.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
437
+ "biobrain_encoder.esm_model.attention_blocks.2.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
438
+ "biobrain_encoder.esm_model.attention_blocks.2.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
439
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.output.bias": "model-00001-of-00007.safetensors",
440
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.output.weight": "model-00001-of-00007.safetensors",
441
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_k.bias": "model-00001-of-00007.safetensors",
442
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_k.weight": "model-00001-of-00007.safetensors",
443
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_q.bias": "model-00001-of-00007.safetensors",
444
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_q.weight": "model-00001-of-00007.safetensors",
445
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_v.bias": "model-00001-of-00007.safetensors",
446
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_v.weight": "model-00001-of-00007.safetensors",
447
+ "biobrain_encoder.esm_model.attention_blocks.20.fc1.weight": "model-00001-of-00007.safetensors",
448
+ "biobrain_encoder.esm_model.attention_blocks.20.fc2.weight": "model-00001-of-00007.safetensors",
449
+ "biobrain_encoder.esm_model.attention_blocks.20.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
450
+ "biobrain_encoder.esm_model.attention_blocks.20.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
451
+ "biobrain_encoder.esm_model.attention_blocks.20.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
452
+ "biobrain_encoder.esm_model.attention_blocks.20.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
453
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.output.bias": "model-00001-of-00007.safetensors",
454
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.output.weight": "model-00001-of-00007.safetensors",
455
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_k.bias": "model-00001-of-00007.safetensors",
456
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_k.weight": "model-00001-of-00007.safetensors",
457
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_q.bias": "model-00001-of-00007.safetensors",
458
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_q.weight": "model-00001-of-00007.safetensors",
459
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_v.bias": "model-00001-of-00007.safetensors",
460
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_v.weight": "model-00001-of-00007.safetensors",
461
+ "biobrain_encoder.esm_model.attention_blocks.21.fc1.weight": "model-00001-of-00007.safetensors",
462
+ "biobrain_encoder.esm_model.attention_blocks.21.fc2.weight": "model-00001-of-00007.safetensors",
463
+ "biobrain_encoder.esm_model.attention_blocks.21.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
464
+ "biobrain_encoder.esm_model.attention_blocks.21.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
465
+ "biobrain_encoder.esm_model.attention_blocks.21.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
466
+ "biobrain_encoder.esm_model.attention_blocks.21.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
467
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.output.bias": "model-00001-of-00007.safetensors",
468
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.output.weight": "model-00001-of-00007.safetensors",
469
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_k.bias": "model-00001-of-00007.safetensors",
470
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_k.weight": "model-00001-of-00007.safetensors",
471
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_q.bias": "model-00001-of-00007.safetensors",
472
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_q.weight": "model-00001-of-00007.safetensors",
473
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_v.bias": "model-00001-of-00007.safetensors",
474
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_v.weight": "model-00001-of-00007.safetensors",
475
+ "biobrain_encoder.esm_model.attention_blocks.22.fc1.weight": "model-00001-of-00007.safetensors",
476
+ "biobrain_encoder.esm_model.attention_blocks.22.fc2.weight": "model-00001-of-00007.safetensors",
477
+ "biobrain_encoder.esm_model.attention_blocks.22.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
478
+ "biobrain_encoder.esm_model.attention_blocks.22.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
479
+ "biobrain_encoder.esm_model.attention_blocks.22.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
480
+ "biobrain_encoder.esm_model.attention_blocks.22.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
481
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.output.bias": "model-00001-of-00007.safetensors",
482
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.output.weight": "model-00001-of-00007.safetensors",
483
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_k.bias": "model-00001-of-00007.safetensors",
484
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_k.weight": "model-00001-of-00007.safetensors",
485
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_q.bias": "model-00001-of-00007.safetensors",
486
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_q.weight": "model-00001-of-00007.safetensors",
487
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_v.bias": "model-00001-of-00007.safetensors",
488
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_v.weight": "model-00001-of-00007.safetensors",
489
+ "biobrain_encoder.esm_model.attention_blocks.23.fc1.weight": "model-00001-of-00007.safetensors",
490
+ "biobrain_encoder.esm_model.attention_blocks.23.fc2.weight": "model-00001-of-00007.safetensors",
491
+ "biobrain_encoder.esm_model.attention_blocks.23.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
492
+ "biobrain_encoder.esm_model.attention_blocks.23.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
493
+ "biobrain_encoder.esm_model.attention_blocks.23.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
494
+ "biobrain_encoder.esm_model.attention_blocks.23.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
495
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.output.bias": "model-00001-of-00007.safetensors",
496
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.output.weight": "model-00001-of-00007.safetensors",
497
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_k.bias": "model-00001-of-00007.safetensors",
498
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_k.weight": "model-00001-of-00007.safetensors",
499
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_q.bias": "model-00001-of-00007.safetensors",
500
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_q.weight": "model-00001-of-00007.safetensors",
501
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_v.bias": "model-00001-of-00007.safetensors",
502
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_v.weight": "model-00001-of-00007.safetensors",
503
+ "biobrain_encoder.esm_model.attention_blocks.24.fc1.weight": "model-00001-of-00007.safetensors",
504
+ "biobrain_encoder.esm_model.attention_blocks.24.fc2.weight": "model-00001-of-00007.safetensors",
505
+ "biobrain_encoder.esm_model.attention_blocks.24.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
506
+ "biobrain_encoder.esm_model.attention_blocks.24.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
507
+ "biobrain_encoder.esm_model.attention_blocks.24.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
508
+ "biobrain_encoder.esm_model.attention_blocks.24.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
509
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.output.bias": "model-00001-of-00007.safetensors",
510
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.output.weight": "model-00001-of-00007.safetensors",
511
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_k.bias": "model-00001-of-00007.safetensors",
512
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_k.weight": "model-00001-of-00007.safetensors",
513
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_q.bias": "model-00001-of-00007.safetensors",
514
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_q.weight": "model-00001-of-00007.safetensors",
515
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_v.bias": "model-00001-of-00007.safetensors",
516
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_v.weight": "model-00001-of-00007.safetensors",
517
+ "biobrain_encoder.esm_model.attention_blocks.25.fc1.weight": "model-00001-of-00007.safetensors",
518
+ "biobrain_encoder.esm_model.attention_blocks.25.fc2.weight": "model-00001-of-00007.safetensors",
519
+ "biobrain_encoder.esm_model.attention_blocks.25.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
520
+ "biobrain_encoder.esm_model.attention_blocks.25.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
521
+ "biobrain_encoder.esm_model.attention_blocks.25.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
522
+ "biobrain_encoder.esm_model.attention_blocks.25.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
523
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.output.bias": "model-00001-of-00007.safetensors",
524
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.output.weight": "model-00001-of-00007.safetensors",
525
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_k.bias": "model-00001-of-00007.safetensors",
526
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_k.weight": "model-00001-of-00007.safetensors",
527
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_q.bias": "model-00001-of-00007.safetensors",
528
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_q.weight": "model-00001-of-00007.safetensors",
529
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_v.bias": "model-00001-of-00007.safetensors",
530
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_v.weight": "model-00001-of-00007.safetensors",
531
+ "biobrain_encoder.esm_model.attention_blocks.26.fc1.weight": "model-00001-of-00007.safetensors",
532
+ "biobrain_encoder.esm_model.attention_blocks.26.fc2.weight": "model-00001-of-00007.safetensors",
533
+ "biobrain_encoder.esm_model.attention_blocks.26.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
534
+ "biobrain_encoder.esm_model.attention_blocks.26.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
535
+ "biobrain_encoder.esm_model.attention_blocks.26.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
536
+ "biobrain_encoder.esm_model.attention_blocks.26.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
537
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.output.bias": "model-00001-of-00007.safetensors",
538
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.output.weight": "model-00001-of-00007.safetensors",
539
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_k.bias": "model-00001-of-00007.safetensors",
540
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_k.weight": "model-00001-of-00007.safetensors",
541
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_q.bias": "model-00001-of-00007.safetensors",
542
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_q.weight": "model-00001-of-00007.safetensors",
543
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_v.bias": "model-00001-of-00007.safetensors",
544
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_v.weight": "model-00001-of-00007.safetensors",
545
+ "biobrain_encoder.esm_model.attention_blocks.27.fc1.weight": "model-00001-of-00007.safetensors",
546
+ "biobrain_encoder.esm_model.attention_blocks.27.fc2.weight": "model-00001-of-00007.safetensors",
547
+ "biobrain_encoder.esm_model.attention_blocks.27.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
548
+ "biobrain_encoder.esm_model.attention_blocks.27.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
549
+ "biobrain_encoder.esm_model.attention_blocks.27.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
550
+ "biobrain_encoder.esm_model.attention_blocks.27.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
551
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.output.bias": "model-00001-of-00007.safetensors",
552
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.output.weight": "model-00001-of-00007.safetensors",
553
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_k.bias": "model-00001-of-00007.safetensors",
554
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_k.weight": "model-00001-of-00007.safetensors",
555
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_q.bias": "model-00001-of-00007.safetensors",
556
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_q.weight": "model-00001-of-00007.safetensors",
557
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_v.bias": "model-00001-of-00007.safetensors",
558
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_v.weight": "model-00001-of-00007.safetensors",
559
+ "biobrain_encoder.esm_model.attention_blocks.28.fc1.weight": "model-00001-of-00007.safetensors",
560
+ "biobrain_encoder.esm_model.attention_blocks.28.fc2.weight": "model-00001-of-00007.safetensors",
561
+ "biobrain_encoder.esm_model.attention_blocks.28.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
562
+ "biobrain_encoder.esm_model.attention_blocks.28.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
563
+ "biobrain_encoder.esm_model.attention_blocks.28.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
564
+ "biobrain_encoder.esm_model.attention_blocks.28.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
565
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.output.bias": "model-00001-of-00007.safetensors",
566
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.output.weight": "model-00001-of-00007.safetensors",
567
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_k.bias": "model-00001-of-00007.safetensors",
568
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_k.weight": "model-00001-of-00007.safetensors",
569
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_q.bias": "model-00001-of-00007.safetensors",
570
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_q.weight": "model-00001-of-00007.safetensors",
571
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_v.bias": "model-00001-of-00007.safetensors",
572
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_v.weight": "model-00001-of-00007.safetensors",
573
+ "biobrain_encoder.esm_model.attention_blocks.3.fc1.weight": "model-00001-of-00007.safetensors",
574
+ "biobrain_encoder.esm_model.attention_blocks.3.fc2.weight": "model-00001-of-00007.safetensors",
575
+ "biobrain_encoder.esm_model.attention_blocks.3.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
576
+ "biobrain_encoder.esm_model.attention_blocks.3.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
577
+ "biobrain_encoder.esm_model.attention_blocks.3.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
578
+ "biobrain_encoder.esm_model.attention_blocks.3.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
579
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.output.bias": "model-00001-of-00007.safetensors",
580
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.output.weight": "model-00001-of-00007.safetensors",
581
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_k.bias": "model-00001-of-00007.safetensors",
582
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_k.weight": "model-00001-of-00007.safetensors",
583
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_q.bias": "model-00001-of-00007.safetensors",
584
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_q.weight": "model-00001-of-00007.safetensors",
585
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_v.bias": "model-00001-of-00007.safetensors",
586
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_v.weight": "model-00001-of-00007.safetensors",
587
+ "biobrain_encoder.esm_model.attention_blocks.4.fc1.weight": "model-00001-of-00007.safetensors",
588
+ "biobrain_encoder.esm_model.attention_blocks.4.fc2.weight": "model-00001-of-00007.safetensors",
589
+ "biobrain_encoder.esm_model.attention_blocks.4.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
590
+ "biobrain_encoder.esm_model.attention_blocks.4.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
591
+ "biobrain_encoder.esm_model.attention_blocks.4.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
592
+ "biobrain_encoder.esm_model.attention_blocks.4.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
593
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.output.bias": "model-00001-of-00007.safetensors",
594
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.output.weight": "model-00001-of-00007.safetensors",
595
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_k.bias": "model-00001-of-00007.safetensors",
596
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_k.weight": "model-00001-of-00007.safetensors",
597
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_q.bias": "model-00001-of-00007.safetensors",
598
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_q.weight": "model-00001-of-00007.safetensors",
599
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_v.bias": "model-00001-of-00007.safetensors",
600
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_v.weight": "model-00001-of-00007.safetensors",
601
+ "biobrain_encoder.esm_model.attention_blocks.5.fc1.weight": "model-00001-of-00007.safetensors",
602
+ "biobrain_encoder.esm_model.attention_blocks.5.fc2.weight": "model-00001-of-00007.safetensors",
603
+ "biobrain_encoder.esm_model.attention_blocks.5.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
604
+ "biobrain_encoder.esm_model.attention_blocks.5.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
605
+ "biobrain_encoder.esm_model.attention_blocks.5.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
606
+ "biobrain_encoder.esm_model.attention_blocks.5.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
607
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.output.bias": "model-00001-of-00007.safetensors",
608
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.output.weight": "model-00001-of-00007.safetensors",
609
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_k.bias": "model-00001-of-00007.safetensors",
610
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_k.weight": "model-00001-of-00007.safetensors",
611
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_q.bias": "model-00001-of-00007.safetensors",
612
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_q.weight": "model-00001-of-00007.safetensors",
613
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_v.bias": "model-00001-of-00007.safetensors",
614
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_v.weight": "model-00001-of-00007.safetensors",
615
+ "biobrain_encoder.esm_model.attention_blocks.6.fc1.weight": "model-00001-of-00007.safetensors",
616
+ "biobrain_encoder.esm_model.attention_blocks.6.fc2.weight": "model-00001-of-00007.safetensors",
617
+ "biobrain_encoder.esm_model.attention_blocks.6.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
618
+ "biobrain_encoder.esm_model.attention_blocks.6.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
619
+ "biobrain_encoder.esm_model.attention_blocks.6.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
620
+ "biobrain_encoder.esm_model.attention_blocks.6.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
621
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.output.bias": "model-00001-of-00007.safetensors",
622
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.output.weight": "model-00001-of-00007.safetensors",
623
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_k.bias": "model-00001-of-00007.safetensors",
624
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_k.weight": "model-00001-of-00007.safetensors",
625
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_q.bias": "model-00001-of-00007.safetensors",
626
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_q.weight": "model-00001-of-00007.safetensors",
627
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_v.bias": "model-00001-of-00007.safetensors",
628
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_v.weight": "model-00001-of-00007.safetensors",
629
+ "biobrain_encoder.esm_model.attention_blocks.7.fc1.weight": "model-00001-of-00007.safetensors",
630
+ "biobrain_encoder.esm_model.attention_blocks.7.fc2.weight": "model-00001-of-00007.safetensors",
631
+ "biobrain_encoder.esm_model.attention_blocks.7.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
632
+ "biobrain_encoder.esm_model.attention_blocks.7.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
633
+ "biobrain_encoder.esm_model.attention_blocks.7.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
634
+ "biobrain_encoder.esm_model.attention_blocks.7.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
635
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.output.bias": "model-00001-of-00007.safetensors",
636
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.output.weight": "model-00001-of-00007.safetensors",
637
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_k.bias": "model-00001-of-00007.safetensors",
638
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_k.weight": "model-00001-of-00007.safetensors",
639
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_q.bias": "model-00001-of-00007.safetensors",
640
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_q.weight": "model-00001-of-00007.safetensors",
641
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_v.bias": "model-00001-of-00007.safetensors",
642
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_v.weight": "model-00001-of-00007.safetensors",
643
+ "biobrain_encoder.esm_model.attention_blocks.8.fc1.weight": "model-00001-of-00007.safetensors",
644
+ "biobrain_encoder.esm_model.attention_blocks.8.fc2.weight": "model-00001-of-00007.safetensors",
645
+ "biobrain_encoder.esm_model.attention_blocks.8.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
646
+ "biobrain_encoder.esm_model.attention_blocks.8.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
647
+ "biobrain_encoder.esm_model.attention_blocks.8.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
648
+ "biobrain_encoder.esm_model.attention_blocks.8.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
649
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.output.bias": "model-00001-of-00007.safetensors",
650
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.output.weight": "model-00001-of-00007.safetensors",
651
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_k.bias": "model-00001-of-00007.safetensors",
652
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_k.weight": "model-00001-of-00007.safetensors",
653
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_q.bias": "model-00001-of-00007.safetensors",
654
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_q.weight": "model-00001-of-00007.safetensors",
655
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_v.bias": "model-00001-of-00007.safetensors",
656
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_v.weight": "model-00001-of-00007.safetensors",
657
+ "biobrain_encoder.esm_model.attention_blocks.9.fc1.weight": "model-00001-of-00007.safetensors",
658
+ "biobrain_encoder.esm_model.attention_blocks.9.fc2.weight": "model-00001-of-00007.safetensors",
659
+ "biobrain_encoder.esm_model.attention_blocks.9.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
660
+ "biobrain_encoder.esm_model.attention_blocks.9.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
661
+ "biobrain_encoder.esm_model.attention_blocks.9.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
662
+ "biobrain_encoder.esm_model.attention_blocks.9.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
663
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.output.bias": "model-00001-of-00007.safetensors",
664
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.output.weight": "model-00001-of-00007.safetensors",
665
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_k.bias": "model-00001-of-00007.safetensors",
666
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_k.weight": "model-00001-of-00007.safetensors",
667
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_q.bias": "model-00001-of-00007.safetensors",
668
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_q.weight": "model-00001-of-00007.safetensors",
669
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_v.bias": "model-00001-of-00007.safetensors",
670
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_v.weight": "model-00001-of-00007.safetensors",
671
+ "biobrain_encoder.esm_model.embed_layer.weight": "model-00001-of-00007.safetensors",
672
+ "biobrain_encoder.esm_model.lm_head._fc1.bias": "model-00001-of-00007.safetensors",
673
+ "biobrain_encoder.esm_model.lm_head._fc1.weight": "model-00001-of-00007.safetensors",
674
+ "biobrain_encoder.esm_model.lm_head._final_fc.bias": "model-00001-of-00007.safetensors",
675
+ "biobrain_encoder.esm_model.lm_head._final_fc.weight": "model-00001-of-00007.safetensors",
676
+ "biobrain_encoder.esm_model.lm_head._first_layer_norm.bias": "model-00001-of-00007.safetensors",
677
+ "biobrain_encoder.esm_model.lm_head._first_layer_norm.weight": "model-00001-of-00007.safetensors",
678
+ "biobrain_encoder.esm_model.lm_head._second_layer_norm.bias": "model-00001-of-00007.safetensors",
679
+ "biobrain_encoder.esm_model.lm_head._second_layer_norm.weight": "model-00001-of-00007.safetensors",
680
+ "projection_model.bio_projection.bias": "model-00006-of-00007.safetensors",
681
+ "projection_model.bio_projection.weight": "model-00006-of-00007.safetensors",
682
+ "projection_model.perceiver_resampler.latent_queries": "model-00006-of-00007.safetensors",
683
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.output.bias": "model-00007-of-00007.safetensors",
684
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.output.weight": "model-00007-of-00007.safetensors",
685
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_k.bias": "model-00006-of-00007.safetensors",
686
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_k.weight": "model-00006-of-00007.safetensors",
687
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_q.bias": "model-00006-of-00007.safetensors",
688
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_q.weight": "model-00006-of-00007.safetensors",
689
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_v.bias": "model-00007-of-00007.safetensors",
690
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_v.weight": "model-00007-of-00007.safetensors",
691
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.output.bias": "model-00007-of-00007.safetensors",
692
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.output.weight": "model-00007-of-00007.safetensors",
693
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_k.bias": "model-00007-of-00007.safetensors",
694
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_k.weight": "model-00007-of-00007.safetensors",
695
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_q.bias": "model-00007-of-00007.safetensors",
696
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_q.weight": "model-00007-of-00007.safetensors",
697
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_v.bias": "model-00007-of-00007.safetensors",
698
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_v.weight": "model-00007-of-00007.safetensors",
699
+ "projection_model.perceiver_resampler.layers.0.fc1.bias": "model-00007-of-00007.safetensors",
700
+ "projection_model.perceiver_resampler.layers.0.fc1.weight": "model-00007-of-00007.safetensors",
701
+ "projection_model.perceiver_resampler.layers.0.fc2.bias": "model-00007-of-00007.safetensors",
702
+ "projection_model.perceiver_resampler.layers.0.fc2.weight": "model-00007-of-00007.safetensors",
703
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_1.bias": "model-00007-of-00007.safetensors",
704
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_1.weight": "model-00007-of-00007.safetensors",
705
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_2.bias": "model-00007-of-00007.safetensors",
706
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_2.weight": "model-00007-of-00007.safetensors",
707
+ "projection_model.perceiver_resampler.layers.0.norm_mlp.bias": "model-00007-of-00007.safetensors",
708
+ "projection_model.perceiver_resampler.layers.0.norm_mlp.weight": "model-00007-of-00007.safetensors",
709
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.output.bias": "model-00007-of-00007.safetensors",
710
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.output.weight": "model-00007-of-00007.safetensors",
711
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_k.bias": "model-00007-of-00007.safetensors",
712
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_k.weight": "model-00007-of-00007.safetensors",
713
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_q.bias": "model-00007-of-00007.safetensors",
714
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_q.weight": "model-00007-of-00007.safetensors",
715
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_v.bias": "model-00007-of-00007.safetensors",
716
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_v.weight": "model-00007-of-00007.safetensors",
717
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.output.bias": "model-00007-of-00007.safetensors",
718
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.output.weight": "model-00007-of-00007.safetensors",
719
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_k.bias": "model-00007-of-00007.safetensors",
720
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_k.weight": "model-00007-of-00007.safetensors",
721
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_q.bias": "model-00007-of-00007.safetensors",
722
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_q.weight": "model-00007-of-00007.safetensors",
723
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_v.bias": "model-00007-of-00007.safetensors",
724
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_v.weight": "model-00007-of-00007.safetensors",
725
+ "projection_model.perceiver_resampler.layers.1.fc1.bias": "model-00007-of-00007.safetensors",
726
+ "projection_model.perceiver_resampler.layers.1.fc1.weight": "model-00007-of-00007.safetensors",
727
+ "projection_model.perceiver_resampler.layers.1.fc2.bias": "model-00007-of-00007.safetensors",
728
+ "projection_model.perceiver_resampler.layers.1.fc2.weight": "model-00007-of-00007.safetensors",
729
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_1.bias": "model-00007-of-00007.safetensors",
730
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_1.weight": "model-00007-of-00007.safetensors",
731
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_2.bias": "model-00007-of-00007.safetensors",
732
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_2.weight": "model-00007-of-00007.safetensors",
733
+ "projection_model.perceiver_resampler.layers.1.norm_mlp.bias": "model-00007-of-00007.safetensors",
734
+ "projection_model.perceiver_resampler.layers.1.norm_mlp.weight": "model-00007-of-00007.safetensors",
735
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.output.bias": "model-00007-of-00007.safetensors",
736
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.output.weight": "model-00007-of-00007.safetensors",
737
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_k.bias": "model-00007-of-00007.safetensors",
738
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_k.weight": "model-00007-of-00007.safetensors",
739
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_q.bias": "model-00007-of-00007.safetensors",
740
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_q.weight": "model-00007-of-00007.safetensors",
741
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_v.bias": "model-00007-of-00007.safetensors",
742
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_v.weight": "model-00007-of-00007.safetensors",
743
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.output.bias": "model-00007-of-00007.safetensors",
744
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.output.weight": "model-00007-of-00007.safetensors",
745
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_k.bias": "model-00007-of-00007.safetensors",
746
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_k.weight": "model-00007-of-00007.safetensors",
747
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_q.bias": "model-00007-of-00007.safetensors",
748
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_q.weight": "model-00007-of-00007.safetensors",
749
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_v.bias": "model-00007-of-00007.safetensors",
750
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_v.weight": "model-00007-of-00007.safetensors",
751
+ "projection_model.perceiver_resampler.layers.2.fc1.bias": "model-00007-of-00007.safetensors",
752
+ "projection_model.perceiver_resampler.layers.2.fc1.weight": "model-00007-of-00007.safetensors",
753
+ "projection_model.perceiver_resampler.layers.2.fc2.bias": "model-00007-of-00007.safetensors",
754
+ "projection_model.perceiver_resampler.layers.2.fc2.weight": "model-00007-of-00007.safetensors",
755
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_1.bias": "model-00007-of-00007.safetensors",
756
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_1.weight": "model-00007-of-00007.safetensors",
757
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_2.bias": "model-00007-of-00007.safetensors",
758
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_2.weight": "model-00007-of-00007.safetensors",
759
+ "projection_model.perceiver_resampler.layers.2.norm_mlp.bias": "model-00007-of-00007.safetensors",
760
+ "projection_model.perceiver_resampler.layers.2.norm_mlp.weight": "model-00007-of-00007.safetensors",
761
+ "projection_model.token_embedding.weight": "model-00006-of-00007.safetensors"
762
+ }
763
+ }