Yjhhh commited on
Commit
ca53d81
verified
1 Parent(s): 0b404fc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +5 -3
main.py CHANGED
@@ -105,8 +105,9 @@ class UnifiedModel(nn.Module):
105
  input_ids=input_id,
106
  attention_mask=attn_mask
107
  )
108
- hidden_states.append(outputs.logits)
109
- concatenated_hidden_states = torch.cat(hidden_states, dim=-1)
 
110
  logits = self.classifier(concatenated_hidden_states)
111
  return logits
112
 
@@ -119,7 +120,8 @@ class UnifiedModel(nn.Module):
119
  model.load_state_dict(torch.load(model_data_bytes))
120
  else:
121
  model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
122
- return UnifiedModel([model])
 
123
 
124
  # Dataset para entrenamiento
125
  class SyntheticDataset(Dataset):
 
105
  input_ids=input_id,
106
  attention_mask=attn_mask
107
  )
108
+ hidden_states.append(outputs.pooler_output) # Usar pooler_output para obtener un vector de 768
109
+
110
+ concatenated_hidden_states = torch.cat(hidden_states, dim=1) # Concatenar en la dimensi贸n correcta
111
  logits = self.classifier(concatenated_hidden_states)
112
  return logits
113
 
 
120
  model.load_state_dict(torch.load(model_data_bytes))
121
  else:
122
  model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) # 3 clases
123
+
124
+ return UnifiedModel([model, model]) # Asegurar que se usa una lista de modelos, en este caso 2
125
 
126
  # Dataset para entrenamiento
127
  class SyntheticDataset(Dataset):