File size: 1,277 Bytes
b959f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torchvision
import torch

from functools import partial
from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from config import NUM_CLASSES


def create_model(num_classes=91):
    """
    Creates a RetinaNet-ResNet50-FPN v2 model pre-trained on COCO.
    Replaces the classification head for the required number of classes.
    """
    model = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
    num_anchors = model.head.classification_head.num_anchors

    # Replace the classification head
    model.head.classification_head = RetinaNetClassificationHead(
        in_channels=256, num_anchors=num_anchors, num_classes=num_classes, norm_layer=partial(torch.nn.GroupNorm, 32)
    )
    return model


if __name__ == "__main__":
    model = create_model(num_classes=NUM_CLASSES)
    print(model)
    # Total parameters:
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{total_params:,} total parameters.")
    # Trainable parameters:
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{total_trainable_params:,} training parameters.")