pg56714's picture
Upload 115 files
8e5cc83 verified
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from efficientvit.apps.trainer import Trainer
from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor
from efficientvit.clscore.trainer.utils import accuracy, apply_mixup, label_smooth
from efficientvit.models.utils import list_join, list_mean, torch_random_choices
__all__ = ["ClsTrainer"]
class ClsTrainer(Trainer):
def __init__(
self,
path: str,
model: nn.Module,
data_provider,
auto_restart_thresh: float or None = None,
) -> None:
super().__init__(
path=path,
model=model,
data_provider=data_provider,
)
self.auto_restart_thresh = auto_restart_thresh
self.test_criterion = nn.CrossEntropyLoss()
def _validate(self, model, data_loader, epoch) -> dict[str, any]:
val_loss = AverageMeter()
val_top1 = AverageMeter()
val_top5 = AverageMeter()
with torch.no_grad():
with tqdm(
total=len(data_loader),
desc=f"Validate Epoch #{epoch + 1}",
disable=not is_master(),
file=sys.stdout,
) as t:
for images, labels in data_loader:
images, labels = images.cuda(), labels.cuda()
# compute output
output = model(images)
loss = self.test_criterion(output, labels)
val_loss.update(loss, images.shape[0])
if self.data_provider.n_classes >= 100:
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
val_top5.update(acc5[0], images.shape[0])
else:
acc1 = accuracy(output, labels, topk=(1,))[0]
val_top1.update(acc1[0], images.shape[0])
t.set_postfix(
{
"loss": val_loss.avg,
"top1": val_top1.avg,
"top5": val_top5.avg,
"#samples": val_top1.get_count(),
"bs": images.shape[0],
"res": images.shape[2],
}
)
t.update()
return {
"val_top1": val_top1.avg,
"val_loss": val_loss.avg,
**({"val_top5": val_top5.avg} if val_top5.count > 0 else {}),
}
def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
images = feed_dict["data"].cuda()
labels = feed_dict["label"].cuda()
# label smooth
labels = label_smooth(labels, self.data_provider.n_classes, self.run_config.label_smooth)
# mixup
if self.run_config.mixup_config is not None:
# choose active mixup config
mix_weight_list = [mix_list[2] for mix_list in self.run_config.mixup_config["op"]]
active_id = torch_random_choices(
list(range(len(self.run_config.mixup_config["op"]))),
weight_list=mix_weight_list,
)
active_id = int(sync_tensor(active_id, reduce="root"))
active_mixup_config = self.run_config.mixup_config["op"][active_id]
mixup_type, mixup_alpha = active_mixup_config[:2]
lam = float(torch.distributions.beta.Beta(mixup_alpha, mixup_alpha).sample())
lam = float(np.clip(lam, 0, 1))
lam = float(sync_tensor(lam, reduce="root"))
images, labels = apply_mixup(images, labels, lam, mixup_type)
return {
"data": images,
"label": labels,
}
def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
images = feed_dict["data"]
labels = feed_dict["label"]
# setup mesa
if self.run_config.mesa is not None and self.run_config.mesa["thresh"] <= self.run_config.progress:
ema_model = self.ema.shadows
with torch.inference_mode():
ema_output = ema_model(images).detach()
ema_output = torch.clone(ema_output)
ema_output = F.sigmoid(ema_output).detach()
else:
ema_output = None
with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.enable_amp):
output = self.model(images)
loss = self.train_criterion(output, labels)
# mesa loss
if ema_output is not None:
mesa_loss = self.train_criterion(output, ema_output)
loss = loss + self.run_config.mesa["ratio"] * mesa_loss
self.scaler.scale(loss).backward()
# calc train top1 acc
if self.run_config.mixup_config is None:
top1 = accuracy(output, torch.argmax(labels, dim=1), topk=(1,))[0][0]
else:
top1 = None
return {
"loss": loss,
"top1": top1,
}
def _train_one_epoch(self, epoch: int) -> dict[str, any]:
train_loss = AverageMeter()
train_top1 = AverageMeter()
with tqdm(
total=len(self.data_provider.train),
desc="Train Epoch #{}".format(epoch + 1),
disable=not is_master(),
file=sys.stdout,
) as t:
for images, labels in self.data_provider.train:
feed_dict = {"data": images, "label": labels}
# preprocessing
feed_dict = self.before_step(feed_dict)
# clear gradient
self.optimizer.zero_grad()
# forward & backward
output_dict = self.run_step(feed_dict)
# update: optimizer, lr_scheduler
self.after_step()
# update train metrics
train_loss.update(output_dict["loss"], images.shape[0])
if output_dict["top1"] is not None:
train_top1.update(output_dict["top1"], images.shape[0])
# tqdm
postfix_dict = {
"loss": train_loss.avg,
"top1": train_top1.avg,
"bs": images.shape[0],
"res": images.shape[2],
"lr": list_join(
sorted(set([group["lr"] for group in self.optimizer.param_groups])),
"#",
"%.1E",
),
"progress": self.run_config.progress,
}
t.set_postfix(postfix_dict)
t.update()
return {
**({"train_top1": train_top1.avg} if train_top1.count > 0 else {}),
"train_loss": train_loss.avg,
}
def train(self, trials=0, save_freq=1) -> None:
if self.run_config.bce:
self.train_criterion = nn.BCEWithLogitsLoss()
else:
self.train_criterion = nn.CrossEntropyLoss()
for epoch in range(self.start_epoch, self.run_config.n_epochs + self.run_config.warmup_epochs):
train_info_dict = self.train_one_epoch(epoch)
# eval
val_info_dict = self.multires_validate(epoch=epoch)
avg_top1 = list_mean([info_dict["val_top1"] for info_dict in val_info_dict.values()])
is_best = avg_top1 > self.best_val
self.best_val = max(avg_top1, self.best_val)
if self.auto_restart_thresh is not None:
if self.best_val - avg_top1 > self.auto_restart_thresh:
self.write_log(f"Abnormal accuracy drop: {self.best_val} -> {avg_top1}")
self.load_model(os.path.join(self.checkpoint_path, "model_best.pt"))
return self.train(trials + 1, save_freq)
# log
val_log = self.run_config.epoch_format(epoch)
val_log += f"\tval_top1={avg_top1:.2f}({self.best_val:.2f})"
val_log += "\tVal("
for key in list(val_info_dict.values())[0]:
if key == "val_top1":
continue
val_log += f"{key}={list_mean([info_dict[key] for info_dict in val_info_dict.values()]):.2f},"
val_log += ")\tTrain("
for key, val in train_info_dict.items():
val_log += f"{key}={val:.2E},"
val_log += (
f'lr={list_join(sorted(set([group["lr"] for group in self.optimizer.param_groups])), "#", "%.1E")})'
)
self.write_log(val_log, prefix="valid", print_log=False)
# save model
if (epoch + 1) % save_freq == 0 or (is_best and self.run_config.progress > 0.8):
self.save_model(
only_state_dict=False,
epoch=epoch,
model_name="model_best.pt" if is_best else "checkpoint.pt",
)