|
using System; |
|
using System.Collections.Generic; |
|
using Unity.MLAgents.Inference.Utils; |
|
using Random = System.Random; |
|
|
|
namespace Unity.MLAgents |
|
{ |
|
|
|
|
|
|
|
internal static class SamplerFactory |
|
{ |
|
public static Func<float> CreateUniformSampler(float min, float max, int seed) |
|
{ |
|
Random distr = new Random(seed); |
|
return () => min + (float)distr.NextDouble() * (max - min); |
|
} |
|
|
|
public static Func<float> CreateGaussianSampler(float mean, float stddev, int seed) |
|
{ |
|
RandomNormal distr = new RandomNormal(seed, mean, stddev); |
|
return () => (float)distr.NextDouble(); |
|
} |
|
|
|
public static Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed) |
|
{ |
|
|
|
Random distr = new Random(seed); |
|
|
|
float sumIntervalSizes = 0; |
|
|
|
int numIntervals = (intervals.Count / 2); |
|
|
|
float[] intervalSizes = new float[numIntervals]; |
|
|
|
IList<Func<float>> intervalFuncs = new Func<float>[numIntervals]; |
|
|
|
|
|
for (int i = 0; i < numIntervals; i++) |
|
{ |
|
var min = intervals[2 * i]; |
|
var max = intervals[2 * i + 1]; |
|
var intervalSize = max - min; |
|
sumIntervalSizes += intervalSize; |
|
intervalSizes[i] = intervalSize; |
|
intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize; |
|
} |
|
|
|
for (int i = 0; i < numIntervals; i++) |
|
{ |
|
intervalSizes[i] = intervalSizes[i] / sumIntervalSizes; |
|
} |
|
|
|
for (int i = 1; i < numIntervals; i++) |
|
{ |
|
intervalSizes[i] += intervalSizes[i - 1]; |
|
} |
|
Multinomial intervalDistr = new Multinomial(seed + 1); |
|
float MultiRange() |
|
{ |
|
int sampledInterval = intervalDistr.Sample(intervalSizes); |
|
return intervalFuncs[sampledInterval].Invoke(); |
|
} |
|
|
|
return MultiRange; |
|
} |
|
} |
|
} |
|
|