Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -1763,8 +1763,8 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
1763 |
concat_input_1 = torch.cat([xf_1, x], dim=1)
|
1764 |
concat_input_2 = torch.cat([xf_2, x], dim=1)
|
1765 |
|
1766 |
-
outs[f"concat_input_1_{layer_idx}"] = concat_input_1.clone()
|
1767 |
-
outs[f"concat_input_2_{layer_idx}"] = concat_input_2.clone()
|
1768 |
|
1769 |
output = layer(
|
1770 |
x=x,
|
@@ -1774,7 +1774,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
1774 |
attention_mask_2=attention_mask_2,
|
1775 |
)
|
1776 |
x = output["embeddings"]
|
1777 |
-
outs[f"attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
|
1778 |
|
1779 |
return x, outs
|
1780 |
|
@@ -1789,6 +1789,9 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
1789 |
Computes the embeddings based on the input tokens.
|
1790 |
"""
|
1791 |
new_outs = {}
|
|
|
|
|
|
|
1792 |
assert (
|
1793 |
input_embeddings_1.shape[-1] == self.config.embed_dim
|
1794 |
), "The input embedding dim should match the model embed dim"
|
@@ -1884,7 +1887,7 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
1884 |
projected_embeddings = projected_embeddings["embeddings"]
|
1885 |
|
1886 |
for key in new_outs.keys():
|
1887 |
-
outs[f"{key}
|
1888 |
|
1889 |
return projected_embeddings, outs
|
1890 |
|
|
|
1763 |
concat_input_1 = torch.cat([xf_1, x], dim=1)
|
1764 |
concat_input_2 = torch.cat([xf_2, x], dim=1)
|
1765 |
|
1766 |
+
#outs[f"concat_input_1_{layer_idx}"] = concat_input_1.clone()
|
1767 |
+
#outs[f"concat_input_2_{layer_idx}"] = concat_input_2.clone()
|
1768 |
|
1769 |
output = layer(
|
1770 |
x=x,
|
|
|
1774 |
attention_mask_2=attention_mask_2,
|
1775 |
)
|
1776 |
x = output["embeddings"]
|
1777 |
+
#outs[f"attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
|
1778 |
|
1779 |
return x, outs
|
1780 |
|
|
|
1789 |
Computes the embeddings based on the input tokens.
|
1790 |
"""
|
1791 |
new_outs = {}
|
1792 |
+
new_outs["input_embeddings_1"] = input_embeddings_1.clone()
|
1793 |
+
new_outs["input_embeddings_2"] = input_embeddings_2.clone()
|
1794 |
+
|
1795 |
assert (
|
1796 |
input_embeddings_1.shape[-1] == self.config.embed_dim
|
1797 |
), "The input embedding dim should match the model embed dim"
|
|
|
1887 |
projected_embeddings = projected_embeddings["embeddings"]
|
1888 |
|
1889 |
for key in new_outs.keys():
|
1890 |
+
outs[f"PERCEIVER_{key}"] = new_outs[key]
|
1891 |
|
1892 |
return projected_embeddings, outs
|
1893 |
|