Spaces:
Sleeping
Sleeping
import grpc | |
from typing import Optional | |
from multiprocessing import Pipe | |
from sys import platform | |
import socket | |
import time | |
from concurrent.futures import ThreadPoolExecutor | |
from .communicator import Communicator, PollCallback | |
from mlagents_envs.communicator_objects.unity_to_external_pb2_grpc import ( | |
UnityToExternalProtoServicer, | |
add_UnityToExternalProtoServicer_to_server, | |
) | |
from mlagents_envs.communicator_objects.unity_message_pb2 import UnityMessageProto | |
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto | |
from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto | |
from .exception import UnityTimeOutException, UnityWorkerInUseException | |
class UnityToExternalServicerImplementation(UnityToExternalProtoServicer): | |
def __init__(self): | |
self.parent_conn, self.child_conn = Pipe() | |
def Initialize(self, request, context): | |
self.child_conn.send(request) | |
return self.child_conn.recv() | |
def Exchange(self, request, context): | |
self.child_conn.send(request) | |
return self.child_conn.recv() | |
class RpcCommunicator(Communicator): | |
def __init__(self, worker_id=0, base_port=5005, timeout_wait=30): | |
""" | |
Python side of the grpc communication. Python is the server and Unity the client | |
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. | |
:int worker_id: Offset from base_port. Used for training multiple environments simultaneously. | |
:int timeout_wait: Timeout (in seconds) to wait for a response before exiting. | |
""" | |
super().__init__(worker_id, base_port) | |
self.port = base_port + worker_id | |
self.worker_id = worker_id | |
self.timeout_wait = timeout_wait | |
self.server = None | |
self.unity_to_external = None | |
self.is_open = False | |
self.create_server() | |
def create_server(self): | |
""" | |
Creates the GRPC server. | |
""" | |
self.check_port(self.port) | |
try: | |
# Establish communication grpc | |
self.server = grpc.server( | |
thread_pool=ThreadPoolExecutor(max_workers=10), | |
options=(("grpc.so_reuseport", 1),), | |
) | |
self.unity_to_external = UnityToExternalServicerImplementation() | |
add_UnityToExternalProtoServicer_to_server( | |
self.unity_to_external, self.server | |
) | |
# Using unspecified address, which means that grpc is communicating on all IPs | |
# This is so that the docker container can connect. | |
self.server.add_insecure_port("[::]:" + str(self.port)) | |
self.server.start() | |
self.is_open = True | |
except Exception: | |
raise UnityWorkerInUseException(self.worker_id) | |
def check_port(self, port): | |
""" | |
Attempts to bind to the requested communicator port, checking if it is already in use. | |
""" | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
if platform == "linux" or platform == "linux2": | |
# On linux, the port remains unusable for TIME_WAIT=60 seconds after closing | |
# SO_REUSEADDR frees the port right after closing the environment | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
try: | |
s.bind(("localhost", port)) | |
except OSError: | |
raise UnityWorkerInUseException(self.worker_id) | |
finally: | |
s.close() | |
def poll_for_timeout(self, poll_callback: Optional[PollCallback] = None) -> None: | |
""" | |
Polls the GRPC parent connection for data, to be used before calling recv. This prevents | |
us from hanging indefinitely in the case where the environment process has died or was not | |
launched. | |
Additionally, a callback can be passed to periodically check the state of the environment. | |
This is used to detect the case when the environment dies without cleaning up the connection, | |
so that we can stop sooner and raise a more appropriate error. | |
""" | |
deadline = time.monotonic() + self.timeout_wait | |
callback_timeout_wait = self.timeout_wait // 10 | |
while time.monotonic() < deadline: | |
if self.unity_to_external.parent_conn.poll(callback_timeout_wait): | |
# Got an acknowledgment from the connection | |
return | |
if poll_callback: | |
# Fire the callback - if it detects something wrong, it should raise an exception. | |
poll_callback() | |
# Got this far without reading any data from the connection, so it must be dead. | |
raise UnityTimeOutException( | |
"The Unity environment took too long to respond. Make sure that :\n" | |
"\t The environment does not need user interaction to launch\n" | |
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n' | |
"\t The environment and the Python interface have compatible versions.\n" | |
"\t If you're running on a headless server without graphics support, turn off display " | |
"by either passing --no-graphics option or build your Unity executable as server build." | |
) | |
def initialize( | |
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None | |
) -> UnityOutputProto: | |
self.poll_for_timeout(poll_callback) | |
aca_param = self.unity_to_external.parent_conn.recv().unity_output | |
message = UnityMessageProto() | |
message.header.status = 200 | |
message.unity_input.CopyFrom(inputs) | |
self.unity_to_external.parent_conn.send(message) | |
self.unity_to_external.parent_conn.recv() | |
return aca_param | |
def exchange( | |
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None | |
) -> Optional[UnityOutputProto]: | |
message = UnityMessageProto() | |
message.header.status = 200 | |
message.unity_input.CopyFrom(inputs) | |
self.unity_to_external.parent_conn.send(message) | |
self.poll_for_timeout(poll_callback) | |
output = self.unity_to_external.parent_conn.recv() | |
if output.header.status != 200: | |
return None | |
return output.unity_output | |
def close(self): | |
""" | |
Sends a shutdown signal to the unity environment, and closes the grpc connection. | |
""" | |
if self.is_open: | |
message_input = UnityMessageProto() | |
message_input.header.status = 400 | |
self.unity_to_external.parent_conn.send(message_input) | |
self.unity_to_external.parent_conn.close() | |
self.server.stop(False) | |
self.is_open = False | |