|
using NUnit.Framework; |
|
using Unity.Barracuda; |
|
using Unity.MLAgents.Sensors; |
|
using Unity.MLAgents.Inference; |
|
|
|
|
|
namespace Unity.MLAgents.Tests |
|
{ |
|
public class ObservationWriterTests |
|
{ |
|
[Test] |
|
public void TestWritesToIList() |
|
{ |
|
ObservationWriter writer = new ObservationWriter(); |
|
var buffer = new[] { 0f, 0f, 0f }; |
|
var shape = new InplaceArray<int>(3); |
|
|
|
writer.SetTarget(buffer, shape, 0); |
|
|
|
writer[0] = 1f; |
|
writer[2] = 2f; |
|
Assert.AreEqual(new[] { 1f, 0f, 2f }, buffer); |
|
|
|
|
|
writer.SetTarget(buffer, shape, 1); |
|
writer[0] = 3f; |
|
Assert.AreEqual(new[] { 1f, 3f, 2f }, buffer); |
|
|
|
|
|
writer.SetTarget(buffer, shape, 0); |
|
writer.AddList(new[] { 4f, 5f }); |
|
Assert.AreEqual(new[] { 4f, 5f, 2f }, buffer); |
|
|
|
|
|
writer.SetTarget(buffer, shape, 1); |
|
writer.AddList(new[] { 6f, 7f }); |
|
Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer); |
|
} |
|
|
|
[Test] |
|
public void TestWritesToTensor() |
|
{ |
|
ObservationWriter writer = new ObservationWriter(); |
|
var t = new TensorProxy |
|
{ |
|
valueType = TensorProxy.TensorType.FloatingPoint, |
|
data = new Tensor(2, 3) |
|
}; |
|
|
|
writer.SetTarget(t, 0, 0); |
|
Assert.AreEqual(0f, t.data[0, 0]); |
|
writer[0] = 1f; |
|
Assert.AreEqual(1f, t.data[0, 0]); |
|
|
|
writer.SetTarget(t, 1, 1); |
|
writer[0] = 2f; |
|
writer[1] = 3f; |
|
|
|
Assert.AreEqual(1f, t.data[0, 0]); |
|
Assert.AreEqual(2f, t.data[1, 1]); |
|
Assert.AreEqual(3f, t.data[1, 2]); |
|
|
|
|
|
t = new TensorProxy |
|
{ |
|
valueType = TensorProxy.TensorType.FloatingPoint, |
|
data = new Tensor(2, 3) |
|
}; |
|
|
|
writer.SetTarget(t, 1, 1); |
|
writer.AddList(new[] { -1f, -2f }); |
|
Assert.AreEqual(0f, t.data[0, 0]); |
|
Assert.AreEqual(0f, t.data[0, 1]); |
|
Assert.AreEqual(0f, t.data[0, 2]); |
|
Assert.AreEqual(0f, t.data[1, 0]); |
|
Assert.AreEqual(-1f, t.data[1, 1]); |
|
Assert.AreEqual(-2f, t.data[1, 2]); |
|
} |
|
|
|
[Test] |
|
public void TestWritesToTensor3D() |
|
{ |
|
ObservationWriter writer = new ObservationWriter(); |
|
var t = new TensorProxy |
|
{ |
|
valueType = TensorProxy.TensorType.FloatingPoint, |
|
data = new Tensor(2, 2, 2, 3) |
|
}; |
|
|
|
writer.SetTarget(t, 0, 0); |
|
writer[1, 0, 1] = 1f; |
|
Assert.AreEqual(1f, t.data[0, 1, 0, 1]); |
|
|
|
writer.SetTarget(t, 0, 1); |
|
writer[1, 0, 0] = 2f; |
|
Assert.AreEqual(2f, t.data[0, 1, 0, 1]); |
|
} |
|
} |
|
} |
|
|