|
using System; |
|
using System.Collections.Generic; |
|
using System.Linq; |
|
using Unity.Barracuda; |
|
using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck; |
|
|
|
namespace Unity.MLAgents.Inference |
|
{ |
|
|
|
|
|
|
|
internal static class BarracudaModelExtensions |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static string[] GetInputNames(this Model model) |
|
{ |
|
var names = new List<string>(); |
|
|
|
if (model == null) |
|
return names.ToArray(); |
|
|
|
foreach (var input in model.inputs) |
|
{ |
|
names.Add(input.name); |
|
} |
|
|
|
foreach (var mem in model.memories) |
|
{ |
|
names.Add(mem.input); |
|
} |
|
|
|
names.Sort(StringComparer.InvariantCulture); |
|
|
|
return names.ToArray(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static int GetVersion(this Model model) |
|
{ |
|
return (int)model.GetTensorByName(TensorNames.VersionNumber)[0]; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static IReadOnlyList<TensorProxy> GetInputTensors(this Model model) |
|
{ |
|
var tensors = new List<TensorProxy>(); |
|
|
|
if (model == null) |
|
return tensors; |
|
|
|
foreach (var input in model.inputs) |
|
{ |
|
tensors.Add(new TensorProxy |
|
{ |
|
name = input.name, |
|
valueType = TensorProxy.TensorType.FloatingPoint, |
|
data = null, |
|
shape = input.shape.Select(i => (long)i).ToArray() |
|
}); |
|
} |
|
|
|
tensors.Sort((el1, el2) => string.Compare(el1.name, el2.name, StringComparison.InvariantCulture)); |
|
|
|
return tensors; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static int GetNumVisualInputs(this Model model) |
|
{ |
|
var count = 0; |
|
if (model == null) |
|
return count; |
|
|
|
foreach (var input in model.inputs) |
|
{ |
|
if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix)) |
|
{ |
|
count++; |
|
} |
|
} |
|
|
|
return count; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static string[] GetOutputNames(this Model model, bool deterministicInference = false) |
|
{ |
|
var names = new List<string>(); |
|
|
|
if (model == null) |
|
{ |
|
return names.ToArray(); |
|
} |
|
|
|
if (model.HasContinuousOutputs(deterministicInference)) |
|
{ |
|
names.Add(model.ContinuousOutputName(deterministicInference)); |
|
} |
|
if (model.HasDiscreteOutputs(deterministicInference)) |
|
{ |
|
names.Add(model.DiscreteOutputName(deterministicInference)); |
|
} |
|
|
|
var modelVersion = model.GetVersion(); |
|
var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; |
|
if (memory > 0) |
|
{ |
|
names.Add(TensorNames.RecurrentOutput); |
|
} |
|
|
|
names.Sort(StringComparer.InvariantCulture); |
|
|
|
return names.ToArray(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false) |
|
{ |
|
if (model == null) |
|
return false; |
|
if (!model.SupportsContinuousAndDiscrete()) |
|
{ |
|
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0; |
|
} |
|
else |
|
{ |
|
bool hasStochasticOutput = !deterministicInference && |
|
model.outputs.Contains(TensorNames.ContinuousActionOutput); |
|
bool hasDeterministicOutput = deterministicInference && |
|
model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput); |
|
|
|
return (hasStochasticOutput || hasDeterministicOutput) && |
|
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static int ContinuousOutputSize(this Model model) |
|
{ |
|
if (model == null) |
|
return 0; |
|
if (!model.SupportsContinuousAndDiscrete()) |
|
{ |
|
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
|
(int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0] : 0; |
|
} |
|
else |
|
{ |
|
var continuousOutputShape = model.GetTensorByName(TensorNames.ContinuousActionOutputShape); |
|
return continuousOutputShape == null ? 0 : (int)continuousOutputShape[0]; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static string ContinuousOutputName(this Model model, bool deterministicInference = false) |
|
{ |
|
if (model == null) |
|
return null; |
|
if (!model.SupportsContinuousAndDiscrete()) |
|
{ |
|
return TensorNames.ActionOutputDeprecated; |
|
} |
|
else |
|
{ |
|
return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false) |
|
{ |
|
if (model == null) |
|
return false; |
|
if (!model.SupportsContinuousAndDiscrete()) |
|
{ |
|
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0; |
|
} |
|
else |
|
{ |
|
bool hasStochasticOutput = !deterministicInference && |
|
model.outputs.Contains(TensorNames.DiscreteActionOutput); |
|
bool hasDeterministicOutput = deterministicInference && |
|
model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput); |
|
return (hasStochasticOutput || hasDeterministicOutput) && |
|
model.DiscreteOutputSize() > 0; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static int DiscreteOutputSize(this Model model) |
|
{ |
|
if (model == null) |
|
return 0; |
|
if (!model.SupportsContinuousAndDiscrete()) |
|
{ |
|
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
|
0 : (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0]; |
|
} |
|
else |
|
{ |
|
var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape); |
|
if (discreteOutputShape == null) |
|
{ |
|
return 0; |
|
} |
|
else |
|
{ |
|
int result = 0; |
|
for (int i = 0; i < discreteOutputShape.length; i++) |
|
{ |
|
result += (int)discreteOutputShape[i]; |
|
} |
|
return result; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static string DiscreteOutputName(this Model model, bool deterministicInference = false) |
|
{ |
|
if (model == null) |
|
return null; |
|
if (!model.SupportsContinuousAndDiscrete()) |
|
{ |
|
return TensorNames.ActionOutputDeprecated; |
|
} |
|
else |
|
{ |
|
return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static bool SupportsContinuousAndDiscrete(this Model model) |
|
{ |
|
return model == null || |
|
model.outputs.Contains(TensorNames.ContinuousActionOutput) || |
|
model.outputs.Contains(TensorNames.DiscreteActionOutput); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks, bool deterministicInference = false) |
|
{ |
|
|
|
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber); |
|
if (modelApiVersionTensor == null) |
|
{ |
|
failedModelChecks.Add( |
|
FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.") |
|
); |
|
return false; |
|
} |
|
|
|
|
|
var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize); |
|
if (memorySizeTensor == null) |
|
{ |
|
failedModelChecks.Add( |
|
FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.") |
|
); |
|
return false; |
|
} |
|
|
|
|
|
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) && |
|
!model.outputs.Contains(TensorNames.ContinuousActionOutput) && |
|
!model.outputs.Contains(TensorNames.DiscreteActionOutput) && |
|
!model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) && |
|
!model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput)) |
|
{ |
|
failedModelChecks.Add( |
|
FailedCheck.Warning("The model does not contain any Action Output Node.") |
|
); |
|
return false; |
|
} |
|
|
|
|
|
if (!model.SupportsContinuousAndDiscrete()) |
|
{ |
|
if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null) |
|
{ |
|
failedModelChecks.Add( |
|
FailedCheck.Warning("The model does not contain any Action Output Shape Node.") |
|
); |
|
return false; |
|
} |
|
if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null) |
|
{ |
|
failedModelChecks.Add( |
|
FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " + |
|
"not found in the model file. " + |
|
"This is only required for model that uses a deprecated model format.") |
|
); |
|
return false; |
|
} |
|
} |
|
else |
|
{ |
|
if (model.outputs.Contains(TensorNames.ContinuousActionOutput)) |
|
{ |
|
if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null) |
|
{ |
|
failedModelChecks.Add( |
|
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.") |
|
); |
|
return false; |
|
} |
|
else if (!model.HasContinuousOutputs(deterministicInference)) |
|
{ |
|
var actionType = deterministicInference ? "deterministic" : "stochastic"; |
|
var actionName = deterministicInference ? "Deterministic" : ""; |
|
failedModelChecks.Add( |
|
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..") |
|
); |
|
return false; |
|
} |
|
} |
|
|
|
if (model.outputs.Contains(TensorNames.DiscreteActionOutput)) |
|
{ |
|
if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null) |
|
{ |
|
failedModelChecks.Add( |
|
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.") |
|
); |
|
return false; |
|
} |
|
else if (!model.HasDiscreteOutputs(deterministicInference)) |
|
{ |
|
var actionType = deterministicInference ? "deterministic" : "stochastic"; |
|
var actionName = deterministicInference ? "Deterministic" : ""; |
|
failedModelChecks.Add( |
|
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.") |
|
); |
|
return false; |
|
} |
|
} |
|
} |
|
return true; |
|
} |
|
} |
|
} |
|
|