Commit
·
dc93082
1
Parent(s):
ad2eff7
working version with base model + mcqbert
Browse files- config.json +2 -2
- modeling_mcqbert.py +5 -3
config.json
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
],
|
6 |
"auto_map": {
|
7 |
"AutoConfig": "configuration_mcqbert.MCQBertConfig",
|
8 |
-
"
|
9 |
},
|
10 |
"attention_probs_dropout_prob": 0.1,
|
11 |
"classifier_dropout": null,
|
@@ -18,7 +18,7 @@
|
|
18 |
"intermediate_size": 3072,
|
19 |
"layer_norm_eps": 1e-12,
|
20 |
"max_position_embeddings": 512,
|
21 |
-
"model_type": "
|
22 |
"num_attention_heads": 12,
|
23 |
"num_hidden_layers": 12,
|
24 |
"pad_token_id": 0,
|
|
|
5 |
],
|
6 |
"auto_map": {
|
7 |
"AutoConfig": "configuration_mcqbert.MCQBertConfig",
|
8 |
+
"AutoModel": "modeling_mcqbert.MCQBert"
|
9 |
},
|
10 |
"attention_probs_dropout_prob": 0.1,
|
11 |
"classifier_dropout": null,
|
|
|
18 |
"intermediate_size": 3072,
|
19 |
"layer_norm_eps": 1e-12,
|
20 |
"max_position_embeddings": 512,
|
21 |
+
"model_type": "mcqbert",
|
22 |
"num_attention_heads": 12,
|
23 |
"num_hidden_layers": 12,
|
24 |
"pad_token_id": 0,
|
modeling_mcqbert.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
-
from transformers import
|
2 |
import torch
|
3 |
|
4 |
from .configuration_mcqbert import MCQBertConfig
|
5 |
|
6 |
-
class MCQBert(
|
7 |
def __init__(self, config: MCQBertConfig):
|
8 |
super().__init__(config)
|
9 |
-
|
|
|
|
|
10 |
|
11 |
cls_input_dim_multiplier = 2 if config.integration_strategy == "cat" else 1
|
12 |
cls_input_dim = self.config.hidden_size * cls_input_dim_multiplier
|
|
|
1 |
+
from transformers import BertModel
|
2 |
import torch
|
3 |
|
4 |
from .configuration_mcqbert import MCQBertConfig
|
5 |
|
6 |
+
class MCQBert(BertModel):
|
7 |
def __init__(self, config: MCQBertConfig):
|
8 |
super().__init__(config)
|
9 |
+
|
10 |
+
if config.integration_strategy is not None:
|
11 |
+
self.student_embedding_layer = torch.nn.Linear(config.student_embedding_size, config.hidden_size)
|
12 |
|
13 |
cls_input_dim_multiplier = 2 if config.integration_strategy == "cat" else 1
|
14 |
cls_input_dim = self.config.hidden_size * cls_input_dim_multiplier
|