|
#if MLA_UNITY_PHYSICS_MODULE |
|
using System.Collections.Generic; |
|
using System.Reflection; |
|
using NUnit.Framework; |
|
using UnityEngine; |
|
using Unity.MLAgents.Sensors; |
|
|
|
namespace Unity.MLAgents.Tests |
|
{ |
|
internal class TestBoxOverlapChecker : BoxOverlapChecker |
|
{ |
|
public TestBoxOverlapChecker( |
|
Vector3 cellScale, |
|
Vector3Int gridSize, |
|
bool rotateWithAgent, |
|
LayerMask colliderMask, |
|
GameObject centerObject, |
|
GameObject agentGameObject, |
|
string[] detectableTags, |
|
int initialColliderBufferSize, |
|
int maxColliderBufferSize |
|
) : base( |
|
cellScale, |
|
gridSize, |
|
rotateWithAgent, |
|
colliderMask, |
|
centerObject, |
|
agentGameObject, |
|
detectableTags, |
|
initialColliderBufferSize, |
|
maxColliderBufferSize) |
|
{} |
|
|
|
public Vector3[] CellLocalPositions |
|
{ |
|
get |
|
{ |
|
return (Vector3[])typeof(BoxOverlapChecker).GetField("m_CellLocalPositions", |
|
BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); |
|
} |
|
} |
|
|
|
public Collider[] ColliderBuffer |
|
{ |
|
get |
|
{ |
|
return (Collider[])typeof(BoxOverlapChecker).GetField("m_ColliderBuffer", |
|
BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); |
|
} |
|
} |
|
|
|
public static TestBoxOverlapChecker CreateChecker( |
|
float cellScaleX = 1f, |
|
float cellScaleZ = 1f, |
|
int gridSizeX = 10, |
|
int gridSizeZ = 10, |
|
bool rotateWithAgent = true, |
|
GameObject centerObject = null, |
|
GameObject agentGameObject = null, |
|
string[] detectableTags = null, |
|
int initialColliderBufferSize = 4, |
|
int maxColliderBufferSize = 500) |
|
{ |
|
return new TestBoxOverlapChecker( |
|
new Vector3(cellScaleX, 0.01f, cellScaleZ), |
|
new Vector3Int(gridSizeX, 1, gridSizeZ), |
|
rotateWithAgent, |
|
LayerMask.GetMask("Default"), |
|
centerObject, |
|
agentGameObject, |
|
detectableTags, |
|
initialColliderBufferSize, |
|
maxColliderBufferSize); |
|
} |
|
} |
|
|
|
public class BoxOverlapCheckerTests |
|
{ |
|
[Test] |
|
public void TestCellLocalPosition() |
|
{ |
|
var testGo = new GameObject("test"); |
|
testGo.transform.position = Vector3.zero; |
|
var boxOverlapSquare = TestBoxOverlapChecker.CreateChecker(gridSizeX: 10, gridSizeZ: 10, rotateWithAgent: false, agentGameObject: testGo); |
|
|
|
var localPos = boxOverlapSquare.CellLocalPositions; |
|
Assert.AreEqual(new Vector3(-4.5f, 0, -4.5f), localPos[0]); |
|
Assert.AreEqual(new Vector3(-4.5f, 0, 4.5f), localPos[9]); |
|
Assert.AreEqual(new Vector3(4.5f, 0, -4.5f), localPos[90]); |
|
Assert.AreEqual(new Vector3(4.5f, 0, 4.5f), localPos[99]); |
|
Object.DestroyImmediate(testGo); |
|
|
|
var testGo2 = new GameObject("test"); |
|
testGo2.transform.position = new Vector3(3.5f, 8f, 17f); |
|
var boxOverlapRect = TestBoxOverlapChecker.CreateChecker(gridSizeX: 5, gridSizeZ: 15, rotateWithAgent: true, agentGameObject: testGo); |
|
|
|
localPos = boxOverlapRect.CellLocalPositions; |
|
Assert.AreEqual(new Vector3(-2f, 0, -7f), localPos[0]); |
|
Assert.AreEqual(new Vector3(-2f, 0, 7f), localPos[14]); |
|
Assert.AreEqual(new Vector3(2f, 0, -7f), localPos[60]); |
|
Assert.AreEqual(new Vector3(2f, 0, 7f), localPos[74]); |
|
Object.DestroyImmediate(testGo2); |
|
} |
|
|
|
[Test] |
|
public void TestCellGlobalPositionNoRotate() |
|
{ |
|
var testGo = new GameObject("test"); |
|
var position = new Vector3(3.5f, 8f, 17f); |
|
testGo.transform.position = position; |
|
var boxOverlap = TestBoxOverlapChecker.CreateChecker(gridSizeX: 10, gridSizeZ: 10, rotateWithAgent: false, agentGameObject: testGo, centerObject: testGo); |
|
|
|
Assert.AreEqual(new Vector3(-4.5f, 0, -4.5f) + position, boxOverlap.GetCellGlobalPosition(0)); |
|
Assert.AreEqual(new Vector3(-4.5f, 0, 4.5f) + position, boxOverlap.GetCellGlobalPosition(9)); |
|
Assert.AreEqual(new Vector3(4.5f, 0, -4.5f) + position, boxOverlap.GetCellGlobalPosition(90)); |
|
Assert.AreEqual(new Vector3(4.5f, 0, 4.5f) + position, boxOverlap.GetCellGlobalPosition(99)); |
|
|
|
testGo.transform.Rotate(0, 90, 0); |
|
Assert.AreEqual(new Vector3(-4.5f, 0, -4.5f) + position, boxOverlap.GetCellGlobalPosition(0)); |
|
Assert.AreEqual(new Vector3(-4.5f, 0, 4.5f) + position, boxOverlap.GetCellGlobalPosition(9)); |
|
Assert.AreEqual(new Vector3(4.5f, 0, -4.5f) + position, boxOverlap.GetCellGlobalPosition(90)); |
|
Assert.AreEqual(new Vector3(4.5f, 0, 4.5f) + position, boxOverlap.GetCellGlobalPosition(99)); |
|
|
|
Object.DestroyImmediate(testGo); |
|
} |
|
|
|
[Test] |
|
public void TestCellGlobalPositionRotate() |
|
{ |
|
var testGo = new GameObject("test"); |
|
var position = new Vector3(15f, 6f, 13f); |
|
testGo.transform.position = position; |
|
var boxOverlap = TestBoxOverlapChecker.CreateChecker(gridSizeX: 5, gridSizeZ: 15, rotateWithAgent: true, agentGameObject: testGo, centerObject: testGo); |
|
|
|
Assert.AreEqual(new Vector3(-2f, 0, -7f) + position, boxOverlap.GetCellGlobalPosition(0)); |
|
Assert.AreEqual(new Vector3(-2f, 0, 7f) + position, boxOverlap.GetCellGlobalPosition(14)); |
|
Assert.AreEqual(new Vector3(2f, 0, -7f) + position, boxOverlap.GetCellGlobalPosition(60)); |
|
Assert.AreEqual(new Vector3(2f, 0, 7f) + position, boxOverlap.GetCellGlobalPosition(74)); |
|
|
|
testGo.transform.Rotate(0, 90, 0); |
|
|
|
Assert.AreEqual(Vector3Int.RoundToInt(new Vector3(-7f, 0, 2f) + position), Vector3Int.RoundToInt(boxOverlap.GetCellGlobalPosition(0))); |
|
Assert.AreEqual(Vector3Int.RoundToInt(new Vector3(7f, 0, 2f) + position), Vector3Int.RoundToInt(boxOverlap.GetCellGlobalPosition(14))); |
|
Assert.AreEqual(Vector3Int.RoundToInt(new Vector3(-7f, 0, -2f) + position), Vector3Int.RoundToInt(boxOverlap.GetCellGlobalPosition(60))); |
|
Assert.AreEqual(Vector3Int.RoundToInt(new Vector3(7f, 0, -2f) + position), Vector3Int.RoundToInt(boxOverlap.GetCellGlobalPosition(74))); |
|
|
|
Object.DestroyImmediate(testGo); |
|
} |
|
|
|
[Test] |
|
public void TestBufferResize() |
|
{ |
|
List<GameObject> testObjects = new List<GameObject>(); |
|
var testGo = new GameObject("test"); |
|
testGo.transform.position = Vector3.zero; |
|
testObjects.Add(testGo); |
|
var boxOverlap = TestBoxOverlapChecker.CreateChecker(agentGameObject: testGo, centerObject: testGo, initialColliderBufferSize: 2, maxColliderBufferSize: 5); |
|
boxOverlap.Perceive(); |
|
Assert.AreEqual(2, boxOverlap.ColliderBuffer.Length); |
|
|
|
for (var i = 0; i < 3; i++) |
|
{ |
|
var boxGo = new GameObject("test"); |
|
boxGo.transform.position = Vector3.zero; |
|
boxGo.AddComponent<BoxCollider>(); |
|
testObjects.Add(boxGo); |
|
} |
|
boxOverlap.Perceive(); |
|
Assert.AreEqual(4, boxOverlap.ColliderBuffer.Length); |
|
|
|
for (var i = 0; i < 2; i++) |
|
{ |
|
var boxGo = new GameObject("test"); |
|
boxGo.transform.position = Vector3.zero; |
|
boxGo.AddComponent<BoxCollider>(); |
|
testObjects.Add(boxGo); |
|
} |
|
boxOverlap.Perceive(); |
|
Assert.AreEqual(5, boxOverlap.ColliderBuffer.Length); |
|
|
|
Object.DestroyImmediate(testGo); |
|
foreach (var go in testObjects) |
|
{ |
|
Object.DestroyImmediate(go); |
|
} |
|
} |
|
|
|
[Test] |
|
public void TestParseCollidersClosest() |
|
{ |
|
var tag1 = "Player"; |
|
List<GameObject> testObjects = new List<GameObject>(); |
|
var testGo = new GameObject("test"); |
|
testGo.transform.position = Vector3.zero; |
|
var boxOverlap = TestBoxOverlapChecker.CreateChecker( |
|
cellScaleX: 10f, |
|
cellScaleZ: 10f, |
|
gridSizeX: 2, |
|
gridSizeZ: 2, |
|
agentGameObject: testGo, |
|
centerObject: testGo, |
|
detectableTags: new[] { tag1 }); |
|
var helper = new VerifyParseCollidersHelper(); |
|
boxOverlap.GridOverlapDetectedClosest += helper.DetectedAction; |
|
|
|
for (var i = 0; i < 3; i++) |
|
{ |
|
var boxGo = new GameObject("test"); |
|
boxGo.transform.position = new Vector3(i + 1, 0, 1); |
|
boxGo.AddComponent<BoxCollider>(); |
|
boxGo.tag = tag1; |
|
testObjects.Add(boxGo); |
|
} |
|
|
|
boxOverlap.Perceive(); |
|
helper.Verify(1, new List<GameObject> { testObjects[0] }); |
|
|
|
Object.DestroyImmediate(testGo); |
|
foreach (var go in testObjects) |
|
{ |
|
Object.DestroyImmediate(go); |
|
} |
|
} |
|
|
|
[Test] |
|
public void TestParseCollidersAll() |
|
{ |
|
var tag1 = "Player"; |
|
List<GameObject> testObjects = new List<GameObject>(); |
|
var testGo = new GameObject("test"); |
|
testGo.transform.position = Vector3.zero; |
|
var boxOverlap = TestBoxOverlapChecker.CreateChecker( |
|
cellScaleX: 10f, |
|
cellScaleZ: 10f, |
|
gridSizeX: 2, |
|
gridSizeZ: 2, |
|
agentGameObject: testGo, |
|
centerObject: testGo, |
|
detectableTags: new[] { tag1 }); |
|
var helper = new VerifyParseCollidersHelper(); |
|
boxOverlap.GridOverlapDetectedAll += helper.DetectedAction; |
|
|
|
for (var i = 0; i < 3; i++) |
|
{ |
|
var boxGo = new GameObject("test"); |
|
boxGo.transform.position = new Vector3(i + 1, 0, 1); |
|
boxGo.AddComponent<BoxCollider>(); |
|
boxGo.tag = tag1; |
|
testObjects.Add(boxGo); |
|
} |
|
|
|
boxOverlap.Perceive(); |
|
helper.Verify(3, testObjects); |
|
|
|
Object.DestroyImmediate(testGo); |
|
foreach (var go in testObjects) |
|
{ |
|
Object.DestroyImmediate(go); |
|
} |
|
} |
|
|
|
public class VerifyParseCollidersHelper |
|
{ |
|
int m_NumInvoked; |
|
List<GameObject> m_ParsedObjects = new List<GameObject>(); |
|
|
|
public void DetectedAction(GameObject go, int cellIndex) |
|
{ |
|
m_NumInvoked += 1; |
|
m_ParsedObjects.Add(go); |
|
} |
|
|
|
public void Verify(int expectNumInvoke, List<GameObject> expectedObjects) |
|
{ |
|
Assert.AreEqual(expectNumInvoke, m_NumInvoked); |
|
Assert.AreEqual(expectedObjects.Count, m_ParsedObjects.Count); |
|
foreach (var obj in expectedObjects) |
|
{ |
|
Assert.Contains(obj, m_ParsedObjects); |
|
} |
|
} |
|
} |
|
|
|
[Test] |
|
public void TestOnlyOneChecker() |
|
{ |
|
var testGo = new GameObject("test"); |
|
testGo.transform.position = Vector3.zero; |
|
var gridSensorComponent = testGo.AddComponent<SimpleTestGridSensorComponent>(); |
|
gridSensorComponent.SetComponentParameters(useGridSensorBase: true, useTestingGridSensor: true); |
|
var sensors = gridSensorComponent.CreateSensors(); |
|
int numChecker = 0; |
|
foreach (var sensor in sensors) |
|
{ |
|
var gridsensor = (GridSensorBase)sensor; |
|
if (gridsensor.m_GridPerception != null) |
|
{ |
|
numChecker += 1; |
|
} |
|
} |
|
Assert.AreEqual(1, numChecker); |
|
} |
|
} |
|
} |
|
#endif |
|
|