File size: 5,115 Bytes
039647a
 
 
986f758
039647a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986f758
039647a
 
 
 
986f758
 
039647a
 
986f758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039647a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, GCNConv

from torchvision import models

###############################################################
# These modules correspond to core building blocks of SAG-ViT:
# 1. A CNN feature extractor for high-fidelity multi-scale feature maps.
# 2. A Graph Attention Network (GAT) to refine patch embeddings.
# 3. A Transformer Encoder to capture global long-range dependencies.
# 4. An MLP classifier head.
###############################################################

class EfficientNetV2FeatureExtractor(nn.Module):
    """
    Extracts multi-scale, spatially-rich, and semantically-meaningful feature maps 
    from images using a pre-trained EfficientNetV2-S model. This corresponds 
    to Section 3.1, where a CNN backbone (EfficientNetV2-S) is used to produce rich 
    feature maps that preserve semantic information at multiple scales.
    """
    def __init__(self, pretrained=False):
        super(EfficientNetV2FeatureExtractor, self).__init__()
        
        # Load EfficientNetV2-S with pretrained weights
        efficientnet = models.efficientnet_v2_s(
            weights="IMAGENET1K_V1" if pretrained else None
        )
        
        # Extract layers up to the last block before downsampling below 16x16
        self.extractor = nn.Sequential(*list(efficientnet.features.children())[:-2])


    def forward(self, x):
        """
        Forward pass through the CNN backbone.

        Input:
        - x (Tensor): Input images of shape (B, 3, H, W)

        Output:
        - features (Tensor): Extracted feature map of shape (B, C, H', W'),
          where H' and W' are reduced spatial dimensions.
        """
        features = self.extractor(x)
        return features

class GATGNN(nn.Module):
    """
    A Graph Attention Network (GAT) that processes patch-graph embeddings.
    This module corresponds to the Graph Attention stage (Section 3.3),
    refining local relationships between patches in a learned manner.
    """
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super(GATGNN, self).__init__()
        # GAT layers: 
        # First layer maps raw patch embeddings to a higher-level representation.
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        # Final GCN layer for refined representation
        self.conv2 = GCNConv(hidden_channels * heads, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        """
        Input:
        - data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.

        Output:
        - x (Tensor): Aggregated graph-level embedding after mean pooling.
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # GAT layer with ReLU activation
        x = F.relu(self.conv1(x, edge_index))
        
        # GCN layer for further aggregation
        x = self.conv2(x, edge_index)
        
        # Global mean pooling to obtain graph-level representation
        out = self.pool(x, batch)
        
        return out

    def forward(self, data):
        """
        Input:
        - data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.

        Output:
        - x (Tensor): Aggregated graph-level embedding after mean pooling.
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.elu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        x = self.pool(x, batch)
        return x

class TransformerEncoder(nn.Module):
    """
    A Transformer encoder to capture long-range dependencies among patch embeddings.
    Integrates global dependencies after GAT processing, as per Section 3.3.
    """
    def __init__(self, d_model, nhead, num_layers, dim_feedforward):
        super(TransformerEncoder, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        """
        Input:
        - x (Tensor): Sequence of patch embeddings with shape (B, N, D).

        Output:
        - (Tensor): Transformed embeddings with global relationships integrated (B, N, D).
        """
        # The Transformer expects (N, B, D), so transpose first
        x = x.transpose(0, 1)  # (N, B, D)
        x = self.transformer_encoder(x)
        x = x.transpose(0, 1)  # (B, N, D)
        return x

class MLPBlock(nn.Module):
    """
    An MLP classification head to map final global embeddings to classification logits.
    """
    def __init__(self, in_features, hidden_features, out_features):
        super(MLPBlock, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, out_features)
        )

    def forward(self, x):
        return self.mlp(x)