tommymarto commited on
Commit
773beeb
·
verified ·
1 Parent(s): 35a5cdb

Update modeling_mcqbert.py

Browse files
Files changed (1) hide show
  1. modeling_mcqbert.py +3 -7
modeling_mcqbert.py CHANGED
@@ -24,23 +24,19 @@ class MCQStudentBert(BertModel):
24
  def forward(self, input_ids, student_embeddings=None):
25
  if self.config.integration_strategy is None:
26
  # don't consider embeddings is no integration strategy (MCQBert)
27
- student_embeddings = torch.zeros(self.config.student_embedding_layer)
28
-
29
- input_embeddings = self.embeddings(input_ids)
30
- combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
31
- output = super().forward(inputs_embeds = combined_embeddings)
32
  return self.classifier(output.last_hidden_state[:, 0, :])
33
 
34
  elif self.config.integration_strategy == "cat":
35
  # MCQStudentBertCat
36
  output = super().forward(input_ids)
37
- output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings)), dim = 1)
38
  return self.classifier(output_with_student_embedding)
39
 
40
  elif self.config.integration_strategy == "sum":
41
  # MCQStudentBertSum
42
  input_embeddings = self.embeddings(input_ids)
43
- combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
44
  output = super().forward(inputs_embeds = combined_embeddings)
45
  return self.classifier(output.last_hidden_state[:, 0, :])
46
 
 
24
  def forward(self, input_ids, student_embeddings=None):
25
  if self.config.integration_strategy is None:
26
  # don't consider embeddings is no integration strategy (MCQBert)
27
+ output = super().forward(input_ids)
 
 
 
 
28
  return self.classifier(output.last_hidden_state[:, 0, :])
29
 
30
  elif self.config.integration_strategy == "cat":
31
  # MCQStudentBertCat
32
  output = super().forward(input_ids)
33
+ output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings).unsqueeze(0)), dim = 1)
34
  return self.classifier(output_with_student_embedding)
35
 
36
  elif self.config.integration_strategy == "sum":
37
  # MCQStudentBertSum
38
  input_embeddings = self.embeddings(input_ids)
39
+ combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).repeat(1, input_embeddings.size(1), 1)
40
  output = super().forward(inputs_embeds = combined_embeddings)
41
  return self.classifier(output.last_hidden_state[:, 0, :])
42