File size: 12,901 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
using System;
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using UnityEngine;
namespace Unity.MLAgents.Integrations.Match3
{
/// <summary>
/// Delegate that provides integer values at a given (x,y) coordinate.
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
public delegate int GridValueProvider(int x, int y);
/// <summary>
/// Type of observations to generate.
///
/// </summary>
public enum Match3ObservationType
{
/// <summary>
/// Generate a one-hot encoding of the cell type for each cell on the board. If there are special types,
/// these will also be one-hot encoded.
/// </summary>
Vector,
/// <summary>
/// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as
/// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded.
/// </summary>
UncompressedVisual,
/// <summary>
/// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as
/// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded.
/// During training, these will be sent as a concatenated series of PNG images, with 3 channels per image.
/// </summary>
CompressedVisual
}
/// <summary>
/// Sensor for Match3 games. Can generate either vector, compressed visual,
/// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
/// </summary>
public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable
{
Match3ObservationType m_ObservationType;
ObservationSpec m_ObservationSpec;
string m_Name;
AbstractBoard m_Board;
BoardSize m_MaxBoardSize;
GridValueProvider m_GridValues;
int m_OneHotSize;
Texture2D m_ObservationTexture;
OneHotToTextureUtil m_TextureUtil;
/// <summary>
/// Create a sensor for the GridValueProvider with the specified observation type.
/// </summary>
/// <remarks>
/// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling
/// the constructor directly.
/// </remarks>
/// <param name="board">The abstract board.</param>
/// <param name="gvp">The GridValueProvider, should be either board.GetCellType or board.GetSpecialType.</param>
/// <param name="oneHotSize">The number of possible values that the GridValueProvider can return.</param>
/// <param name="obsType">Whether to produce vector or visual observations</param>
/// <param name="name">Name of the sensor.</param>
public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name)
{
var maxBoardSize = board.GetMaxBoardSize();
m_Name = name;
m_MaxBoardSize = maxBoardSize;
m_GridValues = gvp;
m_OneHotSize = oneHotSize;
m_Board = board;
m_ObservationType = obsType;
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.Vector(maxBoardSize.Rows * maxBoardSize.Columns * oneHotSize)
: ObservationSpec.Visual(maxBoardSize.Rows, maxBoardSize.Columns, oneHotSize);
}
/// <summary>
/// Create a sensor that encodes the board cells as observations.
/// </summary>
/// <param name="board">The abstract board.</param>
/// <param name="obsType">Whether to produce vector or visual observations</param>
/// <param name="name">Name of the sensor.</param>
/// <returns></returns>
public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
{
var maxBoardSize = board.GetMaxBoardSize();
return new Match3Sensor(board, board.GetCellType, maxBoardSize.NumCellTypes, obsType, name);
}
/// <summary>
/// Create a sensor that encodes the cell special types as observations. Returns null if the board's
/// NumSpecialTypes is 0 (indicating the sensor isn't needed).
/// </summary>
/// <param name="board">The abstract board.</param>
/// <param name="obsType">Whether to produce vector or visual observations</param>
/// <param name="name">Name of the sensor.</param>
/// <returns></returns>
public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
{
var maxBoardSize = board.GetMaxBoardSize();
if (maxBoardSize.NumSpecialTypes == 0)
{
return null;
}
var specialSize = maxBoardSize.NumSpecialTypes + 1;
return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name);
}
/// <inheritdoc/>
public ObservationSpec GetObservationSpec()
{
return m_ObservationSpec;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
m_Board.CheckBoardSizes(m_MaxBoardSize);
var currentBoardSize = m_Board.GetCurrentBoardSize();
int offset = 0;
var isVisual = m_ObservationType != Match3ObservationType.Vector;
// This is equivalent to
// for (var r = 0; r < m_MaxBoardSize.Rows; r++)
// for (var c = 0; c < m_MaxBoardSize.Columns; c++)
// if (r < currentBoardSize.Rows && c < currentBoardSize.Columns)
// WriteOneHot
// else
// WriteZero
// but rearranged to avoid the branching.
for (var r = 0; r < currentBoardSize.Rows; r++)
{
for (var c = 0; c < currentBoardSize.Columns; c++)
{
var val = m_GridValues(r, c);
writer.WriteOneHot(offset, r, c, val, m_OneHotSize, isVisual);
offset += m_OneHotSize;
}
for (var c = currentBoardSize.Columns; c < m_MaxBoardSize.Columns; c++)
{
writer.WriteZero(offset, r, c, m_OneHotSize, isVisual);
offset += m_OneHotSize;
}
}
for (var r = currentBoardSize.Rows; r < m_MaxBoardSize.Columns; r++)
{
for (var c = 0; c < m_MaxBoardSize.Columns; c++)
{
writer.WriteZero(offset, r, c, m_OneHotSize, isVisual);
offset += m_OneHotSize;
}
}
return offset;
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
m_Board.CheckBoardSizes(m_MaxBoardSize);
var height = m_MaxBoardSize.Rows;
var width = m_MaxBoardSize.Columns;
if (ReferenceEquals(null, m_ObservationTexture))
{
m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
}
if (ReferenceEquals(null, m_TextureUtil))
{
m_TextureUtil = new OneHotToTextureUtil(height, width);
}
var bytesOut = new List<byte>();
var currentBoardSize = m_Board.GetCurrentBoardSize();
// Encode the cell types or special types as batches of PNGs
// This is potentially wasteful, e.g. if there are 4 cell types and 1 special type, we could
// fit in in 2 images, but we'll use 3 total (2 PNGs for the 4 cell type channels, and 1 for
// the special types).
var numCellImages = (m_OneHotSize + 2) / 3;
for (var i = 0; i < numCellImages; i++)
{
m_TextureUtil.EncodeToTexture(
m_GridValues,
m_ObservationTexture,
3 * i,
currentBoardSize.Rows,
currentBoardSize.Columns
);
bytesOut.AddRange(m_ObservationTexture.EncodeToPNG());
}
return bytesOut.ToArray();
}
/// <inheritdoc/>
public void Update()
{
}
/// <inheritdoc/>
public void Reset()
{
}
internal SensorCompressionType GetCompressionType()
{
return m_ObservationType == Match3ObservationType.CompressedVisual ?
SensorCompressionType.PNG :
SensorCompressionType.None;
}
/// <inheritdoc/>
public CompressionSpec GetCompressionSpec()
{
return new CompressionSpec(GetCompressionType());
}
/// <inheritdoc/>
public string GetName()
{
return m_Name;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.Match3Sensor;
}
/// <summary>
/// Clean up the owned Texture2D.
/// </summary>
public void Dispose()
{
if (!ReferenceEquals(null, m_ObservationTexture))
{
Utilities.DestroyTexture(m_ObservationTexture);
m_ObservationTexture = null;
}
}
}
/// <summary>
/// Utility class for converting a 2D array of ints representing a one-hot encoding into
/// a texture, suitable for conversion to PNGs for observations.
/// Works by encoding 3 values at a time as pixels in the texture, thus it should be
/// called (maxValue + 2) / 3 times, increasing the channelOffset by 3 each time.
/// </summary>
internal class OneHotToTextureUtil
{
Color[] m_Colors;
int m_MaxHeight;
int m_MaxWidth;
private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue };
public OneHotToTextureUtil(int maxHeight, int maxWidth)
{
m_Colors = new Color[maxHeight * maxWidth];
m_MaxHeight = maxHeight;
m_MaxWidth = maxWidth;
}
public void EncodeToTexture(
GridValueProvider gridValueProvider,
Texture2D texture,
int channelOffset,
int currentHeight,
int currentWidth
)
{
var i = 0;
// There's an implicit flip converting to PNG from texture, so make sure we
// counteract that when forming the texture by iterating through h in reverse.
for (var h = m_MaxHeight - 1; h >= 0; h--)
{
for (var w = 0; w < m_MaxWidth; w++)
{
var colorVal = Color.black;
if (h < currentHeight && w < currentWidth)
{
int oneHotValue = gridValueProvider(h, w);
if (oneHotValue >= channelOffset && oneHotValue < channelOffset + 3)
{
colorVal = s_OneHotColors[oneHotValue - channelOffset];
}
}
m_Colors[i++] = colorVal;
}
}
texture.SetPixels(m_Colors);
}
}
/// <summary>
/// Utility methods for writing one-hot observations.
/// </summary>
internal static class ObservationWriterMatch3Extensions
{
public static void WriteOneHot(this ObservationWriter writer, int offset, int row, int col, int value, int oneHotSize, bool isVisual)
{
if (isVisual)
{
for (var i = 0; i < oneHotSize; i++)
{
writer[row, col, i] = (i == value) ? 1.0f : 0.0f;
}
}
else
{
for (var i = 0; i < oneHotSize; i++)
{
writer[offset] = (i == value) ? 1.0f : 0.0f;
offset++;
}
}
}
public static void WriteZero(this ObservationWriter writer, int offset, int row, int col, int oneHotSize, bool isVisual)
{
if (isVisual)
{
for (var i = 0; i < oneHotSize; i++)
{
writer[row, col, i] = 0.0f;
}
}
else
{
for (var i = 0; i < oneHotSize; i++)
{
writer[offset] = 0.0f;
offset++;
}
}
}
}
}
|