File size: 4,241 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
using NUnit.Framework;
using System.IO;
using Unity.MLAgents.SideChannels;
namespace Unity.MLAgents.Tests
{
public class SamplerTests
{
const int k_Seed = 1337;
const double k_Epsilon = 0.0001;
EnvironmentParametersChannel m_Channel;
public SamplerTests()
{
m_Channel = SideChannelManager.GetSideChannel<EnvironmentParametersChannel>();
// if running test on its own
if (m_Channel == null)
{
m_Channel = new EnvironmentParametersChannel();
SideChannelManager.RegisterSideChannel(m_Channel);
}
}
[Test]
public void UniformSamplerTest()
{
float min_value = 1.0f;
float max_value = 2.0f;
string parameter = "parameter1";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.Uniform);
outgoingMsg.WriteFloat32(min_value);
outgoingMsg.WriteFloat32(max_value);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelManager.ProcessSideChannelData(message);
}
Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}
[Test]
public void GaussianSamplerTest()
{
float mean = 3.0f;
float stddev = 0.2f;
string parameter = "parameter2";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.Gaussian);
outgoingMsg.WriteFloat32(mean);
outgoingMsg.WriteFloat32(stddev);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelManager.ProcessSideChannelData(message);
}
Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}
[Test]
public void MultiRangeUniformSamplerTest()
{
float[] intervals = new float[4];
intervals[0] = 1.2f;
intervals[1] = 2f;
intervals[2] = 3.2f;
intervals[3] = 4.1f;
string parameter = "parameter3";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform);
outgoingMsg.WriteFloatList(intervals);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelManager.ProcessSideChannelData(message);
}
Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}
internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg)
{
byte[] message = msg.ToByteArray();
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
binaryWriter.Write(sideChannel.ChannelId.ToByteArray());
binaryWriter.Write(message.Length);
binaryWriter.Write(message);
}
return memStream.ToArray();
}
}
}
}
|