File size: 12,901 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
using System;
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using UnityEngine;

namespace Unity.MLAgents.Integrations.Match3
{
    /// <summary>
    /// Delegate that provides integer values at a given (x,y) coordinate.
    /// </summary>
    /// <param name="x"></param>
    /// <param name="y"></param>
    public delegate int GridValueProvider(int x, int y);

    /// <summary>
    /// Type of observations to generate.
    ///
    /// </summary>
    public enum Match3ObservationType
    {
        /// <summary>
        /// Generate a one-hot encoding of the cell type for each cell on the board. If there are special types,
        /// these will also be one-hot encoded.
        /// </summary>
        Vector,

        /// <summary>
        /// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as
        /// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded.
        /// </summary>
        UncompressedVisual,

        /// <summary>
        /// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as
        /// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded.
        /// During training, these will be sent as a concatenated series of PNG images, with 3 channels per image.
        /// </summary>
        CompressedVisual
    }

    /// <summary>
    /// Sensor for Match3 games. Can generate either vector, compressed visual,
    /// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
    /// </summary>
    public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable
    {
        Match3ObservationType m_ObservationType;
        ObservationSpec m_ObservationSpec;
        string m_Name;

        AbstractBoard m_Board;
        BoardSize m_MaxBoardSize;
        GridValueProvider m_GridValues;
        int m_OneHotSize;

        Texture2D m_ObservationTexture;
        OneHotToTextureUtil m_TextureUtil;

        /// <summary>
        /// Create a sensor for the GridValueProvider with the specified observation type.
        /// </summary>
        /// <remarks>
        /// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling
        /// the constructor directly.
        /// </remarks>
        /// <param name="board">The abstract board.</param>
        /// <param name="gvp">The GridValueProvider, should be either board.GetCellType or board.GetSpecialType.</param>
        /// <param name="oneHotSize">The number of possible values that the GridValueProvider can return.</param>
        /// <param name="obsType">Whether to produce vector or visual observations</param>
        /// <param name="name">Name of the sensor.</param>
        public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name)
        {
            var maxBoardSize = board.GetMaxBoardSize();
            m_Name = name;
            m_MaxBoardSize = maxBoardSize;
            m_GridValues = gvp;
            m_OneHotSize = oneHotSize;
            m_Board = board;

            m_ObservationType = obsType;
            m_ObservationSpec = obsType == Match3ObservationType.Vector
                ? ObservationSpec.Vector(maxBoardSize.Rows * maxBoardSize.Columns * oneHotSize)
                : ObservationSpec.Visual(maxBoardSize.Rows, maxBoardSize.Columns, oneHotSize);
        }

        /// <summary>
        /// Create a sensor that encodes the board cells as observations.
        /// </summary>
        /// <param name="board">The abstract board.</param>
        /// <param name="obsType">Whether to produce vector or visual observations</param>
        /// <param name="name">Name of the sensor.</param>
        /// <returns></returns>
        public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
        {
            var maxBoardSize = board.GetMaxBoardSize();
            return new Match3Sensor(board, board.GetCellType, maxBoardSize.NumCellTypes, obsType, name);
        }

        /// <summary>
        /// Create a sensor that encodes the cell special types as observations. Returns null if the board's
        /// NumSpecialTypes is 0 (indicating the sensor isn't needed).
        /// </summary>
        /// <param name="board">The abstract board.</param>
        /// <param name="obsType">Whether to produce vector or visual observations</param>
        /// <param name="name">Name of the sensor.</param>
        /// <returns></returns>
        public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
        {
            var maxBoardSize = board.GetMaxBoardSize();
            if (maxBoardSize.NumSpecialTypes == 0)
            {
                return null;
            }
            var specialSize = maxBoardSize.NumSpecialTypes + 1;
            return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name);
        }

        /// <inheritdoc/>
        public ObservationSpec GetObservationSpec()
        {
            return m_ObservationSpec;
        }

        /// <inheritdoc/>
        public int Write(ObservationWriter writer)
        {
            m_Board.CheckBoardSizes(m_MaxBoardSize);
            var currentBoardSize = m_Board.GetCurrentBoardSize();

            int offset = 0;
            var isVisual = m_ObservationType != Match3ObservationType.Vector;

            // This is equivalent to
            // for (var r = 0; r < m_MaxBoardSize.Rows; r++)
            //     for (var c = 0; c < m_MaxBoardSize.Columns; c++)
            //          if (r < currentBoardSize.Rows && c < currentBoardSize.Columns)
            //              WriteOneHot
            //          else
            //              WriteZero
            // but rearranged to avoid the branching.

            for (var r = 0; r < currentBoardSize.Rows; r++)
            {
                for (var c = 0; c < currentBoardSize.Columns; c++)
                {
                    var val = m_GridValues(r, c);
                    writer.WriteOneHot(offset, r, c, val, m_OneHotSize, isVisual);
                    offset += m_OneHotSize;
                }

                for (var c = currentBoardSize.Columns; c < m_MaxBoardSize.Columns; c++)
                {
                    writer.WriteZero(offset, r, c, m_OneHotSize, isVisual);
                    offset += m_OneHotSize;
                }
            }

            for (var r = currentBoardSize.Rows; r < m_MaxBoardSize.Columns; r++)
            {
                for (var c = 0; c < m_MaxBoardSize.Columns; c++)
                {
                    writer.WriteZero(offset, r, c, m_OneHotSize, isVisual);
                    offset += m_OneHotSize;
                }
            }

            return offset;
        }

        /// <inheritdoc/>
        public byte[] GetCompressedObservation()
        {
            m_Board.CheckBoardSizes(m_MaxBoardSize);
            var height = m_MaxBoardSize.Rows;
            var width = m_MaxBoardSize.Columns;
            if (ReferenceEquals(null, m_ObservationTexture))
            {
                m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
            }

            if (ReferenceEquals(null, m_TextureUtil))
            {
                m_TextureUtil = new OneHotToTextureUtil(height, width);
            }
            var bytesOut = new List<byte>();
            var currentBoardSize = m_Board.GetCurrentBoardSize();

            // Encode the cell types or special types as batches of PNGs
            // This is potentially wasteful, e.g. if there are 4 cell types and 1 special type, we could
            // fit in in 2 images, but we'll use 3 total (2 PNGs for the 4 cell type channels, and 1 for
            // the special types).
            var numCellImages = (m_OneHotSize + 2) / 3;
            for (var i = 0; i < numCellImages; i++)
            {
                m_TextureUtil.EncodeToTexture(
                    m_GridValues,
                    m_ObservationTexture,
                    3 * i,
                    currentBoardSize.Rows,
                    currentBoardSize.Columns
                );
                bytesOut.AddRange(m_ObservationTexture.EncodeToPNG());
            }

            return bytesOut.ToArray();
        }

        /// <inheritdoc/>
        public void Update()
        {
        }

        /// <inheritdoc/>
        public void Reset()
        {
        }

        internal SensorCompressionType GetCompressionType()
        {
            return m_ObservationType == Match3ObservationType.CompressedVisual ?
                SensorCompressionType.PNG :
                SensorCompressionType.None;
        }

        /// <inheritdoc/>
        public CompressionSpec GetCompressionSpec()
        {
            return new CompressionSpec(GetCompressionType());
        }

        /// <inheritdoc/>
        public string GetName()
        {
            return m_Name;
        }

        /// <inheritdoc/>
        public BuiltInSensorType GetBuiltInSensorType()
        {
            return BuiltInSensorType.Match3Sensor;
        }

        /// <summary>
        /// Clean up the owned Texture2D.
        /// </summary>
        public void Dispose()
        {
            if (!ReferenceEquals(null, m_ObservationTexture))
            {
                Utilities.DestroyTexture(m_ObservationTexture);
                m_ObservationTexture = null;
            }
        }
    }

    /// <summary>
    /// Utility class for converting a 2D array of ints representing a one-hot encoding into
    /// a texture, suitable for conversion to PNGs for observations.
    /// Works by encoding 3 values at a time as pixels in the texture, thus it should be
    /// called (maxValue + 2) / 3 times, increasing the channelOffset by 3 each time.
    /// </summary>
    internal class OneHotToTextureUtil
    {
        Color[] m_Colors;
        int m_MaxHeight;
        int m_MaxWidth;
        private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue };

        public OneHotToTextureUtil(int maxHeight, int maxWidth)
        {
            m_Colors = new Color[maxHeight * maxWidth];
            m_MaxHeight = maxHeight;
            m_MaxWidth = maxWidth;
        }

        public void EncodeToTexture(
            GridValueProvider gridValueProvider,
            Texture2D texture,
            int channelOffset,
            int currentHeight,
            int currentWidth
        )
        {
            var i = 0;
            // There's an implicit flip converting to PNG from texture, so make sure we
            // counteract that when forming the texture by iterating through h in reverse.
            for (var h = m_MaxHeight - 1; h >= 0; h--)
            {
                for (var w = 0; w < m_MaxWidth; w++)
                {
                    var colorVal = Color.black;
                    if (h < currentHeight && w < currentWidth)
                    {
                        int oneHotValue = gridValueProvider(h, w);
                        if (oneHotValue >= channelOffset && oneHotValue < channelOffset + 3)
                        {
                            colorVal = s_OneHotColors[oneHotValue - channelOffset];
                        }
                    }
                    m_Colors[i++] = colorVal;
                }
            }
            texture.SetPixels(m_Colors);
        }
    }

    /// <summary>
    /// Utility methods for writing one-hot observations.
    /// </summary>
    internal static class ObservationWriterMatch3Extensions
    {
        public static void WriteOneHot(this ObservationWriter writer, int offset, int row, int col, int value, int oneHotSize, bool isVisual)
        {
            if (isVisual)
            {
                for (var i = 0; i < oneHotSize; i++)
                {
                    writer[row, col, i] = (i == value) ? 1.0f : 0.0f;
                }
            }
            else
            {
                for (var i = 0; i < oneHotSize; i++)
                {
                    writer[offset] = (i == value) ? 1.0f : 0.0f;
                    offset++;
                }
            }
        }

        public static void WriteZero(this ObservationWriter writer, int offset, int row, int col, int oneHotSize, bool isVisual)
        {
            if (isVisual)
            {
                for (var i = 0; i < oneHotSize; i++)
                {
                    writer[row, col, i] = 0.0f;
                }
            }
            else
            {
                for (var i = 0; i < oneHotSize; i++)
                {
                    writer[offset] = 0.0f;
                    offset++;
                }
            }
        }
    }
}