Spaces:
Sleeping
Sleeping
from typing import Tuple, List | |
from collections import namedtuple | |
import torch | |
import torch.nn.functional as F | |
EPS = 1e-8 | |
def acer_policy_error( | |
q_values: torch.Tensor, | |
q_retraces: torch.Tensor, | |
v_pred: torch.Tensor, | |
target_logit: torch.Tensor, | |
actions: torch.Tensor, | |
ratio: torch.Tensor, | |
c_clip_ratio: float = 10.0 | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Overview: | |
Get ACER policy loss. | |
Arguments: | |
- q_values (:obj:`torch.Tensor`): Q values | |
- q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method) | |
- v_pred (:obj:`torch.Tensor`): V values | |
- target_pi (:obj:`torch.Tensor`): The new policy's probability | |
- actions (:obj:`torch.Tensor`): The actions in replay buffer | |
- ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy | |
- c_clip_ratio (:obj:`float`): clip value for ratio | |
Returns: | |
- actor_loss (:obj:`torch.Tensor`): policy loss from q_retrace | |
- bc_loss (:obj:`torch.Tensor`): correct policy loss | |
Shapes: | |
- q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim | |
- q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` | |
- v_pred (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` | |
- target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` | |
- actions (:obj:`torch.LongTensor`): :math:`(T, B)` | |
- ratio (:obj:`torch.FloatTensor`): :math:`(T, B, N)` | |
- actor_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` | |
- bc_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` | |
Examples: | |
>>> q_values=torch.randn(2, 3, 4), | |
>>> q_retraces=torch.randn(2, 3, 1), | |
>>> v_pred=torch.randn(2, 3, 1), | |
>>> target_pi=torch.randn(2, 3, 4), | |
>>> actions=torch.randint(0, 4, (2, 3)), | |
>>> ratio=torch.randn(2, 3, 4), | |
>>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio) | |
""" | |
actions = actions.unsqueeze(-1) | |
with torch.no_grad(): | |
advantage_retraces = q_retraces - v_pred # shape T,B,1 | |
advantage_native = q_values - v_pred # shape T,B,env_action_shape | |
actor_loss = ratio.gather(-1, actions).clamp(max=c_clip_ratio) * advantage_retraces * target_logit.gather( | |
-1, actions | |
) # shape T,B,1 | |
# bias correction term, the first target_pi will not calculate gradient flow | |
bias_correction_loss = (1.0-c_clip_ratio/(ratio+EPS)).clamp(min=0.0)*torch.exp(target_logit).detach() * \ | |
advantage_native*target_logit # shape T,B,env_action_shape | |
bias_correction_loss = bias_correction_loss.sum(-1, keepdim=True) | |
return actor_loss, bias_correction_loss | |
def acer_value_error(q_values, q_retraces, actions): | |
""" | |
Overview: | |
Get ACER critic loss. | |
Arguments: | |
- q_values (:obj:`torch.Tensor`): Q values | |
- q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method) | |
- actions (:obj:`torch.Tensor`): The actions in replay buffer | |
- ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy | |
Returns: | |
- critic_loss (:obj:`torch.Tensor`): critic loss | |
Shapes: | |
- q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim | |
- q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` | |
- actions (:obj:`torch.LongTensor`): :math:`(T, B)` | |
- critic_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` | |
Examples: | |
>>> q_values=torch.randn(2, 3, 4) | |
>>> q_retraces=torch.randn(2, 3, 1) | |
>>> actions=torch.randint(0, 4, (2, 3)) | |
>>> loss = acer_value_error(q_values, q_retraces, actions) | |
""" | |
actions = actions.unsqueeze(-1) | |
critic_loss = 0.5 * (q_retraces - q_values.gather(-1, actions)).pow(2) | |
return critic_loss | |
def acer_trust_region_update( | |
actor_gradients: List[torch.Tensor], target_logit: torch.Tensor, avg_logit: torch.Tensor, | |
trust_region_value: float | |
) -> List[torch.Tensor]: | |
""" | |
Overview: | |
calcuate gradient with trust region constrain | |
Arguments: | |
- actor_gradients (:obj:`list(torch.Tensor)`): gradients value's for different part | |
- target_pi (:obj:`torch.Tensor`): The new policy's probability | |
- avg_pi (:obj:`torch.Tensor`): The average policy's probability | |
- trust_region_value (:obj:`float`): the range of trust region | |
Returns: | |
- update_gradients (:obj:`list(torch.Tensor)`): gradients with trust region constraint | |
Shapes: | |
- target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` | |
- avg_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` | |
- update_gradients (:obj:`list(torch.FloatTensor)`): :math:`(T, B, N)` | |
Examples: | |
>>> actor_gradients=[torch.randn(2, 3, 4)] | |
>>> target_pi=torch.randn(2, 3, 4) | |
>>> avg_pi=torch.randn(2, 3, 4) | |
>>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1) | |
""" | |
with torch.no_grad(): | |
KL_gradients = [torch.exp(avg_logit)] | |
update_gradients = [] | |
# TODO: here is only one elements in this list.Maybe will use to more elements in the future | |
actor_gradient = actor_gradients[0] | |
KL_gradient = KL_gradients[0] | |
scale = actor_gradient.mul(KL_gradient).sum(-1, keepdim=True) - trust_region_value | |
scale = torch.div(scale, KL_gradient.mul(KL_gradient).sum(-1, keepdim=True)).clamp(min=0.0) | |
update_gradients.append(actor_gradient - scale * KL_gradient) | |
return update_gradients | |