File size: 1,189 Bytes
2cdd41c
 
1615d09
2cdd41c
1615d09
2cdd41c
 
 
 
 
 
1615d09
 
 
 
 
 
 
 
 
2cdd41c
 
1615d09
 
 
 
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
1615d09
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch.nn as nn

from isegm.model.modifiers import LRMult
from isegm.utils.serialization import serialize

from .is_model import ISModel
from .modeling.hrnet_ocr import HighResolutionNet


class HRNetModel(ISModel):
    @serialize
    def __init__(
        self,
        width=48,
        ocr_width=256,
        small=False,
        backbone_lr_mult=0.1,
        norm_layer=nn.BatchNorm2d,
        **kwargs
    ):
        super().__init__(norm_layer=norm_layer, **kwargs)

        self.feature_extractor = HighResolutionNet(
            width=width,
            ocr_width=ocr_width,
            small=small,
            num_classes=1,
            norm_layer=norm_layer,
        )
        self.feature_extractor.apply(LRMult(backbone_lr_mult))
        if ocr_width > 0:
            self.feature_extractor.ocr_distri_head.apply(LRMult(1.0))
            self.feature_extractor.ocr_gather_head.apply(LRMult(1.0))
            self.feature_extractor.conv3x3_ocr.apply(LRMult(1.0))

    def backbone_forward(self, image, coord_features=None):
        net_outputs = self.feature_extractor(image, coord_features)

        return {"instances": net_outputs[0], "instances_aux": net_outputs[1]}