File size: 1,277 Bytes
b959f6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torchvision
import torch
from functools import partial
from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from config import NUM_CLASSES
def create_model(num_classes=91):
"""
Creates a RetinaNet-ResNet50-FPN v2 model pre-trained on COCO.
Replaces the classification head for the required number of classes.
"""
model = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
num_anchors = model.head.classification_head.num_anchors
# Replace the classification head
model.head.classification_head = RetinaNetClassificationHead(
in_channels=256, num_anchors=num_anchors, num_classes=num_classes, norm_layer=partial(torch.nn.GroupNorm, 32)
)
return model
if __name__ == "__main__":
model = create_model(num_classes=NUM_CLASSES)
print(model)
# Total parameters:
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
# Trainable parameters:
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
|