Spaces:
Sleeping
Sleeping
import multiprocessing as mp | |
import ctypes | |
from time import sleep, time | |
from typing import Any, Dict, List | |
import pytest | |
from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType | |
class MockEnv(): | |
def __init__(self, _) -> None: | |
self._counter = 0 | |
def step(self, _): | |
self._counter += 1 | |
return self._counter | |
def action_space(self): | |
return 3 | |
def block(self): | |
sleep(10) | |
def block_reset(self): | |
sleep(10) | |
def sleep1(self): | |
sleep(1) | |
def test_supervisor(type_): | |
sv = Supervisor(type_=type_) | |
for _ in range(3): | |
sv.register(lambda: MockEnv("AnyArgs")) | |
sv.start_link() | |
for env_id in range(len(sv._children)): | |
sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"])) | |
recv_states: List[RecvPayload] = [] | |
for _ in range(3): | |
recv_states.append(sv.recv()) | |
assert sum([payload.proc_id for payload in recv_states]) == 3 | |
assert all([payload.data == 1 for payload in recv_states]) | |
# Test recv_all | |
send_payloads = [] | |
for env_id in range(len(sv._children)): | |
payload = SendPayload( | |
proc_id=env_id, | |
method="step", | |
args=["any action"], | |
) | |
send_payloads.append(payload) | |
sv.send(payload) | |
req_ids = [payload.req_id for payload in send_payloads] | |
# Only wait for last two messages, keep the first one in the queue. | |
recv_payloads = sv.recv_all(send_payloads[1:]) | |
assert len(recv_payloads) == 2 | |
for req_id, payload in zip(req_ids[1:], recv_payloads): | |
assert req_id == payload.req_id | |
recv_payload = sv.recv() | |
assert recv_payload.req_id == req_ids[0] | |
assert len(sv.action_space) == 3 | |
assert all(a == 3 for a in sv.action_space) | |
sv.shutdown() | |
def test_supervisor_spawn(): | |
sv = Supervisor(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) | |
for _ in range(3): | |
sv.register(MockEnv("AnyArgs")) | |
sv.start_link() | |
for env_id in range(len(sv._children)): | |
sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"])) | |
recv_states: List[RecvPayload] = [] | |
for _ in range(3): | |
recv_states.append(sv.recv()) | |
assert sum([payload.proc_id for payload in recv_states]) == 3 | |
assert all([payload.data == 1 for payload in recv_states]) | |
sv.shutdown() | |
class MockCrashEnv(MockEnv): | |
def step(self, _): | |
super().step(_) | |
if self._counter == 2: | |
raise Exception("Ohh") | |
return self._counter | |
def test_crash_supervisor(type_): | |
sv = Supervisor(type_=type_) | |
for _ in range(2): | |
sv.register(lambda: MockEnv("AnyArgs")) | |
sv.register(lambda: MockCrashEnv("AnyArgs")) | |
sv.start_link() | |
# Send 6 messages, will cause the third subprocess crash | |
for env_id in range(len(sv._children)): | |
for _ in range(2): | |
sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"])) | |
# Find the error mesasge | |
recv_states: List[RecvPayload] = [] | |
for _ in range(6): | |
recv_payload = sv.recv(ignore_err=True) | |
if recv_payload.err: | |
sv._children[recv_payload.proc_id].restart() | |
recv_states.append(recv_payload) | |
assert any([isinstance(payload.err, Exception) for payload in recv_states]) | |
# Resume | |
for env_id in range(len(sv._children)): | |
sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"])) | |
recv_states: List[RecvPayload] = [] | |
for _ in range(3): | |
recv_states.append(sv.recv()) | |
# 3 + 3 + 1 | |
assert sum([p.data for p in recv_states]) == 7 | |
with pytest.raises(Exception): | |
sv.send(SendPayload(proc_id=2, method="step", args=["any action"])) | |
sv.recv(ignore_err=False) | |
sv.shutdown() | |
def test_recv_all(type_): | |
sv = Supervisor(type_=type_) | |
for _ in range(3): | |
sv.register(lambda: MockEnv("AnyArgs")) | |
sv.start_link() | |
# Test recv_all | |
send_payloads = [] | |
for env_id in range(len(sv._children)): | |
payload = SendPayload( | |
proc_id=env_id, | |
method="step", | |
args=["any action"], | |
) | |
send_payloads.append(payload) | |
sv.send(payload) | |
retry_times = {env_id: 0 for env_id in range(len(sv._children))} | |
def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayload]): | |
if retry_times[recv_payload.proc_id] == 2: | |
return | |
retry_times[recv_payload.proc_id] += 1 | |
payload = SendPayload(proc_id=recv_payload.proc_id, method="step", args={"action"}) | |
sv.send(payload) | |
remain_payloads[payload.req_id] = payload | |
recv_payloads = sv.recv_all(send_payloads=send_payloads, callback=recv_callback) | |
assert len(recv_payloads) == 3 | |
assert all([v == 2 for v in retry_times.values()]) | |
sv.shutdown() | |
def test_timeout(type_): | |
sv = Supervisor(type_=type_) | |
for _ in range(3): | |
sv.register(lambda: MockEnv("AnyArgs")) | |
sv.start_link() | |
send_payloads = [] | |
for env_id in range(len(sv._children)): | |
payload = SendPayload(proc_id=env_id, method="block") | |
send_payloads.append(payload) | |
sv.send(payload) | |
# Test timeout exception | |
with pytest.raises(TimeoutError): | |
sv.recv_all(send_payloads=send_payloads, timeout=1) | |
sv.shutdown(timeout=1) | |
# Test timeout with ignore error | |
sv.start_link() | |
send_payloads = [] | |
# 0 is block | |
payload = SendPayload(proc_id=0, method="block") | |
send_payloads.append(payload) | |
sv.send(payload) | |
# 1 is step | |
payload = SendPayload(proc_id=1, method="step", args=[""]) | |
send_payloads.append(payload) | |
sv.send(payload) | |
payloads = sv.recv_all(send_payloads=send_payloads, timeout=1, ignore_err=True) | |
assert isinstance(payloads[0].err, TimeoutError) | |
assert payloads[1].err is None | |
sv.shutdown(timeout=1) | |
def test_timeout_with_callback(type_): | |
sv = Supervisor(type_=type_) | |
for _ in range(3): | |
sv.register(lambda: MockEnv("AnyArgs")) | |
sv.start_link() | |
send_payloads = [] | |
# 0 is block | |
payload = SendPayload(proc_id=0, method="block") | |
send_payloads.append(payload) | |
sv.send(payload) | |
# 1 is step | |
payload = SendPayload(proc_id=1, method="step", args=[""]) | |
send_payloads.append(payload) | |
sv.send(payload) | |
block_reset_callback = False | |
# 1. Add another send payload in the callback | |
# 2. Recv this send payload and check for the method | |
def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayload]): | |
if recv_payload.method == "block" and recv_payload.err: | |
new_send_payload = SendPayload(proc_id=recv_payload.proc_id, method="block_reset") | |
remain_payloads[new_send_payload.req_id] = new_send_payload | |
return | |
if recv_payload.method == "block_reset" and recv_payload.err: | |
nonlocal block_reset_callback | |
block_reset_callback = True | |
return | |
payloads = sv.recv_all(send_payloads=send_payloads, timeout=1, ignore_err=True, callback=recv_callback) | |
assert block_reset_callback | |
assert isinstance(payloads[0].err, TimeoutError) | |
assert payloads[1].err is None | |
sv.shutdown(timeout=1) | |
# gitlab ci and local test pass, github always fail | |
def test_shared_memory(): | |
sv = Supervisor(type_=ChildType.PROCESS) | |
def shm_callback(payload: RecvPayload, shm: Any): | |
shm[payload.proc_id] = payload.req_id | |
payload.data = 0 | |
shm = mp.Array(ctypes.c_uint8, 3) | |
for i in range(3): | |
sv.register(lambda: MockEnv("AnyArgs"), shm_buffer=shm, shm_callback=shm_callback) | |
sv.start_link() | |
# Send init request | |
for env_id in range(len(sv._children)): | |
sv.send(SendPayload(proc_id=env_id, req_id=env_id, method="sleep1", args=[])) | |
start = time() | |
for i in range(6): | |
payload = sv.recv() | |
assert payload.data == 0 | |
assert shm[payload.proc_id] == payload.req_id | |
sv.send(SendPayload(proc_id=payload.proc_id, req_id=i, method="sleep1", args=[])) | |
# Non blocking | |
assert time() - start < 3 | |
sv.shutdown() | |
def test_supervisor_benchmark(type_): | |
sv = Supervisor(type_=type_) | |
for _ in range(3): | |
sv.register(lambda: MockEnv("AnyArgs")) | |
sv.start_link() | |
for env_id in range(len(sv._children)): | |
sv.send(SendPayload(proc_id=env_id, method="step", args=[""])) | |
start = time() | |
for _ in range(1000): | |
payload = sv.recv() | |
sv.send(SendPayload(proc_id=payload.proc_id, method="step", args=[""])) | |
assert time() - start < 1 | |