Spaces:
Running
Running
from multiprocessing import Process, Pipe | |
import gym | |
def worker(conn, env): | |
while True: | |
cmd, data = conn.recv() | |
if cmd == "step": | |
obs, reward, done, info = env.step(data) | |
if done: | |
obs = env.reset() | |
conn.send((obs, reward, done, info)) | |
elif cmd == "set_curriculum_parameters": | |
env.set_curriculum_parameters(data) | |
conn.send(None) | |
elif cmd == "reset": | |
obs = env.reset() | |
conn.send(obs) | |
elif cmd == "get_mission": | |
ks = env.get_mission() | |
conn.send(ks) | |
else: | |
raise NotImplementedError | |
class ParallelEnv(gym.Env): | |
"""A concurrent execution of environments in multiple processes.""" | |
def __init__(self, envs): | |
assert len(envs) >= 1, "No environment given." | |
self.envs = envs | |
self.observation_space = self.envs[0].observation_space | |
self.action_space = self.envs[0].action_space | |
if hasattr(self.envs[0], "curriculum"): | |
self.curriculum = self.envs[0].curriculum | |
self.locals = [] | |
for env in self.envs[1:]: | |
local, remote = Pipe() | |
self.locals.append(local) | |
p = Process(target=worker, args=(remote, env)) | |
p.daemon = True | |
p.start() | |
remote.close() | |
def broadcast_curriculum_parameters(self, data): | |
# broadcast curriculum_data to every worker | |
for local in self.locals: | |
local.send(("set_curriculum_parameters", data)) | |
results = [self.envs[0].set_curriculum_parameters(data)] + [local.recv() for local in self.locals] | |
def get_mission(self): | |
for local in self.locals: | |
local.send(("get_mission", None)) | |
results = [self.envs[0].get_mission()] + [local.recv() for local in self.locals] | |
return results | |
def reset(self): | |
for local in self.locals: | |
local.send(("reset", None)) | |
results = [self.envs[0].reset()] + [local.recv() for local in self.locals] | |
return results | |
def step(self, actions): | |
for local, action in zip(self.locals, actions[1:]): | |
local.send(("step", action)) | |
obs, reward, done, info = self.envs[0].step(actions[0]) | |
if done: | |
obs = self.envs[0].reset() | |
results = zip(*[(obs, reward, done, info)] + [local.recv() for local in self.locals]) | |
return results | |
def render(self): | |
raise NotImplementedError |