Spaces:
Sleeping
Sleeping
#unittest for priority_calculator | |
import unittest | |
import pytest | |
import numpy as np | |
from unittest.mock import Mock, patch | |
from ding.framework import OnlineRLContext, OfflineRLContext | |
from ding.framework import task, Parallel | |
from ding.framework.middleware.functional import priority_calculator | |
class MockPolicy(Mock): | |
def priority_fun(self, data): | |
return np.random.rand(len(data)) | |
def test_priority_calculator(): | |
policy = MockPolicy() | |
ctx = OnlineRLContext() | |
ctx.trajectories = [ | |
{ | |
'obs': np.random.rand(2, 2), | |
'next_obs': np.random.rand(2, 2), | |
'reward': np.random.rand(1), | |
'info': {} | |
} for _ in range(10) | |
] | |
priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun) | |
priority_calculator_middleware(ctx) | |
assert len(ctx.trajectories) == 10 | |
assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories]) | |