Update modeling_fastesm.py
Browse files- modeling_fastesm.py +4 -4
modeling_fastesm.py
CHANGED
@@ -447,8 +447,8 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
447 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
448 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
449 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
450 |
-
residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float() # required for sql
|
451 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
452 |
|
453 |
for seq, emb in zip(seqs, embeddings):
|
454 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
@@ -466,10 +466,10 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
466 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
467 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
468 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
469 |
-
residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float()
|
470 |
if full_precision:
|
471 |
residue_embeddings = residue_embeddings.float()
|
472 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
473 |
for seq, emb in zip(seqs, embeddings):
|
474 |
embeddings_dict[seq] = emb
|
475 |
|
|
|
447 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
448 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
449 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
450 |
+
residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float() # required for sql
|
451 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
452 |
|
453 |
for seq, emb in zip(seqs, embeddings):
|
454 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
|
|
466 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
467 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
468 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
469 |
+
residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float()
|
470 |
if full_precision:
|
471 |
residue_embeddings = residue_embeddings.float()
|
472 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
473 |
for seq, emb in zip(seqs, embeddings):
|
474 |
embeddings_dict[seq] = emb
|
475 |
|