File size: 3,991 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
using UnityEngine.Profiling;

namespace Unity.MLAgents.Actuators
{
    /// <summary>
    /// IActuator implementation that forwards calls to an <see cref="IActionReceiver"/> and an <see cref="IHeuristicProvider"/>.
    /// </summary>
    internal class VectorActuator : IActuator, IBuiltInActuator
    {
        IActionReceiver m_ActionReceiver;
        IHeuristicProvider m_HeuristicProvider;

        ActionBuffers m_ActionBuffers;
        internal ActionBuffers ActionBuffers
        {
            get => m_ActionBuffers;
            private set => m_ActionBuffers = value;
        }

        /// <summary>
        /// Create a VectorActuator that forwards to the provided IActionReceiver.
        /// </summary>
        /// <param name="actionReceiver">The <see cref="IActionReceiver"/> used for OnActionReceived and WriteDiscreteActionMask.
        /// If this parameter also implements <see cref="IHeuristicProvider"/> it will be cast and used to forward calls to
        /// <see cref="IHeuristicProvider.Heuristic"/>.</param>
        /// <param name="actionSpec"></param>
        /// <param name="name"></param>
        public VectorActuator(IActionReceiver actionReceiver,
                              ActionSpec actionSpec,
                              string name = "VectorActuator")
            : this(actionReceiver, actionReceiver as IHeuristicProvider, actionSpec, name) { }

        /// <summary>
        /// Create a VectorActuator that forwards to the provided IActionReceiver.
        /// </summary>
        /// <param name="actionReceiver">The <see cref="IActionReceiver"/> used for OnActionReceived and WriteDiscreteActionMask.</param>
        /// <param name="heuristicProvider">The <see cref="IHeuristicProvider"/> used to fill the <see cref="ActionBuffers"/>
        /// for Heuristic Policies.</param>
        /// <param name="actionSpec"></param>
        /// <param name="name"></param>
        public VectorActuator(IActionReceiver actionReceiver,
                              IHeuristicProvider heuristicProvider,
                              ActionSpec actionSpec,
                              string name = "VectorActuator")
        {
            m_ActionReceiver = actionReceiver;
            m_HeuristicProvider = heuristicProvider;
            ActionSpec = actionSpec;
            string suffix;
            if (actionSpec.NumContinuousActions == 0)
            {
                suffix = "-Discrete";
            }
            else if (actionSpec.NumDiscreteActions == 0)
            {
                suffix = "-Continuous";
            }
            else
            {
                suffix = $"-Continuous-{actionSpec.NumContinuousActions}-Discrete-{actionSpec.NumDiscreteActions}";
            }
            Name = name + suffix;
        }

        /// <inheritdoc />
        public void ResetData()
        {
            m_ActionBuffers = ActionBuffers.Empty;
        }

        /// <inheritdoc />
        public void OnActionReceived(ActionBuffers actionBuffers)
        {
            Profiler.BeginSample("VectorActuator.OnActionReceived");
            m_ActionBuffers = actionBuffers;
            m_ActionReceiver.OnActionReceived(m_ActionBuffers);
            Profiler.EndSample();
        }

        public void Heuristic(in ActionBuffers actionBuffersOut)
        {
            Profiler.BeginSample("VectorActuator.Heuristic");
            m_HeuristicProvider?.Heuristic(actionBuffersOut);
            Profiler.EndSample();
        }

        /// <inheritdoc />
        public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
        {
            m_ActionReceiver.WriteDiscreteActionMask(actionMask);
        }

        /// <inheritdoc/>
        public ActionSpec ActionSpec { get; }

        /// <inheritdoc />
        public string Name { get; }

        /// <inheritdoc />
        public virtual BuiltInActuatorType GetBuiltInActuatorType()
        {
            return BuiltInActuatorType.VectorActuator;
        }
    }
}