File size: 4,961 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
using System;
using System.Collections.Generic;

namespace Unity.MLAgents.Actuators
{
    /// <summary>
    /// Implementation of IDiscreteActionMask that allows writing to the action mask from an <see cref="IActuator"/>.
    /// </summary>
    internal class ActuatorDiscreteActionMask : IDiscreteActionMask
    {
        /// When using discrete control, is the starting indices of the actions
        /// when all the branches are concatenated with each other.
        int[] m_StartingActionIndices;

        int[] m_BranchSizes;

        bool[] m_CurrentMask;

        IList<IActuator> m_Actuators;

        readonly int m_SumOfDiscreteBranchSizes;
        readonly int m_NumBranches;

        /// <summary>
        /// The offset into the branches array that is used when actuators are writing to the action mask.
        /// </summary>
        public int CurrentBranchOffset { get; set; }

        internal ActuatorDiscreteActionMask(IList<IActuator> actuators, int sumOfDiscreteBranchSizes, int numBranches, int[] branchSizes = null)
        {
            m_Actuators = actuators;
            m_SumOfDiscreteBranchSizes = sumOfDiscreteBranchSizes;
            m_NumBranches = numBranches;
            m_BranchSizes = branchSizes;
        }

        /// <inheritdoc/>
        public void SetActionEnabled(int branch, int actionIndex, bool isEnabled)
        {
            LazyInitialize();
#if DEBUG
            if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
            {
                throw new UnityAgentsException(
                    "Invalid Action Masking: Action Mask is too large for specified branch.");
            }
#endif
            m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled;
        }

        void LazyInitialize()
        {
            if (m_BranchSizes == null)
            {
                m_BranchSizes = new int[m_NumBranches];
                var start = 0;
                for (var i = 0; i < m_Actuators.Count; i++)
                {
                    var actuator = m_Actuators[i];
                    var branchSizes = actuator.ActionSpec.BranchSizes;
                    Array.Copy(branchSizes, 0, m_BranchSizes, start, branchSizes.Length);
                    start += branchSizes.Length;
                }
            }

            // By default, the masks are null. If we want to specify a new mask, we initialize
            // the actionMasks with trues.
            if (m_CurrentMask == null)
            {
                m_CurrentMask = new bool[m_SumOfDiscreteBranchSizes];
            }

            // If this is the first time the masked actions are used, we generate the starting
            // indices for each branch.
            if (m_StartingActionIndices == null)
            {
                m_StartingActionIndices = Utilities.CumSum(m_BranchSizes);
            }
        }

        /// <summary>
        /// Get the current mask for an agent.
        /// </summary>
        /// <returns>A mask for the agent. A boolean array of length equal to the total number of
        /// actions.</returns>
        internal bool[] GetMask()
        {
#if DEBUG
            if (m_CurrentMask != null)
            {
                AssertMask();
            }
#endif
            return m_CurrentMask;
        }

        /// <summary>
        /// Makes sure that the current mask is usable.
        /// </summary>
        void AssertMask()
        {
#if DEBUG
            for (var branchIndex = 0; branchIndex < m_NumBranches; branchIndex++)
            {
                if (AreAllActionsMasked(branchIndex))
                {
                    throw new UnityAgentsException(
                        "Invalid Action Masking : All the actions of branch " + branchIndex +
                        " are masked.");
                }
            }
#endif
        }

        /// <summary>
        /// Resets the current mask for an agent.
        /// </summary>
        internal void ResetMask()
        {
            if (m_CurrentMask != null)
            {
                Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length);
            }
        }

        /// <summary>
        /// Checks if all the actions in the input branch are masked.
        /// </summary>
        /// <param name="branch"> The index of the branch to check.</param>
        /// <returns> True if all the actions of the branch are masked.</returns>
        bool AreAllActionsMasked(int branch)
        {
            if (m_CurrentMask == null)
            {
                return false;
            }
            var start = m_StartingActionIndices[branch];
            var end = m_StartingActionIndices[branch + 1];
            for (var i = start; i < end; i++)
            {
                if (!m_CurrentMask[i])
                {
                    return false;
                }
            }
            return true;
        }
    }
}