File size: 5,097 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using UnityEngine;
using UnityEngine.TestTools;
namespace Tests
{
public class PublicApiAgent : Agent
{
public int numHeuristicCalls;
[Observable]
public float ObservableFloat;
public override void Heuristic(in ActionBuffers actionsOut)
{
numHeuristicCalls++;
base.Heuristic(actionsOut);
}
}
// Simple SensorComponent that sets up a StackingSensor
public class StackingComponent : SensorComponent
{
public SensorComponent wrappedComponent;
public int numStacks;
public override ISensor[] CreateSensors()
{
var wrappedSensors = wrappedComponent.CreateSensors();
var sensorsOut = new ISensor[wrappedSensors.Length];
for (var i = 0; i < wrappedSensors.Length; i++)
{
sensorsOut[i] = new StackingSensor(wrappedSensors[i], numStacks);
}
return sensorsOut;
}
}
public class RuntimeApiTest
{
[SetUp]
public static void Setup()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
Academy.Instance.AutomaticSteppingEnabled = false;
}
[UnityTest]
public IEnumerator RuntimeApiTestWithEnumeratorPasses()
{
Academy.Instance.InferenceSeed = 1337;
var gameObject = new GameObject();
var behaviorParams = gameObject.AddComponent<BehaviorParameters>();
behaviorParams.BrainParameters.VectorObservationSize = 3;
behaviorParams.BrainParameters.NumStackedVectorObservations = 2;
behaviorParams.BrainParameters.VectorActionDescriptions = new[] { "Continuous1", "TestActionA", "TestActionB" };
behaviorParams.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 2, 2 });
behaviorParams.BehaviorName = "TestBehavior";
behaviorParams.TeamId = 42;
behaviorParams.UseChildSensors = true;
behaviorParams.DeterministicInference = false;
behaviorParams.ObservableAttributeHandling = ObservableAttributeOptions.ExamineAll;
// Can't actually create an Agent with InferenceOnly and no model, so change back
behaviorParams.BehaviorType = BehaviorType.Default;
#if MLA_UNITY_PHYSICS_MODULE
var sensorComponent = gameObject.AddComponent<RayPerceptionSensorComponent3D>();
sensorComponent.SensorName = "ray3d";
sensorComponent.DetectableTags = new List<string> { "Player", "Respawn" };
sensorComponent.RaysPerDirection = 3;
// Make a StackingSensor that wraps the RayPerceptionSensorComponent3D
// This isn't necessarily practical, just to ensure that it can be done
var wrappingSensorComponent = gameObject.AddComponent<StackingComponent>();
wrappingSensorComponent.wrappedComponent = sensorComponent;
wrappingSensorComponent.numStacks = 3;
// ISensor isn't set up yet.
Assert.IsNull(sensorComponent.RaySensor);
#endif
// Make sure we can set the behavior type correctly after the agent is initialized
// (this creates a new policy).
behaviorParams.BehaviorType = BehaviorType.HeuristicOnly;
// Agent needs to be added after everything else is setup.
var agent = gameObject.AddComponent<PublicApiAgent>();
// DecisionRequester has to be added after Agent.
var decisionRequester = gameObject.AddComponent<DecisionRequester>();
decisionRequester.DecisionPeriod = 2;
decisionRequester.TakeActionsBetweenDecisions = true;
#if MLA_UNITY_PHYSICS_MODULE
// Initialization should set up the sensors
Assert.IsNotNull(sensorComponent.RaySensor);
#endif
// Let's change the inference device
var otherDevice = behaviorParams.InferenceDevice == InferenceDevice.CPU ? InferenceDevice.GPU : InferenceDevice.CPU;
agent.SetModel(behaviorParams.BehaviorName, behaviorParams.Model, otherDevice);
agent.AddReward(1.0f);
// skip a frame.
yield return null;
Academy.Instance.EnvironmentStep();
var actions = agent.GetStoredActionBuffers().DiscreteActions;
// default Heuristic implementation should return zero actions.
Assert.AreEqual(new ActionSegment<int>(new[] { 0, 0 }), actions);
Assert.AreEqual(1, agent.numHeuristicCalls);
Academy.Instance.EnvironmentStep();
Assert.AreEqual(1, agent.numHeuristicCalls);
Academy.Instance.EnvironmentStep();
Assert.AreEqual(2, agent.numHeuristicCalls);
}
}
}
|