File size: 2,272 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
namespace Unity.MLAgents.Inference.Utils
{
    /// <summary>
    /// Multinomial - Draws samples from a multinomial distribution given a (potentially unscaled)
    /// cumulative mass function (CMF). This means that the CMF need not "end" with probability
    /// mass of 1.0. For instance: [0.1, 0.2, 0.5] is a valid (unscaled). What is important is
    /// that it is a cumulative function, not a probability function. In other words,
    /// entry[i] = P(x \le i), NOT P(i - 1 \le x \lt i).
    /// (\le stands for less than or equal to while \lt is strictly less than).
    /// </summary>
    internal class Multinomial
    {
        readonly System.Random m_Random;

        /// <summary>
        /// Constructor.
        /// </summary>
        /// <param name="seed">
        /// Seed for the random number generator used in the sampling process.
        /// </param>
        public Multinomial(int seed)
        {
            m_Random = new System.Random(seed);
        }

        /// <summary>
        /// Samples from the Multinomial distribution defined by the provided cumulative
        /// mass function.
        /// </summary>
        /// <param name="cmf">
        /// Cumulative mass function, which may be unscaled. The entries in this array need
        /// to be monotonic (always increasing). If the CMF is scaled, then the last entry in
        /// the array will be 1.0.
        /// </param>
        /// <param name="branchSize">The number of possible branches, i.e. the effective size of the cmf array.</param>
        /// <returns>A sampled index from the CMF ranging from 0 to branchSize-1.</returns>
        public int Sample(float[] cmf, int branchSize)
        {
            var p = (float)m_Random.NextDouble() * cmf[branchSize - 1];
            var cls = 0;
            while (cmf[cls] < p)
            {
                ++cls;
            }

            return cls;
        }

        /// <summary>
        /// Samples from the Multinomial distribution defined by the provided cumulative
        /// mass function.
        /// </summary>
        /// <returns>A sampled index from the CMF ranging from 0 to cmf.Length-1.</returns>
        public int Sample(float[] cmf)
        {
            return Sample(cmf, cmf.Length);
        }
    }
}