import sys |
sys.path.insert(0, '/data1/PycharmProjects/FGVC11/pytorch-image-models-main') |
import timm |
import torch |
from torch import nn |
import torch.nn.functional as F |
from timm.layers import LayerNorm2d, LayerNorm,NormMlpClassifierHead, ClassifierHead |
from timm.models.convnext import ConvNeXtStage |
import numpy as np |
LARGE_CP = '' |
class expert(nn.Module): |
def __init__(self, model_arch, num_classes,pretrain=True) -> None: |
super().__init__() |
self.model = ConvNeXtStage( |
in_chs = 768, |
out_chs= 1536, |
kernel_size=7, |
stride=2, |
dilation=(1, 1), |
depth=3, |
drop_path_rates=[0.0,0.0,0.0], |
ls_init_value=1e-6, |
conv_mlp=False, |
conv_bias=True, |
use_grn=False, |
act_layer='gelu', |
norm_layer = LayerNorm2d, |
norm_layer_cl = LayerNorm, |
) |
self.cls_head = NormMlpClassifierHead( |
in_features=1536, |
num_classes= num_classes, |
hidden_size=1536, |
pool_type='avg', |
drop_rate=0.0, |
norm_layer=LayerNorm2d, |
act_layer='gelu', |
) |
if model_arch == 'convnext_large_mlp': |
checkpoints = LARGE_CP |
assert False, 'pretrain weight not found' |
print('use pretrain weight:', checkpoints) |
state_dict = torch.load(checkpoints) |
for key in list(state_dict.keys()): |
if key.startswith('module.'): |
new_key = key[7:] |
state_dict[new_key] = state_dict[key] |
del state_dict[key] |
for key in list(state_dict.keys()): |
if not 'stages.3.' in key: |
del state_dict[key] |
if key.startswith('stages.3.'): |
new_key = key[9:] |
state_dict[new_key] = state_dict[key] |
del state_dict[key] |
self.model.load_state_dict(state_dict,strict = True) |
del state_dict |
def forward(self, out_stage3): |
out = self.model(out_stage3) |
out = self.cls_head(out) |
return out |
class Moe(nn.Module): |
def __init__(self, model_arch, num_classes,mask,pretrain=True) -> None: |
super().__init__() |
if pretrain: |
self.backbone = timm.create_model(model_arch, num_classes=0, pretrained=False,out_stage3=True) |
if model_arch == 'convnextv2_base.fcmae_ft_in22k_in1k_384': |
checkpoints = BASE_CP |
elif model_arch == 'convnext_large_mlp': |
checkpoints = LARGE_CP |
else: |
assert False, 'pretrain weight not found' |
print('use pretrain weight:', checkpoints) |
state_dict = torch.load(checkpoints) |
self.backbone.load_state_dict(state_dict, strict=False) |
del state_dict |
self.head = NormMlpClassifierHead( |
in_features=1536, |
num_classes= num_classes, |
hidden_size=1536, |
pool_type='avg', |
drop_rate=0.0, |
norm_layer=LayerNorm2d, |
act_layer='gelu', |
) |
self.expert_venomous = expert(model_arch, num_classes) |
self.expert_not_venomous = expert(model_arch, num_classes) |
self.venomous_head = nn.Linear(768+1536,1,bias=False) |
torch.nn.init.xavier_uniform_(self.venomous_head.weight) |
self.venomous_mask = mask |
self.not_venomous_mask = torch.ones_like(mask)-mask |
def forward(self, x): |
out4,out3 = self.backbone(x) |
feat = torch.cat([F.adaptive_max_pool2d(out3,1).flatten(1), |
F.adaptive_max_pool2d(out4,1).flatten(1)],dim=-1) |
is_venomous = self.venomous_head(feat) |
alpha= torch.sigmoid(is_venomous) |
venomous = self.expert_venomous(out3)*self.venomous_mask.to(x.device) |
not_venomous =self.expert_not_venomous(out3)*self.not_venomous_mask.to(x.device) |
y_hat = self.head(out4) |
expert_pred = venomous + not_venomous |
final_pred = y_hat+expert_pred |
return y_hat,expert_pred,is_venomous,final_pred |
class SeesawLossWithLogits(nn.Module): |
""" |
This is unofficial implementation for Seesaw loss, |
which is proposed in the techinical report for LVIS workshop at ECCV 2020. |
For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf. |
Args: |
class_counts: The list which has number of samples for each class. |
Should have same length as num_classes. |
p: Scale parameter which adjust the strength of panishment. |
Set to 0.8 as a default by following the original paper. |
""" |
def __init__(self, class_counts: np.array,num_classes, p: float = 0.8): |
super().__init__() |
class_counts = torch.FloatTensor(class_counts) |
conditions = class_counts[:, None] > class_counts[None, :] |
trues = (class_counts[None, :] / class_counts[:, None]) ** p |
falses = torch.ones(len(class_counts), len(class_counts)) |
self.s = torch.where(conditions, trues, falses) |
self.num_classes = num_classes |
self.eps = 1.0e-6 |
def forward(self, logits, targets): |
targets = nn.functional.one_hot(targets,num_classes=self.num_classes).float().to(targets.device) |
self.s = self.s.to(targets.device) |
max_element, _ = logits.max(axis=-1) |
logits = logits - max_element[:, None] |
numerator = torch.exp(logits) |
denominator = ( |
(1 - targets)[:, None, :] |
* self.s[None, :, :] |
* torch.exp(logits)[:, None, :]).sum(axis=-1) \ |
+ torch.exp(logits) |
sigma = numerator / (denominator + self.eps) |
loss = (- targets * torch.log(sigma + self.eps)).sum(-1) |
return loss.mean() |
class all_loss(nn.Module): |
def __init__(self, class_counts: np.array, num_classes): |
super().__init__() |
self.main_loss = SeesawLossWithLogits(class_counts,num_classes) |
self.venomous_loss = SeesawLossWithLogits(class_counts,num_classes) |
self.final_pred_loss = SeesawLossWithLogits(class_counts,num_classes) |
def forward(self,y_hat,expert_pred,alpha,final_pred,targets,is_venomous): |
loss1 = self.main_loss(y_hat,targets) |
loss2 = self.venomous_loss(expert_pred,targets) |
loss4 = self.final_pred_loss(final_pred,targets) |
return (loss1+loss2+loss4)/3 |