Spaces:
Sleeping
Sleeping
import pytest | |
from itertools import product | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ding.model.template import DecisionTransformer | |
from ding.torch_utils import is_differentiable | |
action_space = ['continuous', 'discrete'] | |
state_encoder = [None, nn.Sequential(nn.Flatten(), nn.Linear(8, 8), nn.Tanh())] | |
args = list(product(*[action_space, state_encoder])) | |
args.pop(1) | |
def test_decision_transformer(action_space, state_encoder): | |
B, T = 4, 6 | |
if state_encoder: | |
state_dim = (2, 2, 2) | |
else: | |
state_dim = 3 | |
act_dim = 2 | |
DT_model = DecisionTransformer( | |
state_dim=state_dim, | |
act_dim=act_dim, | |
state_encoder=state_encoder, | |
n_blocks=3, | |
h_dim=8, | |
context_len=T, | |
n_heads=2, | |
drop_p=0.1, | |
continuous=(action_space == 'continuous') | |
) | |
DT_model.configure_optimizers(1.0, 0.0003) | |
is_continuous = True if action_space == 'continuous' else False | |
if state_encoder: | |
timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T | |
else: | |
timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T | |
if isinstance(state_dim, int): | |
states = torch.randn([B, T, state_dim]) # B x T x state_dim | |
else: | |
states = torch.randn([B, T, *state_dim]) # B x T x state_dim | |
if action_space == 'continuous': | |
actions = torch.randn([B, T, act_dim]) # B x T x act_dim | |
action_target = torch.randn([B, T, act_dim]) | |
else: | |
actions = torch.randint(0, act_dim, [B, T, 1]) | |
action_target = torch.randint(0, act_dim, [B, T, 1]) | |
returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]) | |
returns_to_go = returns_to_go_sample.repeat([B, 1]).unsqueeze(-1) # B x T x 1 | |
# all ones since no padding | |
traj_mask = torch.ones([B, T], dtype=torch.long) # B x T | |
if is_continuous: | |
assert action_target.shape == (B, T, act_dim) | |
else: | |
assert action_target.shape == (B, T, 1) | |
actions = actions.squeeze(-1) | |
returns_to_go = returns_to_go.float() | |
state_preds, action_preds, return_preds = DT_model.forward( | |
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go | |
) | |
if state_encoder: | |
assert state_preds is None | |
assert return_preds is None | |
else: | |
assert state_preds.shape == (B, T, state_dim) | |
assert return_preds.shape == (B, T, 1) | |
assert action_preds.shape == (B, T, act_dim) | |
# only consider non padded elements | |
if state_encoder: | |
action_preds = action_preds.reshape(-1, act_dim) | |
else: | |
action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1, ) > 0] | |
if is_continuous: | |
action_target = action_target.view(-1, act_dim)[traj_mask.view(-1, ) > 0] | |
else: | |
action_target = action_target.view(-1)[traj_mask.view(-1, ) > 0] | |
if is_continuous: | |
action_loss = F.mse_loss(action_preds, action_target) | |
else: | |
action_loss = F.cross_entropy(action_preds, action_target) | |
if state_encoder: | |
is_differentiable( | |
action_loss, [DT_model.transformer, DT_model.embed_action, DT_model.embed_rtg, DT_model.state_encoder] | |
) | |
else: | |
is_differentiable( | |
action_loss, [ | |
DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg, | |
DT_model.embed_state | |
] | |
) | |