lhallee commited on
Commit
7ce555f
·
verified ·
1 Parent(s): 7cff869

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +8 -4
modeling_fastesm.py CHANGED
@@ -556,10 +556,6 @@ class FastEsmPreTrainedModel(PreTrainedModel):
556
  Returns:
557
  Dictionary mapping sequences to embeddings, or None if sql=True
558
  """
559
- sequences = list(set([seq[:max_len] for seq in sequences]))
560
- sequences = sorted(sequences, key=len, reverse=True)
561
- dataset = ProteinDataset(sequences)
562
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
563
  device = self.device
564
 
565
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -570,6 +566,7 @@ class FastEsmPreTrainedModel(PreTrainedModel):
570
  else:
571
  return residue_embeddings[:, 0, :]
572
 
 
573
  if sql:
574
  import sqlite3
575
  conn = sqlite3.connect(sql_db_path)
@@ -580,6 +577,9 @@ class FastEsmPreTrainedModel(PreTrainedModel):
580
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
581
  print(f"Embedding {len(to_embed)} new sequences")
582
  if len(to_embed) > 0:
 
 
 
583
  with torch.no_grad():
584
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
585
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
@@ -598,6 +598,10 @@ class FastEsmPreTrainedModel(PreTrainedModel):
598
  conn.close()
599
  return None
600
 
 
 
 
 
601
  embeddings_dict = {}
602
  with torch.no_grad():
603
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
 
556
  Returns:
557
  Dictionary mapping sequences to embeddings, or None if sql=True
558
  """
 
 
 
 
559
  device = self.device
560
 
561
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
566
  else:
567
  return residue_embeddings[:, 0, :]
568
 
569
+ sequences = list(set([seq[:max_len] for seq in sequences]))
570
  if sql:
571
  import sqlite3
572
  conn = sqlite3.connect(sql_db_path)
 
577
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
578
  print(f"Embedding {len(to_embed)} new sequences")
579
  if len(to_embed) > 0:
580
+ to_embed = sorted(to_embed, key=len, reverse=True)
581
+ dataset = ProteinDataset(to_embed)
582
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
583
  with torch.no_grad():
584
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
585
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
 
598
  conn.close()
599
  return None
600
 
601
+ sequences = list(set([seq[:max_len] for seq in sequences]))
602
+ sequences = sorted(sequences, key=len, reverse=True)
603
+ dataset = ProteinDataset(sequences)
604
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
605
  embeddings_dict = {}
606
  with torch.no_grad():
607
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):