File size: 4,935 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
using Unity.Barracuda;
using System.Collections.Generic;
using System.Diagnostics;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Policies
{
    /// <summary>
    /// Where to perform inference.
    /// </summary>
    public enum InferenceDevice
    {
        /// <summary>
        /// Default inference. This is currently the same as Burst, but may change in the future.
        /// </summary>
        Default = 0,

        /// <summary>
        /// GPU inference. Corresponds to WorkerFactory.Type.ComputePrecompiled in Barracuda.
        /// </summary>
        GPU = 1,

        /// <summary>
        /// CPU inference using Burst. Corresponds to WorkerFactory.Type.CSharpBurst in Barracuda.
        /// </summary>
        Burst = 2,

        /// <summary>
        /// CPU inference. Corresponds to in WorkerFactory.Type.CSharp Barracuda.
        /// Burst is recommended instead; this is kept for legacy compatibility.
        /// </summary>
        CPU = 3,
    }

    /// <summary>
    /// The Barracuda Policy uses a Barracuda Model to make decisions at
    /// every step. It uses a ModelRunner that is shared across all
    /// Barracuda Policies that use the same model and inference devices.
    /// </summary>
    internal class BarracudaPolicy : IPolicy
    {
        protected ModelRunner m_ModelRunner;
        ActionBuffers m_LastActionBuffer;

        int m_AgentId;
        /// <summary>
        /// Inference only: set to true if the action selection from model should be
        /// deterministic.
        /// </summary>
        bool m_DeterministicInference;

        /// <summary>
        /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
        /// </summary>
        List<int[]> m_SensorShapes;
        ActionSpec m_ActionSpec;

        private string m_BehaviorName;

        /// <summary>
        /// List of actuators, only used for analytics
        /// </summary>
        private IList<IActuator> m_Actuators;

        /// <summary>
        /// Whether or not we've tried to send analytics for this model. We only ever try to send once per policy,
        /// and do additional deduplication in the analytics code.
        /// </summary>
        private bool m_AnalyticsSent;

        /// <summary>
        /// Instantiate a BarracudaPolicy with the necessary objects for it to run.
        /// </summary>
        /// <param name="actionSpec">The action spec of the behavior.</param>
        /// <param name="actuators">The actuators used for this behavior.</param>
        /// <param name="model">The Neural Network to use.</param>
        /// <param name="inferenceDevice">Which device Barracuda will run on.</param>
        /// <param name="behaviorName">The name of the behavior.</param>
        /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
        /// deterministic. </param>
        public BarracudaPolicy(
            ActionSpec actionSpec,
            IList<IActuator> actuators,
            NNModel model,
            InferenceDevice inferenceDevice,
            string behaviorName,
            bool deterministicInference = false
        )
        {
            var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice, deterministicInference);
            m_ModelRunner = modelRunner;
            m_BehaviorName = behaviorName;
            m_ActionSpec = actionSpec;
            m_Actuators = actuators;
            m_DeterministicInference = deterministicInference;
        }

        /// <inheritdoc />
        public void RequestDecision(AgentInfo info, List<ISensor> sensors)
        {
            SendAnalytics(sensors);
            m_AgentId = info.episodeId;
            m_ModelRunner?.PutObservations(info, sensors);
        }

        [Conditional("MLA_UNITY_ANALYTICS_MODULE")]
        void SendAnalytics(IList<ISensor> sensors)
        {
            if (!m_AnalyticsSent)
            {
                m_AnalyticsSent = true;
                Analytics.InferenceAnalytics.InferenceModelSet(
                    m_ModelRunner.Model,
                    m_BehaviorName,
                    m_ModelRunner.InferenceDevice,
                    sensors,
                    m_ActionSpec,
                    m_Actuators
                );
            }
        }

        /// <inheritdoc />
        public ref readonly ActionBuffers DecideAction()
        {
            if (m_ModelRunner == null)
            {
                m_LastActionBuffer = ActionBuffers.Empty;
            }
            else
            {
                m_ModelRunner?.DecideBatch();
                m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId);
            }
            return ref m_LastActionBuffer;
        }

        public void Dispose()
        {
        }
    }
}