Yanisadel commited on
Commit
60edd4e
·
1 Parent(s): 93a5c8c

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +13 -2
chatNT.py CHANGED
@@ -1763,6 +1763,9 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1763
  concat_input_1 = torch.cat([xf_1, x], dim=1)
1764
  concat_input_2 = torch.cat([xf_2, x], dim=1)
1765
 
 
 
 
1766
  output = layer(
1767
  x=x,
1768
  cross_attention_embeddings_1=concat_input_1,
@@ -1771,6 +1774,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1771
  attention_mask_2=attention_mask_2,
1772
  )
1773
  x = output["embeddings"]
 
1774
 
1775
  return x, outs
1776
 
@@ -1784,6 +1788,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1784
  """
1785
  Computes the embeddings based on the input tokens.
1786
  """
 
1787
  assert (
1788
  input_embeddings_1.shape[-1] == self.config.embed_dim
1789
  ), "The input embedding dim should match the model embed dim"
@@ -1798,6 +1803,8 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1798
  outs: Dict[str, torch.Tensor] = {}
1799
  x = latent_queries
1800
 
 
 
1801
  x, outs = self.apply_attention_blocks(
1802
  x=x,
1803
  xf_1=input_embeddings_1,
@@ -1865,13 +1872,17 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1865
  english_token_ids, self.config.resampled_length, self.english_pad_token_id
1866
  )
1867
 
1868
- projected_embeddings = self.perceiver_resampler(
1869
  input_embeddings_1=projected_bio_embeddings,
1870
  attention_mask_1=bio_attention_mask,
1871
  input_embeddings_2=english_embeddings,
1872
  attention_mask_2=english_attention_mask,
1873
- )["embeddings"]
 
1874
 
 
 
 
1875
  return projected_embeddings, outs
1876
 
1877
 
 
1763
  concat_input_1 = torch.cat([xf_1, x], dim=1)
1764
  concat_input_2 = torch.cat([xf_2, x], dim=1)
1765
 
1766
+ outs[f"concat_input_1_{layer_idx}"] = concat_input_1.clone()
1767
+ outs[f"concat_input_2_{layer_idx}"] = concat_input_2.clone()
1768
+
1769
  output = layer(
1770
  x=x,
1771
  cross_attention_embeddings_1=concat_input_1,
 
1774
  attention_mask_2=attention_mask_2,
1775
  )
1776
  x = output["embeddings"]
1777
+ outs[f"attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
1778
 
1779
  return x, outs
1780
 
 
1788
  """
1789
  Computes the embeddings based on the input tokens.
1790
  """
1791
+ outs = {}
1792
  assert (
1793
  input_embeddings_1.shape[-1] == self.config.embed_dim
1794
  ), "The input embedding dim should match the model embed dim"
 
1803
  outs: Dict[str, torch.Tensor] = {}
1804
  x = latent_queries
1805
 
1806
+ outs["latent_queries"] = x.clone()
1807
+
1808
  x, outs = self.apply_attention_blocks(
1809
  x=x,
1810
  xf_1=input_embeddings_1,
 
1872
  english_token_ids, self.config.resampled_length, self.english_pad_token_id
1873
  )
1874
 
1875
+ projected_embeddings, new_outs = self.perceiver_resampler(
1876
  input_embeddings_1=projected_bio_embeddings,
1877
  attention_mask_1=bio_attention_mask,
1878
  input_embeddings_2=english_embeddings,
1879
  attention_mask_2=english_attention_mask,
1880
+ )
1881
+ projected_embeddings = projected_embeddings["embeddings"]
1882
 
1883
+ for key in new_outs.keys():
1884
+ outs[f"{key}_perceiver"] = new_outs[key]
1885
+
1886
  return projected_embeddings, outs
1887
 
1888