File size: 7,176 Bytes
786f6a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import sys
sys.path.insert(0, '/data1/PycharmProjects/FGVC11/pytorch-image-models-main')
# sys.path.insert(0, '/data1/PycharmProjects/FGVC11/submision/pytorch-image-models-main')
# sys.path.insert(0, '/data1/PycharmProjects/FGVC11/submision/pytorch-image-models')
import timm 
import torch
from torch import nn
import torch.nn.functional as F
from timm.layers import LayerNorm2d, LayerNorm,NormMlpClassifierHead, ClassifierHead
from timm.models.convnext import ConvNeXtStage
import numpy as np

LARGE_CP = '' # path to pretrain weights



class expert(nn.Module):
    def __init__(self,  model_arch, num_classes,pretrain=True) -> None:
        super().__init__()

        self.model = ConvNeXtStage(
                in_chs = 768,
                out_chs= 1536,
                kernel_size=7,
                stride=2,
                dilation=(1, 1),
                depth=3,
                drop_path_rates=[0.0,0.0,0.0],
                ls_init_value=1e-6,
                conv_mlp=False,
                conv_bias=True,
                use_grn=False,
                act_layer='gelu',
                norm_layer = LayerNorm2d,
                norm_layer_cl = LayerNorm,
            )    

        self.cls_head = NormMlpClassifierHead(
                in_features=1536,
                num_classes= num_classes,
                hidden_size=1536,
                pool_type='avg',#max
                drop_rate=0.0,
                norm_layer=LayerNorm2d,
                act_layer='gelu',
            )
            

        if model_arch == 'convnext_large_mlp':
            checkpoints = LARGE_CP

            assert False, 'pretrain weight not found'
        print('use pretrain weight:', checkpoints)
        state_dict = torch.load(checkpoints)
        for key in list(state_dict.keys()):
            if key.startswith('module.'):
                new_key = key[7:]
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        for key in list(state_dict.keys()):
            if not 'stages.3.' in key:
                del state_dict[key]   
            if key.startswith('stages.3.'):
                new_key = key[9:]
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        self.model.load_state_dict(state_dict,strict = True)
        del state_dict
    def forward(self, out_stage3):
        out = self.model(out_stage3)
        out = self.cls_head(out)
        return out
        
class Moe(nn.Module):
    def __init__(self,  model_arch, num_classes,mask,pretrain=True) -> None:
        super().__init__()
        if pretrain:
            self.backbone = timm.create_model(model_arch, num_classes=0, pretrained=False,out_stage3=True)
            if model_arch == 'convnextv2_base.fcmae_ft_in22k_in1k_384':
                checkpoints = BASE_CP
            elif model_arch == 'convnext_large_mlp':
                checkpoints = LARGE_CP
            else:
                assert False, 'pretrain weight not found'
            print('use pretrain weight:', checkpoints)
            
            state_dict = torch.load(checkpoints)
            self.backbone.load_state_dict(state_dict, strict=False)
            del state_dict
            
                
        self.head = NormMlpClassifierHead(
                in_features=1536,
                num_classes= num_classes,
                hidden_size=1536,
                pool_type='avg',#max
                drop_rate=0.0,
                norm_layer=LayerNorm2d,
                act_layer='gelu',
            )    
        self.expert_venomous = expert(model_arch, num_classes)
        self.expert_not_venomous = expert(model_arch, num_classes)
        self.venomous_head = nn.Linear(768+1536,1,bias=False)
        torch.nn.init.xavier_uniform_(self.venomous_head.weight)
        self.venomous_mask = mask
        self.not_venomous_mask = torch.ones_like(mask)-mask
    def forward(self, x):
        out4,out3 = self.backbone(x)
        feat = torch.cat([F.adaptive_max_pool2d(out3,1).flatten(1),
                        F.adaptive_max_pool2d(out4,1).flatten(1)],dim=-1)
        is_venomous = self.venomous_head(feat)
        alpha= torch.sigmoid(is_venomous)
        venomous = self.expert_venomous(out3)*self.venomous_mask.to(x.device)
        not_venomous =self.expert_not_venomous(out3)*self.not_venomous_mask.to(x.device)
        y_hat = self.head(out4)
        
        # expert_pred = venomous * alpha + not_venomous*(1-alpha)
        expert_pred = venomous + not_venomous
        final_pred = y_hat+expert_pred
        return y_hat,expert_pred,is_venomous,final_pred

class SeesawLossWithLogits(nn.Module):
    """
    This is unofficial implementation for Seesaw loss,
    which is proposed in the techinical report for LVIS workshop at ECCV 2020.
    For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf.
    Args:
    class_counts: The list which has number of samples for each class.
                  Should have same length as num_classes.
    p: Scale parameter which adjust the strength of panishment.
       Set to 0.8 as a default by following the original paper.
    """

    def __init__(self, class_counts: np.array,num_classes, p: float = 0.8):
        super().__init__()

        class_counts = torch.FloatTensor(class_counts)
        conditions = class_counts[:, None] > class_counts[None, :]
        trues = (class_counts[None, :] / class_counts[:, None]) ** p
        falses = torch.ones(len(class_counts), len(class_counts))
        self.s = torch.where(conditions, trues, falses)
        self.num_classes = num_classes
        self.eps = 1.0e-6

    def forward(self, logits, targets):
        targets = nn.functional.one_hot(targets,num_classes=self.num_classes).float().to(targets.device)
        self.s = self.s.to(targets.device)
        max_element, _ = logits.max(axis=-1)
        logits = logits - max_element[:, None]  # to prevent overflow

        numerator = torch.exp(logits)
        denominator = (
            (1 - targets)[:, None, :]
            * self.s[None, :, :]
            * torch.exp(logits)[:, None, :]).sum(axis=-1) \
            + torch.exp(logits)

        sigma = numerator / (denominator + self.eps)
        loss = (- targets * torch.log(sigma + self.eps)).sum(-1)
        return loss.mean()
        
        
class all_loss(nn.Module):
    def __init__(self, class_counts: np.array, num_classes):
        super().__init__()
        self.main_loss = SeesawLossWithLogits(class_counts,num_classes)
        self.venomous_loss = SeesawLossWithLogits(class_counts,num_classes)
        self.final_pred_loss = SeesawLossWithLogits(class_counts,num_classes)
        # self.venomous_loss = nn.CrossEntropyLoss()
        # self.alpha_loss = nn.BCEWithLogitsLoss()
        # self.final_pred_loss = nn.CrossEntropyLoss()
    def forward(self,y_hat,expert_pred,alpha,final_pred,targets,is_venomous):
        loss1 = self.main_loss(y_hat,targets)
        loss2 = self.venomous_loss(expert_pred,targets)
        # loss3 = self.alpha_loss(alpha,is_venomous.unsqueeze(1))
        loss4 = self.final_pred_loss(final_pred,targets)
        
        return (loss1+loss2+loss4)/3