|
using System; |
|
using System.Collections.Generic; |
|
using Unity.MLAgents.Sensors; |
|
using UnityEngine; |
|
|
|
namespace Unity.MLAgents.Integrations.Match3 |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
public delegate int GridValueProvider(int x, int y); |
|
|
|
|
|
|
|
|
|
|
|
public enum Match3ObservationType |
|
{ |
|
|
|
|
|
|
|
|
|
Vector, |
|
|
|
|
|
|
|
|
|
|
|
UncompressedVisual, |
|
|
|
|
|
|
|
|
|
|
|
|
|
CompressedVisual |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable |
|
{ |
|
Match3ObservationType m_ObservationType; |
|
ObservationSpec m_ObservationSpec; |
|
string m_Name; |
|
|
|
AbstractBoard m_Board; |
|
BoardSize m_MaxBoardSize; |
|
GridValueProvider m_GridValues; |
|
int m_OneHotSize; |
|
|
|
Texture2D m_ObservationTexture; |
|
OneHotToTextureUtil m_TextureUtil; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name) |
|
{ |
|
var maxBoardSize = board.GetMaxBoardSize(); |
|
m_Name = name; |
|
m_MaxBoardSize = maxBoardSize; |
|
m_GridValues = gvp; |
|
m_OneHotSize = oneHotSize; |
|
m_Board = board; |
|
|
|
m_ObservationType = obsType; |
|
m_ObservationSpec = obsType == Match3ObservationType.Vector |
|
? ObservationSpec.Vector(maxBoardSize.Rows * maxBoardSize.Columns * oneHotSize) |
|
: ObservationSpec.Visual(maxBoardSize.Rows, maxBoardSize.Columns, oneHotSize); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) |
|
{ |
|
var maxBoardSize = board.GetMaxBoardSize(); |
|
return new Match3Sensor(board, board.GetCellType, maxBoardSize.NumCellTypes, obsType, name); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) |
|
{ |
|
var maxBoardSize = board.GetMaxBoardSize(); |
|
if (maxBoardSize.NumSpecialTypes == 0) |
|
{ |
|
return null; |
|
} |
|
var specialSize = maxBoardSize.NumSpecialTypes + 1; |
|
return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name); |
|
} |
|
|
|
|
|
public ObservationSpec GetObservationSpec() |
|
{ |
|
return m_ObservationSpec; |
|
} |
|
|
|
|
|
public int Write(ObservationWriter writer) |
|
{ |
|
m_Board.CheckBoardSizes(m_MaxBoardSize); |
|
var currentBoardSize = m_Board.GetCurrentBoardSize(); |
|
|
|
int offset = 0; |
|
var isVisual = m_ObservationType != Match3ObservationType.Vector; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (var r = 0; r < currentBoardSize.Rows; r++) |
|
{ |
|
for (var c = 0; c < currentBoardSize.Columns; c++) |
|
{ |
|
var val = m_GridValues(r, c); |
|
writer.WriteOneHot(offset, r, c, val, m_OneHotSize, isVisual); |
|
offset += m_OneHotSize; |
|
} |
|
|
|
for (var c = currentBoardSize.Columns; c < m_MaxBoardSize.Columns; c++) |
|
{ |
|
writer.WriteZero(offset, r, c, m_OneHotSize, isVisual); |
|
offset += m_OneHotSize; |
|
} |
|
} |
|
|
|
for (var r = currentBoardSize.Rows; r < m_MaxBoardSize.Columns; r++) |
|
{ |
|
for (var c = 0; c < m_MaxBoardSize.Columns; c++) |
|
{ |
|
writer.WriteZero(offset, r, c, m_OneHotSize, isVisual); |
|
offset += m_OneHotSize; |
|
} |
|
} |
|
|
|
return offset; |
|
} |
|
|
|
|
|
public byte[] GetCompressedObservation() |
|
{ |
|
m_Board.CheckBoardSizes(m_MaxBoardSize); |
|
var height = m_MaxBoardSize.Rows; |
|
var width = m_MaxBoardSize.Columns; |
|
if (ReferenceEquals(null, m_ObservationTexture)) |
|
{ |
|
m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false); |
|
} |
|
|
|
if (ReferenceEquals(null, m_TextureUtil)) |
|
{ |
|
m_TextureUtil = new OneHotToTextureUtil(height, width); |
|
} |
|
var bytesOut = new List<byte>(); |
|
var currentBoardSize = m_Board.GetCurrentBoardSize(); |
|
|
|
|
|
|
|
|
|
|
|
var numCellImages = (m_OneHotSize + 2) / 3; |
|
for (var i = 0; i < numCellImages; i++) |
|
{ |
|
m_TextureUtil.EncodeToTexture( |
|
m_GridValues, |
|
m_ObservationTexture, |
|
3 * i, |
|
currentBoardSize.Rows, |
|
currentBoardSize.Columns |
|
); |
|
bytesOut.AddRange(m_ObservationTexture.EncodeToPNG()); |
|
} |
|
|
|
return bytesOut.ToArray(); |
|
} |
|
|
|
|
|
public void Update() |
|
{ |
|
} |
|
|
|
|
|
public void Reset() |
|
{ |
|
} |
|
|
|
internal SensorCompressionType GetCompressionType() |
|
{ |
|
return m_ObservationType == Match3ObservationType.CompressedVisual ? |
|
SensorCompressionType.PNG : |
|
SensorCompressionType.None; |
|
} |
|
|
|
|
|
public CompressionSpec GetCompressionSpec() |
|
{ |
|
return new CompressionSpec(GetCompressionType()); |
|
} |
|
|
|
|
|
public string GetName() |
|
{ |
|
return m_Name; |
|
} |
|
|
|
|
|
public BuiltInSensorType GetBuiltInSensorType() |
|
{ |
|
return BuiltInSensorType.Match3Sensor; |
|
} |
|
|
|
|
|
|
|
|
|
public void Dispose() |
|
{ |
|
if (!ReferenceEquals(null, m_ObservationTexture)) |
|
{ |
|
Utilities.DestroyTexture(m_ObservationTexture); |
|
m_ObservationTexture = null; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal class OneHotToTextureUtil |
|
{ |
|
Color[] m_Colors; |
|
int m_MaxHeight; |
|
int m_MaxWidth; |
|
private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue }; |
|
|
|
public OneHotToTextureUtil(int maxHeight, int maxWidth) |
|
{ |
|
m_Colors = new Color[maxHeight * maxWidth]; |
|
m_MaxHeight = maxHeight; |
|
m_MaxWidth = maxWidth; |
|
} |
|
|
|
public void EncodeToTexture( |
|
GridValueProvider gridValueProvider, |
|
Texture2D texture, |
|
int channelOffset, |
|
int currentHeight, |
|
int currentWidth |
|
) |
|
{ |
|
var i = 0; |
|
|
|
|
|
for (var h = m_MaxHeight - 1; h >= 0; h--) |
|
{ |
|
for (var w = 0; w < m_MaxWidth; w++) |
|
{ |
|
var colorVal = Color.black; |
|
if (h < currentHeight && w < currentWidth) |
|
{ |
|
int oneHotValue = gridValueProvider(h, w); |
|
if (oneHotValue >= channelOffset && oneHotValue < channelOffset + 3) |
|
{ |
|
colorVal = s_OneHotColors[oneHotValue - channelOffset]; |
|
} |
|
} |
|
m_Colors[i++] = colorVal; |
|
} |
|
} |
|
texture.SetPixels(m_Colors); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
internal static class ObservationWriterMatch3Extensions |
|
{ |
|
public static void WriteOneHot(this ObservationWriter writer, int offset, int row, int col, int value, int oneHotSize, bool isVisual) |
|
{ |
|
if (isVisual) |
|
{ |
|
for (var i = 0; i < oneHotSize; i++) |
|
{ |
|
writer[row, col, i] = (i == value) ? 1.0f : 0.0f; |
|
} |
|
} |
|
else |
|
{ |
|
for (var i = 0; i < oneHotSize; i++) |
|
{ |
|
writer[offset] = (i == value) ? 1.0f : 0.0f; |
|
offset++; |
|
} |
|
} |
|
} |
|
|
|
public static void WriteZero(this ObservationWriter writer, int offset, int row, int col, int oneHotSize, bool isVisual) |
|
{ |
|
if (isVisual) |
|
{ |
|
for (var i = 0; i < oneHotSize; i++) |
|
{ |
|
writer[row, col, i] = 0.0f; |
|
} |
|
} |
|
else |
|
{ |
|
for (var i = 0; i < oneHotSize; i++) |
|
{ |
|
writer[offset] = 0.0f; |
|
offset++; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|