|
using System; |
|
using System.Collections.Generic; |
|
using System.Linq; |
|
using Google.Protobuf; |
|
using Unity.MLAgents.CommunicatorObjects; |
|
using UnityEngine; |
|
using System.Runtime.CompilerServices; |
|
using Unity.MLAgents.Actuators; |
|
using Unity.MLAgents.Sensors; |
|
using Unity.MLAgents.Demonstrations; |
|
using Unity.MLAgents.Policies; |
|
|
|
using Unity.MLAgents.Analytics; |
|
|
|
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor")] |
|
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")] |
|
[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Utils.Tests")] |
|
|
|
namespace Unity.MLAgents |
|
{ |
|
internal static class GrpcExtensions |
|
{ |
|
#region AgentInfo |
|
|
|
|
|
|
|
private static bool s_HaveWarnedTrainerCapabilitiesAgentGroup; |
|
|
|
|
|
|
|
|
|
|
|
public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai) |
|
{ |
|
var agentInfoProto = ai.ToAgentInfoProto(); |
|
|
|
var agentActionProto = new AgentActionProto(); |
|
|
|
if (!ai.storedActions.IsEmpty()) |
|
{ |
|
if (!ai.storedActions.ContinuousActions.IsEmpty()) |
|
{ |
|
agentActionProto.ContinuousActions.AddRange(ai.storedActions.ContinuousActions.Array); |
|
} |
|
if (!ai.storedActions.DiscreteActions.IsEmpty()) |
|
{ |
|
agentActionProto.DiscreteActions.AddRange(ai.storedActions.DiscreteActions.Array); |
|
} |
|
} |
|
|
|
return new AgentInfoActionPairProto |
|
{ |
|
AgentInfo = agentInfoProto, |
|
ActionInfo = agentActionProto |
|
}; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai) |
|
{ |
|
if (ai.groupId > 0) |
|
{ |
|
var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.MultiAgentGroups; |
|
if (!trainerCanHandle) |
|
{ |
|
if (!s_HaveWarnedTrainerCapabilitiesAgentGroup) |
|
{ |
|
Debug.LogWarning( |
|
$"Attached trainer doesn't support Multi Agent Groups; group rewards will be ignored." + |
|
"Please find the versions that work best together from our release page: " + |
|
"https://github.com/Unity-Technologies/ml-agents/releases" |
|
); |
|
s_HaveWarnedTrainerCapabilitiesAgentGroup = true; |
|
} |
|
} |
|
} |
|
var agentInfoProto = new AgentInfoProto |
|
{ |
|
Reward = ai.reward, |
|
GroupReward = ai.groupReward, |
|
MaxStepReached = ai.maxStepReached, |
|
Done = ai.done, |
|
Id = ai.episodeId, |
|
GroupId = ai.groupId, |
|
}; |
|
|
|
if (ai.discreteActionMasks != null) |
|
{ |
|
agentInfoProto.ActionMask.AddRange(ai.discreteActionMasks); |
|
} |
|
|
|
return agentInfoProto; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public static List<ObservationSummary> GetObservationSummaries(this AgentInfoActionPairProto infoActionPair) |
|
{ |
|
List<ObservationSummary> summariesOut = new List<ObservationSummary>(); |
|
var agentInfo = infoActionPair.AgentInfo; |
|
foreach (var obs in agentInfo.Observations) |
|
{ |
|
var summary = new ObservationSummary(); |
|
summary.shape = obs.Shape.ToArray(); |
|
summariesOut.Add(summary); |
|
} |
|
|
|
return summariesOut; |
|
} |
|
|
|
#endregion |
|
|
|
#region BrainParameters |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static BrainParametersProto ToProto(this BrainParameters bp, string name, bool isTraining) |
|
{ |
|
|
|
#pragma warning disable CS0618 |
|
var brainParametersProto = new BrainParametersProto |
|
{ |
|
VectorActionSpaceTypeDeprecated = (SpaceTypeProto)bp.VectorActionSpaceType, |
|
BrainName = name, |
|
IsTraining = isTraining, |
|
ActionSpec = ToActionSpecProto(bp.ActionSpec), |
|
}; |
|
if (bp.VectorActionSize != null) |
|
{ |
|
brainParametersProto.VectorActionSizeDeprecated.AddRange(bp.VectorActionSize); |
|
} |
|
if (bp.VectorActionDescriptions != null) |
|
{ |
|
brainParametersProto.VectorActionDescriptionsDeprecated.AddRange(bp.VectorActionDescriptions); |
|
} |
|
#pragma warning restore CS0618 |
|
return brainParametersProto; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static BrainParametersProto ToBrainParametersProto(this ActionSpec actionSpec, string name, bool isTraining) |
|
{ |
|
var brainParametersProto = new BrainParametersProto |
|
{ |
|
BrainName = name, |
|
IsTraining = isTraining, |
|
ActionSpec = ToActionSpecProto(actionSpec), |
|
}; |
|
|
|
var supportHybrid = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.HybridActions; |
|
if (!supportHybrid) |
|
{ |
|
actionSpec.CheckAllContinuousOrDiscrete(); |
|
if (actionSpec.NumContinuousActions > 0) |
|
{ |
|
brainParametersProto.VectorActionSizeDeprecated.Add(actionSpec.NumContinuousActions); |
|
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Continuous; |
|
} |
|
else if (actionSpec.NumDiscreteActions > 0) |
|
{ |
|
brainParametersProto.VectorActionSizeDeprecated.AddRange(actionSpec.BranchSizes); |
|
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Discrete; |
|
} |
|
} |
|
|
|
|
|
return brainParametersProto; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public static BrainParameters ToBrainParameters(this BrainParametersProto bpp) |
|
{ |
|
ActionSpec actionSpec; |
|
if (bpp.ActionSpec == null) |
|
{ |
|
|
|
#pragma warning disable CS0618 |
|
var spaceType = (SpaceType)bpp.VectorActionSpaceTypeDeprecated; |
|
if (spaceType == SpaceType.Continuous) |
|
{ |
|
actionSpec = ActionSpec.MakeContinuous(bpp.VectorActionSizeDeprecated.ToArray()[0]); |
|
} |
|
else |
|
{ |
|
actionSpec = ActionSpec.MakeDiscrete(bpp.VectorActionSizeDeprecated.ToArray()); |
|
} |
|
#pragma warning restore CS0618 |
|
} |
|
else |
|
{ |
|
actionSpec = ToActionSpec(bpp.ActionSpec); |
|
} |
|
var bp = new BrainParameters |
|
{ |
|
VectorActionDescriptions = bpp.VectorActionDescriptionsDeprecated.ToArray(), |
|
ActionSpec = actionSpec, |
|
}; |
|
return bp; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public static ActionSpec ToActionSpec(this ActionSpecProto actionSpecProto) |
|
{ |
|
var actionSpec = new ActionSpec(actionSpecProto.NumContinuousActions); |
|
if (actionSpecProto.DiscreteBranchSizes != null) |
|
{ |
|
actionSpec.BranchSizes = actionSpecProto.DiscreteBranchSizes.ToArray(); |
|
} |
|
return actionSpec; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public static ActionSpecProto ToActionSpecProto(this ActionSpec actionSpec) |
|
{ |
|
var actionSpecProto = new ActionSpecProto |
|
{ |
|
NumContinuousActions = actionSpec.NumContinuousActions, |
|
NumDiscreteActions = actionSpec.NumDiscreteActions, |
|
}; |
|
if (actionSpec.BranchSizes != null) |
|
{ |
|
actionSpecProto.DiscreteBranchSizes.AddRange(actionSpec.BranchSizes); |
|
} |
|
return actionSpecProto; |
|
} |
|
|
|
#endregion |
|
|
|
#region DemonstrationMetaData |
|
|
|
|
|
|
|
public static DemonstrationMetaProto ToProto(this DemonstrationMetaData dm) |
|
{ |
|
var demonstrationName = dm.demonstrationName ?? ""; |
|
var demoProto = new DemonstrationMetaProto |
|
{ |
|
ApiVersion = DemonstrationMetaData.ApiVersion, |
|
MeanReward = dm.meanReward, |
|
NumberSteps = dm.numberSteps, |
|
NumberEpisodes = dm.numberEpisodes, |
|
DemonstrationName = demonstrationName |
|
}; |
|
return demoProto; |
|
} |
|
|
|
|
|
|
|
|
|
public static DemonstrationMetaData ToDemonstrationMetaData(this DemonstrationMetaProto demoProto) |
|
{ |
|
var dm = new DemonstrationMetaData |
|
{ |
|
numberEpisodes = demoProto.NumberEpisodes, |
|
numberSteps = demoProto.NumberSteps, |
|
meanReward = demoProto.MeanReward, |
|
demonstrationName = demoProto.DemonstrationName |
|
}; |
|
if (demoProto.ApiVersion != DemonstrationMetaData.ApiVersion) |
|
{ |
|
throw new Exception("API versions of demonstration are incompatible."); |
|
} |
|
return dm; |
|
} |
|
|
|
#endregion |
|
|
|
public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitializationInputProto inputProto) |
|
{ |
|
return new UnityRLInitParameters |
|
{ |
|
seed = inputProto.Seed, |
|
numAreas = inputProto.NumAreas, |
|
pythonLibraryVersion = inputProto.PackageVersion, |
|
pythonCommunicationVersion = inputProto.CommunicationVersion, |
|
TrainerCapabilities = inputProto.Capabilities.ToRLCapabilities() |
|
}; |
|
} |
|
|
|
#region AgentAction |
|
public static List<ActionBuffers> ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto) |
|
{ |
|
var agentActions = new List<ActionBuffers>(proto.Value.Count); |
|
foreach (var ap in proto.Value) |
|
{ |
|
agentActions.Add(ap.ToActionBuffers()); |
|
} |
|
return agentActions; |
|
} |
|
|
|
public static ActionBuffers ToActionBuffers(this AgentActionProto proto) |
|
{ |
|
return new ActionBuffers(proto.ContinuousActions.ToArray(), proto.DiscreteActions.ToArray()); |
|
} |
|
|
|
#endregion |
|
|
|
#region Observations |
|
|
|
|
|
|
|
private static bool s_HaveWarnedTrainerCapabilitiesMultiPng; |
|
private static bool s_HaveWarnedTrainerCapabilitiesMapping; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter) |
|
{ |
|
var obsSpec = sensor.GetObservationSpec(); |
|
var shape = obsSpec.Shape; |
|
ObservationProto observationProto = null; |
|
var compressionSpec = sensor.GetCompressionSpec(); |
|
var compressionType = compressionSpec.SensorCompressionType; |
|
|
|
if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3) |
|
{ |
|
var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations; |
|
if (!trainerCanHandle) |
|
{ |
|
if (!s_HaveWarnedTrainerCapabilitiesMultiPng) |
|
{ |
|
Debug.LogWarning( |
|
$"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}. " + |
|
"Please find the versions that work best together from our release page: " + |
|
"https://github.com/Unity-Technologies/ml-agents/releases" |
|
); |
|
s_HaveWarnedTrainerCapabilitiesMultiPng = true; |
|
} |
|
compressionType = SensorCompressionType.None; |
|
} |
|
} |
|
|
|
if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3) |
|
{ |
|
var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping; |
|
var isTrivialMapping = compressionSpec.IsTrivialMapping(); |
|
if (!trainerCanHandleMapping && !isTrivialMapping) |
|
{ |
|
if (!s_HaveWarnedTrainerCapabilitiesMapping) |
|
{ |
|
Debug.LogWarning( |
|
$"The sensor {sensor.GetName()} is using non-trivial mapping and " + |
|
"the attached trainer doesn't support compression mapping. " + |
|
"Switching to uncompressed observations. " + |
|
"Please find the versions that work best together from our release page: " + |
|
"https://github.com/Unity-Technologies/ml-agents/releases" |
|
); |
|
s_HaveWarnedTrainerCapabilitiesMapping = true; |
|
} |
|
compressionType = SensorCompressionType.None; |
|
} |
|
} |
|
|
|
if (compressionType == SensorCompressionType.None) |
|
{ |
|
var numFloats = sensor.ObservationSize(); |
|
var floatDataProto = new ObservationProto.Types.FloatData(); |
|
|
|
|
|
for (var i = 0; i < numFloats; i++) |
|
{ |
|
floatDataProto.Data.Add(0.0f); |
|
} |
|
|
|
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0); |
|
sensor.Write(observationWriter); |
|
|
|
observationProto = new ObservationProto |
|
{ |
|
FloatData = floatDataProto, |
|
CompressionType = (CompressionTypeProto)SensorCompressionType.None, |
|
}; |
|
} |
|
else |
|
{ |
|
var compressedObs = sensor.GetCompressedObservation(); |
|
if (compressedObs == null) |
|
{ |
|
throw new UnityAgentsException( |
|
$"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " + |
|
"You must return a byte[]. If you don't want to use compressed observations, " + |
|
"return CompressionSpec.Default() from GetCompressionSpec()." |
|
); |
|
} |
|
observationProto = new ObservationProto |
|
{ |
|
CompressedData = ByteString.CopyFrom(compressedObs), |
|
CompressionType = (CompressionTypeProto)sensor.GetCompressionSpec().SensorCompressionType, |
|
}; |
|
if (compressionSpec.CompressedChannelMapping != null) |
|
{ |
|
observationProto.CompressedChannelMapping.AddRange(compressionSpec.CompressedChannelMapping); |
|
} |
|
} |
|
|
|
|
|
var dimensionProperties = obsSpec.DimensionProperties; |
|
for (int i = 0; i < dimensionProperties.Length; i++) |
|
{ |
|
observationProto.DimensionProperties.Add((int)dimensionProperties[i]); |
|
} |
|
|
|
|
|
if (dimensionProperties == new InplaceArray<DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None)) |
|
{ |
|
var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation; |
|
if (!trainerCanHandleVarLenObs) |
|
{ |
|
throw new UnityAgentsException("Variable Length Observations are not supported by the trainer"); |
|
} |
|
} |
|
|
|
for (var i = 0; i < shape.Length; i++) |
|
{ |
|
observationProto.Shape.Add(shape[i]); |
|
} |
|
|
|
var sensorName = sensor.GetName(); |
|
if (!string.IsNullOrEmpty(sensorName)) |
|
{ |
|
observationProto.Name = sensorName; |
|
} |
|
|
|
observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType; |
|
return observationProto; |
|
} |
|
|
|
#endregion |
|
|
|
public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto proto) |
|
{ |
|
return new UnityRLCapabilities |
|
{ |
|
BaseRLCapabilities = proto.BaseRLCapabilities, |
|
ConcatenatedPngObservations = proto.ConcatenatedPngObservations, |
|
CompressedChannelMapping = proto.CompressedChannelMapping, |
|
HybridActions = proto.HybridActions, |
|
TrainingAnalytics = proto.TrainingAnalytics, |
|
VariableLengthObservation = proto.VariableLengthObservation, |
|
MultiAgentGroups = proto.MultiAgentGroups, |
|
}; |
|
} |
|
|
|
public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps) |
|
{ |
|
return new UnityRLCapabilitiesProto |
|
{ |
|
BaseRLCapabilities = rlCaps.BaseRLCapabilities, |
|
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations, |
|
CompressedChannelMapping = rlCaps.CompressedChannelMapping, |
|
HybridActions = rlCaps.HybridActions, |
|
TrainingAnalytics = rlCaps.TrainingAnalytics, |
|
VariableLengthObservation = rlCaps.VariableLengthObservation, |
|
MultiAgentGroups = rlCaps.MultiAgentGroups, |
|
}; |
|
} |
|
|
|
#region Analytics |
|
internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitializedEvent( |
|
this TrainingEnvironmentInitialized inputProto) |
|
{ |
|
return new TrainingEnvironmentInitializedEvent |
|
{ |
|
TrainerPythonVersion = inputProto.PythonVersion, |
|
MLAgentsVersion = inputProto.MlagentsVersion, |
|
MLAgentsEnvsVersion = inputProto.MlagentsEnvsVersion, |
|
TorchVersion = inputProto.TorchVersion, |
|
TorchDeviceType = inputProto.TorchDeviceType, |
|
NumEnvironments = inputProto.NumEnvs, |
|
NumEnvironmentParameters = inputProto.NumEnvironmentParameters, |
|
RunOptions = inputProto.RunOptions, |
|
}; |
|
} |
|
|
|
internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEvent( |
|
this TrainingBehaviorInitialized inputProto) |
|
{ |
|
RewardSignals rewardSignals = 0; |
|
rewardSignals |= inputProto.ExtrinsicRewardEnabled ? RewardSignals.Extrinsic : 0; |
|
rewardSignals |= inputProto.GailRewardEnabled ? RewardSignals.Gail : 0; |
|
rewardSignals |= inputProto.CuriosityRewardEnabled ? RewardSignals.Curiosity : 0; |
|
rewardSignals |= inputProto.RndRewardEnabled ? RewardSignals.Rnd : 0; |
|
|
|
TrainingFeatures trainingFeatures = 0; |
|
trainingFeatures |= inputProto.BehavioralCloningEnabled ? TrainingFeatures.BehavioralCloning : 0; |
|
trainingFeatures |= inputProto.RecurrentEnabled ? TrainingFeatures.Recurrent : 0; |
|
trainingFeatures |= inputProto.TrainerThreaded ? TrainingFeatures.Threaded : 0; |
|
trainingFeatures |= inputProto.SelfPlayEnabled ? TrainingFeatures.SelfPlay : 0; |
|
trainingFeatures |= inputProto.CurriculumEnabled ? TrainingFeatures.Curriculum : 0; |
|
|
|
|
|
return new TrainingBehaviorInitializedEvent |
|
{ |
|
BehaviorName = inputProto.BehaviorName, |
|
TrainerType = inputProto.TrainerType, |
|
RewardSignalFlags = rewardSignals, |
|
TrainingFeatureFlags = trainingFeatures, |
|
VisualEncoder = inputProto.VisualEncoder, |
|
NumNetworkLayers = inputProto.NumNetworkLayers, |
|
NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits, |
|
Config = inputProto.Config, |
|
}; |
|
} |
|
|
|
#endregion |
|
} |
|
} |
|
|