Spaces:
Runtime error
Runtime error
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
# International Conference on Computer Vision (ICCV), 2023 | |
import os | |
import sys | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from efficientvit.apps.trainer import Trainer | |
from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor | |
from efficientvit.clscore.trainer.utils import accuracy, apply_mixup, label_smooth | |
from efficientvit.models.utils import list_join, list_mean, torch_random_choices | |
__all__ = ["ClsTrainer"] | |
class ClsTrainer(Trainer): | |
def __init__( | |
self, | |
path: str, | |
model: nn.Module, | |
data_provider, | |
auto_restart_thresh: float or None = None, | |
) -> None: | |
super().__init__( | |
path=path, | |
model=model, | |
data_provider=data_provider, | |
) | |
self.auto_restart_thresh = auto_restart_thresh | |
self.test_criterion = nn.CrossEntropyLoss() | |
def _validate(self, model, data_loader, epoch) -> dict[str, any]: | |
val_loss = AverageMeter() | |
val_top1 = AverageMeter() | |
val_top5 = AverageMeter() | |
with torch.no_grad(): | |
with tqdm( | |
total=len(data_loader), | |
desc=f"Validate Epoch #{epoch + 1}", | |
disable=not is_master(), | |
file=sys.stdout, | |
) as t: | |
for images, labels in data_loader: | |
images, labels = images.cuda(), labels.cuda() | |
# compute output | |
output = model(images) | |
loss = self.test_criterion(output, labels) | |
val_loss.update(loss, images.shape[0]) | |
if self.data_provider.n_classes >= 100: | |
acc1, acc5 = accuracy(output, labels, topk=(1, 5)) | |
val_top5.update(acc5[0], images.shape[0]) | |
else: | |
acc1 = accuracy(output, labels, topk=(1,))[0] | |
val_top1.update(acc1[0], images.shape[0]) | |
t.set_postfix( | |
{ | |
"loss": val_loss.avg, | |
"top1": val_top1.avg, | |
"top5": val_top5.avg, | |
"#samples": val_top1.get_count(), | |
"bs": images.shape[0], | |
"res": images.shape[2], | |
} | |
) | |
t.update() | |
return { | |
"val_top1": val_top1.avg, | |
"val_loss": val_loss.avg, | |
**({"val_top5": val_top5.avg} if val_top5.count > 0 else {}), | |
} | |
def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: | |
images = feed_dict["data"].cuda() | |
labels = feed_dict["label"].cuda() | |
# label smooth | |
labels = label_smooth(labels, self.data_provider.n_classes, self.run_config.label_smooth) | |
# mixup | |
if self.run_config.mixup_config is not None: | |
# choose active mixup config | |
mix_weight_list = [mix_list[2] for mix_list in self.run_config.mixup_config["op"]] | |
active_id = torch_random_choices( | |
list(range(len(self.run_config.mixup_config["op"]))), | |
weight_list=mix_weight_list, | |
) | |
active_id = int(sync_tensor(active_id, reduce="root")) | |
active_mixup_config = self.run_config.mixup_config["op"][active_id] | |
mixup_type, mixup_alpha = active_mixup_config[:2] | |
lam = float(torch.distributions.beta.Beta(mixup_alpha, mixup_alpha).sample()) | |
lam = float(np.clip(lam, 0, 1)) | |
lam = float(sync_tensor(lam, reduce="root")) | |
images, labels = apply_mixup(images, labels, lam, mixup_type) | |
return { | |
"data": images, | |
"label": labels, | |
} | |
def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: | |
images = feed_dict["data"] | |
labels = feed_dict["label"] | |
# setup mesa | |
if self.run_config.mesa is not None and self.run_config.mesa["thresh"] <= self.run_config.progress: | |
ema_model = self.ema.shadows | |
with torch.inference_mode(): | |
ema_output = ema_model(images).detach() | |
ema_output = torch.clone(ema_output) | |
ema_output = F.sigmoid(ema_output).detach() | |
else: | |
ema_output = None | |
with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.enable_amp): | |
output = self.model(images) | |
loss = self.train_criterion(output, labels) | |
# mesa loss | |
if ema_output is not None: | |
mesa_loss = self.train_criterion(output, ema_output) | |
loss = loss + self.run_config.mesa["ratio"] * mesa_loss | |
self.scaler.scale(loss).backward() | |
# calc train top1 acc | |
if self.run_config.mixup_config is None: | |
top1 = accuracy(output, torch.argmax(labels, dim=1), topk=(1,))[0][0] | |
else: | |
top1 = None | |
return { | |
"loss": loss, | |
"top1": top1, | |
} | |
def _train_one_epoch(self, epoch: int) -> dict[str, any]: | |
train_loss = AverageMeter() | |
train_top1 = AverageMeter() | |
with tqdm( | |
total=len(self.data_provider.train), | |
desc="Train Epoch #{}".format(epoch + 1), | |
disable=not is_master(), | |
file=sys.stdout, | |
) as t: | |
for images, labels in self.data_provider.train: | |
feed_dict = {"data": images, "label": labels} | |
# preprocessing | |
feed_dict = self.before_step(feed_dict) | |
# clear gradient | |
self.optimizer.zero_grad() | |
# forward & backward | |
output_dict = self.run_step(feed_dict) | |
# update: optimizer, lr_scheduler | |
self.after_step() | |
# update train metrics | |
train_loss.update(output_dict["loss"], images.shape[0]) | |
if output_dict["top1"] is not None: | |
train_top1.update(output_dict["top1"], images.shape[0]) | |
# tqdm | |
postfix_dict = { | |
"loss": train_loss.avg, | |
"top1": train_top1.avg, | |
"bs": images.shape[0], | |
"res": images.shape[2], | |
"lr": list_join( | |
sorted(set([group["lr"] for group in self.optimizer.param_groups])), | |
"#", | |
"%.1E", | |
), | |
"progress": self.run_config.progress, | |
} | |
t.set_postfix(postfix_dict) | |
t.update() | |
return { | |
**({"train_top1": train_top1.avg} if train_top1.count > 0 else {}), | |
"train_loss": train_loss.avg, | |
} | |
def train(self, trials=0, save_freq=1) -> None: | |
if self.run_config.bce: | |
self.train_criterion = nn.BCEWithLogitsLoss() | |
else: | |
self.train_criterion = nn.CrossEntropyLoss() | |
for epoch in range(self.start_epoch, self.run_config.n_epochs + self.run_config.warmup_epochs): | |
train_info_dict = self.train_one_epoch(epoch) | |
# eval | |
val_info_dict = self.multires_validate(epoch=epoch) | |
avg_top1 = list_mean([info_dict["val_top1"] for info_dict in val_info_dict.values()]) | |
is_best = avg_top1 > self.best_val | |
self.best_val = max(avg_top1, self.best_val) | |
if self.auto_restart_thresh is not None: | |
if self.best_val - avg_top1 > self.auto_restart_thresh: | |
self.write_log(f"Abnormal accuracy drop: {self.best_val} -> {avg_top1}") | |
self.load_model(os.path.join(self.checkpoint_path, "model_best.pt")) | |
return self.train(trials + 1, save_freq) | |
# log | |
val_log = self.run_config.epoch_format(epoch) | |
val_log += f"\tval_top1={avg_top1:.2f}({self.best_val:.2f})" | |
val_log += "\tVal(" | |
for key in list(val_info_dict.values())[0]: | |
if key == "val_top1": | |
continue | |
val_log += f"{key}={list_mean([info_dict[key] for info_dict in val_info_dict.values()]):.2f}," | |
val_log += ")\tTrain(" | |
for key, val in train_info_dict.items(): | |
val_log += f"{key}={val:.2E}," | |
val_log += ( | |
f'lr={list_join(sorted(set([group["lr"] for group in self.optimizer.param_groups])), "#", "%.1E")})' | |
) | |
self.write_log(val_log, prefix="valid", print_log=False) | |
# save model | |
if (epoch + 1) % save_freq == 0 or (is_best and self.run_config.progress > 0.8): | |
self.save_model( | |
only_state_dict=False, | |
epoch=epoch, | |
model_name="model_best.pt" if is_best else "checkpoint.pt", | |
) | |