Yjhhh commited on
Commit
8a6d0ed
·
verified ·
1 Parent(s): d9c2299

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -2
main.py CHANGED
@@ -31,7 +31,8 @@ class UnifiedModel(nn.Module):
31
  def __init__(self, models):
32
  super(UnifiedModel, self).__init__()
33
  self.models = nn.ModuleList(models)
34
- self.classifier = nn.Linear(sum([model.config.hidden_size for model in models]), 2)
 
35
 
36
  def forward(self, input_ids, attention_mask):
37
  hidden_states = []
@@ -40,7 +41,7 @@ class UnifiedModel(nn.Module):
40
  input_ids=input_id,
41
  attention_mask=attn_mask
42
  )
43
- hidden_states.append(outputs.logits[:, 0])
44
  concatenated_hidden_states = torch.cat(hidden_states, dim=-1)
45
  logits = self.classifier(concatenated_hidden_states)
46
  return logits
 
31
  def __init__(self, models):
32
  super(UnifiedModel, self).__init__()
33
  self.models = nn.ModuleList(models)
34
+ hidden_size = self.models[0].config.hidden_size
35
+ self.classifier = nn.Linear(len(models) * hidden_size, 2)
36
 
37
  def forward(self, input_ids, attention_mask):
38
  hidden_states = []
 
41
  input_ids=input_id,
42
  attention_mask=attn_mask
43
  )
44
+ hidden_states.append(outputs.logits[:, -1, :])
45
  concatenated_hidden_states = torch.cat(hidden_states, dim=-1)
46
  logits = self.classifier(concatenated_hidden_states)
47
  return logits