File size: 4,993 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Runtime.CompilerServices;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Sensor.Tests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]
namespace Unity.MLAgents.Utils.Tests
{
internal class TestPolicy : IPolicy
{
public Action OnRequestDecision;
ObservationWriter m_ObsWriter = new ObservationWriter();
static ActionSpec s_ActionSpec = ActionSpec.MakeContinuous(1);
static ActionBuffers s_EmptyActionBuffers = new ActionBuffers(new float[1], Array.Empty<int>());
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
foreach (var sensor in sensors)
{
sensor.GetObservationProto(m_ObsWriter);
}
OnRequestDecision?.Invoke();
}
public ref readonly ActionBuffers DecideAction() { return ref s_EmptyActionBuffers; }
public void Dispose() { }
}
public class TestAgent : Agent
{
internal AgentInfo _Info
{
get
{
return (AgentInfo)typeof(Agent).GetField("m_Info", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
}
set
{
typeof(Agent).GetField("m_Info", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, value);
}
}
internal void SetPolicy(IPolicy policy)
{
typeof(Agent).GetField("m_Brain", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, policy);
}
internal IPolicy GetPolicy()
{
return (IPolicy)typeof(Agent).GetField("m_Brain", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
}
public int initializeAgentCalls;
public int collectObservationsCalls;
public int collectObservationsCallsForEpisode;
public int agentActionCalls;
public int agentActionCallsForEpisode;
public int agentOnEpisodeBeginCalls;
public int heuristicCalls;
public TestSensor sensor1;
public TestSensor sensor2;
[Observable("observableFloat")]
public float observableFloat;
public override void Initialize()
{
initializeAgentCalls += 1;
// Add in some custom Sensors so we can confirm they get sorted as expected.
sensor1 = new TestSensor("testsensor1");
sensor2 = new TestSensor("testsensor2");
sensor2.compressionType = SensorCompressionType.PNG;
sensors.Add(sensor2);
sensors.Add(sensor1);
}
public override void CollectObservations(VectorSensor sensor)
{
collectObservationsCalls += 1;
collectObservationsCallsForEpisode += 1;
sensor.AddObservation(collectObservationsCallsForEpisode);
}
public override void OnActionReceived(ActionBuffers buffers)
{
agentActionCalls += 1;
agentActionCallsForEpisode += 1;
AddReward(0.1f);
}
public override void OnEpisodeBegin()
{
agentOnEpisodeBeginCalls += 1;
collectObservationsCallsForEpisode = 0;
agentActionCallsForEpisode = 0;
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var obs = GetObservations();
var continuousActions = actionsOut.ContinuousActions;
continuousActions[0] = (int)obs[0];
heuristicCalls++;
}
}
public class TestSensor : ISensor
{
public string sensorName;
public int numWriteCalls;
public int numCompressedCalls;
public int numResetCalls;
public SensorCompressionType compressionType = SensorCompressionType.None;
public TestSensor(string n)
{
sensorName = n;
}
public ObservationSpec GetObservationSpec()
{
return ObservationSpec.Vector(0);
}
public int Write(ObservationWriter writer)
{
numWriteCalls++;
// No-op
return 0;
}
public byte[] GetCompressedObservation()
{
numCompressedCalls++;
return new byte[] { 0 };
}
public CompressionSpec GetCompressionSpec()
{
return new CompressionSpec(compressionType);
}
public string GetName()
{
return sensorName;
}
public void Update() { }
public void Reset()
{
numResetCalls++;
}
}
public class TestClasses
{
}
}
|