tommymarto
commited on
Update modeling_mcqbert.py
Browse files- modeling_mcqbert.py +3 -7
modeling_mcqbert.py
CHANGED
@@ -24,23 +24,19 @@ class MCQStudentBert(BertModel):
|
|
24 |
def forward(self, input_ids, student_embeddings=None):
|
25 |
if self.config.integration_strategy is None:
|
26 |
# don't consider embeddings is no integration strategy (MCQBert)
|
27 |
-
|
28 |
-
|
29 |
-
input_embeddings = self.embeddings(input_ids)
|
30 |
-
combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
|
31 |
-
output = super().forward(inputs_embeds = combined_embeddings)
|
32 |
return self.classifier(output.last_hidden_state[:, 0, :])
|
33 |
|
34 |
elif self.config.integration_strategy == "cat":
|
35 |
# MCQStudentBertCat
|
36 |
output = super().forward(input_ids)
|
37 |
-
output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings)), dim = 1)
|
38 |
return self.classifier(output_with_student_embedding)
|
39 |
|
40 |
elif self.config.integration_strategy == "sum":
|
41 |
# MCQStudentBertSum
|
42 |
input_embeddings = self.embeddings(input_ids)
|
43 |
-
combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).
|
44 |
output = super().forward(inputs_embeds = combined_embeddings)
|
45 |
return self.classifier(output.last_hidden_state[:, 0, :])
|
46 |
|
|
|
24 |
def forward(self, input_ids, student_embeddings=None):
|
25 |
if self.config.integration_strategy is None:
|
26 |
# don't consider embeddings is no integration strategy (MCQBert)
|
27 |
+
output = super().forward(input_ids)
|
|
|
|
|
|
|
|
|
28 |
return self.classifier(output.last_hidden_state[:, 0, :])
|
29 |
|
30 |
elif self.config.integration_strategy == "cat":
|
31 |
# MCQStudentBertCat
|
32 |
output = super().forward(input_ids)
|
33 |
+
output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings).unsqueeze(0)), dim = 1)
|
34 |
return self.classifier(output_with_student_embedding)
|
35 |
|
36 |
elif self.config.integration_strategy == "sum":
|
37 |
# MCQStudentBertSum
|
38 |
input_embeddings = self.embeddings(input_ids)
|
39 |
+
combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).repeat(1, input_embeddings.size(1), 1)
|
40 |
output = super().forward(inputs_embeds = combined_embeddings)
|
41 |
return self.classifier(output.last_hidden_state[:, 0, :])
|
42 |
|