Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from torch.distributions import Categorical, Independent, Normal | |
from collections import namedtuple | |
from .isw import compute_importance_weights | |
from ding.hpc_rl import hpc_wrapper | |
def vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamma=0.99, lambda_=0.95): | |
""" | |
Overview: | |
Computation of vtrace return. | |
Returns: | |
- vtrace_return (:obj:`torch.FloatTensor`): the vtrace loss item, all of them are differentiable 0-dim tensor | |
Shapes: | |
- clipped_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size | |
- clipped_cs (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
- bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T+1, B)` | |
- vtrace_return (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
""" | |
deltas = clipped_rhos * (reward + gamma * bootstrap_values[1:] - bootstrap_values[:-1]) | |
factor = gamma * lambda_ | |
result = bootstrap_values[:-1].clone() | |
vtrace_item = 0. | |
for t in reversed(range(reward.size()[0])): | |
vtrace_item = deltas[t] + factor * clipped_cs[t] * vtrace_item | |
result[t] += vtrace_item | |
return result | |
def vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma): | |
""" | |
Overview: | |
Computation of vtrace advantage. | |
Returns: | |
- vtrace_advantage (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor | |
Shapes: | |
- clipped_pg_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size | |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
- return (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
- bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
- vtrace_advantage (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
""" | |
return clipped_pg_rhos * (reward + gamma * return_ - bootstrap_values) | |
vtrace_data = namedtuple('vtrace_data', ['target_output', 'behaviour_output', 'action', 'value', 'reward', 'weight']) | |
vtrace_loss = namedtuple('vtrace_loss', ['policy_loss', 'value_loss', 'entropy_loss']) | |
def shape_fn_vtrace_discrete_action(args, kwargs): | |
r""" | |
Overview: | |
Return shape of vtrace for hpc | |
Returns: | |
shape: [T, B, N] | |
""" | |
if len(args) <= 0: | |
tmp = kwargs['data'].target_output.shape | |
else: | |
tmp = args[0].target_output.shape | |
return tmp | |
def vtrace_error_discrete_action( | |
data: namedtuple, | |
gamma: float = 0.99, | |
lambda_: float = 0.95, | |
rho_clip_ratio: float = 1.0, | |
c_clip_ratio: float = 1.0, | |
rho_pg_clip_ratio: float = 1.0 | |
): | |
""" | |
Overview: | |
Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ | |
Architectures), (arXiv:1802.01561) | |
Arguments: | |
- data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data`` | |
- target_output (:obj:`torch.Tensor`): the output taking the action by the current policy network,\ | |
usually this output is network output logit | |
- behaviour_output (:obj:`torch.Tensor`): the output taking the action by the behaviour policy network,\ | |
usually this output is network output logit, which is used to produce the trajectory(collector) | |
- action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\ | |
i.e.: behaviour_action | |
- gamma: (:obj:`float`): the future discount factor, defaults to 0.95 | |
- lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 | |
- rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ | |
the baseline targets (vs) | |
- c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ | |
the baseline targets (vs) | |
- rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ | |
the policy gradient advantage | |
Returns: | |
- trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor | |
Shapes: | |
- target_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where T is timestep, B is batch size and\ | |
N is action dim | |
- behaviour_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)` | |
- action (:obj:`torch.LongTensor`): :math:`(T, B)` | |
- value (:obj:`torch.FloatTensor`): :math:`(T+1, B)` | |
- reward (:obj:`torch.LongTensor`): :math:`(T, B)` | |
- weight (:obj:`torch.LongTensor`): :math:`(T, B)` | |
Examples: | |
>>> T, B, N = 4, 8, 16 | |
>>> value = torch.randn(T + 1, B).requires_grad_(True) | |
>>> reward = torch.rand(T, B) | |
>>> target_output = torch.randn(T, B, N).requires_grad_(True) | |
>>> behaviour_output = torch.randn(T, B, N) | |
>>> action = torch.randint(0, N, size=(T, B)) | |
>>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) | |
>>> loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1) | |
""" | |
target_output, behaviour_output, action, value, reward, weight = data | |
with torch.no_grad(): | |
IS = compute_importance_weights(target_output, behaviour_output, action, 'discrete') | |
rhos = torch.clamp(IS, max=rho_clip_ratio) | |
cs = torch.clamp(IS, max=c_clip_ratio) | |
return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_) | |
pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) | |
return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) | |
adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma) | |
if weight is None: | |
weight = torch.ones_like(reward) | |
dist_target = Categorical(logits=target_output) | |
pg_loss = -(dist_target.log_prob(action) * adv * weight).mean() | |
value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() | |
entropy_loss = (dist_target.entropy() * weight).mean() | |
return vtrace_loss(pg_loss, value_loss, entropy_loss) | |
def vtrace_error_continuous_action( | |
data: namedtuple, | |
gamma: float = 0.99, | |
lambda_: float = 0.95, | |
rho_clip_ratio: float = 1.0, | |
c_clip_ratio: float = 1.0, | |
rho_pg_clip_ratio: float = 1.0 | |
): | |
""" | |
Overview: | |
Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ | |
Architectures), (arXiv:1802.01561) | |
Arguments: | |
- data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data`` | |
- target_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \ | |
by the current policy network, usually this output is network output, \ | |
which represents the distribution by reparameterization trick. | |
- behaviour_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \ | |
by the behaviour policy network, usually this output is network output logit, \ | |
which represents the distribution by reparameterization trick. | |
- action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory, \ | |
i.e.: behaviour_action | |
- gamma: (:obj:`float`): the future discount factor, defaults to 0.95 | |
- lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 | |
- rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ | |
the baseline targets (vs) | |
- c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ | |
the baseline targets (vs) | |
- rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ | |
the policy gradient advantage | |
Returns: | |
- trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor | |
Shapes: | |
- target_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)`, \ | |
where T is timestep, B is batch size and \ | |
N is action dim. The keys are usually parameters of reparameterization trick. | |
- behaviour_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)` | |
- action (:obj:`torch.LongTensor`): :math:`(T, B)` | |
- value (:obj:`torch.FloatTensor`): :math:`(T+1, B)` | |
- reward (:obj:`torch.LongTensor`): :math:`(T, B)` | |
- weight (:obj:`torch.LongTensor`): :math:`(T, B)` | |
Examples: | |
>>> T, B, N = 4, 8, 16 | |
>>> value = torch.randn(T + 1, B).requires_grad_(True) | |
>>> reward = torch.rand(T, B) | |
>>> target_output = dict( | |
>>> 'mu': torch.randn(T, B, N).requires_grad_(True), | |
>>> 'sigma': torch.exp(torch.randn(T, B, N).requires_grad_(True)), | |
>>> ) | |
>>> behaviour_output = dict( | |
>>> 'mu': torch.randn(T, B, N), | |
>>> 'sigma': torch.exp(torch.randn(T, B, N)), | |
>>> ) | |
>>> action = torch.randn((T, B, N)) | |
>>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) | |
>>> loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1) | |
""" | |
target_output, behaviour_output, action, value, reward, weight = data | |
with torch.no_grad(): | |
IS = compute_importance_weights(target_output, behaviour_output, action, 'continuous') | |
rhos = torch.clamp(IS, max=rho_clip_ratio) | |
cs = torch.clamp(IS, max=c_clip_ratio) | |
return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_) | |
pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) | |
return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) | |
adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma) | |
if weight is None: | |
weight = torch.ones_like(reward) | |
dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1) | |
pg_loss = -(dist_target.log_prob(action) * adv * weight).mean() | |
value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() | |
entropy_loss = (dist_target.entropy() * weight).mean() | |
return vtrace_loss(pg_loss, value_loss, entropy_loss) | |