Upload MOJO
Browse files
mojo.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import logging
|
2 |
import math
|
3 |
-
from dataclasses import dataclass
|
4 |
-
from typing import Optional, Tuple
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
@@ -510,32 +510,40 @@ class LMHead(nn.Module):
|
|
510 |
return out
|
511 |
|
512 |
|
513 |
-
@dataclass
|
514 |
class MOJOConfig(PretrainedConfig): # noqa: N801
|
515 |
model_type = "MOJO"
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
|
540 |
def __post_init__(self):
|
541 |
# Validate attention key size
|
|
|
1 |
import logging
|
2 |
import math
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Any, Optional, Tuple
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
|
|
510 |
return out
|
511 |
|
512 |
|
|
|
513 |
class MOJOConfig(PretrainedConfig): # noqa: N801
|
514 |
model_type = "MOJO"
|
515 |
+
|
516 |
+
def __init__(self, **kwargs: Any) -> None:
|
517 |
+
super().__init__(**kwargs)
|
518 |
+
self.alphabet_size = kwargs.get(
|
519 |
+
"alphabet_size", {"rnaseq": 66, "methylation": 66}
|
520 |
+
)
|
521 |
+
self.token_embed_dim = kwargs.get("token_embed_dim", 256)
|
522 |
+
self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200)
|
523 |
+
self.use_gene_embedding = kwargs.get("use_gene_embedding", True)
|
524 |
+
self.project_gene_embedding = kwargs.get("project_gene_embedding", True)
|
525 |
+
self.sequence_length = kwargs.get("sequence_length", 17_116) # n_genes
|
526 |
+
self.fixed_sequence_length = kwargs.get("fixed_sequence_length", None)
|
527 |
+
self.num_downsamples = kwargs.get("num_downsamples", 8)
|
528 |
+
self.conv_init_embed_dim = kwargs.get("conv_init_embed_dim", 512)
|
529 |
+
self.stem_kernel_shape = kwargs.get("stem_kernel_shape", 15)
|
530 |
+
self.embed_dim = kwargs.get("embed_dim", 512)
|
531 |
+
self.filter_list = kwargs.get("filter_list", [])
|
532 |
+
self.num_attention_heads = kwargs.get("num_attention_heads", 16)
|
533 |
+
self.key_size = kwargs.get("key_size", None)
|
534 |
+
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 1_024)
|
535 |
+
self.num_layers = kwargs.get("num_layers", 8)
|
536 |
+
self.num_hidden_layers_head = kwargs.get("num_hidden_layers_head", 1)
|
537 |
+
|
538 |
+
# return
|
539 |
+
self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get(
|
540 |
+
"embeddings_layers_to_save", ()
|
541 |
+
)
|
542 |
+
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get(
|
543 |
+
"attention_maps_to_save", []
|
544 |
+
)
|
545 |
+
|
546 |
+
self.__post_init__()
|
547 |
|
548 |
def __post_init__(self):
|
549 |
# Validate attention key size
|