Spaces:
Sleeping
Sleeping
import pytest | |
from collections import namedtuple | |
import random | |
import numpy as np | |
import torch | |
from ding.utils.data import timestep_collate, default_collate, default_decollate, diff_shape_collate | |
B, T = 4, 3 | |
class TestTimestepCollate: | |
def get_data(self): | |
data = { | |
'obs': [torch.randn(4) for _ in range(T)], | |
'reward': [torch.FloatTensor([0]) for _ in range(T)], | |
'done': [False for _ in range(T)], | |
'prev_state': [(torch.randn(3), torch.randn(3)) for _ in range(T)], | |
'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)], | |
} | |
return data | |
def get_multi_shape_state_data(self): | |
data = { | |
'obs': [torch.randn(4) for _ in range(T)], | |
'reward': [torch.FloatTensor([0]) for _ in range(T)], | |
'done': [False for _ in range(T)], | |
'prev_state': [ | |
[(torch.randn(3), torch.randn(5)), (torch.randn(4), ), (torch.randn(5), torch.randn(6))] | |
for _ in range(T) | |
], | |
'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)], | |
} | |
return data | |
def test(self): | |
batch = timestep_collate([self.get_data() for _ in range(B)]) | |
assert isinstance(batch, dict) | |
assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action']) | |
assert batch['obs'].shape == (T, B, 4) | |
assert batch['reward'].shape == (T, B) | |
assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool | |
assert isinstance(batch['prev_state'], list) | |
assert len(batch['prev_state']) == T and len(batch['prev_state'][0]) == B | |
assert isinstance(batch['action'], list) and len(batch['action']) == T | |
assert batch['action'][0][0].shape == (B, 3) | |
assert batch['action'][0][1].shape == (B, 5) | |
# hidden_state might contain multi prev_states with different shapes | |
batch = timestep_collate([self.get_multi_shape_state_data() for _ in range(B)]) | |
assert isinstance(batch, dict) | |
assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action']) | |
assert batch['obs'].shape == (T, B, 4) | |
assert batch['reward'].shape == (T, B) | |
assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool | |
assert isinstance(batch['prev_state'], list) | |
print(batch['prev_state'][0][0]) | |
assert len(batch['prev_state']) == T and len(batch['prev_state'][0] | |
) == B and len(batch['prev_state'][0][0]) == 3 | |
assert isinstance(batch['action'], list) and len(batch['action']) == T | |
assert batch['action'][0][0].shape == (B, 3) | |
assert batch['action'][0][1].shape == (B, 5) | |
class TestDefaultCollate: | |
def test_numpy(self): | |
data = [np.random.randn(4, 3).astype(np.float64) for _ in range(5)] | |
data = default_collate(data) | |
assert data.shape == (5, 4, 3) | |
assert data.dtype == torch.float64 | |
data = [float(np.random.randn(1)[0]) for _ in range(6)] | |
data = default_collate(data) | |
assert data.shape == (6, ) | |
assert data.dtype == torch.float32 | |
with pytest.raises(TypeError): | |
default_collate([np.array(['str']) for _ in range(3)]) | |
def test_basic(self): | |
data = [random.random() for _ in range(3)] | |
data = default_collate(data) | |
assert data.shape == (3, ) | |
assert data.dtype == torch.float32 | |
data = [random.randint(0, 10) for _ in range(3)] | |
data = default_collate(data) | |
assert data.shape == (3, ) | |
assert data.dtype == torch.int64 | |
data = ['str' for _ in range(4)] | |
data = default_collate(data) | |
assert len(data) == 4 | |
assert all([s == 'str' for s in data]) | |
T = namedtuple('T', ['x', 'y']) | |
data = [T(1, 2) for _ in range(4)] | |
data = default_collate(data) | |
assert isinstance(data, T) | |
assert data.x.shape == (4, ) and data.x.eq(1).sum() == 4 | |
assert data.y.shape == (4, ) and data.y.eq(2).sum() == 4 | |
with pytest.raises(TypeError): | |
default_collate([object() for _ in range(4)]) | |
data = [{'collate_ignore_data': random.random()} for _ in range(4)] | |
data = default_collate(data) | |
assert isinstance(data, dict) | |
assert len(data['collate_ignore_data']) == 4 | |
class TestDefaultDecollate: | |
def test(self): | |
with pytest.raises(TypeError): | |
default_decollate([object() for _ in range(4)]) | |
data = torch.randn(4, 3, 5) | |
data = default_decollate(data) | |
print([d.shape for d in data]) | |
assert len(data) == 4 and all([d.shape == (3, 5) for d in data]) | |
data = [torch.randn(8, 2, 4), torch.randn(8, 5)] | |
data = default_decollate(data) | |
assert len(data) == 8 and all([d[0].shape == (2, 4) and d[1].shape == (5, ) for d in data]) | |
data = { | |
'logit': torch.randn(4, 13), | |
'action': torch.randint(0, 13, size=(4, )), | |
'prev_state': [(torch.zeros(3, 1, 12), torch.zeros(3, 1, 12)) for _ in range(4)], | |
} | |
data = default_decollate(data) | |
assert len(data) == 4 and isinstance(data, list) | |
assert all([d['logit'].shape == (13, ) for d in data]) | |
assert all([d['action'].shape == (1, ) for d in data]) | |
assert all([len(d['prev_state']) == 2 and d['prev_state'][0].shape == (3, 1, 12) for d in data]) | |
class TestDiffShapeCollate: | |
def test(self): | |
with pytest.raises(TypeError): | |
diff_shape_collate([object() for _ in range(4)]) | |
data = [ | |
{ | |
'item1': torch.randn(4), | |
'item2': None, | |
'item3': torch.randn(3), | |
'item4': np.random.randn(5, 6) | |
}, | |
{ | |
'item1': torch.randn(5), | |
'item2': torch.randn(6), | |
'item3': torch.randn(3), | |
'item4': np.random.randn(5, 6) | |
}, | |
] | |
data = diff_shape_collate(data) | |
assert isinstance(data['item1'], list) and len(data['item1']) == 2 | |
assert isinstance(data['item2'], list) and len(data['item2']) == 2 and data['item2'][0] is None | |
assert data['item3'].shape == (2, 3) | |
assert data['item4'].shape == (2, 5, 6) | |
data = [ | |
{ | |
'item1': 1, | |
'item2': 3, | |
'item3': 2.0 | |
}, | |
{ | |
'item1': None, | |
'item2': 4, | |
'item3': 2.0 | |
}, | |
] | |
data = diff_shape_collate(data) | |
assert isinstance(data['item1'], list) and len(data['item1']) == 2 and data['item1'][1] is None | |
assert data['item2'].shape == (2, ) and data['item2'].dtype == torch.int64 | |
assert data['item3'].shape == (2, ) and data['item3'].dtype == torch.float32 | |