Update main.py
Browse files
main.py
CHANGED
@@ -31,7 +31,8 @@ class UnifiedModel(nn.Module):
|
|
31 |
def __init__(self, models):
|
32 |
super(UnifiedModel, self).__init__()
|
33 |
self.models = nn.ModuleList(models)
|
34 |
-
|
|
|
35 |
|
36 |
def forward(self, input_ids, attention_mask):
|
37 |
hidden_states = []
|
@@ -40,7 +41,7 @@ class UnifiedModel(nn.Module):
|
|
40 |
input_ids=input_id,
|
41 |
attention_mask=attn_mask
|
42 |
)
|
43 |
-
hidden_states.append(outputs.logits[:,
|
44 |
concatenated_hidden_states = torch.cat(hidden_states, dim=-1)
|
45 |
logits = self.classifier(concatenated_hidden_states)
|
46 |
return logits
|
|
|
31 |
def __init__(self, models):
|
32 |
super(UnifiedModel, self).__init__()
|
33 |
self.models = nn.ModuleList(models)
|
34 |
+
hidden_size = self.models[0].config.hidden_size
|
35 |
+
self.classifier = nn.Linear(len(models) * hidden_size, 2)
|
36 |
|
37 |
def forward(self, input_ids, attention_mask):
|
38 |
hidden_states = []
|
|
|
41 |
input_ids=input_id,
|
42 |
attention_mask=attn_mask
|
43 |
)
|
44 |
+
hidden_states.append(outputs.logits[:, -1, :])
|
45 |
concatenated_hidden_states = torch.cat(hidden_states, dim=-1)
|
46 |
logits = self.classifier(concatenated_hidden_states)
|
47 |
return logits
|