File size: 736 Bytes
b7914f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from res.impl.HRNetV2 import HRNetV2
import torch


class Config:
    pass


class HRNetV2Wrapper:
    def __init__(self):
        config = Config()
        config.data_len = 5000
        config.kernel_size = 5
        config.dilation = 1
        config.num_stages = 3
        config.num_blocks = 6
        config.num_modules = [1, 1, 1, 4, 3]
        config.use_bottleneck = [1, 0, 0, 0, 0]
        config.stage1_channels = 128
        config.num_channels_init = 48
        config.interpolate_mode = "linear"
        config.output_size = 3
        self.model = HRNetV2(config)
        weights = torch.load("./res/models/hrnetv2/weights.pth")
        self.model.load_state_dict(weights)
        self.model = self.model.to("cpu").eval()