Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from ding.framework import OnlineRLContext | |
from ding.data.buffer import DequeBuffer | |
from typing import Any | |
import numpy as np | |
import copy | |
from ding.framework.middleware.functional.enhancer import reward_estimator, her_data_enhancer | |
from unittest.mock import Mock, patch | |
from ding.framework.middleware.tests import MockHerRewardModel, CONFIG | |
DATA = [{'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2)} for _ in range(20)] | |
class MockRewardModel(Mock): | |
def estimate(self, data: list) -> Any: | |
assert len(data) == len(DATA) | |
assert torch.equal(data[0]['obs'], DATA[0]['obs']) | |
def test_reward_estimator(): | |
ctx = OnlineRLContext() | |
ctx.train_data = copy.deepcopy(DATA) | |
with patch("ding.reward_model.HerRewardModel", MockHerRewardModel): | |
reward_estimator(cfg=None, reward_model=MockRewardModel())(ctx) | |
def test_her_data_enhancer(): | |
cfg = copy.deepcopy(CONFIG) | |
ctx = OnlineRLContext() | |
with patch("ding.reward_model.HerRewardModel", MockHerRewardModel): | |
mock_her_reward_model = MockHerRewardModel() | |
buffer = DequeBuffer(mock_her_reward_model.episode_size) | |
train_data = [ | |
[ | |
{ | |
'action': torch.randint(low=0, high=5, size=(1, )), | |
'collect_train_iter': torch.tensor([0]), | |
'done': torch.tensor(False), | |
'next_obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32), | |
'obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32), | |
'reward': torch.randint(low=0, high=2, size=(1, ), dtype=torch.float32), | |
} for _ in range(np.random.choice([1, 4, 5], size=1)[0]) | |
] for _ in range(mock_her_reward_model.episode_size) | |
] | |
for d in train_data: | |
buffer.push(d) | |
her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx) | |
assert len(ctx.train_data) == mock_her_reward_model.episode_size * mock_her_reward_model.episode_element_size | |
assert len(ctx.train_data[0]) == 6 | |
buffer = DequeBuffer(cfg.policy.learn.batch_size) | |
for d in train_data: | |
buffer.push(d) | |
mock_her_reward_model.episode_size = None | |
her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx) | |
assert len(ctx.train_data) == cfg.policy.learn.batch_size * mock_her_reward_model.episode_element_size | |
assert len(ctx.train_data[0]) == 6 | |