import sys sys.path.insert(0, './pytorch-image-models-main') ####################################### from moe import Moe,all_loss ####################################### import os os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"#0,1,2,3 import torch import cv2 from albumentations.pytorch import ToTensorV2 from albumentations import ( HorizontalFlip, VerticalFlip, ShiftScaleRotate, CLAHE, RandomRotate90, Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, GaussNoise, MotionBlur, MedianBlur, PiecewiseAffine, RandomResizedCrop, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize, SmallestMaxSize ) import time # import timm import torch.multiprocessing as mp import torch.distributed as dist from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from torch.cuda.amp import autocast, GradScaler from torch.utils.data import Dataset, DataLoader from torch.optim import Adam, SGD, AdamW, RMSprop from torch import nn import random from tqdm import tqdm from PIL import Image import numpy as np import logging from sklearn.model_selection import GroupKFold, StratifiedKFold import pandas as pd import math CFG = { 'seed': 42, # 719,42,68 'model_arch': 'convnext_large_mlp',# #convnextv2_base.fcmae_ft_in22k_in1k_384 'patch': 16, 'mean':[0.485, 0.456, 0.406] , 'std':[0.229, 0.224, 0.225], 'mix_type': 'cutmix', # cutmix, mixup, tokenmix, randommix, none 'mix_prob': 0.7, 'img_size': 512,#512 'class_num': 1784, 'warmup_epochs': 1, 'warmup_lr_factor': 0.01, # warmup_lr = lr * warmup_lr_factor 'epochs': 11, # convlarge v2 epoch11 不用alpha 'train_bs': 24, 'valid_bs': 64, 'lr': 7.5e-5,#7.5e-5 'min_lr': 1e-5,#1e-6 'differLR': False, # 'bacbone_lr_factor': 0.2, # if 'differLR' is True, the lr of backbone will be lr * bacbone_lr_factor 'head_lr': 0,#used when differ 'head_wd': 0.05, 'num_workers': 8, 'device': 'cuda', 'smoothing': 0.1, # label smoothing 'weight_decay': 2e-5, 'accum_iter': 1, # suppoprt to do batch accumulation for backprop with effectively larger batch size 'verbose_step': 1, # the step of printing loss } logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) handler = logging.FileHandler(f"logs/{CFG['model_arch']}_train_moe.log") handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) def seed_everything(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True def get_img(path): # print(path) im_bgr = cv2.imread(path) im_rgb = im_bgr[:, :, ::-1] return im_rgb train_data_root = '/data1/dataset/SnakeCLEF2024/' val_data_root = '/data1/dataset/SnakeCLEF2023/val/SnakeCLEF2023-large_size/' train_df = pd.read_csv('./metadata/train_full.csv') valid_df = pd.read_csv('./metadata/SnakeCLEF2023-ValMetadata.csv') is_venomous_df = pd.read_csv('./metadata/venomous_status_list.csv') class_id2venomous ={} venomous_mask = torch.ones(CFG['class_num']) for class_id,is_venomous in zip(is_venomous_df['class_id'],is_venomous_df['MIVS']): venomous_mask[class_id]=is_venomous if class_id not in class_id2venomous.keys(): class_id2venomous[class_id] = is_venomous train_df['MIVS'] = train_df['class_id'].map(class_id2venomous) valid_df['MIVS'] = valid_df['class_id'].map(class_id2venomous) class FGVCDataset(Dataset): def __init__(self, df, data_root, transforms=None, output_label=True, one_hot_label=False ): super().__init__() self.df = df.reset_index(drop=True).copy() self.transforms = transforms self.data_root = data_root self.output_label = output_label self.one_hot_label = one_hot_label if output_label == True: self.labels = self.df['class_id'].values self.is_venomous = self.df['MIVS'] if one_hot_label is True: self.labels = np.eye(self.df['class_id'].max() + 1)[self.labels] def __len__(self): return self.df.shape[0] def __getitem__(self, index: int): # get labels if self.output_label: target = self.labels[index] venomous = self.is_venomous[index] image_path = self.data_root + self.df.loc[index]['image_path'] img = get_img(image_path) if self.transforms: img = self.transforms(image=img)['image'] if self.output_label == True: return img, target,venomous else: return img def get_train_transforms(): return Compose([ RandomResizedCrop(CFG['img_size'], CFG['img_size'], interpolation=cv2.INTER_CUBIC, scale=(0.5, 1.3)), Transpose(p=0.5), HorizontalFlip(p=0.5), VerticalFlip(p=0.5), ShiftScaleRotate(p=0.3), PiecewiseAffine(p=0.5), RandomBrightnessContrast( brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=1.0), OneOf([ OpticalDistortion(distort_limit=1.0), GridDistortion(num_steps=5, distort_limit=1.), ], p=0.5), Normalize(mean=CFG['mean'], std=CFG['std'], max_pixel_value=255.0, p=1.0), ToTensorV2(p=1.0), ], p=1.) def get_valid_transforms(): return Compose([ # SmallestMaxSize(CFG['img_size']), Resize(CFG['img_size'], CFG['img_size'], interpolation=cv2.INTER_CUBIC), # CenterCrop(CFG['img_size'], CFG['img_size']), Normalize(mean=CFG['mean'], std=CFG['std'], max_pixel_value=255.0, p=1.0), ToTensorV2(p=1.0), ], p=1.) def prepare_dataloader(train_df, val_df, train_idx, val_idx): train_ = train_df.loc[train_idx, :].reset_index(drop=True) valid_ = val_df.loc[val_idx, :].reset_index(drop=True) train_ds = FGVCDataset(train_, train_data_root, transforms=get_train_transforms()) valid_ds = FGVCDataset(valid_, val_data_root, transforms=get_valid_transforms()) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=CFG['train_bs'], pin_memory=False, drop_last=False, shuffle=True, num_workers=CFG['num_workers'] ) val_loader = torch.utils.data.DataLoader( valid_ds, batch_size=CFG['valid_bs'], num_workers=CFG['num_workers'], shuffle=False, pin_memory=False, ) return train_loader, val_loader def rand_bbox(size, lam): W = size[2] H = size[3] cut_rat = np.sqrt(1. - lam) cut_w = np.int32(W * cut_rat) cut_h = np.int32(H * cut_rat) # uniform cx = np.random.randint(W) cy = np.random.randint(H) bbx1 = np.clip(cx - cut_w // 2, 0, W) bby1 = np.clip(cy - cut_h // 2, 0, H) bbx2 = np.clip(cx + cut_w // 2, 0, W) bby2 = np.clip(cy + cut_h // 2, 0, H) return bbx1, bby1, bbx2, bby2 def generate_mask_random(imgs, patch=CFG['patch'], mask_token_num_start=14, lam=0.5): _, _, W, H = imgs.shape assert W % patch == 0 assert H % patch == 0 p = W // patch mask_ratio = 1 - lam num_masking_patches = min(p**2, int(mask_ratio * (p**2)) + mask_token_num_start) mask_idx = np.random.permutation(p**2)[:num_masking_patches] lam = 1 - num_masking_patches / (p**2) return mask_idx, lam def get_mixed_data(imgs, image_labels, is_venomous,mix_type): mix_lst = ['cutmix', 'tokenmix', 'mixup', 'randommix'] assert mix_type in mix_lst, f'Not Supported mix type: {mix_type}' if mix_type == 'randommix': # select a mix_type randomly mix_type = random.choice(mix_lst[:-2]) if mix_type == 'mixup': alpha = 2.0 rand_index = torch.randperm(imgs.size()[0]).cuda() target_a = image_labels target_b = image_labels[rand_index] lam = np.random.beta(alpha, alpha) imgs = imgs * lam + imgs[rand_index] * (1 - lam) elif mix_type == 'cutmix': beta = 1.0 lam = np.random.beta(beta, beta) rand_index = torch.randperm(imgs.size()[0]).cuda() target_a = image_labels target_b = image_labels[rand_index] is_venomous_a = is_venomous is_venomous_b = is_venomous[rand_index] bbx1, bby1, bbx2, bby2 = rand_bbox(imgs.size(), lam) imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[rand_index, :, bbx1:bbx2, bby1:bby2] # adjust lambda to exactly match pixel ratio lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (imgs.size()[-1] * imgs.size()[-2])) elif mix_type == 'tokenmix': B, C, W, H = imgs.shape mask_idx, lam = generate_mask_random(imgs) rand_index = torch.randperm(imgs.size()[0]).cuda() p = W // CFG['patch'] patch_w = CFG['patch'] patch_h = CFG['patch'] for idx in mask_idx: row_s = idx // p col_s = idx % p x1 = patch_w * row_s x2 = x1 + patch_w y1 = patch_h * col_s y2 = y1 + patch_h imgs[:, :, x1:x2, y1:y2] = imgs[rand_index, :, x1:x2, y1:y2] target_a = image_labels target_b = image_labels[rand_index] return imgs, target_a, target_b, is_venomous_a,is_venomous_b,lam def train_one_epoch_mix(epoch, model, loss_fn, optimizer, train_loader, device, scheduler=None, schd_batch_update=False, mix_type=CFG['mix_type']): model.train() running_loss = None image_preds_all = [] image_targets_all = [] pbar = tqdm(enumerate(train_loader), total=len(train_loader),ncols=70) for step, (imgs, image_labels,is_venomous) in pbar: imgs = imgs.to(device).float() image_labels = image_labels.to(device).long() is_venomous = is_venomous.to(device).float()#.long() if np.random.rand(1) < CFG['mix_prob']: imgs, target_a, target_b,is_venomous_a,is_venomous_b ,lam = get_mixed_data(imgs, image_labels, is_venomous,mix_type) with autocast(): # image_preds = model(imgs) # loss = loss_fn(image_preds, target_a) * lam + loss_fn(image_preds, target_b) * (1. - lam) y_hat,expert_pred,alpha,image_preds = model(imgs) loss = loss_fn(y_hat,expert_pred,alpha,image_preds,target_a,is_venomous_a)*lam+loss_fn(y_hat,expert_pred,alpha,image_preds,target_b,is_venomous_b)*(1.0-lam) scaler.scale(loss).backward() else: with autocast(): y_hat,expert_pred,alpha,image_preds = model(imgs) loss = loss_fn(y_hat,expert_pred,alpha,image_preds,image_labels,is_venomous) scaler.scale(loss).backward() image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()] image_targets_all += [image_labels.detach().cpu().numpy()] if running_loss is None: running_loss = loss.item() else: running_loss = running_loss * .99 + loss.item() * .01 # if running_loss >10: # print(epoch) if ((step + 1) % CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)): # may unscale_ here if desired (e.g., to allow clipping unscaled gradients) # torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-8) scaler.step(optimizer) scaler.update() optimizer.zero_grad() if scheduler is not None and schd_batch_update: scheduler.step() if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)): description = f'epoch {epoch} loss: {running_loss:.4f}' pbar.set_description(description) image_preds_all = np.concatenate(image_preds_all) image_targets_all = np.concatenate(image_targets_all) accuracy = (image_preds_all == image_targets_all).mean() print('Train multi-class accuracy = {:.4f}'.format(accuracy)) logger.info(' Epoch: ' + str(epoch) + ' Train multi-class accuracy = {:.4f}'.format(accuracy)) logger.info(' Epoch: ' + str(epoch) + ' Train loss = {:.4f}'.format(running_loss)) if scheduler is not None and not schd_batch_update: scheduler.step() def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False): model.eval() loss_sum = 0 sample_num = 0 image_preds_all = [] image_targets_all = [] pbar = tqdm(enumerate(val_loader), total=len(val_loader),ncols=70) for step, (imgs, image_labels,is_venomous) in pbar: imgs = imgs.to(device).float() image_labels = image_labels.to(device).long() is_venomous = is_venomous.to(device).float()#.long() # image_preds = model(imgs) y_hat,expert_pred,alpha,image_preds = model(imgs) image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()] image_targets_all += [image_labels.detach().cpu().numpy()] # if openset, transform labels to calculate loss without reporting errors openset_idx = image_labels == -1 image_labels[openset_idx] = 0 # just assign class_id: 0 loss = loss_fn(image_preds, image_labels) loss_sum += loss.item() * image_labels.shape[0] sample_num += image_labels.shape[0] if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)): description = f'epoch {epoch} loss: {loss_sum / sample_num:.4f}' pbar.set_description(description) image_preds_all = np.concatenate(image_preds_all) image_targets_all = np.concatenate(image_targets_all) accuracy = (image_preds_all == image_targets_all).mean() print('validation multi-class accuracy = {:.4f}'.format(accuracy)) logger.info(' Epoch: ' + str(epoch) + ' validation multi-class accuracy = {:.4f}'.format(accuracy)) if scheduler is not None: if schd_loss_update: scheduler.step(loss_sum / sample_num) else: scheduler.step() return accuracy if __name__ == '__main__': # time.sleep(150 * 60) seed_everything(CFG['seed']) logger.info(CFG) trn_idx = np.arange(train_df.shape[0]) val_idx = np.arange(valid_df.shape[0]) df_class_id = np.array(train_df['class_id']) class_counts = np.bincount(df_class_id) device = torch.device(CFG['device']) model = Moe(CFG['model_arch'],CFG['class_num'],venomous_mask) model = nn.DataParallel(model) model.to(device) model.module.not_venomous_mask.to(device) model.module.venomous_mask.to(device) train_loader, val_loader = prepare_dataloader(train_df, valid_df, trn_idx, val_idx) scaler = GradScaler() if CFG['differLR']: backbone_params = list(map(id, model.module.backbone.parameters())) head_params = filter(lambda p: id(p) not in backbone_params, model.parameters()) if CFG['head_lr']>0: lr_cfg = [ {'params': model.module.backbone.parameters(), 'lr': CFG['lr'] ,'weight_decay':CFG['weight_decay']}, {'params': head_params , 'lr': CFG['head_lr'],'weight_decay':CFG['head_wd']}] optimizer = torch.optim.AdamW(lr_cfg, lr=CFG['lr'], weight_decay=CFG['weight_decay']) else: # lr_cfg = [ {'params': model.module.backbone.parameters(), 'lr': CFG['lr'] ,'weight_decay':CFG['weight_decay']}] print('frozen center') # for param in model.module.center.parameters(): model.module.center.requires_grad = False lr_cfg = [ {'params': model.module.backbone.parameters(), 'lr': CFG['lr'], 'weight_decay': CFG['weight_decay']}] # head_params = {name: param for name, param in named_parameters if id(param) not in backbone_params} optimizer = torch.optim.AdamW(lr_cfg, lr=CFG['lr'], weight_decay=CFG['weight_decay']) else: optimizer = torch.optim.AdamW(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay']) main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=CFG['epochs'] - CFG['warmup_epochs'], eta_min=CFG['min_lr'] ) warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=CFG['warmup_lr_factor'], total_iters=CFG['warmup_epochs'] ) scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[CFG['warmup_epochs']] ) loss_tr = all_loss(class_counts,CFG['class_num']).to(device) loss_fn = nn.CrossEntropyLoss(label_smoothing=CFG['smoothing']).to(device) best_acc = 0.0 for epoch in range(CFG['epochs']): print(optimizer.param_groups[0]['lr']) train_one_epoch_mix(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler) temp_acc = 0.0 with torch.no_grad(): temp_acc = valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False) if temp_acc > best_acc: torch.save(model.state_dict(), './checkpoints_moe/moe_{}_mix_{}_mixprob_{}_seed_{}_ls_{}_epochs_{}_differLR_{}_imsize{}.pth'.format( CFG['model_arch'], CFG['mix_type'], CFG['mix_prob'], CFG['seed'], CFG['smoothing'], CFG['epochs'], CFG['differLR'], CFG['img_size'])) if temp_acc > best_acc: best_acc = temp_acc del model, optimizer, train_loader, val_loader, scaler, scheduler print(best_acc) logger.info('BEST-Valid-ACC: ' + str(best_acc)) torch.cuda.empty_cache()