SnakeCLEF2024 / moe.py
pengdadaaa's picture
Upload 741 files
786f6a6 verified
raw
history blame
7.18 kB
import sys
sys.path.insert(0, '/data1/PycharmProjects/FGVC11/pytorch-image-models-main')
# sys.path.insert(0, '/data1/PycharmProjects/FGVC11/submision/pytorch-image-models-main')
# sys.path.insert(0, '/data1/PycharmProjects/FGVC11/submision/pytorch-image-models')
import timm
import torch
from torch import nn
import torch.nn.functional as F
from timm.layers import LayerNorm2d, LayerNorm,NormMlpClassifierHead, ClassifierHead
from timm.models.convnext import ConvNeXtStage
import numpy as np
LARGE_CP = '' # path to pretrain weights
class expert(nn.Module):
def __init__(self, model_arch, num_classes,pretrain=True) -> None:
super().__init__()
self.model = ConvNeXtStage(
in_chs = 768,
out_chs= 1536,
kernel_size=7,
stride=2,
dilation=(1, 1),
depth=3,
drop_path_rates=[0.0,0.0,0.0],
ls_init_value=1e-6,
conv_mlp=False,
conv_bias=True,
use_grn=False,
act_layer='gelu',
norm_layer = LayerNorm2d,
norm_layer_cl = LayerNorm,
)
self.cls_head = NormMlpClassifierHead(
in_features=1536,
num_classes= num_classes,
hidden_size=1536,
pool_type='avg',#max
drop_rate=0.0,
norm_layer=LayerNorm2d,
act_layer='gelu',
)
if model_arch == 'convnext_large_mlp':
checkpoints = LARGE_CP
assert False, 'pretrain weight not found'
print('use pretrain weight:', checkpoints)
state_dict = torch.load(checkpoints)
for key in list(state_dict.keys()):
if key.startswith('module.'):
new_key = key[7:]
state_dict[new_key] = state_dict[key]
del state_dict[key]
for key in list(state_dict.keys()):
if not 'stages.3.' in key:
del state_dict[key]
if key.startswith('stages.3.'):
new_key = key[9:]
state_dict[new_key] = state_dict[key]
del state_dict[key]
self.model.load_state_dict(state_dict,strict = True)
del state_dict
def forward(self, out_stage3):
out = self.model(out_stage3)
out = self.cls_head(out)
return out
class Moe(nn.Module):
def __init__(self, model_arch, num_classes,mask,pretrain=True) -> None:
super().__init__()
if pretrain:
self.backbone = timm.create_model(model_arch, num_classes=0, pretrained=False,out_stage3=True)
if model_arch == 'convnextv2_base.fcmae_ft_in22k_in1k_384':
checkpoints = BASE_CP
elif model_arch == 'convnext_large_mlp':
checkpoints = LARGE_CP
else:
assert False, 'pretrain weight not found'
print('use pretrain weight:', checkpoints)
state_dict = torch.load(checkpoints)
self.backbone.load_state_dict(state_dict, strict=False)
del state_dict
self.head = NormMlpClassifierHead(
in_features=1536,
num_classes= num_classes,
hidden_size=1536,
pool_type='avg',#max
drop_rate=0.0,
norm_layer=LayerNorm2d,
act_layer='gelu',
)
self.expert_venomous = expert(model_arch, num_classes)
self.expert_not_venomous = expert(model_arch, num_classes)
self.venomous_head = nn.Linear(768+1536,1,bias=False)
torch.nn.init.xavier_uniform_(self.venomous_head.weight)
self.venomous_mask = mask
self.not_venomous_mask = torch.ones_like(mask)-mask
def forward(self, x):
out4,out3 = self.backbone(x)
feat = torch.cat([F.adaptive_max_pool2d(out3,1).flatten(1),
F.adaptive_max_pool2d(out4,1).flatten(1)],dim=-1)
is_venomous = self.venomous_head(feat)
alpha= torch.sigmoid(is_venomous)
venomous = self.expert_venomous(out3)*self.venomous_mask.to(x.device)
not_venomous =self.expert_not_venomous(out3)*self.not_venomous_mask.to(x.device)
y_hat = self.head(out4)
# expert_pred = venomous * alpha + not_venomous*(1-alpha)
expert_pred = venomous + not_venomous
final_pred = y_hat+expert_pred
return y_hat,expert_pred,is_venomous,final_pred
class SeesawLossWithLogits(nn.Module):
"""
This is unofficial implementation for Seesaw loss,
which is proposed in the techinical report for LVIS workshop at ECCV 2020.
For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf.
Args:
class_counts: The list which has number of samples for each class.
Should have same length as num_classes.
p: Scale parameter which adjust the strength of panishment.
Set to 0.8 as a default by following the original paper.
"""
def __init__(self, class_counts: np.array,num_classes, p: float = 0.8):
super().__init__()
class_counts = torch.FloatTensor(class_counts)
conditions = class_counts[:, None] > class_counts[None, :]
trues = (class_counts[None, :] / class_counts[:, None]) ** p
falses = torch.ones(len(class_counts), len(class_counts))
self.s = torch.where(conditions, trues, falses)
self.num_classes = num_classes
self.eps = 1.0e-6
def forward(self, logits, targets):
targets = nn.functional.one_hot(targets,num_classes=self.num_classes).float().to(targets.device)
self.s = self.s.to(targets.device)
max_element, _ = logits.max(axis=-1)
logits = logits - max_element[:, None] # to prevent overflow
numerator = torch.exp(logits)
denominator = (
(1 - targets)[:, None, :]
* self.s[None, :, :]
* torch.exp(logits)[:, None, :]).sum(axis=-1) \
+ torch.exp(logits)
sigma = numerator / (denominator + self.eps)
loss = (- targets * torch.log(sigma + self.eps)).sum(-1)
return loss.mean()
class all_loss(nn.Module):
def __init__(self, class_counts: np.array, num_classes):
super().__init__()
self.main_loss = SeesawLossWithLogits(class_counts,num_classes)
self.venomous_loss = SeesawLossWithLogits(class_counts,num_classes)
self.final_pred_loss = SeesawLossWithLogits(class_counts,num_classes)
# self.venomous_loss = nn.CrossEntropyLoss()
# self.alpha_loss = nn.BCEWithLogitsLoss()
# self.final_pred_loss = nn.CrossEntropyLoss()
def forward(self,y_hat,expert_pred,alpha,final_pred,targets,is_venomous):
loss1 = self.main_loss(y_hat,targets)
loss2 = self.venomous_loss(expert_pred,targets)
# loss3 = self.alpha_loss(alpha,is_venomous.unsqueeze(1))
loss4 = self.final_pred_loss(final_pred,targets)
return (loss1+loss2+loss4)/3