Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -720,8 +720,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
720 |
projected_bio_embeddings.append(proj)
|
721 |
for key in output.keys():
|
722 |
outs[f"{key}_{bio_seq_num}"] = output[key]
|
|
|
|
|
723 |
|
724 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
|
|
725 |
|
726 |
# decode
|
727 |
logits = self.biobrain_decoder(
|
@@ -730,13 +733,6 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
730 |
)
|
731 |
|
732 |
outs["logits"] = logits
|
733 |
-
outs["projected_bio_embeddings"] = projected_bio_embeddings
|
734 |
-
|
735 |
-
# Just for debugging
|
736 |
-
print("(debug) remember to remove bio_embeddings storage")
|
737 |
-
if projected_bio_embeddings is not None:
|
738 |
-
for i, embed in enumerate(bio_embeddings_list):
|
739 |
-
outs[f"bio_embeddings_list_{i}"] = embed
|
740 |
|
741 |
return outs
|
742 |
|
|
|
720 |
projected_bio_embeddings.append(proj)
|
721 |
for key in output.keys():
|
722 |
outs[f"{key}_{bio_seq_num}"] = output[key]
|
723 |
+
outs[f"bio_embeddings_list_{bio_seq_num}"] = proj
|
724 |
+
|
725 |
|
726 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
727 |
+
outs["projected_bio_embeddings"] = projected_bio_embeddings.clone()
|
728 |
|
729 |
# decode
|
730 |
logits = self.biobrain_decoder(
|
|
|
733 |
)
|
734 |
|
735 |
outs["logits"] = logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
736 |
|
737 |
return outs
|
738 |
|