using Unity.Barracuda; | |
using System.Collections.Generic; | |
using System.Diagnostics; | |
using Unity.MLAgents.Actuators; | |
using Unity.MLAgents.Inference; | |
using Unity.MLAgents.Sensors; | |
namespace Unity.MLAgents.Policies | |
{ | |
/// <summary> | |
/// Where to perform inference. | |
/// </summary> | |
public enum InferenceDevice | |
{ | |
/// <summary> | |
/// Default inference. This is currently the same as Burst, but may change in the future. | |
/// </summary> | |
Default = 0, | |
/// <summary> | |
/// GPU inference. Corresponds to WorkerFactory.Type.ComputePrecompiled in Barracuda. | |
/// </summary> | |
GPU = 1, | |
/// <summary> | |
/// CPU inference using Burst. Corresponds to WorkerFactory.Type.CSharpBurst in Barracuda. | |
/// </summary> | |
Burst = 2, | |
/// <summary> | |
/// CPU inference. Corresponds to in WorkerFactory.Type.CSharp Barracuda. | |
/// Burst is recommended instead; this is kept for legacy compatibility. | |
/// </summary> | |
CPU = 3, | |
} | |
/// <summary> | |
/// The Barracuda Policy uses a Barracuda Model to make decisions at | |
/// every step. It uses a ModelRunner that is shared across all | |
/// Barracuda Policies that use the same model and inference devices. | |
/// </summary> | |
internal class BarracudaPolicy : IPolicy | |
{ | |
protected ModelRunner m_ModelRunner; | |
ActionBuffers m_LastActionBuffer; | |
int m_AgentId; | |
/// <summary> | |
/// Inference only: set to true if the action selection from model should be | |
/// deterministic. | |
/// </summary> | |
bool m_DeterministicInference; | |
/// <summary> | |
/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. | |
/// </summary> | |
List<int[]> m_SensorShapes; | |
ActionSpec m_ActionSpec; | |
private string m_BehaviorName; | |
/// <summary> | |
/// List of actuators, only used for analytics | |
/// </summary> | |
private IList<IActuator> m_Actuators; | |
/// <summary> | |
/// Whether or not we've tried to send analytics for this model. We only ever try to send once per policy, | |
/// and do additional deduplication in the analytics code. | |
/// </summary> | |
private bool m_AnalyticsSent; | |
/// <summary> | |
/// Instantiate a BarracudaPolicy with the necessary objects for it to run. | |
/// </summary> | |
/// <param name="actionSpec">The action spec of the behavior.</param> | |
/// <param name="actuators">The actuators used for this behavior.</param> | |
/// <param name="model">The Neural Network to use.</param> | |
/// <param name="inferenceDevice">Which device Barracuda will run on.</param> | |
/// <param name="behaviorName">The name of the behavior.</param> | |
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be | |
/// deterministic. </param> | |
public BarracudaPolicy( | |
ActionSpec actionSpec, | |
IList<IActuator> actuators, | |
NNModel model, | |
InferenceDevice inferenceDevice, | |
string behaviorName, | |
bool deterministicInference = false | |
) | |
{ | |
var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice, deterministicInference); | |
m_ModelRunner = modelRunner; | |
m_BehaviorName = behaviorName; | |
m_ActionSpec = actionSpec; | |
m_Actuators = actuators; | |
m_DeterministicInference = deterministicInference; | |
} | |
/// <inheritdoc /> | |
public void RequestDecision(AgentInfo info, List<ISensor> sensors) | |
{ | |
SendAnalytics(sensors); | |
m_AgentId = info.episodeId; | |
m_ModelRunner?.PutObservations(info, sensors); | |
} | |
[ | ]|
void SendAnalytics(IList<ISensor> sensors) | |
{ | |
if (!m_AnalyticsSent) | |
{ | |
m_AnalyticsSent = true; | |
Analytics.InferenceAnalytics.InferenceModelSet( | |
m_ModelRunner.Model, | |
m_BehaviorName, | |
m_ModelRunner.InferenceDevice, | |
sensors, | |
m_ActionSpec, | |
m_Actuators | |
); | |
} | |
} | |
/// <inheritdoc /> | |
public ref readonly ActionBuffers DecideAction() | |
{ | |
if (m_ModelRunner == null) | |
{ | |
m_LastActionBuffer = ActionBuffers.Empty; | |
} | |
else | |
{ | |
m_ModelRunner?.DecideBatch(); | |
m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId); | |
} | |
return ref m_LastActionBuffer; | |
} | |
public void Dispose() | |
{ | |
} | |
} | |
} | |