File size: 2,994 Bytes
2a13495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn

from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch.base.modules import Activation


class SegformerGH(MAnet):
    def __init__(
        self,
        encoder_name: str = "mit_b5",
        encoder_weights="imagenet",
        decoder_channels=(256, 128, 64, 32, 32),
        decoder_pab_channels=256,
        in_channels: int = 3,
        classes: int = 3,
    ):
        super(SegformerGH, self).__init__(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            decoder_channels=decoder_channels,
            decoder_pab_channels=decoder_pab_channels,
            in_channels=in_channels,
            classes=classes,
        )

        convert_relu_to_mish(self.encoder)
        convert_relu_to_mish(self.decoder)

        self.cellprob_head = DeepSegmantationHead(
            in_channels=decoder_channels[-1], out_channels=1, kernel_size=3,
        )
        self.gradflow_head = DeepSegmantationHead(
            in_channels=decoder_channels[-1], out_channels=2, kernel_size=3,
        )

    def forward(self, x):
        """Sequentially pass `x` trough model`s encoder, decoder and heads"""
        self.check_input_shape(x)

        features = self.encoder(x)
        decoder_output = self.decoder(*features)

        gradflow_mask = self.gradflow_head(decoder_output)
        cellprob_mask = self.cellprob_head(decoder_output)

        masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)

        return masks


class DeepSegmantationHead(nn.Sequential):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
    ):
        conv2d_1 = nn.Conv2d(
            in_channels,
            in_channels // 2,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
        )
        bn = nn.BatchNorm2d(in_channels // 2)
        conv2d_2 = nn.Conv2d(
            in_channels // 2,
            out_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
        )
        mish = nn.Mish(inplace=True)

        upsampling = (
            nn.UpsamplingBilinear2d(scale_factor=upsampling)
            if upsampling > 1
            else nn.Identity()
        )
        activation = Activation(activation)
        super().__init__(conv2d_1, mish, bn, conv2d_2, upsampling, activation)


def convert_relu_to_mish(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.Mish(inplace=True))
        else:
            convert_relu_to_mish(child)


if __name__ == "__main__":
    model = SegformerGH(
        encoder_name="mit_b5",
        encoder_weights=None,
        decoder_channels=(1024, 512, 256, 128, 64),
        decoder_pab_channels=256,
        in_channels=3,
        classes=3,
    )

    model.load_state_dict(torch.load("./main_model.pth",map_location="cpu"))
    torch.save(model, "main_model.pt")