lhallee commited on
Commit
2e16ff9
·
verified ·
1 Parent(s): bcdd73d

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +16 -16
modeling_fastesm.py CHANGED
@@ -442,22 +442,22 @@ class FastEsmPreTrainedModel(PreTrainedModel):
442
  to_embed = [seq for seq in sequences if seq not in already_embedded]
443
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
444
  print(f"Embedding {len(to_embed)} new sequences")
445
-
446
- with torch.no_grad():
447
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
448
- seqs = sequences[i * batch_size:(i + 1) * batch_size]
449
- input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
450
- residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float() # required for sql
451
- embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
452
-
453
- for seq, emb in zip(seqs, embeddings):
454
- c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
455
- (seq, emb.cpu().numpy().tobytes()))
456
-
457
- if (i + 1) % 100 == 0:
458
- conn.commit()
459
-
460
- conn.commit()
461
  conn.close()
462
  return None
463
 
 
442
  to_embed = [seq for seq in sequences if seq not in already_embedded]
443
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
444
  print(f"Embedding {len(to_embed)} new sequences")
445
+ if len(to_embed) > 0:
446
+ with torch.no_grad():
447
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
448
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
449
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
450
+ residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float() # required for sql
451
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
452
+
453
+ for seq, emb in zip(seqs, embeddings):
454
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
455
+ (seq, emb.cpu().numpy().tobytes()))
456
+
457
+ if (i + 1) % 100 == 0:
458
+ conn.commit()
459
+
460
+ conn.commit()
461
  conn.close()
462
  return None
463