Spaces:
Sleeping
Sleeping
import pytest | |
import random | |
import copy | |
import torch | |
import treetensor.torch as ttorch | |
from unittest.mock import Mock, patch | |
from ding.data.buffer import DequeBuffer | |
from ding.framework import OnlineRLContext, task | |
from ding.framework.middleware import trainer, multistep_trainer, OffPolicyLearner, HERLearner | |
from ding.framework.middleware.tests import MockHerRewardModel, CONFIG | |
class MockPolicy(Mock): | |
_device = 'cpu' | |
# MockPolicy class for train mode | |
def forward(self, train_data, **kwargs): | |
res = { | |
'total_loss': 0.1, | |
} | |
return res | |
class MultiStepMockPolicy(Mock): | |
_device = 'cpu' | |
# MockPolicy class for multi-step train mode | |
def forward(self, train_data, **kwargs): | |
res = [ | |
{ | |
'total_loss': 0.1, | |
}, | |
{ | |
'total_loss': 1.0, | |
}, | |
] | |
return res | |
def get_mock_train_input(): | |
data = {'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2), 'reward': random.random(), 'info': {}} | |
return ttorch.as_tensor(data) | |
def test_trainer(): | |
cfg = copy.deepcopy(CONFIG) | |
ctx = OnlineRLContext() | |
ctx.train_data = None | |
with patch("ding.policy.Policy", MockPolicy): | |
policy = MockPolicy() | |
for _ in range(10): | |
trainer(cfg, policy)(ctx) | |
assert ctx.train_iter == 0 | |
ctx.train_data = get_mock_train_input() | |
with patch("ding.policy.Policy", MockPolicy): | |
policy = MockPolicy() | |
for _ in range(30): | |
trainer(cfg, policy)(ctx) | |
assert ctx.train_iter == 30 | |
assert ctx.train_output["total_loss"] == 0.1 | |
def test_multistep_trainer(): | |
cfg = copy.deepcopy(CONFIG) | |
ctx = OnlineRLContext() | |
ctx.train_data = None | |
with patch("ding.policy.Policy", MockPolicy): | |
policy = MockPolicy() | |
for _ in range(10): | |
trainer(cfg, policy)(ctx) | |
assert ctx.train_iter == 0 | |
ctx.train_data = get_mock_train_input() | |
with patch("ding.policy.Policy", MultiStepMockPolicy): | |
policy = MultiStepMockPolicy() | |
for _ in range(30): | |
multistep_trainer(policy, 10)(ctx) | |
assert ctx.train_iter == 60 | |
assert ctx.train_output[0]["total_loss"] == 0.1 | |
assert ctx.train_output[1]["total_loss"] == 1.0 | |
def test_offpolicy_learner(): | |
cfg = copy.deepcopy(CONFIG) | |
ctx = OnlineRLContext() | |
buffer = DequeBuffer(size=10) | |
for _ in range(10): | |
buffer.push(get_mock_train_input()) | |
with patch("ding.policy.Policy", MockPolicy): | |
with task.start(): | |
policy = MockPolicy() | |
learner = OffPolicyLearner(cfg, policy, buffer) | |
learner(ctx) | |
assert len(ctx.train_output) == 4 | |
def test_her_learner(): | |
cfg = copy.deepcopy(CONFIG) | |
ctx = OnlineRLContext() | |
buffer = DequeBuffer(size=10) | |
for _ in range(10): | |
buffer.push([get_mock_train_input(), get_mock_train_input()]) | |
with patch("ding.policy.Policy", MockPolicy), patch("ding.reward_model.HerRewardModel", MockHerRewardModel): | |
with task.start(): | |
policy = MockPolicy() | |
her_reward_model = MockHerRewardModel() | |
learner = HERLearner(cfg, policy, buffer, her_reward_model) | |
learner(ctx) | |
assert len(ctx.train_output) == 4 | |