ankanpy's picture
Upload 9 files
b959f6e verified
import torchvision
import torch
from functools import partial
from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from config import NUM_CLASSES
def create_model(num_classes=91):
"""
Creates a RetinaNet-ResNet50-FPN v2 model pre-trained on COCO.
Replaces the classification head for the required number of classes.
"""
model = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
num_anchors = model.head.classification_head.num_anchors
# Replace the classification head
model.head.classification_head = RetinaNetClassificationHead(
in_channels=256, num_anchors=num_anchors, num_classes=num_classes, norm_layer=partial(torch.nn.GroupNorm, 32)
)
return model
if __name__ == "__main__":
model = create_model(num_classes=NUM_CLASSES)
print(model)
# Total parameters:
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
# Trainable parameters:
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")