File size: 3,994 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
using System;
using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Assert = UnityEngine.Assertions.Assert;

namespace Unity.MLAgents.Tests.Actuators
{
    [TestFixture]
    public class VectorActuatorTests
    {
        class TestActionReceiver : IActionReceiver, IHeuristicProvider
        {
            public ActionBuffers LastActionBuffers;
            public int Branch;
            public IList<int> Mask;
            public ActionSpec ActionSpec { get; }
            public bool HeuristicCalled;

            public void OnActionReceived(ActionBuffers actionBuffers)
            {
                LastActionBuffers = actionBuffers;
            }

            public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
            {
                foreach (var actionIndex in Mask)
                {
                    actionMask.SetActionEnabled(Branch, actionIndex, false);
                }
            }

            public void Heuristic(in ActionBuffers actionBuffersOut)
            {
                HeuristicCalled = true;
            }
        }

        [Test]
        public void TestConstruct()
        {
            var ar = new TestActionReceiver();
            var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name");

            Assert.IsTrue(va.ActionSpec.NumDiscreteActions == 3);
            Assert.IsTrue(va.ActionSpec.SumOfDiscreteBranchSizes == 6);
            Assert.IsTrue(va.ActionSpec.NumContinuousActions == 0);

            var va1 = new VectorActuator(ar, ActionSpec.MakeContinuous(4), "name");

            Assert.IsTrue(va1.ActionSpec.NumContinuousActions == 4);
            Assert.IsTrue(va1.ActionSpec.SumOfDiscreteBranchSizes == 0);
            Assert.AreEqual(va1.Name, "name-Continuous");
        }

        [Test]
        public void TestOnActionReceived()
        {
            var ar = new TestActionReceiver();
            var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name");

            var discreteActions = new[] { 0, 1, 1 };
            var ab = new ActionBuffers(ActionSegment<float>.Empty,
                new ActionSegment<int>(discreteActions, 0, 3));

            va.OnActionReceived(ab);

            Assert.AreEqual(ar.LastActionBuffers, ab);
            va.ResetData();
            Assert.AreEqual(va.ActionBuffers.ContinuousActions, ActionSegment<float>.Empty);
            Assert.AreEqual(va.ActionBuffers.DiscreteActions, ActionSegment<int>.Empty);
        }

        [Test]
        public void TestResetData()
        {
            var ar = new TestActionReceiver();
            var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name");

            var discreteActions = new[] { 0, 1, 1 };
            var ab = new ActionBuffers(ActionSegment<float>.Empty,
                new ActionSegment<int>(discreteActions, 0, 3));

            va.OnActionReceived(ab);
        }

        [Test]
        public void TestWriteDiscreteActionMask()
        {
            var ar = new TestActionReceiver();
            var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name");
            var bdam = new ActuatorDiscreteActionMask(new[] { va }, 6, 3);

            var groundTruthMask = new[] { false, true, false, false, true, true };

            ar.Branch = 1;
            ar.Mask = new[] { 0 };
            va.WriteDiscreteActionMask(bdam);
            ar.Branch = 2;
            ar.Mask = new[] { 1, 2 };
            va.WriteDiscreteActionMask(bdam);

            Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask()));
        }

        [Test]
        public void TestHeuristic()
        {
            var ar = new TestActionReceiver();
            var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name");

            va.Heuristic(new ActionBuffers(Array.Empty<float>(), va.ActionSpec.BranchSizes));
            Assert.IsTrue(ar.HeuristicCalled);
        }
    }
}