|
using System; |
|
using System.Collections.Generic; |
|
|
|
namespace Unity.MLAgents.Actuators |
|
{ |
|
|
|
|
|
|
|
internal class ActuatorDiscreteActionMask : IDiscreteActionMask |
|
{ |
|
|
|
|
|
int[] m_StartingActionIndices; |
|
|
|
int[] m_BranchSizes; |
|
|
|
bool[] m_CurrentMask; |
|
|
|
IList<IActuator> m_Actuators; |
|
|
|
readonly int m_SumOfDiscreteBranchSizes; |
|
readonly int m_NumBranches; |
|
|
|
|
|
|
|
|
|
public int CurrentBranchOffset { get; set; } |
|
|
|
internal ActuatorDiscreteActionMask(IList<IActuator> actuators, int sumOfDiscreteBranchSizes, int numBranches, int[] branchSizes = null) |
|
{ |
|
m_Actuators = actuators; |
|
m_SumOfDiscreteBranchSizes = sumOfDiscreteBranchSizes; |
|
m_NumBranches = numBranches; |
|
m_BranchSizes = branchSizes; |
|
} |
|
|
|
|
|
public void SetActionEnabled(int branch, int actionIndex, bool isEnabled) |
|
{ |
|
LazyInitialize(); |
|
#if DEBUG |
|
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch]) |
|
{ |
|
throw new UnityAgentsException( |
|
"Invalid Action Masking: Action Mask is too large for specified branch."); |
|
} |
|
#endif |
|
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled; |
|
} |
|
|
|
void LazyInitialize() |
|
{ |
|
if (m_BranchSizes == null) |
|
{ |
|
m_BranchSizes = new int[m_NumBranches]; |
|
var start = 0; |
|
for (var i = 0; i < m_Actuators.Count; i++) |
|
{ |
|
var actuator = m_Actuators[i]; |
|
var branchSizes = actuator.ActionSpec.BranchSizes; |
|
Array.Copy(branchSizes, 0, m_BranchSizes, start, branchSizes.Length); |
|
start += branchSizes.Length; |
|
} |
|
} |
|
|
|
|
|
|
|
if (m_CurrentMask == null) |
|
{ |
|
m_CurrentMask = new bool[m_SumOfDiscreteBranchSizes]; |
|
} |
|
|
|
|
|
|
|
if (m_StartingActionIndices == null) |
|
{ |
|
m_StartingActionIndices = Utilities.CumSum(m_BranchSizes); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
internal bool[] GetMask() |
|
{ |
|
#if DEBUG |
|
if (m_CurrentMask != null) |
|
{ |
|
AssertMask(); |
|
} |
|
#endif |
|
return m_CurrentMask; |
|
} |
|
|
|
|
|
|
|
|
|
void AssertMask() |
|
{ |
|
#if DEBUG |
|
for (var branchIndex = 0; branchIndex < m_NumBranches; branchIndex++) |
|
{ |
|
if (AreAllActionsMasked(branchIndex)) |
|
{ |
|
throw new UnityAgentsException( |
|
"Invalid Action Masking : All the actions of branch " + branchIndex + |
|
" are masked."); |
|
} |
|
} |
|
#endif |
|
} |
|
|
|
|
|
|
|
|
|
internal void ResetMask() |
|
{ |
|
if (m_CurrentMask != null) |
|
{ |
|
Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
bool AreAllActionsMasked(int branch) |
|
{ |
|
if (m_CurrentMask == null) |
|
{ |
|
return false; |
|
} |
|
var start = m_StartingActionIndices[branch]; |
|
var end = m_StartingActionIndices[branch + 1]; |
|
for (var i = start; i < end; i++) |
|
{ |
|
if (!m_CurrentMask[i]) |
|
{ |
|
return false; |
|
} |
|
} |
|
return true; |
|
} |
|
} |
|
} |
|
|