File size: 4,095 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
using System;
using NUnit.Framework;
using Unity.Barracuda;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Inference.Utils;

namespace Unity.MLAgents.Tests
{
    public class TensorUtilsTest
    {
        [TestCase(4, TestName = "TestResizeTensor_4D")]
        [TestCase(8, TestName = "TestResizeTensor_8D")]
        public void TestResizeTensor(int dimension)
        {
            var alloc = new TensorCachingAllocator();
            var height = 64;
            var width = 84;
            var channels = 3;

            // Set shape to {1, ..., height, width, channels}
            // For 8D, the ... are all 1's
            var shape = new long[dimension];
            for (var i = 0; i < dimension; i++)
            {
                shape[i] = 1;
            }

            shape[dimension - 3] = height;
            shape[dimension - 2] = width;
            shape[dimension - 1] = channels;

            var intShape = new int[dimension];
            for (var i = 0; i < dimension; i++)
            {
                intShape[i] = (int)shape[i];
            }

            var tensorProxy = new TensorProxy
            {
                valueType = TensorProxy.TensorType.Integer,
                data = new Tensor(intShape),
                shape = shape,
            };

            // These should be invariant after the resize.
            Assert.AreEqual(height, tensorProxy.data.shape.height);
            Assert.AreEqual(width, tensorProxy.data.shape.width);
            Assert.AreEqual(channels, tensorProxy.data.shape.channels);

            TensorUtils.ResizeTensor(tensorProxy, 42, alloc);

            Assert.AreEqual(height, tensorProxy.shape[dimension - 3]);
            Assert.AreEqual(width, tensorProxy.shape[dimension - 2]);
            Assert.AreEqual(channels, tensorProxy.shape[dimension - 1]);

            Assert.AreEqual(height, tensorProxy.data.shape.height);
            Assert.AreEqual(width, tensorProxy.data.shape.width);
            Assert.AreEqual(channels, tensorProxy.data.shape.channels);

            alloc.Dispose();
        }

        [Test]
        public void RandomNormalTestTensorInt()
        {
            var rn = new RandomNormal(1982);
            var t = new TensorProxy
            {
                valueType = TensorProxy.TensorType.Integer
            };

            Assert.Throws<NotImplementedException>(
                () => TensorUtils.FillTensorWithRandomNormal(t, rn));
        }

        [Test]
        public void RandomNormalTestDataNull()
        {
            var rn = new RandomNormal(1982);
            var t = new TensorProxy
            {
                valueType = TensorProxy.TensorType.FloatingPoint
            };

            Assert.Throws<ArgumentNullException>(
                () => TensorUtils.FillTensorWithRandomNormal(t, rn));
        }

        [Test]
        public void RandomNormalTestTensor()
        {
            var rn = new RandomNormal(1982);
            var t = new TensorProxy
            {
                valueType = TensorProxy.TensorType.FloatingPoint,
                data = new Tensor(1, 3, 4, 2)
            };

            TensorUtils.FillTensorWithRandomNormal(t, rn);

            var reference = new[]
            {
                -0.4315872f,
                -1.11074f,
                0.3414804f,
                -1.130287f,
                0.1413168f,
                -0.5105762f,
                -0.3027347f,
                -0.2645015f,
                1.225356f,
                -0.02921959f,
                0.3716498f,
                -1.092338f,
                0.9561074f,
                -0.5018106f,
                1.167787f,
                -0.7763879f,
                -0.07491868f,
                0.5396146f,
                -0.1377991f,
                0.3331701f,
                0.06144788f,
                0.9520947f,
                1.088157f,
                -1.177194f,
            };

            for (var i = 0; i < t.data.length; i++)
            {
                Assert.AreEqual(t.data[i], reference[i], 0.0001);
            }
        }
    }
}