ppo-Pyramids-Training
/
com.unity.ml-agents
/Tests
/Editor
/Integrations
/Match3
/Match3ActuatorTests.cs
using System.Collections.Generic; | |
using NUnit.Framework; | |
using Unity.MLAgents.Actuators; | |
using Unity.MLAgents.Integrations.Match3; | |
using UnityEngine; | |
namespace Unity.MLAgents.Tests.Integrations.Match3 | |
{ | |
internal class SimpleBoard : AbstractBoard | |
{ | |
public int Rows; | |
public int Columns; | |
public int NumCellTypes; | |
public int NumSpecialTypes; | |
public int LastMoveIndex; | |
public bool MovesAreValid = true; | |
public bool CallbackCalled; | |
public override BoardSize GetMaxBoardSize() | |
{ | |
return new BoardSize | |
{ | |
Rows = Rows, | |
Columns = Columns, | |
NumCellTypes = NumCellTypes, | |
NumSpecialTypes = NumSpecialTypes | |
}; | |
} | |
public override int GetCellType(int row, int col) | |
{ | |
return 0; | |
} | |
public override int GetSpecialType(int row, int col) | |
{ | |
return 0; | |
} | |
public override bool IsMoveValid(Move m) | |
{ | |
return MovesAreValid; | |
} | |
public override bool MakeMove(Move m) | |
{ | |
LastMoveIndex = m.MoveIndex; | |
return MovesAreValid; | |
} | |
public void Callback() | |
{ | |
CallbackCalled = true; | |
} | |
} | |
public class Match3ActuatorTests | |
{ | |
[ | ]|
public void SetUp() | |
{ | |
if (Academy.IsInitialized) | |
{ | |
Academy.Instance.Dispose(); | |
} | |
} | |
[ | ]|
[ | ]|
public void TestValidMoves(bool movesAreValid) | |
{ | |
// Check that a board with no valid moves doesn't raise an exception. | |
var gameObj = new GameObject(); | |
var board = gameObj.AddComponent<SimpleBoard>(); | |
var agent = gameObj.AddComponent<Agent>(); | |
gameObj.AddComponent<Match3ActuatorComponent>(); | |
board.Rows = 5; | |
board.Columns = 5; | |
board.NumCellTypes = 5; | |
board.NumSpecialTypes = 0; | |
board.MovesAreValid = movesAreValid; | |
board.OnNoValidMovesAction = board.Callback; | |
board.LastMoveIndex = -1; | |
agent.LazyInitialize(); | |
agent.RequestDecision(); | |
Academy.Instance.EnvironmentStep(); | |
if (movesAreValid) | |
{ | |
Assert.IsFalse(board.CallbackCalled); | |
} | |
else | |
{ | |
Assert.IsTrue(board.CallbackCalled); | |
} | |
Assert.AreNotEqual(-1, board.LastMoveIndex); | |
} | |
[ | ]|
public void TestActionSpec() | |
{ | |
var gameObj = new GameObject(); | |
var board = gameObj.AddComponent<SimpleBoard>(); | |
var actuator = gameObj.AddComponent<Match3ActuatorComponent>(); | |
board.Rows = 5; | |
board.Columns = 5; | |
board.NumCellTypes = 5; | |
board.NumSpecialTypes = 0; | |
var actionSpec = actuator.ActionSpec; | |
Assert.AreEqual(1, actionSpec.NumDiscreteActions); | |
Assert.AreEqual(board.NumMoves(), actionSpec.BranchSizes[0]); | |
} | |
[ | ]|
public void TestActionSpecNullBoard() | |
{ | |
var gameObj = new GameObject(); | |
var actuator = gameObj.AddComponent<Match3ActuatorComponent>(); | |
var actionSpec = actuator.ActionSpec; | |
Assert.AreEqual(0, actionSpec.NumDiscreteActions); | |
Assert.AreEqual(0, actionSpec.NumContinuousActions); | |
} | |
public class HashSetActionMask : IDiscreteActionMask | |
{ | |
public HashSet<int>[] HashSets; | |
public HashSetActionMask(ActionSpec spec) | |
{ | |
HashSets = new HashSet<int>[spec.NumDiscreteActions]; | |
for (var i = 0; i < spec.NumDiscreteActions; i++) | |
{ | |
HashSets[i] = new HashSet<int>(); | |
} | |
} | |
public void SetActionEnabled(int branch, int actionIndex, bool isEnabled) | |
{ | |
var hashSet = HashSets[branch]; | |
if (isEnabled) | |
{ | |
hashSet.Remove(actionIndex); | |
} | |
else | |
{ | |
hashSet.Add(actionIndex); | |
} | |
} | |
} | |
[ | ]|
[ | ]|
public void TestMasking(bool fullBoard) | |
{ | |
var gameObj = new GameObject("board"); | |
var board = gameObj.AddComponent<StringBoard>(); | |
var boardString = | |
@"0105 | |
1024 | |
0203 | |
2022"; | |
board.SetBoard(boardString); | |
var boardSize = board.GetMaxBoardSize(); | |
if (!fullBoard) | |
{ | |
board.CurrentRows -= 1; | |
} | |
var validMoves = AbstractBoardTests.GetValidMoves4x4(fullBoard, boardSize); | |
var actuatorComponent = gameObj.AddComponent<Match3ActuatorComponent>(); | |
var actuator = actuatorComponent.CreateActuators()[0]; | |
var masks = new HashSetActionMask(actuator.ActionSpec); | |
actuator.WriteDiscreteActionMask(masks); | |
// Run through all moves and make sure those are the only valid ones | |
HashSet<int> validIndices = new HashSet<int>(); | |
foreach (var m in validMoves) | |
{ | |
validIndices.Add(m.MoveIndex); | |
} | |
// Valid moves and masked moves should be disjoint | |
Assert.IsFalse(validIndices.Overlaps(masks.HashSets[0])); | |
// And they should add up to all the potential moves | |
Assert.AreEqual(validIndices.Count + masks.HashSets[0].Count, board.NumMoves()); | |
} | |
[ | ]|
public void TestNoBoardReturnsEmptyActuators() | |
{ | |
var gameObj = new GameObject("board"); | |
var actuatorComponent = gameObj.AddComponent<Match3ActuatorComponent>(); | |
var actuators = actuatorComponent.CreateActuators(); | |
Assert.AreEqual(0, actuators.Length); | |
} | |
} | |
} | |