File size: 4,925 Bytes
89c0b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings

import torch
from torch.optim.lr_scheduler import LRScheduler


class CosineAnnealingWithWarmup(LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_steps: int,
        decay_steps: int,
        lr: float,
        min_lr: float,
        last_epoch: int = -1,
        verbose: bool = False,
    ):
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_steps
        self.lr = lr
        self.min_lr = min_lr
        super().__init__(optimizer, last_epoch, verbose)

    def _get_step_lr(self, step):
        if step <= self.warmup_steps:
            return (step + 1) / (self.warmup_steps + 1) * self.lr
        elif step >= self.decay_steps:
            return self.min_lr
        else:
            decay_ratio = (step - self.warmup_steps) / (
                self.decay_steps - self.warmup_steps
            )
            assert 0 <= decay_ratio <= 1
            coff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
            return self.min_lr + coff * (self.lr - self.min_lr)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, "
                "please use `get_last_lr()`.",
                UserWarning,
            )
        return [
            self._get_step_lr(self.last_epoch) for group in self.optimizer.param_groups
        ]

    def _get_closed_form_lr(self):
        return [self._get_step_lr(self.last_epoch) for base_lr in self.base_lrs]


# The Alphafold3 Learning Rate Scheduler As in 5.4
class AlphaFold3LRScheduler(LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        last_epoch: int = -1,
        verbose: bool = False,
        warmup_steps: int = 1000,
        lr: float = 1.8e-3,
        decay_every_n_steps: int = 50000,
        decay_factor: float = 0.95,
    ) -> None:
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_every_n_steps
        self.lr = lr
        self.decay_factor = decay_factor
        super(AlphaFold3LRScheduler, self).__init__(
            optimizer=optimizer, last_epoch=last_epoch, verbose=verbose
        )

    def _get_step_lr(self, step):
        if step <= self.warmup_steps:
            lr = step / self.warmup_steps * self.lr
        else:
            decay_count = step // self.decay_steps
            lr = self.lr * (self.decay_factor**decay_count)
        return lr

    def get_lr(self) -> list[float]:
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, "
                "please use `get_last_lr()`.",
                UserWarning,
            )
        return [
            self._get_step_lr(self.last_epoch) for group in self.optimizer.param_groups
        ]


def get_lr_scheduler(
    configs, optimizer: torch.optim.Optimizer, **kwargs
) -> torch.optim.lr_scheduler.LRScheduler:
    """
    Get the learning rate scheduler based on the configuration.

    Args:
        configs: Configuration object containing scheduler settings.
        optimizer (torch.optim.Optimizer): The optimizer to which the scheduler will be attached.
        **kwargs: Additional keyword arguments to be passed to the scheduler.

    Returns:
        torch.optim.lr_scheduler.LRScheduler: The learning rate scheduler.

    Raises:
        ValueError: If the specified learning rate scheduler is invalid.
    """
    if configs.lr_scheduler == "af3":
        lr_scheduler = AlphaFold3LRScheduler(
            optimizer, **configs.af3_lr_scheduler, **kwargs
        )
    elif configs.lr_scheduler == "cosine_annealing":
        lr_scheduler = CosineAnnealingWithWarmup(
            optimizer,
            configs.warmup_steps,
            configs.max_steps,
            configs.lr,
            configs.lr * configs.min_lr_ratio,
            **kwargs,
        )
    elif configs.lr_scheduler == "constant":
        lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
            optimizer,
            factor=1.0,
            total_iters=configs.max_steps,
            **kwargs,
        )
    else:
        raise ValueError(f"Invalid lr scheduler: [{configs.lr_scheduler}]")
    return lr_scheduler