Spaces:
Build error
Build error
import logging | |
import torch | |
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2, FasterRCNN, FastRCNNPredictor | |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights | |
from torchvision.models.resnet import ResNet50_Weights | |
logger = logging.getLogger(__name__) | |
def set_seeds(seed: int = 42): | |
"""Sets random sets for torch operations. | |
Args: | |
seed (int, optional): Random seed to set. Defaults to 42. | |
""" | |
# Set the seed for general torch operations | |
torch.manual_seed(seed) | |
# Set the seed for CUDA torch operations (ones that happen on the GPU) | |
torch.cuda.manual_seed(seed) | |
class ModelFactory: | |
def get_model(name, weights, backbone, backbone_weights, trainable_backbone_layers): | |
logger.debug(f"get_model -> model:{name}") | |
if name == "fasterrcnn_resnet50_fpn": | |
# backbone = backbone | |
model_weights = FasterRCNN_ResNet50_FPN_V2_Weights[weights] | |
model_backbone_weights = ResNet50_Weights[backbone_weights] | |
# trainable_backbone_layers = 1 | |
model: FasterRCNN = fasterrcnn_resnet50_fpn_v2( | |
weights=model_weights, backbone_name=backbone, weights_backbone=model_backbone_weights, trainable_backbone_layers=trainable_backbone_layers) | |
# for param in model.rpn.parameters(): | |
# param.requires_grad = False | |
# for param in model.roi_heads.parameters(): | |
# param.requires_grad = False | |
# for param in model.backbone.fpn.parameters(): | |
# param.requires_grad = False | |
set_seeds() | |
num_classes = 2 # 1 class (person) + background | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
model.roi_heads.box_predictor = FastRCNNPredictor( | |
in_features, num_classes) | |
else: | |
logger.error( | |
"Please, provide a valid model as argument. Select one of the following: fasterrcnn_resnet50_fpn.") | |
raise ValueError(name) | |
return model | |