ppo-Pyramids-Training
/
com.unity.ml-agents
/Tests
/Editor
/Actuators
/ActuatorDiscreteActionMaskTests.cs
using System.Collections.Generic; | |
using NUnit.Framework; | |
using Unity.MLAgents.Actuators; | |
namespace Unity.MLAgents.Tests.Actuators | |
{ | |
[ | ]|
public class ActuatorDiscreteActionMaskTests | |
{ | |
[ | ]|
public void Construction() | |
{ | |
var masker = new ActuatorDiscreteActionMask(new List<IActuator>(), 0, 0); | |
Assert.IsNotNull(masker); | |
} | |
[ | ]|
public void NullMask() | |
{ | |
var masker = new ActuatorDiscreteActionMask(new List<IActuator>(), 0, 0); | |
var mask = masker.GetMask(); | |
Assert.IsNull(mask); | |
} | |
[ | ]|
public void FirstBranchMask() | |
{ | |
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); | |
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); | |
var mask = masker.GetMask(); | |
Assert.IsNull(mask); | |
masker.SetActionEnabled(0, 1, false); | |
masker.SetActionEnabled(0, 2, false); | |
masker.SetActionEnabled(0, 3, false); | |
mask = masker.GetMask(); | |
Assert.IsFalse(mask[0]); | |
Assert.IsTrue(mask[1]); | |
Assert.IsTrue(mask[2]); | |
Assert.IsTrue(mask[3]); | |
Assert.IsFalse(mask[4]); | |
Assert.AreEqual(mask.Length, 15); | |
} | |
[ | ]|
public void CanOverwriteMask() | |
{ | |
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); | |
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); | |
masker.SetActionEnabled(0, 1, false); | |
var mask = masker.GetMask(); | |
Assert.IsTrue(mask[1]); | |
masker.SetActionEnabled(0, 1, true); | |
Assert.IsFalse(mask[1]); | |
} | |
[ | ]|
public void SecondBranchMask() | |
{ | |
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); | |
var masker = new ActuatorDiscreteActionMask(new[] { actuator1 }, 15, 3); | |
masker.SetActionEnabled(1, 1, false); | |
masker.SetActionEnabled(1, 2, false); | |
masker.SetActionEnabled(1, 3, false); | |
var mask = masker.GetMask(); | |
Assert.IsFalse(mask[0]); | |
Assert.IsFalse(mask[4]); | |
Assert.IsTrue(mask[5]); | |
Assert.IsTrue(mask[6]); | |
Assert.IsTrue(mask[7]); | |
Assert.IsFalse(mask[8]); | |
Assert.IsFalse(mask[9]); | |
} | |
[ | ]|
public void MaskReset() | |
{ | |
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); | |
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); | |
masker.SetActionEnabled(1, 1, false); | |
masker.SetActionEnabled(1, 2, false); | |
masker.SetActionEnabled(1, 3, false); | |
masker.ResetMask(); | |
var mask = masker.GetMask(); | |
for (var i = 0; i < 15; i++) | |
{ | |
Assert.IsFalse(mask[i]); | |
} | |
} | |
[ | ]|
public void ThrowsError() | |
{ | |
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); | |
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); | |
Assert.Catch<UnityAgentsException>( | |
() => masker.SetActionEnabled(0, 5, false)); | |
Assert.Catch<UnityAgentsException>( | |
() => masker.SetActionEnabled(1, 5, false)); | |
masker.SetActionEnabled(2, 5, false); | |
Assert.Catch<UnityAgentsException>( | |
() => masker.SetActionEnabled(3, 1, false)); | |
masker.GetMask(); | |
masker.ResetMask(); | |
masker.SetActionEnabled(0, 0, false); | |
masker.SetActionEnabled(0, 1, false); | |
masker.SetActionEnabled(0, 2, false); | |
masker.SetActionEnabled(0, 3, false); | |
Assert.Catch<UnityAgentsException>( | |
() => masker.GetMask()); | |
} | |
[ | ]|
public void MultipleMaskEdit() | |
{ | |
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); | |
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); | |
masker.SetActionEnabled(0, 0, false); | |
masker.SetActionEnabled(0, 1, false); | |
masker.SetActionEnabled(0, 3, false); | |
masker.SetActionEnabled(2, 1, false); | |
var mask = masker.GetMask(); | |
for (var i = 0; i < 15; i++) | |
{ | |
if ((i == 0) || (i == 1) || (i == 3) || (i == 10)) | |
{ | |
Assert.IsTrue(mask[i]); | |
} | |
else | |
{ | |
Assert.IsFalse(mask[i]); | |
} | |
} | |
} | |
} | |
} | |