Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +8 -4
modeling_fastesm.py
CHANGED
@@ -556,10 +556,6 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
556 |
Returns:
|
557 |
Dictionary mapping sequences to embeddings, or None if sql=True
|
558 |
"""
|
559 |
-
sequences = list(set([seq[:max_len] for seq in sequences]))
|
560 |
-
sequences = sorted(sequences, key=len, reverse=True)
|
561 |
-
dataset = ProteinDataset(sequences)
|
562 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
563 |
device = self.device
|
564 |
|
565 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
@@ -570,6 +566,7 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
570 |
else:
|
571 |
return residue_embeddings[:, 0, :]
|
572 |
|
|
|
573 |
if sql:
|
574 |
import sqlite3
|
575 |
conn = sqlite3.connect(sql_db_path)
|
@@ -580,6 +577,9 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
580 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
581 |
print(f"Embedding {len(to_embed)} new sequences")
|
582 |
if len(to_embed) > 0:
|
|
|
|
|
|
|
583 |
with torch.no_grad():
|
584 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
585 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
@@ -598,6 +598,10 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
598 |
conn.close()
|
599 |
return None
|
600 |
|
|
|
|
|
|
|
|
|
601 |
embeddings_dict = {}
|
602 |
with torch.no_grad():
|
603 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
|
|
556 |
Returns:
|
557 |
Dictionary mapping sequences to embeddings, or None if sql=True
|
558 |
"""
|
|
|
|
|
|
|
|
|
559 |
device = self.device
|
560 |
|
561 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
566 |
else:
|
567 |
return residue_embeddings[:, 0, :]
|
568 |
|
569 |
+
sequences = list(set([seq[:max_len] for seq in sequences]))
|
570 |
if sql:
|
571 |
import sqlite3
|
572 |
conn = sqlite3.connect(sql_db_path)
|
|
|
577 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
578 |
print(f"Embedding {len(to_embed)} new sequences")
|
579 |
if len(to_embed) > 0:
|
580 |
+
to_embed = sorted(to_embed, key=len, reverse=True)
|
581 |
+
dataset = ProteinDataset(to_embed)
|
582 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
583 |
with torch.no_grad():
|
584 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
585 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
|
|
598 |
conn.close()
|
599 |
return None
|
600 |
|
601 |
+
sequences = list(set([seq[:max_len] for seq in sequences]))
|
602 |
+
sequences = sorted(sequences, key=len, reverse=True)
|
603 |
+
dataset = ProteinDataset(sequences)
|
604 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
605 |
embeddings_dict = {}
|
606 |
with torch.no_grad():
|
607 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|