|
#if UNITY_EDITOR || UNITY_STANDALONE |
|
#define MLA_SUPPORTED_TRAINING_PLATFORM |
|
#endif |
|
|
|
#if MLA_SUPPORTED_TRAINING_PLATFORM |
|
using Grpc.Core; |
|
#if UNITY_EDITOR |
|
using UnityEditor; |
|
#endif |
|
using System; |
|
using System.Collections.Generic; |
|
using System.Linq; |
|
using UnityEngine; |
|
using Unity.MLAgents.Actuators; |
|
using Unity.MLAgents.CommunicatorObjects; |
|
using Unity.MLAgents.Sensors; |
|
using Unity.MLAgents.SideChannels; |
|
using Google.Protobuf; |
|
|
|
using Unity.MLAgents.Analytics; |
|
|
|
namespace Unity.MLAgents |
|
{ |
|
|
|
public class RpcCommunicator : ICommunicator |
|
{ |
|
public event QuitCommandHandler QuitCommandReceived; |
|
public event ResetCommandHandler ResetCommandReceived; |
|
|
|
|
|
bool m_IsOpen; |
|
|
|
List<string> m_BehaviorNames = new List<string>(); |
|
bool m_NeedCommunicateThisStep; |
|
ObservationWriter m_ObservationWriter = new ObservationWriter(); |
|
Dictionary<string, SensorShapeValidator> m_SensorShapeValidators = new Dictionary<string, SensorShapeValidator>(); |
|
Dictionary<string, List<int>> m_OrderedAgentsRequestingDecisions = new Dictionary<string, List<int>>(); |
|
|
|
|
|
UnityRLOutputProto m_CurrentUnityRlOutput = |
|
new UnityRLOutputProto(); |
|
|
|
Dictionary<string, Dictionary<int, ActionBuffers>> m_LastActionsReceived = |
|
new Dictionary<string, Dictionary<int, ActionBuffers>>(); |
|
|
|
|
|
HashSet<string> m_SentBrainKeys = new HashSet<string>(); |
|
Dictionary<string, ActionSpec> m_UnsentBrainKeys = new Dictionary<string, ActionSpec>(); |
|
|
|
|
|
|
|
UnityToExternalProto.UnityToExternalProtoClient m_Client; |
|
Channel m_Channel; |
|
|
|
|
|
|
|
|
|
protected RpcCommunicator() |
|
{ |
|
} |
|
|
|
public static RpcCommunicator Create() |
|
{ |
|
#if MLA_SUPPORTED_TRAINING_PLATFORM |
|
return new RpcCommunicator(); |
|
#else |
|
return null; |
|
#endif |
|
} |
|
|
|
#region Initialization |
|
|
|
internal static bool CheckCommunicationVersionsAreCompatible( |
|
string unityCommunicationVersion, |
|
string pythonApiVersion |
|
) |
|
{ |
|
var unityVersion = new Version(unityCommunicationVersion); |
|
var pythonVersion = new Version(pythonApiVersion); |
|
if (unityVersion.Major == 0) |
|
{ |
|
if (unityVersion.Major != pythonVersion.Major || unityVersion.Minor != pythonVersion.Minor) |
|
{ |
|
return false; |
|
} |
|
} |
|
else if (unityVersion.Major != pythonVersion.Major) |
|
{ |
|
return false; |
|
} |
|
else if (unityVersion.Minor != pythonVersion.Minor) |
|
{ |
|
|
|
|
|
} |
|
return true; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public bool Initialize(CommunicatorInitParameters initParameters, out UnityRLInitParameters initParametersOut) |
|
{ |
|
#if MLA_SUPPORTED_TRAINING_PLATFORM |
|
var academyParameters = new UnityRLInitializationOutputProto |
|
{ |
|
Name = initParameters.name, |
|
PackageVersion = initParameters.unityPackageVersion, |
|
CommunicationVersion = initParameters.unityCommunicationVersion, |
|
Capabilities = initParameters.CSharpCapabilities.ToProto() |
|
}; |
|
|
|
UnityInputProto input; |
|
UnityInputProto initializationInput; |
|
try |
|
{ |
|
initializationInput = Initialize( |
|
initParameters.port, |
|
new UnityOutputProto |
|
{ |
|
RlInitializationOutput = academyParameters |
|
}, |
|
out input |
|
); |
|
} |
|
catch (Exception ex) |
|
{ |
|
if (ex is RpcException rpcException) |
|
{ |
|
switch (rpcException.Status.StatusCode) |
|
{ |
|
case StatusCode.Unavailable: |
|
|
|
break; |
|
case StatusCode.DeadlineExceeded: |
|
|
|
break; |
|
default: |
|
Debug.Log($"Unexpected gRPC exception when trying to initialize communication: {rpcException}"); |
|
break; |
|
} |
|
} |
|
else |
|
{ |
|
Debug.Log($"Unexpected exception when trying to initialize communication: {ex}"); |
|
} |
|
initParametersOut = new UnityRLInitParameters(); |
|
NotifyQuitAndShutDownChannel(); |
|
return false; |
|
} |
|
|
|
var pythonPackageVersion = initializationInput.RlInitializationInput.PackageVersion; |
|
var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion; |
|
TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion); |
|
|
|
var communicationIsCompatible = CheckCommunicationVersionsAreCompatible( |
|
initParameters.unityCommunicationVersion, |
|
pythonCommunicationVersion |
|
); |
|
|
|
|
|
|
|
if (initializationInput != null && input == null) |
|
{ |
|
if (!communicationIsCompatible) |
|
{ |
|
Debug.LogWarningFormat( |
|
"Communication protocol between python ({0}) and Unity ({1}) have different " + |
|
"versions which make them incompatible. Python library version: {2}.", |
|
pythonCommunicationVersion, initParameters.unityCommunicationVersion, |
|
pythonPackageVersion |
|
); |
|
} |
|
else |
|
{ |
|
Debug.LogWarningFormat( |
|
"Unknown communication error between Python. Python communication protocol: {0}, " + |
|
"Python library version: {1}.", |
|
pythonCommunicationVersion, |
|
pythonPackageVersion |
|
); |
|
} |
|
|
|
initParametersOut = new UnityRLInitParameters(); |
|
return false; |
|
} |
|
|
|
UpdateEnvironmentWithInput(input.RlInput); |
|
initParametersOut = initializationInput.RlInitializationInput.ToUnityRLInitParameters(); |
|
|
|
Application.quitting += NotifyQuitAndShutDownChannel; |
|
return true; |
|
#else |
|
initParametersOut = new UnityRLInitParameters(); |
|
return false; |
|
#endif |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public void SubscribeBrain(string brainKey, ActionSpec actionSpec) |
|
{ |
|
if (m_BehaviorNames.Contains(brainKey)) |
|
{ |
|
return; |
|
} |
|
m_BehaviorNames.Add(brainKey); |
|
m_CurrentUnityRlOutput.AgentInfos.Add( |
|
brainKey, |
|
new UnityRLOutputProto.Types.ListAgentInfoProto() |
|
); |
|
|
|
CacheActionSpec(brainKey, actionSpec); |
|
} |
|
|
|
void UpdateEnvironmentWithInput(UnityRLInputProto rlInput) |
|
{ |
|
SideChannelManager.ProcessSideChannelData(rlInput.SideChannel.ToArray()); |
|
SendCommandEvent(rlInput.Command); |
|
} |
|
|
|
UnityInputProto Initialize(int port, UnityOutputProto unityOutput, out UnityInputProto unityInput) |
|
{ |
|
m_IsOpen = true; |
|
m_Channel = new Channel($"localhost:{port}", ChannelCredentials.Insecure); |
|
|
|
m_Client = new UnityToExternalProto.UnityToExternalProtoClient(m_Channel); |
|
var result = m_Client.Exchange(WrapMessage(unityOutput, 200)); |
|
var inputMessage = m_Client.Exchange(WrapMessage(null, 200)); |
|
unityInput = inputMessage.UnityInput; |
|
#if UNITY_EDITOR |
|
EditorApplication.playModeStateChanged += HandleOnPlayModeChanged; |
|
#endif |
|
if (result.Header.Status != 200 || inputMessage.Header.Status != 200) |
|
{ |
|
m_IsOpen = false; |
|
NotifyQuitAndShutDownChannel(); |
|
} |
|
return result.UnityInput; |
|
} |
|
|
|
void NotifyQuitAndShutDownChannel() |
|
{ |
|
QuitCommandReceived?.Invoke(); |
|
try |
|
{ |
|
m_Channel.ShutdownAsync().Wait(); |
|
} |
|
catch (Exception) |
|
{ |
|
|
|
} |
|
} |
|
|
|
#endregion |
|
|
|
#region Destruction |
|
|
|
|
|
|
|
|
|
public void Dispose() |
|
{ |
|
if (!m_IsOpen) |
|
{ |
|
return; |
|
} |
|
|
|
try |
|
{ |
|
m_Client.Exchange(WrapMessage(null, 400)); |
|
m_IsOpen = false; |
|
} |
|
catch |
|
{ |
|
|
|
} |
|
} |
|
|
|
#endregion |
|
|
|
#region Sending Events |
|
|
|
void SendCommandEvent(CommandProto command) |
|
{ |
|
switch (command) |
|
{ |
|
case CommandProto.Quit: |
|
{ |
|
NotifyQuitAndShutDownChannel(); |
|
return; |
|
} |
|
case CommandProto.Reset: |
|
{ |
|
foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys) |
|
{ |
|
m_OrderedAgentsRequestingDecisions[brainName].Clear(); |
|
} |
|
ResetCommandReceived?.Invoke(); |
|
return; |
|
} |
|
default: |
|
{ |
|
return; |
|
} |
|
} |
|
} |
|
|
|
#endregion |
|
|
|
#region Sending and retreiving data |
|
|
|
public void DecideBatch() |
|
{ |
|
if (!m_NeedCommunicateThisStep) |
|
{ |
|
return; |
|
} |
|
m_NeedCommunicateThisStep = false; |
|
|
|
SendBatchedMessageHelper(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public void PutObservations(string behaviorName, AgentInfo info, List<ISensor> sensors) |
|
{ |
|
#if DEBUG |
|
if (!m_SensorShapeValidators.ContainsKey(behaviorName)) |
|
{ |
|
m_SensorShapeValidators[behaviorName] = new SensorShapeValidator(); |
|
} |
|
m_SensorShapeValidators[behaviorName].ValidateSensors(sensors); |
|
#endif |
|
|
|
using (TimerStack.Instance.Scoped("AgentInfo.ToProto")) |
|
{ |
|
var agentInfoProto = info.ToAgentInfoProto(); |
|
|
|
using (TimerStack.Instance.Scoped("GenerateSensorData")) |
|
{ |
|
foreach (var sensor in sensors) |
|
{ |
|
var obsProto = sensor.GetObservationProto(m_ObservationWriter); |
|
agentInfoProto.Observations.Add(obsProto); |
|
} |
|
} |
|
m_CurrentUnityRlOutput.AgentInfos[behaviorName].Value.Add(agentInfoProto); |
|
} |
|
|
|
m_NeedCommunicateThisStep = true; |
|
if (!m_OrderedAgentsRequestingDecisions.ContainsKey(behaviorName)) |
|
{ |
|
m_OrderedAgentsRequestingDecisions[behaviorName] = new List<int>(); |
|
} |
|
if (!info.done) |
|
{ |
|
m_OrderedAgentsRequestingDecisions[behaviorName].Add(info.episodeId); |
|
} |
|
if (!m_LastActionsReceived.ContainsKey(behaviorName)) |
|
{ |
|
m_LastActionsReceived[behaviorName] = new Dictionary<int, ActionBuffers>(); |
|
} |
|
m_LastActionsReceived[behaviorName][info.episodeId] = ActionBuffers.Empty; |
|
if (info.done) |
|
{ |
|
m_LastActionsReceived[behaviorName].Remove(info.episodeId); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
void SendBatchedMessageHelper() |
|
{ |
|
var message = new UnityOutputProto |
|
{ |
|
RlOutput = m_CurrentUnityRlOutput, |
|
}; |
|
var tempUnityRlInitializationOutput = GetTempUnityRlInitializationOutput(); |
|
if (tempUnityRlInitializationOutput != null) |
|
{ |
|
message.RlInitializationOutput = tempUnityRlInitializationOutput; |
|
} |
|
|
|
byte[] messageAggregated = SideChannelManager.GetSideChannelMessage(); |
|
message.RlOutput.SideChannel = ByteString.CopyFrom(messageAggregated); |
|
|
|
var input = Exchange(message); |
|
UpdateSentActionSpec(tempUnityRlInitializationOutput); |
|
|
|
foreach (var k in m_CurrentUnityRlOutput.AgentInfos.Keys) |
|
{ |
|
m_CurrentUnityRlOutput.AgentInfos[k].Value.Clear(); |
|
} |
|
|
|
var rlInput = input?.RlInput; |
|
|
|
if (rlInput?.AgentActions == null) |
|
{ |
|
return; |
|
} |
|
|
|
UpdateEnvironmentWithInput(rlInput); |
|
|
|
foreach (var brainName in rlInput.AgentActions.Keys) |
|
{ |
|
if (!m_OrderedAgentsRequestingDecisions[brainName].Any()) |
|
{ |
|
continue; |
|
} |
|
|
|
if (!rlInput.AgentActions[brainName].Value.Any()) |
|
{ |
|
continue; |
|
} |
|
|
|
var agentActions = rlInput.AgentActions[brainName].ToAgentActionList(); |
|
var numAgents = m_OrderedAgentsRequestingDecisions[brainName].Count; |
|
for (var i = 0; i < numAgents; i++) |
|
{ |
|
var agentAction = agentActions[i]; |
|
var agentId = m_OrderedAgentsRequestingDecisions[brainName][i]; |
|
if (m_LastActionsReceived[brainName].ContainsKey(agentId)) |
|
{ |
|
m_LastActionsReceived[brainName][agentId] = agentAction; |
|
} |
|
} |
|
} |
|
foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys) |
|
{ |
|
m_OrderedAgentsRequestingDecisions[brainName].Clear(); |
|
} |
|
} |
|
|
|
public ActionBuffers GetActions(string behaviorName, int agentId) |
|
{ |
|
if (m_LastActionsReceived.ContainsKey(behaviorName)) |
|
{ |
|
if (m_LastActionsReceived[behaviorName].ContainsKey(agentId)) |
|
{ |
|
return m_LastActionsReceived[behaviorName][agentId]; |
|
} |
|
} |
|
return ActionBuffers.Empty; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
UnityInputProto Exchange(UnityOutputProto unityOutput) |
|
{ |
|
if (!m_IsOpen) |
|
{ |
|
return null; |
|
} |
|
|
|
try |
|
{ |
|
var message = m_Client.Exchange(WrapMessage(unityOutput, 200)); |
|
if (message.Header.Status == 200) |
|
{ |
|
return message.UnityInput; |
|
} |
|
|
|
m_IsOpen = false; |
|
|
|
|
|
|
|
NotifyQuitAndShutDownChannel(); |
|
return message.UnityInput; |
|
} |
|
catch (Exception ex) |
|
{ |
|
if (ex is RpcException rpcException) |
|
{ |
|
|
|
switch (rpcException.Status.StatusCode) |
|
{ |
|
case StatusCode.Unavailable: |
|
|
|
break; |
|
case StatusCode.ResourceExhausted: |
|
|
|
|
|
|
|
Debug.LogError($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer."); |
|
break; |
|
default: |
|
|
|
Debug.Log($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer."); |
|
break; |
|
} |
|
} |
|
else |
|
{ |
|
|
|
Debug.LogError($"Communication Exception: {ex.Message}. Disconnecting from trainer."); |
|
} |
|
|
|
m_IsOpen = false; |
|
NotifyQuitAndShutDownChannel(); |
|
return null; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static UnityMessageProto WrapMessage(UnityOutputProto content, int status) |
|
{ |
|
return new UnityMessageProto |
|
{ |
|
Header = new HeaderProto { Status = status }, |
|
UnityOutput = content |
|
}; |
|
} |
|
|
|
void CacheActionSpec(string behaviorName, ActionSpec actionSpec) |
|
{ |
|
if (m_SentBrainKeys.Contains(behaviorName)) |
|
{ |
|
return; |
|
} |
|
|
|
|
|
m_UnsentBrainKeys[behaviorName] = actionSpec; |
|
} |
|
|
|
UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput() |
|
{ |
|
UnityRLInitializationOutputProto output = null; |
|
foreach (var behaviorName in m_UnsentBrainKeys.Keys) |
|
{ |
|
if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(behaviorName)) |
|
{ |
|
if (m_CurrentUnityRlOutput.AgentInfos[behaviorName].CalculateSize() > 0) |
|
{ |
|
|
|
|
|
|
|
|
|
if (output == null) |
|
{ |
|
output = new UnityRLInitializationOutputProto(); |
|
} |
|
|
|
var actionSpec = m_UnsentBrainKeys[behaviorName]; |
|
output.BrainParameters.Add(actionSpec.ToBrainParametersProto(behaviorName, true)); |
|
} |
|
} |
|
} |
|
|
|
return output; |
|
} |
|
|
|
void UpdateSentActionSpec(UnityRLInitializationOutputProto output) |
|
{ |
|
if (output == null) |
|
{ |
|
return; |
|
} |
|
|
|
foreach (var brainProto in output.BrainParameters) |
|
{ |
|
m_SentBrainKeys.Add(brainProto.BrainName); |
|
m_UnsentBrainKeys.Remove(brainProto.BrainName); |
|
} |
|
} |
|
|
|
#endregion |
|
|
|
#if UNITY_EDITOR |
|
|
|
|
|
|
|
|
|
void HandleOnPlayModeChanged(PlayModeStateChange state) |
|
{ |
|
|
|
if (state == PlayModeStateChange.ExitingPlayMode) |
|
{ |
|
Dispose(); |
|
} |
|
} |
|
|
|
#endif |
|
} |
|
} |
|
#endif // UNITY_EDITOR || UNITY_STANDALONE |
|
|