File size: 5,115 Bytes
039647a 986f758 039647a 986f758 039647a 986f758 039647a 986f758 039647a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, GCNConv
from torchvision import models
###############################################################
# These modules correspond to core building blocks of SAG-ViT:
# 1. A CNN feature extractor for high-fidelity multi-scale feature maps.
# 2. A Graph Attention Network (GAT) to refine patch embeddings.
# 3. A Transformer Encoder to capture global long-range dependencies.
# 4. An MLP classifier head.
###############################################################
class EfficientNetV2FeatureExtractor(nn.Module):
"""
Extracts multi-scale, spatially-rich, and semantically-meaningful feature maps
from images using a pre-trained EfficientNetV2-S model. This corresponds
to Section 3.1, where a CNN backbone (EfficientNetV2-S) is used to produce rich
feature maps that preserve semantic information at multiple scales.
"""
def __init__(self, pretrained=False):
super(EfficientNetV2FeatureExtractor, self).__init__()
# Load EfficientNetV2-S with pretrained weights
efficientnet = models.efficientnet_v2_s(
weights="IMAGENET1K_V1" if pretrained else None
)
# Extract layers up to the last block before downsampling below 16x16
self.extractor = nn.Sequential(*list(efficientnet.features.children())[:-2])
def forward(self, x):
"""
Forward pass through the CNN backbone.
Input:
- x (Tensor): Input images of shape (B, 3, H, W)
Output:
- features (Tensor): Extracted feature map of shape (B, C, H', W'),
where H' and W' are reduced spatial dimensions.
"""
features = self.extractor(x)
return features
class GATGNN(nn.Module):
"""
A Graph Attention Network (GAT) that processes patch-graph embeddings.
This module corresponds to the Graph Attention stage (Section 3.3),
refining local relationships between patches in a learned manner.
"""
def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
super(GATGNN, self).__init__()
# GAT layers:
# First layer maps raw patch embeddings to a higher-level representation.
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
# Final GCN layer for refined representation
self.conv2 = GCNConv(hidden_channels * heads, out_channels)
self.pool = global_mean_pool
def forward(self, data):
"""
Input:
- data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
Output:
- x (Tensor): Aggregated graph-level embedding after mean pooling.
"""
x, edge_index, batch = data.x, data.edge_index, data.batch
# GAT layer with ReLU activation
x = F.relu(self.conv1(x, edge_index))
# GCN layer for further aggregation
x = self.conv2(x, edge_index)
# Global mean pooling to obtain graph-level representation
out = self.pool(x, batch)
return out
def forward(self, data):
"""
Input:
- data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
Output:
- x (Tensor): Aggregated graph-level embedding after mean pooling.
"""
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.elu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
x = self.pool(x, batch)
return x
class TransformerEncoder(nn.Module):
"""
A Transformer encoder to capture long-range dependencies among patch embeddings.
Integrates global dependencies after GAT processing, as per Section 3.3.
"""
def __init__(self, d_model, nhead, num_layers, dim_feedforward):
super(TransformerEncoder, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, x):
"""
Input:
- x (Tensor): Sequence of patch embeddings with shape (B, N, D).
Output:
- (Tensor): Transformed embeddings with global relationships integrated (B, N, D).
"""
# The Transformer expects (N, B, D), so transpose first
x = x.transpose(0, 1) # (N, B, D)
x = self.transformer_encoder(x)
x = x.transpose(0, 1) # (B, N, D)
return x
class MLPBlock(nn.Module):
"""
An MLP classification head to map final global embeddings to classification logits.
"""
def __init__(self, in_features, hidden_features, out_features):
super(MLPBlock, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(in_features, hidden_features),
nn.ReLU(),
nn.Linear(hidden_features, out_features)
)
def forward(self, x):
return self.mlp(x)
|