Yanisadel commited on
Commit
ab017bf
·
1 Parent(s): f4d5754

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +16 -6
chatNT.py CHANGED
@@ -694,6 +694,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
694
  vocab_size - 1
695
  )
696
 
 
697
  if bio_token_ids is None:
698
  projected_bio_embeddings = None
699
  else:
@@ -708,14 +709,18 @@ class TorchMultiOmicsModel(PreTrainedModel):
708
 
709
 
710
  # Project these embeddings
711
- projected_bio_embeddings = [
712
- self.projection_model(
 
 
713
  bio_token_ids=bio_token_ids[:, bio_seq_num],
714
  bio_embeddings=bio_embeddings,
715
  english_token_ids=projection_english_tokens_ids,
716
  )
717
- for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
718
- ]
 
 
719
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
720
 
721
  # decode
@@ -724,7 +729,8 @@ class TorchMultiOmicsModel(PreTrainedModel):
724
  projected_bio_embeddings=projected_bio_embeddings,
725
  )
726
 
727
- outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
 
728
 
729
  # Just for debugging
730
  print("(debug) remember to remove bio_embeddings storage")
@@ -1848,8 +1854,12 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1848
  english_token_ids (torch.Tensor):
1849
  Shape (batch_size, num_english_tokens)
1850
  """
 
1851
  projected_bio_embeddings = self.bio_projection(bio_embeddings)
 
 
1852
  english_embeddings = self.token_embedding(english_token_ids)
 
1853
 
1854
  bio_attention_mask = build_perceiver_padding_attention_mask(
1855
  bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
@@ -1865,7 +1875,7 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1865
  attention_mask_2=english_attention_mask,
1866
  )["embeddings"]
1867
 
1868
- return projected_embeddings
1869
 
1870
 
1871
  def build_perceiver_padding_attention_mask(
 
694
  vocab_size - 1
695
  )
696
 
697
+ outs = {}
698
  if bio_token_ids is None:
699
  projected_bio_embeddings = None
700
  else:
 
709
 
710
 
711
  # Project these embeddings
712
+ projected_bio_embeddings = []
713
+ print("(debug) remember to remove loop for projected")
714
+ for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list):
715
+ proj, output = self.projection_model(
716
  bio_token_ids=bio_token_ids[:, bio_seq_num],
717
  bio_embeddings=bio_embeddings,
718
  english_token_ids=projection_english_tokens_ids,
719
  )
720
+ projected_bio_embeddings.append(proj)
721
+ for key in output.keys():
722
+ outs[f"{key}_{bio_seq_num}"] = output[key]
723
+
724
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
725
 
726
  # decode
 
729
  projected_bio_embeddings=projected_bio_embeddings,
730
  )
731
 
732
+ outs["logits"] = logits
733
+ outs["projected_bio_embeddings"] = projected_bio_embeddings
734
 
735
  # Just for debugging
736
  print("(debug) remember to remove bio_embeddings storage")
 
1854
  english_token_ids (torch.Tensor):
1855
  Shape (batch_size, num_english_tokens)
1856
  """
1857
+ outs = {}
1858
  projected_bio_embeddings = self.bio_projection(bio_embeddings)
1859
+ print("(debug) remember to remove this projected_bio_embeddings out, and 'outs' output")
1860
+ outs['projected_bio_embeddings'] = projected_bio_embeddings
1861
  english_embeddings = self.token_embedding(english_token_ids)
1862
+ outs['english_embeddings'] = english_embeddings
1863
 
1864
  bio_attention_mask = build_perceiver_padding_attention_mask(
1865
  bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
 
1875
  attention_mask_2=english_attention_mask,
1876
  )["embeddings"]
1877
 
1878
+ return projected_embeddings, outs
1879
 
1880
 
1881
  def build_perceiver_padding_attention_mask(