Matteo Sirri
feat: initial commit
169e11c
raw
history blame
2.11 kB
import logging
import torch
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2, FasterRCNN, FastRCNNPredictor
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.resnet import ResNet50_Weights
logger = logging.getLogger(__name__)
def set_seeds(seed: int = 42):
"""Sets random sets for torch operations.
Args:
seed (int, optional): Random seed to set. Defaults to 42.
"""
# Set the seed for general torch operations
torch.manual_seed(seed)
# Set the seed for CUDA torch operations (ones that happen on the GPU)
torch.cuda.manual_seed(seed)
class ModelFactory:
@staticmethod
def get_model(name, weights, backbone, backbone_weights, trainable_backbone_layers):
logger.debug(f"get_model -> model:{name}")
if name == "fasterrcnn_resnet50_fpn":
# backbone = backbone
model_weights = FasterRCNN_ResNet50_FPN_V2_Weights[weights]
model_backbone_weights = ResNet50_Weights[backbone_weights]
# trainable_backbone_layers = 1
model: FasterRCNN = fasterrcnn_resnet50_fpn_v2(
weights=model_weights, backbone_name=backbone, weights_backbone=model_backbone_weights, trainable_backbone_layers=trainable_backbone_layers)
# for param in model.rpn.parameters():
# param.requires_grad = False
# for param in model.roi_heads.parameters():
# param.requires_grad = False
# for param in model.backbone.fpn.parameters():
# param.requires_grad = False
set_seeds()
num_classes = 2 # 1 class (person) + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(
in_features, num_classes)
else:
logger.error(
"Please, provide a valid model as argument. Select one of the following: fasterrcnn_resnet50_fpn.")
raise ValueError(name)
return model