import torch.nn as nn from isegm.model.modifiers import LRMult from isegm.utils.serialization import serialize from .is_model import ISModel from .modeling.hrnet_ocr import HighResolutionNet class HRNetModel(ISModel): @serialize def __init__( self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs ): super().__init__(norm_layer=norm_layer, **kwargs) self.feature_extractor = HighResolutionNet( width=width, ocr_width=ocr_width, small=small, num_classes=1, norm_layer=norm_layer, ) self.feature_extractor.apply(LRMult(backbone_lr_mult)) if ocr_width > 0: self.feature_extractor.ocr_distri_head.apply(LRMult(1.0)) self.feature_extractor.ocr_gather_head.apply(LRMult(1.0)) self.feature_extractor.conv3x3_ocr.apply(LRMult(1.0)) def backbone_forward(self, image, coord_features=None): net_outputs = self.feature_extractor(image, coord_features) return {"instances": net_outputs[0], "instances_aux": net_outputs[1]}