Update modeling_fastesm.py
Browse files- modeling_fastesm.py +16 -16
modeling_fastesm.py
CHANGED
@@ -442,22 +442,22 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
442 |
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
443 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
444 |
print(f"Embedding {len(to_embed)} new sequences")
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
conn.close()
|
462 |
return None
|
463 |
|
|
|
442 |
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
443 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
444 |
print(f"Embedding {len(to_embed)} new sequences")
|
445 |
+
if len(to_embed) > 0:
|
446 |
+
with torch.no_grad():
|
447 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
448 |
+
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
449 |
+
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
450 |
+
residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float() # required for sql
|
451 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
452 |
+
|
453 |
+
for seq, emb in zip(seqs, embeddings):
|
454 |
+
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
455 |
+
(seq, emb.cpu().numpy().tobytes()))
|
456 |
+
|
457 |
+
if (i + 1) % 100 == 0:
|
458 |
+
conn.commit()
|
459 |
+
|
460 |
+
conn.commit()
|
461 |
conn.close()
|
462 |
return None
|
463 |
|