File size: 1,435 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
using NUnit.Framework;
using Unity.MLAgents.Inference.Utils;

namespace Unity.MLAgents.Tests
{
    public class MultinomialTest
    {
        [Test]
        public void TestDim1()
        {
            var m = new Multinomial(2018);
            var cdf = new[] { 1f };

            Assert.AreEqual(0, m.Sample(cdf));
            Assert.AreEqual(0, m.Sample(cdf));
            Assert.AreEqual(0, m.Sample(cdf));
        }

        [Test]
        public void TestDim1Unscaled()
        {
            var m = new Multinomial(2018);
            var cdf = new[] { 0.1f };

            Assert.AreEqual(0, m.Sample(cdf));
            Assert.AreEqual(0, m.Sample(cdf));
            Assert.AreEqual(0, m.Sample(cdf));
        }

        [Test]
        public void TestDim3()
        {
            var m = new Multinomial(2018);
            var cdf = new[] { 0.1f, 0.3f, 1.0f };

            Assert.AreEqual(2, m.Sample(cdf));
            Assert.AreEqual(2, m.Sample(cdf));
            Assert.AreEqual(2, m.Sample(cdf));
            Assert.AreEqual(1, m.Sample(cdf));
        }

        [Test]
        public void TestDim3Unscaled()
        {
            var m = new Multinomial(2018);
            var cdf = new[] { 0.05f, 0.15f, 0.5f };

            Assert.AreEqual(2, m.Sample(cdf));
            Assert.AreEqual(2, m.Sample(cdf));
            Assert.AreEqual(2, m.Sample(cdf));
            Assert.AreEqual(1, m.Sample(cdf));
        }
    }
}