File size: 4,098 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
using System;
using System.Collections.Generic;
using NUnit.Framework;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Analytics;
using Unity.MLAgents.Policies;
using UnityEditor;
namespace Unity.MLAgents.Tests.Analytics
{
[TestFixture]
public class TrainingAnalyticsTests
{
[TestCase("foo?team=42", ExpectedResult = "foo")]
[TestCase("foo", ExpectedResult = "foo")]
[TestCase("foo?bar?team=1337", ExpectedResult = "foo?bar")]
public string TestParseBehaviorName(string fullyQualifiedBehaviorName)
{
return TrainingAnalytics.ParseBehaviorName(fullyQualifiedBehaviorName);
}
[Test]
public void TestRemotePolicyEvent()
{
var behaviorName = "testBehavior";
var sensor1 = new Test3DSensor("SensorA", 21, 20, 3);
var sensor2 = new Test3DSensor("SensorB", 20, 22, 3);
var sensors = new List<ISensor> { sensor1, sensor2 };
var actionSpec = ActionSpec.MakeContinuous(2);
var vectorActuator = new VectorActuator(null, actionSpec, "test'");
var actuators = new IActuator[] { vectorActuator };
var remotePolicyEvent = TrainingAnalytics.GetEventForRemotePolicy(behaviorName, sensors, actionSpec, actuators);
// The behavior name should be hashed, not pass-through.
Assert.AreNotEqual(behaviorName, remotePolicyEvent.BehaviorName);
Assert.AreEqual(2, remotePolicyEvent.ObservationSpecs.Count);
Assert.AreEqual(3, remotePolicyEvent.ObservationSpecs[0].DimensionInfos.Length);
Assert.AreEqual(20, remotePolicyEvent.ObservationSpecs[0].DimensionInfos[0].Size);
Assert.AreEqual(0, remotePolicyEvent.ObservationSpecs[0].ObservationType);
Assert.AreEqual("None", remotePolicyEvent.ObservationSpecs[0].CompressionType);
Assert.AreEqual(Test3DSensor.k_BuiltInSensorType, remotePolicyEvent.ObservationSpecs[0].BuiltInSensorType);
Assert.AreEqual(2, remotePolicyEvent.ActionSpec.NumContinuousActions);
Assert.AreEqual(0, remotePolicyEvent.ActionSpec.NumDiscreteActions);
Assert.AreEqual(2, remotePolicyEvent.ActuatorInfos[0].NumContinuousActions);
Assert.AreEqual(0, remotePolicyEvent.ActuatorInfos[0].NumDiscreteActions);
}
[Test]
public void TestRemotePolicy()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
using (new AnalyticsUtils.DisableAnalyticsSending())
{
var actionSpec = ActionSpec.MakeContinuous(3);
var policy = new RemotePolicy(actionSpec, Array.Empty<IActuator>(), "TestBehavior?team=42");
policy.RequestDecision(new AgentInfo(), new List<ISensor>());
}
Academy.Instance.Dispose();
}
[TestCase("a name we expect to hash", ExpectedResult = "d084a8b6da6a6a1c097cdc9ffea95e1546da4647352113ed77cbe7b4192e6d73")]
[TestCase("another_name", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")]
[TestCase("0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")]
public string TestTrainingBehaviorInitialized(string stringToMaybeHash)
{
var tbiEvent = new TrainingBehaviorInitializedEvent();
tbiEvent.BehaviorName = stringToMaybeHash;
tbiEvent.Config = "{}";
var sanitizedEvent = TrainingAnalytics.SanitizeTrainingBehaviorInitializedEvent(tbiEvent);
return sanitizedEvent.BehaviorName;
}
[Test]
public void TestEnableAnalytics()
{
#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE
Assert.IsTrue(EditorAnalytics.enabled == TrainingAnalytics.EnableAnalytics());
#else
Assert.IsFalse(TrainingAnalytics.EnableAnalytics());
#endif
}
}
}
|