File size: 1,435 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
using NUnit.Framework;
using Unity.MLAgents.Inference.Utils;
namespace Unity.MLAgents.Tests
{
public class MultinomialTest
{
[Test]
public void TestDim1()
{
var m = new Multinomial(2018);
var cdf = new[] { 1f };
Assert.AreEqual(0, m.Sample(cdf));
Assert.AreEqual(0, m.Sample(cdf));
Assert.AreEqual(0, m.Sample(cdf));
}
[Test]
public void TestDim1Unscaled()
{
var m = new Multinomial(2018);
var cdf = new[] { 0.1f };
Assert.AreEqual(0, m.Sample(cdf));
Assert.AreEqual(0, m.Sample(cdf));
Assert.AreEqual(0, m.Sample(cdf));
}
[Test]
public void TestDim3()
{
var m = new Multinomial(2018);
var cdf = new[] { 0.1f, 0.3f, 1.0f };
Assert.AreEqual(2, m.Sample(cdf));
Assert.AreEqual(2, m.Sample(cdf));
Assert.AreEqual(2, m.Sample(cdf));
Assert.AreEqual(1, m.Sample(cdf));
}
[Test]
public void TestDim3Unscaled()
{
var m = new Multinomial(2018);
var cdf = new[] { 0.05f, 0.15f, 0.5f };
Assert.AreEqual(2, m.Sample(cdf));
Assert.AreEqual(2, m.Sample(cdf));
Assert.AreEqual(2, m.Sample(cdf));
Assert.AreEqual(1, m.Sample(cdf));
}
}
}
|