File size: 7,344 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
using System.Collections.Generic;
using Unity.Barracuda;
using NUnit.Framework;
using UnityEngine;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Utils.Tests;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class EditModeTestInternalBrainTensorGenerator
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}
static List<TestAgent> GetFakeAgents(ObservableAttributeOptions observableAttributeOptions = ObservableAttributeOptions.Ignore)
{
var goA = new GameObject("goA");
var bpA = goA.AddComponent<BehaviorParameters>();
bpA.BrainParameters.VectorObservationSize = 3;
bpA.BrainParameters.NumStackedVectorObservations = 1;
bpA.ObservableAttributeHandling = observableAttributeOptions;
var agentA = goA.AddComponent<TestAgent>();
var goB = new GameObject("goB");
var bpB = goB.AddComponent<BehaviorParameters>();
bpB.BrainParameters.VectorObservationSize = 3;
bpB.BrainParameters.NumStackedVectorObservations = 1;
bpB.ObservableAttributeHandling = observableAttributeOptions;
var agentB = goB.AddComponent<TestAgent>();
var agents = new List<TestAgent> { agentA, agentB };
foreach (var agent in agents)
{
agent.LazyInitialize();
}
agentA.collectObservationsSensor.AddObservation(new Vector3(1, 2, 3));
agentB.collectObservationsSensor.AddObservation(new Vector3(4, 5, 6));
var infoA = new AgentInfo
{
storedActions = new ActionBuffers(null, new[] { 1, 2 }),
discreteActionMasks = null,
};
var infoB = new AgentInfo
{
storedActions = new ActionBuffers(null, new[] { 3, 4 }),
discreteActionMasks = new[] { true, false, false, false, false },
};
agentA._Info = infoA;
agentB._Info = infoB;
return agents;
}
[Test]
public void Construction()
{
var alloc = new TensorCachingAllocator();
var mem = new Dictionary<int, List<float>>();
var tensorGenerator = new TensorGenerator(0, alloc, mem);
Assert.IsNotNull(tensorGenerator);
alloc.Dispose();
}
[Test]
public void GenerateBatchSize()
{
var inputTensor = new TensorProxy();
var alloc = new TensorCachingAllocator();
const int batchSize = 4;
var generator = new BatchSizeGenerator(alloc);
generator.Generate(inputTensor, batchSize, null);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0], batchSize);
alloc.Dispose();
}
[Test]
public void GenerateSequenceLength()
{
var inputTensor = new TensorProxy();
var alloc = new TensorCachingAllocator();
const int batchSize = 4;
var generator = new SequenceLengthGenerator(alloc);
generator.Generate(inputTensor, batchSize, null);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0], 1);
alloc.Dispose();
}
[Test]
public void GenerateVectorObservation()
{
var inputTensor = new TensorProxy
{
shape = new long[] { 2, 4 }
};
const int batchSize = 4;
var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll);
var alloc = new TensorCachingAllocator();
var generator = new ObservationGenerator(alloc);
generator.AddSensorIndex(0); // ObservableAttribute (size 1)
generator.AddSensorIndex(1); // TestSensor (size 0)
generator.AddSensorIndex(2); // TestSensor (size 0)
generator.AddSensorIndex(3); // VectorSensor (size 3)
var agent0 = agentInfos[0];
var agent1 = agentInfos[1];
var inputs = new List<AgentInfoSensorsPair>
{
new AgentInfoSensorsPair {agentInfo = agent0._Info, sensors = agent0.sensors},
new AgentInfoSensorsPair {agentInfo = agent1._Info, sensors = agent1.sensors},
};
generator.Generate(inputTensor, batchSize, inputs);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 1], 1);
Assert.AreEqual(inputTensor.data[0, 3], 3);
Assert.AreEqual(inputTensor.data[1, 1], 4);
Assert.AreEqual(inputTensor.data[1, 3], 6);
alloc.Dispose();
}
[Test]
public void GeneratePreviousActionInput()
{
var inputTensor = new TensorProxy
{
shape = new long[] { 2, 2 },
valueType = TensorProxy.TensorType.Integer
};
const int batchSize = 4;
var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new PreviousActionInputGenerator(alloc);
var agent0 = agentInfos[0];
var agent1 = agentInfos[1];
var inputs = new List<AgentInfoSensorsPair>
{
new AgentInfoSensorsPair {agentInfo = agent0._Info, sensors = agent0.sensors},
new AgentInfoSensorsPair {agentInfo = agent1._Info, sensors = agent1.sensors},
};
generator.Generate(inputTensor, batchSize, inputs);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
Assert.AreEqual(inputTensor.data[0, 1], 2);
Assert.AreEqual(inputTensor.data[1, 0], 3);
Assert.AreEqual(inputTensor.data[1, 1], 4);
alloc.Dispose();
}
[Test]
public void GenerateActionMaskInput()
{
var inputTensor = new TensorProxy
{
shape = new long[] { 2, 5 },
valueType = TensorProxy.TensorType.FloatingPoint
};
const int batchSize = 4;
var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new ActionMaskInputGenerator(alloc);
var agent0 = agentInfos[0];
var agent1 = agentInfos[1];
var inputs = new List<AgentInfoSensorsPair>
{
new AgentInfoSensorsPair {agentInfo = agent0._Info, sensors = agent0.sensors},
new AgentInfoSensorsPair {agentInfo = agent1._Info, sensors = agent1.sensors},
};
generator.Generate(inputTensor, batchSize, inputs);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
Assert.AreEqual(inputTensor.data[0, 4], 1);
Assert.AreEqual(inputTensor.data[1, 0], 0);
Assert.AreEqual(inputTensor.data[1, 4], 1);
alloc.Dispose();
}
}
}
|