Spaces:
Sleeping
Sleeping
from collections import namedtuple | |
import torch | |
from ding.hpc_rl import hpc_wrapper | |
gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done', 'traj_flag']) | |
def shape_fn_gae(args, kwargs): | |
r""" | |
Overview: | |
Return shape of gae for hpc | |
Returns: | |
shape: [T, B] | |
""" | |
if len(args) <= 0: | |
tmp = kwargs['data'].reward.shape | |
else: | |
tmp = args[0].reward.shape | |
return tmp | |
def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.FloatTensor: | |
""" | |
Overview: | |
Implementation of Generalized Advantage Estimator (arXiv:1506.02438) | |
Arguments: | |
- data (:obj:`namedtuple`): gae input data with fields ['value', 'reward'], which contains some episodes or \ | |
trajectories data. | |
- gamma (:obj:`float`): the future discount factor, should be in [0, 1], defaults to 0.99. | |
- lambda (:obj:`float`): the gae parameter lambda, should be in [0, 1], defaults to 0.97, when lambda -> 0, \ | |
it induces bias, but when lambda -> 1, it has high variance due to the sum of terms. | |
Returns: | |
- adv (:obj:`torch.FloatTensor`): the calculated advantage | |
Shapes: | |
- value (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is trajectory length and B is batch size | |
- next_value (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
- adv (:obj:`torch.FloatTensor`): :math:`(T, B)` | |
Examples: | |
>>> value = torch.randn(2, 3) | |
>>> next_value = torch.randn(2, 3) | |
>>> reward = torch.randn(2, 3) | |
>>> data = gae_data(value, next_value, reward, None, None) | |
>>> adv = gae(data) | |
""" | |
value, next_value, reward, done, traj_flag = data | |
if done is None: | |
done = torch.zeros_like(reward, device=reward.device) | |
if traj_flag is None: | |
traj_flag = done | |
done = done.float() | |
traj_flag = traj_flag.float() | |
if len(value.shape) == len(reward.shape) + 1: # for some marl case: value(T, B, A), reward(T, B) | |
reward = reward.unsqueeze(-1) | |
done = done.unsqueeze(-1) | |
traj_flag = traj_flag.unsqueeze(-1) | |
next_value *= (1 - done) | |
delta = reward + gamma * next_value - value | |
factor = gamma * lambda_ * (1 - traj_flag) | |
adv = torch.zeros_like(value) | |
gae_item = torch.zeros_like(value[0]) | |
for t in reversed(range(reward.shape[0])): | |
gae_item = delta[t] + factor[t] * gae_item | |
adv[t] = gae_item | |
return adv | |