File size: 4,331 Bytes
ad5354d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import json
import numpy as np
import torch.nn as nn
from src.efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
__all__ = ["Scheduler", "RunConfig"]
class Scheduler:
PROGRESS = 0
class RunConfig:
n_epochs: int
init_lr: float
warmup_epochs: int
warmup_lr: float
lr_schedule_name: str
lr_schedule_param: dict
optimizer_name: str
optimizer_params: dict
weight_decay: float
no_wd_keys: list
grad_clip: float # allow none to turn off grad clipping
reset_bn: bool
reset_bn_size: int
reset_bn_batch_size: int
eval_image_size: list # allow none to use image_size in data_provider
@property
def none_allowed(self):
return ["grad_clip", "eval_image_size"]
def __init__(self, **kwargs): # arguments must be passed as kwargs
for k, val in kwargs.items():
setattr(self, k, val)
# check that all relevant configs are there
annotations = {}
for clas in type(self).mro():
if hasattr(clas, "__annotations__"):
annotations.update(clas.__annotations__)
for k, k_type in annotations.items():
assert hasattr(
self, k
), f"Key {k} with type {k_type} required for initialization."
attr = getattr(self, k)
if k in self.none_allowed:
k_type = (k_type, type(None))
assert isinstance(
attr, k_type
), f"Key {k} must be type {k_type}, provided={attr}."
self.global_step = 0
self.batch_per_epoch = 1
def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
param_dict = {}
for name, param in network.named_parameters():
if param.requires_grad:
opt_config = [self.weight_decay, self.init_lr]
if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
if np.any([key in name for key in self.no_wd_keys]):
opt_config[0] = 0
opt_key = json.dumps(opt_config)
param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
net_params = []
for opt_key, param_list in param_dict.items():
wd, lr = json.loads(opt_key)
net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
optimizer = build_optimizer(
net_params, self.optimizer_name, self.optimizer_params, self.init_lr
)
# build lr scheduler
if self.lr_schedule_name == "cosine":
decay_steps = []
for epoch in self.lr_schedule_param.get("step", []):
decay_steps.append(epoch * self.batch_per_epoch)
decay_steps.append(self.n_epochs * self.batch_per_epoch)
decay_steps.sort()
lr_scheduler = CosineLRwithWarmup(
optimizer,
self.warmup_epochs * self.batch_per_epoch,
self.warmup_lr,
decay_steps,
)
else:
raise NotImplementedError
return optimizer, lr_scheduler
def update_global_step(self, epoch, batch_id=0) -> None:
self.global_step = epoch * self.batch_per_epoch + batch_id
Scheduler.PROGRESS = self.progress
@property
def progress(self) -> float:
warmup_steps = self.warmup_epochs * self.batch_per_epoch
steps = max(0, self.global_step - warmup_steps)
return steps / (self.n_epochs * self.batch_per_epoch)
def step(self) -> None:
self.global_step += 1
Scheduler.PROGRESS = self.progress
def get_remaining_epoch(self, epoch, post=True) -> int:
return self.n_epochs + self.warmup_epochs - epoch - int(post)
def epoch_format(self, epoch: int) -> str:
epoch_format = f"%.{len(str(self.n_epochs))}d"
epoch_format = f"[{epoch_format}/{epoch_format}]"
epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
return epoch_format
|