File size: 6,272 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
using System.Collections.Generic;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Integrations.Match3;
using UnityEngine;
namespace Unity.MLAgents.Tests.Integrations.Match3
{
internal class SimpleBoard : AbstractBoard
{
public int Rows;
public int Columns;
public int NumCellTypes;
public int NumSpecialTypes;
public int LastMoveIndex;
public bool MovesAreValid = true;
public bool CallbackCalled;
public override BoardSize GetMaxBoardSize()
{
return new BoardSize
{
Rows = Rows,
Columns = Columns,
NumCellTypes = NumCellTypes,
NumSpecialTypes = NumSpecialTypes
};
}
public override int GetCellType(int row, int col)
{
return 0;
}
public override int GetSpecialType(int row, int col)
{
return 0;
}
public override bool IsMoveValid(Move m)
{
return MovesAreValid;
}
public override bool MakeMove(Move m)
{
LastMoveIndex = m.MoveIndex;
return MovesAreValid;
}
public void Callback()
{
CallbackCalled = true;
}
}
public class Match3ActuatorTests
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}
[TestCase(true)]
[TestCase(false)]
public void TestValidMoves(bool movesAreValid)
{
// Check that a board with no valid moves doesn't raise an exception.
var gameObj = new GameObject();
var board = gameObj.AddComponent<SimpleBoard>();
var agent = gameObj.AddComponent<Agent>();
gameObj.AddComponent<Match3ActuatorComponent>();
board.Rows = 5;
board.Columns = 5;
board.NumCellTypes = 5;
board.NumSpecialTypes = 0;
board.MovesAreValid = movesAreValid;
board.OnNoValidMovesAction = board.Callback;
board.LastMoveIndex = -1;
agent.LazyInitialize();
agent.RequestDecision();
Academy.Instance.EnvironmentStep();
if (movesAreValid)
{
Assert.IsFalse(board.CallbackCalled);
}
else
{
Assert.IsTrue(board.CallbackCalled);
}
Assert.AreNotEqual(-1, board.LastMoveIndex);
}
[Test]
public void TestActionSpec()
{
var gameObj = new GameObject();
var board = gameObj.AddComponent<SimpleBoard>();
var actuator = gameObj.AddComponent<Match3ActuatorComponent>();
board.Rows = 5;
board.Columns = 5;
board.NumCellTypes = 5;
board.NumSpecialTypes = 0;
var actionSpec = actuator.ActionSpec;
Assert.AreEqual(1, actionSpec.NumDiscreteActions);
Assert.AreEqual(board.NumMoves(), actionSpec.BranchSizes[0]);
}
[Test]
public void TestActionSpecNullBoard()
{
var gameObj = new GameObject();
var actuator = gameObj.AddComponent<Match3ActuatorComponent>();
var actionSpec = actuator.ActionSpec;
Assert.AreEqual(0, actionSpec.NumDiscreteActions);
Assert.AreEqual(0, actionSpec.NumContinuousActions);
}
public class HashSetActionMask : IDiscreteActionMask
{
public HashSet<int>[] HashSets;
public HashSetActionMask(ActionSpec spec)
{
HashSets = new HashSet<int>[spec.NumDiscreteActions];
for (var i = 0; i < spec.NumDiscreteActions; i++)
{
HashSets[i] = new HashSet<int>();
}
}
public void SetActionEnabled(int branch, int actionIndex, bool isEnabled)
{
var hashSet = HashSets[branch];
if (isEnabled)
{
hashSet.Remove(actionIndex);
}
else
{
hashSet.Add(actionIndex);
}
}
}
[TestCase(true, TestName = "Full Board")]
[TestCase(false, TestName = "Small Board")]
public void TestMasking(bool fullBoard)
{
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
var boardString =
@"0105
1024
0203
2022";
board.SetBoard(boardString);
var boardSize = board.GetMaxBoardSize();
if (!fullBoard)
{
board.CurrentRows -= 1;
}
var validMoves = AbstractBoardTests.GetValidMoves4x4(fullBoard, boardSize);
var actuatorComponent = gameObj.AddComponent<Match3ActuatorComponent>();
var actuator = actuatorComponent.CreateActuators()[0];
var masks = new HashSetActionMask(actuator.ActionSpec);
actuator.WriteDiscreteActionMask(masks);
// Run through all moves and make sure those are the only valid ones
HashSet<int> validIndices = new HashSet<int>();
foreach (var m in validMoves)
{
validIndices.Add(m.MoveIndex);
}
// Valid moves and masked moves should be disjoint
Assert.IsFalse(validIndices.Overlaps(masks.HashSets[0]));
// And they should add up to all the potential moves
Assert.AreEqual(validIndices.Count + masks.HashSets[0].Count, board.NumMoves());
}
[Test]
public void TestNoBoardReturnsEmptyActuators()
{
var gameObj = new GameObject("board");
var actuatorComponent = gameObj.AddComponent<Match3ActuatorComponent>();
var actuators = actuatorComponent.CreateActuators();
Assert.AreEqual(0, actuators.Length);
}
}
}
|