Yanisadel commited on
Commit
579bdca
·
1 Parent(s): c22355f

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +6 -3
chatNT.py CHANGED
@@ -1788,7 +1788,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1788
  """
1789
  Computes the embeddings based on the input tokens.
1790
  """
1791
- outs = {}
1792
  assert (
1793
  input_embeddings_1.shape[-1] == self.config.embed_dim
1794
  ), "The input embedding dim should match the model embed dim"
@@ -1803,7 +1803,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1803
  outs: Dict[str, torch.Tensor] = {}
1804
  x = latent_queries
1805
 
1806
- outs["latent_queries"] = x.clone()
1807
 
1808
  x, outs = self.apply_attention_blocks(
1809
  x=x,
@@ -1814,9 +1814,12 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1814
  attention_mask_2=attention_mask_2,
1815
  )
1816
 
 
 
 
1817
  outs["embeddings"] = x
1818
 
1819
- return outs
1820
 
1821
 
1822
  class TorchMultiModalPerceiverResamplerProjection(nn.Module):
 
1788
  """
1789
  Computes the embeddings based on the input tokens.
1790
  """
1791
+ new_outs = {}
1792
  assert (
1793
  input_embeddings_1.shape[-1] == self.config.embed_dim
1794
  ), "The input embedding dim should match the model embed dim"
 
1803
  outs: Dict[str, torch.Tensor] = {}
1804
  x = latent_queries
1805
 
1806
+ new_outs["latent_queries"] = x.clone()
1807
 
1808
  x, outs = self.apply_attention_blocks(
1809
  x=x,
 
1814
  attention_mask_2=attention_mask_2,
1815
  )
1816
 
1817
+ for key in outs.keys():
1818
+ new_outs[key] = outs[key].copy()
1819
+
1820
  outs["embeddings"] = x
1821
 
1822
+ return outs, new_outs
1823
 
1824
 
1825
  class TorchMultiModalPerceiverResamplerProjection(nn.Module):