Yanisadel commited on
Commit
2164e14
·
1 Parent(s): e98c2c7

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +14 -14
chatNT.py CHANGED
@@ -405,9 +405,7 @@ class TorchBioBrainDecoder(nn.Module):
405
  """
406
 
407
  # Compute English token embeddings
408
- print("(debug) in biobraindecoder, english tokens ids : ", english_token_ids.shape)
409
  tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
410
- print("(debug) tokens_embeddings shape : ", tokens_embeddings.shape)
411
 
412
  if projected_bio_embeddings is not None:
413
  (
@@ -419,8 +417,10 @@ class TorchBioBrainDecoder(nn.Module):
419
 
420
  # Insert the bio embeddings at the SEQ token positions
421
  processed_tokens_ids = english_token_ids.clone()
422
- print("(debug) Inside : processed tokens ids shape : ", processed_tokens_ids.shape)
423
- print("(debug) Inside : projected bio embeddings shape : ", projected_bio_embeddings.shape)
 
 
424
  for bio_seq_num in range(num_bio_sequences):
425
  tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
426
  processed_tokens_ids,
@@ -431,7 +431,6 @@ class TorchBioBrainDecoder(nn.Module):
431
  print("After call : ", tokens_embeddings.shape)
432
 
433
  # Regular GPT pass through
434
- print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
435
  embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
436
  embeddings = self.gpt_model.final_norm(embeddings)
437
 
@@ -472,6 +471,11 @@ class TorchBioBrainDecoder(nn.Module):
472
  - input_embeddings with resampled_embeddings inserted at the SEQ token
473
  - tokens with the SEQ token set to -1
474
  """
 
 
 
 
 
475
 
476
  def _insert(
477
  tokens_1d: torch.Tensor,
@@ -485,6 +489,7 @@ class TorchBioBrainDecoder(nn.Module):
485
  resampled_embeddings (torch.Tensor):
486
  Shape (bio_sequence_length, embed_dim,)
487
  """
 
488
  indices = torch.where(tokens_1d == self.seq_token_id)[0]
489
  if indices.numel() > 0:
490
  idx = indices[0].item()
@@ -501,6 +506,7 @@ class TorchBioBrainDecoder(nn.Module):
501
  :-1, :
502
  ]
503
  tokens_1d[idx] = -1
 
504
  return x, tokens_1d
505
  else:
506
  return (
@@ -519,8 +525,11 @@ class TorchBioBrainDecoder(nn.Module):
519
  )
520
  tokens_acc.append(tokens_out)
521
  embeddings_acc.append(embeddings_out)
 
 
522
  tokens_acc = torch.stack(tokens_acc)
523
  embeddings_acc = torch.stack(embeddings_acc)
 
524
 
525
  return embeddings_acc, tokens_acc
526
 
@@ -701,13 +710,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
701
 
702
  if projected_bio_embeddings is None:
703
  # Compute bio sequences embeddings
704
- print("(debug) shape bio tokens ids : ", bio_token_ids.shape)
705
  bio_embeddings_list = [
706
  self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
707
  for bio_seq_num in range(num_bio_sequences)
708
  ]
709
 
710
- print("(debug) shape of embeddings : ", bio_embeddings_list[0].shape)
711
 
712
  # Project these embeddings
713
  projected_bio_embeddings = [
@@ -718,14 +725,9 @@ class TorchMultiOmicsModel(PreTrainedModel):
718
  )
719
  for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
720
  ]
721
- print("(debug) Shape output projection model : ", projected_bio_embeddings[0].shape)
722
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
723
- print("(debug) Shape projected bio embeddings : ", projected_bio_embeddings.shape)
724
 
725
  # decode
726
- print("(debug) Going in biobrain decoder : ")
727
- print("(debug) English token ids : ", english_token_ids.shape)
728
- print("(debug) Projected bio embeddings : ", projected_bio_embeddings.shape)
729
  logits = self.biobrain_decoder(
730
  english_token_ids=english_token_ids,
731
  projected_bio_embeddings=projected_bio_embeddings,
@@ -899,7 +901,6 @@ class TorchGptGroupedQueryAttention(nn.Module):
899
  value_inputs: torch.Tensor,
900
  attention_mask: torch.Tensor = None,
901
  ) -> torch.Tensor:
902
- print("(debug) Query input shape : ", query_inputs.shape)
903
  batch_size, seq_len, _ = query_inputs.shape
904
 
905
  queries = self.query_linear(query_inputs).view( # noqa
@@ -981,7 +982,6 @@ class TorchGptDecoder(nn.Module):
981
  if attention_mask is None:
982
  attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
983
  for layer in self.layers:
984
- print("Embedding shape in apply_transformer_layers : ", embeddings.shape)
985
  embeddings = layer(embeddings, attention_mask)
986
 
987
  return embeddings
 
405
  """
406
 
407
  # Compute English token embeddings
 
408
  tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
 
409
 
410
  if projected_bio_embeddings is not None:
411
  (
 
417
 
418
  # Insert the bio embeddings at the SEQ token positions
419
  processed_tokens_ids = english_token_ids.clone()
420
+ print("(debug) Before call tokens embeddings shape : ", tokens_embeddings.shape)
421
+ print("(debug) Before call Processed tokens ids shape : ", processed_tokens_ids.shape)
422
+ print("(debug) Before call Projected bio embeddings shape : ", projected_bio_embeddings.shape)
423
+ print("num bio sequences : ", num_bio_sequences)
424
  for bio_seq_num in range(num_bio_sequences):
425
  tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
426
  processed_tokens_ids,
 
431
  print("After call : ", tokens_embeddings.shape)
432
 
433
  # Regular GPT pass through
 
434
  embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
435
  embeddings = self.gpt_model.final_norm(embeddings)
436
 
 
471
  - input_embeddings with resampled_embeddings inserted at the SEQ token
472
  - tokens with the SEQ token set to -1
473
  """
474
+ print("Insert_embeddings input shape : ")
475
+ print("Tokens : ", tokens.shape)
476
+ print("Input embeddings : ", input_embeddings.shape)
477
+ print("Resampled embeddings : ", resampled_embeddings.shape)
478
+ print("Bio seq num : ", bio_seq_num)
479
 
480
  def _insert(
481
  tokens_1d: torch.Tensor,
 
489
  resampled_embeddings (torch.Tensor):
490
  Shape (bio_sequence_length, embed_dim,)
491
  """
492
+ print("_insert input : ", input_embeddings_1d.shape, resampled_embeddings_1d.shape)
493
  indices = torch.where(tokens_1d == self.seq_token_id)[0]
494
  if indices.numel() > 0:
495
  idx = indices[0].item()
 
506
  :-1, :
507
  ]
508
  tokens_1d[idx] = -1
509
+ print("_insert output : ", x.shape)
510
  return x, tokens_1d
511
  else:
512
  return (
 
525
  )
526
  tokens_acc.append(tokens_out)
527
  embeddings_acc.append(embeddings_out)
528
+
529
+ print("(Embeddings_acc[0] shape : ", embeddings_acc[0].shape)
530
  tokens_acc = torch.stack(tokens_acc)
531
  embeddings_acc = torch.stack(embeddings_acc)
532
+ print("Embeddings acc shape : ", embeddings_acc.shape)
533
 
534
  return embeddings_acc, tokens_acc
535
 
 
710
 
711
  if projected_bio_embeddings is None:
712
  # Compute bio sequences embeddings
 
713
  bio_embeddings_list = [
714
  self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
715
  for bio_seq_num in range(num_bio_sequences)
716
  ]
717
 
 
718
 
719
  # Project these embeddings
720
  projected_bio_embeddings = [
 
725
  )
726
  for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
727
  ]
 
728
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
 
729
 
730
  # decode
 
 
 
731
  logits = self.biobrain_decoder(
732
  english_token_ids=english_token_ids,
733
  projected_bio_embeddings=projected_bio_embeddings,
 
901
  value_inputs: torch.Tensor,
902
  attention_mask: torch.Tensor = None,
903
  ) -> torch.Tensor:
 
904
  batch_size, seq_len, _ = query_inputs.shape
905
 
906
  queries = self.query_linear(query_inputs).view( # noqa
 
982
  if attention_mask is None:
983
  attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
984
  for layer in self.layers:
 
985
  embeddings = layer(embeddings, attention_mask)
986
 
987
  return embeddings