Spaces:
Sleeping
Sleeping
from collections import namedtuple | |
import numpy as np | |
import pytest | |
import torch | |
import treetensor.torch as ttorch | |
from ding.utils.default_helper import lists_to_dicts, dicts_to_lists, squeeze, default_get, override, error_wrapper, \ | |
list_split, LimitedSpaceContainer, set_pkg_seed, deep_merge_dicts, deep_update, flatten_dict, RunningMeanStd, \ | |
one_time_warning, split_data_generator, get_shape0 | |
class TestDefaultHelper(): | |
def test_get_shape0(self): | |
a = { | |
'a': { | |
'b': torch.randn(4, 3) | |
}, | |
'c': { | |
'd': torch.randn(4) | |
}, | |
} | |
b = [a, a] | |
c = (a, a) | |
d = { | |
'a': { | |
'b': ["a", "b", "c", "d"] | |
}, | |
'c': { | |
'd': torch.randn(4) | |
}, | |
} | |
a = ttorch.as_tensor(a) | |
assert get_shape0(a) == 4 | |
assert get_shape0(b) == 4 | |
assert get_shape0(c) == 4 | |
with pytest.raises(Exception) as e_info: | |
assert get_shape0(d) == 4 | |
def test_lists_to_dicts(self): | |
set_pkg_seed(12) | |
with pytest.raises(ValueError): | |
lists_to_dicts([]) | |
with pytest.raises(TypeError): | |
lists_to_dicts([1]) | |
assert lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) == {1: [1, 2], 10: [3, 4]} | |
T = namedtuple('T', ['location', 'race']) | |
data = [T({'x': 1, 'y': 2}, 'zerg') for _ in range(3)] | |
output = lists_to_dicts(data) | |
assert isinstance(output, T) and output.__class__ == T | |
assert len(output.location) == 3 | |
data = [{'value': torch.randn(1), 'obs': {'scalar': torch.randn(4)}} for _ in range(3)] | |
output = lists_to_dicts(data, recursive=True) | |
assert isinstance(output, dict) | |
assert len(output['value']) == 3 | |
assert len(output['obs']['scalar']) == 3 | |
def test_dicts_to_lists(self): | |
assert dicts_to_lists({1: [1, 2], 10: [3, 4]}) == [{1: 1, 10: 3}, {1: 2, 10: 4}] | |
def test_squeeze(self): | |
assert squeeze((4, )) == 4 | |
assert squeeze({'a': 4}) == 4 | |
assert squeeze([1, 3]) == (1, 3) | |
data = np.random.randn(3) | |
output = squeeze(data) | |
assert (output == data).all() | |
def test_default_get(self): | |
assert default_get({}, 'a', default_value=1, judge_fn=lambda x: x < 2) == 1 | |
assert default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 2) == 1 | |
with pytest.raises(AssertionError): | |
default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 0) | |
assert default_get({'val': 1}, 'val', default_value=2) == 1 | |
def test_override(self): | |
class foo(object): | |
def fun(self): | |
raise NotImplementedError | |
class foo1(foo): | |
def fun(self): | |
return "a" | |
with pytest.raises(NameError): | |
class foo2(foo): | |
def func(self): | |
pass | |
with pytest.raises(NotImplementedError): | |
foo().fun() | |
foo1().fun() | |
def test_error_wrapper(self): | |
def good_ret(a, b=1): | |
return a + b | |
wrap_good_ret = error_wrapper(good_ret, 0) | |
assert good_ret(1) == wrap_good_ret(1) | |
def bad_ret(a, b=0): | |
return a / b | |
wrap_bad_ret = error_wrapper(bad_ret, 0) | |
assert wrap_bad_ret(1) == 0 | |
wrap_bad_ret_with_customized_log = error_wrapper(bad_ret, 0, 'customized_information') | |
def test_list_split(self): | |
data = [i for i in range(10)] | |
output, residual = list_split(data, step=4) | |
assert len(output) == 2 | |
assert output[1] == [4, 5, 6, 7] | |
assert residual == [8, 9] | |
output, residual = list_split(data, step=5) | |
assert len(output) == 2 | |
assert output[1] == [5, 6, 7, 8, 9] | |
assert residual is None | |
class TestLimitedSpaceContainer(): | |
def test_container(self): | |
container = LimitedSpaceContainer(0, 5) | |
first = container.acquire_space() | |
assert first | |
assert container.cur == 1 | |
left = container.get_residual_space() | |
assert left == 4 | |
assert container.cur == container.max_val == 5 | |
no_space = container.acquire_space() | |
assert not no_space | |
container.increase_space() | |
six = container.acquire_space() | |
assert six | |
for i in range(6): | |
container.release_space() | |
assert container.cur == 5 - i | |
container.decrease_space() | |
assert container.max_val == 5 | |
class TestDict: | |
def test_deep_merge_dicts(self): | |
dict1 = { | |
'a': 3, | |
'b': { | |
'c': 3, | |
'd': { | |
'e': 6, | |
'f': 5, | |
} | |
} | |
} | |
dict2 = { | |
'b': { | |
'c': 5, | |
'd': 6, | |
'g': 4, | |
} | |
} | |
new_dict = deep_merge_dicts(dict1, dict2) | |
assert new_dict['a'] == 3 | |
assert isinstance(new_dict['b'], dict) | |
assert new_dict['b']['c'] == 5 | |
assert new_dict['b']['c'] == 5 | |
assert new_dict['b']['g'] == 4 | |
def test_deep_update(self): | |
dict1 = { | |
'a': 3, | |
'b': { | |
'c': 3, | |
'd': { | |
'e': 6, | |
'f': 5, | |
}, | |
'z': 4, | |
} | |
} | |
dict2 = { | |
'b': { | |
'c': 5, | |
'd': 6, | |
'g': 4, | |
} | |
} | |
with pytest.raises(RuntimeError): | |
new1 = deep_update(dict1, dict2, new_keys_allowed=False) | |
new2 = deep_update(dict1, dict2, new_keys_allowed=False, whitelist=['b']) | |
assert new2['a'] == 3 | |
assert new2['b']['c'] == 5 | |
assert new2['b']['d'] == 6 | |
assert new2['b']['g'] == 4 | |
assert new2['b']['z'] == 4 | |
dict1 = { | |
'a': 3, | |
'b': { | |
'type': 'old', | |
'z': 4, | |
} | |
} | |
dict2 = { | |
'b': { | |
'type': 'new', | |
'c': 5, | |
} | |
} | |
new3 = deep_update(dict1, dict2, new_keys_allowed=True, whitelist=[], override_all_if_type_changes=['b']) | |
assert new3['a'] == 3 | |
assert new3['b']['type'] == 'new' | |
assert new3['b']['c'] == 5 | |
assert 'z' not in new3['b'] | |
def test_flatten_dict(self): | |
dict = { | |
'a': 3, | |
'b': { | |
'c': 3, | |
'd': { | |
'e': 6, | |
'f': 5, | |
}, | |
'z': 4, | |
} | |
} | |
flat = flatten_dict(dict) | |
assert flat['a'] == 3 | |
assert flat['b/c'] == 3 | |
assert flat['b/d/e'] == 6 | |
assert flat['b/d/f'] == 5 | |
assert flat['b/z'] == 4 | |
def test_one_time_warning(self): | |
one_time_warning('test_one_time_warning') | |
def test_running_mean_std(self): | |
running = RunningMeanStd() | |
running.reset() | |
running.update(np.arange(1, 10)) | |
assert running.mean == pytest.approx(5, abs=1e-4) | |
assert running.std == pytest.approx(2.582030, abs=1e-6) | |
running.update(np.arange(2, 11)) | |
assert running.mean == pytest.approx(5.5, abs=1e-4) | |
assert running.std == pytest.approx(2.629981, abs=1e-6) | |
running.reset() | |
running.update(np.arange(1, 10)) | |
assert pytest.approx(running.mean, abs=1e-4) == 5 | |
assert running.mean == pytest.approx(5, abs=1e-4) | |
assert running.std == pytest.approx(2.582030, abs=1e-6) | |
new_shape = running.new_shape((2, 4), (3, ), (1, )) | |
assert isinstance(new_shape, tuple) and len(new_shape) == 3 | |
running = RunningMeanStd(shape=(4, )) | |
running.reset() | |
running.update(np.random.random((10, 4))) | |
assert isinstance(running.mean, torch.Tensor) and running.mean.shape == (4, ) | |
assert isinstance(running.std, torch.Tensor) and running.std.shape == (4, ) | |
def test_split_data_generator(self): | |
def get_data(): | |
return { | |
'obs': torch.randn(5), | |
'action': torch.randint(0, 10, size=(1, )), | |
'prev_state': [None, None], | |
'info': { | |
'other_obs': torch.randn(5) | |
}, | |
} | |
data = [get_data() for _ in range(4)] | |
data = lists_to_dicts(data) | |
data['obs'] = torch.stack(data['obs']) | |
data['action'] = torch.stack(data['action']) | |
data['info'] = {'other_obs': torch.stack([t['other_obs'] for t in data['info']])} | |
assert len(data['obs']) == 4 | |
data['NoneKey'] = None | |
generator = split_data_generator(data, 3) | |
generator_result = list(generator) | |
assert len(generator_result) == 2 | |
assert generator_result[0]['NoneKey'] is None | |
assert len(generator_result[0]['obs']) == 3 | |
assert generator_result[0]['info']['other_obs'].shape == (3, 5) | |
assert generator_result[1]['NoneKey'] is None | |
assert len(generator_result[1]['obs']) == 3 | |
assert generator_result[1]['info']['other_obs'].shape == (3, 5) | |
generator = split_data_generator(data, 3, shuffle=False) | |