|
from types import NoneType |
|
from typing import List, Tuple, Optional, Any |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
from torch import nn |
|
from torchvision.models import ResNet50_Weights, WeightsEnum |
|
from torchvision.models.detection.anchor_utils import AnchorGenerator |
|
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone |
|
from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor |
|
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN |
|
from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork |
|
from torchvision.models.detection.transform import GeneralizedRCNNTransform |
|
from torchvision.ops import MultiScaleRoIAlign |
|
|
|
from .ellipse_roi_head import EllipseRoIHeads, EllipseRegressor |
|
from ..utils.types import CollatedBatchType |
|
|
|
|
|
class EllipseRCNN(GeneralizedRCNN): |
|
def __init__( |
|
self, |
|
num_classes: int = 2, |
|
|
|
backbone_name: str = "resnet50", |
|
weights: WeightsEnum | str = ResNet50_Weights.IMAGENET1K_V1, |
|
min_size: int = 256, |
|
max_size: int = 512, |
|
image_mean: Optional[List[float]] = None, |
|
image_std: Optional[List[float]] = None, |
|
|
|
rpn_anchor_generator: Optional[nn.Module] = None, |
|
rpn_head: Optional[nn.Module] = None, |
|
rpn_pre_nms_top_n_train: int = 2000, |
|
rpn_pre_nms_top_n_test: int = 1000, |
|
rpn_post_nms_top_n_train: int = 2000, |
|
rpn_post_nms_top_n_test: int = 1000, |
|
rpn_nms_thresh: float = 0.7, |
|
rpn_fg_iou_thresh: float = 0.7, |
|
rpn_bg_iou_thresh: float = 0.3, |
|
rpn_batch_size_per_image: int = 256, |
|
rpn_positive_fraction: float = 0.5, |
|
rpn_score_thresh: float = 0.0, |
|
|
|
box_roi_pool: Optional[nn.Module] = None, |
|
box_head: Optional[nn.Module] = None, |
|
box_predictor: Optional[nn.Module] = None, |
|
box_score_thresh: float = 0.05, |
|
box_nms_thresh: float = 0.5, |
|
box_detections_per_img: int = 100, |
|
box_fg_iou_thresh: float = 0.5, |
|
box_bg_iou_thresh: float = 0.5, |
|
box_batch_size_per_image: int = 512, |
|
box_positive_fraction: float = 0.25, |
|
bbox_reg_weights: Optional[Tuple[float, float, float, float]] = None, |
|
|
|
ellipse_roi_pool: Optional[nn.Module] = None, |
|
ellipse_head: Optional[nn.Module] = None, |
|
ellipse_predictor: Optional[nn.Module] = None, |
|
ellipse_loss_scale: float = 1.0, |
|
ellipse_loss_normalize: bool = False, |
|
): |
|
if backbone_name != "resnet50" and weights == ResNet50_Weights.IMAGENET1K_V1: |
|
raise ValueError( |
|
"If backbone_name is not resnet50, weights_enum must be specified" |
|
) |
|
|
|
backbone = resnet_fpn_backbone( |
|
backbone_name=backbone_name, weights=weights, trainable_layers=5 |
|
) |
|
|
|
if not hasattr(backbone, "out_channels"): |
|
raise ValueError( |
|
"backbone should contain an attribute out_channels " |
|
"specifying the number of output channels (assumed to be the " |
|
"same for all the levels)" |
|
) |
|
|
|
if not isinstance(rpn_anchor_generator, (AnchorGenerator, NoneType)): |
|
raise TypeError( |
|
"rpn_anchor_generator must be an instance of AnchorGenerator or None" |
|
) |
|
|
|
if not isinstance(box_roi_pool, (MultiScaleRoIAlign, NoneType)): |
|
raise TypeError( |
|
"box_roi_pool must be an instance of MultiScaleRoIAlign or None" |
|
) |
|
|
|
if num_classes is not None: |
|
if box_predictor is not None: |
|
raise ValueError( |
|
"num_classes should be None when box_predictor is specified" |
|
) |
|
else: |
|
if box_predictor is None: |
|
raise ValueError( |
|
"num_classes should not be None when box_predictor " |
|
"is not specified" |
|
) |
|
|
|
out_channels = backbone.out_channels |
|
|
|
if rpn_anchor_generator is None: |
|
anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) |
|
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) |
|
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) |
|
if rpn_head is None: |
|
rpn_head = RPNHead( |
|
out_channels, rpn_anchor_generator.num_anchors_per_location()[0] |
|
) |
|
|
|
rpn_pre_nms_top_n = dict( |
|
training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test |
|
) |
|
rpn_post_nms_top_n = dict( |
|
training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test |
|
) |
|
|
|
rpn = RegionProposalNetwork( |
|
rpn_anchor_generator, |
|
rpn_head, |
|
rpn_fg_iou_thresh, |
|
rpn_bg_iou_thresh, |
|
rpn_batch_size_per_image, |
|
rpn_positive_fraction, |
|
rpn_pre_nms_top_n, |
|
rpn_post_nms_top_n, |
|
rpn_nms_thresh, |
|
score_thresh=rpn_score_thresh, |
|
) |
|
|
|
default_representation_size = 1024 |
|
|
|
if box_roi_pool is None: |
|
box_roi_pool = MultiScaleRoIAlign( |
|
featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2 |
|
) |
|
|
|
if box_head is None: |
|
resolution = box_roi_pool.output_size[0] |
|
if isinstance(resolution, int): |
|
box_head = TwoMLPHead( |
|
out_channels * resolution**2, default_representation_size |
|
) |
|
else: |
|
raise ValueError( |
|
"resolution should be an int but is {}".format(resolution) |
|
) |
|
|
|
if box_predictor is None: |
|
box_predictor = FastRCNNPredictor(default_representation_size, num_classes) |
|
|
|
if ellipse_roi_pool is None: |
|
ellipse_roi_pool = MultiScaleRoIAlign( |
|
featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2 |
|
) |
|
|
|
resolution = box_roi_pool.output_size[0] |
|
if ellipse_head is None: |
|
if isinstance(resolution, int): |
|
ellipse_head = TwoMLPHead( |
|
out_channels * resolution**2, default_representation_size |
|
) |
|
else: |
|
raise ValueError( |
|
"resolution should be an int but is {}".format(resolution) |
|
) |
|
|
|
if ellipse_predictor is None: |
|
ellipse_predictor = EllipseRegressor( |
|
default_representation_size, num_classes |
|
) |
|
|
|
roi_heads = EllipseRoIHeads( |
|
|
|
box_roi_pool, |
|
box_head, |
|
box_predictor, |
|
box_fg_iou_thresh, |
|
box_bg_iou_thresh, |
|
box_batch_size_per_image, |
|
box_positive_fraction, |
|
bbox_reg_weights, |
|
box_score_thresh, |
|
box_nms_thresh, |
|
box_detections_per_img, |
|
|
|
ellipse_roi_pool=ellipse_roi_pool, |
|
ellipse_head=ellipse_head, |
|
ellipse_predictor=ellipse_predictor, |
|
loss_scale=ellipse_loss_scale, |
|
kld_normalize=ellipse_loss_normalize, |
|
) |
|
|
|
if image_mean is None: |
|
image_mean = [0.485, 0.456, 0.406] |
|
if image_std is None: |
|
image_std = [0.229, 0.224, 0.225] |
|
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) |
|
|
|
super().__init__(backbone, rpn, roi_heads, transform) |
|
|
|
|
|
class EllipseRCNNLightning(pl.LightningModule): |
|
def __init__( |
|
self, |
|
model: EllipseRCNN, |
|
lr: float = 1e-4, |
|
weight_decay: float = 1e-4, |
|
): |
|
super().__init__() |
|
self.model = model |
|
self.save_hyperparameters(ignore=["model"]) |
|
|
|
def configure_optimizers(self) -> Any: |
|
optimizer = torch.optim.AdamW( |
|
self.model.parameters(), |
|
lr=self.hparams.lr, |
|
weight_decay=self.hparams.weight_decay, |
|
amsgrad=True, |
|
) |
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer, mode="min", factor=0.5, patience=2, min_lr=1e-6 |
|
) |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": {"scheduler": scheduler, "monitor": "val/loss_total"}, |
|
} |
|
|
|
def training_step( |
|
self, batch: CollatedBatchType, batch_idx: int = 0 |
|
) -> torch.Tensor: |
|
images, targets = batch |
|
loss_dict = self.model(images, targets) |
|
self.log_dict( |
|
{f"train/{k}": v for k, v in loss_dict.items()}, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
) |
|
|
|
loss = sum(loss_dict.values()) |
|
self.log("train/loss_total", loss, prog_bar=True, logger=True, on_step=True) |
|
|
|
return loss |
|
|
|
def validation_step( |
|
self, batch: CollatedBatchType, batch_idx: int = 0 |
|
) -> torch.Tensor: |
|
self.train(True) |
|
images, targets = batch |
|
|
|
loss_dict = self.model(images, targets) |
|
|
|
self.log_dict( |
|
{f"val/{k}": v for k, v in loss_dict.items()}, |
|
logger=True, |
|
on_step=False, |
|
on_epoch=True, |
|
) |
|
|
|
val_loss = sum(loss_dict.values()) |
|
self.log( |
|
"val/loss_total", |
|
val_loss, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=False, |
|
on_epoch=True, |
|
) |
|
|
|
self.log( |
|
"hp_metric", |
|
val_loss, |
|
) |
|
|
|
self.log( |
|
"lr", |
|
self.lr_schedulers().get_last_lr()[0], |
|
) |
|
|
|
return val_loss |
|
|