|
import sys |
|
sys.path.insert(0, '/data1/PycharmProjects/FGVC11/pytorch-image-models-main') |
|
|
|
|
|
import timm |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from timm.layers import LayerNorm2d, LayerNorm,NormMlpClassifierHead, ClassifierHead |
|
from timm.models.convnext import ConvNeXtStage |
|
import numpy as np |
|
|
|
LARGE_CP = '' |
|
|
|
|
|
|
|
class expert(nn.Module): |
|
def __init__(self, model_arch, num_classes,pretrain=True) -> None: |
|
super().__init__() |
|
|
|
self.model = ConvNeXtStage( |
|
in_chs = 768, |
|
out_chs= 1536, |
|
kernel_size=7, |
|
stride=2, |
|
dilation=(1, 1), |
|
depth=3, |
|
drop_path_rates=[0.0,0.0,0.0], |
|
ls_init_value=1e-6, |
|
conv_mlp=False, |
|
conv_bias=True, |
|
use_grn=False, |
|
act_layer='gelu', |
|
norm_layer = LayerNorm2d, |
|
norm_layer_cl = LayerNorm, |
|
) |
|
|
|
self.cls_head = NormMlpClassifierHead( |
|
in_features=1536, |
|
num_classes= num_classes, |
|
hidden_size=1536, |
|
pool_type='avg', |
|
drop_rate=0.0, |
|
norm_layer=LayerNorm2d, |
|
act_layer='gelu', |
|
) |
|
|
|
|
|
if model_arch == 'convnext_large_mlp': |
|
checkpoints = LARGE_CP |
|
|
|
assert False, 'pretrain weight not found' |
|
print('use pretrain weight:', checkpoints) |
|
state_dict = torch.load(checkpoints) |
|
for key in list(state_dict.keys()): |
|
if key.startswith('module.'): |
|
new_key = key[7:] |
|
state_dict[new_key] = state_dict[key] |
|
del state_dict[key] |
|
for key in list(state_dict.keys()): |
|
if not 'stages.3.' in key: |
|
del state_dict[key] |
|
if key.startswith('stages.3.'): |
|
new_key = key[9:] |
|
state_dict[new_key] = state_dict[key] |
|
del state_dict[key] |
|
self.model.load_state_dict(state_dict,strict = True) |
|
del state_dict |
|
def forward(self, out_stage3): |
|
out = self.model(out_stage3) |
|
out = self.cls_head(out) |
|
return out |
|
|
|
class Moe(nn.Module): |
|
def __init__(self, model_arch, num_classes,mask,pretrain=True) -> None: |
|
super().__init__() |
|
if pretrain: |
|
self.backbone = timm.create_model(model_arch, num_classes=0, pretrained=False,out_stage3=True) |
|
if model_arch == 'convnextv2_base.fcmae_ft_in22k_in1k_384': |
|
checkpoints = BASE_CP |
|
elif model_arch == 'convnext_large_mlp': |
|
checkpoints = LARGE_CP |
|
else: |
|
assert False, 'pretrain weight not found' |
|
print('use pretrain weight:', checkpoints) |
|
|
|
state_dict = torch.load(checkpoints) |
|
self.backbone.load_state_dict(state_dict, strict=False) |
|
del state_dict |
|
|
|
|
|
self.head = NormMlpClassifierHead( |
|
in_features=1536, |
|
num_classes= num_classes, |
|
hidden_size=1536, |
|
pool_type='avg', |
|
drop_rate=0.0, |
|
norm_layer=LayerNorm2d, |
|
act_layer='gelu', |
|
) |
|
self.expert_venomous = expert(model_arch, num_classes) |
|
self.expert_not_venomous = expert(model_arch, num_classes) |
|
self.venomous_head = nn.Linear(768+1536,1,bias=False) |
|
torch.nn.init.xavier_uniform_(self.venomous_head.weight) |
|
self.venomous_mask = mask |
|
self.not_venomous_mask = torch.ones_like(mask)-mask |
|
def forward(self, x): |
|
out4,out3 = self.backbone(x) |
|
feat = torch.cat([F.adaptive_max_pool2d(out3,1).flatten(1), |
|
F.adaptive_max_pool2d(out4,1).flatten(1)],dim=-1) |
|
is_venomous = self.venomous_head(feat) |
|
alpha= torch.sigmoid(is_venomous) |
|
venomous = self.expert_venomous(out3)*self.venomous_mask.to(x.device) |
|
not_venomous =self.expert_not_venomous(out3)*self.not_venomous_mask.to(x.device) |
|
y_hat = self.head(out4) |
|
|
|
|
|
expert_pred = venomous + not_venomous |
|
final_pred = y_hat+expert_pred |
|
return y_hat,expert_pred,is_venomous,final_pred |
|
|
|
class SeesawLossWithLogits(nn.Module): |
|
""" |
|
This is unofficial implementation for Seesaw loss, |
|
which is proposed in the techinical report for LVIS workshop at ECCV 2020. |
|
For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf. |
|
Args: |
|
class_counts: The list which has number of samples for each class. |
|
Should have same length as num_classes. |
|
p: Scale parameter which adjust the strength of panishment. |
|
Set to 0.8 as a default by following the original paper. |
|
""" |
|
|
|
def __init__(self, class_counts: np.array,num_classes, p: float = 0.8): |
|
super().__init__() |
|
|
|
class_counts = torch.FloatTensor(class_counts) |
|
conditions = class_counts[:, None] > class_counts[None, :] |
|
trues = (class_counts[None, :] / class_counts[:, None]) ** p |
|
falses = torch.ones(len(class_counts), len(class_counts)) |
|
self.s = torch.where(conditions, trues, falses) |
|
self.num_classes = num_classes |
|
self.eps = 1.0e-6 |
|
|
|
def forward(self, logits, targets): |
|
targets = nn.functional.one_hot(targets,num_classes=self.num_classes).float().to(targets.device) |
|
self.s = self.s.to(targets.device) |
|
max_element, _ = logits.max(axis=-1) |
|
logits = logits - max_element[:, None] |
|
|
|
numerator = torch.exp(logits) |
|
denominator = ( |
|
(1 - targets)[:, None, :] |
|
* self.s[None, :, :] |
|
* torch.exp(logits)[:, None, :]).sum(axis=-1) \ |
|
+ torch.exp(logits) |
|
|
|
sigma = numerator / (denominator + self.eps) |
|
loss = (- targets * torch.log(sigma + self.eps)).sum(-1) |
|
return loss.mean() |
|
|
|
|
|
class all_loss(nn.Module): |
|
def __init__(self, class_counts: np.array, num_classes): |
|
super().__init__() |
|
self.main_loss = SeesawLossWithLogits(class_counts,num_classes) |
|
self.venomous_loss = SeesawLossWithLogits(class_counts,num_classes) |
|
self.final_pred_loss = SeesawLossWithLogits(class_counts,num_classes) |
|
|
|
|
|
|
|
def forward(self,y_hat,expert_pred,alpha,final_pred,targets,is_venomous): |
|
loss1 = self.main_loss(y_hat,targets) |
|
loss2 = self.venomous_loss(expert_pred,targets) |
|
|
|
loss4 = self.final_pred_loss(final_pred,targets) |
|
|
|
return (loss1+loss2+loss4)/3 |
|
|
|
|