Yjhhh commited on
Commit
f8425ea
verified
1 Parent(s): 2bb71c1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -2
main.py CHANGED
@@ -96,7 +96,8 @@ class UnifiedModel(nn.Module):
96
  super(UnifiedModel, self).__init__()
97
  self.models = nn.ModuleList(models)
98
  hidden_size = self.models[0].config.hidden_size
99
- self.classifier = nn.Linear(len(models) * hidden_size, 3) # 3 clases
 
100
 
101
  def forward(self, input_ids, attention_mask):
102
  hidden_states = []
@@ -108,7 +109,8 @@ class UnifiedModel(nn.Module):
108
  hidden_states.append(outputs.logits) # Usar directamente outputs.logits
109
 
110
  concatenated_hidden_states = torch.cat(hidden_states, dim=1)
111
- logits = self.classifier(concatenated_hidden_states)
 
112
  return logits
113
 
114
  @staticmethod
 
96
  super(UnifiedModel, self).__init__()
97
  self.models = nn.ModuleList(models)
98
  hidden_size = self.models[0].config.hidden_size
99
+ self.projection = nn.Linear(len(models) * 3, hidden_size) # Nueva capa lineal para proyecci贸n
100
+ self.classifier = nn.Linear(hidden_size, 3) # 3 clases
101
 
102
  def forward(self, input_ids, attention_mask):
103
  hidden_states = []
 
109
  hidden_states.append(outputs.logits) # Usar directamente outputs.logits
110
 
111
  concatenated_hidden_states = torch.cat(hidden_states, dim=1)
112
+ projected_features = self.projection(concatenated_hidden_states) # Proyectar a hidden_size
113
+ logits = self.classifier(projected_features)
114
  return logits
115
 
116
  @staticmethod