File size: 4,935 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
using Unity.Barracuda;
using System.Collections.Generic;
using System.Diagnostics;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Policies
{
/// <summary>
/// Where to perform inference.
/// </summary>
public enum InferenceDevice
{
/// <summary>
/// Default inference. This is currently the same as Burst, but may change in the future.
/// </summary>
Default = 0,
/// <summary>
/// GPU inference. Corresponds to WorkerFactory.Type.ComputePrecompiled in Barracuda.
/// </summary>
GPU = 1,
/// <summary>
/// CPU inference using Burst. Corresponds to WorkerFactory.Type.CSharpBurst in Barracuda.
/// </summary>
Burst = 2,
/// <summary>
/// CPU inference. Corresponds to in WorkerFactory.Type.CSharp Barracuda.
/// Burst is recommended instead; this is kept for legacy compatibility.
/// </summary>
CPU = 3,
}
/// <summary>
/// The Barracuda Policy uses a Barracuda Model to make decisions at
/// every step. It uses a ModelRunner that is shared across all
/// Barracuda Policies that use the same model and inference devices.
/// </summary>
internal class BarracudaPolicy : IPolicy
{
protected ModelRunner m_ModelRunner;
ActionBuffers m_LastActionBuffer;
int m_AgentId;
/// <summary>
/// Inference only: set to true if the action selection from model should be
/// deterministic.
/// </summary>
bool m_DeterministicInference;
/// <summary>
/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
/// </summary>
List<int[]> m_SensorShapes;
ActionSpec m_ActionSpec;
private string m_BehaviorName;
/// <summary>
/// List of actuators, only used for analytics
/// </summary>
private IList<IActuator> m_Actuators;
/// <summary>
/// Whether or not we've tried to send analytics for this model. We only ever try to send once per policy,
/// and do additional deduplication in the analytics code.
/// </summary>
private bool m_AnalyticsSent;
/// <summary>
/// Instantiate a BarracudaPolicy with the necessary objects for it to run.
/// </summary>
/// <param name="actionSpec">The action spec of the behavior.</param>
/// <param name="actuators">The actuators used for this behavior.</param>
/// <param name="model">The Neural Network to use.</param>
/// <param name="inferenceDevice">Which device Barracuda will run on.</param>
/// <param name="behaviorName">The name of the behavior.</param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
public BarracudaPolicy(
ActionSpec actionSpec,
IList<IActuator> actuators,
NNModel model,
InferenceDevice inferenceDevice,
string behaviorName,
bool deterministicInference = false
)
{
var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice, deterministicInference);
m_ModelRunner = modelRunner;
m_BehaviorName = behaviorName;
m_ActionSpec = actionSpec;
m_Actuators = actuators;
m_DeterministicInference = deterministicInference;
}
/// <inheritdoc />
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
SendAnalytics(sensors);
m_AgentId = info.episodeId;
m_ModelRunner?.PutObservations(info, sensors);
}
[Conditional("MLA_UNITY_ANALYTICS_MODULE")]
void SendAnalytics(IList<ISensor> sensors)
{
if (!m_AnalyticsSent)
{
m_AnalyticsSent = true;
Analytics.InferenceAnalytics.InferenceModelSet(
m_ModelRunner.Model,
m_BehaviorName,
m_ModelRunner.InferenceDevice,
sensors,
m_ActionSpec,
m_Actuators
);
}
}
/// <inheritdoc />
public ref readonly ActionBuffers DecideAction()
{
if (m_ModelRunner == null)
{
m_LastActionBuffer = ActionBuffers.Empty;
}
else
{
m_ModelRunner?.DecideBatch();
m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId);
}
return ref m_LastActionBuffer;
}
public void Dispose()
{
}
}
}
|