mgelard commited on
Commit
3c657e7
·
verified ·
1 Parent(s): 8f7201c

Upload MOJO

Browse files
Files changed (1) hide show
  1. mojo.py +34 -26
mojo.py CHANGED
@@ -1,7 +1,7 @@
1
  import logging
2
  import math
3
- from dataclasses import dataclass, field
4
- from typing import Optional, Tuple
5
 
6
  import numpy as np
7
  import torch
@@ -510,32 +510,40 @@ class LMHead(nn.Module):
510
  return out
511
 
512
 
513
- @dataclass
514
  class MOJOConfig(PretrainedConfig): # noqa: N801
515
  model_type = "MOJO"
516
- alphabet_size: dict[str, int] = field(
517
- default_factory=lambda: {"rnaseq": 66, "methylation": 66}
518
- )
519
- token_embed_dim: int = 256
520
- init_gene_embed_dim: int = 200
521
- use_gene_embedding: bool = True
522
- project_gene_embedding: bool = True
523
- sequence_length: int = 17_116 # n_genes
524
- fixed_sequence_length: int | None = None
525
- num_downsamples: int = 8
526
- conv_init_embed_dim: int = 512
527
- stem_kernel_shape: int = 15
528
- embed_dim: int = 512
529
- filter_list: list[int] = field(default_factory=list)
530
- num_attention_heads: int = 16
531
- key_size: Optional[int] = None
532
- ffn_embed_dim: int = 1_024
533
- num_layers: int = 8
534
- num_hidden_layers_head: int = 1
535
-
536
- # return
537
- embeddings_layers_to_save: tuple[int, ...] = field(default_factory=tuple)
538
- attention_maps_to_save: list[tuple[int, int]] = field(default_factory=list)
 
 
 
 
 
 
 
 
 
539
 
540
  def __post_init__(self):
541
  # Validate attention key size
 
1
  import logging
2
  import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional, Tuple
5
 
6
  import numpy as np
7
  import torch
 
510
  return out
511
 
512
 
 
513
  class MOJOConfig(PretrainedConfig): # noqa: N801
514
  model_type = "MOJO"
515
+
516
+ def __init__(self, **kwargs: Any) -> None:
517
+ super().__init__(**kwargs)
518
+ self.alphabet_size = kwargs.get(
519
+ "alphabet_size", {"rnaseq": 66, "methylation": 66}
520
+ )
521
+ self.token_embed_dim = kwargs.get("token_embed_dim", 256)
522
+ self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200)
523
+ self.use_gene_embedding = kwargs.get("use_gene_embedding", True)
524
+ self.project_gene_embedding = kwargs.get("project_gene_embedding", True)
525
+ self.sequence_length = kwargs.get("sequence_length", 17_116) # n_genes
526
+ self.fixed_sequence_length = kwargs.get("fixed_sequence_length", None)
527
+ self.num_downsamples = kwargs.get("num_downsamples", 8)
528
+ self.conv_init_embed_dim = kwargs.get("conv_init_embed_dim", 512)
529
+ self.stem_kernel_shape = kwargs.get("stem_kernel_shape", 15)
530
+ self.embed_dim = kwargs.get("embed_dim", 512)
531
+ self.filter_list = kwargs.get("filter_list", [])
532
+ self.num_attention_heads = kwargs.get("num_attention_heads", 16)
533
+ self.key_size = kwargs.get("key_size", None)
534
+ self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 1_024)
535
+ self.num_layers = kwargs.get("num_layers", 8)
536
+ self.num_hidden_layers_head = kwargs.get("num_hidden_layers_head", 1)
537
+
538
+ # return
539
+ self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get(
540
+ "embeddings_layers_to_save", ()
541
+ )
542
+ self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get(
543
+ "attention_maps_to_save", []
544
+ )
545
+
546
+ self.__post_init__()
547
 
548
  def __post_init__(self):
549
  # Validate attention key size