File size: 12,824 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
using System;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Profiling;

namespace Unity.MLAgents.Sensors
{
    /// <summary>
    /// The way the GridSensor process detected colliders in a cell.
    /// </summary>
    public enum ProcessCollidersMethod
    {
        /// <summary>
        /// Get data from all colliders detected in a cell
        /// </summary>
        ProcessAllColliders,

        /// <summary>
        /// Get data from the collider closest to the agent
        /// </summary>
        ProcessClosestColliders
    }

    /// <summary>
    /// Grid-based sensor.
    /// </summary>
    public class GridSensorBase : ISensor, IBuiltInSensor, IDisposable
    {
        string m_Name;
        Vector3 m_CellScale;
        Vector3Int m_GridSize;
        string[] m_DetectableTags;
        SensorCompressionType m_CompressionType;
        ObservationSpec m_ObservationSpec;
        internal IGridPerception m_GridPerception;

        // Buffers
        float[] m_PerceptionBuffer;
        Color[] m_PerceptionColors;
        Texture2D m_PerceptionTexture;
        float[] m_CellDataBuffer;

        // Utility Constants Calculated on Init
        int m_NumCells;
        int m_CellObservationSize;
        Vector3 m_CellCenterOffset;


        /// <summary>
        /// Create a GridSensorBase with the specified configuration.
        /// </summary>
        /// <param name="name">The sensor name</param>
        /// <param name="cellScale">The scale of each cell in the grid</param>
        /// <param name="gridSize">Number of cells on each side of the grid</param>
        /// <param name="detectableTags">Tags to be detected by the sensor</param>
        /// <param name="compression">Compression type</param>
        public GridSensorBase(
            string name,
            Vector3 cellScale,
            Vector3Int gridSize,
            string[] detectableTags,
            SensorCompressionType compression
        )
        {
            m_Name = name;
            m_CellScale = cellScale;
            m_GridSize = gridSize;
            m_DetectableTags = detectableTags;
            CompressionType = compression;

            if (m_GridSize.y != 1)
            {
                throw new UnityAgentsException("GridSensor only supports 2D grids.");
            }

            m_NumCells = m_GridSize.x * m_GridSize.z;
            m_CellObservationSize = GetCellObservationSize();
            m_ObservationSpec = ObservationSpec.Visual(m_GridSize.x, m_GridSize.z, m_CellObservationSize);
            m_PerceptionTexture = new Texture2D(m_GridSize.x, m_GridSize.z, TextureFormat.RGB24, false);

            ResetPerceptionBuffer();
        }

        /// <summary>
        /// The compression type used by the sensor.
        /// </summary>
        public SensorCompressionType CompressionType
        {
            get { return m_CompressionType; }
            set
            {
                if (!IsDataNormalized() && value == SensorCompressionType.PNG)
                {
                    Debug.LogWarning($"Compression type {value} is only supported with normalized data. " +
                        "The sensor will not compress the data.");
                    return;
                }
                m_CompressionType = value;
            }
        }

        internal float[] PerceptionBuffer
        {
            get { return m_PerceptionBuffer; }
        }

        /// <summary>
        /// The tags which the sensor dectects.
        /// </summary>
        protected string[] DetectableTags
        {
            get { return m_DetectableTags; }
        }

        /// <inheritdoc/>
        public void Reset() { }

        /// <summary>
        /// Clears the perception buffer before loading in new data.
        /// </summary>
        public void ResetPerceptionBuffer()
        {
            if (m_PerceptionBuffer != null)
            {
                Array.Clear(m_PerceptionBuffer, 0, m_PerceptionBuffer.Length);
                Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length);
            }
            else
            {
                m_PerceptionBuffer = new float[m_CellObservationSize * m_NumCells];
                m_CellDataBuffer = new float[m_CellObservationSize];
                m_PerceptionColors = new Color[m_NumCells];
            }
        }

        /// <inheritdoc/>
        public string GetName()
        {
            return m_Name;
        }

        /// <inheritdoc/>
        public CompressionSpec GetCompressionSpec()
        {
            return new CompressionSpec(CompressionType);
        }

        /// <inheritdoc/>
        public BuiltInSensorType GetBuiltInSensorType()
        {
            return BuiltInSensorType.GridSensor;
        }

        /// <inheritdoc/>
        public byte[] GetCompressedObservation()
        {
            using (TimerStack.Instance.Scoped("GridSensor.GetCompressedObservation"))
            {
                var allBytes = new List<byte>();
                var numImages = (m_CellObservationSize + 2) / 3;
                for (int i = 0; i < numImages; i++)
                {
                    var channelIndex = 3 * i;
                    GridValuesToTexture(channelIndex, Math.Min(3, m_CellObservationSize - channelIndex));
                    allBytes.AddRange(m_PerceptionTexture.EncodeToPNG());
                }

                return allBytes.ToArray();
            }
        }

        /// <summary>
        /// Convert observation values to texture for PNG compression.
        /// </summary>
        void GridValuesToTexture(int channelIndex, int numChannelsToAdd)
        {
            for (int i = 0; i < m_NumCells; i++)
            {
                for (int j = 0; j < numChannelsToAdd; j++)
                {
                    m_PerceptionColors[i][j] = m_PerceptionBuffer[i * m_CellObservationSize + channelIndex + j];
                }
            }
            m_PerceptionTexture.SetPixels(m_PerceptionColors);
        }

        /// <summary>
        /// Get the observation values of the detected game object.
        /// Default is to record the detected tag index.
        ///
        /// This method can be overridden to encode the observation differently or get custom data from the object.
        /// When overriding this method, <seealso cref="GetCellObservationSize"/> and <seealso cref="IsDataNormalized"/>
        /// might also need to change accordingly.
        /// </summary>
        /// <param name="detectedObject">The game object that was detected within a certain cell</param>
        /// <param name="tagIndex">The index of the detectedObject's tag in the DetectableObjects list</param>
        /// <param name="dataBuffer">The buffer to write the observation values.
        ///         The buffer size is configured by <seealso cref="GetCellObservationSize"/>.
        /// </param>
        /// <example>
        ///   Here is an example of overriding GetObjectData to get the velocity of a potential Rigidbody:
        ///   <code>
        ///     protected override void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer)
        ///     {
        ///         if (tagIndex == Array.IndexOf(DetectableTags, "RigidBodyObject"))
        ///         {
        ///             Rigidbody rigidbody = detectedObject.GetComponent&lt;Rigidbody&gt;();
        ///             dataBuffer[0] = rigidbody.velocity.x;
        ///             dataBuffer[1] = rigidbody.velocity.y;
        ///             dataBuffer[2] = rigidbody.velocity.z;
        ///         }
        ///     }
        ///  </code>
        /// </example>
        protected virtual void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer)
        {
            dataBuffer[0] = tagIndex + 1;
        }

        /// <summary>
        /// Get the observation size for each cell. This will be the size of dataBuffer for <seealso cref="GetObjectData"/>.
        /// If overriding <seealso cref="GetObjectData"/>, override this method as well to the custom observation size.
        /// </summary>
        /// <returns>The observation size of each cell.</returns>
        protected virtual int GetCellObservationSize()
        {
            return 1;
        }

        /// <summary>
        /// Whether the data is normalized within [0, 1]. The sensor can only use PNG compression if the data is normailzed.
        /// If overriding <seealso cref="GetObjectData"/>, override this method as well according to the custom observation values.
        /// </summary>
        /// <returns>Bool value indicating whether data is normalized.</returns>
        protected virtual bool IsDataNormalized()
        {
            return false;
        }

        /// <summary>
        /// Whether to process all detected colliders in a cell. Default to false and only use the one closest to the agent.
        /// If overriding <seealso cref="GetObjectData"/>, consider override this method when needed.
        /// </summary>
        /// <returns>Bool value indicating whether to process all detected colliders in a cell.</returns>
        protected internal virtual ProcessCollidersMethod GetProcessCollidersMethod()
        {
            return ProcessCollidersMethod.ProcessClosestColliders;
        }

        /// <summary>
        /// If using PNG compression, check if the values are normalized.
        /// </summary>
        void ValidateValues(float[] dataValues, GameObject detectedObject)
        {
            if (m_CompressionType != SensorCompressionType.PNG)
            {
                return;
            }

            for (int j = 0; j < dataValues.Length; j++)
            {
                if (dataValues[j] < 0 || dataValues[j] > 1)
                    throw new UnityAgentsException($"When using compression type {m_CompressionType} the data value has to be normalized between 0-1. " +
                        $"Received value[{dataValues[j]}] for {detectedObject.name}");
            }
        }

        /// <summary>
        /// Collect data from the detected object if a detectable tag is matched.
        /// </summary>
        internal void ProcessDetectedObject(GameObject detectedObject, int cellIndex)
        {
            Profiler.BeginSample("GridSensor.ProcessDetectedObject");
            for (var i = 0; i < m_DetectableTags.Length; i++)
            {
                if (!ReferenceEquals(detectedObject, null) && detectedObject.CompareTag(m_DetectableTags[i]))
                {
                    if (GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders)
                    {
                        Array.Copy(m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellDataBuffer, 0, m_CellObservationSize);
                    }
                    else
                    {
                        Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length);
                    }

                    GetObjectData(detectedObject, i, m_CellDataBuffer);
                    ValidateValues(m_CellDataBuffer, detectedObject);
                    Array.Copy(m_CellDataBuffer, 0, m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellObservationSize);
                    break;
                }
            }
            Profiler.EndSample();
        }

        /// <inheritdoc/>
        public void Update()
        {
            ResetPerceptionBuffer();
            using (TimerStack.Instance.Scoped("GridSensor.Update"))
            {
                if (m_GridPerception != null)
                {
                    m_GridPerception.Perceive();
                }
            }
        }

        /// <inheritdoc/>
        public ObservationSpec GetObservationSpec()
        {
            return m_ObservationSpec;
        }

        /// <inheritdoc/>
        public int Write(ObservationWriter writer)
        {
            using (TimerStack.Instance.Scoped("GridSensor.Write"))
            {
                int index = 0;
                for (var h = m_GridSize.z - 1; h >= 0; h--)
                {
                    for (var w = 0; w < m_GridSize.x; w++)
                    {
                        for (var d = 0; d < m_CellObservationSize; d++)
                        {
                            writer[h, w, d] = m_PerceptionBuffer[index];
                            index++;
                        }
                    }
                }
                return index;
            }
        }

        /// <summary>
        /// Clean up the internal objects.
        /// </summary>
        public void Dispose()
        {
            if (!ReferenceEquals(null, m_PerceptionTexture))
            {
                Utilities.DestroyTexture(m_PerceptionTexture);
                m_PerceptionTexture = null;
            }
        }
    }
}