Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -405,9 +405,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
405 |
"""
|
406 |
|
407 |
# Compute English token embeddings
|
408 |
-
print("(debug) in biobraindecoder, english tokens ids : ", english_token_ids.shape)
|
409 |
tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
|
410 |
-
print("(debug) tokens_embeddings shape : ", tokens_embeddings.shape)
|
411 |
|
412 |
if projected_bio_embeddings is not None:
|
413 |
(
|
@@ -419,8 +417,10 @@ class TorchBioBrainDecoder(nn.Module):
|
|
419 |
|
420 |
# Insert the bio embeddings at the SEQ token positions
|
421 |
processed_tokens_ids = english_token_ids.clone()
|
422 |
-
print("(debug)
|
423 |
-
print("(debug)
|
|
|
|
|
424 |
for bio_seq_num in range(num_bio_sequences):
|
425 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
426 |
processed_tokens_ids,
|
@@ -431,7 +431,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
431 |
print("After call : ", tokens_embeddings.shape)
|
432 |
|
433 |
# Regular GPT pass through
|
434 |
-
print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
|
435 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
436 |
embeddings = self.gpt_model.final_norm(embeddings)
|
437 |
|
@@ -472,6 +471,11 @@ class TorchBioBrainDecoder(nn.Module):
|
|
472 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
473 |
- tokens with the SEQ token set to -1
|
474 |
"""
|
|
|
|
|
|
|
|
|
|
|
475 |
|
476 |
def _insert(
|
477 |
tokens_1d: torch.Tensor,
|
@@ -485,6 +489,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
485 |
resampled_embeddings (torch.Tensor):
|
486 |
Shape (bio_sequence_length, embed_dim,)
|
487 |
"""
|
|
|
488 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
489 |
if indices.numel() > 0:
|
490 |
idx = indices[0].item()
|
@@ -501,6 +506,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
501 |
:-1, :
|
502 |
]
|
503 |
tokens_1d[idx] = -1
|
|
|
504 |
return x, tokens_1d
|
505 |
else:
|
506 |
return (
|
@@ -519,8 +525,11 @@ class TorchBioBrainDecoder(nn.Module):
|
|
519 |
)
|
520 |
tokens_acc.append(tokens_out)
|
521 |
embeddings_acc.append(embeddings_out)
|
|
|
|
|
522 |
tokens_acc = torch.stack(tokens_acc)
|
523 |
embeddings_acc = torch.stack(embeddings_acc)
|
|
|
524 |
|
525 |
return embeddings_acc, tokens_acc
|
526 |
|
@@ -701,13 +710,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
701 |
|
702 |
if projected_bio_embeddings is None:
|
703 |
# Compute bio sequences embeddings
|
704 |
-
print("(debug) shape bio tokens ids : ", bio_token_ids.shape)
|
705 |
bio_embeddings_list = [
|
706 |
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
707 |
for bio_seq_num in range(num_bio_sequences)
|
708 |
]
|
709 |
|
710 |
-
print("(debug) shape of embeddings : ", bio_embeddings_list[0].shape)
|
711 |
|
712 |
# Project these embeddings
|
713 |
projected_bio_embeddings = [
|
@@ -718,14 +725,9 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
718 |
)
|
719 |
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
|
720 |
]
|
721 |
-
print("(debug) Shape output projection model : ", projected_bio_embeddings[0].shape)
|
722 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
723 |
-
print("(debug) Shape projected bio embeddings : ", projected_bio_embeddings.shape)
|
724 |
|
725 |
# decode
|
726 |
-
print("(debug) Going in biobrain decoder : ")
|
727 |
-
print("(debug) English token ids : ", english_token_ids.shape)
|
728 |
-
print("(debug) Projected bio embeddings : ", projected_bio_embeddings.shape)
|
729 |
logits = self.biobrain_decoder(
|
730 |
english_token_ids=english_token_ids,
|
731 |
projected_bio_embeddings=projected_bio_embeddings,
|
@@ -899,7 +901,6 @@ class TorchGptGroupedQueryAttention(nn.Module):
|
|
899 |
value_inputs: torch.Tensor,
|
900 |
attention_mask: torch.Tensor = None,
|
901 |
) -> torch.Tensor:
|
902 |
-
print("(debug) Query input shape : ", query_inputs.shape)
|
903 |
batch_size, seq_len, _ = query_inputs.shape
|
904 |
|
905 |
queries = self.query_linear(query_inputs).view( # noqa
|
@@ -981,7 +982,6 @@ class TorchGptDecoder(nn.Module):
|
|
981 |
if attention_mask is None:
|
982 |
attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
|
983 |
for layer in self.layers:
|
984 |
-
print("Embedding shape in apply_transformer_layers : ", embeddings.shape)
|
985 |
embeddings = layer(embeddings, attention_mask)
|
986 |
|
987 |
return embeddings
|
|
|
405 |
"""
|
406 |
|
407 |
# Compute English token embeddings
|
|
|
408 |
tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
|
|
|
409 |
|
410 |
if projected_bio_embeddings is not None:
|
411 |
(
|
|
|
417 |
|
418 |
# Insert the bio embeddings at the SEQ token positions
|
419 |
processed_tokens_ids = english_token_ids.clone()
|
420 |
+
print("(debug) Before call tokens embeddings shape : ", tokens_embeddings.shape)
|
421 |
+
print("(debug) Before call Processed tokens ids shape : ", processed_tokens_ids.shape)
|
422 |
+
print("(debug) Before call Projected bio embeddings shape : ", projected_bio_embeddings.shape)
|
423 |
+
print("num bio sequences : ", num_bio_sequences)
|
424 |
for bio_seq_num in range(num_bio_sequences):
|
425 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
426 |
processed_tokens_ids,
|
|
|
431 |
print("After call : ", tokens_embeddings.shape)
|
432 |
|
433 |
# Regular GPT pass through
|
|
|
434 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
435 |
embeddings = self.gpt_model.final_norm(embeddings)
|
436 |
|
|
|
471 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
472 |
- tokens with the SEQ token set to -1
|
473 |
"""
|
474 |
+
print("Insert_embeddings input shape : ")
|
475 |
+
print("Tokens : ", tokens.shape)
|
476 |
+
print("Input embeddings : ", input_embeddings.shape)
|
477 |
+
print("Resampled embeddings : ", resampled_embeddings.shape)
|
478 |
+
print("Bio seq num : ", bio_seq_num)
|
479 |
|
480 |
def _insert(
|
481 |
tokens_1d: torch.Tensor,
|
|
|
489 |
resampled_embeddings (torch.Tensor):
|
490 |
Shape (bio_sequence_length, embed_dim,)
|
491 |
"""
|
492 |
+
print("_insert input : ", input_embeddings_1d.shape, resampled_embeddings_1d.shape)
|
493 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
494 |
if indices.numel() > 0:
|
495 |
idx = indices[0].item()
|
|
|
506 |
:-1, :
|
507 |
]
|
508 |
tokens_1d[idx] = -1
|
509 |
+
print("_insert output : ", x.shape)
|
510 |
return x, tokens_1d
|
511 |
else:
|
512 |
return (
|
|
|
525 |
)
|
526 |
tokens_acc.append(tokens_out)
|
527 |
embeddings_acc.append(embeddings_out)
|
528 |
+
|
529 |
+
print("(Embeddings_acc[0] shape : ", embeddings_acc[0].shape)
|
530 |
tokens_acc = torch.stack(tokens_acc)
|
531 |
embeddings_acc = torch.stack(embeddings_acc)
|
532 |
+
print("Embeddings acc shape : ", embeddings_acc.shape)
|
533 |
|
534 |
return embeddings_acc, tokens_acc
|
535 |
|
|
|
710 |
|
711 |
if projected_bio_embeddings is None:
|
712 |
# Compute bio sequences embeddings
|
|
|
713 |
bio_embeddings_list = [
|
714 |
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
715 |
for bio_seq_num in range(num_bio_sequences)
|
716 |
]
|
717 |
|
|
|
718 |
|
719 |
# Project these embeddings
|
720 |
projected_bio_embeddings = [
|
|
|
725 |
)
|
726 |
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
|
727 |
]
|
|
|
728 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
|
|
729 |
|
730 |
# decode
|
|
|
|
|
|
|
731 |
logits = self.biobrain_decoder(
|
732 |
english_token_ids=english_token_ids,
|
733 |
projected_bio_embeddings=projected_bio_embeddings,
|
|
|
901 |
value_inputs: torch.Tensor,
|
902 |
attention_mask: torch.Tensor = None,
|
903 |
) -> torch.Tensor:
|
|
|
904 |
batch_size, seq_len, _ = query_inputs.shape
|
905 |
|
906 |
queries = self.query_linear(query_inputs).view( # noqa
|
|
|
982 |
if attention_mask is None:
|
983 |
attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
|
984 |
for layer in self.layers:
|
|
|
985 |
embeddings = layer(embeddings, attention_mask)
|
986 |
|
987 |
return embeddings
|