fisherman611 commited on
Commit
89ae6ce
verified
1 Parent(s): 8313ba2

Create models/can/can.py

Browse files
Files changed (1) hide show
  1. models/can/can.py +819 -0
models/can/can.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ import math
6
+
7
+
8
+ """Custom DenseNet Backbone"""
9
+ class DenseBlock(nn.Module):
10
+ """
11
+ Basic DenseNet block
12
+ """
13
+ def __init__(self, in_channels, growth_rate, num_layers):
14
+ super(DenseBlock, self).__init__()
15
+ self.layers = nn.ModuleList()
16
+ for i in range(num_layers):
17
+ self.layers.append(self._make_layer(in_channels + i * growth_rate, growth_rate))
18
+
19
+ def _make_layer(self, in_channels, growth_rate):
20
+ layer = nn.Sequential(
21
+ nn.BatchNorm2d(in_channels),
22
+ nn.ReLU(inplace=True),
23
+ nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False),
24
+ nn.BatchNorm2d(4 * growth_rate),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
27
+ )
28
+ return layer
29
+
30
+ def forward(self, x):
31
+ features = [x]
32
+ for layer in self.layers:
33
+ new_feature = layer(torch.cat(features, dim=1))
34
+ features.append(new_feature)
35
+ return torch.cat(features, dim=1)
36
+
37
+
38
+ class TransitionLayer(nn.Module):
39
+ """
40
+ Transition layer between DenseBlocks
41
+ """
42
+ def __init__(self, in_channels, out_channels):
43
+ super(TransitionLayer, self).__init__()
44
+ self.transition = nn.Sequential(
45
+ nn.BatchNorm2d(in_channels),
46
+ nn.ReLU(inplace=True),
47
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
48
+ nn.AvgPool2d(kernel_size=2, stride=2)
49
+ )
50
+
51
+ def forward(self, x):
52
+ return self.transition(x)
53
+
54
+
55
+ class DenseNetBackbone(nn.Module):
56
+ """
57
+ DenseNet backbone for CAN
58
+ """
59
+ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64):
60
+ super(DenseNetBackbone, self).__init__()
61
+
62
+ # Initial layer
63
+ self.features = nn.Sequential(
64
+ nn.Conv2d(1, num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
65
+ nn.BatchNorm2d(num_init_features),
66
+ nn.ReLU(inplace=True),
67
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
68
+ )
69
+
70
+ # DenseBlocks
71
+ num_features = num_init_features
72
+ for i, num_layers in enumerate(block_config):
73
+ block = DenseBlock(num_features, growth_rate, num_layers)
74
+ self.features.add_module(f'denseblock{i+1}', block)
75
+ num_features = num_features + growth_rate * num_layers
76
+ if i != len(block_config) - 1:
77
+ trans = TransitionLayer(num_features, num_features // 2)
78
+ self.features.add_module(f'transition{i+1}', trans)
79
+ num_features = num_features // 2
80
+
81
+ # Final processing
82
+ self.features.add_module('norm5', nn.BatchNorm2d(num_features))
83
+ self.features.add_module('relu5', nn.ReLU(inplace=True))
84
+
85
+ self.out_channels = num_features # 684 (with default configuration)
86
+
87
+ def forward(self, x):
88
+ return self.features(x)
89
+
90
+
91
+ """Pretrained DenseNet"""
92
+ class DenseNetFeatureExtractor(nn.Module):
93
+ def __init__(self, densenet_model, out_channels=684):
94
+ super().__init__()
95
+ # Change input conv to 1 channel
96
+ self.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
97
+ # Copy pretrained weights (average over RGB channels)
98
+ self.conv0.weight.data = densenet_model.features.conv0.weight.data.mean(dim=1, keepdim=True)
99
+ self.features = densenet_model.features
100
+ self.out_channels = out_channels
101
+ # Add a 1x1 conv to match your expected output channels if needed
102
+ self.final_conv = nn.Conv2d(1024, out_channels, kernel_size=1)
103
+ self.final_bn = nn.BatchNorm2d(out_channels)
104
+ self.final_relu = nn.ReLU(inplace=True)
105
+
106
+ def forward(self, x):
107
+ x = self.conv0(x)
108
+ x = self.features.norm0(x)
109
+ x = self.features.relu0(x)
110
+ x = self.features.pool0(x)
111
+ x = self.features.denseblock1(x)
112
+ x = self.features.transition1(x)
113
+ x = self.features.denseblock2(x)
114
+ x = self.features.transition2(x)
115
+ x = self.features.denseblock3(x)
116
+ x = self.features.transition3(x)
117
+ x = self.features.denseblock4(x)
118
+ x = self.features.norm5(x)
119
+ x = self.final_conv(x)
120
+ x = self.final_bn(x)
121
+ x = self.final_relu(x)
122
+ return x
123
+
124
+
125
+ """Custom ResNet Backbone"""
126
+ class BasicBlock(nn.Module):
127
+ """
128
+ Basic ResNet block
129
+ """
130
+ expansion = 1
131
+
132
+ def __init__(self, in_channels, out_channels, stride=1):
133
+ super(BasicBlock, self).__init__()
134
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
135
+ self.bn1 = nn.BatchNorm2d(out_channels)
136
+ self.relu = nn.ReLU(inplace=True)
137
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
138
+ self.bn2 = nn.BatchNorm2d(out_channels)
139
+
140
+ self.shortcut = nn.Sequential()
141
+ if stride != 1 or in_channels != out_channels * self.expansion:
142
+ self.shortcut = nn.Sequential(
143
+ nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
144
+ nn.BatchNorm2d(out_channels * self.expansion)
145
+ )
146
+
147
+ def forward(self, x):
148
+ identity = x
149
+
150
+ out = self.conv1(x)
151
+ out = self.bn1(out)
152
+ out = self.relu(out)
153
+
154
+ out = self.conv2(out)
155
+ out = self.bn2(out)
156
+
157
+ out += self.shortcut(identity)
158
+ out = self.relu(out)
159
+
160
+ return out
161
+
162
+
163
+ class Bottleneck(nn.Module):
164
+ """
165
+ Bottleneck ResNet block
166
+ """
167
+ expansion = 4
168
+
169
+ def __init__(self, in_channels, out_channels, stride=1):
170
+ super(Bottleneck, self).__init__()
171
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
172
+ self.bn1 = nn.BatchNorm2d(out_channels)
173
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
174
+ self.bn2 = nn.BatchNorm2d(out_channels)
175
+ self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
176
+ self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
177
+ self.relu = nn.ReLU(inplace=True)
178
+
179
+ self.shortcut = nn.Sequential()
180
+ if stride != 1 or in_channels != out_channels * self.expansion:
181
+ self.shortcut = nn.Sequential(
182
+ nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
183
+ nn.BatchNorm2d(out_channels * self.expansion)
184
+ )
185
+
186
+ def forward(self, x):
187
+ identity = x
188
+
189
+ out = self.conv1(x)
190
+ out = self.bn1(out)
191
+ out = self.relu(out)
192
+
193
+ out = self.conv2(out)
194
+ out = self.bn2(out)
195
+ out = self.relu(out)
196
+
197
+ out = self.conv3(out)
198
+ out = self.bn3(out)
199
+
200
+ out += self.shortcut(identity)
201
+ out = self.relu(out)
202
+
203
+ return out
204
+
205
+
206
+ class ResNetBackbone(nn.Module):
207
+ """
208
+ ResNet backbone for CAN model, designed to output similar dimensions as DenseNet
209
+ """
210
+ def __init__(self, block_type='bottleneck', layers=[3, 4, 6, 3], num_init_features=64):
211
+ super(ResNetBackbone, self).__init__()
212
+
213
+ # Initial layer
214
+ self.conv1 = nn.Conv2d(1, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)
215
+ self.bn1 = nn.BatchNorm2d(num_init_features)
216
+ self.relu = nn.ReLU(inplace=True)
217
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
218
+
219
+ # Define block type
220
+ if block_type == 'basic':
221
+ block = BasicBlock
222
+ expansion = 1
223
+ elif block_type == 'bottleneck':
224
+ block = Bottleneck
225
+ expansion = 4
226
+ else:
227
+ raise ValueError(f"Unknown block type: {block_type}")
228
+
229
+ # Create layers
230
+ self.layer1 = self._make_layer(block, num_init_features, 64, layers[0], stride=1)
231
+ self.layer2 = self._make_layer(block, 64 * expansion, 128, layers[1], stride=2)
232
+ self.layer3 = self._make_layer(block, 128 * expansion, 256, layers[2], stride=2)
233
+ self.layer4 = self._make_layer(block, 256 * expansion, 512, layers[3], stride=2)
234
+
235
+ # Final processing to match DenseNet output channels
236
+ self.final_conv = nn.Conv2d(512 * expansion, 684, kernel_size=1)
237
+ self.final_bn = nn.BatchNorm2d(684)
238
+ self.final_relu = nn.ReLU(inplace=True)
239
+
240
+ self.out_channels = 684 # Match DenseNet output channels
241
+
242
+ # Initialize weights
243
+ self._initialize_weights()
244
+
245
+ def _make_layer(self, block, in_channels, out_channels, num_blocks, stride):
246
+ layers = []
247
+ layers.append(block(in_channels, out_channels, stride))
248
+ for _ in range(1, num_blocks):
249
+ layers.append(block(out_channels * block.expansion, out_channels))
250
+ return nn.Sequential(*layers)
251
+
252
+ def _initialize_weights(self):
253
+ for m in self.modules():
254
+ if isinstance(m, nn.Conv2d):
255
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
256
+ elif isinstance(m, nn.BatchNorm2d):
257
+ nn.init.constant_(m.weight, 1)
258
+ nn.init.constant_(m.bias, 0)
259
+
260
+ def forward(self, x):
261
+ x = self.conv1(x)
262
+ x = self.bn1(x)
263
+ x = self.relu(x)
264
+ x = self.maxpool(x)
265
+
266
+ x = self.layer1(x)
267
+ x = self.layer2(x)
268
+ x = self.layer3(x)
269
+ x = self.layer4(x)
270
+
271
+ x = self.final_conv(x)
272
+ x = self.final_bn(x)
273
+ x = self.final_relu(x)
274
+
275
+ return x
276
+
277
+
278
+
279
+ """Pretrained ResNet"""
280
+ class ResNetFeatureExtractor(nn.Module):
281
+ def __init__(self, resnet_model, out_channels=684):
282
+ super().__init__()
283
+ # Change input conv to 1 channel
284
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
285
+ self.conv1.weight.data = resnet_model.conv1.weight.data.sum(dim=1, keepdim=True) # average weights if needed
286
+ self.bn1 = resnet_model.bn1
287
+ self.relu = resnet_model.relu
288
+ self.maxpool = resnet_model.maxpool
289
+ self.layer1 = resnet_model.layer1
290
+ self.layer2 = resnet_model.layer2
291
+ self.layer3 = resnet_model.layer3
292
+ self.layer4 = resnet_model.layer4
293
+ # Add a 1x1 conv to match DenseNet output channels if needed
294
+ self.final_conv = nn.Conv2d(2048, out_channels, kernel_size=1)
295
+ self.final_bn = nn.BatchNorm2d(out_channels)
296
+ self.final_relu = nn.ReLU(inplace=True)
297
+ self.out_channels = out_channels
298
+
299
+ def forward(self, x):
300
+ x = self.conv1(x)
301
+ x = self.bn1(x)
302
+ x = self.relu(x)
303
+ x = self.maxpool(x)
304
+ x = self.layer1(x)
305
+ x = self.layer2(x)
306
+ x = self.layer3(x)
307
+ x = self.layer4(x)
308
+ x = self.final_conv(x)
309
+ x = self.final_bn(x)
310
+ x = self.final_relu(x)
311
+ return x
312
+
313
+
314
+ """Channel Attention"""
315
+ class ChannelAttention(nn.Module):
316
+ """
317
+ Channel-wise attention mechanism
318
+ """
319
+ def __init__(self, in_channels, ratio=16):
320
+ super(ChannelAttention, self).__init__()
321
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
322
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
323
+
324
+ self.fc = nn.Sequential(
325
+ nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, bias=False),
326
+ nn.ReLU(inplace=True),
327
+ nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, bias=False)
328
+ )
329
+ self.sigmoid = nn.Sigmoid()
330
+
331
+ def forward(self, x):
332
+ avg_out = self.fc(self.avg_pool(x))
333
+ max_out = self.fc(self.max_pool(x))
334
+ out = avg_out + max_out
335
+ return self.sigmoid(out)
336
+
337
+
338
+ """Multi-scale Couting Module"""
339
+ class MSCM(nn.Module):
340
+ """
341
+ Multi-Scale Counting Module
342
+ """
343
+ def __init__(self, in_channels, num_classes):
344
+ super(MSCM, self).__init__()
345
+
346
+ # Branch 1: 3x3 kernel
347
+ self.branch1 = nn.Sequential(
348
+ nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
349
+ nn.ReLU(inplace=True),
350
+ nn.Dropout2d(p=0.2)
351
+ )
352
+ self.attention1 = ChannelAttention(256)
353
+
354
+ # Branch 2: 5x5 kernel
355
+ self.branch2 = nn.Sequential(
356
+ nn.Conv2d(in_channels, 256, kernel_size=5, padding=2),
357
+ nn.ReLU(inplace=True),
358
+ nn.Dropout2d(p=0.2)
359
+ )
360
+ self.attention2 = ChannelAttention(256)
361
+
362
+ # 1x1 Conv layer to reduce channels and create counting map
363
+ self.conv_reduce = nn.Conv2d(512, num_classes, kernel_size=1)
364
+ self.sigmoid = nn.Sigmoid()
365
+
366
+ def forward(self, x):
367
+ # Process branch 1
368
+ out1 = self.branch1(x)
369
+ out1 = out1 * self.attention1(out1)
370
+
371
+ # Process branch 2
372
+ out2 = self.branch2(x)
373
+ out2 = out2 * self.attention2(out2)
374
+
375
+ # Concatenate features from both branches
376
+ concat_features = torch.cat([out1, out2], dim=1) # Shape: B x 512 x H x W
377
+
378
+ # Create counting map
379
+ count_map = self.sigmoid(self.conv_reduce(concat_features)) # Shape: B x C x H x W
380
+
381
+ # Apply sum-pooling to create 1D counting vector
382
+ # Sum over the entire feature map along height and width
383
+ count_vector = torch.sum(count_map, dim=(2, 3)) # Shape: B x C
384
+
385
+ return count_map, count_vector
386
+
387
+
388
+ """Positional Encoding"""
389
+ class PositionalEncoding(nn.Module):
390
+ """
391
+ Positional encoding for attention decoder
392
+ """
393
+ def __init__(self, d_model, max_seq_len=1024):
394
+ super(PositionalEncoding, self).__init__()
395
+ self.d_model = d_model
396
+
397
+ # Create positional encoding matrix
398
+ pe = torch.zeros(max_seq_len, d_model)
399
+ position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
400
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
401
+
402
+ pe[:, 0::2] = torch.sin(position * div_term)
403
+ pe[:, 1::2] = torch.cos(position * div_term)
404
+ self.register_buffer('pe', pe)
405
+
406
+ def forward(self, x):
407
+ # x shape: B x H x W x d_model
408
+ b, h, w, _ = x.shape
409
+
410
+ # Ensure we have enough positional encodings for the feature map size
411
+ if h*w > self.pe.size(0): #type: ignore
412
+ # Dynamically extend positional encodings if needed
413
+ device = self.pe.device
414
+ extended_pe = torch.zeros(h*w, self.d_model, device=device) #type: ignore
415
+ position = torch.arange(0, h*w, dtype=torch.float, device=device).unsqueeze(1) #type: ignore
416
+ div_term = torch.exp(torch.arange(0, self.d_model, 2, device=device).float() * (-math.log(10000.0) / self.d_model)) #type: ignore
417
+
418
+ extended_pe[:, 0::2] = torch.sin(position * div_term)
419
+ extended_pe[:, 1::2] = torch.cos(position * div_term)
420
+
421
+ pos_encoding = extended_pe.view(h, w, -1)
422
+ else:
423
+ # Use pre-computed positional encodings
424
+ pos_encoding = self.pe[:h*w].view(h, w, -1) #type: ignore
425
+
426
+ pos_encoding = pos_encoding.unsqueeze(0).expand(b, -1, -1, -1) # B x H x W x d_model
427
+ return pos_encoding
428
+
429
+
430
+ """Counting-combined Attentional Decoder"""
431
+ class CCAD(nn.Module):
432
+ """
433
+ Counting-Combined Attentional Decoder
434
+ """
435
+ def __init__(self, input_channels, hidden_size, embedding_dim, num_classes, use_coverage=True):
436
+ super(CCAD, self).__init__()
437
+
438
+ self.hidden_size = hidden_size
439
+ self.embedding_dim = embedding_dim
440
+ self.use_coverage = use_coverage
441
+
442
+ # Input layer to reduce feature map
443
+ self.feature_proj = nn.Conv2d(input_channels, hidden_size * 2, kernel_size=1)
444
+
445
+ # Positional encoding
446
+ self.pos_encoder = PositionalEncoding(hidden_size * 2)
447
+
448
+ # Embedding layer for output symbols
449
+ self.embedding = nn.Embedding(num_classes, embedding_dim)
450
+
451
+ # GRU cell
452
+ self.gru = nn.GRUCell(embedding_dim + hidden_size + num_classes, hidden_size)
453
+
454
+ # Attention
455
+ self.attention_w = nn.Linear(hidden_size * 2, hidden_size)
456
+ self.attention_v = nn.Linear(hidden_size, 1)
457
+ if use_coverage:
458
+ self.coverage_proj = nn.Linear(1, hidden_size)
459
+
460
+ # Output layer
461
+ self.out = nn.Linear(hidden_size + hidden_size + num_classes, num_classes)
462
+ self.dropout = nn.Dropout(p=0.3)
463
+
464
+ def forward(self, feature_map, count_vector, target=None, teacher_forcing_ratio=0.5, max_len=200):
465
+ batch_size = feature_map.size(0)
466
+ device = feature_map.device
467
+
468
+ # Transform feature map
469
+ projected_features = self.feature_proj(feature_map) # B x 2*hidden_size x H x W
470
+ H, W = projected_features.size(2), projected_features.size(3)
471
+
472
+ # Reshape feature map to B x H*W x 2*hidden_size
473
+ projected_features = projected_features.permute(0, 2, 3, 1).contiguous() # B x H x W x 2*hidden_size
474
+
475
+ # Add positional encoding
476
+ pos_encoding = self.pos_encoder(projected_features) # B x H x W x 2*hidden_size
477
+ projected_features = projected_features + pos_encoding
478
+
479
+ # Reshape for attention processing
480
+ projected_features = projected_features.view(batch_size, H*W, -1) # B x H*W x 2*hidden_size
481
+
482
+ # Initialize initial hidden state
483
+ h_t = torch.zeros(batch_size, self.hidden_size, device=device)
484
+
485
+ # Initialize coverage attention if used
486
+ if self.use_coverage:
487
+ coverage = torch.zeros(batch_size, H*W, 1, device=device)
488
+
489
+ # First <SOS> token
490
+ y_t_1 = torch.ones(batch_size, dtype=torch.long, device=device)
491
+
492
+ # Prepare target sequence if provided
493
+ if target is not None:
494
+ max_len = target.size(1)
495
+
496
+ # Array to store predictions
497
+ outputs = torch.zeros(batch_size, max_len, self.embedding.num_embeddings, device=device)
498
+
499
+ for t in range(max_len):
500
+ # Apply embedding to the previous symbol
501
+ embedded = self.embedding(y_t_1) # B x embedding_dim
502
+
503
+ # Compute attention
504
+ attention_input = self.attention_w(projected_features) # B x H*W x hidden_size
505
+
506
+ # Add coverage attention if used
507
+ if self.use_coverage:
508
+ coverage_input = self.coverage_proj(coverage.float()) #type: ignore
509
+ attention_input = attention_input + coverage_input
510
+
511
+ # Add hidden state to attention
512
+ h_expanded = h_t.unsqueeze(1).expand(-1, H*W, -1) # B x H*W x hidden_size
513
+ attention_input = torch.tanh(attention_input + h_expanded)
514
+
515
+ # Compute attention weights
516
+ e_t = self.attention_v(attention_input).squeeze(-1) # B x H*W
517
+ alpha_t = F.softmax(e_t, dim=1) # B x H*W
518
+
519
+ # Update coverage if used
520
+ if self.use_coverage:
521
+ coverage = coverage + alpha_t.unsqueeze(-1) #type: ignore
522
+
523
+ # Compute context vector
524
+ alpha_t = alpha_t.unsqueeze(1) # B x 1 x H*W
525
+ context = torch.bmm(alpha_t, projected_features).squeeze(1) # B x 2*hidden_size
526
+ context = context[:, :self.hidden_size] # Take the first half as context vector
527
+
528
+ # Combine embedding, context vector, and count vector
529
+ gru_input = torch.cat([embedded, context, count_vector], dim=1)
530
+
531
+ # Update hidden state
532
+ h_t = self.gru(gru_input, h_t)
533
+
534
+ # Predict output symbol
535
+ output = self.out(torch.cat([h_t, context, count_vector], dim=1))
536
+ outputs[:, t] = output
537
+
538
+ # Decide the next input symbol
539
+ if target is not None and torch.rand(1).item() < teacher_forcing_ratio:
540
+ y_t_1 = target[:, t]
541
+ else:
542
+ # Greedy decoding
543
+ _, y_t_1 = output.max(1)
544
+
545
+ return outputs
546
+
547
+
548
+ """Full model CAN (Counting-Aware Network)"""
549
+ class CAN(nn.Module):
550
+ """
551
+ Counting-Aware Network for handwritten mathematical expression recognition
552
+ """
553
+ def __init__(self, num_classes, backbone=None, hidden_size=256, embedding_dim=256, use_coverage=True):
554
+ super(CAN, self).__init__()
555
+
556
+ # Backbone
557
+ if backbone is None:
558
+ self.backbone = DenseNetBackbone()
559
+ else:
560
+ self.backbone = backbone
561
+ backbone_channels = self.backbone.out_channels
562
+
563
+ # Multi-Scale Counting Module
564
+ self.mscm = MSCM(backbone_channels, num_classes)
565
+
566
+ # Counting-Combined Attentional Decoder
567
+ self.decoder = CCAD(
568
+ input_channels=backbone_channels,
569
+ hidden_size=hidden_size,
570
+ embedding_dim=embedding_dim,
571
+ num_classes=num_classes,
572
+ use_coverage=use_coverage
573
+ )
574
+
575
+ # Save parameters for later use
576
+ self.hidden_size = hidden_size
577
+ self.embedding_dim = embedding_dim
578
+ self.num_classes = num_classes
579
+ self.use_coverage = use_coverage
580
+
581
+ def init_hidden_state(self, visual_features):
582
+ """
583
+ Initialize hidden state and cell state for LSTM
584
+
585
+ Args:
586
+ visual_features: Visual features from backbone
587
+
588
+ Returns:
589
+ h, c: Initial hidden and cell states
590
+ """
591
+ batch_size = visual_features.size(0)
592
+ device = visual_features.device
593
+
594
+ # Initialize hidden state with zeros
595
+ h = torch.zeros(1, batch_size, self.hidden_size, device=device)
596
+ c = torch.zeros(1, batch_size, self.hidden_size, device=device)
597
+
598
+ return h, c
599
+
600
+ def forward(self, x, target=None, teacher_forcing_ratio=0.5):
601
+ # Extract features from backbone
602
+ features = self.backbone(x)
603
+
604
+ # Compute count map and count vector from MSCM
605
+ count_map, count_vector = self.mscm(features)
606
+
607
+ # Decode with CCAD
608
+ outputs = self.decoder(features, count_vector, target, teacher_forcing_ratio)
609
+
610
+ return outputs, count_vector
611
+
612
+ def calculate_loss(self, outputs, targets, count_vectors, count_targets, lambda_count=0.01):
613
+ """
614
+ Compute the combined loss function for CAN
615
+
616
+ Args:
617
+ outputs: Predicted output sequence from decoder
618
+ targets: Actual target sequence
619
+ count_vectors: Predicted count vector
620
+ count_targets: Actual target count vector
621
+ lambda_count: Weight for counting loss
622
+
623
+ Returns:
624
+ Total loss: L = L_cls + 位 * L_counting
625
+ """
626
+ # Loss for decoder (cross entropy)
627
+ L_cls = F.cross_entropy(outputs.view(-1, outputs.size(-1)), targets.view(-1))
628
+
629
+ # Loss for counting (MSE)
630
+ L_counting = F.smooth_l1_loss(count_vectors / self.num_classes, count_targets / self.num_classes)
631
+
632
+ # Total loss
633
+ total_loss = L_cls + lambda_count * L_counting
634
+
635
+ return total_loss, L_cls, L_counting
636
+
637
+ def recognize(self, images, max_length=150, start_token=None, end_token=None, beam_width=5):
638
+ """
639
+ Recognize the handwritten expression using beam search (batch_size=1 only).
640
+
641
+ Args:
642
+ images: Input image tensor, shape (1, channels, height, width)
643
+ max_length: Maximum length of the output sequence
644
+ start_token: Start token index
645
+ end_token: End token index
646
+ beam_width: Beam width for beam search
647
+
648
+ Returns:
649
+ best_sequence: List of token indices
650
+ attention_weights: List of attention weights for visualization
651
+ """
652
+ if images.size(0) != 1:
653
+ raise ValueError("Beam search is implemented only for batch_size=1")
654
+
655
+ device = images.device
656
+
657
+ # Encode the image
658
+ visual_features = self.backbone(images)
659
+
660
+ # Get count vector
661
+ _, count_vector = self.mscm(visual_features)
662
+
663
+ # Prepare feature map for decoder
664
+ projected_features = self.decoder.feature_proj(visual_features) # (1, 2*hidden_size, H, W)
665
+ H, W = projected_features.size(2), projected_features.size(3)
666
+ projected_features = projected_features.permute(0, 2, 3, 1).contiguous() # (1, H, W, 2*hidden_size)
667
+ pos_encoding = self.decoder.pos_encoder(projected_features) # (1, H, W, 2*hidden_size)
668
+ projected_features = projected_features + pos_encoding # (1, H, W, 2*hidden_size)
669
+ projected_features = projected_features.view(1, H*W, -1) # (1, H*W, 2*hidden_size)
670
+
671
+ # Initialize beams
672
+ beam_sequences = [torch.tensor([start_token], device=device)] * beam_width # List of (seq_len) tensors
673
+ beam_scores = torch.zeros(beam_width, device=device) # (beam_width)
674
+ h_t = torch.zeros(beam_width, self.hidden_size, device=device) # (beam_width, hidden_size)
675
+ if self.use_coverage:
676
+ coverage = torch.zeros(beam_width, H*W, device=device) # (beam_width, H*W)
677
+
678
+ all_attention_weights = []
679
+
680
+ for step in range(max_length):
681
+ # Get current tokens for all beams
682
+ current_tokens = torch.tensor([seq[-1] for seq in beam_sequences], device=device) # (beam_width)
683
+
684
+ # Apply embedding
685
+ embedded = self.decoder.embedding(current_tokens) # (beam_width, embedding_dim)
686
+
687
+ # Compute attention for each beam
688
+ attention_input = self.decoder.attention_w(projected_features.expand(beam_width, -1, -1)) # (beam_width, H*W, hidden_size)
689
+ if self.use_coverage:
690
+ coverage_input = self.decoder.coverage_proj(coverage.unsqueeze(-1)) # (beam_width, H*W, hidden_size) #type: ignore
691
+ attention_input = attention_input + coverage_input
692
+
693
+ h_expanded = h_t.unsqueeze(1).expand(-1, H*W, -1) # (beam_width, H*W, hidden_size)
694
+ attention_input = torch.tanh(attention_input + h_expanded)
695
+
696
+ e_t = self.decoder.attention_v(attention_input).squeeze(-1) # (beam_width, H*W)
697
+ alpha_t = F.softmax(e_t, dim=1) # (beam_width, H*W)
698
+
699
+ all_attention_weights.append(alpha_t.detach())
700
+
701
+ if self.use_coverage:
702
+ coverage = coverage + alpha_t #type: ignore
703
+
704
+ context = torch.bmm(alpha_t.unsqueeze(1), projected_features.expand(beam_width, -1, -1)).squeeze(1) # (beam_width, 2*hidden_size)
705
+ context = context[:, :self.hidden_size] # (beam_width, hidden_size)
706
+
707
+ # Expand count_vector to (beam_width, num_classes)
708
+ count_vector_expanded = count_vector.expand(beam_width, -1) # (beam_width, num_classes)
709
+
710
+ gru_input = torch.cat([embedded, context, count_vector_expanded], dim=1) # (beam_width, embedding_dim + hidden_size + num_classes)
711
+
712
+ h_t = self.decoder.gru(gru_input, h_t) # (beam_width, hidden_size)
713
+
714
+ output = self.decoder.out(torch.cat([h_t, context, count_vector_expanded], dim=1)) # (beam_width, num_classes)
715
+ scores = F.log_softmax(output, dim=1) # (beam_width, num_classes)
716
+
717
+ # Compute new scores for all beam-token combinations
718
+ new_beam_scores = beam_scores.unsqueeze(1) + scores # (beam_width, num_classes)
719
+ new_beam_scores_flat = new_beam_scores.view(-1) # (beam_width * num_classes)
720
+
721
+ # Select top beam_width scores and indices
722
+ topk_scores, topk_indices = new_beam_scores_flat.topk(beam_width)
723
+
724
+ # Determine which beam and token each top score corresponds to
725
+ beam_indices = topk_indices // self.num_classes # (beam_width)
726
+ token_indices = topk_indices % self.num_classes # (beam_width)
727
+
728
+ # Create new beam sequences and states
729
+ new_beam_sequences = []
730
+ new_h_t = []
731
+ if self.use_coverage:
732
+ new_coverage = []
733
+ for i in range(beam_width):
734
+ prev_beam_idx = beam_indices[i].item()
735
+ token = token_indices[i].item()
736
+ new_seq = torch.cat([beam_sequences[prev_beam_idx], torch.tensor([token], device=device)]) #type: ignore
737
+ new_beam_sequences.append(new_seq)
738
+ new_h_t.append(h_t[prev_beam_idx])
739
+ if self.use_coverage:
740
+ new_coverage.append(coverage[prev_beam_idx]) #type: ignore
741
+
742
+ # Update beams
743
+ beam_sequences = new_beam_sequences
744
+ beam_scores = topk_scores
745
+ h_t = torch.stack(new_h_t)
746
+ if self.use_coverage:
747
+ coverage = torch.stack(new_coverage) #type: ignore
748
+
749
+ # Select the sequence with the highest score
750
+ best_idx = beam_scores.argmax()
751
+ best_sequence = beam_sequences[best_idx].tolist()
752
+
753
+ # Remove <start> and stop at <end>
754
+ if best_sequence[0] == start_token:
755
+ best_sequence = best_sequence[1:]
756
+ if end_token in best_sequence:
757
+ end_idx = best_sequence.index(end_token)
758
+ best_sequence = best_sequence[:end_idx]
759
+
760
+ return best_sequence, all_attention_weights
761
+
762
+
763
+ def create_can_model(num_classes, hidden_size=256, embedding_dim=256, use_coverage=True, pretrained_backbone=False, backbone_type='densenet'):
764
+ """
765
+ Create CAN model with either DenseNet or ResNet backbone
766
+
767
+ Args:
768
+ num_classes: Number of symbol classes
769
+ pretrained_backbone: Whether to use a pretrained backbone
770
+ backbone_type: Type of backbone to use ('densenet' or 'resnet')
771
+
772
+ Returns:
773
+ CAN model
774
+ """
775
+ # Create backbone
776
+ if backbone_type == 'densenet':
777
+ if pretrained_backbone:
778
+ densenet = models.densenet121(pretrained=True)
779
+ backbone = DenseNetFeatureExtractor(densenet, out_channels=684)
780
+ else:
781
+ backbone = DenseNetBackbone()
782
+ elif backbone_type == 'resnet':
783
+ if pretrained_backbone:
784
+ resnet = models.resnet50(pretrained=True)
785
+ backbone = ResNetFeatureExtractor(resnet, out_channels=684)
786
+ else:
787
+ backbone = ResNetBackbone(block_type='bottleneck', layers=[3, 4, 6, 3])
788
+ else:
789
+ raise ValueError(f"Unknown backbone type: {backbone_type}")
790
+
791
+ # Create model
792
+ model = CAN(
793
+ num_classes=num_classes,
794
+ backbone=backbone,
795
+ hidden_size=hidden_size,
796
+ embedding_dim=embedding_dim,
797
+ use_coverage=use_coverage
798
+ )
799
+
800
+ return model
801
+
802
+
803
+ # # Example usage
804
+ # if __name__ == "__main__":
805
+ # # Create CAN model with 101 symbol classes (example)
806
+ # num_classes = 101 # Number of symbol classes + special tokens like <SOS>, <EOS>
807
+ # model = create_can_model(num_classes)
808
+
809
+ # # Create dummy input data
810
+ # batch_size = 4
811
+ # input_image = torch.randn(batch_size, 1, 128, 384) # B x C x H x W
812
+ # target = torch.randint(0, num_classes, (batch_size, 50)) # B x max_len
813
+
814
+ # # Forward pass
815
+ # outputs, count_vectors = model(input_image, target)
816
+
817
+ # # Print output shapes
818
+ # print(f"Outputs shape: {outputs.shape}") # B x max_len x num_classes
819
+ # print(f"Count vectors shape: {count_vectors.shape}") # B x num_classes