File size: 1,148 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from abc import ABC, abstractmethod

import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR, LRScheduler

from tianshou.highlevel.config import SamplingConfig
from tianshou.utils.string import ToStringMixin


class LRSchedulerFactory(ToStringMixin, ABC):
    """Factory for the creation of a learning rate scheduler."""

    @abstractmethod
    def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
        pass


class LRSchedulerFactoryLinear(LRSchedulerFactory):
    def __init__(self, sampling_config: SamplingConfig):
        self.sampling_config = sampling_config

    def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
        return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute)

    class _LRLambda:
        def __init__(self, sampling_config: SamplingConfig):
            self.max_update_num = (
                np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect)
                * sampling_config.num_epochs
            )

        def compute(self, epoch: int) -> float:
            return 1.0 - epoch / self.max_update_num