Yanisadel commited on
Commit
80a78d5
·
1 Parent(s): 8d737ee

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +3 -0
chatNT.py CHANGED
@@ -426,6 +426,7 @@ class TorchBioBrainDecoder(nn.Module):
426
  )
427
 
428
  # Regular GPT pass through
 
429
  embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
430
  embeddings = self.gpt_model.final_norm(embeddings)
431
 
@@ -885,6 +886,7 @@ class TorchGptGroupedQueryAttention(nn.Module):
885
  value_inputs: torch.Tensor,
886
  attention_mask: torch.Tensor = None,
887
  ) -> torch.Tensor:
 
888
  batch_size, seq_len, _ = query_inputs.shape
889
 
890
  queries = self.query_linear(query_inputs).view( # noqa
@@ -966,6 +968,7 @@ class TorchGptDecoder(nn.Module):
966
  if attention_mask is None:
967
  attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
968
  for layer in self.layers:
 
969
  embeddings = layer(embeddings, attention_mask)
970
 
971
  return embeddings
 
426
  )
427
 
428
  # Regular GPT pass through
429
+ print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
430
  embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
431
  embeddings = self.gpt_model.final_norm(embeddings)
432
 
 
886
  value_inputs: torch.Tensor,
887
  attention_mask: torch.Tensor = None,
888
  ) -> torch.Tensor:
889
+ print("(debug) Query input shape : ", query_inputs.shape)
890
  batch_size, seq_len, _ = query_inputs.shape
891
 
892
  queries = self.query_linear(query_inputs).view( # noqa
 
968
  if attention_mask is None:
969
  attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
970
  for layer in self.layers:
971
+ print("Embedding shape in apply_transformer_layers : ", embeddings.shape)
972
  embeddings = layer(embeddings, attention_mask)
973
 
974
  return embeddings