File size: 3,237 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
using System.Collections.Generic;
using Unity.Barracuda;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;

namespace Unity.MLAgents.Tests
{
    public class DiscreteActionOutputApplierTest
    {
        [Test]
        public void TestDiscreteApply()
        {
            var actionSpec = ActionSpec.MakeDiscrete(3, 2);

            var applier = new DiscreteActionOutputApplier(actionSpec, 2020, null);
            var agentIds = new List<int> { 42, 1337 };
            var actionBuffers = new Dictionary<int, ActionBuffers>();
            actionBuffers[42] = new ActionBuffers(actionSpec);
            actionBuffers[1337] = new ActionBuffers(actionSpec);

            var actionTensor = new TensorProxy
            {
                data = new Tensor(
                    2,
                    2,
                    new[]
                    {
                        2.0f, // Agent 0, branch 0
                        1.0f, // Agent 0, branch 1
                        0.0f, // Agent 1, branch 0
                        0.0f  // Agent 1, branch 1
                    }),
                shape = new long[] { 2, 2 },
                valueType = TensorProxy.TensorType.FloatingPoint
            };

            applier.Apply(actionTensor, agentIds, actionBuffers);
            Assert.AreEqual(2, actionBuffers[42].DiscreteActions[0]);
            Assert.AreEqual(1, actionBuffers[42].DiscreteActions[1]);

            Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[0]);
            Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[1]);
        }
    }

    public class LegacyDiscreteActionOutputApplierTest
    {
        [Test]
        public void TestDiscreteApply()
        {
            var actionSpec = ActionSpec.MakeDiscrete(3, 2);
            const float smallLogProb = -1000.0f;
            const float largeLogProb = -1.0f;

            var logProbs = new TensorProxy
            {
                data = new Tensor(
                    2,
                    5,
                    new[]
                    {
                        smallLogProb, smallLogProb, largeLogProb, // Agent 0, branch 0
                        smallLogProb, largeLogProb,               // Agent 0, branch 1
                        largeLogProb, smallLogProb, smallLogProb, // Agent 1, branch 0
                        largeLogProb, smallLogProb,               // Agent 1, branch 1
                    }),
                valueType = TensorProxy.TensorType.FloatingPoint
            };

            var applier = new LegacyDiscreteActionOutputApplier(actionSpec, 2020, null);
            var agentIds = new List<int> { 42, 1337 };
            var actionBuffers = new Dictionary<int, ActionBuffers>();
            actionBuffers[42] = new ActionBuffers(actionSpec);
            actionBuffers[1337] = new ActionBuffers(actionSpec);

            applier.Apply(logProbs, agentIds, actionBuffers);
            Assert.AreEqual(2, actionBuffers[42].DiscreteActions[0]);
            Assert.AreEqual(1, actionBuffers[42].DiscreteActions[1]);

            Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[0]);
            Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[1]);
        }
    }
}