AnnaMats's picture
Second Push
05c9ac2
using NUnit.Framework;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class ObservationSpecTests
{
[Test]
public void TestVectorObsSpec()
{
var obsSpec = ObservationSpec.Vector(5);
Assert.AreEqual(1, obsSpec.Rank);
var shape = obsSpec.Shape;
Assert.AreEqual(1, shape.Length);
Assert.AreEqual(5, shape[0]);
var dimensionProps = obsSpec.DimensionProperties;
Assert.AreEqual(1, dimensionProps.Length);
Assert.AreEqual(DimensionProperty.None, dimensionProps[0]);
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
}
[Test]
public void TestVariableLengthObsSpec()
{
var obsSpec = ObservationSpec.VariableLength(5, 6);
Assert.AreEqual(2, obsSpec.Rank);
var shape = obsSpec.Shape;
Assert.AreEqual(2, shape.Length);
Assert.AreEqual(5, shape[0]);
Assert.AreEqual(6, shape[1]);
var dimensionProps = obsSpec.DimensionProperties;
Assert.AreEqual(2, dimensionProps.Length);
Assert.AreEqual(DimensionProperty.VariableSize, dimensionProps[0]);
Assert.AreEqual(DimensionProperty.None, dimensionProps[1]);
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
}
[Test]
public void TestVisualObsSpec()
{
var obsSpec = ObservationSpec.Visual(5, 6, 7);
Assert.AreEqual(3, obsSpec.Rank);
var shape = obsSpec.Shape;
Assert.AreEqual(3, shape.Length);
Assert.AreEqual(5, shape[0]);
Assert.AreEqual(6, shape[1]);
Assert.AreEqual(7, shape[2]);
var dimensionProps = obsSpec.DimensionProperties;
Assert.AreEqual(3, dimensionProps.Length);
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[0]);
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[1]);
Assert.AreEqual(DimensionProperty.None, dimensionProps[2]);
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
}
[Test]
public void TestMismatchShapeDimensionPropThrows()
{
var shape = new InplaceArray<int>(1, 2);
var dimProps = new InplaceArray<DimensionProperty>(DimensionProperty.TranslationalEquivariance);
Assert.Throws<UnityAgentsException>(() => new ObservationSpec(shape, dimProps));
}
}
}