Upload BulkRNABert
Browse files- bulkrnabert.py +24 -17
bulkrnabert.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import logging
|
2 |
-
from
|
3 |
-
from typing import Optional
|
4 |
|
5 |
import numpy as np
|
6 |
import torch
|
@@ -198,23 +197,31 @@ class SelfAttentionBlock(nn.Module):
|
|
198 |
return output
|
199 |
|
200 |
|
201 |
-
@dataclass
|
202 |
class BulkRNABertConfig(PretrainedConfig):
|
203 |
model_type = "BulkRNABert"
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
def __post_init__(self):
|
220 |
# Validate attention key size
|
|
|
1 |
import logging
|
2 |
+
from typing import Any, Optional
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
|
|
197 |
return output
|
198 |
|
199 |
|
|
|
200 |
class BulkRNABertConfig(PretrainedConfig):
|
201 |
model_type = "BulkRNABert"
|
202 |
+
|
203 |
+
def __init__(self, **kwargs: Any) -> None:
|
204 |
+
super().__init__(**kwargs)
|
205 |
+
self.n_genes = kwargs.get("n_genes", 19_062)
|
206 |
+
self.n_expressions_bins = kwargs.get("n_expressions_bins", 64)
|
207 |
+
self.embed_dim = kwargs.get("embed_dim", 256)
|
208 |
+
self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200)
|
209 |
+
self.use_gene_embedding = kwargs.get("use_gene_embedding", True)
|
210 |
+
self.project_gene_embedding = kwargs.get("project_gene_embedding", True)
|
211 |
+
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
|
212 |
+
self.key_size = kwargs.get("key_size", None)
|
213 |
+
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 512)
|
214 |
+
self.num_layers = kwargs.get("num_layers", 4)
|
215 |
+
|
216 |
+
# return
|
217 |
+
self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get(
|
218 |
+
"embeddings_layers_to_save", ()
|
219 |
+
)
|
220 |
+
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get(
|
221 |
+
"attention_maps_to_save", []
|
222 |
+
)
|
223 |
+
|
224 |
+
self.__post_init__()
|
225 |
|
226 |
def __post_init__(self):
|
227 |
# Validate attention key size
|