|
import torch |
|
import torch.nn as nn |
|
from .layers import MLP, MultiHeadAttention |
|
|
|
|
|
class MultiViewClaimRepresentation(nn.Module): |
|
""" |
|
Multi-view claim representation module with transformer-like architecture |
|
for self-attention and cross-attention in text and image modalities. |
|
""" |
|
def __init__(self, text_input_dim=384, image_input_dim=1024, embed_dim=512, num_heads=8, dropout=0.1, mlp_ratio=4.0, fused_attn=False): |
|
super().__init__() |
|
self.text_input_dim = text_input_dim |
|
self.image_input_dim = image_input_dim |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
|
|
self.text_proj = nn.Linear(text_input_dim, embed_dim) |
|
self.image_proj = nn.Linear(image_input_dim, embed_dim) |
|
|
|
|
|
self.text_WQ = nn.Linear(embed_dim, embed_dim) |
|
self.text_WK = nn.Linear(embed_dim, embed_dim) |
|
self.text_WV = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
self.image_WQ = nn.Linear(embed_dim, embed_dim) |
|
self.image_WK = nn.Linear(embed_dim, embed_dim) |
|
self.image_WV = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
self.text_self_attn_out = nn.Linear(embed_dim, embed_dim) |
|
self.image_self_attn_out = nn.Linear(embed_dim, embed_dim) |
|
self.text_cross_attn_out = nn.Linear(embed_dim, embed_dim) |
|
self.image_cross_attn_out = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
self.text_self_ln1 = nn.LayerNorm(embed_dim) |
|
self.text_self_ln2 = nn.LayerNorm(embed_dim) |
|
self.image_self_ln1 = nn.LayerNorm(embed_dim) |
|
self.image_self_ln2 = nn.LayerNorm(embed_dim) |
|
self.text_cross_ln1 = nn.LayerNorm(embed_dim) |
|
self.text_cross_ln2 = nn.LayerNorm(embed_dim) |
|
self.image_cross_ln1 = nn.LayerNorm(embed_dim) |
|
self.image_cross_ln2 = nn.LayerNorm(embed_dim) |
|
|
|
|
|
self.text_mlp = MLP(embed_dim, mlp_ratio, dropout) |
|
self.image_mlp = MLP(embed_dim, mlp_ratio, dropout) |
|
|
|
|
|
self.attention = MultiHeadAttention(embed_dim, num_heads, dropout, fused_attn) |
|
self.proj_dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, X_t=None, X_i=None): |
|
""" |
|
Args: |
|
X_t (Tensor): Text embeddings of shape (B, L_t, D) |
|
X_i (Tensor): Image embeddings of shape (B, L_i, D) |
|
|
|
Returns: |
|
(H_t_fused, H_i_fused): |
|
H_t_fused: Text representations with self- and co-attention |
|
H_i_fused: Image representations with self- and co-attention |
|
""" |
|
|
|
if X_t is not None: |
|
X_t = self.text_proj(X_t) |
|
if X_i is not None: |
|
X_i = self.image_proj(X_i) |
|
|
|
|
|
text_Q = self.text_WQ(X_t) if X_t is not None else None |
|
text_K = self.text_WK(X_t) if X_t is not None else None |
|
text_V = self.text_WV(X_t) if X_t is not None else None |
|
|
|
image_Q = self.image_WQ(X_i) if X_i is not None else None |
|
image_K = self.image_WK(X_i) if X_i is not None else None |
|
image_V = self.image_WV(X_i) if X_i is not None else None |
|
|
|
|
|
if X_t is not None and X_i is None: |
|
|
|
H_t = X_t + self.attention(text_Q, text_K, text_V, self.text_self_attn_out) |
|
H_t = self.text_self_ln1(H_t) |
|
|
|
H_t = H_t + self.text_mlp(H_t) |
|
H_t = self.text_self_ln2(H_t) |
|
return H_t, None |
|
|
|
|
|
if X_i is not None and X_t is None: |
|
|
|
H_i = X_i + self.attention(image_Q, image_K, image_V, self.image_self_attn_out) |
|
H_i = self.image_self_ln1(H_i) |
|
|
|
H_i = H_i + self.image_mlp(H_i) |
|
H_i = self.image_self_ln2(H_i) |
|
return None, H_i |
|
|
|
|
|
|
|
H_t = X_t + self.attention(text_Q, text_K, text_V, self.text_self_attn_out) |
|
H_t = self.text_self_ln1(H_t) |
|
C_t = H_t + self.attention(H_t, text_K, text_V, self.text_cross_attn_out) |
|
C_t = self.text_cross_ln1(C_t) |
|
|
|
C_t = C_t + self.text_mlp(C_t) |
|
C_t = self.text_cross_ln2(C_t) |
|
|
|
|
|
H_i = X_i + self.attention(image_Q, image_K, image_V, self.image_self_attn_out) |
|
H_i = self.image_self_ln1(H_i) |
|
C_i = H_i + self.attention(H_i, image_K, image_V, self.image_cross_attn_out) |
|
C_i = self.image_cross_ln1(C_i) |
|
|
|
C_i = C_i + self.image_mlp(C_i) |
|
C_i = self.image_cross_ln2(C_i) |
|
|
|
return C_t, C_i |
|
|
|
|
|
class CrossAttentionEvidenceConditioning(nn.Module): |
|
""" |
|
Cross-attention module to condition claim representations |
|
on textual and visual evidence. |
|
""" |
|
def __init__(self, text_input_dim=384, image_input_dim=1024, embed_dim=768, num_heads=8, dropout=0.1, mlp_ratio=4.0, fused_attn=False): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.embed_dim = embed_dim |
|
self.dropout = dropout |
|
self.fused_attn = fused_attn |
|
|
|
|
|
self.text_WQ = nn.Linear(embed_dim, embed_dim) |
|
self.image_WQ = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
self.text_evidence_key = nn.Linear(text_input_dim, embed_dim) |
|
self.text_evidence_value = nn.Linear(text_input_dim, embed_dim) |
|
|
|
|
|
self.image_evidence_key = nn.Linear(image_input_dim, embed_dim) |
|
self.image_evidence_value = nn.Linear(image_input_dim, embed_dim) |
|
|
|
|
|
self.text_text_out = nn.Linear(embed_dim, embed_dim) |
|
self.text_image_out = nn.Linear(embed_dim, embed_dim) |
|
self.image_text_out = nn.Linear(embed_dim, embed_dim) |
|
self.image_image_out = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
self.text_text_ln1 = nn.LayerNorm(embed_dim) |
|
self.text_text_ln2 = nn.LayerNorm(embed_dim) |
|
self.text_image_ln1 = nn.LayerNorm(embed_dim) |
|
self.text_image_ln2 = nn.LayerNorm(embed_dim) |
|
self.image_text_ln1 = nn.LayerNorm(embed_dim) |
|
self.image_text_ln2 = nn.LayerNorm(embed_dim) |
|
self.image_image_ln1 = nn.LayerNorm(embed_dim) |
|
self.image_image_ln2 = nn.LayerNorm(embed_dim) |
|
|
|
|
|
self.text_mlp = MLP(embed_dim, mlp_ratio, dropout) |
|
self.image_mlp = MLP(embed_dim, mlp_ratio, dropout) |
|
|
|
|
|
self.attention = MultiHeadAttention(embed_dim, num_heads, dropout, fused_attn) |
|
self.proj_dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, H_t=None, H_i=None, E_t=None, E_i=None): |
|
""" |
|
Returns: |
|
(S_t, S_i): Each contains a tuple of (text_evidence_output, image_evidence_output) |
|
""" |
|
S_t_t, S_t_i = None, None |
|
S_i_t, S_i_i = None, None |
|
|
|
if H_t is not None: |
|
|
|
S_t_t = self.attention( |
|
Q=self.text_WQ(H_t), |
|
K=self.text_evidence_key(E_t), |
|
V=self.text_evidence_value(E_t), |
|
out_proj=self.text_text_out |
|
) |
|
S_t_t = H_t + S_t_t |
|
S_t_t = self.text_text_ln1(S_t_t) |
|
S_t_t = S_t_t + self.text_mlp(S_t_t) |
|
S_t_t = self.text_text_ln2(S_t_t) |
|
|
|
|
|
S_t_i = self.attention( |
|
Q=self.text_WQ(H_t), |
|
K=self.image_evidence_key(E_i), |
|
V=self.image_evidence_value(E_i), |
|
out_proj=self.text_image_out |
|
) |
|
S_t_i = H_t + S_t_i |
|
S_t_i = self.text_image_ln1(S_t_i) |
|
S_t_i = S_t_i + self.text_mlp(S_t_i) |
|
S_t_i = self.text_image_ln2(S_t_i) |
|
|
|
if H_i is not None: |
|
|
|
S_i_t = self.attention( |
|
Q=self.image_WQ(H_i), |
|
K=self.text_evidence_key(E_t), |
|
V=self.text_evidence_value(E_t), |
|
out_proj=self.image_text_out |
|
) |
|
S_i_t = H_i + S_i_t |
|
S_i_t = self.image_text_ln1(S_i_t) |
|
S_i_t = S_i_t + self.image_mlp(S_i_t) |
|
S_i_t = self.image_text_ln2(S_i_t) |
|
|
|
|
|
S_i_i = self.attention( |
|
Q=self.image_WQ(H_i), |
|
K=self.image_evidence_key(E_i), |
|
V=self.image_evidence_value(E_i), |
|
out_proj=self.image_image_out |
|
) |
|
S_i_i = H_i + S_i_i |
|
S_i_i = self.image_image_ln1(S_i_i) |
|
S_i_i = S_i_i + self.image_mlp(S_i_i) |
|
S_i_i = self.image_image_ln2(S_i_i) |
|
|
|
return (S_t_t, S_t_i), (S_i_t, S_i_i) |
|
|
|
|
|
class ClassificationModule(nn.Module): |
|
""" |
|
Classification module that takes final text/image representations |
|
and outputs logits for {support, refute, not enough info} |
|
for each evidence path. |
|
""" |
|
def __init__(self, embed_dim=768, hidden_dim=256, num_classes=3, dropout=0.1): |
|
super().__init__() |
|
|
|
self.mlp_text_given_text = nn.Sequential( |
|
nn.Linear(embed_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim, num_classes) |
|
) |
|
self.mlp_text_given_image = nn.Sequential( |
|
nn.Linear(embed_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim, num_classes) |
|
) |
|
|
|
|
|
self.mlp_image_given_text = nn.Sequential( |
|
nn.Linear(embed_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim, num_classes) |
|
) |
|
self.mlp_image_given_image = nn.Sequential( |
|
nn.Linear(embed_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim, num_classes) |
|
) |
|
|
|
def forward(self, S_t=None, S_i=None): |
|
""" |
|
Args: |
|
S_t: Tuple of (text_given_text, text_given_image) representations |
|
S_i: Tuple of (image_given_text, image_given_image) representations |
|
Returns: |
|
y_t: Tuple of (text_given_text_logits, text_given_image_logits) |
|
y_i: Tuple of (image_given_text_logits, image_given_image_logits) |
|
""" |
|
y_t_t, y_t_i = None, None |
|
y_i_t, y_i_i = None, None |
|
|
|
if S_t is not None: |
|
S_t_t, S_t_i = S_t |
|
if S_t_t is not None: |
|
pooled_t_t = S_t_t.mean(dim=1) |
|
y_t_t = self.mlp_text_given_text(pooled_t_t) |
|
if S_t_i is not None: |
|
pooled_t_i = S_t_i.mean(dim=1) |
|
y_t_i = self.mlp_text_given_image(pooled_t_i) |
|
|
|
if S_i is not None: |
|
S_i_t, S_i_i = S_i |
|
if S_i_t is not None: |
|
pooled_i_t = S_i_t.mean(dim=1) |
|
y_i_t = self.mlp_image_given_text(pooled_i_t) |
|
if S_i_i is not None: |
|
pooled_i_i = S_i_i.mean(dim=1) |
|
y_i_i = self.mlp_image_given_image(pooled_i_i) |
|
|
|
return (y_t_t, y_t_i), (y_i_t, y_i_i) |
|
|
|
|
|
class MisinformationDetectionModel(nn.Module): |
|
""" |
|
End-to-end model combining: |
|
1) Multi-view claim representation |
|
2) Cross-attention evidence conditioning |
|
3) Classification for each evidence path |
|
""" |
|
def __init__(self, |
|
text_input_dim=384, |
|
image_input_dim=1024, |
|
embed_dim=512, |
|
num_heads=8, |
|
dropout=0.1, |
|
hidden_dim=256, |
|
num_classes=3, |
|
mlp_ratio=4.0, |
|
fused_attn=False): |
|
super().__init__() |
|
|
|
self.representation = MultiViewClaimRepresentation( |
|
text_input_dim=text_input_dim, |
|
image_input_dim=image_input_dim, |
|
embed_dim=embed_dim, |
|
num_heads=num_heads, |
|
dropout=dropout, |
|
mlp_ratio=mlp_ratio, |
|
fused_attn=fused_attn |
|
) |
|
self.cross_attn = CrossAttentionEvidenceConditioning( |
|
text_input_dim=text_input_dim, |
|
image_input_dim=image_input_dim, |
|
embed_dim=embed_dim, |
|
num_heads=num_heads, |
|
dropout=dropout, |
|
mlp_ratio=mlp_ratio, |
|
fused_attn=fused_attn |
|
) |
|
self.classifier = ClassificationModule( |
|
embed_dim=embed_dim, |
|
hidden_dim=hidden_dim, |
|
num_classes=num_classes, |
|
dropout=dropout |
|
) |
|
|
|
|
|
self._initialize_weights() |
|
|
|
def _initialize_weights(self): |
|
for module in self.modules(): |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.LayerNorm): |
|
nn.init.ones_(module.weight) |
|
nn.init.zeros_(module.bias) |
|
|
|
def forward(self, X_t=None, X_i=None, E_t=None, E_i=None): |
|
""" |
|
Args: |
|
X_t (Tensor): Text claim embeddings (B, L_t, D) |
|
X_i (Tensor): Image claim embeddings (B, L_i, D) |
|
E_t (Tensor): Text evidence embeddings (B, L_e_t, D) |
|
E_i (Tensor): Image evidence embeddings (B, L_e_i, D) |
|
|
|
Returns: |
|
y_t: Tuple of (text_given_text_logits, text_given_image_logits) |
|
y_i: Tuple of (image_given_text_logits, image_given_image_logits) |
|
Each logit tensor has shape (B, num_classes) |
|
""" |
|
|
|
H_t, H_i = self.representation(X_t, X_i) |
|
|
|
|
|
(S_t_t, S_t_i), (S_i_t, S_i_i) = self.cross_attn(H_t, H_i, E_t, E_i) |
|
|
|
|
|
(y_t_t, y_t_i), (y_i_t, y_i_i) = self.classifier( |
|
S_t=(S_t_t, S_t_i), |
|
S_i=(S_i_t, S_i_i) |
|
) |
|
|
|
return (y_t_t, y_t_i), (y_i_t, y_i_i) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
batch_size = 2 |
|
seq_len_t = 5 |
|
seq_len_i = 7 |
|
evidence_len_t = 6 |
|
evidence_len_i = 8 |
|
embed_dim = 768 |
|
|
|
|
|
text_claim = torch.randn(batch_size, seq_len_t, embed_dim) |
|
image_claim = torch.randn(batch_size, seq_len_i, embed_dim) |
|
text_evidence = torch.randn(batch_size, evidence_len_t, embed_dim) |
|
image_evidence = torch.randn(batch_size, evidence_len_i, embed_dim) |
|
|
|
|
|
model = MisinformationDetectionModel( |
|
embed_dim=embed_dim, |
|
num_heads=8, |
|
dropout=0.1, |
|
hidden_dim=256, |
|
num_classes=3 |
|
) |
|
|
|
|
|
(y_t_t, y_t_i), (y_i_t, y_i_i) = model( |
|
X_t=text_claim, |
|
X_i=image_claim, |
|
E_t=text_evidence, |
|
E_i=image_evidence |
|
) |
|
print("Text-Text logits:", y_t_t.shape) |
|
print("Text-Image logits:", y_t_i.shape) |
|
print("Image-Text logits:", y_i_t.shape) |
|
print("Image-Image logits:", y_i_i.shape) |
|
|
|
|
|
(y_t_t, y_t_i), (y_i_t, y_i_i) = model( |
|
X_t=text_claim, |
|
E_t=text_evidence |
|
) |
|
print("\nUnimodal Text:") |
|
print("Text-Text logits:", y_t_t.shape if y_t_t is not None else None) |
|
print("Text-Image logits:", y_t_i if y_t_i is not None else None) |
|
print("Image-Text logits:", y_i_t if y_i_t is not None else None) |
|
print("Image-Image logits:", y_i_i if y_i_i is not None else None) |