Spaces:
Sleeping
Sleeping
import random | |
import time | |
import socket | |
import pytest | |
import multiprocessing as mp | |
from ditk import logging | |
from ding.framework import task | |
from ding.framework.parallel import Parallel | |
from ding.framework.context import OnlineRLContext | |
from ding.framework.middleware.barrier import Barrier | |
PORTS_LIST = ["1235", "1236", "1237"] | |
class EnvStepMiddleware: | |
def __call__(self, ctx): | |
yield | |
ctx.env_step += 1 | |
class SleepMiddleware: | |
def __init__(self, node_id): | |
self.node_id = node_id | |
def random_sleep(self, diection, step): | |
random.seed(self.node_id + step) | |
sleep_second = random.randint(1, 5) | |
logging.info("Node:[{}] env_step:[{}]-{} will sleep:{}s".format(self.node_id, step, diection, sleep_second)) | |
for i in range(sleep_second): | |
time.sleep(1) | |
print("Node:[{}] sleepping...".format(self.node_id)) | |
logging.info("Node:[{}] env_step:[{}]-{} wake up!".format(self.node_id, step, diection)) | |
def __call__(self, ctx): | |
self.random_sleep("forward", ctx.env_step) | |
yield | |
self.random_sleep("backward", ctx.env_step) | |
def star_barrier(): | |
with task.start(ctx=OnlineRLContext()): | |
node_id = task.router.node_id | |
if node_id == 0: | |
attch_from_nums = 3 | |
else: | |
attch_from_nums = 0 | |
barrier = Barrier(attch_from_nums) | |
task.use(barrier, lock=False) | |
task.use(SleepMiddleware(node_id), lock=False) | |
task.use(barrier, lock=False) | |
task.use(EnvStepMiddleware(), lock=False) | |
try: | |
task.run(2) | |
except Exception as e: | |
logging.error(e) | |
assert False | |
def mesh_barrier(): | |
with task.start(ctx=OnlineRLContext()): | |
node_id = task.router.node_id | |
attch_from_nums = 3 - task.router.node_id | |
barrier = Barrier(attch_from_nums) | |
task.use(barrier, lock=False) | |
task.use(SleepMiddleware(node_id), lock=False) | |
task.use(barrier, lock=False) | |
task.use(EnvStepMiddleware(), lock=False) | |
try: | |
task.run(2) | |
except Exception as e: | |
logging.error(e) | |
assert False | |
def unmatch_barrier(): | |
with task.start(ctx=OnlineRLContext()): | |
node_id = task.router.node_id | |
attch_from_nums = 3 - task.router.node_id | |
task.use(Barrier(attch_from_nums, 5), lock=False) | |
if node_id != 2: | |
task.use(Barrier(attch_from_nums, 5), lock=False) | |
try: | |
task.run(2) | |
except TimeoutError as e: | |
assert node_id != 2 | |
logging.info("Node:[{}] timeout with barrier".format(node_id)) | |
else: | |
time.sleep(5) | |
assert node_id == 2 | |
logging.info("Node:[{}] finish barrier".format(node_id)) | |
def launch_barrier(args): | |
i, topo, fn, test_id = args | |
address = socket.gethostbyname(socket.gethostname()) | |
topology = "alone" | |
attach_to = [] | |
port_base = PORTS_LIST[test_id] | |
port = port_base + str(i) | |
if topo == 'star': | |
if i != 0: | |
attach_to = ['tcp://{}:{}{}'.format(address, port_base, 0)] | |
elif topo == 'mesh': | |
for j in range(i): | |
attach_to.append('tcp://{}:{}{}'.format(address, port_base, j)) | |
Parallel.runner( | |
node_ids=i, | |
ports=int(port), | |
attach_to=attach_to, | |
topology=topology, | |
protocol="tcp", | |
n_parallel_workers=1, | |
startup_interval=0 | |
)(fn) | |
def test_star_topology_barrier(): | |
ctx = mp.get_context("spawn") | |
with ctx.Pool(processes=4) as pool: | |
pool.map(launch_barrier, [[i, 'star', star_barrier, 0] for i in range(4)]) | |
pool.close() | |
pool.join() | |
def test_mesh_topology_barrier(): | |
ctx = mp.get_context("spawn") | |
with ctx.Pool(processes=4) as pool: | |
pool.map(launch_barrier, [[i, 'mesh', mesh_barrier, 1] for i in range(4)]) | |
pool.close() | |
pool.join() | |
def test_unmatch_barrier(): | |
ctx = mp.get_context("spawn") | |
with ctx.Pool(processes=4) as pool: | |
pool.map(launch_barrier, [[i, 'mesh', unmatch_barrier, 2] for i in range(4)]) | |
pool.close() | |
pool.join() | |