Spaces:
Running
Running
File size: 6,729 Bytes
e11e4fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import grpc
from typing import Optional
from multiprocessing import Pipe
from sys import platform
import socket
import time
from concurrent.futures import ThreadPoolExecutor
from .communicator import Communicator, PollCallback
from mlagents_envs.communicator_objects.unity_to_external_pb2_grpc import (
UnityToExternalProtoServicer,
add_UnityToExternalProtoServicer_to_server,
)
from mlagents_envs.communicator_objects.unity_message_pb2 import UnityMessageProto
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto
from .exception import UnityTimeOutException, UnityWorkerInUseException
class UnityToExternalServicerImplementation(UnityToExternalProtoServicer):
def __init__(self):
self.parent_conn, self.child_conn = Pipe()
def Initialize(self, request, context):
self.child_conn.send(request)
return self.child_conn.recv()
def Exchange(self, request, context):
self.child_conn.send(request)
return self.child_conn.recv()
class RpcCommunicator(Communicator):
def __init__(self, worker_id=0, base_port=5005, timeout_wait=30):
"""
Python side of the grpc communication. Python is the server and Unity the client
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
:int worker_id: Offset from base_port. Used for training multiple environments simultaneously.
:int timeout_wait: Timeout (in seconds) to wait for a response before exiting.
"""
super().__init__(worker_id, base_port)
self.port = base_port + worker_id
self.worker_id = worker_id
self.timeout_wait = timeout_wait
self.server = None
self.unity_to_external = None
self.is_open = False
self.create_server()
def create_server(self):
"""
Creates the GRPC server.
"""
self.check_port(self.port)
try:
# Establish communication grpc
self.server = grpc.server(
thread_pool=ThreadPoolExecutor(max_workers=10),
options=(("grpc.so_reuseport", 1),),
)
self.unity_to_external = UnityToExternalServicerImplementation()
add_UnityToExternalProtoServicer_to_server(
self.unity_to_external, self.server
)
# Using unspecified address, which means that grpc is communicating on all IPs
# This is so that the docker container can connect.
self.server.add_insecure_port("[::]:" + str(self.port))
self.server.start()
self.is_open = True
except Exception:
raise UnityWorkerInUseException(self.worker_id)
def check_port(self, port):
"""
Attempts to bind to the requested communicator port, checking if it is already in use.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if platform == "linux" or platform == "linux2":
# On linux, the port remains unusable for TIME_WAIT=60 seconds after closing
# SO_REUSEADDR frees the port right after closing the environment
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
s.bind(("localhost", port))
except OSError:
raise UnityWorkerInUseException(self.worker_id)
finally:
s.close()
def poll_for_timeout(self, poll_callback: Optional[PollCallback] = None) -> None:
"""
Polls the GRPC parent connection for data, to be used before calling recv. This prevents
us from hanging indefinitely in the case where the environment process has died or was not
launched.
Additionally, a callback can be passed to periodically check the state of the environment.
This is used to detect the case when the environment dies without cleaning up the connection,
so that we can stop sooner and raise a more appropriate error.
"""
deadline = time.monotonic() + self.timeout_wait
callback_timeout_wait = self.timeout_wait // 10
while time.monotonic() < deadline:
if self.unity_to_external.parent_conn.poll(callback_timeout_wait):
# Got an acknowledgment from the connection
return
if poll_callback:
# Fire the callback - if it detects something wrong, it should raise an exception.
poll_callback()
# Got this far without reading any data from the connection, so it must be dead.
raise UnityTimeOutException(
"The Unity environment took too long to respond. Make sure that :\n"
"\t The environment does not need user interaction to launch\n"
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n'
"\t The environment and the Python interface have compatible versions.\n"
"\t If you're running on a headless server without graphics support, turn off display "
"by either passing --no-graphics option or build your Unity executable as server build."
)
def initialize(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> UnityOutputProto:
self.poll_for_timeout(poll_callback)
aca_param = self.unity_to_external.parent_conn.recv().unity_output
message = UnityMessageProto()
message.header.status = 200
message.unity_input.CopyFrom(inputs)
self.unity_to_external.parent_conn.send(message)
self.unity_to_external.parent_conn.recv()
return aca_param
def exchange(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> Optional[UnityOutputProto]:
message = UnityMessageProto()
message.header.status = 200
message.unity_input.CopyFrom(inputs)
self.unity_to_external.parent_conn.send(message)
self.poll_for_timeout(poll_callback)
output = self.unity_to_external.parent_conn.recv()
if output.header.status != 200:
return None
return output.unity_output
def close(self):
"""
Sends a shutdown signal to the unity environment, and closes the grpc connection.
"""
if self.is_open:
message_input = UnityMessageProto()
message_input.header.status = 400
self.unity_to_external.parent_conn.send(message_input)
self.unity_to_external.parent_conn.close()
self.server.stop(False)
self.is_open = False
|