using System.Collections.Generic; | |
using Unity.Barracuda; | |
using Unity.MLAgents.Actuators; | |
namespace Unity.MLAgents.Inference | |
{ | |
/// <summary> | |
/// Mapping between the output tensor names and the method that will use the | |
/// output tensors and the Agents present in the batch to update their action, memories and | |
/// value estimates. | |
/// A TensorApplier implements a Dictionary of strings (node names) to an Action. | |
/// This action takes as input the tensor and the Dictionary of Agent to AgentInfo for | |
/// the current batch. | |
/// </summary> | |
internal class TensorApplier | |
{ | |
/// <summary> | |
/// A tensor Applier's Execute method takes a tensor and a Dictionary of Agent to AgentInfo. | |
/// Uses the data contained inside the tensor to modify the state of the Agent. The Tensors | |
/// are assumed to have the batch size on the first dimension and the agents to be ordered | |
/// the same way in the dictionary and in the tensor. | |
/// </summary> | |
public interface IApplier | |
{ | |
/// <summary> | |
/// Applies the values in the Tensor to the Agents present in the agentInfos | |
/// </summary> | |
/// <param name="tensorProxy"> | |
/// The Tensor containing the data to be applied to the Agents | |
/// </param> | |
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param> | |
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param> | |
void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions); | |
} | |
readonly Dictionary<string, IApplier> m_Dict = new Dictionary<string, IApplier>(); | |
/// <summary> | |
/// Returns a new TensorAppliers object. | |
/// </summary> | |
/// <param name="actionSpec"> Description of the actions for the Agent.</param> | |
/// <param name="seed"> The seed the Appliers will be initialized with.</param> | |
/// <param name="allocator"> Tensor allocator</param> | |
/// <param name="memories">Dictionary of AgentInfo.id to memory used to pass to the inference model.</param> | |
/// <param name="barracudaModel"></param> | |
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be | |
/// deterministic.</param> | |
public TensorApplier( | |
ActionSpec actionSpec, | |
int seed, | |
ITensorAllocator allocator, | |
Dictionary<int, List<float>> memories, | |
object barracudaModel = null, | |
bool deterministicInference = false) | |
{ | |
// If model is null, no inference to run and exception is thrown before reaching here. | |
if (barracudaModel == null) | |
{ | |
return; | |
} | |
var model = (Model)barracudaModel; | |
if (!model.SupportsContinuousAndDiscrete()) | |
{ | |
actionSpec.CheckAllContinuousOrDiscrete(); | |
} | |
if (actionSpec.NumContinuousActions > 0) | |
{ | |
var tensorName = model.ContinuousOutputName(deterministicInference); | |
m_Dict[tensorName] = new ContinuousActionOutputApplier(actionSpec); | |
} | |
var modelVersion = model.GetVersion(); | |
if (actionSpec.NumDiscreteActions > 0) | |
{ | |
var tensorName = model.DiscreteOutputName(deterministicInference); | |
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0) | |
{ | |
m_Dict[tensorName] = new LegacyDiscreteActionOutputApplier(actionSpec, seed, allocator); | |
} | |
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0) | |
{ | |
m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator); | |
} | |
} | |
m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories); | |
} | |
/// <summary> | |
/// Updates the state of the agents based on the data present in the tensor. | |
/// </summary> | |
/// <param name="tensors"> Enumerable of tensors containing the data.</param> | |
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param> | |
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param> | |
/// <exception cref="UnityAgentsException"> One of the tensor does not have an | |
/// associated applier.</exception> | |
public void ApplyTensors( | |
IReadOnlyList<TensorProxy> tensors, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions) | |
{ | |
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) | |
{ | |
var tensor = tensors[tensorIndex]; | |
if (!m_Dict.ContainsKey(tensor.name)) | |
{ | |
throw new UnityAgentsException( | |
$"Unknown tensorProxy expected as output : {tensor.name}"); | |
} | |
m_Dict[tensor.name].Apply(tensor, actionIds, lastActions); | |
} | |
} | |
} | |
} | |