Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +7 -3
modeling_fastesm.py
CHANGED
@@ -603,6 +603,7 @@ class EmbeddingMixin:
|
|
603 |
tokenizer: PreTrainedTokenizerBase,
|
604 |
batch_size: int = 2,
|
605 |
max_len: int = 512,
|
|
|
606 |
full_embeddings: bool = False,
|
607 |
embed_dtype: torch.dtype = torch.float32,
|
608 |
pooling_types: List[str] = ['mean'],
|
@@ -654,8 +655,9 @@ class EmbeddingMixin:
|
|
654 |
)
|
655 |
>>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
|
656 |
"""
|
657 |
-
sequences = list(set([seq[:max_len] for seq in sequences]))
|
658 |
sequences = sorted(sequences, key=len, reverse=True)
|
|
|
659 |
collate_fn = build_collator(tokenizer)
|
660 |
device = self.device
|
661 |
pooler = Pooler(pooling_types) if not full_embeddings else None
|
@@ -686,7 +688,7 @@ class EmbeddingMixin:
|
|
686 |
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
687 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
688 |
if full_embeddings:
|
689 |
-
emb = emb[mask.bool()]
|
690 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
691 |
(seq, emb.cpu().numpy().tobytes()))
|
692 |
|
@@ -716,7 +718,9 @@ class EmbeddingMixin:
|
|
716 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
717 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
718 |
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
|
719 |
-
for seq, emb in zip(seqs, embeddings):
|
|
|
|
|
720 |
embeddings_dict[seq] = emb
|
721 |
|
722 |
if save:
|
|
|
603 |
tokenizer: PreTrainedTokenizerBase,
|
604 |
batch_size: int = 2,
|
605 |
max_len: int = 512,
|
606 |
+
truncate: bool = True,
|
607 |
full_embeddings: bool = False,
|
608 |
embed_dtype: torch.dtype = torch.float32,
|
609 |
pooling_types: List[str] = ['mean'],
|
|
|
655 |
)
|
656 |
>>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
|
657 |
"""
|
658 |
+
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
659 |
sequences = sorted(sequences, key=len, reverse=True)
|
660 |
+
hidden_size = self.config.hidden_size
|
661 |
collate_fn = build_collator(tokenizer)
|
662 |
device = self.device
|
663 |
pooler = Pooler(pooling_types) if not full_embeddings else None
|
|
|
688 |
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
689 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
690 |
if full_embeddings:
|
691 |
+
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
692 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
693 |
(seq, emb.cpu().numpy().tobytes()))
|
694 |
|
|
|
718 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
719 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
720 |
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
|
721 |
+
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
722 |
+
if full_embeddings:
|
723 |
+
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
724 |
embeddings_dict[seq] = emb
|
725 |
|
726 |
if save:
|