File size: 3,299 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import logging
from typing import Optional, Callable

from .base import BaseStrategy
from .utils import dry_run_for_search_space
from ..execution import query_available_resources

try:
    has_tianshou = True
    import torch
    from tianshou.data import Collector, VectorReplayBuffer
    from tianshou.env import BaseVectorEnv
    from tianshou.policy import BasePolicy, PPOPolicy  # pylint: disable=unused-import
    from ._rl_impl import ModelEvaluationEnv, MultiThreadEnvWorker, Preprocessor, Actor, Critic
except ImportError:
    has_tianshou = False


_logger = logging.getLogger(__name__)


class PolicyBasedRL(BaseStrategy):
    """
    Algorithm for policy-based reinforcement learning.
    This is a wrapper of algorithms provided in tianshou (PPO by default),
    and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., REINFORCE [1]_).

    Parameters
    ----------
    max_collect : int
        How many times collector runs to collect trials for RL. Default 100.
    trial_per_collect : int
        How many trials (trajectories) each time collector collects.
        After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
    policy_fn : function
        Takes ``ModelEvaluationEnv`` as input and return a policy. See ``_default_policy_fn`` for an example.

    References
    ----------

    .. [1] Barret Zoph and Quoc V. Le, "Neural Architecture Search with Reinforcement Learning".
        https://arxiv.org/abs/1611.01578
    """

    def __init__(self, max_collect: int = 100, trial_per_collect = 20,
                 policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None):
        if not has_tianshou:
            raise ImportError('`tianshou` is required to run RL-based strategy. '
                              'Please use "pip install tianshou" to install it beforehand.')

        self.policy_fn = policy_fn or self._default_policy_fn
        self.max_collect = max_collect
        self.trial_per_collect = trial_per_collect

    @staticmethod
    def _default_policy_fn(env):
        net = Preprocessor(env.observation_space)
        actor = Actor(env.action_space, net)
        critic = Critic(net)
        optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=1e-4)
        return PPOPolicy(actor, critic, optim, torch.distributions.Categorical,
                         discount_factor=1., action_space=env.action_space)

    def run(self, base_model, applied_mutators):
        search_space = dry_run_for_search_space(base_model, applied_mutators)
        concurrency = query_available_resources()

        env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space)
        policy = self.policy_fn(env_fn())

        env = BaseVectorEnv([env_fn for _ in range(concurrency)], MultiThreadEnvWorker)
        collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))

        for cur_collect in range(1, self.max_collect + 1):
            _logger.info('Collect [%d] Running...', cur_collect)
            result = collector.collect(n_episode=self.trial_per_collect)
            _logger.info('Collect [%d] Result: %s', cur_collect, str(result))
            policy.update(0, collector.buffer, batch_size=64, repeat=5)