lhallee commited on
Commit
7376a9a
·
verified ·
1 Parent(s): 4f5815f

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +7 -3
modeling_fastesm.py CHANGED
@@ -603,6 +603,7 @@ class EmbeddingMixin:
603
  tokenizer: PreTrainedTokenizerBase,
604
  batch_size: int = 2,
605
  max_len: int = 512,
 
606
  full_embeddings: bool = False,
607
  embed_dtype: torch.dtype = torch.float32,
608
  pooling_types: List[str] = ['mean'],
@@ -654,8 +655,9 @@ class EmbeddingMixin:
654
  )
655
  >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
656
  """
657
- sequences = list(set([seq[:max_len] for seq in sequences]))
658
  sequences = sorted(sequences, key=len, reverse=True)
 
659
  collate_fn = build_collator(tokenizer)
660
  device = self.device
661
  pooler = Pooler(pooling_types) if not full_embeddings else None
@@ -686,7 +688,7 @@ class EmbeddingMixin:
686
  embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
687
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
688
  if full_embeddings:
689
- emb = emb[mask.bool()]
690
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
691
  (seq, emb.cpu().numpy().tobytes()))
692
 
@@ -716,7 +718,9 @@ class EmbeddingMixin:
716
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
717
  residue_embeddings = self._embed(input_ids, attention_mask)
718
  embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
719
- for seq, emb in zip(seqs, embeddings):
 
 
720
  embeddings_dict[seq] = emb
721
 
722
  if save:
 
603
  tokenizer: PreTrainedTokenizerBase,
604
  batch_size: int = 2,
605
  max_len: int = 512,
606
+ truncate: bool = True,
607
  full_embeddings: bool = False,
608
  embed_dtype: torch.dtype = torch.float32,
609
  pooling_types: List[str] = ['mean'],
 
655
  )
656
  >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
657
  """
658
+ sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
659
  sequences = sorted(sequences, key=len, reverse=True)
660
+ hidden_size = self.config.hidden_size
661
  collate_fn = build_collator(tokenizer)
662
  device = self.device
663
  pooler = Pooler(pooling_types) if not full_embeddings else None
 
688
  embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
689
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
690
  if full_embeddings:
691
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
692
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
693
  (seq, emb.cpu().numpy().tobytes()))
694
 
 
718
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
719
  residue_embeddings = self._embed(input_ids, attention_mask)
720
  embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
721
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
722
+ if full_embeddings:
723
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
724
  embeddings_dict[seq] = emb
725
 
726
  if save: