File size: 5,287 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
using System;
using System.Collections.Generic;
using Unity.Barracuda;
using Unity.MLAgents.Inference.Utils;
namespace Unity.MLAgents.Inference
{
/// <summary>
/// Tensor - A class to encapsulate a Tensor used for inference.
///
/// This class contains the Array that holds the data array, the shapes, type and the
/// placeholder in the execution graph. All the fields are editable in the inspector,
/// allowing the user to specify everything but the data in a graphical way.
/// </summary>
[Serializable]
internal class TensorProxy
{
public enum TensorType
{
Integer,
FloatingPoint
};
static readonly Dictionary<TensorType, Type> k_TypeMap =
new Dictionary<TensorType, Type>()
{
{TensorType.FloatingPoint, typeof(float)},
{TensorType.Integer, typeof(int)}
};
public string name;
public TensorType valueType;
// Since Type is not serializable, we use the DisplayType for the Inspector
public Type DataType => k_TypeMap[valueType];
public long[] shape;
public Tensor data;
public long Height
{
get { return shape.Length == 4 ? shape[1] : shape[5]; }
}
public long Width
{
get { return shape.Length == 4 ? shape[2] : shape[6]; }
}
public long Channels
{
get { return shape.Length == 4 ? shape[3] : shape[7]; }
}
}
internal static class TensorUtils
{
public static void ResizeTensor(TensorProxy tensor, int batch, ITensorAllocator allocator)
{
if (tensor.shape[0] == batch &&
tensor.data != null && tensor.data.batch == batch)
{
return;
}
tensor.data?.Dispose();
tensor.shape[0] = batch;
if (tensor.shape.Length == 4 || tensor.shape.Length == 8)
{
tensor.data = allocator.Alloc(
new TensorShape(
batch,
(int)tensor.Height,
(int)tensor.Width,
(int)tensor.Channels));
}
else
{
tensor.data = allocator.Alloc(
new TensorShape(
batch,
(int)tensor.shape[tensor.shape.Length - 1]));
}
}
internal static long[] TensorShapeFromBarracuda(TensorShape src)
{
if (src.height == 1 && src.width == 1)
{
return new long[] { src.batch, src.channels };
}
return new long[] { src.batch, src.height, src.width, src.channels };
}
public static TensorProxy TensorProxyFromBarracuda(Tensor src, string nameOverride = null)
{
var shape = TensorShapeFromBarracuda(src.shape);
return new TensorProxy
{
name = nameOverride ?? src.name,
valueType = TensorProxy.TensorType.FloatingPoint,
shape = shape,
data = src
};
}
/// <summary>
/// Fill a specific batch of a TensorProxy with a given value
/// </summary>
/// <param name="tensorProxy"></param>
/// <param name="batch">The batch index to fill.</param>
/// <param name="fillValue"></param>
public static void FillTensorBatch(TensorProxy tensorProxy, int batch, float fillValue)
{
var height = tensorProxy.data.height;
var width = tensorProxy.data.width;
var channels = tensorProxy.data.channels;
for (var h = 0; h < height; h++)
{
for (var w = 0; w < width; w++)
{
for (var c = 0; c < channels; c++)
{
tensorProxy.data[batch, h, w, c] = fillValue;
}
}
}
}
/// <summary>
/// Fill a pre-allocated Tensor with random numbers
/// </summary>
/// <param name="tensorProxy">The pre-allocated Tensor to fill</param>
/// <param name="randomNormal">RandomNormal object used to populate tensor</param>
/// <exception cref="NotImplementedException">
/// Throws when trying to fill a Tensor of type other than float
/// </exception>
/// <exception cref="ArgumentNullException">
/// Throws when the Tensor is not allocated
/// </exception>
public static void FillTensorWithRandomNormal(
TensorProxy tensorProxy, RandomNormal randomNormal)
{
if (tensorProxy.DataType != typeof(float))
{
throw new NotImplementedException("Only float data types are currently supported");
}
if (tensorProxy.data == null)
{
throw new ArgumentNullException();
}
for (var i = 0; i < tensorProxy.data.length; i++)
{
tensorProxy.data[i] = (float)randomNormal.NextDouble();
}
}
}
}
|