Update main.py
Browse files
main.py
CHANGED
@@ -105,9 +105,9 @@ class UnifiedModel(nn.Module):
|
|
105 |
input_ids=input_id,
|
106 |
attention_mask=attn_mask
|
107 |
)
|
108 |
-
hidden_states.append(outputs.
|
109 |
|
110 |
-
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
111 |
logits = self.classifier(concatenated_hidden_states)
|
112 |
return logits
|
113 |
|
@@ -390,7 +390,8 @@ async def get_home():
|
|
390 |
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
|
391 |
for model_name, model in models.items():
|
392 |
torch.save(model.state_dict(), model_name)
|
393 |
-
|
|
|
394 |
|
395 |
for tokenizer_name, tokenizer in tokenizers.items():
|
396 |
tokens = tokenizer.get_vocab()
|
|
|
105 |
input_ids=input_id,
|
106 |
attention_mask=attn_mask
|
107 |
)
|
108 |
+
hidden_states.append(outputs.logits[:, -1, :]) # Obtener el 煤ltimo hidden state
|
109 |
|
110 |
+
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
111 |
logits = self.classifier(concatenated_hidden_states)
|
112 |
return logits
|
113 |
|
|
|
390 |
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
|
391 |
for model_name, model in models.items():
|
392 |
torch.save(model.state_dict(), model_name)
|
393 |
+
with open(model_name, "rb") as f:
|
394 |
+
redis_client.set(f"model:{model_name}", f.read())
|
395 |
|
396 |
for tokenizer_name, tokenizer in tokenizers.items():
|
397 |
tokens = tokenizer.get_vocab()
|