File size: 4,370 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
using NUnit.Framework;
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{
public class VectorSensorTests
{
[Test]
public void TestCtor()
{
ISensor sensor = new VectorSensor(4);
Assert.AreEqual("VectorSensor_size4", sensor.GetName());
sensor = new VectorSensor(3, "test_sensor");
Assert.AreEqual("test_sensor", sensor.GetName());
}
[Test]
public void TestWrite()
{
var sensor = new VectorSensor(4);
sensor.AddObservation(1f);
sensor.AddObservation(2f);
sensor.AddObservation(3f);
sensor.AddObservation(4f);
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
// Check that if we don't call Update(), the same observations are produced
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
// Check that Update() clears the data
sensor.Update();
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f });
}
[Test]
public void TestAddObservationFloat()
{
var sensor = new VectorSensor(1);
sensor.AddObservation(1.2f);
SensorTestHelper.CompareObservation(sensor, new[] { 1.2f });
}
[Test]
public void TestObservationType()
{
var sensor = new VectorSensor(1);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.GoalSignal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal);
}
[Test]
public void TestAddObservationInt()
{
var sensor = new VectorSensor(1);
sensor.AddObservation(42);
SensorTestHelper.CompareObservation(sensor, new[] { 42f });
}
[Test]
public void TestAddObservationVec()
{
var sensor = new VectorSensor(3);
sensor.AddObservation(new Vector3(1, 2, 3));
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f });
sensor = new VectorSensor(2);
sensor.AddObservation(new Vector2(4, 5));
SensorTestHelper.CompareObservation(sensor, new[] { 4f, 5f });
}
[Test]
public void TestAddObservationQuaternion()
{
var sensor = new VectorSensor(4);
sensor.AddObservation(Quaternion.identity);
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 1f });
}
[Test]
public void TestWriteEnumerable()
{
var sensor = new VectorSensor(4);
sensor.AddObservation(new[] { 1f, 2f, 3f, 4f });
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
}
[Test]
public void TestAddObservationBool()
{
var sensor = new VectorSensor(1);
sensor.AddObservation(true);
SensorTestHelper.CompareObservation(sensor, new[] { 1f });
}
[Test]
public void TestAddObservationOneHot()
{
var sensor = new VectorSensor(4);
sensor.AddOneHotObservation(2, 4);
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 0f });
}
[Test]
public void TestWriteTooMany()
{
var sensor = new VectorSensor(2);
sensor.AddObservation(new[] { 1f, 2f, 3f, 4f });
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f });
}
[Test]
public void TestWriteNotEnough()
{
var sensor = new VectorSensor(4);
sensor.AddObservation(new[] { 1f, 2f });
// Make sure extra zeros are added
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 0f, 0f });
}
}
}
|