File size: 4,095 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
using System;
using NUnit.Framework;
using Unity.Barracuda;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Inference.Utils;
namespace Unity.MLAgents.Tests
{
public class TensorUtilsTest
{
[TestCase(4, TestName = "TestResizeTensor_4D")]
[TestCase(8, TestName = "TestResizeTensor_8D")]
public void TestResizeTensor(int dimension)
{
var alloc = new TensorCachingAllocator();
var height = 64;
var width = 84;
var channels = 3;
// Set shape to {1, ..., height, width, channels}
// For 8D, the ... are all 1's
var shape = new long[dimension];
for (var i = 0; i < dimension; i++)
{
shape[i] = 1;
}
shape[dimension - 3] = height;
shape[dimension - 2] = width;
shape[dimension - 1] = channels;
var intShape = new int[dimension];
for (var i = 0; i < dimension; i++)
{
intShape[i] = (int)shape[i];
}
var tensorProxy = new TensorProxy
{
valueType = TensorProxy.TensorType.Integer,
data = new Tensor(intShape),
shape = shape,
};
// These should be invariant after the resize.
Assert.AreEqual(height, tensorProxy.data.shape.height);
Assert.AreEqual(width, tensorProxy.data.shape.width);
Assert.AreEqual(channels, tensorProxy.data.shape.channels);
TensorUtils.ResizeTensor(tensorProxy, 42, alloc);
Assert.AreEqual(height, tensorProxy.shape[dimension - 3]);
Assert.AreEqual(width, tensorProxy.shape[dimension - 2]);
Assert.AreEqual(channels, tensorProxy.shape[dimension - 1]);
Assert.AreEqual(height, tensorProxy.data.shape.height);
Assert.AreEqual(width, tensorProxy.data.shape.width);
Assert.AreEqual(channels, tensorProxy.data.shape.channels);
alloc.Dispose();
}
[Test]
public void RandomNormalTestTensorInt()
{
var rn = new RandomNormal(1982);
var t = new TensorProxy
{
valueType = TensorProxy.TensorType.Integer
};
Assert.Throws<NotImplementedException>(
() => TensorUtils.FillTensorWithRandomNormal(t, rn));
}
[Test]
public void RandomNormalTestDataNull()
{
var rn = new RandomNormal(1982);
var t = new TensorProxy
{
valueType = TensorProxy.TensorType.FloatingPoint
};
Assert.Throws<ArgumentNullException>(
() => TensorUtils.FillTensorWithRandomNormal(t, rn));
}
[Test]
public void RandomNormalTestTensor()
{
var rn = new RandomNormal(1982);
var t = new TensorProxy
{
valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(1, 3, 4, 2)
};
TensorUtils.FillTensorWithRandomNormal(t, rn);
var reference = new[]
{
-0.4315872f,
-1.11074f,
0.3414804f,
-1.130287f,
0.1413168f,
-0.5105762f,
-0.3027347f,
-0.2645015f,
1.225356f,
-0.02921959f,
0.3716498f,
-1.092338f,
0.9561074f,
-0.5018106f,
1.167787f,
-0.7763879f,
-0.07491868f,
0.5396146f,
-0.1377991f,
0.3331701f,
0.06144788f,
0.9520947f,
1.088157f,
-1.177194f,
};
for (var i = 0; i < t.data.length; i++)
{
Assert.AreEqual(t.data[i], reference[i], 0.0001);
}
}
}
}
|