# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han # International Conference on Computer Vision (ICCV), 2023 import os import sys import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from efficientvit.apps.trainer import Trainer from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor from efficientvit.clscore.trainer.utils import accuracy, apply_mixup, label_smooth from efficientvit.models.utils import list_join, list_mean, torch_random_choices __all__ = ["ClsTrainer"] class ClsTrainer(Trainer): def __init__( self, path: str, model: nn.Module, data_provider, auto_restart_thresh: float or None = None, ) -> None: super().__init__( path=path, model=model, data_provider=data_provider, ) self.auto_restart_thresh = auto_restart_thresh self.test_criterion = nn.CrossEntropyLoss() def _validate(self, model, data_loader, epoch) -> dict[str, any]: val_loss = AverageMeter() val_top1 = AverageMeter() val_top5 = AverageMeter() with torch.no_grad(): with tqdm( total=len(data_loader), desc=f"Validate Epoch #{epoch + 1}", disable=not is_master(), file=sys.stdout, ) as t: for images, labels in data_loader: images, labels = images.cuda(), labels.cuda() # compute output output = model(images) loss = self.test_criterion(output, labels) val_loss.update(loss, images.shape[0]) if self.data_provider.n_classes >= 100: acc1, acc5 = accuracy(output, labels, topk=(1, 5)) val_top5.update(acc5[0], images.shape[0]) else: acc1 = accuracy(output, labels, topk=(1,))[0] val_top1.update(acc1[0], images.shape[0]) t.set_postfix( { "loss": val_loss.avg, "top1": val_top1.avg, "top5": val_top5.avg, "#samples": val_top1.get_count(), "bs": images.shape[0], "res": images.shape[2], } ) t.update() return { "val_top1": val_top1.avg, "val_loss": val_loss.avg, **({"val_top5": val_top5.avg} if val_top5.count > 0 else {}), } def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: images = feed_dict["data"].cuda() labels = feed_dict["label"].cuda() # label smooth labels = label_smooth(labels, self.data_provider.n_classes, self.run_config.label_smooth) # mixup if self.run_config.mixup_config is not None: # choose active mixup config mix_weight_list = [mix_list[2] for mix_list in self.run_config.mixup_config["op"]] active_id = torch_random_choices( list(range(len(self.run_config.mixup_config["op"]))), weight_list=mix_weight_list, ) active_id = int(sync_tensor(active_id, reduce="root")) active_mixup_config = self.run_config.mixup_config["op"][active_id] mixup_type, mixup_alpha = active_mixup_config[:2] lam = float(torch.distributions.beta.Beta(mixup_alpha, mixup_alpha).sample()) lam = float(np.clip(lam, 0, 1)) lam = float(sync_tensor(lam, reduce="root")) images, labels = apply_mixup(images, labels, lam, mixup_type) return { "data": images, "label": labels, } def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: images = feed_dict["data"] labels = feed_dict["label"] # setup mesa if self.run_config.mesa is not None and self.run_config.mesa["thresh"] <= self.run_config.progress: ema_model = self.ema.shadows with torch.inference_mode(): ema_output = ema_model(images).detach() ema_output = torch.clone(ema_output) ema_output = F.sigmoid(ema_output).detach() else: ema_output = None with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.enable_amp): output = self.model(images) loss = self.train_criterion(output, labels) # mesa loss if ema_output is not None: mesa_loss = self.train_criterion(output, ema_output) loss = loss + self.run_config.mesa["ratio"] * mesa_loss self.scaler.scale(loss).backward() # calc train top1 acc if self.run_config.mixup_config is None: top1 = accuracy(output, torch.argmax(labels, dim=1), topk=(1,))[0][0] else: top1 = None return { "loss": loss, "top1": top1, } def _train_one_epoch(self, epoch: int) -> dict[str, any]: train_loss = AverageMeter() train_top1 = AverageMeter() with tqdm( total=len(self.data_provider.train), desc="Train Epoch #{}".format(epoch + 1), disable=not is_master(), file=sys.stdout, ) as t: for images, labels in self.data_provider.train: feed_dict = {"data": images, "label": labels} # preprocessing feed_dict = self.before_step(feed_dict) # clear gradient self.optimizer.zero_grad() # forward & backward output_dict = self.run_step(feed_dict) # update: optimizer, lr_scheduler self.after_step() # update train metrics train_loss.update(output_dict["loss"], images.shape[0]) if output_dict["top1"] is not None: train_top1.update(output_dict["top1"], images.shape[0]) # tqdm postfix_dict = { "loss": train_loss.avg, "top1": train_top1.avg, "bs": images.shape[0], "res": images.shape[2], "lr": list_join( sorted(set([group["lr"] for group in self.optimizer.param_groups])), "#", "%.1E", ), "progress": self.run_config.progress, } t.set_postfix(postfix_dict) t.update() return { **({"train_top1": train_top1.avg} if train_top1.count > 0 else {}), "train_loss": train_loss.avg, } def train(self, trials=0, save_freq=1) -> None: if self.run_config.bce: self.train_criterion = nn.BCEWithLogitsLoss() else: self.train_criterion = nn.CrossEntropyLoss() for epoch in range(self.start_epoch, self.run_config.n_epochs + self.run_config.warmup_epochs): train_info_dict = self.train_one_epoch(epoch) # eval val_info_dict = self.multires_validate(epoch=epoch) avg_top1 = list_mean([info_dict["val_top1"] for info_dict in val_info_dict.values()]) is_best = avg_top1 > self.best_val self.best_val = max(avg_top1, self.best_val) if self.auto_restart_thresh is not None: if self.best_val - avg_top1 > self.auto_restart_thresh: self.write_log(f"Abnormal accuracy drop: {self.best_val} -> {avg_top1}") self.load_model(os.path.join(self.checkpoint_path, "model_best.pt")) return self.train(trials + 1, save_freq) # log val_log = self.run_config.epoch_format(epoch) val_log += f"\tval_top1={avg_top1:.2f}({self.best_val:.2f})" val_log += "\tVal(" for key in list(val_info_dict.values())[0]: if key == "val_top1": continue val_log += f"{key}={list_mean([info_dict[key] for info_dict in val_info_dict.values()]):.2f}," val_log += ")\tTrain(" for key, val in train_info_dict.items(): val_log += f"{key}={val:.2E}," val_log += ( f'lr={list_join(sorted(set([group["lr"] for group in self.optimizer.param_groups])), "#", "%.1E")})' ) self.write_log(val_log, prefix="valid", print_log=False) # save model if (epoch + 1) % save_freq == 0 or (is_best and self.run_config.progress > 0.8): self.save_model( only_state_dict=False, epoch=epoch, model_name="model_best.pt" if is_best else "checkpoint.pt", )