Update main.py
Browse files
main.py
CHANGED
@@ -96,7 +96,8 @@ class UnifiedModel(nn.Module):
|
|
96 |
super(UnifiedModel, self).__init__()
|
97 |
self.models = nn.ModuleList(models)
|
98 |
hidden_size = self.models[0].config.hidden_size
|
99 |
-
self.
|
|
|
100 |
|
101 |
def forward(self, input_ids, attention_mask):
|
102 |
hidden_states = []
|
@@ -108,7 +109,8 @@ class UnifiedModel(nn.Module):
|
|
108 |
hidden_states.append(outputs.logits) # Usar directamente outputs.logits
|
109 |
|
110 |
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
111 |
-
|
|
|
112 |
return logits
|
113 |
|
114 |
@staticmethod
|
|
|
96 |
super(UnifiedModel, self).__init__()
|
97 |
self.models = nn.ModuleList(models)
|
98 |
hidden_size = self.models[0].config.hidden_size
|
99 |
+
self.projection = nn.Linear(len(models) * 3, hidden_size) # Nueva capa lineal para proyecci贸n
|
100 |
+
self.classifier = nn.Linear(hidden_size, 3) # 3 clases
|
101 |
|
102 |
def forward(self, input_ids, attention_mask):
|
103 |
hidden_states = []
|
|
|
109 |
hidden_states.append(outputs.logits) # Usar directamente outputs.logits
|
110 |
|
111 |
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
112 |
+
projected_features = self.projection(concatenated_hidden_states) # Proyectar a hidden_size
|
113 |
+
logits = self.classifier(projected_features)
|
114 |
return logits
|
115 |
|
116 |
@staticmethod
|