Update main.py
Browse files
main.py
CHANGED
@@ -105,8 +105,9 @@ class UnifiedModel(nn.Module):
|
|
105 |
input_ids=input_id,
|
106 |
attention_mask=attn_mask
|
107 |
)
|
108 |
-
hidden_states.append(outputs.
|
109 |
-
|
|
|
110 |
logits = self.classifier(concatenated_hidden_states)
|
111 |
return logits
|
112 |
|
@@ -119,7 +120,8 @@ class UnifiedModel(nn.Module):
|
|
119 |
model.load_state_dict(torch.load(model_data_bytes))
|
120 |
else:
|
121 |
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
122 |
-
|
|
|
123 |
|
124 |
# Dataset para entrenamiento
|
125 |
class SyntheticDataset(Dataset):
|
|
|
105 |
input_ids=input_id,
|
106 |
attention_mask=attn_mask
|
107 |
)
|
108 |
+
hidden_states.append(outputs.pooler_output) # Usar pooler_output para obtener un vector de 768
|
109 |
+
|
110 |
+
concatenated_hidden_states = torch.cat(hidden_states, dim=1) # Concatenar en la dimensi贸n correcta
|
111 |
logits = self.classifier(concatenated_hidden_states)
|
112 |
return logits
|
113 |
|
|
|
120 |
model.load_state_dict(torch.load(model_data_bytes))
|
121 |
else:
|
122 |
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
|
123 |
+
|
124 |
+
return UnifiedModel([model, model]) # Asegurar que se usa una lista de modelos, en este caso 2
|
125 |
|
126 |
# Dataset para entrenamiento
|
127 |
class SyntheticDataset(Dataset):
|