mgelard commited on
Commit
8f7201c
·
verified ·
1 Parent(s): 25b4ce1

Upload MOJO

Browse files
Files changed (4) hide show
  1. README.md +199 -0
  2. config.json +44 -0
  3. model.safetensors +3 -0
  4. mojo.py +763 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alphabet_size": {
3
+ "methylation": 66,
4
+ "rnaseq": 66
5
+ },
6
+ "architectures": [
7
+ "MOJO"
8
+ ],
9
+ "attention_maps_to_save": [],
10
+ "auto_map": {
11
+ "AutoConfig": "mojo.MOJOConfig",
12
+ "AutoModel": "mojo.MOJO"
13
+ },
14
+ "conv_init_embed_dim": 512,
15
+ "embed_dim": 512,
16
+ "embeddings_layers_to_save": [],
17
+ "ffn_embed_dim": 1024,
18
+ "filter_list": [
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512,
25
+ 512,
26
+ 512,
27
+ 512
28
+ ],
29
+ "fixed_sequence_length": 17152,
30
+ "init_gene_embed_dim": 200,
31
+ "key_size": 32,
32
+ "model_type": "MOJO",
33
+ "num_attention_heads": 16,
34
+ "num_downsamples": 8,
35
+ "num_hidden_layers_head": 1,
36
+ "num_layers": 8,
37
+ "project_gene_embedding": true,
38
+ "sequence_length": 17116,
39
+ "stem_kernel_shape": 15,
40
+ "token_embed_dim": 256,
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.37.2",
43
+ "use_gene_embedding": true
44
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98055d27c5646170a1650531a4f410cdee34d51adf9efeee25723c77af8ef0a4
3
+ size 209206776
mojo.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F # noqa: N812
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+
12
+
13
+ @dataclass
14
+ class RotaryEmbeddingConfig:
15
+ """
16
+ Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows
17
+ to adapt the rotary embeddings to larger lengths than what was used for training.
18
+ One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa
19
+ Args:
20
+ """
21
+
22
+ rescaling_factor: Optional[float]
23
+
24
+
25
+ class RotaryEmbedding(torch.nn.Module):
26
+ """
27
+ Rotary position embeddings based on those in
28
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).
29
+ Query and keys are transformed by rotation
30
+ matrices which depend on their relative positions.
31
+ """
32
+
33
+ def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig):
34
+ super().__init__()
35
+
36
+ # Extract argument from the config
37
+ self.rescaling_factor = rotary_embedding_config.rescaling_factor
38
+ self.upper_freq = 10000
39
+ self.dim = dim
40
+
41
+ self._seq_len_cached = None
42
+ self._cos_cached = None
43
+ self._sin_cached = None
44
+
45
+ def _apply_rotary_pos_emb(
46
+ self,
47
+ heads: torch.Tensor,
48
+ cos: torch.Tensor,
49
+ sin: torch.Tensor,
50
+ ) -> torch.Tensor:
51
+ """ """
52
+ x_first, x_second = (
53
+ heads[..., : heads.shape[-1] // 2],
54
+ heads[..., heads.shape[-1] // 2 :],
55
+ )
56
+
57
+ first_part = x_first * cos - x_second * sin
58
+ second_part = x_second * cos + x_first * sin
59
+
60
+ return torch.cat((first_part, second_part), dim=-1)
61
+
62
+ def _compute_cos_sin_tables(
63
+ self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
64
+ ) -> tuple[torch.Tensor, torch.Tensor]:
65
+ seq_len = x.shape[seq_dimension]
66
+ # Reset the tables if the sequence length has changed,
67
+ # or if we're on a new device (possibly due to tracing for instance)
68
+ self._seq_len_cached = seq_len
69
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
70
+ freqs = torch.einsum("i, j -> ij", t, inv_freq)
71
+
72
+ self._cos_cached = torch.cos(freqs)[None, :, None, :]
73
+ self._sin_cached = torch.sin(freqs)[None, :, None, :]
74
+ return self._cos_cached, self._sin_cached
75
+
76
+ def forward(
77
+ self, q: torch.Tensor, k: torch.Tensor
78
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ if self.rescaling_factor is None:
80
+ inv_freq = 1.0 / (
81
+ self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim)
82
+ )
83
+ else:
84
+ updated_base = self.upper_freq * (
85
+ self.rescaling_factor ** (self.dim / (self.dim - 2))
86
+ )
87
+ inv_freq = 1.0 / (
88
+ updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
89
+ )
90
+
91
+ self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
92
+ q,
93
+ inv_freq,
94
+ seq_dimension=-3,
95
+ )
96
+
97
+ return (
98
+ self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
99
+ self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
100
+ )
101
+
102
+
103
+ class ResidualConvBlock(nn.Module):
104
+ """
105
+ Conv Block with Residual connection.
106
+ """
107
+
108
+ def __init__(
109
+ self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1
110
+ ):
111
+ super().__init__()
112
+ self.conv_block = ConvBlock(
113
+ dim_in=dim_in,
114
+ dim_out=dim_out,
115
+ layer_norm_shape=layer_norm_shape,
116
+ kernel_size=kernel_size,
117
+ )
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ y = self.conv_block(x)
121
+ return x.reshape(y.shape) + y
122
+
123
+
124
+ class ConvBlock(nn.Module):
125
+ """
126
+ Conv Block.
127
+ """
128
+
129
+ def __init__(
130
+ self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1
131
+ ):
132
+ super().__init__()
133
+ self.conv = nn.Conv1d(
134
+ in_channels=dim_in,
135
+ out_channels=dim_out,
136
+ kernel_size=kernel_size,
137
+ padding="same",
138
+ )
139
+ self.layer_norm = nn.LayerNorm(layer_norm_shape, eps=1e-5)
140
+
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+ x = x.permute(0, 2, 1)
143
+ x = self.layer_norm(x)
144
+ x = x.permute(0, 2, 1)
145
+ x = self.conv(x)
146
+ x = F.gelu(x, approximate="tanh")
147
+ return x
148
+
149
+
150
+ class ConvTowerBlock(nn.Module):
151
+ def __init__(
152
+ self,
153
+ dim_in: int,
154
+ dim_out: int,
155
+ conv_layer_norm_shape: int,
156
+ resconv_layer_norm_shape,
157
+ kernel_size: int,
158
+ ) -> None:
159
+ super().__init__()
160
+ self.conv_layer = ConvBlock(
161
+ dim_in=dim_in,
162
+ dim_out=dim_out,
163
+ layer_norm_shape=conv_layer_norm_shape,
164
+ kernel_size=kernel_size,
165
+ )
166
+ self.res_conv = ResidualConvBlock(
167
+ dim_in=dim_out,
168
+ dim_out=dim_out,
169
+ layer_norm_shape=resconv_layer_norm_shape,
170
+ kernel_size=1,
171
+ )
172
+ self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2)
173
+
174
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
175
+ residual = x
176
+ x = self.conv_layer(x)
177
+ x = self.res_conv(x)
178
+ x = self.avg_pool(x)
179
+ return x, residual
180
+
181
+
182
+ class ResidualDeConvBlock(nn.Module):
183
+ """
184
+ Conv Block with Residual connection.
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ dim_in: int,
190
+ dim_out: int,
191
+ layer_norm_shape: int,
192
+ kernel_size: int = 1,
193
+ stride: int = 1,
194
+ ):
195
+ super().__init__()
196
+ self.deconv_block = DeConvBlock(
197
+ dim_in=dim_in,
198
+ dim_out=dim_out,
199
+ layer_norm_shape=layer_norm_shape,
200
+ kernel_size=kernel_size,
201
+ stride=stride,
202
+ )
203
+
204
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
205
+ y = self.deconv_block(x)
206
+ return x.reshape(y.shape) + y
207
+
208
+
209
+ class DeConvBlock(nn.Module):
210
+ """
211
+ DeConv Block.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ dim_in: int,
217
+ dim_out: int,
218
+ layer_norm_shape: int,
219
+ kernel_size: int = 1,
220
+ stride: int = 1,
221
+ ):
222
+ super().__init__()
223
+ self.deconv = nn.ConvTranspose1d(
224
+ in_channels=dim_in,
225
+ out_channels=dim_out,
226
+ kernel_size=kernel_size,
227
+ stride=stride,
228
+ padding=0,
229
+ )
230
+ self.layer_norm = nn.LayerNorm(layer_norm_shape)
231
+ self.kernel_size = kernel_size
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ x = x.permute(0, 2, 1)
235
+ x = self.layer_norm(x)
236
+ x = x.permute(0, 2, 1)
237
+ x = self.deconv(x)
238
+ if self.kernel_size == 5:
239
+ # handle the special case where haiku
240
+ # deconv removes padding automatically
241
+ x = x[:, :, 1:-2]
242
+ x = F.gelu(x, approximate="tanh")
243
+ return x
244
+
245
+
246
+ class DeConvTowerBlock(nn.Module):
247
+ def __init__(
248
+ self,
249
+ dim_in: int,
250
+ dim_out: int,
251
+ kernel_size: int,
252
+ conv_layer_norm_shape: int,
253
+ resconv_layer_norm_shape: int,
254
+ stride: int = 2,
255
+ ):
256
+ super().__init__()
257
+ self.deconv_block = DeConvBlock(
258
+ dim_in=dim_in,
259
+ dim_out=dim_out,
260
+ layer_norm_shape=conv_layer_norm_shape,
261
+ kernel_size=kernel_size,
262
+ stride=stride,
263
+ )
264
+ self.res_deconv_block = ResidualDeConvBlock(
265
+ dim_in=dim_out,
266
+ dim_out=dim_out,
267
+ layer_norm_shape=resconv_layer_norm_shape,
268
+ kernel_size=1,
269
+ )
270
+
271
+ def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
272
+ x = self.deconv_block(x)
273
+ x = self.res_deconv_block(x)
274
+ x = x + res
275
+ return x
276
+
277
+
278
+ class MultiHeadAttention(nn.Module):
279
+ def __init__(
280
+ self,
281
+ num_heads: int,
282
+ key_size: int,
283
+ rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None,
284
+ add_bias_kv: bool = False,
285
+ value_size: Optional[int] = None,
286
+ model_size: Optional[int] = None,
287
+ name: Optional[str] = None,
288
+ ):
289
+ super().__init__()
290
+ if not model_size:
291
+ model_size = key_size
292
+ if not value_size:
293
+ value_size = key_size
294
+ self.model_size = model_size
295
+ self.key_size = key_size
296
+ self.value_size = value_size
297
+ self.add_bias_kv = add_bias_kv
298
+ self.name = name
299
+ self.num_heads = num_heads
300
+ self._rotary_embedding_config = rotary_embedding_config
301
+
302
+ self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
303
+ self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
304
+ self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
305
+ self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
306
+ if self._rotary_embedding_config:
307
+ self._rotary_embedding = RotaryEmbedding(
308
+ self.key_size, self._rotary_embedding_config
309
+ )
310
+
311
+ def apply_rotary_embeddings(
312
+ self,
313
+ query: torch.Tensor,
314
+ key: torch.Tensor,
315
+ ) -> tuple[torch.Tensor, torch.Tensor]:
316
+ """ """
317
+ query, key = self._rotary_embedding(query, key)
318
+ return query, key
319
+
320
+ def forward(
321
+ self,
322
+ query: torch.Tensor,
323
+ key: torch.Tensor,
324
+ value: torch.Tensor,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ attention_weight_bias: Optional[torch.Tensor] = None,
327
+ ) -> dict[str, torch.Tensor]:
328
+ """
329
+ Returns:
330
+ dictionary containing attention weights
331
+ and outputs.
332
+ """
333
+ key_heads = self.w_k(key).reshape(
334
+ (*key.shape[:-1], self.num_heads, self.key_size)
335
+ )
336
+ query_heads = self.w_q(query).reshape(
337
+ (*query.shape[:-1], self.num_heads, self.key_size)
338
+ )
339
+ value_heads = self.w_v(value).reshape(
340
+ (*value.shape[:-1], self.num_heads, self.value_size)
341
+ )
342
+ if self._rotary_embedding_config:
343
+ query_heads, key_heads = self.apply_rotary_embeddings(
344
+ query_heads, key_heads
345
+ )
346
+ attention_weights = torch.einsum(
347
+ "...thd, ...Thd -> ...htT", query_heads, key_heads
348
+ )
349
+ sqrt_key_size = np.sqrt(self.key_size)
350
+ attention_weights = attention_weights / sqrt_key_size
351
+ if attention_mask:
352
+ attention_weights = torch.where(attention_mask, attention_weights, -1e30)
353
+ if attention_weight_bias:
354
+ attention_weights = F.softmax(
355
+ attention_weights + attention_weight_bias, dim=-1
356
+ )
357
+ else:
358
+ attention_weights = F.softmax(attention_weights, dim=-1)
359
+ value_out = torch.einsum(
360
+ "...htT, ...Thd->...thd", attention_weights, value_heads
361
+ )
362
+ value_out = value_out.reshape((*value_out.shape[:-2], -1))
363
+ embeddings = self.output(value_out)
364
+
365
+ return {"attention_weights": attention_weights, "embeddings": embeddings}
366
+
367
+
368
+ class SelfAttentionBlock(nn.Module):
369
+ def __init__(
370
+ self,
371
+ num_heads: int,
372
+ embed_dim: int,
373
+ ffn_embed_dim: int,
374
+ key_size: Optional[int] = None,
375
+ add_bias_kv: bool = False,
376
+ add_bias_fnn: bool = True,
377
+ ffn_activation_name: str = "gelu-no-approx",
378
+ use_glu_in_ffn: bool = False,
379
+ layer_norm_eps: float = 1e-5, # this is the default haiku value
380
+ pre_layer_norm: bool = True,
381
+ name: Optional[str] = None,
382
+ rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None,
383
+ ):
384
+ super().__init__()
385
+ if key_size is None:
386
+ if embed_dim % num_heads != 0:
387
+ raise ValueError(
388
+ f"The embedding dimension should be divisible by the number of "
389
+ f"heads, however provided embedding dimension is {embed_dim} and "
390
+ f"the number of heads is {num_heads}."
391
+ )
392
+ else:
393
+ key_size = embed_dim // num_heads
394
+
395
+ # Get ffn activation function
396
+ self._pre_layer_norm = pre_layer_norm
397
+ self._use_glu_in_fnn = use_glu_in_ffn
398
+ # Define layers
399
+ if use_glu_in_ffn:
400
+ # user should multiply ffn_embed_dim by 2/3 when using GLU
401
+ # to keep total number of parameters equal
402
+ # see https://arxiv.org/pdf/2002.05202.pdf. for more details
403
+ # we multiply by 2 here as the output will be split in 2 for GLU
404
+ self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
405
+ else:
406
+ self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
407
+
408
+ self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
409
+
410
+ self.layer_norm_self_attention = nn.LayerNorm(
411
+ embed_dim,
412
+ )
413
+ self.layer_norm_mlp = nn.LayerNorm(embed_dim)
414
+ if ffn_activation_name == "swish":
415
+ self._ffn_activation_fn = nn.SiLU()
416
+ elif ffn_activation_name == "gelu-no-approx":
417
+ self._ffn_activation_fn = lambda x: F.gelu(x, approximate="none")
418
+ else:
419
+ self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
420
+
421
+ self.mha = MultiHeadAttention(
422
+ num_heads=num_heads,
423
+ key_size=key_size,
424
+ add_bias_kv=add_bias_kv,
425
+ model_size=embed_dim,
426
+ name="self_attention",
427
+ rotary_embedding_config=rotary_embedding_config,
428
+ )
429
+
430
+ def mlp(self, embed: torch.Tensor) -> torch.Tensor:
431
+
432
+ if self._pre_layer_norm:
433
+ x = self.layer_norm_mlp(embed)
434
+ else:
435
+ x = embed
436
+
437
+ if self._use_glu_in_fnn:
438
+ x = self.fc1(x)
439
+ x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
440
+ x = self._ffn_activation_fn(x1) * x2
441
+ else:
442
+ x = self._ffn_activation_fn(self.fc1(x))
443
+ x = self.fc2(x)
444
+
445
+ if not self._pre_layer_norm:
446
+ x = self.layer_norm_mlp(x + embed)
447
+ return x
448
+
449
+ def forward(
450
+ self,
451
+ x: torch.Tensor,
452
+ attention_mask: Optional[torch.Tensor] = None,
453
+ attention_weight_bias: Optional[torch.Tensor] = None,
454
+ ) -> torch.Tensor:
455
+
456
+ res = x
457
+ if self._pre_layer_norm:
458
+ x = self.layer_norm_self_attention(x)
459
+
460
+ output = self.mha(
461
+ x,
462
+ x,
463
+ x,
464
+ attention_mask=attention_mask,
465
+ attention_weight_bias=attention_weight_bias,
466
+ )
467
+
468
+ if not self._pre_layer_norm:
469
+ output["embeddings"] = self.layer_norm_self_attention(
470
+ output["embeddings"] + res
471
+ )
472
+
473
+ x = output["embeddings"]
474
+ else:
475
+ x = output["embeddings"]
476
+ x = res + x
477
+
478
+ # MLP
479
+ if not self._pre_layer_norm:
480
+ x = self.mlp(x)
481
+ else:
482
+ x = x + self.mlp(x)
483
+
484
+ output["embeddings"] = x
485
+ return output
486
+
487
+
488
+ class LMHead(nn.Module):
489
+ def __init__(
490
+ self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int
491
+ ) -> None:
492
+ """ """
493
+ super().__init__()
494
+ self.num_hidden_layers = num_hidden_layers
495
+ self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)])
496
+ self.linear_layers.extend(
497
+ nn.ModuleList(
498
+ [nn.Linear(embed_dim, embed_dim)] # noqa
499
+ for _ in range(num_hidden_layers - 1)
500
+ )
501
+ )
502
+ self.linear_out = nn.Linear(embed_dim, dim_out)
503
+
504
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
505
+ x = F.gelu(x, approximate="tanh")
506
+ for layer in self.linear_layers:
507
+ x = layer(x)
508
+ x = F.gelu(x, approximate="tanh")
509
+ out = self.linear_out(x)
510
+ return out
511
+
512
+
513
+ @dataclass
514
+ class MOJOConfig(PretrainedConfig): # noqa: N801
515
+ model_type = "MOJO"
516
+ alphabet_size: dict[str, int] = field(
517
+ default_factory=lambda: {"rnaseq": 66, "methylation": 66}
518
+ )
519
+ token_embed_dim: int = 256
520
+ init_gene_embed_dim: int = 200
521
+ use_gene_embedding: bool = True
522
+ project_gene_embedding: bool = True
523
+ sequence_length: int = 17_116 # n_genes
524
+ fixed_sequence_length: int | None = None
525
+ num_downsamples: int = 8
526
+ conv_init_embed_dim: int = 512
527
+ stem_kernel_shape: int = 15
528
+ embed_dim: int = 512
529
+ filter_list: list[int] = field(default_factory=list)
530
+ num_attention_heads: int = 16
531
+ key_size: Optional[int] = None
532
+ ffn_embed_dim: int = 1_024
533
+ num_layers: int = 8
534
+ num_hidden_layers_head: int = 1
535
+
536
+ # return
537
+ embeddings_layers_to_save: tuple[int, ...] = field(default_factory=tuple)
538
+ attention_maps_to_save: list[tuple[int, int]] = field(default_factory=list)
539
+
540
+ def __post_init__(self):
541
+ # Validate attention key size
542
+ key_size = self.key_size
543
+ if key_size is None:
544
+ embed_dim = self.embed_dim
545
+ num_attention_heads = self.num_attention_heads
546
+ if not embed_dim % num_attention_heads == 0:
547
+ raise ValueError(
548
+ f"When no key size is provided, the embedding dimension should be "
549
+ f"divisible by the number of heads, however provided embedding "
550
+ f"dimension is {embed_dim} and the number of heads is "
551
+ f"{num_attention_heads}."
552
+ )
553
+ self.key_size = embed_dim // num_attention_heads
554
+
555
+ # Validate gene embedding projection
556
+ use_gene_embedding = self.use_gene_embedding
557
+ if use_gene_embedding:
558
+ init_gene_embed_dim = self.init_gene_embed_dim
559
+ token_embed_dim = self.token_embed_dim
560
+ if init_gene_embed_dim != token_embed_dim:
561
+ project_gene_embedding = self.project_gene_embedding
562
+ if not project_gene_embedding:
563
+ logging.warning(
564
+ f"Init gene embedding dimension ({init_gene_embed_dim})"
565
+ f"different than token embedding dimension ({token_embed_dim})."
566
+ f"Setting `project_gene_embedding` to True"
567
+ )
568
+ self.project_gene_embedding = True
569
+
570
+ # Compute fixed_sequence_length
571
+ num_downsamples = self.num_downsamples
572
+ sequence_length = self.sequence_length
573
+ downsample_factor = 2**num_downsamples
574
+ fixed_sequence_length = (
575
+ math.ceil(sequence_length / downsample_factor) * downsample_factor
576
+ )
577
+ self.fixed_sequence_length = fixed_sequence_length
578
+
579
+ # Create filters list
580
+ num_downsamples = self.num_downsamples
581
+ filter_list = (
582
+ np.linspace(
583
+ self.conv_init_embed_dim,
584
+ self.embed_dim,
585
+ num_downsamples + 1,
586
+ )
587
+ .astype(int)
588
+ .tolist()
589
+ )
590
+ self.filter_list = filter_list # noqa
591
+
592
+
593
+ class MOJO(PreTrainedModel): # noqa: N801
594
+ config_class = MOJOConfig
595
+
596
+ def __init__(self, config: MOJOConfig):
597
+ super().__init__(config=config)
598
+
599
+ # Embeddings
600
+ self.embedding_layers = nn.ModuleDict(
601
+ {
602
+ omic: nn.Embedding(config.alphabet_size[omic], config.token_embed_dim)
603
+ for omic in config.alphabet_size
604
+ }
605
+ )
606
+
607
+ self.gene_embedding_layer = nn.Embedding(
608
+ config.fixed_sequence_length,
609
+ config.init_gene_embed_dim,
610
+ )
611
+ self.fc_gene_embedding = nn.Linear(
612
+ config.init_gene_embed_dim, config.token_embed_dim
613
+ )
614
+
615
+ # Convolutions
616
+ self.stem_conv = nn.Sequential(
617
+ nn.Conv1d(
618
+ in_channels=config.token_embed_dim,
619
+ out_channels=config.conv_init_embed_dim,
620
+ kernel_size=config.stem_kernel_shape,
621
+ padding="same",
622
+ ),
623
+ nn.GELU(approximate="tanh"),
624
+ )
625
+
626
+ self.conv_tower = nn.ModuleList(
627
+ [
628
+ ConvTowerBlock(
629
+ dim_in=config.filter_list[i],
630
+ dim_out=config.filter_list[i + 1],
631
+ kernel_size=5,
632
+ conv_layer_norm_shape=config.filter_list[i],
633
+ resconv_layer_norm_shape=config.filter_list[i + 1],
634
+ )
635
+ for i in range(len(config.filter_list) - 1)
636
+ ]
637
+ )
638
+
639
+ # Transformer
640
+ attention_maps_to_save = config.attention_maps_to_save
641
+ self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save})
642
+
643
+ self._attention_maps_per_layer_to_save = {
644
+ layer: [t[1] for t in attention_maps_to_save if t[0] == layer]
645
+ for layer in self._attention_layers_to_save
646
+ }
647
+
648
+ max_layer = max(self._attention_layers_to_save + [0])
649
+ if max_layer > config.num_layers:
650
+ raise ValueError(
651
+ f"You are requiring attention maps for layer {max_layer}, "
652
+ f"while the model has {config.num_layers} layers only."
653
+ )
654
+ self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None)
655
+ self.transformer_layers = nn.ModuleList(
656
+ [
657
+ SelfAttentionBlock(
658
+ num_heads=config.num_attention_heads,
659
+ embed_dim=config.embed_dim,
660
+ ffn_embed_dim=config.ffn_embed_dim,
661
+ key_size=config.key_size,
662
+ add_bias_kv=False,
663
+ add_bias_fnn=False,
664
+ ffn_activation_name="swish",
665
+ use_glu_in_ffn=True,
666
+ layer_norm_eps=1e-5, # this is the default haiku value
667
+ pre_layer_norm=True,
668
+ name=f"attention_layer_{layer_idx}",
669
+ rotary_embedding_config=self._rotary_embedding_config,
670
+ )
671
+ for layer_idx in range(config.num_layers)
672
+ ]
673
+ )
674
+
675
+ # Deconvolutions
676
+ self.deconv_tower = nn.ModuleList(
677
+ [
678
+ DeConvTowerBlock(
679
+ dim_in=config.filter_list[-1 - i],
680
+ dim_out=config.filter_list[-1 - i - 1],
681
+ kernel_size=5,
682
+ stride=2,
683
+ conv_layer_norm_shape=config.filter_list[-1 - i],
684
+ resconv_layer_norm_shape=config.filter_list[-1 - i - 1],
685
+ )
686
+ for i in range(len(config.filter_list) - 1)
687
+ ]
688
+ )
689
+
690
+ # Language Modeling heads
691
+ self.omic_lm_heads = nn.ModuleDict(
692
+ {
693
+ omic: LMHead(
694
+ dim_in=config.conv_init_embed_dim,
695
+ embed_dim=config.embed_dim,
696
+ dim_out=config.alphabet_size[omic],
697
+ num_hidden_layers=config.num_hidden_layers_head,
698
+ )
699
+ for omic in self.config.alphabet_size
700
+ }
701
+ )
702
+
703
+ def get_embeddings(
704
+ self,
705
+ input_ids: dict[str, torch.Tensor],
706
+ ) -> dict[str, torch.Tensor]:
707
+ omic_embeddings = {}
708
+ for omic, omic_tokens in input_ids.items():
709
+ omic_embeddings[omic] = self.embedding_layers[omic](omic_tokens)
710
+ return omic_embeddings
711
+
712
+ def forward(self, input_ids: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
713
+ outs = {}
714
+ embeddings = self.get_embeddings(input_ids)
715
+ outs["omic_embeddings"] = embeddings
716
+ x = torch.stack(list(embeddings.values()), dim=0).sum(dim=0) # [B, T, C]
717
+ outs["embeddings"] = x
718
+
719
+ if self.config.use_gene_embedding:
720
+ gene_indices = torch.arange(
721
+ self.config.fixed_sequence_length, device=x.device
722
+ )
723
+ gene_embedding = self.gene_embedding_layer(gene_indices)
724
+ if self.config.project_gene_embedding:
725
+ gene_embedding = self.fc_gene_embedding(gene_embedding)
726
+ x = x + gene_embedding
727
+ outs["embeddings_with_gene_embedding"] = x
728
+
729
+ x = x.permute(0, 2, 1)
730
+ x = self.stem_conv(x)
731
+ outs["stem"] = x
732
+
733
+ residuals = []
734
+ for conv_block in self.conv_tower:
735
+ x, res = conv_block(x)
736
+ residuals.append(res)
737
+ x = x.permute(0, 2, 1)
738
+ outs["conv_tower"] = x
739
+ outs["conv_tower_residuals"] = residuals # type: ignore
740
+ residuals = residuals[::-1]
741
+
742
+ for layer_idx, transformer in enumerate(self.transformer_layers):
743
+ output = transformer(x)
744
+ x = output["embeddings"]
745
+ if (layer_idx + 1) in self.config.embeddings_layers_to_save:
746
+ outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"]
747
+ if (layer_idx + 1) in self._attention_layers_to_save:
748
+ for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
749
+ dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
750
+ outs[dkey] = output["attention_weights"][:, map_number + 1]
751
+ outs["after_transformer_embedding"] = x
752
+
753
+ x = x.permute(0, 2, 1)
754
+ for deconv_block, res in zip(self.deconv_tower, residuals):
755
+ x = deconv_block(x, res)
756
+ x = x.permute(0, 2, 1)
757
+ outs["deconv_tower"] = x
758
+
759
+ outs["logits"] = {
760
+ omic: self.omic_lm_heads[omic](x) for omic in self.config.alphabet_size
761
+ }
762
+
763
+ return outs