Yanisadel commited on
Commit
8f1087e
·
1 Parent(s): 80a78d5

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +10 -0
chatNT.py CHANGED
@@ -405,7 +405,9 @@ class TorchBioBrainDecoder(nn.Module):
405
  """
406
 
407
  # Compute English token embeddings
 
408
  tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
 
409
 
410
  if projected_bio_embeddings is not None:
411
  (
@@ -696,11 +698,14 @@ class TorchMultiOmicsModel(PreTrainedModel):
696
 
697
  if projected_bio_embeddings is None:
698
  # Compute bio sequences embeddings
 
699
  bio_embeddings_list = [
700
  self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
701
  for bio_seq_num in range(num_bio_sequences)
702
  ]
703
 
 
 
704
  # Project these embeddings
705
  projected_bio_embeddings = [
706
  self.projection_model(
@@ -710,9 +715,14 @@ class TorchMultiOmicsModel(PreTrainedModel):
710
  )
711
  for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
712
  ]
 
713
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
 
714
 
715
  # decode
 
 
 
716
  logits = self.biobrain_decoder(
717
  english_token_ids=english_token_ids,
718
  projected_bio_embeddings=projected_bio_embeddings,
 
405
  """
406
 
407
  # Compute English token embeddings
408
+ print("(debug) in biobraindecoder, english tokens ids : ", english_token_ids.shape)
409
  tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
410
+ print("(debug) tokens_embeddings shape : ", tokens_embeddings.shape)
411
 
412
  if projected_bio_embeddings is not None:
413
  (
 
698
 
699
  if projected_bio_embeddings is None:
700
  # Compute bio sequences embeddings
701
+ print("(debug) shape bio tokens ids : ", bio_tokens_ids.shape)
702
  bio_embeddings_list = [
703
  self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
704
  for bio_seq_num in range(num_bio_sequences)
705
  ]
706
 
707
+ print("(debug) shape of embeddings : ", bio_embeddings_list[0].shape)
708
+
709
  # Project these embeddings
710
  projected_bio_embeddings = [
711
  self.projection_model(
 
715
  )
716
  for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
717
  ]
718
+ print("(debug) Shape output projection model : ", projected_bio_embeddings[0].shape)
719
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
720
+ print("(debug) Shape projected bio embeddings : "), projected_bio_embeddings.shape)
721
 
722
  # decode
723
+ print("(debug) Going in biobrain decoder : ")
724
+ print("(debug) English token ids : ", english_token_ids.shape)
725
+ print("(debug) Projected bio embeddings : ", projected_bio_embeddings.shape)
726
  logits = self.biobrain_decoder(
727
  english_token_ids=english_token_ids,
728
  projected_bio_embeddings=projected_bio_embeddings,