File size: 2,146 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from collections.abc import Iterator
from typing import TypeVar

import torch
from torch import nn


def optim_step(
    loss: torch.Tensor,
    optim: torch.optim.Optimizer,
    module: nn.Module | None = None,
    max_grad_norm: float | None = None,
) -> None:
    """Perform a single optimization step: zero_grad -> backward (-> clip_grad_norm) -> step.

    :param loss:
    :param optim:
    :param module: the module to optimize, required if max_grad_norm is passed
    :param max_grad_norm: if passed, will clip gradients using this
    """
    optim.zero_grad()
    loss.backward()
    if max_grad_norm:
        if not module:
            raise ValueError(
                "module must be passed if max_grad_norm is passed. "
                "Note: often the module will be the policy, i.e.`self`",
            )
        nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm)
    optim.step()


_STANDARD_TORCH_OPTIMIZERS = [
    torch.optim.Adam,
    torch.optim.SGD,
    torch.optim.RMSprop,
    torch.optim.Adadelta,
    torch.optim.AdamW,
    torch.optim.Adamax,
    torch.optim.NAdam,
    torch.optim.SparseAdam,
    torch.optim.LBFGS,
]

TOptim = TypeVar("TOptim", bound=torch.optim.Optimizer)


def clone_optimizer(
    optim: TOptim,
    new_params: nn.Parameter | Iterator[nn.Parameter],
) -> TOptim:
    """Clone an optimizer to get a new optim instance with new parameters.

    **WARNING**: This is a temporary measure, and should not be used in downstream code!
    Once tianshou interfaces have moved to optimizer factories instead of optimizers,
    this will be removed.

    :param optim: the optimizer to clone
    :param new_params: the new parameters to use
    :return: a new optimizer with the same configuration as the old one
    """
    optim_class = type(optim)
    # custom optimizers may not behave as expected
    if optim_class not in _STANDARD_TORCH_OPTIMIZERS:
        raise ValueError(
            f"Cannot clone optimizer {optim} of type {optim_class}"
            f"Currently, only standard torch optimizers are supported.",
        )
    return optim_class(new_params, **optim.defaults)