Spaces:
Sleeping
Sleeping
from collections import namedtuple | |
import torch | |
import torch.nn.functional as F | |
from torch.distributions import Independent, Normal | |
a2c_data = namedtuple('a2c_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight']) | |
a2c_loss = namedtuple('a2c_loss', ['policy_loss', 'value_loss', 'entropy_loss']) | |
def a2c_error(data: namedtuple) -> namedtuple: | |
""" | |
Overview: | |
Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for discrete action space | |
Arguments: | |
- data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data`` | |
Returns: | |
- a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor | |
Shapes: | |
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim | |
- action (:obj:`torch.LongTensor`): :math:`(B, )` | |
- value (:obj:`torch.FloatTensor`): :math:`(B, )` | |
- adv (:obj:`torch.FloatTensor`): :math:`(B, )` | |
- return (:obj:`torch.FloatTensor`): :math:`(B, )` | |
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` | |
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor | |
- value_loss (:obj:`torch.FloatTensor`): :math:`()` | |
- entropy_loss (:obj:`torch.FloatTensor`): :math:`()` | |
Examples: | |
>>> data = a2c_data( | |
>>> logit=torch.randn(2, 3), | |
>>> action=torch.randint(0, 3, (2, )), | |
>>> value=torch.randn(2, ), | |
>>> adv=torch.randn(2, ), | |
>>> return_=torch.randn(2, ), | |
>>> weight=torch.ones(2, ), | |
>>> ) | |
>>> loss = a2c_error(data) | |
""" | |
logit, action, value, adv, return_, weight = data | |
if weight is None: | |
weight = torch.ones_like(value) | |
dist = torch.distributions.categorical.Categorical(logits=logit) | |
logp = dist.log_prob(action) | |
entropy_loss = (dist.entropy() * weight).mean() | |
policy_loss = -(logp * adv * weight).mean() | |
value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean() | |
return a2c_loss(policy_loss, value_loss, entropy_loss) | |
def a2c_error_continuous(data: namedtuple) -> namedtuple: | |
""" | |
Overview: | |
Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for continuous action space | |
Arguments: | |
- data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data`` | |
Returns: | |
- a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor | |
Shapes: | |
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim | |
- action (:obj:`torch.LongTensor`): :math:`(B, N)` | |
- value (:obj:`torch.FloatTensor`): :math:`(B, )` | |
- adv (:obj:`torch.FloatTensor`): :math:`(B, )` | |
- return (:obj:`torch.FloatTensor`): :math:`(B, )` | |
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` | |
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor | |
- value_loss (:obj:`torch.FloatTensor`): :math:`()` | |
- entropy_loss (:obj:`torch.FloatTensor`): :math:`()` | |
Examples: | |
>>> data = a2c_data( | |
>>> logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)}, | |
>>> action=torch.randn(2, 3), | |
>>> value=torch.randn(2, ), | |
>>> adv=torch.randn(2, ), | |
>>> return_=torch.randn(2, ), | |
>>> weight=torch.ones(2, ), | |
>>> ) | |
>>> loss = a2c_error_continuous(data) | |
""" | |
logit, action, value, adv, return_, weight = data | |
if weight is None: | |
weight = torch.ones_like(value) | |
dist = Independent(Normal(logit['mu'], logit['sigma']), 1) | |
logp = dist.log_prob(action) | |
entropy_loss = (dist.entropy() * weight).mean() | |
policy_loss = -(logp * adv * weight).mean() | |
value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean() | |
return a2c_loss(policy_loss, value_loss, entropy_loss) | |