File size: 2,113 Bytes
169e11c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging
import torch
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2, FasterRCNN, FastRCNNPredictor
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.resnet import ResNet50_Weights

logger = logging.getLogger(__name__)


def set_seeds(seed: int = 42):
    """Sets random sets for torch operations.

    Args:
        seed (int, optional): Random seed to set. Defaults to 42.
    """
    # Set the seed for general torch operations
    torch.manual_seed(seed)
    # Set the seed for CUDA torch operations (ones that happen on the GPU)
    torch.cuda.manual_seed(seed)


class ModelFactory:
    @staticmethod
    def get_model(name, weights, backbone, backbone_weights, trainable_backbone_layers):
        logger.debug(f"get_model -> model:{name}")

        if name == "fasterrcnn_resnet50_fpn":
            # backbone = backbone
            model_weights = FasterRCNN_ResNet50_FPN_V2_Weights[weights]
            model_backbone_weights = ResNet50_Weights[backbone_weights]
            # trainable_backbone_layers = 1
            model: FasterRCNN = fasterrcnn_resnet50_fpn_v2(
                weights=model_weights, backbone_name=backbone, weights_backbone=model_backbone_weights, trainable_backbone_layers=trainable_backbone_layers)

            # for param in model.rpn.parameters():
            #     param.requires_grad = False
            # for param in model.roi_heads.parameters():
            #     param.requires_grad = False
            # for param in model.backbone.fpn.parameters():
            #     param.requires_grad = False

            set_seeds()

            num_classes = 2  # 1 class (person) + background
            in_features = model.roi_heads.box_predictor.cls_score.in_features
            model.roi_heads.box_predictor = FastRCNNPredictor(
                in_features, num_classes)

        else:
            logger.error(
                "Please, provide a valid model as argument. Select one of the following: fasterrcnn_resnet50_fpn.")
            raise ValueError(name)

        return model