Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -694,6 +694,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
694 |
vocab_size - 1
|
695 |
)
|
696 |
|
|
|
697 |
if bio_token_ids is None:
|
698 |
projected_bio_embeddings = None
|
699 |
else:
|
@@ -708,14 +709,18 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
708 |
|
709 |
|
710 |
# Project these embeddings
|
711 |
-
projected_bio_embeddings = [
|
712 |
-
|
|
|
|
|
713 |
bio_token_ids=bio_token_ids[:, bio_seq_num],
|
714 |
bio_embeddings=bio_embeddings,
|
715 |
english_token_ids=projection_english_tokens_ids,
|
716 |
)
|
717 |
-
|
718 |
-
|
|
|
|
|
719 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
720 |
|
721 |
# decode
|
@@ -724,7 +729,8 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
724 |
projected_bio_embeddings=projected_bio_embeddings,
|
725 |
)
|
726 |
|
727 |
-
outs
|
|
|
728 |
|
729 |
# Just for debugging
|
730 |
print("(debug) remember to remove bio_embeddings storage")
|
@@ -1848,8 +1854,12 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
1848 |
english_token_ids (torch.Tensor):
|
1849 |
Shape (batch_size, num_english_tokens)
|
1850 |
"""
|
|
|
1851 |
projected_bio_embeddings = self.bio_projection(bio_embeddings)
|
|
|
|
|
1852 |
english_embeddings = self.token_embedding(english_token_ids)
|
|
|
1853 |
|
1854 |
bio_attention_mask = build_perceiver_padding_attention_mask(
|
1855 |
bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
|
@@ -1865,7 +1875,7 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
1865 |
attention_mask_2=english_attention_mask,
|
1866 |
)["embeddings"]
|
1867 |
|
1868 |
-
return projected_embeddings
|
1869 |
|
1870 |
|
1871 |
def build_perceiver_padding_attention_mask(
|
|
|
694 |
vocab_size - 1
|
695 |
)
|
696 |
|
697 |
+
outs = {}
|
698 |
if bio_token_ids is None:
|
699 |
projected_bio_embeddings = None
|
700 |
else:
|
|
|
709 |
|
710 |
|
711 |
# Project these embeddings
|
712 |
+
projected_bio_embeddings = []
|
713 |
+
print("(debug) remember to remove loop for projected")
|
714 |
+
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list):
|
715 |
+
proj, output = self.projection_model(
|
716 |
bio_token_ids=bio_token_ids[:, bio_seq_num],
|
717 |
bio_embeddings=bio_embeddings,
|
718 |
english_token_ids=projection_english_tokens_ids,
|
719 |
)
|
720 |
+
projected_bio_embeddings.append(proj)
|
721 |
+
for key in output.keys():
|
722 |
+
outs[f"{key}_{bio_seq_num}"] = output[key]
|
723 |
+
|
724 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
725 |
|
726 |
# decode
|
|
|
729 |
projected_bio_embeddings=projected_bio_embeddings,
|
730 |
)
|
731 |
|
732 |
+
outs["logits"] = logits
|
733 |
+
outs["projected_bio_embeddings"] = projected_bio_embeddings
|
734 |
|
735 |
# Just for debugging
|
736 |
print("(debug) remember to remove bio_embeddings storage")
|
|
|
1854 |
english_token_ids (torch.Tensor):
|
1855 |
Shape (batch_size, num_english_tokens)
|
1856 |
"""
|
1857 |
+
outs = {}
|
1858 |
projected_bio_embeddings = self.bio_projection(bio_embeddings)
|
1859 |
+
print("(debug) remember to remove this projected_bio_embeddings out, and 'outs' output")
|
1860 |
+
outs['projected_bio_embeddings'] = projected_bio_embeddings
|
1861 |
english_embeddings = self.token_embedding(english_token_ids)
|
1862 |
+
outs['english_embeddings'] = english_embeddings
|
1863 |
|
1864 |
bio_attention_mask = build_perceiver_padding_attention_mask(
|
1865 |
bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
|
|
|
1875 |
attention_mask_2=english_attention_mask,
|
1876 |
)["embeddings"]
|
1877 |
|
1878 |
+
return projected_embeddings, outs
|
1879 |
|
1880 |
|
1881 |
def build_perceiver_padding_attention_mask(
|