Spaces:
Build error
Build error
File size: 2,113 Bytes
169e11c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import logging
import torch
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2, FasterRCNN, FastRCNNPredictor
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.resnet import ResNet50_Weights
logger = logging.getLogger(__name__)
def set_seeds(seed: int = 42):
"""Sets random sets for torch operations.
Args:
seed (int, optional): Random seed to set. Defaults to 42.
"""
# Set the seed for general torch operations
torch.manual_seed(seed)
# Set the seed for CUDA torch operations (ones that happen on the GPU)
torch.cuda.manual_seed(seed)
class ModelFactory:
@staticmethod
def get_model(name, weights, backbone, backbone_weights, trainable_backbone_layers):
logger.debug(f"get_model -> model:{name}")
if name == "fasterrcnn_resnet50_fpn":
# backbone = backbone
model_weights = FasterRCNN_ResNet50_FPN_V2_Weights[weights]
model_backbone_weights = ResNet50_Weights[backbone_weights]
# trainable_backbone_layers = 1
model: FasterRCNN = fasterrcnn_resnet50_fpn_v2(
weights=model_weights, backbone_name=backbone, weights_backbone=model_backbone_weights, trainable_backbone_layers=trainable_backbone_layers)
# for param in model.rpn.parameters():
# param.requires_grad = False
# for param in model.roi_heads.parameters():
# param.requires_grad = False
# for param in model.backbone.fpn.parameters():
# param.requires_grad = False
set_seeds()
num_classes = 2 # 1 class (person) + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(
in_features, num_classes)
else:
logger.error(
"Please, provide a valid model as argument. Select one of the following: fasterrcnn_resnet50_fpn.")
raise ValueError(name)
return model
|