File size: 4,331 Bytes
ad5354d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023

import json

import numpy as np
import torch.nn as nn

from src.efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer

__all__ = ["Scheduler", "RunConfig"]


class Scheduler:
    PROGRESS = 0


class RunConfig:
    n_epochs: int
    init_lr: float
    warmup_epochs: int
    warmup_lr: float
    lr_schedule_name: str
    lr_schedule_param: dict
    optimizer_name: str
    optimizer_params: dict
    weight_decay: float
    no_wd_keys: list
    grad_clip: float  # allow none to turn off grad clipping
    reset_bn: bool
    reset_bn_size: int
    reset_bn_batch_size: int
    eval_image_size: list  # allow none to use image_size in data_provider

    @property
    def none_allowed(self):
        return ["grad_clip", "eval_image_size"]

    def __init__(self, **kwargs):  # arguments must be passed as kwargs
        for k, val in kwargs.items():
            setattr(self, k, val)

        # check that all relevant configs are there
        annotations = {}
        for clas in type(self).mro():
            if hasattr(clas, "__annotations__"):
                annotations.update(clas.__annotations__)
        for k, k_type in annotations.items():
            assert hasattr(
                self, k
            ), f"Key {k} with type {k_type} required for initialization."
            attr = getattr(self, k)
            if k in self.none_allowed:
                k_type = (k_type, type(None))
            assert isinstance(
                attr, k_type
            ), f"Key {k} must be type {k_type}, provided={attr}."

        self.global_step = 0
        self.batch_per_epoch = 1

    def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
        r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
        param_dict = {}
        for name, param in network.named_parameters():
            if param.requires_grad:
                opt_config = [self.weight_decay, self.init_lr]
                if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
                    if np.any([key in name for key in self.no_wd_keys]):
                        opt_config[0] = 0
                opt_key = json.dumps(opt_config)
                param_dict[opt_key] = param_dict.get(opt_key, []) + [param]

        net_params = []
        for opt_key, param_list in param_dict.items():
            wd, lr = json.loads(opt_key)
            net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})

        optimizer = build_optimizer(
            net_params, self.optimizer_name, self.optimizer_params, self.init_lr
        )
        # build lr scheduler
        if self.lr_schedule_name == "cosine":
            decay_steps = []
            for epoch in self.lr_schedule_param.get("step", []):
                decay_steps.append(epoch * self.batch_per_epoch)
            decay_steps.append(self.n_epochs * self.batch_per_epoch)
            decay_steps.sort()
            lr_scheduler = CosineLRwithWarmup(
                optimizer,
                self.warmup_epochs * self.batch_per_epoch,
                self.warmup_lr,
                decay_steps,
            )
        else:
            raise NotImplementedError
        return optimizer, lr_scheduler

    def update_global_step(self, epoch, batch_id=0) -> None:
        self.global_step = epoch * self.batch_per_epoch + batch_id
        Scheduler.PROGRESS = self.progress

    @property
    def progress(self) -> float:
        warmup_steps = self.warmup_epochs * self.batch_per_epoch
        steps = max(0, self.global_step - warmup_steps)
        return steps / (self.n_epochs * self.batch_per_epoch)

    def step(self) -> None:
        self.global_step += 1
        Scheduler.PROGRESS = self.progress

    def get_remaining_epoch(self, epoch, post=True) -> int:
        return self.n_epochs + self.warmup_epochs - epoch - int(post)

    def epoch_format(self, epoch: int) -> str:
        epoch_format = f"%.{len(str(self.n_epochs))}d"
        epoch_format = f"[{epoch_format}/{epoch_format}]"
        epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
        return epoch_format