File size: 4,289 Bytes
c9b5796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import logging

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor
import feature_extractor_models as smp
import torch
from .base import BaseModel

logger = logging.getLogger(__name__)



class FeatureExtractor(BaseModel):
    default_conf = {
        "pretrained": True,
        "input_dim": 3,
        "output_dim": 128,  # # of channels in output feature maps
        "encoder": "resnet50",  # torchvision net as string
        "remove_stride_from_first_conv": False,
        "num_downsample": None,  # how many downsample block
        "decoder_norm": "nn.BatchNorm2d",  # normalization ind decoder blocks
        "do_average_pooling": False,
        "checkpointed": False,  # whether to use gradient checkpointing
        "architecture":"FPN"
    }
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    # self.fmodel=None

    def build_encoder(self, conf):
        assert isinstance(conf.encoder, str)
        if conf.pretrained:
            assert conf.input_dim == 3


        # return encoder, layers



    def _init(self, conf):
        # Preprocessing
        self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
        self.register_buffer("std_", torch.tensor(self.std), persistent=False)

        if conf.architecture=="FPN":
        # Encoder
            self.fmodel = smp.FPN(
                encoder_name=conf.encoder,  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
                in_channels=conf.input_dim,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=conf.output_dim,  # model output channels (number of classes in your dataset)
                upsampling=2,  # optional, final output upsampling, default is 8
                activation=None
            )
        elif conf.architecture == "LightFPN":
            self.fmodel = smp.L(
                encoder_name=conf.encoder,  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
                in_channels=conf.input_dim,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=conf.output_dim,  # model output channels (number of classes in your dataset)
                upsampling=2,  # optional, final output upsampling, default is 8
                activation=None
            )
        elif conf.architecture=="PSP":

            self.fmodel =smp.PSPNet(
                encoder_name=conf.encoder,  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
                in_channels=conf.input_dim,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=conf.output_dim,  # model output channels (number of classes in your dataset)
                upsampling=4,  # optional, final output upsampling, default is 8
                activation=None
            )
        else:
            raise ValueError("Only FPN")
        # elif conf.architecture=="Unet":
        #     self.fmodel = smp.FPN(
        #         encoder_name=conf.encoder,  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        #         encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
        #         in_channels=conf.input_dim,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        #         classes=conf.output_dim,  # model output channels (number of classes in your dataset)
        #         # upsampling=int(conf.upsampling),  # optional, final output upsampling, default is 8
        #         activation="relu"
        #     )


    def _forward(self, data):
        image = data["image"]
        image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]

        output = self.fmodel(image)
        # output = self.decoder(skip_features)

        pred = {"feature_maps": [output]}
        return pred
if __name__ == '__main__':
    model=FeatureExtractor()