Spaces:
svjack
/
Runtime error

File size: 3,352 Bytes
c614b0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from collections import OrderedDict

import torch
import torch.nn as nn
from torchvision.models import (
    ResNet50_Weights,
    VGG16_BN_Weights,
    VGG16_Weights,
    resnet50,
    vgg16,
    vgg16_bn,
)

from engine.BiRefNet.config import Config
from engine.BiRefNet.models.backbones.pvt_v2 import (
    pvt_v2_b0,
    pvt_v2_b1,
    pvt_v2_b2,
    pvt_v2_b5,
)
from engine.BiRefNet.models.backbones.swin_v1 import (
    swin_v1_b,
    swin_v1_l,
    swin_v1_s,
    swin_v1_t,
)

config = Config()


def build_backbone(bb_name, pretrained=True, params_settings=""):
    if bb_name == "vgg16":
        bb_net = list(
            vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children()
        )[0]
        bb = nn.Sequential(
            OrderedDict(
                {
                    "conv1": bb_net[:4],
                    "conv2": bb_net[4:9],
                    "conv3": bb_net[9:16],
                    "conv4": bb_net[16:23],
                }
            )
        )
    elif bb_name == "vgg16bn":
        bb_net = list(
            vgg16_bn(
                pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None
            ).children()
        )[0]
        bb = nn.Sequential(
            OrderedDict(
                {
                    "conv1": bb_net[:6],
                    "conv2": bb_net[6:13],
                    "conv3": bb_net[13:23],
                    "conv4": bb_net[23:33],
                }
            )
        )
    elif bb_name == "resnet50":
        bb_net = list(
            resnet50(
                pretrained=ResNet50_Weights.DEFAULT if pretrained else None
            ).children()
        )
        bb = nn.Sequential(
            OrderedDict(
                {
                    "conv1": nn.Sequential(*bb_net[0:3]),
                    "conv2": bb_net[4],
                    "conv3": bb_net[5],
                    "conv4": bb_net[6],
                }
            )
        )
    else:
        bb = eval("{}({})".format(bb_name, params_settings))
        if pretrained:
            bb = load_weights(bb, bb_name)
    return bb


def load_weights(model, model_name):
    save_model = torch.load(
        config.weights[model_name], map_location="cpu", weights_only=True
    )
    model_dict = model.state_dict()
    state_dict = {
        k: v if v.size() == model_dict[k].size() else model_dict[k]
        for k, v in save_model.items()
        if k in model_dict.keys()
    }
    # to ignore the weights with mismatched size when I modify the backbone itself.
    if not state_dict:
        save_model_keys = list(save_model.keys())
        sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None
        state_dict = {
            k: v if v.size() == model_dict[k].size() else model_dict[k]
            for k, v in save_model[sub_item].items()
            if k in model_dict.keys()
        }
        if not state_dict or not sub_item:
            print(
                "Weights are not successully loaded. Check the state dict of weights file."
            )
            return None
        else:
            print(
                'Found correct weights in the "{}" item of loaded state_dict.'.format(
                    sub_item
                )
            )
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)
    return model