File size: 12,824 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 |
using System;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Profiling;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// The way the GridSensor process detected colliders in a cell.
/// </summary>
public enum ProcessCollidersMethod
{
/// <summary>
/// Get data from all colliders detected in a cell
/// </summary>
ProcessAllColliders,
/// <summary>
/// Get data from the collider closest to the agent
/// </summary>
ProcessClosestColliders
}
/// <summary>
/// Grid-based sensor.
/// </summary>
public class GridSensorBase : ISensor, IBuiltInSensor, IDisposable
{
string m_Name;
Vector3 m_CellScale;
Vector3Int m_GridSize;
string[] m_DetectableTags;
SensorCompressionType m_CompressionType;
ObservationSpec m_ObservationSpec;
internal IGridPerception m_GridPerception;
// Buffers
float[] m_PerceptionBuffer;
Color[] m_PerceptionColors;
Texture2D m_PerceptionTexture;
float[] m_CellDataBuffer;
// Utility Constants Calculated on Init
int m_NumCells;
int m_CellObservationSize;
Vector3 m_CellCenterOffset;
/// <summary>
/// Create a GridSensorBase with the specified configuration.
/// </summary>
/// <param name="name">The sensor name</param>
/// <param name="cellScale">The scale of each cell in the grid</param>
/// <param name="gridSize">Number of cells on each side of the grid</param>
/// <param name="detectableTags">Tags to be detected by the sensor</param>
/// <param name="compression">Compression type</param>
public GridSensorBase(
string name,
Vector3 cellScale,
Vector3Int gridSize,
string[] detectableTags,
SensorCompressionType compression
)
{
m_Name = name;
m_CellScale = cellScale;
m_GridSize = gridSize;
m_DetectableTags = detectableTags;
CompressionType = compression;
if (m_GridSize.y != 1)
{
throw new UnityAgentsException("GridSensor only supports 2D grids.");
}
m_NumCells = m_GridSize.x * m_GridSize.z;
m_CellObservationSize = GetCellObservationSize();
m_ObservationSpec = ObservationSpec.Visual(m_GridSize.x, m_GridSize.z, m_CellObservationSize);
m_PerceptionTexture = new Texture2D(m_GridSize.x, m_GridSize.z, TextureFormat.RGB24, false);
ResetPerceptionBuffer();
}
/// <summary>
/// The compression type used by the sensor.
/// </summary>
public SensorCompressionType CompressionType
{
get { return m_CompressionType; }
set
{
if (!IsDataNormalized() && value == SensorCompressionType.PNG)
{
Debug.LogWarning($"Compression type {value} is only supported with normalized data. " +
"The sensor will not compress the data.");
return;
}
m_CompressionType = value;
}
}
internal float[] PerceptionBuffer
{
get { return m_PerceptionBuffer; }
}
/// <summary>
/// The tags which the sensor dectects.
/// </summary>
protected string[] DetectableTags
{
get { return m_DetectableTags; }
}
/// <inheritdoc/>
public void Reset() { }
/// <summary>
/// Clears the perception buffer before loading in new data.
/// </summary>
public void ResetPerceptionBuffer()
{
if (m_PerceptionBuffer != null)
{
Array.Clear(m_PerceptionBuffer, 0, m_PerceptionBuffer.Length);
Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length);
}
else
{
m_PerceptionBuffer = new float[m_CellObservationSize * m_NumCells];
m_CellDataBuffer = new float[m_CellObservationSize];
m_PerceptionColors = new Color[m_NumCells];
}
}
/// <inheritdoc/>
public string GetName()
{
return m_Name;
}
/// <inheritdoc/>
public CompressionSpec GetCompressionSpec()
{
return new CompressionSpec(CompressionType);
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.GridSensor;
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
using (TimerStack.Instance.Scoped("GridSensor.GetCompressedObservation"))
{
var allBytes = new List<byte>();
var numImages = (m_CellObservationSize + 2) / 3;
for (int i = 0; i < numImages; i++)
{
var channelIndex = 3 * i;
GridValuesToTexture(channelIndex, Math.Min(3, m_CellObservationSize - channelIndex));
allBytes.AddRange(m_PerceptionTexture.EncodeToPNG());
}
return allBytes.ToArray();
}
}
/// <summary>
/// Convert observation values to texture for PNG compression.
/// </summary>
void GridValuesToTexture(int channelIndex, int numChannelsToAdd)
{
for (int i = 0; i < m_NumCells; i++)
{
for (int j = 0; j < numChannelsToAdd; j++)
{
m_PerceptionColors[i][j] = m_PerceptionBuffer[i * m_CellObservationSize + channelIndex + j];
}
}
m_PerceptionTexture.SetPixels(m_PerceptionColors);
}
/// <summary>
/// Get the observation values of the detected game object.
/// Default is to record the detected tag index.
///
/// This method can be overridden to encode the observation differently or get custom data from the object.
/// When overriding this method, <seealso cref="GetCellObservationSize"/> and <seealso cref="IsDataNormalized"/>
/// might also need to change accordingly.
/// </summary>
/// <param name="detectedObject">The game object that was detected within a certain cell</param>
/// <param name="tagIndex">The index of the detectedObject's tag in the DetectableObjects list</param>
/// <param name="dataBuffer">The buffer to write the observation values.
/// The buffer size is configured by <seealso cref="GetCellObservationSize"/>.
/// </param>
/// <example>
/// Here is an example of overriding GetObjectData to get the velocity of a potential Rigidbody:
/// <code>
/// protected override void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer)
/// {
/// if (tagIndex == Array.IndexOf(DetectableTags, "RigidBodyObject"))
/// {
/// Rigidbody rigidbody = detectedObject.GetComponent<Rigidbody>();
/// dataBuffer[0] = rigidbody.velocity.x;
/// dataBuffer[1] = rigidbody.velocity.y;
/// dataBuffer[2] = rigidbody.velocity.z;
/// }
/// }
/// </code>
/// </example>
protected virtual void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer)
{
dataBuffer[0] = tagIndex + 1;
}
/// <summary>
/// Get the observation size for each cell. This will be the size of dataBuffer for <seealso cref="GetObjectData"/>.
/// If overriding <seealso cref="GetObjectData"/>, override this method as well to the custom observation size.
/// </summary>
/// <returns>The observation size of each cell.</returns>
protected virtual int GetCellObservationSize()
{
return 1;
}
/// <summary>
/// Whether the data is normalized within [0, 1]. The sensor can only use PNG compression if the data is normailzed.
/// If overriding <seealso cref="GetObjectData"/>, override this method as well according to the custom observation values.
/// </summary>
/// <returns>Bool value indicating whether data is normalized.</returns>
protected virtual bool IsDataNormalized()
{
return false;
}
/// <summary>
/// Whether to process all detected colliders in a cell. Default to false and only use the one closest to the agent.
/// If overriding <seealso cref="GetObjectData"/>, consider override this method when needed.
/// </summary>
/// <returns>Bool value indicating whether to process all detected colliders in a cell.</returns>
protected internal virtual ProcessCollidersMethod GetProcessCollidersMethod()
{
return ProcessCollidersMethod.ProcessClosestColliders;
}
/// <summary>
/// If using PNG compression, check if the values are normalized.
/// </summary>
void ValidateValues(float[] dataValues, GameObject detectedObject)
{
if (m_CompressionType != SensorCompressionType.PNG)
{
return;
}
for (int j = 0; j < dataValues.Length; j++)
{
if (dataValues[j] < 0 || dataValues[j] > 1)
throw new UnityAgentsException($"When using compression type {m_CompressionType} the data value has to be normalized between 0-1. " +
$"Received value[{dataValues[j]}] for {detectedObject.name}");
}
}
/// <summary>
/// Collect data from the detected object if a detectable tag is matched.
/// </summary>
internal void ProcessDetectedObject(GameObject detectedObject, int cellIndex)
{
Profiler.BeginSample("GridSensor.ProcessDetectedObject");
for (var i = 0; i < m_DetectableTags.Length; i++)
{
if (!ReferenceEquals(detectedObject, null) && detectedObject.CompareTag(m_DetectableTags[i]))
{
if (GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders)
{
Array.Copy(m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellDataBuffer, 0, m_CellObservationSize);
}
else
{
Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length);
}
GetObjectData(detectedObject, i, m_CellDataBuffer);
ValidateValues(m_CellDataBuffer, detectedObject);
Array.Copy(m_CellDataBuffer, 0, m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellObservationSize);
break;
}
}
Profiler.EndSample();
}
/// <inheritdoc/>
public void Update()
{
ResetPerceptionBuffer();
using (TimerStack.Instance.Scoped("GridSensor.Update"))
{
if (m_GridPerception != null)
{
m_GridPerception.Perceive();
}
}
}
/// <inheritdoc/>
public ObservationSpec GetObservationSpec()
{
return m_ObservationSpec;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("GridSensor.Write"))
{
int index = 0;
for (var h = m_GridSize.z - 1; h >= 0; h--)
{
for (var w = 0; w < m_GridSize.x; w++)
{
for (var d = 0; d < m_CellObservationSize; d++)
{
writer[h, w, d] = m_PerceptionBuffer[index];
index++;
}
}
}
return index;
}
}
/// <summary>
/// Clean up the internal objects.
/// </summary>
public void Dispose()
{
if (!ReferenceEquals(null, m_PerceptionTexture))
{
Utilities.DestroyTexture(m_PerceptionTexture);
m_PerceptionTexture = null;
}
}
}
}
|