File size: 13,789 Bytes
c26f22f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 |
# mask_rcnn_swin_meta.py - Mask R-CNN with Swin Transformer for data point segmentation
#
# ADAPTED FROM CASCADE R-CNN CONFIG:
# - Uses same Swin Transformer Base backbone with optimizations
# - Maintains data-point class weighting (10x) and IoU strategies
# - Adds mask head for instance segmentation of data points
# - Uses enhanced annotation files with segmentation masks
# - Keeps custom hooks and progressive loss strategies
#
# MASK-SPECIFIC OPTIMIZATIONS:
# - RoI size 14x14 for mask extraction (matches data point size)
# - FCN mask head with 4 convolution layers
# - Mask loss weight balanced with bbox and classification losses
# - Enhanced test-time augmentation for better mask quality
#
# DATA POINT FOCUS:
# - Primary target: data-point class (ID 11) with 10x weight
# - Generates both bounding boxes AND instance masks
# - Optimized for 16x16 pixel data points in scientific charts
# Removed _base_ inheritance to avoid path issues - all configs are inlined below
# Custom imports - same as Cascade R-CNN setup
custom_imports = dict(
imports=[
'legend_match_swin.custom_models.register',
'legend_match_swin.custom_models.custom_hooks',
'legend_match_swin.custom_models.progressive_loss_hook',
'legend_match_swin.custom_models.flexible_load_annotations',
],
allow_failed_imports=False
)
# Add to Python path
import sys
sys.path.insert(0, '.')
# Mask R-CNN model with Swin Transformer backbone
model = dict(
type='MaskRCNN',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32,
pad_mask=True, # Important for mask training
mask_pad_value=0,
),
# Same Swin Transformer Base backbone as Cascade R-CNN
backbone=dict(
type='SwinTransformer',
embed_dims=128, # Swin Base embedding dimensions
depths=[2, 2, 18, 2], # Swin Base depths
num_heads=[4, 8, 16, 32], # Swin Base attention heads
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.3, # Same as Cascade config
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=False,
convert_weights=True,
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth'
)
),
# Same FPN as Cascade R-CNN
neck=dict(
type='FPN',
in_channels=[128, 256, 512, 1024], # Swin Base: embed_dims * 2^(stage)
out_channels=256,
num_outs=5, # Standard for Mask R-CNN (was 6 in Cascade)
start_level=0,
add_extra_convs='on_input'
),
# Same RPN configuration as Cascade R-CNN
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[1, 2, 4, 8], # Same small scales for tiny objects
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]), # Standard FPN strides for Mask R-CNN
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)
),
# Mask R-CNN ROI head with bbox + mask branches
roi_head=dict(
type='StandardRoIHead',
# Bbox ROI extractor (same as Cascade R-CNN final stage)
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]
),
# Bbox head with data-point class weighting
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=22, # 22 enhanced categories including boxplot
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]
),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
class_weight=[1.0, # background class (index 0)
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
10.0, # data-point at index 12 gets 10x weight (11+1 for background)
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] # Added boxplot class
),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)
),
# Mask ROI extractor (optimized for 16x16 data points)
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=(14, 14), sampling_ratio=0, aligned=True), # Force exact 14x14 with legacy alignment
out_channels=256,
featmap_strides=[4, 8, 16, 32]
),
# Mask head optimized for data points with square mask targets
mask_head=dict(
type='SquareFCNMaskHead',
num_convs=4, # 4 conv layers for good feature extraction
in_channels=256,
roi_feat_size=14, # Explicitly set ROI feature size
conv_out_channels=256,
num_classes=22, # 22 enhanced categories including boxplot
upsample_cfg=dict(type=None), # No upsampling - keep 14x14
loss_mask=dict(
type='CrossEntropyLoss',
use_mask=True,
loss_weight=1.0 # Balanced with bbox loss
)
)
),
# Training configuration adapted from Cascade R-CNN
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
# RCNN training (using Cascade stage 2 settings - balanced for mask training)
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5, # Balanced IoU for bbox + mask training
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=True, # Important for small data points
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=(14, 14), # Force exact 14x14 size for data points
pos_weight=-1,
debug=False)
),
# Test configuration with soft NMS
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.005, # Low threshold to catch data points
nms=dict(
type='soft_nms', # Soft NMS for better small object detection
iou_threshold=0.3, # Low for data points
min_score=0.005,
method='gaussian',
sigma=0.5),
max_per_img=100,
mask_thr_binary=0.5 # Binary mask threshold
)
)
)
# Dataset settings - using standard COCO dataset for mask support
dataset_type = 'CocoDataset'
data_root = ''
# 22 enhanced categories including boxplot
CLASSES = (
'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label', # 0-5
'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item', # 6-10
'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line', # 11-15 (data-point at index 11)
'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area', # 16-20
'boxplot' # 21
)
# Verify data-point class index
assert CLASSES[11] == 'data-point', f"Expected 'data-point' at index 11 in CLASSES tuple, got '{CLASSES[11]}'"
# Training dataloader with mask annotations
train_dataloader = dict(
batch_size=2, # Same as Cascade R-CNN
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='legend_match_swin/mask_generation/enhanced_datasets/train_filtered_with_masks_only.json',
data_prefix=dict(img='legend_data/train/images/'),
metainfo=dict(classes=CLASSES),
filter_cfg=dict(filter_empty_gt=False, min_size=12), # Don't filter out images with masks
# Disable any built-in filtering that might remove annotations
test_mode=False,
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', scale=(1120, 672), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='ClampBBoxes'),
dict(type='PackDetInputs')
]
)
)
# Validation dataloader with mask annotations
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json',
data_prefix=dict(img='legend_data/train/images/'),
metainfo=dict(classes=CLASSES),
test_mode=True,
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(1120, 672), keep_ratio=True),
dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True),
dict(type='ClampBBoxes'),
dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]
)
)
test_dataloader = val_dataloader
# Enhanced evaluators for both bbox and mask metrics
val_evaluator = dict(
type='CocoMetric',
ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json',
metric=['bbox', 'segm'],
format_only=False,
classwise=True,
proposal_nums=(100, 300, 1000)
)
test_evaluator = val_evaluator
# Same custom hooks as Cascade R-CNN
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='DetVisualizationHook')
)
# Same custom hooks as Cascade R-CNN (adapted for Mask R-CNN)
custom_hooks = [
dict(type='SkipBadSamplesHook', interval=1),
dict(type='ChartTypeDistributionHook', interval=500),
dict(type='MissingImageReportHook', interval=1000),
dict(type='NanRecoveryHook',
fallback_loss=1.0,
max_consecutive_nans=50,
log_interval=25),
# Note: Progressive loss hook not used in standard Mask R-CNN
# but could be adapted if needed for bbox loss only
]
# Training configuration - reduced to 20 epochs
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=20, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# Same optimizer settings as Cascade R-CNN
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
clip_grad=dict(max_norm=10.0, norm_type=2)
)
# Same learning rate schedule as Cascade R-CNN
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.1,
by_epoch=False,
begin=0,
end=1000),
dict(
type='CosineAnnealingLR',
begin=0,
end=20,
by_epoch=True,
T_max=20,
eta_min=1e-5,
convert_to_iter_based=True)
]
# Work directory
work_dir = '/content/drive/MyDrive/Research Summer 2025/Dense Captioning Toolkit/CHART-DeMatch/work_dirs/mask_rcnn_swin_base_20ep_meta'
# Fresh start
resume = False
load_from = None
# Default runtime settings (normally inherited from _base_)
default_scope = 'mmdet'
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
log_level = 'INFO' |