File size: 3,066 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
using NUnit.Framework;
using Unity.Barracuda;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Inference;
namespace Unity.MLAgents.Tests
{
public class ObservationWriterTests
{
[Test]
public void TestWritesToIList()
{
ObservationWriter writer = new ObservationWriter();
var buffer = new[] { 0f, 0f, 0f };
var shape = new InplaceArray<int>(3);
writer.SetTarget(buffer, shape, 0);
// Elementwise writes
writer[0] = 1f;
writer[2] = 2f;
Assert.AreEqual(new[] { 1f, 0f, 2f }, buffer);
// Elementwise writes with offset
writer.SetTarget(buffer, shape, 1);
writer[0] = 3f;
Assert.AreEqual(new[] { 1f, 3f, 2f }, buffer);
// AddList
writer.SetTarget(buffer, shape, 0);
writer.AddList(new[] { 4f, 5f });
Assert.AreEqual(new[] { 4f, 5f, 2f }, buffer);
// AddList with offset
writer.SetTarget(buffer, shape, 1);
writer.AddList(new[] { 6f, 7f });
Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer);
}
[Test]
public void TestWritesToTensor()
{
ObservationWriter writer = new ObservationWriter();
var t = new TensorProxy
{
valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(2, 3)
};
writer.SetTarget(t, 0, 0);
Assert.AreEqual(0f, t.data[0, 0]);
writer[0] = 1f;
Assert.AreEqual(1f, t.data[0, 0]);
writer.SetTarget(t, 1, 1);
writer[0] = 2f;
writer[1] = 3f;
// [0, 0] shouldn't change
Assert.AreEqual(1f, t.data[0, 0]);
Assert.AreEqual(2f, t.data[1, 1]);
Assert.AreEqual(3f, t.data[1, 2]);
// AddList
t = new TensorProxy
{
valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(2, 3)
};
writer.SetTarget(t, 1, 1);
writer.AddList(new[] { -1f, -2f });
Assert.AreEqual(0f, t.data[0, 0]);
Assert.AreEqual(0f, t.data[0, 1]);
Assert.AreEqual(0f, t.data[0, 2]);
Assert.AreEqual(0f, t.data[1, 0]);
Assert.AreEqual(-1f, t.data[1, 1]);
Assert.AreEqual(-2f, t.data[1, 2]);
}
[Test]
public void TestWritesToTensor3D()
{
ObservationWriter writer = new ObservationWriter();
var t = new TensorProxy
{
valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(2, 2, 2, 3)
};
writer.SetTarget(t, 0, 0);
writer[1, 0, 1] = 1f;
Assert.AreEqual(1f, t.data[0, 1, 0, 1]);
writer.SetTarget(t, 0, 1);
writer[1, 0, 0] = 2f;
Assert.AreEqual(2f, t.data[0, 1, 0, 1]);
}
}
}
|