File size: 4,022 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using UnityEngine;
namespace Unity.MLAgents.Tests.Policies
{
[TestFixture]
public class HeuristicPolicyTest
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}
/// <summary>
/// Assert that the action buffers are initialized to zero, and then set them to non-zero values.
/// </summary>
/// <param name="actionsOut"></param>
static void CheckAndSetBuffer(in ActionBuffers actionsOut)
{
var continuousActions = actionsOut.ContinuousActions;
for (var continuousIndex = 0; continuousIndex < continuousActions.Length; continuousIndex++)
{
Assert.AreEqual(continuousActions[continuousIndex], 0.0f);
continuousActions[continuousIndex] = 1.0f;
}
var discreteActions = actionsOut.DiscreteActions;
for (var discreteIndex = 0; discreteIndex < discreteActions.Length; discreteIndex++)
{
Assert.AreEqual(discreteActions[discreteIndex], 0);
discreteActions[discreteIndex] = 1;
}
}
class ActionClearedAgent : Agent
{
public int HeuristicCalls;
public override void Heuristic(in ActionBuffers actionsOut)
{
CheckAndSetBuffer(actionsOut);
HeuristicCalls++;
}
}
class ActionClearedActuator : IActuator
{
public int HeuristicCalls;
public ActionClearedActuator(ActionSpec actionSpec)
{
ActionSpec = actionSpec;
Name = GetType().Name;
}
public void OnActionReceived(ActionBuffers actionBuffers)
{
}
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
}
public void Heuristic(in ActionBuffers actionBuffersOut)
{
CheckAndSetBuffer(actionBuffersOut);
HeuristicCalls++;
}
public ActionSpec ActionSpec { get; }
public string Name { get; }
public void ResetData()
{
}
}
class ActionClearedActuatorComponent : ActuatorComponent
{
public ActionClearedActuator ActionClearedActuator;
public ActionClearedActuatorComponent()
{
ActionSpec = new ActionSpec(2, new[] { 3, 3 });
}
public override IActuator[] CreateActuators()
{
ActionClearedActuator = new ActionClearedActuator(ActionSpec);
return new IActuator[] { ActionClearedActuator };
}
public override ActionSpec ActionSpec { get; }
}
[Test]
public void TestActionsCleared()
{
var gameObj = new GameObject();
var agent = gameObj.AddComponent<ActionClearedAgent>();
var behaviorParameters = agent.GetComponent<BehaviorParameters>();
behaviorParameters.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 4 });
behaviorParameters.BrainParameters.VectorObservationSize = 0;
behaviorParameters.BehaviorType = BehaviorType.HeuristicOnly;
var actuatorComponent = gameObj.AddComponent<ActionClearedActuatorComponent>();
agent.LazyInitialize();
const int k_NumSteps = 5;
for (var i = 0; i < k_NumSteps; i++)
{
agent.RequestDecision();
Academy.Instance.EnvironmentStep();
}
Assert.AreEqual(agent.HeuristicCalls, k_NumSteps);
Assert.AreEqual(actuatorComponent.ActionClearedActuator.HeuristicCalls, k_NumSteps);
}
}
}
|