Yjhhh commited on
Commit
212546b
verified
1 Parent(s): ca53d81

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -3
main.py CHANGED
@@ -105,9 +105,9 @@ class UnifiedModel(nn.Module):
105
  input_ids=input_id,
106
  attention_mask=attn_mask
107
  )
108
- hidden_states.append(outputs.pooler_output) # Usar pooler_output para obtener un vector de 768
109
 
110
- concatenated_hidden_states = torch.cat(hidden_states, dim=1) # Concatenar en la dimensi贸n correcta
111
  logits = self.classifier(concatenated_hidden_states)
112
  return logits
113
 
@@ -390,7 +390,8 @@ async def get_home():
390
  def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
391
  for model_name, model in models.items():
392
  torch.save(model.state_dict(), model_name)
393
- redis_client.set(f"model:{model_name}", open(model_name, "rb").read())
 
394
 
395
  for tokenizer_name, tokenizer in tokenizers.items():
396
  tokens = tokenizer.get_vocab()
 
105
  input_ids=input_id,
106
  attention_mask=attn_mask
107
  )
108
+ hidden_states.append(outputs.logits[:, -1, :]) # Obtener el 煤ltimo hidden state
109
 
110
+ concatenated_hidden_states = torch.cat(hidden_states, dim=1)
111
  logits = self.classifier(concatenated_hidden_states)
112
  return logits
113
 
 
390
  def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
391
  for model_name, model in models.items():
392
  torch.save(model.state_dict(), model_name)
393
+ with open(model_name, "rb") as f:
394
+ redis_client.set(f"model:{model_name}", f.read())
395
 
396
  for tokenizer_name, tokenizer in tokenizers.items():
397
  tokens = tokenizer.get_vocab()