Filipstrozik
Add initial implementation of EllipseRCNN model and dataset utilities
afc2161
from types import NoneType
from typing import List, Tuple, Optional, Any
import pytorch_lightning as pl
import torch
from torch import nn
from torchvision.models import ResNet50_Weights, WeightsEnum
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor # noqa: F
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.ops import MultiScaleRoIAlign
from .ellipse_roi_head import EllipseRoIHeads, EllipseRegressor
from ..utils.types import CollatedBatchType
class EllipseRCNN(GeneralizedRCNN):
def __init__(
self,
num_classes: int = 2,
# transform parameters
backbone_name: str = "resnet50",
weights: WeightsEnum | str = ResNet50_Weights.IMAGENET1K_V1,
min_size: int = 256,
max_size: int = 512,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
# Region Proposal Network parameters
rpn_anchor_generator: Optional[nn.Module] = None,
rpn_head: Optional[nn.Module] = None,
rpn_pre_nms_top_n_train: int = 2000,
rpn_pre_nms_top_n_test: int = 1000,
rpn_post_nms_top_n_train: int = 2000,
rpn_post_nms_top_n_test: int = 1000,
rpn_nms_thresh: float = 0.7,
rpn_fg_iou_thresh: float = 0.7,
rpn_bg_iou_thresh: float = 0.3,
rpn_batch_size_per_image: int = 256,
rpn_positive_fraction: float = 0.5,
rpn_score_thresh: float = 0.0,
# Box parameters
box_roi_pool: Optional[nn.Module] = None,
box_head: Optional[nn.Module] = None,
box_predictor: Optional[nn.Module] = None,
box_score_thresh: float = 0.05,
box_nms_thresh: float = 0.5,
box_detections_per_img: int = 100,
box_fg_iou_thresh: float = 0.5,
box_bg_iou_thresh: float = 0.5,
box_batch_size_per_image: int = 512,
box_positive_fraction: float = 0.25,
bbox_reg_weights: Optional[Tuple[float, float, float, float]] = None,
# Ellipse regressor
ellipse_roi_pool: Optional[nn.Module] = None,
ellipse_head: Optional[nn.Module] = None,
ellipse_predictor: Optional[nn.Module] = None,
ellipse_loss_scale: float = 1.0,
ellipse_loss_normalize: bool = False,
):
if backbone_name != "resnet50" and weights == ResNet50_Weights.IMAGENET1K_V1:
raise ValueError(
"If backbone_name is not resnet50, weights_enum must be specified"
)
backbone = resnet_fpn_backbone(
backbone_name=backbone_name, weights=weights, trainable_layers=5
)
if not hasattr(backbone, "out_channels"):
raise ValueError(
"backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)"
)
if not isinstance(rpn_anchor_generator, (AnchorGenerator, NoneType)):
raise TypeError(
"rpn_anchor_generator must be an instance of AnchorGenerator or None"
)
if not isinstance(box_roi_pool, (MultiScaleRoIAlign, NoneType)):
raise TypeError(
"box_roi_pool must be an instance of MultiScaleRoIAlign or None"
)
if num_classes is not None:
if box_predictor is not None:
raise ValueError(
"num_classes should be None when box_predictor is specified"
)
else:
if box_predictor is None:
raise ValueError(
"num_classes should not be None when box_predictor "
"is not specified"
)
out_channels = backbone.out_channels
if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
if rpn_head is None:
rpn_head = RPNHead(
out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
)
rpn_pre_nms_top_n = dict(
training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test
)
rpn_post_nms_top_n = dict(
training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test
)
rpn = RegionProposalNetwork(
rpn_anchor_generator,
rpn_head,
rpn_fg_iou_thresh,
rpn_bg_iou_thresh,
rpn_batch_size_per_image,
rpn_positive_fraction,
rpn_pre_nms_top_n,
rpn_post_nms_top_n,
rpn_nms_thresh,
score_thresh=rpn_score_thresh,
)
default_representation_size = 1024
if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign(
featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
)
if box_head is None:
resolution = box_roi_pool.output_size[0]
if isinstance(resolution, int):
box_head = TwoMLPHead(
out_channels * resolution**2, default_representation_size
)
else:
raise ValueError(
"resolution should be an int but is {}".format(resolution)
)
if box_predictor is None:
box_predictor = FastRCNNPredictor(default_representation_size, num_classes)
if ellipse_roi_pool is None:
ellipse_roi_pool = MultiScaleRoIAlign(
featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
)
resolution = box_roi_pool.output_size[0]
if ellipse_head is None:
if isinstance(resolution, int):
ellipse_head = TwoMLPHead(
out_channels * resolution**2, default_representation_size
)
else:
raise ValueError(
"resolution should be an int but is {}".format(resolution)
)
if ellipse_predictor is None:
ellipse_predictor = EllipseRegressor(
default_representation_size, num_classes
)
roi_heads = EllipseRoIHeads(
# Box
box_roi_pool,
box_head,
box_predictor,
box_fg_iou_thresh,
box_bg_iou_thresh,
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
box_score_thresh,
box_nms_thresh,
box_detections_per_img,
# Ellipse
ellipse_roi_pool=ellipse_roi_pool,
ellipse_head=ellipse_head,
ellipse_predictor=ellipse_predictor,
loss_scale=ellipse_loss_scale,
kld_normalize=ellipse_loss_normalize,
)
if image_mean is None:
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
super().__init__(backbone, rpn, roi_heads, transform)
class EllipseRCNNLightning(pl.LightningModule):
def __init__(
self,
model: EllipseRCNN,
lr: float = 1e-4,
weight_decay: float = 1e-4,
):
super().__init__()
self.model = model
self.save_hyperparameters(ignore=["model"])
def configure_optimizers(self) -> Any:
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
amsgrad=True,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=2, min_lr=1e-6
)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "monitor": "val/loss_total"},
}
def training_step(
self, batch: CollatedBatchType, batch_idx: int = 0
) -> torch.Tensor:
images, targets = batch
loss_dict = self.model(images, targets)
self.log_dict(
{f"train/{k}": v for k, v in loss_dict.items()},
prog_bar=True,
logger=True,
on_step=True,
)
loss = sum(loss_dict.values())
self.log("train/loss_total", loss, prog_bar=True, logger=True, on_step=True)
return loss
def validation_step(
self, batch: CollatedBatchType, batch_idx: int = 0
) -> torch.Tensor:
self.train(True)
images, targets = batch
loss_dict = self.model(images, targets)
self.log_dict(
{f"val/{k}": v for k, v in loss_dict.items()},
logger=True,
on_step=False,
on_epoch=True,
)
val_loss = sum(loss_dict.values())
self.log(
"val/loss_total",
val_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
)
self.log(
"hp_metric",
val_loss,
)
self.log(
"lr",
self.lr_schedulers().get_last_lr()[0],
)
return val_loss