ppo-Pyramids-Training
/
com.unity.ml-agents
/Tests
/Editor
/Integrations
/Match3
/Match3SensorTests.cs
using System.Collections.Generic; | |
using System.IO; | |
using System.Reflection; | |
using NUnit.Framework; | |
using UnityEngine; | |
using Unity.MLAgents.Integrations.Match3; | |
using Unity.MLAgents.Sensors; | |
namespace Unity.MLAgents.Tests.Integrations.Match3 | |
{ | |
public class Match3SensorTests | |
{ | |
// Whether the expected PNG data should be written to a file. | |
// Only set this to true if the compressed observation format changes. | |
private bool WritePNGDataToFile = false; | |
private const string k_CellObservationPng = "match3obs_"; | |
private const string k_SpecialObservationPng = "match3obs_special_"; | |
private const string k_Suffix2x2 = "2x2_"; | |
[ | ]|
[ | ]|
public void TestVectorObservations(bool fullBoard) | |
{ | |
var boardString = | |
@"000 | |
000 | |
010"; | |
var gameObj = new GameObject("board"); | |
var board = gameObj.AddComponent<StringBoard>(); | |
board.SetBoard(boardString); | |
if (!fullBoard) | |
{ | |
board.CurrentRows = 2; | |
board.CurrentColumns = 2; | |
} | |
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>(); | |
sensorComponent.ObservationType = Match3ObservationType.Vector; | |
var sensor = sensorComponent.CreateSensors()[0]; | |
var expectedShape = new InplaceArray<int>(3 * 3 * 2); | |
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); | |
float[] expectedObs; | |
if (fullBoard) | |
{ | |
expectedObs = new float[] | |
{ | |
1, 0, /* 0 */ 0, 1, /* 1 */ 1, 0, /* 0 */ | |
1, 0, /* 0 */ 1, 0, /* 0 */ 1, 0, /* 0 */ | |
1, 0, /* 0 */ 1, 0, /* 0 */ 1, 0, /* 0 */ | |
}; | |
} | |
else | |
{ | |
expectedObs = new float[] | |
{ | |
1, 0, /* 0 */ 0, 1, /* 1 */ 0, 0, /* empty */ | |
1, 0, /* 0 */ 1, 0, /* 0 */ 0, 0, /* empty */ | |
0, 0, /* empty */ 0, 0, /* empty */ 0, 0, /* empty */ | |
}; | |
} | |
SensorTestHelper.CompareObservation(sensor, expectedObs); | |
} | |
[ | ]|
public void TestVectorObservationsSpecial() | |
{ | |
var boardString = | |
@"000 | |
000 | |
010"; | |
var specialString = | |
@"010 | |
200 | |
000"; | |
var gameObj = new GameObject("board"); | |
var board = gameObj.AddComponent<StringBoard>(); | |
board.SetBoard(boardString); | |
board.SetSpecial(specialString); | |
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>(); | |
sensorComponent.ObservationType = Match3ObservationType.Vector; | |
var sensors = sensorComponent.CreateSensors(); | |
var cellSensor = sensors[0]; | |
var specialSensor = sensors[1]; | |
{ | |
var expectedShape = new InplaceArray<int>(3 * 3 * 2); | |
Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape); | |
var expectedObs = new float[] | |
{ | |
1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */ | |
1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ | |
1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ | |
}; | |
SensorTestHelper.CompareObservation(cellSensor, expectedObs); | |
} | |
{ | |
var expectedShape = new InplaceArray<int>(3 * 3 * 3); | |
Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape); | |
var expectedObs = new float[] | |
{ | |
1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */ | |
0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ | |
1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */ | |
}; | |
SensorTestHelper.CompareObservation(specialSensor, expectedObs); | |
} | |
} | |
[ | ]|
[ | ]|
public void TestVisualObservations(bool fullBoard) | |
{ | |
var boardString = | |
@"000 | |
000 | |
010"; | |
var gameObj = new GameObject("board"); | |
var board = gameObj.AddComponent<StringBoard>(); | |
board.SetBoard(boardString); | |
if (!fullBoard) | |
{ | |
board.CurrentRows = 2; | |
board.CurrentColumns = 2; | |
} | |
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>(); | |
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; | |
var sensor = sensorComponent.CreateSensors()[0]; | |
var expectedShape = new InplaceArray<int>(3, 3, 2); | |
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); | |
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType); | |
float[] expectedObs; | |
float[,,] expectedObs3D; | |
if (fullBoard) | |
{ | |
expectedObs = new float[] | |
{ | |
1, 0, /**/ 0, 1, /**/ 1, 0, | |
1, 0, /**/ 1, 0, /**/ 1, 0, | |
1, 0, /**/ 1, 0, /**/ 1, 0, | |
}; | |
expectedObs3D = new float[,,] | |
{ | |
{{1, 0}, {0, 1}, {1, 0}}, | |
{{1, 0}, {1, 0}, {1, 0}}, | |
{{1, 0}, {1, 0}, {1, 0}}, | |
}; | |
} | |
else | |
{ | |
expectedObs = new float[] | |
{ | |
1, 0, /* 0 */ 0, 1, /* 1 */ 0, 0, /* empty */ | |
1, 0, /* 0 */ 1, 0, /* 0 */ 0, 0, /* empty */ | |
0, 0, /* empty */ 0, 0, /* empty */ 0, 0, /* empty */ | |
}; | |
expectedObs3D = new float[,,] | |
{ | |
{{1, 0}, {0, 1}, {0, 0}}, | |
{{1, 0}, {1, 0}, {0, 0}}, | |
{{0, 0}, {0, 0}, {0, 0}}, | |
}; | |
} | |
SensorTestHelper.CompareObservation(sensor, expectedObs); | |
SensorTestHelper.CompareObservation(sensor, expectedObs3D); | |
} | |
[ | ]|
public void TestVisualObservationsSpecial() | |
{ | |
var boardString = | |
@"000 | |
000 | |
010"; | |
var specialString = | |
@"010 | |
200 | |
000"; | |
var gameObj = new GameObject("board"); | |
var board = gameObj.AddComponent<StringBoard>(); | |
board.SetBoard(boardString); | |
board.SetSpecial(specialString); | |
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>(); | |
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; | |
var sensors = sensorComponent.CreateSensors(); | |
var cellSensor = sensors[0]; | |
var specialSensor = sensors[1]; | |
{ | |
var expectedShape = new InplaceArray<int>(3, 3, 2); | |
Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape); | |
Assert.AreEqual(SensorCompressionType.None, cellSensor.GetCompressionSpec().SensorCompressionType); | |
var expectedObs = new float[] | |
{ | |
1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */ | |
1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ | |
1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ | |
}; | |
SensorTestHelper.CompareObservation(cellSensor, expectedObs); | |
var expectedObs3D = new float[,,] | |
{ | |
{{1, 0}, {0, 1}, {1, 0}}, | |
{{1, 0}, {1, 0}, {1, 0}}, | |
{{1, 0}, {1, 0}, {1, 0}}, | |
}; | |
SensorTestHelper.CompareObservation(cellSensor, expectedObs3D); | |
} | |
{ | |
var expectedShape = new InplaceArray<int>(3, 3, 3); | |
Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape); | |
Assert.AreEqual(SensorCompressionType.None, specialSensor.GetCompressionSpec().SensorCompressionType); | |
var expectedObs = new float[] | |
{ | |
1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */ | |
0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ | |
1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */ | |
}; | |
SensorTestHelper.CompareObservation(specialSensor, expectedObs); | |
var expectedObs3D = new float[,,] | |
{ | |
{{1, 0, 0}, {1, 0, 0}, {1, 0, 0}}, | |
{{0, 0, 1}, {1, 0, 0}, {1, 0, 0}}, | |
{{1, 0, 0}, {0, 1, 0}, {1, 0, 0}}, | |
}; | |
SensorTestHelper.CompareObservation(specialSensor, expectedObs3D); | |
} | |
// Test that Dispose() cleans up the component and its sensors | |
sensorComponent.Dispose(); | |
var flags = BindingFlags.Instance | BindingFlags.NonPublic; | |
var componentSensors = (ISensor[])typeof(Match3SensorComponent).GetField("m_Sensors", flags).GetValue(sensorComponent); | |
Assert.IsNull(componentSensors); | |
var cellTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor); | |
Assert.IsNull(cellTexture); | |
var specialTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor); | |
Assert.IsNull(specialTexture); | |
} | |
[ | ]|
[ | ]|
[ | ]|
[ | ]|
public void TestCompressedVisualObservationsSpecial(bool fullBoard, bool useSpecial) | |
{ | |
var boardString = | |
@"003 | |
000 | |
010"; | |
var specialString = | |
@"014 | |
200 | |
000"; | |
var gameObj = new GameObject("board"); | |
var board = gameObj.AddComponent<StringBoard>(); | |
board.SetBoard(boardString); | |
var paths = new List<string> { k_CellObservationPng }; | |
if (useSpecial) | |
{ | |
board.SetSpecial(specialString); | |
paths.Add(k_SpecialObservationPng); | |
} | |
if (!fullBoard) | |
{ | |
// Shrink the board, and change the paths we're using for the ground truth PNGs | |
board.CurrentRows = 2; | |
board.CurrentColumns = 2; | |
for (var i = 0; i < paths.Count; i++) | |
{ | |
paths[i] = paths[i] + k_Suffix2x2; | |
} | |
} | |
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>(); | |
sensorComponent.ObservationType = Match3ObservationType.CompressedVisual; | |
var sensors = sensorComponent.CreateSensors(); | |
var expectedNumChannels = new[] { 4, 5 }; | |
for (var i = 0; i < paths.Count; i++) | |
{ | |
var sensor = sensors[i]; | |
var expectedShape = new InplaceArray<int>(3, 3, expectedNumChannels[i]); | |
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); | |
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType); | |
var pngData = sensor.GetCompressedObservation(); | |
if (WritePNGDataToFile) | |
{ | |
// Enable this if the format of the observation changes | |
SavePNGs(pngData, paths[i]); | |
} | |
var expectedPng = LoadPNGs(paths[i], 2); | |
Assert.AreEqual(expectedPng, pngData); | |
} | |
} | |
/// <summary> | |
/// Helper method for un-concatenating PNG observations. | |
/// </summary> | |
/// <param name="concatenated"></param> | |
/// <returns></returns> | |
List<byte[]> SplitPNGs(byte[] concatenated) | |
{ | |
var pngsOut = new List<byte[]>(); | |
var pngHeader = new byte[] { 137, 80, 78, 71, 13, 10, 26, 10 }; | |
var current = new List<byte>(); | |
for (var i = 0; i < concatenated.Length; i++) | |
{ | |
current.Add(concatenated[i]); | |
// Check if the header starts at the next position | |
// If so, we'll start a new output array. | |
var headerIsNext = false; | |
if (i + 1 < concatenated.Length - pngHeader.Length) | |
{ | |
for (var j = 0; j < pngHeader.Length; j++) | |
{ | |
if (concatenated[i + 1 + j] != pngHeader[j]) | |
{ | |
break; | |
} | |
if (j == pngHeader.Length - 1) | |
{ | |
headerIsNext = true; | |
} | |
} | |
} | |
if (headerIsNext) | |
{ | |
pngsOut.Add(current.ToArray()); | |
current = new List<byte>(); | |
} | |
} | |
pngsOut.Add(current.ToArray()); | |
return pngsOut; | |
} | |
void SavePNGs(byte[] concatenatedPngData, string pathPrefix) | |
{ | |
var splitPngs = SplitPNGs(concatenatedPngData); | |
for (var i = 0; i < splitPngs.Count; i++) | |
{ | |
var pngData = splitPngs[i]; | |
var path = $"Packages/com.unity.ml-agents/Tests/Editor/Integrations/Match3/{pathPrefix}{i}.png"; | |
using (var sw = File.Create(path)) | |
{ | |
foreach (var b in pngData) | |
{ | |
sw.WriteByte(b); | |
} | |
} | |
} | |
} | |
byte[] LoadPNGs(string pathPrefix, int numExpected) | |
{ | |
var bytesOut = new List<byte>(); | |
for (var i = 0; i < numExpected; i++) | |
{ | |
var path = $"Packages/com.unity.ml-agents/Tests/Editor/Integrations/Match3/{pathPrefix}{i}.png"; | |
var res = File.ReadAllBytes(path); | |
bytesOut.AddRange(res); | |
} | |
return bytesOut.ToArray(); | |
} | |
[ | ]|
public void TestNoBoardReturnsEmptySensors() | |
{ | |
var gameObj = new GameObject("board"); | |
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>(); | |
var sensors = sensorComponent.CreateSensors(); | |
Assert.AreEqual(0, sensors.Length); | |
} | |
} | |
} | |