lhallee commited on
Commit
9f3d5a4
·
verified ·
1 Parent(s): 1a988ad

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +2 -2
modeling_fastesm.py CHANGED
@@ -717,11 +717,11 @@ class EmbeddingMixin:
717
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
718
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
719
  residue_embeddings = self._embed(input_ids, attention_mask)
720
- embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
721
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
722
  if full_embeddings:
723
  emb = emb[mask.bool()].reshape(-1, hidden_size)
724
- embeddings_dict[seq] = emb
725
 
726
  if save:
727
  torch.save(embeddings_dict, save_path)
 
717
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
718
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
719
  residue_embeddings = self._embed(input_ids, attention_mask)
720
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
721
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
722
  if full_embeddings:
723
  emb = emb[mask.bool()].reshape(-1, hidden_size)
724
+ embeddings_dict[seq] = emb.cpu()
725
 
726
  if save:
727
  torch.save(embeddings_dict, save_path)