|
using System.Collections.Generic; |
|
using Unity.Barracuda; |
|
using Unity.MLAgents.Sensors; |
|
|
|
namespace Unity.MLAgents.Inference |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal class TensorGenerator |
|
{ |
|
public interface IGenerator |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Generate( |
|
TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos); |
|
} |
|
|
|
readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>(); |
|
int m_ApiVersion; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public TensorGenerator( |
|
int seed, |
|
ITensorAllocator allocator, |
|
Dictionary<int, List<float>> memories, |
|
object barracudaModel = null, |
|
bool deterministicInference = false) |
|
{ |
|
|
|
if (barracudaModel == null) |
|
{ |
|
return; |
|
} |
|
var model = (Model)barracudaModel; |
|
|
|
m_ApiVersion = model.GetVersion(); |
|
|
|
|
|
m_Dict[TensorNames.BatchSizePlaceholder] = |
|
new BatchSizeGenerator(allocator); |
|
m_Dict[TensorNames.SequenceLengthPlaceholder] = |
|
new SequenceLengthGenerator(allocator); |
|
m_Dict[TensorNames.RecurrentInPlaceholder] = |
|
new RecurrentInputGenerator(allocator, memories); |
|
|
|
m_Dict[TensorNames.PreviousActionPlaceholder] = |
|
new PreviousActionInputGenerator(allocator); |
|
m_Dict[TensorNames.ActionMaskPlaceholder] = |
|
new ActionMaskInputGenerator(allocator); |
|
m_Dict[TensorNames.RandomNormalEpsilonPlaceholder] = |
|
new RandomNormalInputGenerator(seed, allocator); |
|
|
|
|
|
|
|
if (model.HasContinuousOutputs(deterministicInference)) |
|
{ |
|
m_Dict[model.ContinuousOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); |
|
} |
|
if (model.HasDiscreteOutputs(deterministicInference)) |
|
{ |
|
m_Dict[model.DiscreteOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); |
|
} |
|
m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator); |
|
m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator); |
|
} |
|
|
|
public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator) |
|
{ |
|
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0) |
|
{ |
|
|
|
|
|
|
|
var visIndex = 0; |
|
ObservationGenerator vecObsGen = null; |
|
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) |
|
{ |
|
var sensor = sensors[sensorIndex]; |
|
var rank = sensor.GetObservationSpec().Rank; |
|
ObservationGenerator obsGen = null; |
|
string obsGenName = null; |
|
switch (rank) |
|
{ |
|
case 1: |
|
if (vecObsGen == null) |
|
{ |
|
vecObsGen = new ObservationGenerator(allocator); |
|
} |
|
obsGen = vecObsGen; |
|
obsGenName = TensorNames.VectorObservationPlaceholder; |
|
break; |
|
case 2: |
|
|
|
|
|
obsGen = new ObservationGenerator(allocator); |
|
obsGenName = TensorNames.GetObservationName(sensorIndex); |
|
break; |
|
case 3: |
|
|
|
|
|
obsGen = new ObservationGenerator(allocator); |
|
obsGenName = TensorNames.GetVisualObservationName(visIndex); |
|
visIndex++; |
|
break; |
|
default: |
|
throw new UnityAgentsException( |
|
$"Sensor {sensor.GetName()} have an invalid rank {rank}"); |
|
} |
|
obsGen.AddSensorIndex(sensorIndex); |
|
m_Dict[obsGenName] = obsGen; |
|
} |
|
} |
|
|
|
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0) |
|
{ |
|
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) |
|
{ |
|
var obsGen = new ObservationGenerator(allocator); |
|
var obsGenName = TensorNames.GetObservationName(sensorIndex); |
|
obsGen.AddSensorIndex(sensorIndex); |
|
m_Dict[obsGenName] = obsGen; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public void GenerateTensors( |
|
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<AgentInfoSensorsPair> infos) |
|
{ |
|
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) |
|
{ |
|
var tensor = tensors[tensorIndex]; |
|
if (!m_Dict.ContainsKey(tensor.name)) |
|
{ |
|
throw new UnityAgentsException( |
|
$"Unknown tensorProxy expected as input : {tensor.name}"); |
|
} |
|
m_Dict[tensor.name].Generate(tensor, currentBatchSize, infos); |
|
} |
|
} |
|
} |
|
} |
|
|