File size: 950 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Jianjie Luo
@contact: [email protected]
"""
import torch
from uniperceiver.config import configurable
from .build import SOLVER_REGISTRY

@SOLVER_REGISTRY.register()
class AdamW(torch.optim.AdamW):
    @configurable
    def __init__(
        self, 
        *,
        params, 
        lr=1e-3, 
        betas=(0.9, 0.999), 
        eps=1e-8,
        weight_decay=0.01, 
        amsgrad=False
    ):
        super(AdamW, self).__init__(
            params, 
            lr, 
            betas, 
            eps,
            weight_decay, 
            amsgrad
        )

    @classmethod
    def from_config(cls, cfg, params):
        return {
            "params": params,
            "lr": cfg.SOLVER.BASE_LR, 
            "betas": cfg.SOLVER.BETAS, 
            "eps": cfg.SOLVER.EPS,
            "weight_decay": cfg.SOLVER.WEIGHT_DECAY, 
            "amsgrad": cfg.SOLVER.AMSGRAD
        }