Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +2 -2
modeling_fastesm.py
CHANGED
@@ -717,11 +717,11 @@ class EmbeddingMixin:
|
|
717 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
718 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
719 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
720 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
721 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
722 |
if full_embeddings:
|
723 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
724 |
-
embeddings_dict[seq] = emb
|
725 |
|
726 |
if save:
|
727 |
torch.save(embeddings_dict, save_path)
|
|
|
717 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
718 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
719 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
720 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
721 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
722 |
if full_embeddings:
|
723 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
724 |
+
embeddings_dict[seq] = emb.cpu()
|
725 |
|
726 |
if save:
|
727 |
torch.save(embeddings_dict, save_path)
|