File size: 5,697 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
using NUnit.Framework;
using UnityEngine;
using System.IO.Abstractions.TestingHelpers;
using System.Reflection;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.CommunicatorObjects;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Utils.Tests;

namespace Unity.MLAgents.Tests
{
    [TestFixture]
    public class DemonstrationTests
    {
        const string k_DemoDirectory = "Assets/Demonstrations/";
        const string k_ExtensionType = ".demo";
        const string k_DemoName = "Test";

        [SetUp]
        public void SetUp()
        {
            if (Academy.IsInitialized)
            {
                Academy.Instance.Dispose();
            }
        }

        [Test]
        public void TestSanitization()
        {
            const string dirtyString = "abc1234567&!@";
            const string knownCleanString = "abc123";
            var cleanString = DemonstrationRecorder.SanitizeName(dirtyString, 6);
            Assert.AreNotEqual(dirtyString, cleanString);
            Assert.AreEqual(cleanString, knownCleanString);
        }

        [Test]
        public void TestStoreInitialize()
        {
            var fileSystem = new MockFileSystem();

            var gameobj = new GameObject("gameObj");

            var bp = gameobj.AddComponent<BehaviorParameters>();
            bp.BrainParameters.VectorObservationSize = 3;
            bp.BrainParameters.NumStackedVectorObservations = 2;
            bp.BrainParameters.VectorActionDescriptions = new[] { "TestActionA", "TestActionB" };
            bp.BrainParameters.ActionSpec = ActionSpec.MakeDiscrete(2, 2);

            gameobj.AddComponent<TestAgent>();

            Assert.IsFalse(fileSystem.Directory.Exists(k_DemoDirectory));

            var demoRec = gameobj.AddComponent<DemonstrationRecorder>();
            demoRec.Record = true;
            demoRec.DemonstrationName = k_DemoName;
            demoRec.DemonstrationDirectory = k_DemoDirectory;
            var demoWriter = demoRec.LazyInitialize(fileSystem);

            Assert.IsTrue(fileSystem.Directory.Exists(k_DemoDirectory));
            Assert.IsTrue(fileSystem.FileExists(k_DemoDirectory + k_DemoName + k_ExtensionType));

            var agentInfo = new AgentInfo
            {
                reward = 1f,
                discreteActionMasks = new[] { false, true },
                done = true,
                episodeId = 5,
                maxStepReached = true,
                storedActions = new ActionBuffers(null, new[] { 0, 1 }),
            };


            demoWriter.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
            demoRec.Close();

            // Make sure close can be called multiple times
            demoWriter.Close();
            demoRec.Close();

            // Make sure trying to write after closing doesn't raise an error.
            demoWriter.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
        }

        public class ObservationAgent : TestAgent
        {
            public override void CollectObservations(VectorSensor sensor)
            {
                collectObservationsCalls += 1;
                sensor.AddObservation(1f);
                sensor.AddObservation(2f);
                sensor.AddObservation(3f);
            }
        }

        [Test]
        public void TestAgentWrite()
        {
            var agentGo1 = new GameObject("TestAgent");
            var bpA = agentGo1.AddComponent<BehaviorParameters>();
            bpA.BrainParameters.VectorObservationSize = 3;
            bpA.BrainParameters.NumStackedVectorObservations = 1;
            bpA.BrainParameters.VectorActionDescriptions = new[] { "TestActionA", "TestActionB" };
            bpA.BrainParameters.ActionSpec = ActionSpec.MakeDiscrete(2, 2);

            agentGo1.AddComponent<ObservationAgent>();
            var agent1 = agentGo1.GetComponent<ObservationAgent>();

            agentGo1.AddComponent<DemonstrationRecorder>();
            var demoRecorder = agentGo1.GetComponent<DemonstrationRecorder>();
            var fileSystem = new MockFileSystem();
            demoRecorder.DemonstrationDirectory = k_DemoDirectory;
            demoRecorder.DemonstrationName = "TestBrain";
            demoRecorder.Record = true;
            demoRecorder.LazyInitialize(fileSystem);

            var agentEnableMethod = typeof(Agent).GetMethod("OnEnable",
                BindingFlags.Instance | BindingFlags.NonPublic);
            var agentSendInfo = typeof(Agent).GetMethod("SendInfo",
                BindingFlags.Instance | BindingFlags.NonPublic);

            agentEnableMethod?.Invoke(agent1, new object[] { });

            // Step the agent
            agent1.RequestDecision();
            agentSendInfo?.Invoke(agent1, new object[] { });

            demoRecorder.Close();

            // Read back the demo file and make sure observations were written
            var reader = fileSystem.File.OpenRead("Assets/Demonstrations/TestBrain.demo");
            reader.Seek(DemonstrationWriter.MetaDataBytes + 1, 0);
            BrainParametersProto.Parser.ParseDelimitedFrom(reader);

            var agentInfoProto = AgentInfoActionPairProto.Parser.ParseDelimitedFrom(reader).AgentInfo;
            var obs = agentInfoProto.Observations[2]; // skip dummy sensors
            {
                var vecObs = obs.FloatData.Data;
                Assert.AreEqual(bpA.BrainParameters.VectorObservationSize, vecObs.Count);
                for (var i = 0; i < vecObs.Count; i++)
                {
                    Assert.AreEqual((float)i + 1, vecObs[i]);
                }
            }
        }
    }
}