Spaces:
Sleeping
Sleeping
from typing import TYPE_CHECKING, Callable, List, Tuple, Any | |
from functools import reduce | |
import treetensor.torch as ttorch | |
import numpy as np | |
from ditk import logging | |
from ding.utils import EasyTimer | |
from ding.envs import BaseEnvManager | |
from ding.policy import Policy | |
from ding.torch_utils import to_ndarray, get_shape0 | |
if TYPE_CHECKING: | |
from ding.framework import OnlineRLContext | |
class TransitionList: | |
def __init__(self, env_num: int) -> None: | |
self.env_num = env_num | |
self._transitions = [[] for _ in range(env_num)] | |
self._done_idx = [[] for _ in range(env_num)] | |
def append(self, env_id: int, transition: Any) -> None: | |
self._transitions[env_id].append(transition) | |
if transition.done: | |
self._done_idx[env_id].append(len(self._transitions[env_id])) | |
def to_trajectories(self) -> Tuple[List[Any], List[int]]: | |
trajectories = sum(self._transitions, []) | |
lengths = [len(t) for t in self._transitions] | |
trajectory_end_idx = [reduce(lambda x, y: x + y, lengths[:i + 1]) for i in range(len(lengths))] | |
trajectory_end_idx = [t - 1 for t in trajectory_end_idx] | |
return trajectories, trajectory_end_idx | |
def to_episodes(self) -> List[List[Any]]: | |
episodes = [] | |
for env_id in range(self.env_num): | |
last_idx = 0 | |
for done_idx in self._done_idx[env_id]: | |
episodes.append(self._transitions[env_id][last_idx:done_idx]) | |
last_idx = done_idx | |
return episodes | |
def clear(self): | |
for item in self._transitions: | |
item.clear() | |
for item in self._done_idx: | |
item.clear() | |
def inferencer(seed: int, policy: Policy, env: BaseEnvManager) -> Callable: | |
""" | |
Overview: | |
The middleware that executes the inference process. | |
Arguments: | |
- seed (:obj:`int`): Random seed. | |
- policy (:obj:`Policy`): The policy to be inferred. | |
- env (:obj:`BaseEnvManager`): The env where the inference process is performed. \ | |
The env.ready_obs (:obj:`tnp.array`) will be used as model input. | |
""" | |
env.seed(seed) | |
def _inference(ctx: "OnlineRLContext"): | |
""" | |
Output of ctx: | |
- obs (:obj:`Union[torch.Tensor, Dict[torch.Tensor]]`): The input observations collected \ | |
from all collector environments. | |
- action: (:obj:`List[np.ndarray]`): The inferred actions listed by env_id. | |
- inference_output (:obj:`Dict[int, Dict]`): The dict of which the key is env_id (int), \ | |
and the value is inference result (Dict). | |
""" | |
if env.closed: | |
env.launch() | |
obs = ttorch.as_tensor(env.ready_obs) | |
ctx.obs = obs | |
obs = obs.to(dtype=ttorch.float32) | |
# TODO mask necessary rollout | |
obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD | |
inference_output = policy.forward(obs, **ctx.collect_kwargs) | |
ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD | |
ctx.inference_output = inference_output | |
return _inference | |
def rolloutor( | |
policy: Policy, | |
env: BaseEnvManager, | |
transitions: TransitionList, | |
collect_print_freq=100, | |
) -> Callable: | |
""" | |
Overview: | |
The middleware that executes the transition process in the env. | |
Arguments: | |
- policy (:obj:`Policy`): The policy to be used during transition. | |
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ | |
its derivatives are supported. | |
- transitions (:obj:`TransitionList`): The transition information which will be filled \ | |
in this process, including `obs`, `next_obs`, `action`, `logit`, `value`, `reward` \ | |
and `done`. | |
""" | |
env_episode_id = [_ for _ in range(env.env_num)] | |
current_id = env.env_num | |
timer = EasyTimer() | |
last_train_iter = 0 | |
total_envstep_count = 0 | |
total_episode_count = 0 | |
total_train_sample_count = 0 | |
env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)} | |
episode_info = [] | |
def _rollout(ctx: "OnlineRLContext"): | |
""" | |
Input of ctx: | |
- action: (:obj:`List[np.ndarray]`): The inferred actions from previous inference process. | |
- obs (:obj:`Dict[Tensor]`): The states fed into the transition dict. | |
- inference_output (:obj:`Dict[int, Dict]`): The inference results to be fed into the \ | |
transition dict. | |
- train_iter (:obj:`int`): The train iteration count to be fed into the transition dict. | |
- env_step (:obj:`int`): The count of env step, which will increase by 1 for a single \ | |
transition call. | |
- env_episode (:obj:`int`): The count of env episode, which will increase by 1 if the \ | |
trajectory stops. | |
""" | |
nonlocal current_id, env_info, episode_info, timer, \ | |
total_episode_count, total_envstep_count, total_train_sample_count, last_train_iter | |
timesteps = env.step(ctx.action) | |
ctx.env_step += len(timesteps) | |
timesteps = [t.tensor() for t in timesteps] | |
collected_sample = 0 | |
collected_step = 0 | |
collected_episode = 0 | |
interaction_duration = timer.value / len(timesteps) | |
for i, timestep in enumerate(timesteps): | |
with timer: | |
transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep) | |
transition = ttorch.as_tensor(transition) | |
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) | |
transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]]) | |
transitions.append(timestep.env_id, transition) | |
collected_step += 1 | |
collected_sample += len(transition.obs) | |
env_info[timestep.env_id.item()]['step'] += 1 | |
env_info[timestep.env_id.item()]['train_sample'] += len(transition.obs) | |
env_info[timestep.env_id.item()]['time'] += timer.value + interaction_duration | |
if timestep.done: | |
info = { | |
'reward': timestep.info['eval_episode_return'], | |
'time': env_info[timestep.env_id.item()]['time'], | |
'step': env_info[timestep.env_id.item()]['step'], | |
'train_sample': env_info[timestep.env_id.item()]['train_sample'], | |
} | |
episode_info.append(info) | |
policy.reset([timestep.env_id.item()]) | |
env_episode_id[timestep.env_id.item()] = current_id | |
collected_episode += 1 | |
current_id += 1 | |
ctx.env_episode += 1 | |
total_envstep_count += collected_step | |
total_episode_count += collected_episode | |
total_train_sample_count += collected_sample | |
if (ctx.train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0: | |
output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) | |
last_train_iter = ctx.train_iter | |
return _rollout | |
def output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) -> None: | |
""" | |
Overview: | |
Print the output log information. You can refer to the docs of `Best Practice` to understand \ | |
the training generated logs and tensorboards. | |
Arguments: | |
- train_iter (:obj:`int`): the number of training iteration. | |
""" | |
episode_count = len(episode_info) | |
envstep_count = sum([d['step'] for d in episode_info]) | |
train_sample_count = sum([d['train_sample'] for d in episode_info]) | |
duration = sum([d['time'] for d in episode_info]) | |
episode_return = [d['reward'].item() for d in episode_info] | |
info = { | |
'episode_count': episode_count, | |
'envstep_count': envstep_count, | |
'train_sample_count': train_sample_count, | |
'avg_envstep_per_episode': envstep_count / episode_count, | |
'avg_sample_per_episode': train_sample_count / episode_count, | |
'avg_envstep_per_sec': envstep_count / duration, | |
'avg_train_sample_per_sec': train_sample_count / duration, | |
'avg_episode_per_sec': episode_count / duration, | |
'reward_mean': np.mean(episode_return), | |
'reward_std': np.std(episode_return), | |
'reward_max': np.max(episode_return), | |
'reward_min': np.min(episode_return), | |
'total_envstep_count': total_envstep_count, | |
'total_train_sample_count': total_train_sample_count, | |
'total_episode_count': total_episode_count, | |
# 'each_reward': episode_return, | |
} | |
episode_info.clear() | |
logging.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) | |