Yanisadel commited on
Commit
3f5d1dc
·
1 Parent(s): ab017bf

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +3 -7
chatNT.py CHANGED
@@ -720,8 +720,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
720
  projected_bio_embeddings.append(proj)
721
  for key in output.keys():
722
  outs[f"{key}_{bio_seq_num}"] = output[key]
 
 
723
 
724
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
 
725
 
726
  # decode
727
  logits = self.biobrain_decoder(
@@ -730,13 +733,6 @@ class TorchMultiOmicsModel(PreTrainedModel):
730
  )
731
 
732
  outs["logits"] = logits
733
- outs["projected_bio_embeddings"] = projected_bio_embeddings
734
-
735
- # Just for debugging
736
- print("(debug) remember to remove bio_embeddings storage")
737
- if projected_bio_embeddings is not None:
738
- for i, embed in enumerate(bio_embeddings_list):
739
- outs[f"bio_embeddings_list_{i}"] = embed
740
 
741
  return outs
742
 
 
720
  projected_bio_embeddings.append(proj)
721
  for key in output.keys():
722
  outs[f"{key}_{bio_seq_num}"] = output[key]
723
+ outs[f"bio_embeddings_list_{bio_seq_num}"] = proj
724
+
725
 
726
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
727
+ outs["projected_bio_embeddings"] = projected_bio_embeddings.clone()
728
 
729
  # decode
730
  logits = self.biobrain_decoder(
 
733
  )
734
 
735
  outs["logits"] = logits
 
 
 
 
 
 
 
736
 
737
  return outs
738