lhallee commited on
Commit
bcdd73d
·
verified ·
1 Parent(s): ce7dbf9

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +4 -4
modeling_fastesm.py CHANGED
@@ -447,8 +447,8 @@ class FastEsmPreTrainedModel(PreTrainedModel):
447
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
448
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
449
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
450
- residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float() # required for sql
451
- embeddings = get_embeddings(residue_embeddings, attention_mask)
452
 
453
  for seq, emb in zip(seqs, embeddings):
454
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
@@ -466,10 +466,10 @@ class FastEsmPreTrainedModel(PreTrainedModel):
466
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
467
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
468
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
469
- residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float()
470
  if full_precision:
471
  residue_embeddings = residue_embeddings.float()
472
- embeddings = get_embeddings(residue_embeddings, attention_mask)
473
  for seq, emb in zip(seqs, embeddings):
474
  embeddings_dict[seq] = emb
475
 
 
447
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
448
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
449
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
450
+ residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float() # required for sql
451
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
452
 
453
  for seq, emb in zip(seqs, embeddings):
454
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
 
466
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
467
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
468
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
469
+ residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float()
470
  if full_precision:
471
  residue_embeddings = residue_embeddings.float()
472
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
473
  for seq, emb in zip(seqs, embeddings):
474
  embeddings_dict[seq] = emb
475