tommymarto commited on
Commit
dc93082
·
1 Parent(s): ad2eff7

working version with base model + mcqbert

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. modeling_mcqbert.py +5 -3
config.json CHANGED
@@ -5,7 +5,7 @@
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "configuration_mcqbert.MCQBertConfig",
8
- "AutoModelForCausalLM": "modeling_mcqbert.MCQBert"
9
  },
10
  "attention_probs_dropout_prob": 0.1,
11
  "classifier_dropout": null,
@@ -18,7 +18,7 @@
18
  "intermediate_size": 3072,
19
  "layer_norm_eps": 1e-12,
20
  "max_position_embeddings": 512,
21
- "model_type": "bert",
22
  "num_attention_heads": 12,
23
  "num_hidden_layers": 12,
24
  "pad_token_id": 0,
 
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "configuration_mcqbert.MCQBertConfig",
8
+ "AutoModel": "modeling_mcqbert.MCQBert"
9
  },
10
  "attention_probs_dropout_prob": 0.1,
11
  "classifier_dropout": null,
 
18
  "intermediate_size": 3072,
19
  "layer_norm_eps": 1e-12,
20
  "max_position_embeddings": 512,
21
+ "model_type": "mcqbert",
22
  "num_attention_heads": 12,
23
  "num_hidden_layers": 12,
24
  "pad_token_id": 0,
modeling_mcqbert.py CHANGED
@@ -1,12 +1,14 @@
1
- from transformers import BertPreTrainedModel
2
  import torch
3
 
4
  from .configuration_mcqbert import MCQBertConfig
5
 
6
- class MCQBert(BertPreTrainedModel):
7
  def __init__(self, config: MCQBertConfig):
8
  super().__init__(config)
9
- self.student_embedding_layer = torch.nn.Linear(config.student_embedding_size, config.hidden_size)
 
 
10
 
11
  cls_input_dim_multiplier = 2 if config.integration_strategy == "cat" else 1
12
  cls_input_dim = self.config.hidden_size * cls_input_dim_multiplier
 
1
+ from transformers import BertModel
2
  import torch
3
 
4
  from .configuration_mcqbert import MCQBertConfig
5
 
6
+ class MCQBert(BertModel):
7
  def __init__(self, config: MCQBertConfig):
8
  super().__init__(config)
9
+
10
+ if config.integration_strategy is not None:
11
+ self.student_embedding_layer = torch.nn.Linear(config.student_embedding_size, config.hidden_size)
12
 
13
  cls_input_dim_multiplier = 2 if config.integration_strategy == "cat" else 1
14
  cls_input_dim = self.config.hidden_size * cls_input_dim_multiplier