File size: 5,287 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
using System;
using System.Collections.Generic;
using Unity.Barracuda;
using Unity.MLAgents.Inference.Utils;

namespace Unity.MLAgents.Inference
{
    /// <summary>
    /// Tensor - A class to encapsulate a Tensor used for inference.
    ///
    /// This class contains the Array that holds the data array, the shapes, type and the
    /// placeholder in the execution graph. All the fields are editable in the inspector,
    /// allowing the user to specify everything but the data in a graphical way.
    /// </summary>
    [Serializable]
    internal class TensorProxy
    {
        public enum TensorType
        {
            Integer,
            FloatingPoint
        };

        static readonly Dictionary<TensorType, Type> k_TypeMap =
            new Dictionary<TensorType, Type>()
        {
            {TensorType.FloatingPoint, typeof(float)},
            {TensorType.Integer, typeof(int)}
        };

        public string name;
        public TensorType valueType;

        // Since Type is not serializable, we use the DisplayType for the Inspector
        public Type DataType => k_TypeMap[valueType];
        public long[] shape;
        public Tensor data;

        public long Height
        {
            get { return shape.Length == 4 ? shape[1] : shape[5]; }
        }

        public long Width
        {
            get { return shape.Length == 4 ? shape[2] : shape[6]; }
        }

        public long Channels
        {
            get { return shape.Length == 4 ? shape[3] : shape[7]; }
        }
    }

    internal static class TensorUtils
    {
        public static void ResizeTensor(TensorProxy tensor, int batch, ITensorAllocator allocator)
        {
            if (tensor.shape[0] == batch &&
                tensor.data != null && tensor.data.batch == batch)
            {
                return;
            }

            tensor.data?.Dispose();
            tensor.shape[0] = batch;

            if (tensor.shape.Length == 4 || tensor.shape.Length == 8)
            {
                tensor.data = allocator.Alloc(
                    new TensorShape(
                        batch,
                        (int)tensor.Height,
                        (int)tensor.Width,
                        (int)tensor.Channels));
            }
            else
            {
                tensor.data = allocator.Alloc(
                    new TensorShape(
                        batch,
                        (int)tensor.shape[tensor.shape.Length - 1]));
            }
        }

        internal static long[] TensorShapeFromBarracuda(TensorShape src)
        {
            if (src.height == 1 && src.width == 1)
            {
                return new long[] { src.batch, src.channels };
            }

            return new long[] { src.batch, src.height, src.width, src.channels };
        }

        public static TensorProxy TensorProxyFromBarracuda(Tensor src, string nameOverride = null)
        {
            var shape = TensorShapeFromBarracuda(src.shape);
            return new TensorProxy
            {
                name = nameOverride ?? src.name,
                valueType = TensorProxy.TensorType.FloatingPoint,
                shape = shape,
                data = src
            };
        }

        /// <summary>
        /// Fill a specific batch of a TensorProxy with a given value
        /// </summary>
        /// <param name="tensorProxy"></param>
        /// <param name="batch">The batch index to fill.</param>
        /// <param name="fillValue"></param>
        public static void FillTensorBatch(TensorProxy tensorProxy, int batch, float fillValue)
        {
            var height = tensorProxy.data.height;
            var width = tensorProxy.data.width;
            var channels = tensorProxy.data.channels;
            for (var h = 0; h < height; h++)
            {
                for (var w = 0; w < width; w++)
                {
                    for (var c = 0; c < channels; c++)
                    {
                        tensorProxy.data[batch, h, w, c] = fillValue;
                    }
                }
            }
        }

        /// <summary>
        /// Fill a pre-allocated Tensor with random numbers
        /// </summary>
        /// <param name="tensorProxy">The pre-allocated Tensor to fill</param>
        /// <param name="randomNormal">RandomNormal object used to populate tensor</param>
        /// <exception cref="NotImplementedException">
        /// Throws when trying to fill a Tensor of type other than float
        /// </exception>
        /// <exception cref="ArgumentNullException">
        /// Throws when the Tensor is not allocated
        /// </exception>
        public static void FillTensorWithRandomNormal(
            TensorProxy tensorProxy, RandomNormal randomNormal)
        {
            if (tensorProxy.DataType != typeof(float))
            {
                throw new NotImplementedException("Only float data types are currently supported");
            }

            if (tensorProxy.data == null)
            {
                throw new ArgumentNullException();
            }

            for (var i = 0; i < tensorProxy.data.length; i++)
            {
                tensorProxy.data[i] = (float)randomNormal.NextDouble();
            }
        }
    }
}