|
using System; |
|
using System.Linq; |
|
using NUnit.Framework; |
|
using UnityEngine; |
|
using UnityEditor; |
|
using Unity.Barracuda; |
|
using Unity.MLAgents.Actuators; |
|
using Unity.MLAgents.Inference; |
|
using Unity.MLAgents.Policies; |
|
using System.Collections.Generic; |
|
|
|
namespace Unity.MLAgents.Tests |
|
{ |
|
public class FloatThresholdComparer : IEqualityComparer<float> |
|
{ |
|
private readonly float _threshold; |
|
public FloatThresholdComparer(float threshold) |
|
{ |
|
_threshold = threshold; |
|
} |
|
|
|
public bool Equals(float x, float y) |
|
{ |
|
return Math.Abs(x - y) < _threshold; |
|
} |
|
|
|
public int GetHashCode(float f) |
|
{ |
|
throw new NotImplementedException("Unable to generate a hash code for threshold floats, do not use this method"); |
|
} |
|
} |
|
|
|
[TestFixture] |
|
public class ModelRunnerTest |
|
{ |
|
const string k_hybrid_ONNX_recurr_v2 = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis8vec_2c_2_3d_v2_0.onnx"; |
|
|
|
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx"; |
|
const string k_discreteONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_obsolete_recurr_v1_0.onnx"; |
|
const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx"; |
|
const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn"; |
|
const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn"; |
|
|
|
private const string k_deterministic_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx"; |
|
private const string k_deterministic_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx"; |
|
|
|
NNModel hybridONNXModelV2; |
|
NNModel continuousONNXModel; |
|
NNModel discreteONNXModel; |
|
NNModel hybridONNXModel; |
|
NNModel continuousNNModel; |
|
NNModel discreteNNModel; |
|
NNModel deterministicDiscreteNNModel; |
|
NNModel deterministicContinuousNNModel; |
|
Test3DSensorComponent sensor_21_20_3; |
|
Test3DSensorComponent sensor_20_22_3; |
|
|
|
|
|
ActionSpec GetContinuous2vis8vec2actionActionSpec() |
|
{ |
|
return ActionSpec.MakeContinuous(2); |
|
} |
|
|
|
ActionSpec GetDiscrete1vis0vec_2_3action_recurrModelActionSpec() |
|
{ |
|
return ActionSpec.MakeDiscrete(2, 3); |
|
} |
|
|
|
ActionSpec GetHybrid0vis53vec_3c_2dActionSpec() |
|
{ |
|
return new ActionSpec(3, new[] { 2 }); |
|
} |
|
|
|
[SetUp] |
|
public void SetUp() |
|
{ |
|
hybridONNXModelV2 = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybrid_ONNX_recurr_v2, typeof(NNModel)); |
|
|
|
continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel)); |
|
discreteONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteONNXPath, typeof(NNModel)); |
|
hybridONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybridONNXPath, typeof(NNModel)); |
|
continuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousNNPath, typeof(NNModel)); |
|
discreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteNNPath, typeof(NNModel)); |
|
deterministicDiscreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deterministic_discreteNNPath, typeof(NNModel)); |
|
deterministicContinuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deterministic_continuousNNPath, typeof(NNModel)); |
|
var go = new GameObject("SensorA"); |
|
sensor_21_20_3 = go.AddComponent<Test3DSensorComponent>(); |
|
sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3); |
|
sensor_20_22_3 = go.AddComponent<Test3DSensorComponent>(); |
|
sensor_20_22_3.Sensor = new Test3DSensor("SensorB", 20, 22, 3); |
|
} |
|
|
|
[Test] |
|
public void TestModelExist() |
|
{ |
|
Assert.IsNotNull(continuousONNXModel); |
|
Assert.IsNotNull(discreteONNXModel); |
|
Assert.IsNotNull(hybridONNXModel); |
|
Assert.IsNotNull(continuousNNModel); |
|
Assert.IsNotNull(discreteNNModel); |
|
Assert.IsNotNull(hybridONNXModelV2); |
|
Assert.IsNotNull(deterministicDiscreteNNModel); |
|
Assert.IsNotNull(deterministicContinuousNNModel); |
|
} |
|
|
|
[Test] |
|
public void TestCreation() |
|
{ |
|
var inferenceDevice = InferenceDevice.Burst; |
|
var modelRunner = new ModelRunner(continuousONNXModel, GetContinuous2vis8vec2actionActionSpec(), inferenceDevice); |
|
modelRunner.Dispose(); |
|
Assert.Throws<UnityAgentsException>(() => |
|
{ |
|
|
|
modelRunner = new ModelRunner(discreteONNXModel, GetDiscrete1vis0vec_2_3action_recurrModelActionSpec(), inferenceDevice); |
|
modelRunner.Dispose(); |
|
}); |
|
modelRunner = new ModelRunner(hybridONNXModel, GetHybrid0vis53vec_3c_2dActionSpec(), inferenceDevice); |
|
modelRunner.Dispose(); |
|
modelRunner = new ModelRunner(continuousNNModel, GetContinuous2vis8vec2actionActionSpec(), inferenceDevice); |
|
modelRunner.Dispose(); |
|
|
|
Assert.Throws<UnityAgentsException>(() => |
|
{ |
|
|
|
modelRunner = new ModelRunner(discreteNNModel, GetDiscrete1vis0vec_2_3action_recurrModelActionSpec(), inferenceDevice); |
|
modelRunner.Dispose(); |
|
}); |
|
|
|
modelRunner = new ModelRunner(hybridONNXModelV2, new ActionSpec(2, new[] { 2, 3 }), inferenceDevice); |
|
modelRunner.Dispose(); |
|
|
|
|
|
modelRunner = new ModelRunner(deterministicDiscreteNNModel, new ActionSpec(0, new[] { 7 }), inferenceDevice); |
|
modelRunner.Dispose(); |
|
|
|
modelRunner = new ModelRunner(deterministicContinuousNNModel, |
|
GetContinuous2vis8vec2actionActionSpec(), inferenceDevice, |
|
deterministicInference: true); |
|
modelRunner.Dispose(); |
|
} |
|
|
|
[Test] |
|
public void TestHasModel() |
|
{ |
|
var modelRunner = new ModelRunner(continuousONNXModel, GetContinuous2vis8vec2actionActionSpec(), InferenceDevice.CPU); |
|
Assert.True(modelRunner.HasModel(continuousONNXModel, InferenceDevice.CPU)); |
|
Assert.False(modelRunner.HasModel(continuousONNXModel, InferenceDevice.GPU)); |
|
Assert.False(modelRunner.HasModel(discreteONNXModel, InferenceDevice.CPU)); |
|
modelRunner.Dispose(); |
|
} |
|
|
|
[Test] |
|
public void TestRunModel() |
|
{ |
|
var actionSpec = GetContinuous2vis8vec2actionActionSpec(); |
|
var modelRunner = new ModelRunner(continuousONNXModel, actionSpec, InferenceDevice.Burst); |
|
var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8"); |
|
var info1 = new AgentInfo(); |
|
info1.episodeId = 1; |
|
modelRunner.PutObservations(info1, new[] |
|
{ |
|
sensor_8, |
|
sensor_21_20_3.CreateSensors()[0], |
|
sensor_20_22_3.CreateSensors()[0] |
|
}.ToList()); |
|
var info2 = new AgentInfo(); |
|
info2.episodeId = 2; |
|
modelRunner.PutObservations(info2, new[] |
|
{ |
|
sensor_8, |
|
sensor_21_20_3.CreateSensors()[0], |
|
sensor_20_22_3.CreateSensors()[0] |
|
}.ToList()); |
|
|
|
modelRunner.DecideBatch(); |
|
|
|
Assert.IsFalse(modelRunner.GetAction(1).Equals(ActionBuffers.Empty)); |
|
Assert.IsFalse(modelRunner.GetAction(2).Equals(ActionBuffers.Empty)); |
|
Assert.IsTrue(modelRunner.GetAction(3).Equals(ActionBuffers.Empty)); |
|
Assert.AreEqual(actionSpec.NumDiscreteActions, modelRunner.GetAction(1).DiscreteActions.Length); |
|
modelRunner.Dispose(); |
|
} |
|
|
|
[Test] |
|
public void TestRunModel_stochastic() |
|
{ |
|
var actionSpec = GetContinuous2vis8vec2actionActionSpec(); |
|
|
|
var modelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst); |
|
var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8"); |
|
var info1 = new AgentInfo(); |
|
var obs = new[] |
|
{ |
|
sensor_8, |
|
sensor_21_20_3.CreateSensors()[0], |
|
sensor_20_22_3.CreateSensors()[0] |
|
}.ToList(); |
|
info1.episodeId = 1; |
|
modelRunner.PutObservations(info1, obs); |
|
modelRunner.DecideBatch(); |
|
var stochAction1 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone(); |
|
|
|
modelRunner.PutObservations(info1, obs); |
|
modelRunner.DecideBatch(); |
|
var stochAction2 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone(); |
|
|
|
Assert.IsFalse(Enumerable.SequenceEqual(stochAction1, stochAction2, new FloatThresholdComparer(0.001f))); |
|
modelRunner.Dispose(); |
|
} |
|
|
|
[Test] |
|
public void TestRunModel_deterministic() |
|
{ |
|
var actionSpec = GetContinuous2vis8vec2actionActionSpec(); |
|
var modelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst); |
|
var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8"); |
|
var info1 = new AgentInfo(); |
|
var obs = new[] |
|
{ |
|
sensor_8, |
|
sensor_21_20_3.CreateSensors()[0], |
|
sensor_20_22_3.CreateSensors()[0] |
|
}.ToList(); |
|
var deterministicModelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst, |
|
deterministicInference: true); |
|
info1.episodeId = 1; |
|
deterministicModelRunner.PutObservations(info1, obs); |
|
deterministicModelRunner.DecideBatch(); |
|
var deterministicAction1 = (float[])deterministicModelRunner.GetAction(1).ContinuousActions.Array.Clone(); |
|
|
|
deterministicModelRunner.PutObservations(info1, obs); |
|
deterministicModelRunner.DecideBatch(); |
|
var deterministicAction2 = (float[])deterministicModelRunner.GetAction(1).ContinuousActions.Array.Clone(); |
|
|
|
Assert.IsTrue(Enumerable.SequenceEqual(deterministicAction1, deterministicAction2, new FloatThresholdComparer(0.001f))); |
|
modelRunner.Dispose(); |
|
} |
|
} |
|
} |
|
|