mgelard commited on
Commit
5dd5c4c
·
verified ·
1 Parent(s): 1b644d1

Upload BulkRNABert

Browse files
Files changed (1) hide show
  1. bulkrnabert.py +24 -17
bulkrnabert.py CHANGED
@@ -1,6 +1,5 @@
1
  import logging
2
- from dataclasses import dataclass, field
3
- from typing import Optional
4
 
5
  import numpy as np
6
  import torch
@@ -198,23 +197,31 @@ class SelfAttentionBlock(nn.Module):
198
  return output
199
 
200
 
201
- @dataclass
202
  class BulkRNABertConfig(PretrainedConfig):
203
  model_type = "BulkRNABert"
204
- n_genes: int = 19_062
205
- n_expressions_bins: int = 64
206
- embed_dim: int = 256
207
- init_gene_embed_dim: int = 200
208
- use_gene_embedding: bool = True
209
- project_gene_embedding: bool = True
210
- num_attention_heads: int = 8
211
- key_size: Optional[int] = None
212
- ffn_embed_dim: int = 512
213
- num_layers: int = 4
214
-
215
- # return
216
- embeddings_layers_to_save: tuple[int, ...] = field(default_factory=tuple)
217
- attention_maps_to_save: list[tuple[int, int]] = field(default_factory=list)
 
 
 
 
 
 
 
 
 
218
 
219
  def __post_init__(self):
220
  # Validate attention key size
 
1
  import logging
2
+ from typing import Any, Optional
 
3
 
4
  import numpy as np
5
  import torch
 
197
  return output
198
 
199
 
 
200
  class BulkRNABertConfig(PretrainedConfig):
201
  model_type = "BulkRNABert"
202
+
203
+ def __init__(self, **kwargs: Any) -> None:
204
+ super().__init__(**kwargs)
205
+ self.n_genes = kwargs.get("n_genes", 19_062)
206
+ self.n_expressions_bins = kwargs.get("n_expressions_bins", 64)
207
+ self.embed_dim = kwargs.get("embed_dim", 256)
208
+ self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200)
209
+ self.use_gene_embedding = kwargs.get("use_gene_embedding", True)
210
+ self.project_gene_embedding = kwargs.get("project_gene_embedding", True)
211
+ self.num_attention_heads = kwargs.get("num_attention_heads", 8)
212
+ self.key_size = kwargs.get("key_size", None)
213
+ self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 512)
214
+ self.num_layers = kwargs.get("num_layers", 4)
215
+
216
+ # return
217
+ self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get(
218
+ "embeddings_layers_to_save", ()
219
+ )
220
+ self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get(
221
+ "attention_maps_to_save", []
222
+ )
223
+
224
+ self.__post_init__()
225
 
226
  def __post_init__(self):
227
  # Validate attention key size