Spaces:
Sleeping
Sleeping
import os | |
import random | |
import shutil | |
import numpy as np | |
import pytest | |
import torch | |
from ding.envs.common.common_function import sqrt_one_hot, div_one_hot, div_func, clip_one_hot, \ | |
reorder_one_hot, reorder_one_hot_array, reorder_boolean_vector, \ | |
batch_binary_encode, get_postion_vector, \ | |
affine_transform, save_frames_as_gif | |
VALUES = [2, 3, 5, 7, 11] | |
def setup_reorder_array(): | |
ret = np.full((12), -1) | |
for i, v in enumerate(VALUES): | |
ret[v] = i | |
return ret | |
def setup_reorder_dict(): | |
return {v: i for i, v in enumerate(VALUES)} | |
def generate_data(): | |
ret = { | |
'obs': np.random.randn(4), | |
} | |
p_weight = np.random.uniform() | |
if p_weight < 1. / 3: | |
pass # no key 'priority' | |
elif p_weight < 2. / 3: | |
ret['priority'] = None | |
else: | |
ret['priority'] = np.random.uniform() | |
return ret | |
class TestEnvCommonFunc: | |
def test_one_hot(self): | |
a = torch.Tensor([[3, 4, 5], [1, 2, 6]]) | |
a_sqrt = sqrt_one_hot(a, 6) | |
assert a_sqrt.max().item() == 1 | |
assert [j.sum().item() for i in a_sqrt for j in i] == [1 for _ in range(6)] | |
sqrt_dim = 3 | |
assert a_sqrt.shape == (2, 3, sqrt_dim) | |
a_div = div_one_hot(a, 6, 2) | |
assert a_div.max().item() == 1 | |
assert [j.sum().item() for i in a_div for j in i] == [1 for _ in range(6)] | |
div_dim = 4 | |
assert a_div.shape == (2, 3, div_dim) | |
a_di = div_func(a, 2) | |
assert a_di.shape == (2, 1, 3) | |
assert torch.eq(a_di.squeeze() * 2, a).all() | |
a_clip = clip_one_hot(a.long(), 4) | |
assert a_clip.max().item() == 1 | |
assert [j.sum().item() for i in a_clip for j in i] == [1 for _ in range(6)] | |
clip_dim = 4 | |
assert a_clip.shape == (2, 3, clip_dim) | |
def test_reorder(self, setup_reorder_array, setup_reorder_dict): | |
a = torch.LongTensor([2, 7]) # VALUES = [2, 3, 5, 7, 11] | |
a_array = reorder_one_hot_array(a, setup_reorder_array, 5) | |
a_dict = reorder_one_hot(a, setup_reorder_dict, 5) | |
assert torch.eq(a_array, a_dict).all() | |
assert a_array.max().item() == 1 | |
assert [j.sum().item() for j in a_array] == [1 for _ in range(2)] | |
reorder_dim = 5 | |
assert a_array.shape == (2, reorder_dim) | |
a_bool = reorder_boolean_vector(a, setup_reorder_dict, 5) | |
assert a_array.max().item() == 1 | |
assert torch.eq(a_bool, sum([_ for _ in a_array])).all() | |
def test_binary(self): | |
a = torch.LongTensor([445, 1023]) | |
a_binary = batch_binary_encode(a, 10) | |
ans = [] | |
for number in a: | |
one = [int(_) for _ in list(bin(number))[2:]] | |
for _ in range(10 - len(one)): | |
one.insert(0, 0) | |
ans.append(one) | |
ans = torch.Tensor(ans) | |
assert torch.eq(a_binary, ans).all() | |
def test_position(self): | |
a = [random.randint(0, 5000) for _ in range(32)] | |
a_position = get_postion_vector(a) | |
assert a_position.shape == (64, ) | |
def test_affine_transform(self): | |
a = torch.rand(4, 3) | |
a = (a - a.min()) / (a.max() - a.min()) | |
a = a * 2 - 1 | |
ans = affine_transform(a, min_val=-2, max_val=2) | |
assert ans.shape == (4, 3) | |
assert ans.min() == -2 and ans.max() == 2 | |
a = np.random.rand(3, 5) | |
a = (a - a.min()) / (a.max() - a.min()) | |
a = a * 2 - 1 | |
ans = affine_transform(a, alpha=4, beta=1) | |
assert ans.shape == (3, 5) | |
assert ans.min() == -3 and ans.max() == 5 | |
def test_save_frames_as_gif(): | |
frames = [np.random.randint(0, 255, [84, 84, 3]) for _ in range(100)] | |
replay_path_gif = './replay_path_gif' | |
env_id = 'test' | |
save_replay_count = 1 | |
if not os.path.exists(replay_path_gif): | |
os.makedirs(replay_path_gif) | |
path = os.path.join(replay_path_gif, '{}_episode_{}.gif'.format(env_id, save_replay_count)) | |
save_frames_as_gif(frames, path) | |
shutil.rmtree(replay_path_gif) | |