Yanisadel commited on
Commit
13dcc57
·
1 Parent(s): 0c3bab2

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +7 -4
chatNT.py CHANGED
@@ -1763,8 +1763,8 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1763
  concat_input_1 = torch.cat([xf_1, x], dim=1)
1764
  concat_input_2 = torch.cat([xf_2, x], dim=1)
1765
 
1766
- outs[f"concat_input_1_{layer_idx}"] = concat_input_1.clone()
1767
- outs[f"concat_input_2_{layer_idx}"] = concat_input_2.clone()
1768
 
1769
  output = layer(
1770
  x=x,
@@ -1774,7 +1774,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1774
  attention_mask_2=attention_mask_2,
1775
  )
1776
  x = output["embeddings"]
1777
- outs[f"attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
1778
 
1779
  return x, outs
1780
 
@@ -1789,6 +1789,9 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1789
  Computes the embeddings based on the input tokens.
1790
  """
1791
  new_outs = {}
 
 
 
1792
  assert (
1793
  input_embeddings_1.shape[-1] == self.config.embed_dim
1794
  ), "The input embedding dim should match the model embed dim"
@@ -1884,7 +1887,7 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1884
  projected_embeddings = projected_embeddings["embeddings"]
1885
 
1886
  for key in new_outs.keys():
1887
- outs[f"{key}_perceiver"] = new_outs[key]
1888
 
1889
  return projected_embeddings, outs
1890
 
 
1763
  concat_input_1 = torch.cat([xf_1, x], dim=1)
1764
  concat_input_2 = torch.cat([xf_2, x], dim=1)
1765
 
1766
+ #outs[f"concat_input_1_{layer_idx}"] = concat_input_1.clone()
1767
+ #outs[f"concat_input_2_{layer_idx}"] = concat_input_2.clone()
1768
 
1769
  output = layer(
1770
  x=x,
 
1774
  attention_mask_2=attention_mask_2,
1775
  )
1776
  x = output["embeddings"]
1777
+ #outs[f"attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
1778
 
1779
  return x, outs
1780
 
 
1789
  Computes the embeddings based on the input tokens.
1790
  """
1791
  new_outs = {}
1792
+ new_outs["input_embeddings_1"] = input_embeddings_1.clone()
1793
+ new_outs["input_embeddings_2"] = input_embeddings_2.clone()
1794
+
1795
  assert (
1796
  input_embeddings_1.shape[-1] == self.config.embed_dim
1797
  ), "The input embedding dim should match the model embed dim"
 
1887
  projected_embeddings = projected_embeddings["embeddings"]
1888
 
1889
  for key in new_outs.keys():
1890
+ outs[f"PERCEIVER_{key}"] = new_outs[key]
1891
 
1892
  return projected_embeddings, outs
1893