|
using NUnit.Framework; |
|
using Unity.MLAgents.Inference.Utils; |
|
|
|
namespace Unity.MLAgents.Tests |
|
{ |
|
public class MultinomialTest |
|
{ |
|
[Test] |
|
public void TestDim1() |
|
{ |
|
var m = new Multinomial(2018); |
|
var cdf = new[] { 1f }; |
|
|
|
Assert.AreEqual(0, m.Sample(cdf)); |
|
Assert.AreEqual(0, m.Sample(cdf)); |
|
Assert.AreEqual(0, m.Sample(cdf)); |
|
} |
|
|
|
[Test] |
|
public void TestDim1Unscaled() |
|
{ |
|
var m = new Multinomial(2018); |
|
var cdf = new[] { 0.1f }; |
|
|
|
Assert.AreEqual(0, m.Sample(cdf)); |
|
Assert.AreEqual(0, m.Sample(cdf)); |
|
Assert.AreEqual(0, m.Sample(cdf)); |
|
} |
|
|
|
[Test] |
|
public void TestDim3() |
|
{ |
|
var m = new Multinomial(2018); |
|
var cdf = new[] { 0.1f, 0.3f, 1.0f }; |
|
|
|
Assert.AreEqual(2, m.Sample(cdf)); |
|
Assert.AreEqual(2, m.Sample(cdf)); |
|
Assert.AreEqual(2, m.Sample(cdf)); |
|
Assert.AreEqual(1, m.Sample(cdf)); |
|
} |
|
|
|
[Test] |
|
public void TestDim3Unscaled() |
|
{ |
|
var m = new Multinomial(2018); |
|
var cdf = new[] { 0.05f, 0.15f, 0.5f }; |
|
|
|
Assert.AreEqual(2, m.Sample(cdf)); |
|
Assert.AreEqual(2, m.Sample(cdf)); |
|
Assert.AreEqual(2, m.Sample(cdf)); |
|
Assert.AreEqual(1, m.Sample(cdf)); |
|
} |
|
} |
|
} |
|
|