File size: 8,700 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
using System.Collections.Generic;
using Unity.Barracuda;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Inference
{
/// <summary>
/// Mapping between Tensor names and generators.
/// A TensorGenerator implements a Dictionary of strings (node names) to an Action.
/// The Action take as argument the tensor, the current batch size and a Dictionary of
/// Agent to AgentInfo corresponding to the current batch.
/// Each Generator reshapes and fills the data of the tensor based of the data of the batch.
/// When the TensorProxy is an Input to the model, the shape of the Tensor will be modified
/// depending on the current batch size and the data of the Tensor will be filled using the
/// Dictionary of Agent to AgentInfo.
/// When the TensorProxy is an Output of the model, only the shape of the Tensor will be
/// modified using the current batch size. The data will be pre-filled with zeros.
/// </summary>
internal class TensorGenerator
{
public interface IGenerator
{
/// <summary>
/// Modifies the data inside a Tensor according to the information contained in the
/// AgentInfos contained in the current batch.
/// </summary>
/// <param name="tensorProxy"> The tensor the data and shape will be modified.</param>
/// <param name="batchSize"> The number of agents present in the current batch.</param>
/// <param name="infos">
/// List of AgentInfos containing the information that will be used to populate
/// the tensor's data.
/// </param>
void Generate(
TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos);
}
readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>();
int m_ApiVersion;
/// <summary>
/// Returns a new TensorGenerators object.
/// </summary>
/// <param name="seed"> The seed the Generators will be initialized with.</param>
/// <param name="allocator"> Tensor allocator.</param>
/// <param name="memories">Dictionary of AgentInfo.id to memory for use in the inference model.</param>
/// <param name="barracudaModel"></param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
public TensorGenerator(
int seed,
ITensorAllocator allocator,
Dictionary<int, List<float>> memories,
object barracudaModel = null,
bool deterministicInference = false)
{
// If model is null, no inference to run and exception is thrown before reaching here.
if (barracudaModel == null)
{
return;
}
var model = (Model)barracudaModel;
m_ApiVersion = model.GetVersion();
// Generator for Inputs
m_Dict[TensorNames.BatchSizePlaceholder] =
new BatchSizeGenerator(allocator);
m_Dict[TensorNames.SequenceLengthPlaceholder] =
new SequenceLengthGenerator(allocator);
m_Dict[TensorNames.RecurrentInPlaceholder] =
new RecurrentInputGenerator(allocator, memories);
m_Dict[TensorNames.PreviousActionPlaceholder] =
new PreviousActionInputGenerator(allocator);
m_Dict[TensorNames.ActionMaskPlaceholder] =
new ActionMaskInputGenerator(allocator);
m_Dict[TensorNames.RandomNormalEpsilonPlaceholder] =
new RandomNormalInputGenerator(seed, allocator);
// Generators for Outputs
if (model.HasContinuousOutputs(deterministicInference))
{
m_Dict[model.ContinuousOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator);
}
if (model.HasDiscreteOutputs(deterministicInference))
{
m_Dict[model.DiscreteOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator);
}
m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator);
m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator);
}
public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator)
{
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0)
{
// Loop through the sensors on a representative agent.
// All vector observations use a shared ObservationGenerator since they are concatenated.
// All other observations use a unique ObservationInputGenerator
var visIndex = 0;
ObservationGenerator vecObsGen = null;
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var sensor = sensors[sensorIndex];
var rank = sensor.GetObservationSpec().Rank;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
{
case 1:
if (vecObsGen == null)
{
vecObsGen = new ObservationGenerator(allocator);
}
obsGen = vecObsGen;
obsGenName = TensorNames.VectorObservationPlaceholder;
break;
case 2:
// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetObservationName(sensorIndex);
break;
case 3:
// If the tensor is of rank 3, we use the "visual observation
// index", which only counts the rank 3 sensors
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetVisualObservationName(visIndex);
visIndex++;
break;
default:
throw new UnityAgentsException(
$"Sensor {sensor.GetName()} have an invalid rank {rank}");
}
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
{
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var obsGen = new ObservationGenerator(allocator);
var obsGenName = TensorNames.GetObservationName(sensorIndex);
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}
}
/// <summary>
/// Populates the data of the tensor inputs given the data contained in the current batch
/// of agents.
/// </summary>
/// <param name="tensors"> Enumerable of tensors that will be modified.</param>
/// <param name="currentBatchSize"> The number of agents present in the current batch
/// </param>
/// <param name="infos"> List of AgentsInfos and Sensors that contains the
/// data that will be used to modify the tensors</param>
/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated generator.</exception>
public void GenerateTensors(
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<AgentInfoSensorsPair> infos)
{
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
{
var tensor = tensors[tensorIndex];
if (!m_Dict.ContainsKey(tensor.name))
{
throw new UnityAgentsException(
$"Unknown tensorProxy expected as input : {tensor.name}");
}
m_Dict[tensor.name].Generate(tensor, currentBatchSize, infos);
}
}
}
}
|