File size: 4,302 Bytes
87e3471
16456f4
87e3471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f463d2
 
 
 
16456f4
 
 
 
 
 
 
 
 
3f463d2
 
 
 
16456f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f463d2
 
 
 
 
 
16456f4
3f463d2
 
 
 
 
 
 
 
 
9e0c113
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
from torchvision.models import resnet50, resnet101,ResNet101_Weights
from torchvision.models.resnet import Bottleneck, BasicBlock
import torch.nn.init as init
import torch.nn as nn
import torch


class ResNet50Classifier(nn.Module):
    def __init__(self, num_classes: int = 14):
        super(ResNet50Classifier, self).__init__()
        self.resnet = resnet50(pretrained=False, num_classes=num_classes)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if not self._load():
            for m in self.modules():
                # print(m)
                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, nn.Conv2d):
                    init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                    if m.bias is not None:
                        init.zeros_(m.bias)
                elif isinstance(m, nn.Linear):
                    init.xavier_normal_(m.weight)
                    if m.bias is not None:
                        init.zeros_(m.bias)

    def forward(self, x):
        x = self.resnet(x)
        return x

    def _load(self, filename: str = None) -> bool:
        if filename is None:
            current_work_dir = os.path.dirname(__file__)
            filename = os.path.join(current_work_dir, "best_pth", "ResNet50Classifier.pth")
        if not os.path.exists(filename):
            print("Model file does not exist.")
            return False
        self.load_state_dict(torch.load(filename))
        print("Model loaded successfully.")
        return True


class ResNet101Classifier(nn.Module):
    def __init__(self, num_classes: int = 14):
        super(ResNet101Classifier, self).__init__()
        # self.resnet = resnet101(pretrained=False, num_classes=num_classes)
        
        # ======== ONLY USE THIS FOR TESTING ========
        self.resnet = resnet101(weights=ResNet101_Weights.DEFAULT)
        # Replace the final fully connected layer with one that outputs the desired number of classes
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)
        # ======== ONLY USE THIS FOR TESTING ========
        

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        # if not self._load():
        #     for m in self.modules():
        #         # print(m)
        #         if isinstance(m, Bottleneck) and m.bn3.weight is not None:
        #             nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
        #         elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
        #             nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
        #         elif isinstance(m, nn.Conv2d):
        #             init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        #             if m.bias is not None:
        #                 init.zeros_(m.bias)
        #         elif isinstance(m, nn.Linear):
        #             init.xavier_normal_(m.weight)
        #             if m.bias is not None:
        #                 init.zeros_(m.bias)

    def forward(self, x):
        x = self.resnet(x)
        return x

    def _load(self, filename: str = None) -> bool:
        return False
        if filename is None:
            current_work_dir = os.path.dirname(__file__)
            filename = os.path.join(current_work_dir, "best_pth", "ResNet101Classifier.pth")
        if not os.path.exists(filename):
            print("Model file does not exist.")
            return False
        self.load_state_dict(torch.load(filename))
        print("Model loaded successfully.")
        return True