Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -405,7 +405,9 @@ class TorchBioBrainDecoder(nn.Module):
|
|
405 |
"""
|
406 |
|
407 |
# Compute English token embeddings
|
|
|
408 |
tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
|
|
|
409 |
|
410 |
if projected_bio_embeddings is not None:
|
411 |
(
|
@@ -696,11 +698,14 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
696 |
|
697 |
if projected_bio_embeddings is None:
|
698 |
# Compute bio sequences embeddings
|
|
|
699 |
bio_embeddings_list = [
|
700 |
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
701 |
for bio_seq_num in range(num_bio_sequences)
|
702 |
]
|
703 |
|
|
|
|
|
704 |
# Project these embeddings
|
705 |
projected_bio_embeddings = [
|
706 |
self.projection_model(
|
@@ -710,9 +715,14 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
710 |
)
|
711 |
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
|
712 |
]
|
|
|
713 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
|
|
714 |
|
715 |
# decode
|
|
|
|
|
|
|
716 |
logits = self.biobrain_decoder(
|
717 |
english_token_ids=english_token_ids,
|
718 |
projected_bio_embeddings=projected_bio_embeddings,
|
|
|
405 |
"""
|
406 |
|
407 |
# Compute English token embeddings
|
408 |
+
print("(debug) in biobraindecoder, english tokens ids : ", english_token_ids.shape)
|
409 |
tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
|
410 |
+
print("(debug) tokens_embeddings shape : ", tokens_embeddings.shape)
|
411 |
|
412 |
if projected_bio_embeddings is not None:
|
413 |
(
|
|
|
698 |
|
699 |
if projected_bio_embeddings is None:
|
700 |
# Compute bio sequences embeddings
|
701 |
+
print("(debug) shape bio tokens ids : ", bio_tokens_ids.shape)
|
702 |
bio_embeddings_list = [
|
703 |
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
704 |
for bio_seq_num in range(num_bio_sequences)
|
705 |
]
|
706 |
|
707 |
+
print("(debug) shape of embeddings : ", bio_embeddings_list[0].shape)
|
708 |
+
|
709 |
# Project these embeddings
|
710 |
projected_bio_embeddings = [
|
711 |
self.projection_model(
|
|
|
715 |
)
|
716 |
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
|
717 |
]
|
718 |
+
print("(debug) Shape output projection model : ", projected_bio_embeddings[0].shape)
|
719 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
720 |
+
print("(debug) Shape projected bio embeddings : "), projected_bio_embeddings.shape)
|
721 |
|
722 |
# decode
|
723 |
+
print("(debug) Going in biobrain decoder : ")
|
724 |
+
print("(debug) English token ids : ", english_token_ids.shape)
|
725 |
+
print("(debug) Projected bio embeddings : ", projected_bio_embeddings.shape)
|
726 |
logits = self.biobrain_decoder(
|
727 |
english_token_ids=english_token_ids,
|
728 |
projected_bio_embeddings=projected_bio_embeddings,
|