Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
import copy | |
from unittest.mock import patch | |
from ding.framework import OnlineRLContext, task | |
from ding.framework.middleware import TransitionList, inferencer, rolloutor | |
from ding.framework.middleware import StepCollector, EpisodeCollector | |
from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG | |
def test_inferencer(): | |
ctx = OnlineRLContext() | |
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
policy = MockPolicy() | |
env = MockEnv() | |
inferencer(0, policy, env)(ctx) | |
assert isinstance(ctx.inference_output, dict) | |
assert ctx.inference_output[0] == {'action': torch.Tensor([0.])} # sum of zeros([2, 2]) | |
assert ctx.inference_output[1] == {'action': torch.Tensor([4.])} # sum of ones([2, 2]) | |
def test_rolloutor(): | |
ctx = OnlineRLContext() | |
transitions = TransitionList(2) | |
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
policy = MockPolicy() | |
env = MockEnv() | |
for _ in range(10): | |
inferencer(0, policy, env)(ctx) | |
rolloutor(policy, env, transitions)(ctx) | |
assert ctx.env_episode == 20 # 10 * env_num | |
assert ctx.env_step == 20 # 10 * env_num | |
def test_step_collector(): | |
cfg = copy.deepcopy(CONFIG) | |
ctx = OnlineRLContext() | |
# test no random_collect_size | |
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
with task.start(): | |
policy = MockPolicy() | |
env = MockEnv() | |
collector = StepCollector(cfg, policy, env) | |
collector(ctx) | |
assert len(ctx.trajectories) == 16 | |
assert ctx.trajectory_end_idx == [7, 15] | |
# test with random_collect_size | |
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
with task.start(): | |
policy = MockPolicy() | |
env = MockEnv() | |
collector = StepCollector(cfg, policy, env, random_collect_size=8) | |
collector(ctx) | |
assert len(ctx.trajectories) == 16 | |
assert ctx.trajectory_end_idx == [7, 15] | |
def test_episode_collector(): | |
cfg = copy.deepcopy(CONFIG) | |
ctx = OnlineRLContext() | |
# test no random_collect_size | |
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
with task.start(): | |
policy = MockPolicy() | |
env = MockEnv() | |
collector = EpisodeCollector(cfg, policy, env) | |
collector(ctx) | |
assert len(ctx.episodes) == 16 | |
# test with random_collect_size | |
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
with task.start(): | |
policy = MockPolicy() | |
env = MockEnv() | |
collector = EpisodeCollector(cfg, policy, env, random_collect_size=8) | |
collector(ctx) | |
assert len(ctx.episodes) == 16 | |