Spaces:
Sleeping
Sleeping
import copy | |
from copy import deepcopy | |
from collections import OrderedDict | |
import pytest | |
import torch | |
import torch.nn as nn | |
from ditk import logging | |
from ding.torch_utils import get_lstm | |
from ding.torch_utils.network.gtrxl import GTrXL | |
from ding.model import model_wrap, register_wrapper, IModelWrapper | |
from ding.model.wrapper.model_wrappers import BaseModelWrapper | |
class TempMLP(torch.nn.Module): | |
def __init__(self): | |
super(TempMLP, self).__init__() | |
self.fc1 = nn.Linear(3, 4) | |
self.bn1 = nn.BatchNorm1d(4) | |
self.fc2 = nn.Linear(4, 6) | |
self.act = nn.ReLU() | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.bn1(x) | |
x = self.act(x) | |
x = self.fc2(x) | |
x = self.act(x) | |
return x | |
class ActorMLP(torch.nn.Module): | |
def __init__(self): | |
super(ActorMLP, self).__init__() | |
self.fc1 = nn.Linear(3, 4) | |
self.bn1 = nn.BatchNorm1d(4) | |
self.fc2 = nn.Linear(4, 6) | |
self.act = nn.ReLU() | |
self.out = nn.Softmax(dim=-1) | |
def forward(self, inputs, tmp=0): | |
x = self.fc1(inputs['obs']) | |
x = self.bn1(x) | |
x = self.act(x) | |
x = self.fc2(x) | |
x = self.act(x) | |
x = self.out(x) | |
ret = {'logit': x, 'tmp': tmp, 'action': x + torch.rand_like(x)} | |
if 'mask' in inputs: | |
ret['action_mask'] = inputs['mask'] | |
return ret | |
class HybridActorMLP(torch.nn.Module): | |
def __init__(self): | |
super(HybridActorMLP, self).__init__() | |
self.fc1 = nn.Linear(3, 4) | |
self.bn1 = nn.BatchNorm1d(4) | |
self.fc2 = nn.Linear(4, 6) | |
self.act = nn.ReLU() | |
self.out = nn.Softmax(dim=-1) | |
self.fc2_cont = nn.Linear(4, 6) | |
self.act_cont = nn.ReLU() | |
def forward(self, inputs, tmp=0): | |
x = self.fc1(inputs['obs']) | |
x = self.bn1(x) | |
x_ = self.act(x) | |
x = self.fc2(x_) | |
x = self.act(x) | |
x_disc = self.out(x) | |
x = self.fc2_cont(x_) | |
x_cont = self.act_cont(x) | |
ret = {'logit': x_disc, 'action_args': x_cont, 'tmp': tmp} | |
if 'mask' in inputs: | |
ret['action_mask'] = inputs['mask'] | |
return ret | |
class HybridReparamActorMLP(torch.nn.Module): | |
def __init__(self): | |
super(HybridReparamActorMLP, self).__init__() | |
self.fc1 = nn.Linear(3, 4) | |
self.bn1 = nn.BatchNorm1d(4) | |
self.fc2 = nn.Linear(4, 6) | |
self.act = nn.ReLU() | |
self.out = nn.Softmax(dim=-1) | |
self.fc2_cont_mu = nn.Linear(4, 6) | |
self.act_cont_mu = nn.ReLU() | |
self.fc2_cont_sigma = nn.Linear(4, 6) | |
self.act_cont_sigma = nn.ReLU() | |
def forward(self, inputs, tmp=0): | |
x = self.fc1(inputs['obs']) | |
x = self.bn1(x) | |
x_ = self.act(x) | |
x = self.fc2(x_) | |
x = self.act(x) | |
x_disc = self.out(x) | |
x = self.fc2_cont_mu(x_) | |
x_cont_mu = self.act_cont_mu(x) | |
x = self.fc2_cont_sigma(x_) | |
x_cont_sigma = self.act_cont_sigma(x) + 1e-8 | |
ret = {'logit': {'action_type': x_disc, 'action_args': {'mu': x_cont_mu, 'sigma': x_cont_sigma}}, 'tmp': tmp} | |
if 'mask' in inputs: | |
ret['action_mask'] = inputs['mask'] | |
return ret | |
class ReparamActorMLP(torch.nn.Module): | |
def __init__(self): | |
super(ReparamActorMLP, self).__init__() | |
self.fc1 = nn.Linear(3, 4) | |
self.bn1 = nn.BatchNorm1d(4) | |
self.fc2 = nn.Linear(4, 6) | |
self.act = nn.ReLU() | |
self.fc2_cont_mu = nn.Linear(4, 6) | |
self.fc2_cont_sigma = nn.Linear(4, 6) | |
def forward(self, inputs, tmp=0): | |
x = self.fc1(inputs['obs']) | |
x = self.bn1(x) | |
x_ = self.act(x) | |
x = self.fc2_cont_mu(x_) | |
x_cont_mu = self.act(x) | |
x = self.fc2_cont_sigma(x_) | |
x_cont_sigma = self.act(x) + 1e-8 | |
ret = {'logit': {'mu': x_cont_mu, 'sigma': x_cont_sigma}, 'tmp': tmp} | |
if 'mask' in inputs: | |
ret['action_mask'] = inputs['mask'] | |
return ret | |
class DeterministicActorMLP(torch.nn.Module): | |
def __init__(self): | |
super(DeterministicActorMLP, self).__init__() | |
self.fc1 = nn.Linear(3, 4) | |
self.bn1 = nn.BatchNorm1d(4) | |
self.act = nn.ReLU() | |
self.fc2_cont_mu = nn.Linear(4, 6) | |
self.act_cont_mu = nn.ReLU() | |
def forward(self, inputs): | |
x = self.fc1(inputs['obs']) | |
x = self.bn1(x) | |
x_ = self.act(x) | |
x = self.fc2_cont_mu(x_) | |
x_cont_mu = self.act_cont_mu(x) | |
ret = { | |
'logit': { | |
'mu': x_cont_mu, | |
} | |
} | |
if 'mask' in inputs: | |
ret['action_mask'] = inputs['mask'] | |
return ret | |
class TempLSTM(torch.nn.Module): | |
def __init__(self): | |
super(TempLSTM, self).__init__() | |
self.model = get_lstm(lstm_type='pytorch', input_size=36, hidden_size=32, num_layers=2, norm_type=None) | |
def forward(self, data): | |
output, next_state = self.model(data['f'], data['prev_state'], list_next_state=True) | |
return {'output': output, 'next_state': next_state} | |
def setup_model(): | |
return torch.nn.Linear(3, 6) | |
class TestModelWrappers: | |
def test_hidden_state_wrapper(self): | |
model = TempLSTM() | |
state_num = 4 | |
model = model_wrap(model, wrapper_name='hidden_state', state_num=state_num, save_prev_state=True) | |
model.reset() | |
data = {'f': torch.randn(2, 4, 36)} | |
output = model.forward(data) | |
assert output['output'].shape == (2, state_num, 32) | |
assert len(output['prev_state']) == 4 | |
assert output['prev_state'][0]['h'].shape == (2, 1, 32) | |
for item in model._state.values(): | |
assert isinstance(item, dict) and len(item) == 2 | |
assert all(t.shape == (2, 1, 32) for t in item.values()) | |
data = {'f': torch.randn(2, 3, 36)} | |
data_id = [0, 1, 3] | |
output = model.forward(data, data_id=data_id) | |
assert output['output'].shape == (2, 3, 32) | |
assert all([len(s) == 2 for s in output['prev_state']]) | |
for item in model._state.values(): | |
assert isinstance(item, dict) and len(item) == 2 | |
assert all(t.shape == (2, 1, 32) for t in item.values()) | |
data = {'f': torch.randn(2, 2, 36)} | |
data_id = [0, 1] | |
output = model.forward(data, data_id=data_id) | |
assert output['output'].shape == (2, 2, 32) | |
assert all([isinstance(s, dict) and len(s) == 2 for s in model._state.values()]) | |
model.reset() | |
assert all([isinstance(s, type(None)) for s in model._state.values()]) | |
def test_target_network_wrapper(self): | |
model = TempMLP() | |
target_model = deepcopy(model) | |
target_model2 = deepcopy(model) | |
target_model = model_wrap(target_model, wrapper_name='target', update_type='assign', update_kwargs={'freq': 2}) | |
model = model_wrap(model, wrapper_name='base') | |
register_wrapper('abstract', IModelWrapper) | |
assert all([hasattr(target_model, n) for n in ['reset', 'forward', 'update']]) | |
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12 | |
model.fc1.weight.data = torch.randn_like(model.fc1.weight) | |
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12 | |
target_model.update(model.state_dict(), direct=True) | |
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12 | |
model.reset() | |
target_model.reset() | |
inputs = torch.randn(2, 3) | |
model.train() | |
target_model.train() | |
output = model.forward(inputs) | |
with torch.no_grad(): | |
output_target = target_model.forward(inputs) | |
assert output.eq(output_target).sum() == 2 * 6 | |
model.fc1.weight.data = torch.randn_like(model.fc1.weight) | |
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12 | |
target_model.update(model.state_dict()) | |
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12 | |
target_model.update(model.state_dict()) | |
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12 | |
# test real reset update_count | |
assert target_model._update_count != 0 | |
target_model.reset() | |
assert target_model._update_count != 0 | |
target_model.reset(target_update_count=0) | |
assert target_model._update_count == 0 | |
target_model2 = model_wrap( | |
target_model2, wrapper_name='target', update_type='momentum', update_kwargs={'theta': 0.01} | |
) | |
target_model2.update(model.state_dict(), direct=True) | |
assert model.fc1.weight.eq(target_model2.fc1.weight).sum() == 12 | |
model.fc1.weight.data = torch.randn_like(model.fc1.weight) | |
old_state_dict = target_model2.state_dict() | |
target_model2.update(model.state_dict()) | |
assert target_model2.fc1.weight.data.eq( | |
old_state_dict['fc1.weight'] * (1 - 0.01) + model.fc1.weight.data * 0.01 | |
).all() | |
def test_eps_greedy_wrapper(self): | |
model = ActorMLP() | |
model = model_wrap(model, wrapper_name='eps_greedy_sample') | |
model.eval() | |
eps_threshold = 0.5 | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold) | |
assert output['tmp'] == 0 | |
for i in range(10): | |
if i == 5: | |
data.pop('mask') | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold, tmp=1) | |
assert isinstance(output, dict) | |
assert output['tmp'] == 1 | |
def test_multinomial_sample_wrapper(self): | |
model = model_wrap(ActorMLP(), wrapper_name='multinomial_sample') | |
data = {'obs': torch.randn(4, 3)} | |
output = model.forward(data) | |
assert output['action'].shape == (4, ) | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
output = model.forward(data) | |
assert output['action'].shape == (4, ) | |
def test_eps_greedy_multinomial_wrapper(self): | |
model = ActorMLP() | |
model = model_wrap(model, wrapper_name='eps_greedy_multinomial_sample') | |
model.eval() | |
eps_threshold = 0.5 | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold, alpha=0.2) | |
assert output['tmp'] == 0 | |
for i in range(10): | |
if i == 5: | |
data.pop('mask') | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold, tmp=1, alpha=0.2) | |
assert isinstance(output, dict) | |
assert output['tmp'] == 1 | |
def test_hybrid_eps_greedy_wrapper(self): | |
model = HybridActorMLP() | |
model = model_wrap(model, wrapper_name='hybrid_eps_greedy_sample') | |
model.eval() | |
eps_threshold = 0.5 | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold) | |
# logit = output['logit'] | |
# assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all() | |
assert isinstance(output['action']['action_args'], | |
torch.Tensor) and output['action']['action_args'].shape == (4, 6) | |
for i in range(10): | |
if i == 5: | |
data.pop('mask') | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold, tmp=1) | |
assert isinstance(output, dict) | |
def test_hybrid_eps_greedy_multinomial_wrapper(self): | |
model = HybridActorMLP() | |
model = model_wrap(model, wrapper_name='hybrid_eps_greedy_multinomial_sample') | |
model.eval() | |
eps_threshold = 0.5 | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold) | |
assert isinstance(output['logit'], torch.Tensor) and output['logit'].shape == (4, 6) | |
assert isinstance(output['action']['action_type'], | |
torch.Tensor) and output['action']['action_type'].shape == (4, ) | |
assert isinstance(output['action']['action_args'], | |
torch.Tensor) and output['action']['action_args'].shape == (4, 6) | |
for i in range(10): | |
if i == 5: | |
data.pop('mask') | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold, tmp=1) | |
assert isinstance(output, dict) | |
def test_hybrid_reparam_multinomial_wrapper(self): | |
model = HybridReparamActorMLP() | |
model = model_wrap(model, wrapper_name='hybrid_reparam_multinomial_sample') | |
model.eval() | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
with torch.no_grad(): | |
output = model.forward(data) | |
assert isinstance(output['logit'], dict) and output['logit']['action_type'].shape == (4, 6) | |
assert isinstance(output['logit']['action_args'], dict) and output['logit']['action_args']['mu'].shape == ( | |
4, 6 | |
) and output['logit']['action_args']['sigma'].shape == (4, 6) | |
assert isinstance(output['action']['action_type'], | |
torch.Tensor) and output['action']['action_type'].shape == (4, ) | |
assert isinstance(output['action']['action_args'], | |
torch.Tensor) and output['action']['action_args'].shape == (4, 6) | |
for i in range(10): | |
if i == 5: | |
data.pop('mask') | |
with torch.no_grad(): | |
output = model.forward(data, tmp=1) | |
assert isinstance(output, dict) | |
def test_argmax_sample_wrapper(self): | |
model = model_wrap(ActorMLP(), wrapper_name='argmax_sample') | |
data = {'obs': torch.randn(4, 3)} | |
output = model.forward(data) | |
logit = output['logit'] | |
assert output['action'].eq(logit.argmax(dim=-1)).all() | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
output = model.forward(data) | |
logit = output['logit'].sub(1e8 * (1 - data['mask'])) | |
assert output['action'].eq(logit.argmax(dim=-1)).all() | |
def test_hybrid_argmax_sample_wrapper(self): | |
model = model_wrap(HybridActorMLP(), wrapper_name='hybrid_argmax_sample') | |
data = {'obs': torch.randn(4, 3)} | |
output = model.forward(data) | |
logit = output['logit'] | |
assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all() | |
assert isinstance(output['action']['action_args'], | |
torch.Tensor) and output['action']['action_args'].shape == (4, 6) | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
output = model.forward(data) | |
logit = output['logit'].sub(1e8 * (1 - data['mask'])) | |
assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all() | |
assert output['action']['action_args'].shape == (4, 6) | |
def test_hybrid_deterministic_argmax_sample_wrapper(self): | |
model = model_wrap(HybridReparamActorMLP(), wrapper_name='hybrid_deterministic_argmax_sample') | |
data = {'obs': torch.randn(4, 3)} | |
output = model.forward(data) | |
assert output['action']['action_type'].eq(output['logit']['action_type'].argmax(dim=-1)).all() | |
assert isinstance(output['action']['action_args'], | |
torch.Tensor) and output['action']['action_args'].shape == (4, 6) | |
assert output['action']['action_args'].eq(output['logit']['action_args']['mu']).all | |
def test_deterministic_sample_wrapper(self): | |
model = model_wrap(DeterministicActorMLP(), wrapper_name='deterministic_sample') | |
data = {'obs': torch.randn(4, 3)} | |
output = model.forward(data) | |
assert output['action'].eq(output['logit']['mu']).all() | |
assert isinstance(output['action'], torch.Tensor) and output['action'].shape == (4, 6) | |
def test_reparam_wrapper(self): | |
model = ReparamActorMLP() | |
model = model_wrap(model, wrapper_name='reparam_sample') | |
model.eval() | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
with torch.no_grad(): | |
output = model.forward(data) | |
assert isinstance(output['logit'], | |
dict) and output['logit']['mu'].shape == (4, 6) and output['logit']['sigma'].shape == (4, 6) | |
for i in range(10): | |
if i == 5: | |
data.pop('mask') | |
with torch.no_grad(): | |
output = model.forward(data, tmp=1) | |
assert isinstance(output, dict) | |
def test_eps_greedy_wrapper_with_list_eps(self): | |
model = ActorMLP() | |
model = model_wrap(model, wrapper_name='eps_greedy_sample') | |
model.eval() | |
eps_threshold = {i: 0.5 for i in range(4)} # for NGU | |
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold) | |
assert output['tmp'] == 0 | |
for i in range(10): | |
if i == 5: | |
data.pop('mask') | |
with torch.no_grad(): | |
output = model.forward(data, eps=eps_threshold, tmp=1) | |
assert isinstance(output, dict) | |
assert output['tmp'] == 1 | |
def test_action_noise_wrapper(self): | |
model = model_wrap( | |
ActorMLP(), | |
wrapper_name='action_noise', | |
noise_type='gauss', | |
noise_range={ | |
'min': -0.1, | |
'max': 0.1 | |
}, | |
action_range={ | |
'min': -0.05, | |
'max': 0.05 | |
} | |
) | |
data = {'obs': torch.randn(4, 3)} | |
output = model.forward(data) | |
action = output['action'] | |
assert action.shape == (4, 6) | |
assert action.eq(action.clamp(-0.05, 0.05)).all() | |
def test_transformer_input_wrapper(self): | |
seq_len, bs, obs_shape = 8, 8, 32 | |
emb_dim = 64 | |
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim) | |
model = model_wrap(model, wrapper_name='transformer_input', seq_len=seq_len) | |
obs = [] | |
for i in range(seq_len + 1): | |
obs.append(torch.randn((bs, obs_shape))) | |
out = model.forward(obs[0], only_last_logit=False) | |
assert out['logit'].shape == (seq_len, bs, emb_dim) | |
assert out['input_seq'].shape == (seq_len, bs, obs_shape) | |
assert sum(out['input_seq'][1:].flatten()) == 0 | |
for i in range(1, seq_len - 1): | |
out = model.forward(obs[i]) | |
assert out['logit'].shape == (bs, emb_dim) | |
assert out['input_seq'].shape == (seq_len, bs, obs_shape) | |
assert sum(out['input_seq'][seq_len - 1:].flatten()) == 0 | |
assert sum(out['input_seq'][:seq_len - 1].flatten()) != 0 | |
out = model.forward(obs[seq_len - 1]) | |
prev_memory = torch.clone(out['input_seq']) | |
out = model.forward(obs[seq_len]) | |
assert torch.all(torch.eq(out['input_seq'][seq_len - 2], prev_memory[seq_len - 1])) | |
# test update of single batches in the memory | |
model.reset(data_id=[0, 5]) # reset memory batch in position 0 and 5 | |
assert sum(model.obs_memory[:, 0].flatten()) == 0 and sum(model.obs_memory[:, 5].flatten()) == 0 | |
assert sum(model.obs_memory[:, 1].flatten()) != 0 | |
assert model.memory_idx[0] == 0 and model.memory_idx[5] == 0 and model.memory_idx[1] == seq_len | |
# test reset | |
model.reset() | |
assert model.obs_memory is None | |
def test_transformer_segment_wrapper(self): | |
seq_len, bs, obs_shape = 12, 8, 32 | |
layer_num, memory_len, emb_dim = 3, 4, 4 | |
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num) | |
model = model_wrap(model, wrapper_name='transformer_segment', seq_len=seq_len) | |
inputs1 = torch.randn((seq_len, bs, obs_shape)) | |
out = model.forward(inputs1) | |
info = model.info('info') | |
info = model.info('x') | |
def test_transformer_memory_wrapper(self): | |
seq_len, bs, obs_shape = 12, 8, 32 | |
layer_num, memory_len, emb_dim = 3, 4, 4 | |
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num) | |
model1 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs) | |
model2 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs) | |
model1.show_memory_occupancy() | |
inputs1 = torch.randn((seq_len, bs, obs_shape)) | |
out = model1.forward(inputs1) | |
new_memory1 = model1.memory | |
inputs2 = torch.randn((seq_len, bs, obs_shape)) | |
out = model2.forward(inputs2) | |
new_memory2 = model2.memory | |
assert not torch.all(torch.eq(new_memory1, new_memory2)) | |
model1.reset(data_id=[0, 5]) | |
assert sum(model1.memory[:, :, 0].flatten()) == 0 and sum(model1.memory[:, :, 5].flatten()) == 0 | |
assert sum(model1.memory[:, :, 1].flatten()) != 0 | |
model1.reset() | |
assert sum(model1.memory.flatten()) == 0 | |
seq_len, bs, obs_shape = 8, 8, 32 | |
layer_num, memory_len, emb_dim = 3, 20, 4 | |
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num) | |
model = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs) | |
inputs1 = torch.randn((seq_len, bs, obs_shape)) | |
out = model.forward(inputs1) | |
new_memory1 = model.memory | |
inputs2 = torch.randn((seq_len, bs, obs_shape)) | |
out = model.forward(inputs2) | |
new_memory2 = model.memory | |
print(new_memory1.shape, inputs1.shape) | |
assert sum(new_memory1[:, -8:].flatten()) != 0 | |
assert sum(new_memory1[:, :-8].flatten()) == 0 | |
assert sum(new_memory2[:, -16:].flatten()) != 0 | |
assert sum(new_memory2[:, :-16].flatten()) == 0 | |
assert torch.all(torch.eq(new_memory1[:, -8:], new_memory2[:, -16:-8])) | |
def test_combination_argmax_sample_wrapper(self): | |
model = model_wrap(ActorMLP(), wrapper_name='combination_argmax_sample') | |
data = {'obs': torch.randn(4, 3)} | |
shot_number = 2 | |
output = model.forward(shot_number=shot_number, inputs=data) | |
assert output['action'].shape == (4, shot_number) | |
assert (output['action'] >= 0).all() and (output['action'] < 64).all() | |
def test_combination_multinomial_sample_wrapper(self): | |
model = model_wrap(ActorMLP(), wrapper_name='combination_multinomial_sample') | |
data = {'obs': torch.randn(4, 3)} | |
shot_number = 2 | |
output = model.forward(shot_number=shot_number, inputs=data) | |
assert output['action'].shape == (4, shot_number) | |
assert (output['action'] >= 0).all() and (output['action'] < 64).all() | |