Yanisadel commited on
Commit
39f4cde
·
1 Parent(s): ef46589

Upload model

Browse files
Files changed (3) hide show
  1. chatNT.py +1850 -0
  2. config.json +85 -0
  3. model.safetensors.index.json +763 -0
chatNT.py ADDED
@@ -0,0 +1,1850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file stores ChatNT and all associated layers and configs
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F # noqa: N812
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+
12
+
13
+ @dataclass
14
+ class RotaryEmbeddingConfig:
15
+ """
16
+ Rotary Positional Embedding configuration
17
+ max_seq_len: The number of positions to encode and cache.
18
+ dim: Dimension of RoPE.
19
+ theta: Rotation angle.
20
+ """
21
+
22
+ max_seq_len: int
23
+ dim: int
24
+ theta: float
25
+
26
+
27
+ @dataclass
28
+ class PerceiverResamplerConfig:
29
+ """
30
+ Parameters to initialize an PerceiverResampler model. Based on the ESM architecture.
31
+
32
+ Args:
33
+ emb_layer_norm_before: Whether to use layer norm before the first attention
34
+ layer.
35
+ attention_heads: Number of attention heads.
36
+ key_size: The dimension of the query, key, and values within each attention
37
+ head, if not specified, it is set to attention_heads//embed_dim.
38
+ It can be useful to set a custom key size if we want to impose the size of
39
+ the query, key and value tensor ( for example, tensors shaped with
40
+ power of 2 are more efficiently handled on TPUs ).
41
+ Note: Parametrizing the model with a custom key size has been done in :
42
+ Brown, Tom, et al. "Language models are few-shot learners."
43
+ Advances in neural information processing systems 33 (2020): 1877-1901.
44
+ embed_dim: Embedding dimension.
45
+ ffn_embed_dim: Feed forward embedding dimension.
46
+ num_layers: Number of attention blocks.
47
+ ffn_activation_name: Activation function to be used in FFN block. Supported
48
+ names are "gelu", "relu", "swish".
49
+ use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
50
+ Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
51
+ to True and use swish as ffn_activation_name.
52
+ Same principle for a gated-relu. To keep the same number of parameters in
53
+ the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
54
+ See https://arxiv.org/pdf/2002.05202.pdf for more details.
55
+ resampled_length: length of the resampled output of the module
56
+ use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
57
+ gradients in the forward pass to reduce the computation in the backward).
58
+ """
59
+
60
+ # architecture
61
+ emb_layer_norm_before: bool = False
62
+ attention_heads: int = 20
63
+ key_size: Optional[int] = None
64
+ embed_dim: int = 1280
65
+ ffn_embed_dim: int = 5120
66
+ num_layers: int = 24
67
+ add_bias_kv: bool = False
68
+ add_bias_ffn: bool = True
69
+ ffn_activation_name: str = "gelu-no-approx"
70
+ use_glu_in_ffn: bool = False
71
+ resampled_length: int = 64
72
+
73
+ # performance
74
+ use_gradient_checkpointing: bool = False
75
+
76
+ def __post_init__(self) -> None:
77
+ """
78
+ Checks that the given values are compatible.
79
+ """
80
+
81
+ if self.key_size is None:
82
+ if not self.embed_dim % self.attention_heads == 0:
83
+ raise ValueError(
84
+ f"When no key size is provided, the embedding dimension should be "
85
+ f"divisible by the number of heads, however provided embedding "
86
+ f"dimension is {self.embed_dim} and the number of heads is "
87
+ f"{self.attention_heads}."
88
+ )
89
+ self.key_size = self.embed_dim // self.attention_heads
90
+
91
+
92
+ @dataclass
93
+ class GptConfig:
94
+ """
95
+ Parameters to initialize a Gpt model.
96
+
97
+ NOTE: the pad token is not defined
98
+
99
+ Args:
100
+ vocab_size: Token vocabulary.
101
+ eos_token_id: used to stop sentence generation
102
+ embed_dim: Embedding dimension.
103
+ ffn_embed_dim: Feed forward embedding dimension.
104
+ num_heads: Number of attention heads.
105
+ num_kv_heads: Number of key and value heads to support Grouped-Query and
106
+ Multi-Query Attention. If None, the number of key and value heads is
107
+ equal to the number of attention heads.
108
+ num_layers: Number of Decoder layer_stack
109
+ rope_config: The configuration for the rotary positional embeddings
110
+ add_bias_ffn: Add bias in feed forward network block.
111
+ ffn_activation_name: Activation function to be used in FFN block. Supported
112
+ names are "gelu", "gelu-no-approx", "relu", "swish".
113
+ use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed
114
+ Forward Network (FFN) block.
115
+ example: To do a swiGLU (gated-swish) put this arg
116
+ to True and use swish as ffn_activation_name.
117
+ Same principle for a gated-relu.
118
+ add_bias_lm_head: whether to use bias in the final LM layer
119
+ norm_type: The type of norm used ( pre normalization scheme ) used. can be
120
+ one of ["layer_norm", "RMS_norm"]
121
+ parallel_attention_ff: Whether to do the attention and the MLP in parallel,
122
+ and then sum up the results as it is done in Gpt-NeoX :
123
+ Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive
124
+ language model." arXiv preprint arXiv:2204.06745 (2022).
125
+ It is said to improve the training time of 15% when compiling with JAX
126
+ use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
127
+ gradients in the forward pass to reduce the computation in the backward).
128
+ add_bias_attn: Add bias to the attention mechanism (key, query, value, and
129
+ output projections).
130
+ """
131
+
132
+ # vocabulary
133
+ vocab_size: int
134
+ eos_token_id: int
135
+
136
+ # architecture
137
+ embed_dim: int = 16
138
+ ffn_embed_dim: int = 64
139
+ num_heads: int = 2
140
+ num_kv_heads: Optional[int] = None
141
+ num_layers: int = 2
142
+ rope_config: RotaryEmbeddingConfig = field(
143
+ default_factory=lambda: RotaryEmbeddingConfig(
144
+ max_seq_len=512, dim=8, theta=10000.0
145
+ )
146
+ )
147
+ add_bias_ffn: bool = False
148
+ ffn_activation_name: str = "swish"
149
+ use_glu_in_ffn: bool = True
150
+ add_bias_lm_head: bool = False
151
+ norm_type: str = "RMS_norm"
152
+ rms_norm_eps: float = 1e-6
153
+ parallel_attention_ff: bool = True
154
+
155
+ # inference / backward behavior
156
+ use_gradient_checkpointing: bool = False
157
+
158
+ # architecture params with default values
159
+ add_bias_attn: bool = False
160
+
161
+ def __post_init__(self) -> None:
162
+ """
163
+ Checks that the given values are compatible.
164
+ """
165
+ if not self.embed_dim % self.num_heads == 0:
166
+ raise ValueError(
167
+ f"The embedding dimension should be "
168
+ f"divisible by the number of heads, however provided embedding "
169
+ f"dimension is {self.embed_dim} and the number of heads is "
170
+ f"{self.num_heads}."
171
+ )
172
+
173
+ if not self.embed_dim // self.num_heads > 1:
174
+ raise ValueError(
175
+ "embed_dim / num_heads must be higher than 2 to apply rotary embeddings"
176
+ )
177
+
178
+ if not self.embed_dim // self.num_heads >= self.rope_config.dim:
179
+ raise ValueError(
180
+ "embed_dim // num_heads must be higher than rope_config.dim "
181
+ "to apply rotary embeddings"
182
+ )
183
+
184
+ def to_dict(self): # type: ignore
185
+ output = asdict(self)
186
+ output["rope_config"] = asdict(self.rope_config)
187
+ return output
188
+
189
+
190
+ @dataclass
191
+ class ESMTransformerConfig:
192
+ """
193
+ Parameters to initialize an ESM model. While the ESM architecture is an encoder-only
194
+ model, different choices have been made for each version and this configuration aims
195
+ to cover most of them.
196
+
197
+ Args:
198
+ alphabet_size: Token vocabulary.
199
+ pad_token_id: ID of pad token.
200
+ mask_token_id: ID of mask token.
201
+ max_positions: Maximum sequence length.
202
+ embed_scale: Correction ratio applied to the embeddings to make up for the
203
+ norm difference between the input during training and inference.
204
+ emb_layer_norm_before: Whether to use layer norm before the first attention
205
+ layer.
206
+ attention_heads: Number of attention heads.
207
+ key_size: The dimension of the query, key, and values within each attention
208
+ head, if not specified, it is set to attention_heads//embed_dim.
209
+ It can be useful to set a custom key size if we want to impose the size of
210
+ the query, key and value tensor ( for example, tensors shaped with
211
+ power of 2 are more efficiently handled on TPUs ).
212
+ Note: Parametrizing the model with a custom key size has been done in :
213
+ Brown, Tom, et al. "Language models are few-shot learners."
214
+ Advances in neural information processing systems 33 (2020): 1877-1901.
215
+ embed_dim: Embedding dimension.
216
+ ffn_embed_dim: Feed forward embedding dimension.
217
+ num_layers: Number of attention blocks.
218
+ positional_embedding: Type of positional embedding to use before the first
219
+ attention layer. Options: "learned", "learned_standard" "sinusoidal" or
220
+ None.
221
+ NOTE: "learned" is the positional embedding of ESM, and "learned_standard"
222
+ is a more standard one, used for example in DNAbert.
223
+ lm_head: type of language model head. Options: "simple", "roberta" or None.
224
+ add_bias_kv: Add bias in attention layer.
225
+ add_bias_ffn: Add bias in feed forward network block.
226
+ use_rotary_embedding: Whether to use rotary embeddings (for ESM2). Requires:
227
+ positional_embeddings = None.
228
+ rescaling_factor: Scaling factor to use for rotary embeddings.
229
+ ffn_activation_name: Activation function to be used in FFN block. Supported
230
+ names are "gelu", "relu", "swish".
231
+ use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
232
+ Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
233
+ to True and use swish as ffn_activation_name.
234
+ Same principle for a gated-relu. To keep the same number of parameters in
235
+ the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
236
+ See https://arxiv.org/pdf/2002.05202.pdf for more details.
237
+ mask_before_attention: Use mask before attention layers (for EMS1b and ESM2).
238
+ layer_norm_eps: the eps factor in the different layer norms of the model (refer
239
+ to layer norm implementation)
240
+ token_dropout: Token dropout.
241
+ masking_ratio: Masking ratio (used if token dropout is enabled).
242
+ masking_prob: Masking probability (used if token dropout is enabled).
243
+ use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
244
+ gradients in the forward pass to reduce the computation in the backward).
245
+ """
246
+
247
+ alphabet_size: int
248
+ pad_token_id: int
249
+ mask_token_id: int
250
+
251
+ max_positions: int = 1024
252
+ embed_scale: float = 1.0
253
+
254
+ # architecture
255
+ emb_layer_norm_before: bool = False
256
+ attention_heads: int = 20
257
+ key_size: Optional[int] = None
258
+ embed_dim: int = 1280
259
+ ffn_embed_dim: int = 5120
260
+ num_layers: int = 24
261
+ positional_embedding: Optional[str] = "learned"
262
+ lm_head: Optional[str] = "simple"
263
+ add_bias_kv: bool = False
264
+ add_bias_ffn: bool = True
265
+ use_rotary_embedding: bool = False
266
+ rescaling_factor: Optional[float] = None
267
+ ffn_activation_name: str = "gelu-no-approx"
268
+ use_glu_in_ffn: bool = False
269
+ mask_before_attention: bool = False
270
+ layer_norm_eps: float = 1e-5
271
+ pre_layer_norm: bool = True
272
+ bias_word_embedding: bool = False
273
+
274
+ # dropout
275
+ token_dropout: bool = False
276
+ masking_ratio: float = 0.1
277
+ masking_prob: float = 0.8
278
+
279
+ # logging
280
+ use_gradient_checkpointing: bool = False
281
+
282
+ # return
283
+ embeddings_layers_to_save: List[int] = field(default_factory=list)
284
+ attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list)
285
+
286
+ def __post_init__(self) -> None:
287
+ """
288
+ Checks that the given values are compatible.
289
+ """
290
+
291
+ if self.key_size is None:
292
+ if not self.embed_dim % self.attention_heads == 0:
293
+ raise ValueError(
294
+ f"When no key size is provided, the embedding dimension should be "
295
+ f"divisible by the number of heads, however provided embedding "
296
+ f"dimension is {self.embed_dim} and the number of heads is "
297
+ f"{self.attention_heads}."
298
+ )
299
+ self.key_size = self.embed_dim // self.attention_heads
300
+ if self.positional_embedding is not None:
301
+ if type(self.positional_embedding) != str:
302
+ raise TypeError
303
+
304
+ if self.positional_embedding not in [
305
+ "learned",
306
+ "sinusoidal",
307
+ "learned_standard",
308
+ "alibi_dnabert_2",
309
+ ]:
310
+ raise ValueError(
311
+ "The positional_embedding argument should either be None,"
312
+ "`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'."
313
+ )
314
+ if self.lm_head is not None:
315
+ if type(self.lm_head) != str:
316
+ raise TypeError
317
+
318
+ if self.lm_head not in ["simple", "roberta"]:
319
+ raise ValueError(
320
+ "The lm_head argument should either be None,"
321
+ "`simple` or `roberta`."
322
+ )
323
+
324
+ if self.use_rotary_embedding and self.positional_embedding is not None:
325
+ raise ValueError(
326
+ "When using rotary embedding, positional_embedding must be set to none"
327
+ )
328
+
329
+ if self.add_bias_kv and self.use_rotary_embedding:
330
+ raise ValueError(
331
+ "Biases on key and values are not compatible with Rotary embeddings."
332
+ )
333
+
334
+ if self.positional_embedding == "alibi_dnabert_2":
335
+ assert not self.add_bias_kv
336
+
337
+
338
+ @dataclass
339
+ class ChatNTConfig(PretrainedConfig):
340
+ model_type = "ChatNT"
341
+
342
+ def __init__(self, **kwargs): # type: ignore
343
+ self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3))
344
+ self.esm_config: ESMTransformerConfig = kwargs.get(
345
+ "esm_config", ESMTransformerConfig(4000, 1, 4)
346
+ )
347
+ self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get(
348
+ "perceiver_resampler_config", PerceiverResamplerConfig()
349
+ )
350
+ self.seq_token_id: int = kwargs.get("seq_token_id", 32000)
351
+ self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1)
352
+ self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2)
353
+ super().__init__(**kwargs)
354
+
355
+ def to_dict(self): # type: ignore
356
+ output = super().to_dict()
357
+
358
+ def serialize(obj): # type: ignore
359
+ return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj)
360
+
361
+ output["gpt_config"] = serialize(self.gpt_config) # type: ignore
362
+ output["esm_config"] = serialize(self.esm_config) # type: ignore
363
+ output["perceiver_resampler_config"] = serialize( # type: ignore
364
+ self.perceiver_resampler_config
365
+ )
366
+ return output
367
+
368
+
369
+ class TorchBioBrainDecoder(nn.Module):
370
+ def __init__(
371
+ self,
372
+ gpt_config: GptConfig,
373
+ seq_token_id: int,
374
+ ):
375
+ """
376
+ Initializes the BioBrain decoder, using a GPT model for text generation with
377
+ bio embeddings.
378
+
379
+ Args:
380
+ gpt_config: Configuration for the GPT model
381
+ seq_token_id: Index of the SEQ token
382
+ """
383
+ super(TorchBioBrainDecoder, self).__init__()
384
+ self.gpt_config = gpt_config
385
+ self.seq_token_id = seq_token_id
386
+
387
+ # Initialize the GPT model (assumed you have it already in PyTorch)
388
+ self.gpt_model = TorchGptDecoder(self.gpt_config)
389
+
390
+ def forward(
391
+ self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor
392
+ ) -> torch.Tensor:
393
+ """
394
+ Forward pass through the model.
395
+
396
+ Args:
397
+ english_token_ids: Tensor of English token IDs with shape
398
+ (batch_size, num_english_tokens).
399
+ projected_bio_embeddings: Optional tensor of bio embeddings with shape
400
+ (batch_size, num_bio_sequences, ?, embed_dim).
401
+
402
+ Returns:
403
+ torch.Tensor: The logits from the GPT model,
404
+ shaped (batch_size, num_english_tokens, vocab_size).
405
+ """
406
+
407
+ # Compute English token embeddings
408
+ tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
409
+
410
+ if projected_bio_embeddings is not None:
411
+ (
412
+ batch_size,
413
+ num_bio_sequences,
414
+ _,
415
+ bio_embed_dim,
416
+ ) = projected_bio_embeddings.shape
417
+
418
+ # Insert the bio embeddings at the SEQ token positions
419
+ processed_tokens_ids = english_token_ids.clone()
420
+ for bio_seq_num in range(num_bio_sequences):
421
+ tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
422
+ processed_tokens_ids,
423
+ tokens_embeddings,
424
+ projected_bio_embeddings[:, bio_seq_num, :, :],
425
+ bio_seq_num=bio_seq_num,
426
+ )
427
+
428
+ # Regular GPT pass through
429
+ embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
430
+ embeddings = self.gpt_model.final_norm(embeddings)
431
+
432
+ # Compute logits
433
+ logits = self.gpt_model.lm_head(embeddings)
434
+
435
+ if projected_bio_embeddings is not None:
436
+ # Clean logits sequentially
437
+ processed_tokens_ids = english_token_ids.clone()
438
+ resampled_length = projected_bio_embeddings.shape[-2]
439
+ for _ in range(num_bio_sequences):
440
+ logits, processed_tokens_ids = self.cleanup_logits(
441
+ tokens=processed_tokens_ids,
442
+ logits=logits,
443
+ resampled_length=resampled_length,
444
+ )
445
+
446
+ return logits
447
+
448
+ def insert_embeddings(
449
+ self,
450
+ tokens: torch.Tensor,
451
+ input_embeddings: torch.Tensor,
452
+ resampled_embeddings: torch.Tensor,
453
+ bio_seq_num: int,
454
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
455
+ """
456
+ Inserts resampled embeddings in input_embeddings, starting at the SEQ token
457
+
458
+ Args:
459
+ tokens (torch.Tensor): Shape (batch_size, num_tokens)
460
+ input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
461
+ resampled_embeddings (torch.Tensor):
462
+ Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
463
+
464
+ Returns:
465
+ Tuple[torch.Tensor, torch.Tensor]:
466
+ - input_embeddings with resampled_embeddings inserted at the SEQ token
467
+ - tokens with the SEQ token set to -1
468
+ """
469
+
470
+ def _insert(
471
+ tokens_1d: torch.Tensor,
472
+ input_embeddings_1d: torch.Tensor,
473
+ resampled_embeddings_1d: torch.Tensor,
474
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
475
+ """
476
+ Args:
477
+ tokens (torch.Tensor): Shape (num_tokens,)
478
+ input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,)
479
+ resampled_embeddings (torch.Tensor):
480
+ Shape (bio_sequence_length, embed_dim,)
481
+ """
482
+ indices = torch.where(tokens_1d == self.seq_token_id)[0]
483
+ if indices.numel() > 0:
484
+ idx = indices[0].item()
485
+ insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
486
+ x = torch.cat(
487
+ [
488
+ input_embeddings_1d[:insertion_pos, :],
489
+ resampled_embeddings_1d,
490
+ input_embeddings_1d[insertion_pos:, :],
491
+ ],
492
+ dim=0,
493
+ )[: tokens_1d.shape[0] + 1, :]
494
+ x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[
495
+ :-1, :
496
+ ]
497
+ tokens_1d[idx] = -1
498
+ return x, tokens_1d
499
+ else:
500
+ return (
501
+ input_embeddings,
502
+ tokens_1d,
503
+ ) # Return unchanged if seq_token_id is not found
504
+
505
+ tokens_acc = []
506
+ embeddings_acc = []
507
+
508
+ for i in range(tokens.shape[0]):
509
+ embeddings_out, tokens_out = _insert(
510
+ tokens[i].clone(),
511
+ input_embeddings[i].clone(),
512
+ resampled_embeddings[i].clone(),
513
+ )
514
+ tokens_acc.append(tokens_out)
515
+ embeddings_acc.append(embeddings_out)
516
+ tokens_acc = torch.stack(tokens_acc)
517
+ embeddings_acc = torch.stack(embeddings_acc)
518
+
519
+ return embeddings_acc, tokens_acc
520
+
521
+ def cleanup_logits(
522
+ self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int
523
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
524
+ """
525
+ Removes the logits corresponding to the unused embeddings.
526
+
527
+ Args:
528
+ tokens: Input english tokens.
529
+ logits: Input logits.
530
+
531
+ Returns:
532
+ Cleaned logits, last values will be equal to 0.
533
+ """
534
+
535
+ def _clean(
536
+ token: torch.Tensor, logit: torch.Tensor
537
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
538
+ indices = torch.where(token == self.seq_token_id)[0]
539
+ if indices.numel() > 0:
540
+ idx = indices[0].item()
541
+
542
+ mask_idx = (
543
+ torch.arange(logit.shape[0] - resampled_length, device=logit.device)
544
+ > idx
545
+ )
546
+ mask_idx = mask_idx.unsqueeze(1)
547
+
548
+ # Remove values corresponding to bio tokens
549
+ logit = (
550
+ logit[:-resampled_length] * (~mask_idx)
551
+ + logit[resampled_length:] * mask_idx
552
+ )
553
+
554
+ # Append zeros at the end
555
+ logit = torch.cat(
556
+ (
557
+ logit,
558
+ torch.zeros(
559
+ (resampled_length, logit.shape[1]),
560
+ dtype=logit.dtype,
561
+ device=logit.device,
562
+ ),
563
+ )
564
+ )
565
+
566
+ # Update token
567
+ token[idx] = -1
568
+
569
+ return logit, token
570
+
571
+ else:
572
+ return logit, token
573
+
574
+ tokens_acc = []
575
+ logits_acc = []
576
+
577
+ for i in range(tokens.shape[0]):
578
+ logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone())
579
+ tokens_acc.append(tokens_out)
580
+ logits_acc.append(logits_out)
581
+ tokens_acc = torch.stack(tokens_acc)
582
+ logits_acc = torch.stack(logits_acc)
583
+
584
+ return logits_acc, tokens_acc
585
+
586
+
587
+ class TorchMultiOmicsModel(PreTrainedModel):
588
+ config_class = ChatNTConfig
589
+
590
+ def __init__(self, config: ChatNTConfig) -> None:
591
+ super().__init__(config=config)
592
+ self.gpt_config = config.gpt_config
593
+ self.esm_config = config.esm_config
594
+ self.perceiver_resampler_config = config.perceiver_resampler_config
595
+ self.seq_token_id = config.seq_token_id
596
+ self.bio_pad_token_id = config.bio_pad_token_id
597
+ self.english_pad_token_id = config.english_pad_token_id
598
+
599
+ # Correct seq_token_id
600
+ self.seq_token_id -= 1
601
+
602
+ self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config)
603
+ self.biobrain_decoder = TorchBioBrainDecoder(
604
+ gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
605
+ )
606
+ self.projection_model = TorchMultiModalPerceiverResamplerProjection(
607
+ perceiver_resampler_config=self.perceiver_resampler_config,
608
+ input_embed_dim=self.esm_config.embed_dim,
609
+ embed_dim=self.gpt_config.embed_dim,
610
+ english_vocab_size=self.gpt_config.vocab_size,
611
+ bio_pad_token_id=self.bio_pad_token_id,
612
+ english_pad_token_id=self.english_pad_token_id,
613
+ )
614
+
615
+ def forward(
616
+ self,
617
+ multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
618
+ projection_english_tokens_ids: torch.Tensor,
619
+ projected_bio_embeddings: torch.Tensor = None,
620
+ ) -> dict[str, torch.Tensor]:
621
+ """
622
+
623
+ Args:
624
+ multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
625
+ english_tokens_ids: Represents the prompt tokens (english tokens)
626
+ Shape (batch_size, num_english_tokens)
627
+
628
+ bio_tokens_ids: Represents the bio sequences tokens
629
+ Shape (batch_size, num_bio_sequences, num_bio_tokens)
630
+
631
+ projection_english_tokens_ids (torch.Tensor):
632
+ Shape (batch_size, num_english_tokens)
633
+
634
+ projected_bio_embeddings (projected_bio_embeddings, optional):
635
+ Shape (batch_size, num_bio_sequencse, ?, embed_dim).
636
+ Defaults to None.
637
+
638
+ Returns:
639
+ dict[str, torch.Tensor] containing:
640
+ - logits:
641
+ Shape (batch_size, num_tokens, vocab_size)
642
+
643
+ - projected_bio_embeddings:
644
+ Shape (batch_size, num_bio_sequences, ?, embed_dim)
645
+ """
646
+ english_token_ids, bio_token_ids = multi_omics_tokens_ids
647
+
648
+ # Replace config.vocab_size value in english tokens
649
+ # We do this because the default vocab size (32000) doesn't match with the
650
+ # number of tokens because of seq_token_id(=32000) that was added
651
+ # Therefore, we will put seq_token_id to 31999
652
+ # (I will also put token n°31999 to 0, which is for unknown token)
653
+ # This is a workaround to avoid having to change the vocab size in the config
654
+ vocab_size = self.gpt_config.vocab_size
655
+ # Replace vocab
656
+ english_token_ids[english_token_ids == vocab_size - 1] = 0
657
+ projection_english_tokens_ids[
658
+ projection_english_tokens_ids == vocab_size - 1
659
+ ] = 0
660
+ english_token_ids[english_token_ids == vocab_size] = vocab_size - 1
661
+ projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = (
662
+ vocab_size - 1
663
+ )
664
+
665
+ if bio_token_ids is None:
666
+ projected_bio_embeddings = None
667
+ else:
668
+ num_bio_sequences = bio_token_ids.shape[1]
669
+
670
+ if projected_bio_embeddings is None:
671
+ # Compute bio sequences embeddings
672
+ bio_embeddings_list = [
673
+ self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
674
+ for bio_seq_num in range(num_bio_sequences)
675
+ ]
676
+
677
+ # Project these embeddings
678
+ projected_bio_embeddings = [
679
+ self.projection_model(
680
+ bio_token_ids=bio_token_ids[:, bio_seq_num],
681
+ bio_embeddings=bio_embeddings,
682
+ english_token_ids=projection_english_tokens_ids,
683
+ )
684
+ for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
685
+ ]
686
+ projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
687
+
688
+ # decode
689
+ logits = self.biobrain_decoder(
690
+ english_token_ids=english_token_ids,
691
+ projected_bio_embeddings=projected_bio_embeddings,
692
+ )
693
+
694
+ outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
695
+
696
+ return outs
697
+
698
+
699
+ class TorchRotaryEmbedding(torch.nn.Module):
700
+ def __init__(self, config: RotaryEmbeddingConfig):
701
+ super().__init__()
702
+
703
+ self.max_seq_len = config.max_seq_len
704
+ self.dim = config.dim
705
+ self.theta = config.theta
706
+ self.sincos_cache = self._create_sinusoidal_positions()
707
+
708
+ def _create_sinusoidal_positions(self) -> torch.Tensor:
709
+ """
710
+ Create the sines and cosines for the RoPE.
711
+
712
+ Returns:
713
+ Sinusoidal positions of shape (self.max_seq_len, self.dim).
714
+ """
715
+ # Create the inverse frequency based on theta and dim
716
+ inv_freq = 1.0 / (
717
+ self.theta ** (torch.arange(0, self.dim, 2).float() / self.dim)
718
+ )
719
+
720
+ # Compute sinusoidal input using the broadcasting
721
+ sinusoid_inp = torch.einsum(
722
+ "i,j->ij", torch.arange(self.max_seq_len).float(), inv_freq
723
+ )
724
+
725
+ # Apply sin and cos to the sinusoidal input
726
+ sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
727
+
728
+ # Allocate a tensor for the final sin-cos values
729
+ sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32)
730
+
731
+ # Fill the sincos tensor with sin and cos values
732
+ sentinel = self.dim // 2 + self.dim % 2
733
+ sincos[:, :sentinel] = sin
734
+ sincos[:, sentinel:] = cos
735
+
736
+ return sincos
737
+
738
+ def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
739
+ """
740
+ Prepare a tensor to apply the RoPE mechanism.
741
+
742
+ Args:
743
+ x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
744
+ typically this is the key or query tensor.
745
+
746
+ Returns:
747
+ The even indices in the last dimension have their sign flipped.
748
+ Tensor of shape (batch_size, seq_len, num_heads, head_dim).
749
+ """
750
+ # Split the tensor into two halves (odd and even indexed dimensions)
751
+ rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1)
752
+
753
+ # Reshape the tensor to the original shape
754
+ rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,))
755
+ return rotate_half
756
+
757
+ def _apply_rotary_pos_emb(
758
+ self, x: torch.Tensor, sincos: torch.Tensor
759
+ ) -> torch.Tensor:
760
+ """
761
+ Applies rotary embeddings to x.
762
+
763
+ Args:
764
+ x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
765
+ typically this is the key or query tensor.
766
+ sincos: Tuple of sine and cosine tensors for position encoding.
767
+
768
+ Returns:
769
+ RoPE embeddings tensor.
770
+ """
771
+ sin_pos, cos_pos = sincos
772
+
773
+ # Reshape the sin and cos tensors for broadcasting
774
+ sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1)
775
+ cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1)
776
+
777
+ # Apply the rotary embedding mechanism
778
+ return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos)
779
+
780
+ def __call__(
781
+ self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None
782
+ ) -> tuple[torch.Tensor, torch.Tensor]:
783
+ """
784
+ Applies rotary embeddings to k and q.
785
+
786
+ Args:
787
+ k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
788
+ q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
789
+ positions: optional positions offset useful when caching,
790
+
791
+ Returns:
792
+ RoPE embeddings for the keys and values.
793
+ """
794
+ batch_size, seq_len, num_heads, head_dim = k.shape
795
+
796
+ # Generate position ids
797
+ position_ids = (
798
+ torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1)
799
+ )
800
+
801
+ if positions is not None:
802
+ position_ids += positions
803
+
804
+ # Retrieve sincos values using the position_ids
805
+ sincos = self.sincos_cache[position_ids]
806
+
807
+ # Split sincos into sin_pos and cos_pos
808
+ sincos = torch.chunk(sincos, 2, dim=-1)
809
+
810
+ # Apply rotary position embedding to key (k) and query (q)
811
+ k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos)
812
+ k_pass = k[..., self.dim :]
813
+
814
+ q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos)
815
+ q_pass = q[..., self.dim :]
816
+
817
+ # Concatenate the rotated and non-rotated parts
818
+ keys = torch.cat([k_rot, k_pass], dim=-1)
819
+ values = torch.cat([q_rot, q_pass], dim=-1)
820
+
821
+ return keys, values
822
+
823
+
824
+ class TorchGptGroupedQueryAttention(nn.Module):
825
+ def __init__(
826
+ self,
827
+ embed_dim: int,
828
+ num_heads: int,
829
+ rope_config: RotaryEmbeddingConfig,
830
+ num_kv_heads: int = None, # type: ignore
831
+ head_dim: int = None, # type: ignore
832
+ add_bias_attn: bool = False, # type: ignore
833
+ ) -> None:
834
+ super().__init__()
835
+ self.num_heads = num_heads
836
+ self.num_kv_heads = num_kv_heads or num_heads
837
+ self.embed_dim = embed_dim
838
+ self.head_dim = head_dim or (embed_dim // num_heads)
839
+ self.add_bias_attn = add_bias_attn
840
+ self.rope = TorchRotaryEmbedding(rope_config)
841
+
842
+ self.query_linear = nn.Linear(
843
+ embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn
844
+ )
845
+ self.key_linear = nn.Linear(
846
+ embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
847
+ )
848
+ self.value_linear = nn.Linear(
849
+ embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
850
+ )
851
+ self.out_linear = nn.Linear(
852
+ self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn
853
+ )
854
+
855
+ def forward(
856
+ self,
857
+ query_inputs: torch.Tensor,
858
+ key_inputs: torch.Tensor,
859
+ value_inputs: torch.Tensor,
860
+ attention_mask: torch.Tensor = None,
861
+ ) -> torch.Tensor:
862
+ batch_size, seq_len, _ = query_inputs.shape
863
+
864
+ queries = self.query_linear(query_inputs).view( # noqa
865
+ batch_size, seq_len, self.num_heads, self.head_dim
866
+ )
867
+ keys = self.key_linear(key_inputs).view( # noqa
868
+ batch_size, seq_len, self.num_kv_heads, self.head_dim
869
+ )
870
+ values = self.value_linear(value_inputs).view( # noqa
871
+ batch_size, seq_len, self.num_kv_heads, self.head_dim
872
+ )
873
+
874
+ keys, queries = self.rope(keys, queries)
875
+
876
+ n_rep = self.num_heads // self.num_kv_heads
877
+ keys = keys.repeat_interleave(n_rep, dim=2)
878
+ values = values.repeat_interleave(n_rep, dim=2)
879
+
880
+ attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / (
881
+ self.head_dim**0.5
882
+ )
883
+
884
+ if attention_mask is not None:
885
+ attention_logits = attention_logits.masked_fill(
886
+ attention_mask == 0, float("-inf")
887
+ )
888
+
889
+ attention_weights = nn.functional.softmax(attention_logits, dim=-1)
890
+
891
+ values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
892
+ values = values.contiguous().view(batch_size, seq_len, -1)
893
+
894
+ return self.out_linear(values)
895
+
896
+
897
+ class TorchGptDecoder(nn.Module):
898
+ def __init__(self, config: GptConfig, name: Optional[str] = None):
899
+ super().__init__()
900
+ self.config = config
901
+
902
+ self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim)
903
+
904
+ if config.norm_type == "layer_norm":
905
+ self.final_norm = nn.LayerNorm(config.embed_dim)
906
+ elif config.norm_type == "RMS_norm":
907
+ self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps)
908
+ else:
909
+ raise ValueError(f"unrecognized norm_type in config {config.norm_type}")
910
+
911
+ self.layers = nn.ModuleList(
912
+ [
913
+ TorchGptDecoderLayer(
914
+ embed_dim=config.embed_dim,
915
+ ffn_embed_dim=config.ffn_embed_dim,
916
+ num_heads=config.num_heads,
917
+ rope_config=config.rope_config,
918
+ norm_type=config.norm_type,
919
+ parallel_attention_ff=config.parallel_attention_ff,
920
+ add_bias_ffn=config.add_bias_ffn,
921
+ ffn_activation_name=config.ffn_activation_name,
922
+ use_glu_in_ffn=config.use_glu_in_ffn,
923
+ num_kv_heads=config.num_kv_heads, # type: ignore
924
+ add_bias_attn=config.add_bias_attn,
925
+ rms_norm_eps=config.rms_norm_eps,
926
+ )
927
+ for _ in range(config.num_layers)
928
+ ]
929
+ )
930
+
931
+ self.lm_head = TorchSimpleLMHead(
932
+ embed_dim=config.embed_dim,
933
+ alphabet_size=config.vocab_size,
934
+ add_bias_lm_head=config.add_bias_lm_head,
935
+ )
936
+
937
+ def apply_transformer_layers(
938
+ self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
939
+ ) -> torch.Tensor:
940
+ if attention_mask is None:
941
+ attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
942
+ for layer in self.layers:
943
+ embeddings = layer(embeddings, attention_mask)
944
+
945
+ return embeddings
946
+
947
+ def forward(
948
+ self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
949
+ ) -> dict[str, torch.Tensor]:
950
+ if attention_mask is None:
951
+ attention_mask = build_causal_attention_mask(1, token_ids.shape[1])
952
+
953
+ tokens_embeddings = self.token_embed(token_ids)
954
+
955
+ after_transformer_embeddings = self.apply_transformer_layers(
956
+ tokens_embeddings, attention_mask=attention_mask
957
+ )
958
+
959
+ embeddings = self.final_norm(after_transformer_embeddings)
960
+ logits = self.lm_head(embeddings)
961
+ return {"embeddings": embeddings, "logits": logits}
962
+
963
+
964
+ class TorchSimpleLMHead(nn.Module):
965
+ def __init__(
966
+ self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True
967
+ ) -> None:
968
+ super().__init__()
969
+ self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head)
970
+
971
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
972
+ return self.fc(x)
973
+
974
+
975
+ class TorchGptDecoderLayer(nn.Module):
976
+ def __init__(
977
+ self,
978
+ embed_dim: int,
979
+ ffn_embed_dim: int,
980
+ num_heads: int,
981
+ rope_config: RotaryEmbeddingConfig,
982
+ norm_type: str,
983
+ parallel_attention_ff: bool,
984
+ add_bias_ffn: bool,
985
+ ffn_activation_name: str,
986
+ use_glu_in_ffn: bool,
987
+ num_kv_heads: int,
988
+ add_bias_attn: bool,
989
+ rms_norm_eps: float = 1e-6,
990
+ ) -> None:
991
+ super().__init__()
992
+ self.num_heads = num_heads
993
+ self.parallel_attention_ff = parallel_attention_ff
994
+ self.use_glu_in_ffn = use_glu_in_ffn
995
+
996
+ # Self-Attention layer
997
+ self.self_attn = TorchGptGroupedQueryAttention(
998
+ embed_dim=embed_dim,
999
+ num_heads=num_heads,
1000
+ num_kv_heads=num_kv_heads,
1001
+ rope_config=rope_config,
1002
+ add_bias_attn=add_bias_attn,
1003
+ )
1004
+
1005
+ # Normalization layers
1006
+ if norm_type == "layer_norm":
1007
+ self.attn_norm = nn.LayerNorm(embed_dim)
1008
+ if not self.parallel_attention_ff:
1009
+ self.ffn_norm = nn.LayerNorm(embed_dim)
1010
+ elif norm_type == "RMS_norm":
1011
+ self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1012
+ if not self.parallel_attention_ff:
1013
+ self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1014
+ else:
1015
+ raise ValueError(f"unrecognized norm_type: {norm_type}")
1016
+
1017
+ # Feedforward network
1018
+ self.activation = get_activation_fn(ffn_activation_name)
1019
+ ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1)
1020
+ self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn)
1021
+ self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1022
+
1023
+ def forward(
1024
+ self, embeddings: torch.Tensor, attention_mask: torch.Tensor
1025
+ ) -> torch.Tensor:
1026
+ residuals = embeddings
1027
+
1028
+ if self.parallel_attention_ff:
1029
+ # Parallel Attention + MLP
1030
+ embeddings_normed = self.attn_norm(embeddings)
1031
+
1032
+ attn_output, _ = self.self_attn(
1033
+ embeddings_normed,
1034
+ embeddings_normed,
1035
+ embeddings_normed,
1036
+ attn_mask=attention_mask,
1037
+ )
1038
+ ffn_output = self.mlp(embeddings_normed) # type: ignore
1039
+
1040
+ return residuals + attn_output + ffn_output
1041
+ else:
1042
+ # Sequential Attention + MLP
1043
+ normed_embeddings = self.attn_norm(embeddings)
1044
+
1045
+ attn_output = embeddings + self.self_attn(
1046
+ normed_embeddings,
1047
+ normed_embeddings,
1048
+ normed_embeddings,
1049
+ attention_mask=attention_mask,
1050
+ )
1051
+
1052
+ normed_embeddings2 = self.ffn_norm(attn_output)
1053
+ ffn_output = self.mlp(normed_embeddings2) # type: ignore
1054
+ return attn_output + ffn_output # Residual connection
1055
+
1056
+ def mlp(self, x: torch.Tensor) -> torch.Tensor:
1057
+ """Applies the feedforward network (MLP) with optional GLU."""
1058
+ ffn_output = self.fc1(x)
1059
+
1060
+ if self.use_glu_in_ffn:
1061
+ ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1)
1062
+ ffn_output = self.activation(ffn_output1) * ffn_output2
1063
+ else:
1064
+ ffn_output = self.activation(ffn_output)
1065
+
1066
+ return self.fc2(ffn_output)
1067
+
1068
+
1069
+ class TorchRMSNorm(nn.Module):
1070
+ def __init__(self, dim: int, eps: float = 1e-6) -> None:
1071
+ super().__init__()
1072
+ self.eps = eps
1073
+ self.scale = nn.Parameter(torch.ones(dim))
1074
+
1075
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1076
+ return (
1077
+ x
1078
+ * self.scale
1079
+ / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
1080
+ )
1081
+
1082
+
1083
+ def get_activation_fn(activation_name: str): # type: ignore
1084
+ activations = {
1085
+ "gelu": nn.functional.gelu,
1086
+ "relu": nn.functional.relu,
1087
+ "swish": nn.functional.silu,
1088
+ "silu": nn.functional.silu,
1089
+ }
1090
+ return activations.get(activation_name, nn.functional.relu)
1091
+
1092
+
1093
+ def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
1094
+ """
1095
+ Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1096
+ to an attention layer.
1097
+
1098
+ Args:
1099
+ batch_size: Batch size.
1100
+ seq_len: Length of the sequences.
1101
+
1102
+ Returns:
1103
+ Batch of causal masks.
1104
+ """
1105
+ mask = torch.ones((batch_size, 1, seq_len, seq_len))
1106
+ causal_mask = torch.tril(mask)
1107
+ return causal_mask
1108
+
1109
+
1110
+ @dataclass
1111
+ class RotaryEmbeddingConfigBis:
1112
+ """
1113
+ Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows
1114
+ to adapt the rotary embeddings to larger lengths than what was used for training.
1115
+ One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa
1116
+ Args:
1117
+ """
1118
+
1119
+ rescaling_factor: Optional[float]
1120
+
1121
+
1122
+ class RotaryEmbeddingBis(torch.nn.Module):
1123
+ """
1124
+ Rotary position embeddings based on those in
1125
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).
1126
+ Query and keys are transformed by rotation
1127
+ matrices which depend on their relative positions.
1128
+ """
1129
+
1130
+ def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis):
1131
+ super().__init__()
1132
+
1133
+ # Extract argument from the config
1134
+ self.rescaling_factor = rotary_embedding_config.rescaling_factor
1135
+ self.upper_freq = 10000
1136
+ self.dim = dim
1137
+
1138
+ self._seq_len_cached = None
1139
+ self._cos_cached = None
1140
+ self._sin_cached = None
1141
+
1142
+ def _apply_rotary_pos_emb(
1143
+ self,
1144
+ heads: torch.Tensor,
1145
+ cos: torch.Tensor,
1146
+ sin: torch.Tensor,
1147
+ ) -> torch.Tensor:
1148
+ """ """
1149
+ x_first, x_second = (
1150
+ heads[..., : heads.shape[-1] // 2],
1151
+ heads[..., heads.shape[-1] // 2 :],
1152
+ )
1153
+
1154
+ first_part = x_first * cos - x_second * sin
1155
+ second_part = x_second * cos + x_first * sin
1156
+
1157
+ return torch.cat((first_part, second_part), dim=-1)
1158
+
1159
+ def _compute_cos_sin_tables(
1160
+ self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1161
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1162
+ seq_len = x.shape[seq_dimension]
1163
+ # Reset the tables if the sequence length has changed,
1164
+ # or if we're on a new device (possibly due to tracing for instance)
1165
+ self._seq_len_cached = seq_len
1166
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
1167
+ # freqs = torch.outer(t, inv_freq)
1168
+ freqs = torch.einsum("i, j -> ij", t, inv_freq)
1169
+
1170
+ self._cos_cached = torch.cos(freqs)[None, :, None, :]
1171
+ self._sin_cached = torch.sin(freqs)[None, :, None, :]
1172
+ # emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1173
+
1174
+ # self._cos_cached = emb.cos()[None, None, :, :]
1175
+ # self._sin_cached = emb.sin()[None, None, :, :]
1176
+
1177
+ return self._cos_cached, self._sin_cached
1178
+
1179
+ def forward(
1180
+ self, q: torch.Tensor, k: torch.Tensor
1181
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1182
+ if self.rescaling_factor is None:
1183
+ inv_freq = 1.0 / (
1184
+ self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim)
1185
+ )
1186
+ else:
1187
+ updated_base = self.upper_freq * (
1188
+ self.rescaling_factor ** (self.dim / (self.dim - 2))
1189
+ )
1190
+ inv_freq = 1.0 / (
1191
+ updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
1192
+ )
1193
+
1194
+ self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1195
+ q,
1196
+ inv_freq,
1197
+ seq_dimension=-3,
1198
+ )
1199
+
1200
+ return (
1201
+ self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1202
+ self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1203
+ )
1204
+
1205
+
1206
+ class MultiHeadAttention(nn.Module):
1207
+ def __init__(
1208
+ self,
1209
+ num_heads: int,
1210
+ key_size: int,
1211
+ rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1212
+ add_bias_kv: bool = False,
1213
+ value_size: Optional[int] = None,
1214
+ model_size: Optional[int] = None,
1215
+ name: Optional[str] = None,
1216
+ ):
1217
+ super().__init__()
1218
+ if not model_size:
1219
+ model_size = key_size * num_heads
1220
+ if not value_size:
1221
+ value_size = key_size
1222
+ self.model_size = model_size
1223
+ self.key_size = key_size
1224
+ self.value_size = value_size
1225
+ self.add_bias_kv = add_bias_kv
1226
+ self.name = name
1227
+ self.num_heads = num_heads
1228
+ self._rotary_embedding_config = rotary_embedding_config
1229
+
1230
+ self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
1231
+ self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
1232
+ self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
1233
+ self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
1234
+ if self._rotary_embedding_config:
1235
+ self._rotary_embedding = RotaryEmbeddingBis(
1236
+ self.key_size, self._rotary_embedding_config
1237
+ )
1238
+
1239
+ def apply_rotary_embeddings(
1240
+ self,
1241
+ query: torch.Tensor,
1242
+ key: torch.Tensor,
1243
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1244
+ """ """
1245
+ query, key = self._rotary_embedding(query, key)
1246
+ return query, key
1247
+
1248
+ def forward(
1249
+ self,
1250
+ query: torch.Tensor,
1251
+ key: torch.Tensor,
1252
+ value: torch.Tensor,
1253
+ attention_mask: Optional[torch.Tensor] = None,
1254
+ attention_weight_bias: Optional[torch.Tensor] = None,
1255
+ ) -> dict[str, torch.Tensor]:
1256
+ """
1257
+ Returns:
1258
+ dictionary containing attention weights
1259
+ and outputs.
1260
+ """
1261
+ key_heads = self.w_k(key).reshape(
1262
+ (*key.shape[:-1], self.num_heads, self.key_size)
1263
+ )
1264
+ query_heads = self.w_q(query).reshape(
1265
+ (*query.shape[:-1], self.num_heads, self.key_size)
1266
+ )
1267
+ value_heads = self.w_v(value).reshape(
1268
+ (*value.shape[:-1], self.num_heads, self.value_size)
1269
+ )
1270
+ if self._rotary_embedding_config:
1271
+ query_heads, key_heads = self.apply_rotary_embeddings(
1272
+ query_heads, key_heads
1273
+ )
1274
+ attention_weights = torch.einsum(
1275
+ "...thd, ...Thd -> ...htT", query_heads, key_heads
1276
+ )
1277
+ sqrt_key_size = np.sqrt(self.key_size)
1278
+ attention_weights = attention_weights / sqrt_key_size
1279
+ if attention_mask is not None:
1280
+ attention_weights = torch.where(attention_mask, attention_weights, -1e30)
1281
+ if attention_weight_bias is not None:
1282
+ attention_weights = F.softmax(
1283
+ attention_weights + attention_weight_bias, dim=-1
1284
+ )
1285
+ else:
1286
+ attention_weights = F.softmax(attention_weights, dim=-1)
1287
+ value_out = torch.einsum(
1288
+ "...htT, ...Thd->...thd", attention_weights, value_heads
1289
+ )
1290
+ value_out = value_out.reshape((*value_out.shape[:-2], -1))
1291
+ embeddings = self.output(value_out)
1292
+
1293
+ return {"attention_weights": attention_weights, "embeddings": embeddings}
1294
+
1295
+
1296
+ class SelfAttentionBlock(nn.Module):
1297
+ def __init__(
1298
+ self,
1299
+ num_heads: int,
1300
+ embed_dim: int,
1301
+ ffn_embed_dim: int,
1302
+ key_size: Optional[int] = None,
1303
+ add_bias_kv: bool = False,
1304
+ add_bias_fnn: bool = True,
1305
+ ffn_activation_name: str = "gelu-no-approx",
1306
+ use_glu_in_ffn: bool = False,
1307
+ layer_norm_eps: float = 1e-5, # this is the default haiku value
1308
+ pre_layer_norm: bool = True,
1309
+ name: Optional[str] = None,
1310
+ rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1311
+ ):
1312
+ super().__init__()
1313
+ if key_size is None:
1314
+ if embed_dim % num_heads != 0:
1315
+ raise ValueError(
1316
+ f"The embedding dimension should be divisible by the number of "
1317
+ f"heads, however provided embedding dimension is {embed_dim} and "
1318
+ f"the number of heads is {num_heads}."
1319
+ )
1320
+ else:
1321
+ key_size = embed_dim // num_heads
1322
+
1323
+ # Get ffn activation function
1324
+ self._pre_layer_norm = pre_layer_norm
1325
+ self._use_glu_in_fnn = use_glu_in_ffn
1326
+ # Define layers
1327
+ if use_glu_in_ffn:
1328
+ # user should multiply ffn_embed_dim by 2/3 when using GLU
1329
+ # to keep total number of parameters equal
1330
+ # see https://arxiv.org/pdf/2002.05202.pdf. for more details
1331
+ # we multiply by 2 here as the output will be split in 2 for GLU
1332
+ self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
1333
+ else:
1334
+ self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
1335
+
1336
+ self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
1337
+
1338
+ self.layer_norm_self_attention = nn.LayerNorm(
1339
+ embed_dim,
1340
+ )
1341
+ self.layer_norm_mlp = nn.LayerNorm(embed_dim)
1342
+ if ffn_activation_name == "swish":
1343
+ self._ffn_activation_fn = nn.SiLU()
1344
+ elif ffn_activation_name == "gelu-no-approx":
1345
+ self._ffn_activation_fn = nn.GELU(approximate="tanh")
1346
+ else:
1347
+ self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
1348
+
1349
+ self.mha = MultiHeadAttention(
1350
+ num_heads=num_heads,
1351
+ key_size=key_size,
1352
+ add_bias_kv=add_bias_kv,
1353
+ model_size=embed_dim,
1354
+ name="self_attention",
1355
+ rotary_embedding_config=rotary_embedding_config,
1356
+ )
1357
+
1358
+ def mlp(self, embed: torch.Tensor) -> torch.Tensor:
1359
+
1360
+ if self._pre_layer_norm:
1361
+ x = self.layer_norm_mlp(embed)
1362
+ else:
1363
+ x = embed
1364
+
1365
+ if self._use_glu_in_fnn:
1366
+ x = self.fc1(x)
1367
+ x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
1368
+ x = self._ffn_activation_fn(x1) * x2
1369
+ else:
1370
+ x = self._ffn_activation_fn(self.fc1(x))
1371
+ x = self.fc2(x)
1372
+
1373
+ if not self._pre_layer_norm:
1374
+ x = self.layer_norm_mlp(x + embed)
1375
+ return x
1376
+
1377
+ def forward(
1378
+ self,
1379
+ x: torch.Tensor,
1380
+ attention_mask: Optional[torch.Tensor] = None,
1381
+ attention_weight_bias: Optional[torch.Tensor] = None,
1382
+ ) -> dict[str, torch.Tensor]:
1383
+
1384
+ res = x
1385
+ if self._pre_layer_norm:
1386
+ x = self.layer_norm_self_attention(x)
1387
+
1388
+ output: dict[str, torch.Tensor] = self.mha(
1389
+ x,
1390
+ x,
1391
+ x,
1392
+ attention_mask=attention_mask,
1393
+ attention_weight_bias=attention_weight_bias,
1394
+ )
1395
+
1396
+ if not self._pre_layer_norm:
1397
+ output["embeddings"] = self.layer_norm_self_attention(
1398
+ output["embeddings"] + res
1399
+ )
1400
+
1401
+ x = output["embeddings"]
1402
+ else:
1403
+ x = output["embeddings"]
1404
+ x = res + x
1405
+
1406
+ # MLP
1407
+ if not self._pre_layer_norm:
1408
+ x = self.mlp(x)
1409
+ else:
1410
+ x = x + self.mlp(x)
1411
+
1412
+ output["embeddings"] = x
1413
+ return output
1414
+
1415
+
1416
+ class RobertaLMHead(nn.Module):
1417
+ """
1418
+ Roberta Language Model head. Transforms final attention layer output into a
1419
+ distribution over tokens at each position.
1420
+ """
1421
+
1422
+ def __init__(self, embed_dim: int, alphabet_size: int):
1423
+ """
1424
+ Args:
1425
+ embed_dim: Embedding dimension.
1426
+ alphabet_size: Number of tokens in the alphabet.
1427
+ """
1428
+ super().__init__()
1429
+ self.embed_dim = embed_dim
1430
+ self.alphabet_size = alphabet_size
1431
+
1432
+ # Define layers
1433
+ self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1434
+ self._fc1 = nn.Linear(embed_dim, embed_dim)
1435
+ self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1436
+ self._final_fc = nn.Linear(embed_dim, alphabet_size)
1437
+
1438
+ def forward(self, x: torch.Tensor) -> dict:
1439
+ x = self._first_layer_norm(x)
1440
+ embeddings = x
1441
+ x = self._fc1(x)
1442
+ x = nn.functional.gelu(x)
1443
+ x = self._second_layer_norm(x)
1444
+ logits = self._final_fc(x)
1445
+ return {"embeddings": embeddings, "logits": logits}
1446
+
1447
+
1448
+ class TorchESMTransformer(nn.Module):
1449
+ def __init__(
1450
+ self,
1451
+ esm_config: ESMTransformerConfig,
1452
+ ):
1453
+ super(TorchESMTransformer, self).__init__()
1454
+ self.esm_config = esm_config
1455
+
1456
+ # Other cases are not implemented
1457
+ assert esm_config.positional_embedding is None
1458
+ assert esm_config.lm_head == "roberta"
1459
+ assert esm_config.use_rotary_embedding is True
1460
+ assert esm_config.token_dropout is False
1461
+ assert esm_config.emb_layer_norm_before is False
1462
+ assert esm_config.mask_before_attention is False
1463
+ assert esm_config.bias_word_embedding is False
1464
+ assert esm_config.use_gradient_checkpointing is False
1465
+
1466
+ self.embed_layer = nn.Embedding(esm_config.alphabet_size, esm_config.embed_dim)
1467
+
1468
+ self.lm_head = RobertaLMHead(
1469
+ embed_dim=esm_config.embed_dim,
1470
+ alphabet_size=esm_config.alphabet_size,
1471
+ )
1472
+
1473
+ self.rotary_embedding_config = RotaryEmbeddingConfigBis(
1474
+ rescaling_factor=esm_config.rescaling_factor
1475
+ )
1476
+
1477
+ self.attention_blocks = nn.ModuleList(
1478
+ [
1479
+ SelfAttentionBlock( # type: ignore
1480
+ num_heads=esm_config.attention_heads,
1481
+ embed_dim=esm_config.embed_dim,
1482
+ key_size=esm_config.key_size,
1483
+ ffn_embed_dim=esm_config.ffn_embed_dim,
1484
+ add_bias_kv=esm_config.add_bias_kv,
1485
+ add_bias_fnn=esm_config.add_bias_ffn,
1486
+ ffn_activation_name=esm_config.ffn_activation_name,
1487
+ use_glu_in_ffn=esm_config.use_glu_in_ffn,
1488
+ rotary_embedding_config=self.rotary_embedding_config,
1489
+ layer_norm_eps=esm_config.layer_norm_eps,
1490
+ pre_layer_norm=esm_config.pre_layer_norm,
1491
+ )
1492
+ for _ in range(esm_config.num_layers)
1493
+ ]
1494
+ )
1495
+
1496
+ def forward(
1497
+ self, tokens: torch.Tensor, attention_mask: torch.Tensor = None
1498
+ ) -> torch.Tensor:
1499
+ """
1500
+ Computes the embeddings based on the input tokens.
1501
+
1502
+ Args:
1503
+ tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1504
+ attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1505
+ If no mask is provided, a mask by default which equals 1 over all non
1506
+ pad tokens and 0 over pad tokens is computed.
1507
+
1508
+ Returns:
1509
+ Dictionary containing the final embeddings and logits.
1510
+ """
1511
+ x = self.embed_layer(tokens)
1512
+
1513
+ # RoBERTa's mask scaling factor
1514
+ x = self.esm_config.embed_scale * x
1515
+
1516
+ if attention_mask is None:
1517
+ attention_mask = build_padding_attention_mask(
1518
+ tokens=tokens, pad_token_id=self.esm_config.pad_token_id
1519
+ )
1520
+
1521
+ for layer in self.attention_blocks:
1522
+ x = layer(x, attention_mask)["embeddings"]
1523
+
1524
+ assert self.esm_config.lm_head == "roberta"
1525
+ x = self.lm_head(x)["embeddings"]
1526
+
1527
+ return x
1528
+
1529
+
1530
+ def build_padding_attention_mask(
1531
+ tokens: torch.Tensor, pad_token_id: int
1532
+ ) -> torch.Tensor:
1533
+ """
1534
+ Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
1535
+
1536
+ Args:
1537
+ tokens: Batch of sequences of shape (batch_size, seq_len).
1538
+ pad_token_id: Int corresponding to the <pad> token to mask.
1539
+
1540
+ Returns:
1541
+ Batch of attention masks, masking out <pad> tokens.
1542
+ """
1543
+ padding_mask = tokens != pad_token_id
1544
+ padding_mask = padding_mask.unsqueeze(1)
1545
+ padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask)
1546
+ return padding_mask
1547
+
1548
+
1549
+ class TorchBioBrainEncoder(nn.Module):
1550
+ def __init__(
1551
+ self,
1552
+ esm_config: ESMTransformerConfig,
1553
+ ):
1554
+ super(TorchBioBrainEncoder, self).__init__()
1555
+ self.esm_config = esm_config
1556
+ self.esm_model = TorchESMTransformer(self.esm_config)
1557
+
1558
+ def forward(
1559
+ self,
1560
+ bio_token_ids: torch.Tensor,
1561
+ ) -> torch.Tensor:
1562
+ """
1563
+ Args:
1564
+ bio_token_ids (torch.Tensor):
1565
+ Shape (batch_size, num_bio_tokens)
1566
+
1567
+ Returns:
1568
+ torch.Tensor:
1569
+ Shape (batch_size, num_bio_tokens, embed_dim)
1570
+ """
1571
+ bio_embeddings = self.esm_model(tokens=bio_token_ids)
1572
+
1573
+ return bio_embeddings
1574
+
1575
+
1576
+ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1577
+ def __init__(
1578
+ self,
1579
+ num_heads: int,
1580
+ embed_dim: int,
1581
+ ffn_embed_dim: int,
1582
+ key_size: Optional[int] = None,
1583
+ add_bias_kv: bool = False,
1584
+ add_bias_ffn: bool = True,
1585
+ ffn_activation_name: str = "gelu",
1586
+ use_glu_in_ffn: bool = False,
1587
+ ):
1588
+ super().__init__()
1589
+
1590
+ if key_size is None:
1591
+ if embed_dim % num_heads != 0:
1592
+ raise ValueError(
1593
+ f"Embedding dimension {embed_dim} should be divisible by "
1594
+ f"num_heads {num_heads}."
1595
+ )
1596
+ key_size = embed_dim // num_heads
1597
+
1598
+ self.num_heads = num_heads
1599
+ self.embed_dim = embed_dim
1600
+ self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim
1601
+ self.use_glu_in_ffn = use_glu_in_ffn
1602
+
1603
+ self.cross_attention_1 = MultiHeadAttention(
1604
+ num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1605
+ )
1606
+ self.cross_attention_2 = MultiHeadAttention(
1607
+ num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1608
+ )
1609
+
1610
+ self.norm_cross_attention_1 = nn.LayerNorm(embed_dim)
1611
+ self.norm_cross_attention_2 = nn.LayerNorm(embed_dim)
1612
+ self.norm_mlp = nn.LayerNorm(embed_dim)
1613
+
1614
+ self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn)
1615
+ self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1616
+
1617
+ self.activation_fn = getattr(
1618
+ nn.functional, ffn_activation_name, nn.functional.gelu
1619
+ )
1620
+
1621
+ def mlp(self, x: torch.Tensor) -> torch.Tensor:
1622
+ x = self.norm_mlp(x)
1623
+ if self.use_glu_in_ffn:
1624
+ x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
1625
+ x = self.activation_fn(x1) * x2
1626
+ else:
1627
+ x = self.activation_fn(self.fc1(x))
1628
+ return self.fc2(x)
1629
+
1630
+ def forward(
1631
+ self,
1632
+ x: torch.Tensor,
1633
+ cross_attention_embeddings_1: torch.Tensor,
1634
+ cross_attention_embeddings_2: torch.Tensor,
1635
+ attention_mask_1: Optional[torch.Tensor] = None,
1636
+ attention_mask_2: Optional[torch.Tensor] = None,
1637
+ ) -> Dict[str, torch.Tensor]:
1638
+ res = x
1639
+ x = self.norm_cross_attention_1(x)
1640
+
1641
+ attn_output = self.cross_attention_1(
1642
+ query=x,
1643
+ key=cross_attention_embeddings_1,
1644
+ value=cross_attention_embeddings_1,
1645
+ attention_mask=attention_mask_1,
1646
+ )["embeddings"]
1647
+ x = res + attn_output
1648
+
1649
+ res = x
1650
+ x = self.norm_cross_attention_2(x)
1651
+ attn_output = self.cross_attention_2(
1652
+ query=x,
1653
+ key=cross_attention_embeddings_2,
1654
+ value=cross_attention_embeddings_2,
1655
+ attention_mask=attention_mask_2,
1656
+ )["embeddings"]
1657
+ x = res + attn_output
1658
+
1659
+ x = x + self.mlp(x)
1660
+
1661
+ return {"embeddings": x}
1662
+
1663
+
1664
+ class TorchMultiModalPerceiverResampler(nn.Module):
1665
+ """
1666
+ Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1667
+ """
1668
+
1669
+ def __init__(
1670
+ self,
1671
+ config: PerceiverResamplerConfig,
1672
+ name: Optional[str] = None,
1673
+ ):
1674
+ """
1675
+ Initialize a Perceiver Resampler model.
1676
+
1677
+ Args:
1678
+ config: Dataclass containing model hyperparameters.
1679
+ name: Name for module (custom will break weight loading).
1680
+ """
1681
+ super().__init__()
1682
+ self.config = config
1683
+ self.name = name
1684
+ self.layers = nn.ModuleList(
1685
+ [
1686
+ TorchMultiModalPerceiverResamplerBlock(
1687
+ num_heads=self.config.attention_heads,
1688
+ embed_dim=self.config.embed_dim,
1689
+ key_size=self.config.key_size,
1690
+ ffn_embed_dim=self.config.ffn_embed_dim,
1691
+ add_bias_kv=self.config.add_bias_kv,
1692
+ add_bias_ffn=self.config.add_bias_ffn,
1693
+ ffn_activation_name=self.config.ffn_activation_name,
1694
+ use_glu_in_ffn=self.config.use_glu_in_ffn,
1695
+ )
1696
+ for _ in range(self.config.num_layers)
1697
+ ]
1698
+ )
1699
+
1700
+ self.latent_queries = torch.nn.Parameter(
1701
+ torch.randn(self.config.resampled_length, self.config.embed_dim)
1702
+ * (
1703
+ 1.0
1704
+ / torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32))
1705
+ )
1706
+ )
1707
+
1708
+ def apply_attention_blocks(
1709
+ self,
1710
+ x: torch.Tensor,
1711
+ xf_1: torch.Tensor,
1712
+ xf_2: torch.Tensor,
1713
+ outs: Dict[str, torch.Tensor],
1714
+ attention_mask_1: Optional[torch.Tensor] = None,
1715
+ attention_mask_2: Optional[torch.Tensor] = None,
1716
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1717
+ """
1718
+ Create the blocks of attention layers and applies them.
1719
+ """
1720
+ for layer in self.layers:
1721
+ concat_input_1 = torch.cat([xf_1, x], dim=1)
1722
+ concat_input_2 = torch.cat([xf_2, x], dim=1)
1723
+
1724
+ output = layer(
1725
+ x=x,
1726
+ cross_attention_embeddings_1=concat_input_1,
1727
+ cross_attention_embeddings_2=concat_input_2,
1728
+ attention_mask_1=attention_mask_1,
1729
+ attention_mask_2=attention_mask_2,
1730
+ )
1731
+ x = output["embeddings"]
1732
+
1733
+ return x, outs
1734
+
1735
+ def forward(
1736
+ self,
1737
+ input_embeddings_1: torch.Tensor,
1738
+ input_embeddings_2: torch.Tensor,
1739
+ attention_mask_1: Optional[torch.Tensor] = None,
1740
+ attention_mask_2: Optional[torch.Tensor] = None,
1741
+ ) -> Dict[str, torch.Tensor]:
1742
+ """
1743
+ Computes the embeddings based on the input tokens.
1744
+ """
1745
+ assert (
1746
+ input_embeddings_1.shape[-1] == self.config.embed_dim
1747
+ ), "The input embedding dim should match the model embed dim"
1748
+ assert (
1749
+ input_embeddings_2.shape[-1] == self.config.embed_dim
1750
+ ), "The input embedding dim should match the model embed dim"
1751
+
1752
+ batch_size = input_embeddings_1.shape[0]
1753
+
1754
+ latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1)
1755
+
1756
+ outs: Dict[str, torch.Tensor] = {}
1757
+ x = latent_queries
1758
+
1759
+ x, outs = self.apply_attention_blocks(
1760
+ x=x,
1761
+ xf_1=input_embeddings_1,
1762
+ xf_2=input_embeddings_2,
1763
+ outs=outs,
1764
+ attention_mask_1=attention_mask_1,
1765
+ attention_mask_2=attention_mask_2,
1766
+ )
1767
+
1768
+ outs["embeddings"] = x
1769
+
1770
+ return outs
1771
+
1772
+
1773
+ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1774
+ def __init__(
1775
+ self,
1776
+ perceiver_resampler_config: PerceiverResamplerConfig,
1777
+ input_embed_dim: int,
1778
+ embed_dim: int,
1779
+ bio_pad_token_id: int,
1780
+ english_pad_token_id: int,
1781
+ english_vocab_size: int,
1782
+ ):
1783
+ super().__init__()
1784
+ self.config = perceiver_resampler_config
1785
+ self.input_embed_dim = input_embed_dim
1786
+ self.embed_dim = embed_dim
1787
+ self.bio_pad_token_id = bio_pad_token_id
1788
+ self.english_pad_token_id = english_pad_token_id
1789
+ self.english_vocab_size = english_vocab_size
1790
+
1791
+ self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1792
+ self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1793
+ self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
1794
+
1795
+ def forward(
1796
+ self,
1797
+ bio_token_ids: torch.Tensor,
1798
+ bio_embeddings: torch.Tensor,
1799
+ english_token_ids: torch.Tensor,
1800
+ ) -> torch.Tensor:
1801
+ """
1802
+ Args:
1803
+ bio_token_ids (torch.Tensor):
1804
+ Shape (batch_size, num_bio_tokens)
1805
+
1806
+ bio_embeddings (torch.Tensor):
1807
+ Shape (batch_size, num_bio_tokens, embed_dim)
1808
+
1809
+ english_token_ids (torch.Tensor):
1810
+ Shape (batch_size, num_english_tokens)
1811
+ """
1812
+ projected_bio_embeddings = self.bio_projection(bio_embeddings)
1813
+ english_embeddings = self.token_embedding(english_token_ids)
1814
+
1815
+ bio_attention_mask = build_perceiver_padding_attention_mask(
1816
+ bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
1817
+ )
1818
+ english_attention_mask = build_perceiver_padding_attention_mask(
1819
+ english_token_ids, self.config.resampled_length, self.english_pad_token_id
1820
+ )
1821
+
1822
+ projected_embeddings = self.perceiver_resampler(
1823
+ input_embeddings_1=projected_bio_embeddings,
1824
+ attention_mask_1=bio_attention_mask,
1825
+ input_embeddings_2=english_embeddings,
1826
+ attention_mask_2=english_attention_mask,
1827
+ )["embeddings"]
1828
+
1829
+ return projected_embeddings
1830
+
1831
+
1832
+ def build_perceiver_padding_attention_mask(
1833
+ tokens: torch.Tensor, resampled_length: int, pad_token_id: int
1834
+ ) -> torch.Tensor:
1835
+ batch_size, seq_len = tokens.shape
1836
+ padding_mask = tokens != pad_token_id # (batch_size, seq_len)
1837
+
1838
+ padding_mask = torch.cat(
1839
+ [
1840
+ padding_mask,
1841
+ torch.ones(
1842
+ (batch_size, resampled_length), dtype=torch.bool, device=tokens.device
1843
+ ),
1844
+ ],
1845
+ dim=1,
1846
+ ) # (batch_size, seq_len + resampled_length)
1847
+
1848
+ padding_mask = padding_mask[:, None, None, :]
1849
+ padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1850
+ return padding_mask
config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TorchMultiOmicsModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "chatNT.ChatNTConfig",
7
+ "AutoModel": "chatNT.TorchMultiOmicsModel"
8
+ },
9
+ "bio_pad_token_id": 1,
10
+ "english_pad_token_id": 2,
11
+ "esm_config": {
12
+ "add_bias_ffn": false,
13
+ "add_bias_kv": false,
14
+ "alphabet_size": 4107,
15
+ "attention_heads": 16,
16
+ "attention_maps_to_save": [],
17
+ "bias_word_embedding": false,
18
+ "emb_layer_norm_before": false,
19
+ "embed_dim": 1024,
20
+ "embed_scale": 1.0,
21
+ "embeddings_layers_to_save": [
22
+ 21
23
+ ],
24
+ "ffn_activation_name": "swish",
25
+ "ffn_embed_dim": 4096,
26
+ "key_size": 64,
27
+ "layer_norm_eps": 1e-05,
28
+ "lm_head": "roberta",
29
+ "mask_before_attention": false,
30
+ "mask_token_id": 2,
31
+ "masking_prob": 0.0,
32
+ "masking_ratio": 0.0,
33
+ "max_positions": 2048,
34
+ "num_layers": 29,
35
+ "pad_token_id": 1,
36
+ "positional_embedding": null,
37
+ "pre_layer_norm": true,
38
+ "rescaling_factor": null,
39
+ "token_dropout": false,
40
+ "use_glu_in_ffn": true,
41
+ "use_gradient_checkpointing": false,
42
+ "use_rotary_embedding": true
43
+ },
44
+ "gpt_config": {
45
+ "add_bias_attn": false,
46
+ "add_bias_ffn": false,
47
+ "add_bias_lm_head": false,
48
+ "embed_dim": 4096,
49
+ "eos_token_id": 2,
50
+ "ffn_activation_name": "silu",
51
+ "ffn_embed_dim": 11008,
52
+ "norm_type": "RMS_norm",
53
+ "num_heads": 32,
54
+ "num_kv_heads": 32,
55
+ "num_layers": 32,
56
+ "parallel_attention_ff": false,
57
+ "rms_norm_eps": 1e-06,
58
+ "rope_config": {
59
+ "dim": 128,
60
+ "max_seq_len": 2048,
61
+ "theta": 10000.0
62
+ },
63
+ "use_glu_in_ffn": true,
64
+ "use_gradient_checkpointing": false,
65
+ "vocab_size": 32000
66
+ },
67
+ "model_type": "ChatNT",
68
+ "perceiver_resampler_config": {
69
+ "add_bias_ffn": true,
70
+ "add_bias_kv": false,
71
+ "attention_heads": 32,
72
+ "emb_layer_norm_before": false,
73
+ "embed_dim": 4096,
74
+ "ffn_activation_name": "gelu-no-approx",
75
+ "ffn_embed_dim": 11008,
76
+ "key_size": 128,
77
+ "num_layers": 3,
78
+ "resampled_length": 64,
79
+ "use_glu_in_ffn": false,
80
+ "use_gradient_checkpointing": false
81
+ },
82
+ "seq_token_id": 32000,
83
+ "torch_dtype": "float32",
84
+ "transformers_version": "4.41.1"
85
+ }
model.safetensors.index.json ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 32174388268
4
+ },
5
+ "weight_map": {
6
+ "biobrain_decoder.gpt_model.final_norm.scale": "model-00001-of-00007.safetensors",
7
+ "biobrain_decoder.gpt_model.layers.0.attn_norm.scale": "model-00001-of-00007.safetensors",
8
+ "biobrain_decoder.gpt_model.layers.0.fc1.weight": "model-00001-of-00007.safetensors",
9
+ "biobrain_decoder.gpt_model.layers.0.fc2.weight": "model-00001-of-00007.safetensors",
10
+ "biobrain_decoder.gpt_model.layers.0.ffn_norm.scale": "model-00001-of-00007.safetensors",
11
+ "biobrain_decoder.gpt_model.layers.0.self_attn.key_linear.weight": "model-00001-of-00007.safetensors",
12
+ "biobrain_decoder.gpt_model.layers.0.self_attn.out_linear.weight": "model-00001-of-00007.safetensors",
13
+ "biobrain_decoder.gpt_model.layers.0.self_attn.query_linear.weight": "model-00001-of-00007.safetensors",
14
+ "biobrain_decoder.gpt_model.layers.0.self_attn.value_linear.weight": "model-00001-of-00007.safetensors",
15
+ "biobrain_decoder.gpt_model.layers.1.attn_norm.scale": "model-00001-of-00007.safetensors",
16
+ "biobrain_decoder.gpt_model.layers.1.fc1.weight": "model-00001-of-00007.safetensors",
17
+ "biobrain_decoder.gpt_model.layers.1.fc2.weight": "model-00001-of-00007.safetensors",
18
+ "biobrain_decoder.gpt_model.layers.1.ffn_norm.scale": "model-00001-of-00007.safetensors",
19
+ "biobrain_decoder.gpt_model.layers.1.self_attn.key_linear.weight": "model-00001-of-00007.safetensors",
20
+ "biobrain_decoder.gpt_model.layers.1.self_attn.out_linear.weight": "model-00001-of-00007.safetensors",
21
+ "biobrain_decoder.gpt_model.layers.1.self_attn.query_linear.weight": "model-00001-of-00007.safetensors",
22
+ "biobrain_decoder.gpt_model.layers.1.self_attn.value_linear.weight": "model-00001-of-00007.safetensors",
23
+ "biobrain_decoder.gpt_model.layers.10.attn_norm.scale": "model-00003-of-00007.safetensors",
24
+ "biobrain_decoder.gpt_model.layers.10.fc1.weight": "model-00003-of-00007.safetensors",
25
+ "biobrain_decoder.gpt_model.layers.10.fc2.weight": "model-00003-of-00007.safetensors",
26
+ "biobrain_decoder.gpt_model.layers.10.ffn_norm.scale": "model-00003-of-00007.safetensors",
27
+ "biobrain_decoder.gpt_model.layers.10.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
28
+ "biobrain_decoder.gpt_model.layers.10.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
29
+ "biobrain_decoder.gpt_model.layers.10.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
30
+ "biobrain_decoder.gpt_model.layers.10.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
31
+ "biobrain_decoder.gpt_model.layers.11.attn_norm.scale": "model-00003-of-00007.safetensors",
32
+ "biobrain_decoder.gpt_model.layers.11.fc1.weight": "model-00003-of-00007.safetensors",
33
+ "biobrain_decoder.gpt_model.layers.11.fc2.weight": "model-00003-of-00007.safetensors",
34
+ "biobrain_decoder.gpt_model.layers.11.ffn_norm.scale": "model-00003-of-00007.safetensors",
35
+ "biobrain_decoder.gpt_model.layers.11.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
36
+ "biobrain_decoder.gpt_model.layers.11.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
37
+ "biobrain_decoder.gpt_model.layers.11.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
38
+ "biobrain_decoder.gpt_model.layers.11.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
39
+ "biobrain_decoder.gpt_model.layers.12.attn_norm.scale": "model-00003-of-00007.safetensors",
40
+ "biobrain_decoder.gpt_model.layers.12.fc1.weight": "model-00003-of-00007.safetensors",
41
+ "biobrain_decoder.gpt_model.layers.12.fc2.weight": "model-00003-of-00007.safetensors",
42
+ "biobrain_decoder.gpt_model.layers.12.ffn_norm.scale": "model-00003-of-00007.safetensors",
43
+ "biobrain_decoder.gpt_model.layers.12.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
44
+ "biobrain_decoder.gpt_model.layers.12.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
45
+ "biobrain_decoder.gpt_model.layers.12.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
46
+ "biobrain_decoder.gpt_model.layers.12.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
47
+ "biobrain_decoder.gpt_model.layers.13.attn_norm.scale": "model-00003-of-00007.safetensors",
48
+ "biobrain_decoder.gpt_model.layers.13.fc1.weight": "model-00003-of-00007.safetensors",
49
+ "biobrain_decoder.gpt_model.layers.13.fc2.weight": "model-00003-of-00007.safetensors",
50
+ "biobrain_decoder.gpt_model.layers.13.ffn_norm.scale": "model-00003-of-00007.safetensors",
51
+ "biobrain_decoder.gpt_model.layers.13.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
52
+ "biobrain_decoder.gpt_model.layers.13.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
53
+ "biobrain_decoder.gpt_model.layers.13.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
54
+ "biobrain_decoder.gpt_model.layers.13.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
55
+ "biobrain_decoder.gpt_model.layers.14.attn_norm.scale": "model-00003-of-00007.safetensors",
56
+ "biobrain_decoder.gpt_model.layers.14.fc1.weight": "model-00003-of-00007.safetensors",
57
+ "biobrain_decoder.gpt_model.layers.14.fc2.weight": "model-00003-of-00007.safetensors",
58
+ "biobrain_decoder.gpt_model.layers.14.ffn_norm.scale": "model-00003-of-00007.safetensors",
59
+ "biobrain_decoder.gpt_model.layers.14.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
60
+ "biobrain_decoder.gpt_model.layers.14.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
61
+ "biobrain_decoder.gpt_model.layers.14.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
62
+ "biobrain_decoder.gpt_model.layers.14.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
63
+ "biobrain_decoder.gpt_model.layers.15.attn_norm.scale": "model-00003-of-00007.safetensors",
64
+ "biobrain_decoder.gpt_model.layers.15.fc1.weight": "model-00004-of-00007.safetensors",
65
+ "biobrain_decoder.gpt_model.layers.15.fc2.weight": "model-00004-of-00007.safetensors",
66
+ "biobrain_decoder.gpt_model.layers.15.ffn_norm.scale": "model-00003-of-00007.safetensors",
67
+ "biobrain_decoder.gpt_model.layers.15.self_attn.key_linear.weight": "model-00003-of-00007.safetensors",
68
+ "biobrain_decoder.gpt_model.layers.15.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
69
+ "biobrain_decoder.gpt_model.layers.15.self_attn.query_linear.weight": "model-00003-of-00007.safetensors",
70
+ "biobrain_decoder.gpt_model.layers.15.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
71
+ "biobrain_decoder.gpt_model.layers.16.attn_norm.scale": "model-00004-of-00007.safetensors",
72
+ "biobrain_decoder.gpt_model.layers.16.fc1.weight": "model-00004-of-00007.safetensors",
73
+ "biobrain_decoder.gpt_model.layers.16.fc2.weight": "model-00004-of-00007.safetensors",
74
+ "biobrain_decoder.gpt_model.layers.16.ffn_norm.scale": "model-00004-of-00007.safetensors",
75
+ "biobrain_decoder.gpt_model.layers.16.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
76
+ "biobrain_decoder.gpt_model.layers.16.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
77
+ "biobrain_decoder.gpt_model.layers.16.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
78
+ "biobrain_decoder.gpt_model.layers.16.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
79
+ "biobrain_decoder.gpt_model.layers.17.attn_norm.scale": "model-00004-of-00007.safetensors",
80
+ "biobrain_decoder.gpt_model.layers.17.fc1.weight": "model-00004-of-00007.safetensors",
81
+ "biobrain_decoder.gpt_model.layers.17.fc2.weight": "model-00004-of-00007.safetensors",
82
+ "biobrain_decoder.gpt_model.layers.17.ffn_norm.scale": "model-00004-of-00007.safetensors",
83
+ "biobrain_decoder.gpt_model.layers.17.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
84
+ "biobrain_decoder.gpt_model.layers.17.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
85
+ "biobrain_decoder.gpt_model.layers.17.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
86
+ "biobrain_decoder.gpt_model.layers.17.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
87
+ "biobrain_decoder.gpt_model.layers.18.attn_norm.scale": "model-00004-of-00007.safetensors",
88
+ "biobrain_decoder.gpt_model.layers.18.fc1.weight": "model-00004-of-00007.safetensors",
89
+ "biobrain_decoder.gpt_model.layers.18.fc2.weight": "model-00004-of-00007.safetensors",
90
+ "biobrain_decoder.gpt_model.layers.18.ffn_norm.scale": "model-00004-of-00007.safetensors",
91
+ "biobrain_decoder.gpt_model.layers.18.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
92
+ "biobrain_decoder.gpt_model.layers.18.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
93
+ "biobrain_decoder.gpt_model.layers.18.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
94
+ "biobrain_decoder.gpt_model.layers.18.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
95
+ "biobrain_decoder.gpt_model.layers.19.attn_norm.scale": "model-00004-of-00007.safetensors",
96
+ "biobrain_decoder.gpt_model.layers.19.fc1.weight": "model-00004-of-00007.safetensors",
97
+ "biobrain_decoder.gpt_model.layers.19.fc2.weight": "model-00004-of-00007.safetensors",
98
+ "biobrain_decoder.gpt_model.layers.19.ffn_norm.scale": "model-00004-of-00007.safetensors",
99
+ "biobrain_decoder.gpt_model.layers.19.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
100
+ "biobrain_decoder.gpt_model.layers.19.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
101
+ "biobrain_decoder.gpt_model.layers.19.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
102
+ "biobrain_decoder.gpt_model.layers.19.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
103
+ "biobrain_decoder.gpt_model.layers.2.attn_norm.scale": "model-00001-of-00007.safetensors",
104
+ "biobrain_decoder.gpt_model.layers.2.fc1.weight": "model-00001-of-00007.safetensors",
105
+ "biobrain_decoder.gpt_model.layers.2.fc2.weight": "model-00001-of-00007.safetensors",
106
+ "biobrain_decoder.gpt_model.layers.2.ffn_norm.scale": "model-00001-of-00007.safetensors",
107
+ "biobrain_decoder.gpt_model.layers.2.self_attn.key_linear.weight": "model-00001-of-00007.safetensors",
108
+ "biobrain_decoder.gpt_model.layers.2.self_attn.out_linear.weight": "model-00001-of-00007.safetensors",
109
+ "biobrain_decoder.gpt_model.layers.2.self_attn.query_linear.weight": "model-00001-of-00007.safetensors",
110
+ "biobrain_decoder.gpt_model.layers.2.self_attn.value_linear.weight": "model-00001-of-00007.safetensors",
111
+ "biobrain_decoder.gpt_model.layers.20.attn_norm.scale": "model-00004-of-00007.safetensors",
112
+ "biobrain_decoder.gpt_model.layers.20.fc1.weight": "model-00004-of-00007.safetensors",
113
+ "biobrain_decoder.gpt_model.layers.20.fc2.weight": "model-00004-of-00007.safetensors",
114
+ "biobrain_decoder.gpt_model.layers.20.ffn_norm.scale": "model-00004-of-00007.safetensors",
115
+ "biobrain_decoder.gpt_model.layers.20.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
116
+ "biobrain_decoder.gpt_model.layers.20.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
117
+ "biobrain_decoder.gpt_model.layers.20.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
118
+ "biobrain_decoder.gpt_model.layers.20.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
119
+ "biobrain_decoder.gpt_model.layers.21.attn_norm.scale": "model-00004-of-00007.safetensors",
120
+ "biobrain_decoder.gpt_model.layers.21.fc1.weight": "model-00005-of-00007.safetensors",
121
+ "biobrain_decoder.gpt_model.layers.21.fc2.weight": "model-00005-of-00007.safetensors",
122
+ "biobrain_decoder.gpt_model.layers.21.ffn_norm.scale": "model-00004-of-00007.safetensors",
123
+ "biobrain_decoder.gpt_model.layers.21.self_attn.key_linear.weight": "model-00004-of-00007.safetensors",
124
+ "biobrain_decoder.gpt_model.layers.21.self_attn.out_linear.weight": "model-00004-of-00007.safetensors",
125
+ "biobrain_decoder.gpt_model.layers.21.self_attn.query_linear.weight": "model-00004-of-00007.safetensors",
126
+ "biobrain_decoder.gpt_model.layers.21.self_attn.value_linear.weight": "model-00004-of-00007.safetensors",
127
+ "biobrain_decoder.gpt_model.layers.22.attn_norm.scale": "model-00005-of-00007.safetensors",
128
+ "biobrain_decoder.gpt_model.layers.22.fc1.weight": "model-00005-of-00007.safetensors",
129
+ "biobrain_decoder.gpt_model.layers.22.fc2.weight": "model-00005-of-00007.safetensors",
130
+ "biobrain_decoder.gpt_model.layers.22.ffn_norm.scale": "model-00005-of-00007.safetensors",
131
+ "biobrain_decoder.gpt_model.layers.22.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
132
+ "biobrain_decoder.gpt_model.layers.22.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
133
+ "biobrain_decoder.gpt_model.layers.22.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
134
+ "biobrain_decoder.gpt_model.layers.22.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
135
+ "biobrain_decoder.gpt_model.layers.23.attn_norm.scale": "model-00005-of-00007.safetensors",
136
+ "biobrain_decoder.gpt_model.layers.23.fc1.weight": "model-00005-of-00007.safetensors",
137
+ "biobrain_decoder.gpt_model.layers.23.fc2.weight": "model-00005-of-00007.safetensors",
138
+ "biobrain_decoder.gpt_model.layers.23.ffn_norm.scale": "model-00005-of-00007.safetensors",
139
+ "biobrain_decoder.gpt_model.layers.23.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
140
+ "biobrain_decoder.gpt_model.layers.23.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
141
+ "biobrain_decoder.gpt_model.layers.23.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
142
+ "biobrain_decoder.gpt_model.layers.23.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
143
+ "biobrain_decoder.gpt_model.layers.24.attn_norm.scale": "model-00005-of-00007.safetensors",
144
+ "biobrain_decoder.gpt_model.layers.24.fc1.weight": "model-00005-of-00007.safetensors",
145
+ "biobrain_decoder.gpt_model.layers.24.fc2.weight": "model-00005-of-00007.safetensors",
146
+ "biobrain_decoder.gpt_model.layers.24.ffn_norm.scale": "model-00005-of-00007.safetensors",
147
+ "biobrain_decoder.gpt_model.layers.24.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
148
+ "biobrain_decoder.gpt_model.layers.24.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
149
+ "biobrain_decoder.gpt_model.layers.24.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
150
+ "biobrain_decoder.gpt_model.layers.24.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
151
+ "biobrain_decoder.gpt_model.layers.25.attn_norm.scale": "model-00005-of-00007.safetensors",
152
+ "biobrain_decoder.gpt_model.layers.25.fc1.weight": "model-00005-of-00007.safetensors",
153
+ "biobrain_decoder.gpt_model.layers.25.fc2.weight": "model-00005-of-00007.safetensors",
154
+ "biobrain_decoder.gpt_model.layers.25.ffn_norm.scale": "model-00005-of-00007.safetensors",
155
+ "biobrain_decoder.gpt_model.layers.25.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
156
+ "biobrain_decoder.gpt_model.layers.25.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
157
+ "biobrain_decoder.gpt_model.layers.25.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
158
+ "biobrain_decoder.gpt_model.layers.25.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
159
+ "biobrain_decoder.gpt_model.layers.26.attn_norm.scale": "model-00005-of-00007.safetensors",
160
+ "biobrain_decoder.gpt_model.layers.26.fc1.weight": "model-00005-of-00007.safetensors",
161
+ "biobrain_decoder.gpt_model.layers.26.fc2.weight": "model-00005-of-00007.safetensors",
162
+ "biobrain_decoder.gpt_model.layers.26.ffn_norm.scale": "model-00005-of-00007.safetensors",
163
+ "biobrain_decoder.gpt_model.layers.26.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
164
+ "biobrain_decoder.gpt_model.layers.26.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
165
+ "biobrain_decoder.gpt_model.layers.26.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
166
+ "biobrain_decoder.gpt_model.layers.26.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
167
+ "biobrain_decoder.gpt_model.layers.27.attn_norm.scale": "model-00005-of-00007.safetensors",
168
+ "biobrain_decoder.gpt_model.layers.27.fc1.weight": "model-00006-of-00007.safetensors",
169
+ "biobrain_decoder.gpt_model.layers.27.fc2.weight": "model-00006-of-00007.safetensors",
170
+ "biobrain_decoder.gpt_model.layers.27.ffn_norm.scale": "model-00005-of-00007.safetensors",
171
+ "biobrain_decoder.gpt_model.layers.27.self_attn.key_linear.weight": "model-00005-of-00007.safetensors",
172
+ "biobrain_decoder.gpt_model.layers.27.self_attn.out_linear.weight": "model-00005-of-00007.safetensors",
173
+ "biobrain_decoder.gpt_model.layers.27.self_attn.query_linear.weight": "model-00005-of-00007.safetensors",
174
+ "biobrain_decoder.gpt_model.layers.27.self_attn.value_linear.weight": "model-00005-of-00007.safetensors",
175
+ "biobrain_decoder.gpt_model.layers.28.attn_norm.scale": "model-00006-of-00007.safetensors",
176
+ "biobrain_decoder.gpt_model.layers.28.fc1.weight": "model-00006-of-00007.safetensors",
177
+ "biobrain_decoder.gpt_model.layers.28.fc2.weight": "model-00006-of-00007.safetensors",
178
+ "biobrain_decoder.gpt_model.layers.28.ffn_norm.scale": "model-00006-of-00007.safetensors",
179
+ "biobrain_decoder.gpt_model.layers.28.self_attn.key_linear.weight": "model-00006-of-00007.safetensors",
180
+ "biobrain_decoder.gpt_model.layers.28.self_attn.out_linear.weight": "model-00006-of-00007.safetensors",
181
+ "biobrain_decoder.gpt_model.layers.28.self_attn.query_linear.weight": "model-00006-of-00007.safetensors",
182
+ "biobrain_decoder.gpt_model.layers.28.self_attn.value_linear.weight": "model-00006-of-00007.safetensors",
183
+ "biobrain_decoder.gpt_model.layers.29.attn_norm.scale": "model-00006-of-00007.safetensors",
184
+ "biobrain_decoder.gpt_model.layers.29.fc1.weight": "model-00006-of-00007.safetensors",
185
+ "biobrain_decoder.gpt_model.layers.29.fc2.weight": "model-00006-of-00007.safetensors",
186
+ "biobrain_decoder.gpt_model.layers.29.ffn_norm.scale": "model-00006-of-00007.safetensors",
187
+ "biobrain_decoder.gpt_model.layers.29.self_attn.key_linear.weight": "model-00006-of-00007.safetensors",
188
+ "biobrain_decoder.gpt_model.layers.29.self_attn.out_linear.weight": "model-00006-of-00007.safetensors",
189
+ "biobrain_decoder.gpt_model.layers.29.self_attn.query_linear.weight": "model-00006-of-00007.safetensors",
190
+ "biobrain_decoder.gpt_model.layers.29.self_attn.value_linear.weight": "model-00006-of-00007.safetensors",
191
+ "biobrain_decoder.gpt_model.layers.3.attn_norm.scale": "model-00002-of-00007.safetensors",
192
+ "biobrain_decoder.gpt_model.layers.3.fc1.weight": "model-00002-of-00007.safetensors",
193
+ "biobrain_decoder.gpt_model.layers.3.fc2.weight": "model-00002-of-00007.safetensors",
194
+ "biobrain_decoder.gpt_model.layers.3.ffn_norm.scale": "model-00002-of-00007.safetensors",
195
+ "biobrain_decoder.gpt_model.layers.3.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
196
+ "biobrain_decoder.gpt_model.layers.3.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
197
+ "biobrain_decoder.gpt_model.layers.3.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
198
+ "biobrain_decoder.gpt_model.layers.3.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
199
+ "biobrain_decoder.gpt_model.layers.30.attn_norm.scale": "model-00006-of-00007.safetensors",
200
+ "biobrain_decoder.gpt_model.layers.30.fc1.weight": "model-00006-of-00007.safetensors",
201
+ "biobrain_decoder.gpt_model.layers.30.fc2.weight": "model-00006-of-00007.safetensors",
202
+ "biobrain_decoder.gpt_model.layers.30.ffn_norm.scale": "model-00006-of-00007.safetensors",
203
+ "biobrain_decoder.gpt_model.layers.30.self_attn.key_linear.weight": "model-00006-of-00007.safetensors",
204
+ "biobrain_decoder.gpt_model.layers.30.self_attn.out_linear.weight": "model-00006-of-00007.safetensors",
205
+ "biobrain_decoder.gpt_model.layers.30.self_attn.query_linear.weight": "model-00006-of-00007.safetensors",
206
+ "biobrain_decoder.gpt_model.layers.30.self_attn.value_linear.weight": "model-00006-of-00007.safetensors",
207
+ "biobrain_decoder.gpt_model.layers.31.attn_norm.scale": "model-00006-of-00007.safetensors",
208
+ "biobrain_decoder.gpt_model.layers.31.fc1.weight": "model-00006-of-00007.safetensors",
209
+ "biobrain_decoder.gpt_model.layers.31.fc2.weight": "model-00006-of-00007.safetensors",
210
+ "biobrain_decoder.gpt_model.layers.31.ffn_norm.scale": "model-00006-of-00007.safetensors",
211
+ "biobrain_decoder.gpt_model.layers.31.self_attn.key_linear.weight": "model-00006-of-00007.safetensors",
212
+ "biobrain_decoder.gpt_model.layers.31.self_attn.out_linear.weight": "model-00006-of-00007.safetensors",
213
+ "biobrain_decoder.gpt_model.layers.31.self_attn.query_linear.weight": "model-00006-of-00007.safetensors",
214
+ "biobrain_decoder.gpt_model.layers.31.self_attn.value_linear.weight": "model-00006-of-00007.safetensors",
215
+ "biobrain_decoder.gpt_model.layers.4.attn_norm.scale": "model-00002-of-00007.safetensors",
216
+ "biobrain_decoder.gpt_model.layers.4.fc1.weight": "model-00002-of-00007.safetensors",
217
+ "biobrain_decoder.gpt_model.layers.4.fc2.weight": "model-00002-of-00007.safetensors",
218
+ "biobrain_decoder.gpt_model.layers.4.ffn_norm.scale": "model-00002-of-00007.safetensors",
219
+ "biobrain_decoder.gpt_model.layers.4.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
220
+ "biobrain_decoder.gpt_model.layers.4.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
221
+ "biobrain_decoder.gpt_model.layers.4.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
222
+ "biobrain_decoder.gpt_model.layers.4.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
223
+ "biobrain_decoder.gpt_model.layers.5.attn_norm.scale": "model-00002-of-00007.safetensors",
224
+ "biobrain_decoder.gpt_model.layers.5.fc1.weight": "model-00002-of-00007.safetensors",
225
+ "biobrain_decoder.gpt_model.layers.5.fc2.weight": "model-00002-of-00007.safetensors",
226
+ "biobrain_decoder.gpt_model.layers.5.ffn_norm.scale": "model-00002-of-00007.safetensors",
227
+ "biobrain_decoder.gpt_model.layers.5.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
228
+ "biobrain_decoder.gpt_model.layers.5.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
229
+ "biobrain_decoder.gpt_model.layers.5.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
230
+ "biobrain_decoder.gpt_model.layers.5.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
231
+ "biobrain_decoder.gpt_model.layers.6.attn_norm.scale": "model-00002-of-00007.safetensors",
232
+ "biobrain_decoder.gpt_model.layers.6.fc1.weight": "model-00002-of-00007.safetensors",
233
+ "biobrain_decoder.gpt_model.layers.6.fc2.weight": "model-00002-of-00007.safetensors",
234
+ "biobrain_decoder.gpt_model.layers.6.ffn_norm.scale": "model-00002-of-00007.safetensors",
235
+ "biobrain_decoder.gpt_model.layers.6.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
236
+ "biobrain_decoder.gpt_model.layers.6.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
237
+ "biobrain_decoder.gpt_model.layers.6.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
238
+ "biobrain_decoder.gpt_model.layers.6.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
239
+ "biobrain_decoder.gpt_model.layers.7.attn_norm.scale": "model-00002-of-00007.safetensors",
240
+ "biobrain_decoder.gpt_model.layers.7.fc1.weight": "model-00002-of-00007.safetensors",
241
+ "biobrain_decoder.gpt_model.layers.7.fc2.weight": "model-00002-of-00007.safetensors",
242
+ "biobrain_decoder.gpt_model.layers.7.ffn_norm.scale": "model-00002-of-00007.safetensors",
243
+ "biobrain_decoder.gpt_model.layers.7.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
244
+ "biobrain_decoder.gpt_model.layers.7.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
245
+ "biobrain_decoder.gpt_model.layers.7.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
246
+ "biobrain_decoder.gpt_model.layers.7.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
247
+ "biobrain_decoder.gpt_model.layers.8.attn_norm.scale": "model-00002-of-00007.safetensors",
248
+ "biobrain_decoder.gpt_model.layers.8.fc1.weight": "model-00002-of-00007.safetensors",
249
+ "biobrain_decoder.gpt_model.layers.8.fc2.weight": "model-00002-of-00007.safetensors",
250
+ "biobrain_decoder.gpt_model.layers.8.ffn_norm.scale": "model-00002-of-00007.safetensors",
251
+ "biobrain_decoder.gpt_model.layers.8.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
252
+ "biobrain_decoder.gpt_model.layers.8.self_attn.out_linear.weight": "model-00002-of-00007.safetensors",
253
+ "biobrain_decoder.gpt_model.layers.8.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
254
+ "biobrain_decoder.gpt_model.layers.8.self_attn.value_linear.weight": "model-00002-of-00007.safetensors",
255
+ "biobrain_decoder.gpt_model.layers.9.attn_norm.scale": "model-00003-of-00007.safetensors",
256
+ "biobrain_decoder.gpt_model.layers.9.fc1.weight": "model-00003-of-00007.safetensors",
257
+ "biobrain_decoder.gpt_model.layers.9.fc2.weight": "model-00003-of-00007.safetensors",
258
+ "biobrain_decoder.gpt_model.layers.9.ffn_norm.scale": "model-00003-of-00007.safetensors",
259
+ "biobrain_decoder.gpt_model.layers.9.self_attn.key_linear.weight": "model-00002-of-00007.safetensors",
260
+ "biobrain_decoder.gpt_model.layers.9.self_attn.out_linear.weight": "model-00003-of-00007.safetensors",
261
+ "biobrain_decoder.gpt_model.layers.9.self_attn.query_linear.weight": "model-00002-of-00007.safetensors",
262
+ "biobrain_decoder.gpt_model.layers.9.self_attn.value_linear.weight": "model-00003-of-00007.safetensors",
263
+ "biobrain_decoder.gpt_model.lm_head.fc.weight": "model-00006-of-00007.safetensors",
264
+ "biobrain_decoder.gpt_model.token_embed.weight": "model-00001-of-00007.safetensors",
265
+ "biobrain_encoder.esm_model.attention_blocks.0.fc1.weight": "model-00001-of-00007.safetensors",
266
+ "biobrain_encoder.esm_model.attention_blocks.0.fc2.weight": "model-00001-of-00007.safetensors",
267
+ "biobrain_encoder.esm_model.attention_blocks.0.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
268
+ "biobrain_encoder.esm_model.attention_blocks.0.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
269
+ "biobrain_encoder.esm_model.attention_blocks.0.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
270
+ "biobrain_encoder.esm_model.attention_blocks.0.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
271
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.output.bias": "model-00001-of-00007.safetensors",
272
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.output.weight": "model-00001-of-00007.safetensors",
273
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_k.bias": "model-00001-of-00007.safetensors",
274
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_k.weight": "model-00001-of-00007.safetensors",
275
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_q.bias": "model-00001-of-00007.safetensors",
276
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_q.weight": "model-00001-of-00007.safetensors",
277
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_v.bias": "model-00001-of-00007.safetensors",
278
+ "biobrain_encoder.esm_model.attention_blocks.0.mha.w_v.weight": "model-00001-of-00007.safetensors",
279
+ "biobrain_encoder.esm_model.attention_blocks.1.fc1.weight": "model-00001-of-00007.safetensors",
280
+ "biobrain_encoder.esm_model.attention_blocks.1.fc2.weight": "model-00001-of-00007.safetensors",
281
+ "biobrain_encoder.esm_model.attention_blocks.1.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
282
+ "biobrain_encoder.esm_model.attention_blocks.1.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
283
+ "biobrain_encoder.esm_model.attention_blocks.1.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
284
+ "biobrain_encoder.esm_model.attention_blocks.1.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
285
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.output.bias": "model-00001-of-00007.safetensors",
286
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.output.weight": "model-00001-of-00007.safetensors",
287
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_k.bias": "model-00001-of-00007.safetensors",
288
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_k.weight": "model-00001-of-00007.safetensors",
289
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_q.bias": "model-00001-of-00007.safetensors",
290
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_q.weight": "model-00001-of-00007.safetensors",
291
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_v.bias": "model-00001-of-00007.safetensors",
292
+ "biobrain_encoder.esm_model.attention_blocks.1.mha.w_v.weight": "model-00001-of-00007.safetensors",
293
+ "biobrain_encoder.esm_model.attention_blocks.10.fc1.weight": "model-00001-of-00007.safetensors",
294
+ "biobrain_encoder.esm_model.attention_blocks.10.fc2.weight": "model-00001-of-00007.safetensors",
295
+ "biobrain_encoder.esm_model.attention_blocks.10.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
296
+ "biobrain_encoder.esm_model.attention_blocks.10.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
297
+ "biobrain_encoder.esm_model.attention_blocks.10.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
298
+ "biobrain_encoder.esm_model.attention_blocks.10.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
299
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.output.bias": "model-00001-of-00007.safetensors",
300
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.output.weight": "model-00001-of-00007.safetensors",
301
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_k.bias": "model-00001-of-00007.safetensors",
302
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_k.weight": "model-00001-of-00007.safetensors",
303
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_q.bias": "model-00001-of-00007.safetensors",
304
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_q.weight": "model-00001-of-00007.safetensors",
305
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_v.bias": "model-00001-of-00007.safetensors",
306
+ "biobrain_encoder.esm_model.attention_blocks.10.mha.w_v.weight": "model-00001-of-00007.safetensors",
307
+ "biobrain_encoder.esm_model.attention_blocks.11.fc1.weight": "model-00001-of-00007.safetensors",
308
+ "biobrain_encoder.esm_model.attention_blocks.11.fc2.weight": "model-00001-of-00007.safetensors",
309
+ "biobrain_encoder.esm_model.attention_blocks.11.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
310
+ "biobrain_encoder.esm_model.attention_blocks.11.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
311
+ "biobrain_encoder.esm_model.attention_blocks.11.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
312
+ "biobrain_encoder.esm_model.attention_blocks.11.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
313
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.output.bias": "model-00001-of-00007.safetensors",
314
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.output.weight": "model-00001-of-00007.safetensors",
315
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_k.bias": "model-00001-of-00007.safetensors",
316
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_k.weight": "model-00001-of-00007.safetensors",
317
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_q.bias": "model-00001-of-00007.safetensors",
318
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_q.weight": "model-00001-of-00007.safetensors",
319
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_v.bias": "model-00001-of-00007.safetensors",
320
+ "biobrain_encoder.esm_model.attention_blocks.11.mha.w_v.weight": "model-00001-of-00007.safetensors",
321
+ "biobrain_encoder.esm_model.attention_blocks.12.fc1.weight": "model-00001-of-00007.safetensors",
322
+ "biobrain_encoder.esm_model.attention_blocks.12.fc2.weight": "model-00001-of-00007.safetensors",
323
+ "biobrain_encoder.esm_model.attention_blocks.12.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
324
+ "biobrain_encoder.esm_model.attention_blocks.12.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
325
+ "biobrain_encoder.esm_model.attention_blocks.12.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
326
+ "biobrain_encoder.esm_model.attention_blocks.12.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
327
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.output.bias": "model-00001-of-00007.safetensors",
328
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.output.weight": "model-00001-of-00007.safetensors",
329
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_k.bias": "model-00001-of-00007.safetensors",
330
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_k.weight": "model-00001-of-00007.safetensors",
331
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_q.bias": "model-00001-of-00007.safetensors",
332
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_q.weight": "model-00001-of-00007.safetensors",
333
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_v.bias": "model-00001-of-00007.safetensors",
334
+ "biobrain_encoder.esm_model.attention_blocks.12.mha.w_v.weight": "model-00001-of-00007.safetensors",
335
+ "biobrain_encoder.esm_model.attention_blocks.13.fc1.weight": "model-00001-of-00007.safetensors",
336
+ "biobrain_encoder.esm_model.attention_blocks.13.fc2.weight": "model-00001-of-00007.safetensors",
337
+ "biobrain_encoder.esm_model.attention_blocks.13.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
338
+ "biobrain_encoder.esm_model.attention_blocks.13.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
339
+ "biobrain_encoder.esm_model.attention_blocks.13.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
340
+ "biobrain_encoder.esm_model.attention_blocks.13.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
341
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.output.bias": "model-00001-of-00007.safetensors",
342
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.output.weight": "model-00001-of-00007.safetensors",
343
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_k.bias": "model-00001-of-00007.safetensors",
344
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_k.weight": "model-00001-of-00007.safetensors",
345
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_q.bias": "model-00001-of-00007.safetensors",
346
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_q.weight": "model-00001-of-00007.safetensors",
347
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_v.bias": "model-00001-of-00007.safetensors",
348
+ "biobrain_encoder.esm_model.attention_blocks.13.mha.w_v.weight": "model-00001-of-00007.safetensors",
349
+ "biobrain_encoder.esm_model.attention_blocks.14.fc1.weight": "model-00001-of-00007.safetensors",
350
+ "biobrain_encoder.esm_model.attention_blocks.14.fc2.weight": "model-00001-of-00007.safetensors",
351
+ "biobrain_encoder.esm_model.attention_blocks.14.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
352
+ "biobrain_encoder.esm_model.attention_blocks.14.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
353
+ "biobrain_encoder.esm_model.attention_blocks.14.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
354
+ "biobrain_encoder.esm_model.attention_blocks.14.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
355
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.output.bias": "model-00001-of-00007.safetensors",
356
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.output.weight": "model-00001-of-00007.safetensors",
357
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_k.bias": "model-00001-of-00007.safetensors",
358
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_k.weight": "model-00001-of-00007.safetensors",
359
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_q.bias": "model-00001-of-00007.safetensors",
360
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_q.weight": "model-00001-of-00007.safetensors",
361
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_v.bias": "model-00001-of-00007.safetensors",
362
+ "biobrain_encoder.esm_model.attention_blocks.14.mha.w_v.weight": "model-00001-of-00007.safetensors",
363
+ "biobrain_encoder.esm_model.attention_blocks.15.fc1.weight": "model-00001-of-00007.safetensors",
364
+ "biobrain_encoder.esm_model.attention_blocks.15.fc2.weight": "model-00001-of-00007.safetensors",
365
+ "biobrain_encoder.esm_model.attention_blocks.15.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
366
+ "biobrain_encoder.esm_model.attention_blocks.15.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
367
+ "biobrain_encoder.esm_model.attention_blocks.15.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
368
+ "biobrain_encoder.esm_model.attention_blocks.15.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
369
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.output.bias": "model-00001-of-00007.safetensors",
370
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.output.weight": "model-00001-of-00007.safetensors",
371
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_k.bias": "model-00001-of-00007.safetensors",
372
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_k.weight": "model-00001-of-00007.safetensors",
373
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_q.bias": "model-00001-of-00007.safetensors",
374
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_q.weight": "model-00001-of-00007.safetensors",
375
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_v.bias": "model-00001-of-00007.safetensors",
376
+ "biobrain_encoder.esm_model.attention_blocks.15.mha.w_v.weight": "model-00001-of-00007.safetensors",
377
+ "biobrain_encoder.esm_model.attention_blocks.16.fc1.weight": "model-00001-of-00007.safetensors",
378
+ "biobrain_encoder.esm_model.attention_blocks.16.fc2.weight": "model-00001-of-00007.safetensors",
379
+ "biobrain_encoder.esm_model.attention_blocks.16.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
380
+ "biobrain_encoder.esm_model.attention_blocks.16.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
381
+ "biobrain_encoder.esm_model.attention_blocks.16.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
382
+ "biobrain_encoder.esm_model.attention_blocks.16.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
383
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.output.bias": "model-00001-of-00007.safetensors",
384
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.output.weight": "model-00001-of-00007.safetensors",
385
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_k.bias": "model-00001-of-00007.safetensors",
386
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_k.weight": "model-00001-of-00007.safetensors",
387
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_q.bias": "model-00001-of-00007.safetensors",
388
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_q.weight": "model-00001-of-00007.safetensors",
389
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_v.bias": "model-00001-of-00007.safetensors",
390
+ "biobrain_encoder.esm_model.attention_blocks.16.mha.w_v.weight": "model-00001-of-00007.safetensors",
391
+ "biobrain_encoder.esm_model.attention_blocks.17.fc1.weight": "model-00001-of-00007.safetensors",
392
+ "biobrain_encoder.esm_model.attention_blocks.17.fc2.weight": "model-00001-of-00007.safetensors",
393
+ "biobrain_encoder.esm_model.attention_blocks.17.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
394
+ "biobrain_encoder.esm_model.attention_blocks.17.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
395
+ "biobrain_encoder.esm_model.attention_blocks.17.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
396
+ "biobrain_encoder.esm_model.attention_blocks.17.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
397
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.output.bias": "model-00001-of-00007.safetensors",
398
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.output.weight": "model-00001-of-00007.safetensors",
399
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_k.bias": "model-00001-of-00007.safetensors",
400
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_k.weight": "model-00001-of-00007.safetensors",
401
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_q.bias": "model-00001-of-00007.safetensors",
402
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_q.weight": "model-00001-of-00007.safetensors",
403
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_v.bias": "model-00001-of-00007.safetensors",
404
+ "biobrain_encoder.esm_model.attention_blocks.17.mha.w_v.weight": "model-00001-of-00007.safetensors",
405
+ "biobrain_encoder.esm_model.attention_blocks.18.fc1.weight": "model-00001-of-00007.safetensors",
406
+ "biobrain_encoder.esm_model.attention_blocks.18.fc2.weight": "model-00001-of-00007.safetensors",
407
+ "biobrain_encoder.esm_model.attention_blocks.18.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
408
+ "biobrain_encoder.esm_model.attention_blocks.18.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
409
+ "biobrain_encoder.esm_model.attention_blocks.18.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
410
+ "biobrain_encoder.esm_model.attention_blocks.18.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
411
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.output.bias": "model-00001-of-00007.safetensors",
412
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.output.weight": "model-00001-of-00007.safetensors",
413
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_k.bias": "model-00001-of-00007.safetensors",
414
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_k.weight": "model-00001-of-00007.safetensors",
415
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_q.bias": "model-00001-of-00007.safetensors",
416
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_q.weight": "model-00001-of-00007.safetensors",
417
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_v.bias": "model-00001-of-00007.safetensors",
418
+ "biobrain_encoder.esm_model.attention_blocks.18.mha.w_v.weight": "model-00001-of-00007.safetensors",
419
+ "biobrain_encoder.esm_model.attention_blocks.19.fc1.weight": "model-00001-of-00007.safetensors",
420
+ "biobrain_encoder.esm_model.attention_blocks.19.fc2.weight": "model-00001-of-00007.safetensors",
421
+ "biobrain_encoder.esm_model.attention_blocks.19.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
422
+ "biobrain_encoder.esm_model.attention_blocks.19.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
423
+ "biobrain_encoder.esm_model.attention_blocks.19.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
424
+ "biobrain_encoder.esm_model.attention_blocks.19.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
425
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.output.bias": "model-00001-of-00007.safetensors",
426
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.output.weight": "model-00001-of-00007.safetensors",
427
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_k.bias": "model-00001-of-00007.safetensors",
428
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_k.weight": "model-00001-of-00007.safetensors",
429
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_q.bias": "model-00001-of-00007.safetensors",
430
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_q.weight": "model-00001-of-00007.safetensors",
431
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_v.bias": "model-00001-of-00007.safetensors",
432
+ "biobrain_encoder.esm_model.attention_blocks.19.mha.w_v.weight": "model-00001-of-00007.safetensors",
433
+ "biobrain_encoder.esm_model.attention_blocks.2.fc1.weight": "model-00001-of-00007.safetensors",
434
+ "biobrain_encoder.esm_model.attention_blocks.2.fc2.weight": "model-00001-of-00007.safetensors",
435
+ "biobrain_encoder.esm_model.attention_blocks.2.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
436
+ "biobrain_encoder.esm_model.attention_blocks.2.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
437
+ "biobrain_encoder.esm_model.attention_blocks.2.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
438
+ "biobrain_encoder.esm_model.attention_blocks.2.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
439
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.output.bias": "model-00001-of-00007.safetensors",
440
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.output.weight": "model-00001-of-00007.safetensors",
441
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_k.bias": "model-00001-of-00007.safetensors",
442
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_k.weight": "model-00001-of-00007.safetensors",
443
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_q.bias": "model-00001-of-00007.safetensors",
444
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_q.weight": "model-00001-of-00007.safetensors",
445
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_v.bias": "model-00001-of-00007.safetensors",
446
+ "biobrain_encoder.esm_model.attention_blocks.2.mha.w_v.weight": "model-00001-of-00007.safetensors",
447
+ "biobrain_encoder.esm_model.attention_blocks.20.fc1.weight": "model-00001-of-00007.safetensors",
448
+ "biobrain_encoder.esm_model.attention_blocks.20.fc2.weight": "model-00001-of-00007.safetensors",
449
+ "biobrain_encoder.esm_model.attention_blocks.20.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
450
+ "biobrain_encoder.esm_model.attention_blocks.20.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
451
+ "biobrain_encoder.esm_model.attention_blocks.20.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
452
+ "biobrain_encoder.esm_model.attention_blocks.20.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
453
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.output.bias": "model-00001-of-00007.safetensors",
454
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.output.weight": "model-00001-of-00007.safetensors",
455
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_k.bias": "model-00001-of-00007.safetensors",
456
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_k.weight": "model-00001-of-00007.safetensors",
457
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_q.bias": "model-00001-of-00007.safetensors",
458
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_q.weight": "model-00001-of-00007.safetensors",
459
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_v.bias": "model-00001-of-00007.safetensors",
460
+ "biobrain_encoder.esm_model.attention_blocks.20.mha.w_v.weight": "model-00001-of-00007.safetensors",
461
+ "biobrain_encoder.esm_model.attention_blocks.21.fc1.weight": "model-00001-of-00007.safetensors",
462
+ "biobrain_encoder.esm_model.attention_blocks.21.fc2.weight": "model-00001-of-00007.safetensors",
463
+ "biobrain_encoder.esm_model.attention_blocks.21.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
464
+ "biobrain_encoder.esm_model.attention_blocks.21.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
465
+ "biobrain_encoder.esm_model.attention_blocks.21.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
466
+ "biobrain_encoder.esm_model.attention_blocks.21.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
467
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.output.bias": "model-00001-of-00007.safetensors",
468
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.output.weight": "model-00001-of-00007.safetensors",
469
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_k.bias": "model-00001-of-00007.safetensors",
470
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_k.weight": "model-00001-of-00007.safetensors",
471
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_q.bias": "model-00001-of-00007.safetensors",
472
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_q.weight": "model-00001-of-00007.safetensors",
473
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_v.bias": "model-00001-of-00007.safetensors",
474
+ "biobrain_encoder.esm_model.attention_blocks.21.mha.w_v.weight": "model-00001-of-00007.safetensors",
475
+ "biobrain_encoder.esm_model.attention_blocks.22.fc1.weight": "model-00001-of-00007.safetensors",
476
+ "biobrain_encoder.esm_model.attention_blocks.22.fc2.weight": "model-00001-of-00007.safetensors",
477
+ "biobrain_encoder.esm_model.attention_blocks.22.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
478
+ "biobrain_encoder.esm_model.attention_blocks.22.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
479
+ "biobrain_encoder.esm_model.attention_blocks.22.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
480
+ "biobrain_encoder.esm_model.attention_blocks.22.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
481
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.output.bias": "model-00001-of-00007.safetensors",
482
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.output.weight": "model-00001-of-00007.safetensors",
483
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_k.bias": "model-00001-of-00007.safetensors",
484
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_k.weight": "model-00001-of-00007.safetensors",
485
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_q.bias": "model-00001-of-00007.safetensors",
486
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_q.weight": "model-00001-of-00007.safetensors",
487
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_v.bias": "model-00001-of-00007.safetensors",
488
+ "biobrain_encoder.esm_model.attention_blocks.22.mha.w_v.weight": "model-00001-of-00007.safetensors",
489
+ "biobrain_encoder.esm_model.attention_blocks.23.fc1.weight": "model-00001-of-00007.safetensors",
490
+ "biobrain_encoder.esm_model.attention_blocks.23.fc2.weight": "model-00001-of-00007.safetensors",
491
+ "biobrain_encoder.esm_model.attention_blocks.23.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
492
+ "biobrain_encoder.esm_model.attention_blocks.23.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
493
+ "biobrain_encoder.esm_model.attention_blocks.23.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
494
+ "biobrain_encoder.esm_model.attention_blocks.23.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
495
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.output.bias": "model-00001-of-00007.safetensors",
496
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.output.weight": "model-00001-of-00007.safetensors",
497
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_k.bias": "model-00001-of-00007.safetensors",
498
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_k.weight": "model-00001-of-00007.safetensors",
499
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_q.bias": "model-00001-of-00007.safetensors",
500
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_q.weight": "model-00001-of-00007.safetensors",
501
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_v.bias": "model-00001-of-00007.safetensors",
502
+ "biobrain_encoder.esm_model.attention_blocks.23.mha.w_v.weight": "model-00001-of-00007.safetensors",
503
+ "biobrain_encoder.esm_model.attention_blocks.24.fc1.weight": "model-00001-of-00007.safetensors",
504
+ "biobrain_encoder.esm_model.attention_blocks.24.fc2.weight": "model-00001-of-00007.safetensors",
505
+ "biobrain_encoder.esm_model.attention_blocks.24.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
506
+ "biobrain_encoder.esm_model.attention_blocks.24.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
507
+ "biobrain_encoder.esm_model.attention_blocks.24.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
508
+ "biobrain_encoder.esm_model.attention_blocks.24.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
509
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.output.bias": "model-00001-of-00007.safetensors",
510
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.output.weight": "model-00001-of-00007.safetensors",
511
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_k.bias": "model-00001-of-00007.safetensors",
512
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_k.weight": "model-00001-of-00007.safetensors",
513
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_q.bias": "model-00001-of-00007.safetensors",
514
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_q.weight": "model-00001-of-00007.safetensors",
515
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_v.bias": "model-00001-of-00007.safetensors",
516
+ "biobrain_encoder.esm_model.attention_blocks.24.mha.w_v.weight": "model-00001-of-00007.safetensors",
517
+ "biobrain_encoder.esm_model.attention_blocks.25.fc1.weight": "model-00001-of-00007.safetensors",
518
+ "biobrain_encoder.esm_model.attention_blocks.25.fc2.weight": "model-00001-of-00007.safetensors",
519
+ "biobrain_encoder.esm_model.attention_blocks.25.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
520
+ "biobrain_encoder.esm_model.attention_blocks.25.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
521
+ "biobrain_encoder.esm_model.attention_blocks.25.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
522
+ "biobrain_encoder.esm_model.attention_blocks.25.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
523
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.output.bias": "model-00001-of-00007.safetensors",
524
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.output.weight": "model-00001-of-00007.safetensors",
525
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_k.bias": "model-00001-of-00007.safetensors",
526
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_k.weight": "model-00001-of-00007.safetensors",
527
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_q.bias": "model-00001-of-00007.safetensors",
528
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_q.weight": "model-00001-of-00007.safetensors",
529
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_v.bias": "model-00001-of-00007.safetensors",
530
+ "biobrain_encoder.esm_model.attention_blocks.25.mha.w_v.weight": "model-00001-of-00007.safetensors",
531
+ "biobrain_encoder.esm_model.attention_blocks.26.fc1.weight": "model-00001-of-00007.safetensors",
532
+ "biobrain_encoder.esm_model.attention_blocks.26.fc2.weight": "model-00001-of-00007.safetensors",
533
+ "biobrain_encoder.esm_model.attention_blocks.26.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
534
+ "biobrain_encoder.esm_model.attention_blocks.26.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
535
+ "biobrain_encoder.esm_model.attention_blocks.26.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
536
+ "biobrain_encoder.esm_model.attention_blocks.26.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
537
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.output.bias": "model-00001-of-00007.safetensors",
538
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.output.weight": "model-00001-of-00007.safetensors",
539
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_k.bias": "model-00001-of-00007.safetensors",
540
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_k.weight": "model-00001-of-00007.safetensors",
541
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_q.bias": "model-00001-of-00007.safetensors",
542
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_q.weight": "model-00001-of-00007.safetensors",
543
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_v.bias": "model-00001-of-00007.safetensors",
544
+ "biobrain_encoder.esm_model.attention_blocks.26.mha.w_v.weight": "model-00001-of-00007.safetensors",
545
+ "biobrain_encoder.esm_model.attention_blocks.27.fc1.weight": "model-00001-of-00007.safetensors",
546
+ "biobrain_encoder.esm_model.attention_blocks.27.fc2.weight": "model-00001-of-00007.safetensors",
547
+ "biobrain_encoder.esm_model.attention_blocks.27.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
548
+ "biobrain_encoder.esm_model.attention_blocks.27.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
549
+ "biobrain_encoder.esm_model.attention_blocks.27.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
550
+ "biobrain_encoder.esm_model.attention_blocks.27.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
551
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.output.bias": "model-00001-of-00007.safetensors",
552
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.output.weight": "model-00001-of-00007.safetensors",
553
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_k.bias": "model-00001-of-00007.safetensors",
554
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_k.weight": "model-00001-of-00007.safetensors",
555
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_q.bias": "model-00001-of-00007.safetensors",
556
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_q.weight": "model-00001-of-00007.safetensors",
557
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_v.bias": "model-00001-of-00007.safetensors",
558
+ "biobrain_encoder.esm_model.attention_blocks.27.mha.w_v.weight": "model-00001-of-00007.safetensors",
559
+ "biobrain_encoder.esm_model.attention_blocks.28.fc1.weight": "model-00001-of-00007.safetensors",
560
+ "biobrain_encoder.esm_model.attention_blocks.28.fc2.weight": "model-00001-of-00007.safetensors",
561
+ "biobrain_encoder.esm_model.attention_blocks.28.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
562
+ "biobrain_encoder.esm_model.attention_blocks.28.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
563
+ "biobrain_encoder.esm_model.attention_blocks.28.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
564
+ "biobrain_encoder.esm_model.attention_blocks.28.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
565
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.output.bias": "model-00001-of-00007.safetensors",
566
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.output.weight": "model-00001-of-00007.safetensors",
567
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_k.bias": "model-00001-of-00007.safetensors",
568
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_k.weight": "model-00001-of-00007.safetensors",
569
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_q.bias": "model-00001-of-00007.safetensors",
570
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_q.weight": "model-00001-of-00007.safetensors",
571
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_v.bias": "model-00001-of-00007.safetensors",
572
+ "biobrain_encoder.esm_model.attention_blocks.28.mha.w_v.weight": "model-00001-of-00007.safetensors",
573
+ "biobrain_encoder.esm_model.attention_blocks.3.fc1.weight": "model-00001-of-00007.safetensors",
574
+ "biobrain_encoder.esm_model.attention_blocks.3.fc2.weight": "model-00001-of-00007.safetensors",
575
+ "biobrain_encoder.esm_model.attention_blocks.3.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
576
+ "biobrain_encoder.esm_model.attention_blocks.3.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
577
+ "biobrain_encoder.esm_model.attention_blocks.3.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
578
+ "biobrain_encoder.esm_model.attention_blocks.3.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
579
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.output.bias": "model-00001-of-00007.safetensors",
580
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.output.weight": "model-00001-of-00007.safetensors",
581
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_k.bias": "model-00001-of-00007.safetensors",
582
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_k.weight": "model-00001-of-00007.safetensors",
583
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_q.bias": "model-00001-of-00007.safetensors",
584
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_q.weight": "model-00001-of-00007.safetensors",
585
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_v.bias": "model-00001-of-00007.safetensors",
586
+ "biobrain_encoder.esm_model.attention_blocks.3.mha.w_v.weight": "model-00001-of-00007.safetensors",
587
+ "biobrain_encoder.esm_model.attention_blocks.4.fc1.weight": "model-00001-of-00007.safetensors",
588
+ "biobrain_encoder.esm_model.attention_blocks.4.fc2.weight": "model-00001-of-00007.safetensors",
589
+ "biobrain_encoder.esm_model.attention_blocks.4.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
590
+ "biobrain_encoder.esm_model.attention_blocks.4.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
591
+ "biobrain_encoder.esm_model.attention_blocks.4.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
592
+ "biobrain_encoder.esm_model.attention_blocks.4.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
593
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.output.bias": "model-00001-of-00007.safetensors",
594
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.output.weight": "model-00001-of-00007.safetensors",
595
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_k.bias": "model-00001-of-00007.safetensors",
596
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_k.weight": "model-00001-of-00007.safetensors",
597
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_q.bias": "model-00001-of-00007.safetensors",
598
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_q.weight": "model-00001-of-00007.safetensors",
599
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_v.bias": "model-00001-of-00007.safetensors",
600
+ "biobrain_encoder.esm_model.attention_blocks.4.mha.w_v.weight": "model-00001-of-00007.safetensors",
601
+ "biobrain_encoder.esm_model.attention_blocks.5.fc1.weight": "model-00001-of-00007.safetensors",
602
+ "biobrain_encoder.esm_model.attention_blocks.5.fc2.weight": "model-00001-of-00007.safetensors",
603
+ "biobrain_encoder.esm_model.attention_blocks.5.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
604
+ "biobrain_encoder.esm_model.attention_blocks.5.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
605
+ "biobrain_encoder.esm_model.attention_blocks.5.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
606
+ "biobrain_encoder.esm_model.attention_blocks.5.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
607
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.output.bias": "model-00001-of-00007.safetensors",
608
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.output.weight": "model-00001-of-00007.safetensors",
609
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_k.bias": "model-00001-of-00007.safetensors",
610
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_k.weight": "model-00001-of-00007.safetensors",
611
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_q.bias": "model-00001-of-00007.safetensors",
612
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_q.weight": "model-00001-of-00007.safetensors",
613
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_v.bias": "model-00001-of-00007.safetensors",
614
+ "biobrain_encoder.esm_model.attention_blocks.5.mha.w_v.weight": "model-00001-of-00007.safetensors",
615
+ "biobrain_encoder.esm_model.attention_blocks.6.fc1.weight": "model-00001-of-00007.safetensors",
616
+ "biobrain_encoder.esm_model.attention_blocks.6.fc2.weight": "model-00001-of-00007.safetensors",
617
+ "biobrain_encoder.esm_model.attention_blocks.6.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
618
+ "biobrain_encoder.esm_model.attention_blocks.6.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
619
+ "biobrain_encoder.esm_model.attention_blocks.6.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
620
+ "biobrain_encoder.esm_model.attention_blocks.6.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
621
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.output.bias": "model-00001-of-00007.safetensors",
622
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.output.weight": "model-00001-of-00007.safetensors",
623
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_k.bias": "model-00001-of-00007.safetensors",
624
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_k.weight": "model-00001-of-00007.safetensors",
625
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_q.bias": "model-00001-of-00007.safetensors",
626
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_q.weight": "model-00001-of-00007.safetensors",
627
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_v.bias": "model-00001-of-00007.safetensors",
628
+ "biobrain_encoder.esm_model.attention_blocks.6.mha.w_v.weight": "model-00001-of-00007.safetensors",
629
+ "biobrain_encoder.esm_model.attention_blocks.7.fc1.weight": "model-00001-of-00007.safetensors",
630
+ "biobrain_encoder.esm_model.attention_blocks.7.fc2.weight": "model-00001-of-00007.safetensors",
631
+ "biobrain_encoder.esm_model.attention_blocks.7.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
632
+ "biobrain_encoder.esm_model.attention_blocks.7.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
633
+ "biobrain_encoder.esm_model.attention_blocks.7.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
634
+ "biobrain_encoder.esm_model.attention_blocks.7.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
635
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.output.bias": "model-00001-of-00007.safetensors",
636
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.output.weight": "model-00001-of-00007.safetensors",
637
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_k.bias": "model-00001-of-00007.safetensors",
638
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_k.weight": "model-00001-of-00007.safetensors",
639
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_q.bias": "model-00001-of-00007.safetensors",
640
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_q.weight": "model-00001-of-00007.safetensors",
641
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_v.bias": "model-00001-of-00007.safetensors",
642
+ "biobrain_encoder.esm_model.attention_blocks.7.mha.w_v.weight": "model-00001-of-00007.safetensors",
643
+ "biobrain_encoder.esm_model.attention_blocks.8.fc1.weight": "model-00001-of-00007.safetensors",
644
+ "biobrain_encoder.esm_model.attention_blocks.8.fc2.weight": "model-00001-of-00007.safetensors",
645
+ "biobrain_encoder.esm_model.attention_blocks.8.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
646
+ "biobrain_encoder.esm_model.attention_blocks.8.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
647
+ "biobrain_encoder.esm_model.attention_blocks.8.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
648
+ "biobrain_encoder.esm_model.attention_blocks.8.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
649
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.output.bias": "model-00001-of-00007.safetensors",
650
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.output.weight": "model-00001-of-00007.safetensors",
651
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_k.bias": "model-00001-of-00007.safetensors",
652
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_k.weight": "model-00001-of-00007.safetensors",
653
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_q.bias": "model-00001-of-00007.safetensors",
654
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_q.weight": "model-00001-of-00007.safetensors",
655
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_v.bias": "model-00001-of-00007.safetensors",
656
+ "biobrain_encoder.esm_model.attention_blocks.8.mha.w_v.weight": "model-00001-of-00007.safetensors",
657
+ "biobrain_encoder.esm_model.attention_blocks.9.fc1.weight": "model-00001-of-00007.safetensors",
658
+ "biobrain_encoder.esm_model.attention_blocks.9.fc2.weight": "model-00001-of-00007.safetensors",
659
+ "biobrain_encoder.esm_model.attention_blocks.9.layer_norm_mlp.bias": "model-00001-of-00007.safetensors",
660
+ "biobrain_encoder.esm_model.attention_blocks.9.layer_norm_mlp.weight": "model-00001-of-00007.safetensors",
661
+ "biobrain_encoder.esm_model.attention_blocks.9.layer_norm_self_attention.bias": "model-00001-of-00007.safetensors",
662
+ "biobrain_encoder.esm_model.attention_blocks.9.layer_norm_self_attention.weight": "model-00001-of-00007.safetensors",
663
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.output.bias": "model-00001-of-00007.safetensors",
664
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.output.weight": "model-00001-of-00007.safetensors",
665
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_k.bias": "model-00001-of-00007.safetensors",
666
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_k.weight": "model-00001-of-00007.safetensors",
667
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_q.bias": "model-00001-of-00007.safetensors",
668
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_q.weight": "model-00001-of-00007.safetensors",
669
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_v.bias": "model-00001-of-00007.safetensors",
670
+ "biobrain_encoder.esm_model.attention_blocks.9.mha.w_v.weight": "model-00001-of-00007.safetensors",
671
+ "biobrain_encoder.esm_model.embed_layer.weight": "model-00001-of-00007.safetensors",
672
+ "biobrain_encoder.esm_model.lm_head._fc1.bias": "model-00001-of-00007.safetensors",
673
+ "biobrain_encoder.esm_model.lm_head._fc1.weight": "model-00001-of-00007.safetensors",
674
+ "biobrain_encoder.esm_model.lm_head._final_fc.bias": "model-00001-of-00007.safetensors",
675
+ "biobrain_encoder.esm_model.lm_head._final_fc.weight": "model-00001-of-00007.safetensors",
676
+ "biobrain_encoder.esm_model.lm_head._first_layer_norm.bias": "model-00001-of-00007.safetensors",
677
+ "biobrain_encoder.esm_model.lm_head._first_layer_norm.weight": "model-00001-of-00007.safetensors",
678
+ "biobrain_encoder.esm_model.lm_head._second_layer_norm.bias": "model-00001-of-00007.safetensors",
679
+ "biobrain_encoder.esm_model.lm_head._second_layer_norm.weight": "model-00001-of-00007.safetensors",
680
+ "projection_model.bio_projection.bias": "model-00006-of-00007.safetensors",
681
+ "projection_model.bio_projection.weight": "model-00006-of-00007.safetensors",
682
+ "projection_model.perceiver_resampler.latent_queries": "model-00006-of-00007.safetensors",
683
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.output.bias": "model-00007-of-00007.safetensors",
684
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.output.weight": "model-00007-of-00007.safetensors",
685
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_k.bias": "model-00006-of-00007.safetensors",
686
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_k.weight": "model-00006-of-00007.safetensors",
687
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_q.bias": "model-00006-of-00007.safetensors",
688
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_q.weight": "model-00006-of-00007.safetensors",
689
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_v.bias": "model-00007-of-00007.safetensors",
690
+ "projection_model.perceiver_resampler.layers.0.cross_attention_1.w_v.weight": "model-00007-of-00007.safetensors",
691
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.output.bias": "model-00007-of-00007.safetensors",
692
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.output.weight": "model-00007-of-00007.safetensors",
693
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_k.bias": "model-00007-of-00007.safetensors",
694
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_k.weight": "model-00007-of-00007.safetensors",
695
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_q.bias": "model-00007-of-00007.safetensors",
696
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_q.weight": "model-00007-of-00007.safetensors",
697
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_v.bias": "model-00007-of-00007.safetensors",
698
+ "projection_model.perceiver_resampler.layers.0.cross_attention_2.w_v.weight": "model-00007-of-00007.safetensors",
699
+ "projection_model.perceiver_resampler.layers.0.fc1.bias": "model-00007-of-00007.safetensors",
700
+ "projection_model.perceiver_resampler.layers.0.fc1.weight": "model-00007-of-00007.safetensors",
701
+ "projection_model.perceiver_resampler.layers.0.fc2.bias": "model-00007-of-00007.safetensors",
702
+ "projection_model.perceiver_resampler.layers.0.fc2.weight": "model-00007-of-00007.safetensors",
703
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_1.bias": "model-00007-of-00007.safetensors",
704
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_1.weight": "model-00007-of-00007.safetensors",
705
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_2.bias": "model-00007-of-00007.safetensors",
706
+ "projection_model.perceiver_resampler.layers.0.norm_cross_attention_2.weight": "model-00007-of-00007.safetensors",
707
+ "projection_model.perceiver_resampler.layers.0.norm_mlp.bias": "model-00007-of-00007.safetensors",
708
+ "projection_model.perceiver_resampler.layers.0.norm_mlp.weight": "model-00007-of-00007.safetensors",
709
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.output.bias": "model-00007-of-00007.safetensors",
710
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.output.weight": "model-00007-of-00007.safetensors",
711
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_k.bias": "model-00007-of-00007.safetensors",
712
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_k.weight": "model-00007-of-00007.safetensors",
713
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_q.bias": "model-00007-of-00007.safetensors",
714
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_q.weight": "model-00007-of-00007.safetensors",
715
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_v.bias": "model-00007-of-00007.safetensors",
716
+ "projection_model.perceiver_resampler.layers.1.cross_attention_1.w_v.weight": "model-00007-of-00007.safetensors",
717
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.output.bias": "model-00007-of-00007.safetensors",
718
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.output.weight": "model-00007-of-00007.safetensors",
719
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_k.bias": "model-00007-of-00007.safetensors",
720
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_k.weight": "model-00007-of-00007.safetensors",
721
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_q.bias": "model-00007-of-00007.safetensors",
722
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_q.weight": "model-00007-of-00007.safetensors",
723
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_v.bias": "model-00007-of-00007.safetensors",
724
+ "projection_model.perceiver_resampler.layers.1.cross_attention_2.w_v.weight": "model-00007-of-00007.safetensors",
725
+ "projection_model.perceiver_resampler.layers.1.fc1.bias": "model-00007-of-00007.safetensors",
726
+ "projection_model.perceiver_resampler.layers.1.fc1.weight": "model-00007-of-00007.safetensors",
727
+ "projection_model.perceiver_resampler.layers.1.fc2.bias": "model-00007-of-00007.safetensors",
728
+ "projection_model.perceiver_resampler.layers.1.fc2.weight": "model-00007-of-00007.safetensors",
729
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_1.bias": "model-00007-of-00007.safetensors",
730
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_1.weight": "model-00007-of-00007.safetensors",
731
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_2.bias": "model-00007-of-00007.safetensors",
732
+ "projection_model.perceiver_resampler.layers.1.norm_cross_attention_2.weight": "model-00007-of-00007.safetensors",
733
+ "projection_model.perceiver_resampler.layers.1.norm_mlp.bias": "model-00007-of-00007.safetensors",
734
+ "projection_model.perceiver_resampler.layers.1.norm_mlp.weight": "model-00007-of-00007.safetensors",
735
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.output.bias": "model-00007-of-00007.safetensors",
736
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.output.weight": "model-00007-of-00007.safetensors",
737
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_k.bias": "model-00007-of-00007.safetensors",
738
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_k.weight": "model-00007-of-00007.safetensors",
739
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_q.bias": "model-00007-of-00007.safetensors",
740
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_q.weight": "model-00007-of-00007.safetensors",
741
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_v.bias": "model-00007-of-00007.safetensors",
742
+ "projection_model.perceiver_resampler.layers.2.cross_attention_1.w_v.weight": "model-00007-of-00007.safetensors",
743
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.output.bias": "model-00007-of-00007.safetensors",
744
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.output.weight": "model-00007-of-00007.safetensors",
745
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_k.bias": "model-00007-of-00007.safetensors",
746
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_k.weight": "model-00007-of-00007.safetensors",
747
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_q.bias": "model-00007-of-00007.safetensors",
748
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_q.weight": "model-00007-of-00007.safetensors",
749
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_v.bias": "model-00007-of-00007.safetensors",
750
+ "projection_model.perceiver_resampler.layers.2.cross_attention_2.w_v.weight": "model-00007-of-00007.safetensors",
751
+ "projection_model.perceiver_resampler.layers.2.fc1.bias": "model-00007-of-00007.safetensors",
752
+ "projection_model.perceiver_resampler.layers.2.fc1.weight": "model-00007-of-00007.safetensors",
753
+ "projection_model.perceiver_resampler.layers.2.fc2.bias": "model-00007-of-00007.safetensors",
754
+ "projection_model.perceiver_resampler.layers.2.fc2.weight": "model-00007-of-00007.safetensors",
755
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_1.bias": "model-00007-of-00007.safetensors",
756
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_1.weight": "model-00007-of-00007.safetensors",
757
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_2.bias": "model-00007-of-00007.safetensors",
758
+ "projection_model.perceiver_resampler.layers.2.norm_cross_attention_2.weight": "model-00007-of-00007.safetensors",
759
+ "projection_model.perceiver_resampler.layers.2.norm_mlp.bias": "model-00007-of-00007.safetensors",
760
+ "projection_model.perceiver_resampler.layers.2.norm_mlp.weight": "model-00007-of-00007.safetensors",
761
+ "projection_model.token_embedding.weight": "model-00006-of-00007.safetensors"
762
+ }
763
+ }