|
using NUnit.Framework; |
|
using System; |
|
using System.Linq; |
|
using Unity.MLAgents.Actuators; |
|
using Unity.MLAgents.Policies; |
|
using UnityEngine; |
|
using Unity.MLAgents.Sensors; |
|
using Unity.MLAgents.Utils.Tests; |
|
|
|
namespace Unity.MLAgents.Tests |
|
{ |
|
public class StackingSensorTests |
|
{ |
|
[SetUp] |
|
public void SetUp() |
|
{ |
|
if (Academy.IsInitialized) |
|
{ |
|
Academy.Instance.Dispose(); |
|
} |
|
|
|
Academy.Instance.AutomaticSteppingEnabled = false; |
|
} |
|
|
|
[TearDown] |
|
public void TearDown() |
|
{ |
|
CommunicatorFactory.ClearCreator(); |
|
} |
|
|
|
[Test] |
|
public void TestCtor() |
|
{ |
|
ISensor wrapped = new VectorSensor(4); |
|
ISensor sensor = new StackingSensor(wrapped, 4); |
|
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName()); |
|
Assert.AreEqual(sensor.GetObservationSpec().Shape, new InplaceArray<int>(16)); |
|
} |
|
|
|
[Test] |
|
public void AssertStackingReset() |
|
{ |
|
var agentGo1 = new GameObject("TestAgent"); |
|
var bp1 = agentGo1.AddComponent<BehaviorParameters>(); |
|
bp1.BrainParameters.NumStackedVectorObservations = 3; |
|
bp1.BrainParameters.ActionSpec = ActionSpec.MakeContinuous(1); |
|
var aca = Academy.Instance; |
|
var agent1 = agentGo1.AddComponent<TestAgent>(); |
|
var policy = new TestPolicy(); |
|
agent1.SetPolicy(policy); |
|
|
|
StackingSensor sensor = null; |
|
foreach (ISensor s in agent1.sensors) |
|
{ |
|
if (s is StackingSensor) |
|
{ |
|
sensor = s as StackingSensor; |
|
} |
|
} |
|
|
|
Assert.NotNull(sensor); |
|
|
|
for (int i = 0; i < 20; i++) |
|
{ |
|
agent1.RequestDecision(); |
|
aca.EnvironmentStep(); |
|
} |
|
SensorTestHelper.CompareObservation(sensor, new[] { 18f, 19f, 20f }); |
|
policy.OnRequestDecision = () => SensorTestHelper.CompareObservation(sensor, new[] { 19f, 20f, 21f }); |
|
agent1.EndEpisode(); |
|
policy.OnRequestDecision = () => { }; |
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f }); |
|
for (int i = 0; i < 20; i++) |
|
{ |
|
agent1.RequestDecision(); |
|
aca.EnvironmentStep(); |
|
SensorTestHelper.CompareObservation(sensor, new[] { Math.Max(0, i - 1f), i, i + 1 }); |
|
} |
|
} |
|
|
|
[Test] |
|
public void TestVectorStacking() |
|
{ |
|
VectorSensor wrapped = new VectorSensor(2); |
|
StackingSensor sensor = new StackingSensor(wrapped, 3); |
|
|
|
wrapped.AddObservation(new[] { 1f, 2f }); |
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f }); |
|
var data = sensor.GetStackedObservations(); |
|
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 0f, 0f, 1f, 2f })); |
|
|
|
sensor.Update(); |
|
wrapped.AddObservation(new[] { 3f, 4f }); |
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f }); |
|
data = sensor.GetStackedObservations(); |
|
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 1f, 2f, 3f, 4f })); |
|
|
|
sensor.Update(); |
|
wrapped.AddObservation(new[] { 5f, 6f }); |
|
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f, 5f, 6f }); |
|
data = sensor.GetStackedObservations(); |
|
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 1f, 2f, 3f, 4f, 5f, 6f })); |
|
|
|
sensor.Update(); |
|
wrapped.AddObservation(new[] { 7f, 8f }); |
|
SensorTestHelper.CompareObservation(sensor, new[] { 3f, 4f, 5f, 6f, 7f, 8f }); |
|
data = sensor.GetStackedObservations(); |
|
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 3f, 4f, 5f, 6f, 7f, 8f })); |
|
|
|
sensor.Update(); |
|
wrapped.AddObservation(new[] { 9f, 10f }); |
|
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f }); |
|
data = sensor.GetStackedObservations(); |
|
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f })); |
|
|
|
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f }); |
|
data = sensor.GetStackedObservations(); |
|
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f })); |
|
} |
|
|
|
[Test] |
|
public void TestVectorStackingReset() |
|
{ |
|
VectorSensor wrapped = new VectorSensor(2); |
|
ISensor sensor = new StackingSensor(wrapped, 3); |
|
|
|
wrapped.AddObservation(new[] { 1f, 2f }); |
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f }); |
|
|
|
sensor.Update(); |
|
wrapped.AddObservation(new[] { 3f, 4f }); |
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f }); |
|
|
|
sensor.Reset(); |
|
wrapped.AddObservation(new[] { 5f, 6f }); |
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 5f, 6f }); |
|
} |
|
|
|
class Dummy3DSensor : ISensor |
|
{ |
|
public SensorCompressionType CompressionType = SensorCompressionType.PNG; |
|
public int[] Mapping; |
|
public ObservationSpec ObservationSpec; |
|
public float[,,] CurrentObservation; |
|
|
|
public ObservationSpec GetObservationSpec() |
|
{ |
|
return ObservationSpec; |
|
} |
|
|
|
public int Write(ObservationWriter writer) |
|
{ |
|
for (var h = 0; h < ObservationSpec.Shape[0]; h++) |
|
{ |
|
for (var w = 0; w < ObservationSpec.Shape[1]; w++) |
|
{ |
|
for (var c = 0; c < ObservationSpec.Shape[2]; c++) |
|
{ |
|
writer[h, w, c] = CurrentObservation[h, w, c]; |
|
} |
|
} |
|
} |
|
return ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]; |
|
} |
|
|
|
public byte[] GetCompressedObservation() |
|
{ |
|
var writer = new ObservationWriter(); |
|
var flattenedObservation = new float[ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]]; |
|
writer.SetTarget(flattenedObservation, ObservationSpec.Shape, 0); |
|
Write(writer); |
|
byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z); |
|
return bytes; |
|
} |
|
|
|
public void Update() { } |
|
|
|
public void Reset() { } |
|
|
|
public CompressionSpec GetCompressionSpec() |
|
{ |
|
return new CompressionSpec(CompressionType, Mapping); |
|
} |
|
|
|
public string GetName() |
|
{ |
|
return "Dummy"; |
|
} |
|
} |
|
|
|
[Test] |
|
public void TestStackingMapping() |
|
{ |
|
|
|
var cameraSensor = new CameraSensor(new Camera(), 64, 64, |
|
true, "grayscaleCamera", SensorCompressionType.PNG); |
|
var stackedCameraSensor = new StackingSensor(cameraSensor, 2); |
|
Assert.AreEqual(stackedCameraSensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 0, 0, 1, 1, 1 }); |
|
|
|
|
|
var renderTextureSensor = new RenderTextureSensor(new RenderTexture(24, 16, 0), |
|
false, "renderTexture", SensorCompressionType.PNG); |
|
var stackedRenderTextureSensor = new StackingSensor(renderTextureSensor, 2); |
|
Assert.AreEqual(stackedRenderTextureSensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 1, 2, 3, 4, 5 }); |
|
|
|
|
|
var dummySensor = new Dummy3DSensor(); |
|
dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4); |
|
dummySensor.Mapping = new[] { 0, 1, 2, 3 }; |
|
var stackedDummySensor = new StackingSensor(dummySensor, 2); |
|
Assert.AreEqual(stackedDummySensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); |
|
|
|
|
|
var paddedDummySensor = new Dummy3DSensor(); |
|
paddedDummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4); |
|
paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 }; |
|
var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2); |
|
Assert.AreEqual(stackedPaddedDummySensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); |
|
} |
|
|
|
[Test] |
|
public void Test3DStacking() |
|
{ |
|
var wrapped = new Dummy3DSensor(); |
|
wrapped.ObservationSpec = ObservationSpec.Visual(2, 1, 2); |
|
var sensor = new StackingSensor(wrapped, 2); |
|
|
|
|
|
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f } }, { { 3f, 4f } } }; |
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 1f, 2f } }, { { 0f, 0f, 3f, 4f } } }); |
|
|
|
sensor.Update(); |
|
wrapped.CurrentObservation = new[, ,] { { { 5f, 6f } }, { { 7f, 8f } } }; |
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 1f, 2f, 5f, 6f } }, { { 3f, 4f, 7f, 8f } } }); |
|
|
|
sensor.Update(); |
|
wrapped.CurrentObservation = new[, ,] { { { 9f, 10f } }, { { 11f, 12f } } }; |
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } }); |
|
|
|
|
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } }); |
|
|
|
|
|
sensor.Reset(); |
|
wrapped.CurrentObservation = new[, ,] { { { 13f, 14f } }, { { 15f, 16f } } }; |
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 13f, 14f } }, { { 0f, 0f, 15f, 16f } } }); |
|
} |
|
|
|
[Test] |
|
public void TestStackedGetCompressedObservation() |
|
{ |
|
var wrapped = new Dummy3DSensor(); |
|
wrapped.ObservationSpec = ObservationSpec.Visual(1, 1, 3); |
|
var sensor = new StackingSensor(wrapped, 2); |
|
|
|
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } }; |
|
var expected1 = sensor.CreateEmptyPNG(); |
|
expected1 = expected1.Concat(Array.ConvertAll(new[] { 1f, 2f, 3f }, (z) => (byte)z)).ToArray(); |
|
Assert.AreEqual(sensor.GetCompressedObservation(), expected1); |
|
|
|
sensor.Update(); |
|
wrapped.CurrentObservation = new[, ,] { { { 4f, 5f, 6f } } }; |
|
var expected2 = Array.ConvertAll(new[] { 1f, 2f, 3f, 4f, 5f, 6f }, (z) => (byte)z); |
|
Assert.AreEqual(sensor.GetCompressedObservation(), expected2); |
|
|
|
sensor.Update(); |
|
wrapped.CurrentObservation = new[, ,] { { { 7f, 8f, 9f } } }; |
|
var expected3 = Array.ConvertAll(new[] { 4f, 5f, 6f, 7f, 8f, 9f }, (z) => (byte)z); |
|
Assert.AreEqual(sensor.GetCompressedObservation(), expected3); |
|
|
|
|
|
sensor.Reset(); |
|
wrapped.CurrentObservation = new[, ,] { { { 10f, 11f, 12f } } }; |
|
var expected4 = sensor.CreateEmptyPNG(); |
|
expected4 = expected4.Concat(Array.ConvertAll(new[] { 10f, 11f, 12f }, (z) => (byte)z)).ToArray(); |
|
Assert.AreEqual(sensor.GetCompressedObservation(), expected4); |
|
} |
|
|
|
[Test] |
|
public void TestStackingSensorBuiltInSensorType() |
|
{ |
|
var dummySensor = new Dummy3DSensor(); |
|
dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4); |
|
dummySensor.Mapping = new[] { 0, 1, 2, 3 }; |
|
var stackedDummySensor = new StackingSensor(dummySensor, 2); |
|
Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown); |
|
|
|
var vectorSensor = new VectorSensor(4); |
|
var stackedVectorSensor = new StackingSensor(vectorSensor, 4); |
|
Assert.AreEqual(stackedVectorSensor.GetBuiltInSensorType(), BuiltInSensorType.VectorSensor); |
|
} |
|
} |
|
} |
|
|