File size: 5,697 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
using NUnit.Framework;
using UnityEngine;
using System.IO.Abstractions.TestingHelpers;
using System.Reflection;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.CommunicatorObjects;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Utils.Tests;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class DemonstrationTests
{
const string k_DemoDirectory = "Assets/Demonstrations/";
const string k_ExtensionType = ".demo";
const string k_DemoName = "Test";
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}
[Test]
public void TestSanitization()
{
const string dirtyString = "abc1234567&!@";
const string knownCleanString = "abc123";
var cleanString = DemonstrationRecorder.SanitizeName(dirtyString, 6);
Assert.AreNotEqual(dirtyString, cleanString);
Assert.AreEqual(cleanString, knownCleanString);
}
[Test]
public void TestStoreInitialize()
{
var fileSystem = new MockFileSystem();
var gameobj = new GameObject("gameObj");
var bp = gameobj.AddComponent<BehaviorParameters>();
bp.BrainParameters.VectorObservationSize = 3;
bp.BrainParameters.NumStackedVectorObservations = 2;
bp.BrainParameters.VectorActionDescriptions = new[] { "TestActionA", "TestActionB" };
bp.BrainParameters.ActionSpec = ActionSpec.MakeDiscrete(2, 2);
gameobj.AddComponent<TestAgent>();
Assert.IsFalse(fileSystem.Directory.Exists(k_DemoDirectory));
var demoRec = gameobj.AddComponent<DemonstrationRecorder>();
demoRec.Record = true;
demoRec.DemonstrationName = k_DemoName;
demoRec.DemonstrationDirectory = k_DemoDirectory;
var demoWriter = demoRec.LazyInitialize(fileSystem);
Assert.IsTrue(fileSystem.Directory.Exists(k_DemoDirectory));
Assert.IsTrue(fileSystem.FileExists(k_DemoDirectory + k_DemoName + k_ExtensionType));
var agentInfo = new AgentInfo
{
reward = 1f,
discreteActionMasks = new[] { false, true },
done = true,
episodeId = 5,
maxStepReached = true,
storedActions = new ActionBuffers(null, new[] { 0, 1 }),
};
demoWriter.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
demoRec.Close();
// Make sure close can be called multiple times
demoWriter.Close();
demoRec.Close();
// Make sure trying to write after closing doesn't raise an error.
demoWriter.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
}
public class ObservationAgent : TestAgent
{
public override void CollectObservations(VectorSensor sensor)
{
collectObservationsCalls += 1;
sensor.AddObservation(1f);
sensor.AddObservation(2f);
sensor.AddObservation(3f);
}
}
[Test]
public void TestAgentWrite()
{
var agentGo1 = new GameObject("TestAgent");
var bpA = agentGo1.AddComponent<BehaviorParameters>();
bpA.BrainParameters.VectorObservationSize = 3;
bpA.BrainParameters.NumStackedVectorObservations = 1;
bpA.BrainParameters.VectorActionDescriptions = new[] { "TestActionA", "TestActionB" };
bpA.BrainParameters.ActionSpec = ActionSpec.MakeDiscrete(2, 2);
agentGo1.AddComponent<ObservationAgent>();
var agent1 = agentGo1.GetComponent<ObservationAgent>();
agentGo1.AddComponent<DemonstrationRecorder>();
var demoRecorder = agentGo1.GetComponent<DemonstrationRecorder>();
var fileSystem = new MockFileSystem();
demoRecorder.DemonstrationDirectory = k_DemoDirectory;
demoRecorder.DemonstrationName = "TestBrain";
demoRecorder.Record = true;
demoRecorder.LazyInitialize(fileSystem);
var agentEnableMethod = typeof(Agent).GetMethod("OnEnable",
BindingFlags.Instance | BindingFlags.NonPublic);
var agentSendInfo = typeof(Agent).GetMethod("SendInfo",
BindingFlags.Instance | BindingFlags.NonPublic);
agentEnableMethod?.Invoke(agent1, new object[] { });
// Step the agent
agent1.RequestDecision();
agentSendInfo?.Invoke(agent1, new object[] { });
demoRecorder.Close();
// Read back the demo file and make sure observations were written
var reader = fileSystem.File.OpenRead("Assets/Demonstrations/TestBrain.demo");
reader.Seek(DemonstrationWriter.MetaDataBytes + 1, 0);
BrainParametersProto.Parser.ParseDelimitedFrom(reader);
var agentInfoProto = AgentInfoActionPairProto.Parser.ParseDelimitedFrom(reader).AgentInfo;
var obs = agentInfoProto.Observations[2]; // skip dummy sensors
{
var vecObs = obs.FloatData.Data;
Assert.AreEqual(bpA.BrainParameters.VectorObservationSize, vecObs.Count);
for (var i = 0; i < vecObs.Count; i++)
{
Assert.AreEqual((float)i + 1, vecObs[i]);
}
}
}
}
}
|