File size: 11,684 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 |
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A SensorComponent that creates a <see cref="GridSensorBase"/>.
/// </summary>
[AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)]
public class GridSensorComponent : SensorComponent
{
// dummy sensor only used for debug gizmo
GridSensorBase m_DebugSensor;
List<GridSensorBase> m_Sensors;
internal IGridPerception m_GridPerception;
[HideInInspector, SerializeField]
protected internal string m_SensorName = "GridSensor";
/// <summary>
/// Name of the generated <see cref="GridSensorBase"/> object.
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
/// </summary>
public string SensorName
{
get { return m_SensorName; }
set { m_SensorName = value; }
}
[HideInInspector, SerializeField]
internal Vector3 m_CellScale = new Vector3(1f, 0.01f, 1f);
/// <summary>
/// The scale of each grid cell.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public Vector3 CellScale
{
get { return m_CellScale; }
set { m_CellScale = value; }
}
[HideInInspector, SerializeField]
internal Vector3Int m_GridSize = new Vector3Int(16, 1, 16);
/// <summary>
/// The number of grid on each side.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public Vector3Int GridSize
{
get { return m_GridSize; }
set
{
if (value.y != 1)
{
m_GridSize = new Vector3Int(value.x, 1, value.z);
}
else
{
m_GridSize = value;
}
}
}
[HideInInspector, SerializeField]
internal bool m_RotateWithAgent = true;
/// <summary>
/// Rotate the grid based on the direction the agent is facing.
/// </summary>
public bool RotateWithAgent
{
get { return m_RotateWithAgent; }
set { m_RotateWithAgent = value; }
}
[HideInInspector, SerializeField]
internal GameObject m_AgentGameObject;
/// <summary>
/// The reference of the root of the agent. This is used to disambiguate objects with
/// the same tag as the agent. Defaults to current GameObject.
/// </summary>
public GameObject AgentGameObject
{
get { return (m_AgentGameObject == null ? gameObject : m_AgentGameObject); }
set { m_AgentGameObject = value; }
}
[HideInInspector, SerializeField]
internal string[] m_DetectableTags;
/// <summary>
/// List of tags that are detected.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public string[] DetectableTags
{
get { return m_DetectableTags; }
set { m_DetectableTags = value; }
}
[HideInInspector, SerializeField]
internal LayerMask m_ColliderMask;
/// <summary>
/// The layer mask.
/// </summary>
public LayerMask ColliderMask
{
get { return m_ColliderMask; }
set { m_ColliderMask = value; }
}
[HideInInspector, SerializeField]
internal int m_MaxColliderBufferSize = 500;
/// <summary>
/// The absolute max size of the Collider buffer used in the non-allocating Physics calls. In other words
/// the Collider buffer will never grow beyond this number even if there are more Colliders in the Grid Cell.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int MaxColliderBufferSize
{
get { return m_MaxColliderBufferSize; }
set { m_MaxColliderBufferSize = value; }
}
[HideInInspector, SerializeField]
internal int m_InitialColliderBufferSize = 4;
/// <summary>
/// The Estimated Max Number of Colliders to expect per cell. This number is used to
/// pre-allocate an array of Colliders in order to take advantage of the OverlapBoxNonAlloc
/// Physics API. If the number of colliders found is >= InitialColliderBufferSize the array
/// will be resized to double its current size. The hard coded absolute size is 500.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int InitialColliderBufferSize
{
get { return m_InitialColliderBufferSize; }
set { m_InitialColliderBufferSize = value; }
}
[HideInInspector, SerializeField]
internal Color[] m_DebugColors;
/// <summary>
/// Array of Colors used for the grid gizmos.
/// </summary>
public Color[] DebugColors
{
get { return m_DebugColors; }
set { m_DebugColors = value; }
}
[HideInInspector, SerializeField]
internal float m_GizmoYOffset = 0f;
/// <summary>
/// The height of the gizmos grid.
/// </summary>
public float GizmoYOffset
{
get { return m_GizmoYOffset; }
set { m_GizmoYOffset = value; }
}
[HideInInspector, SerializeField]
internal bool m_ShowGizmos = false;
/// <summary>
/// Whether to show gizmos or not.
/// </summary>
public bool ShowGizmos
{
get { return m_ShowGizmos; }
set { m_ShowGizmos = value; }
}
[HideInInspector, SerializeField]
internal SensorCompressionType m_CompressionType = SensorCompressionType.PNG;
/// <summary>
/// The compression type to use for the sensor.
/// </summary>
public SensorCompressionType CompressionType
{
get { return m_CompressionType; }
set { m_CompressionType = value; UpdateSensor(); }
}
[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of frames of observations that will be stacked before being fed to the neural network.")]
internal int m_ObservationStacks = 1;
/// <summary>
/// Whether to stack previous observations. Using 1 means no previous observations.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int ObservationStacks
{
get { return m_ObservationStacks; }
set { m_ObservationStacks = value; }
}
/// <inheritdoc/>
public override ISensor[] CreateSensors()
{
m_GridPerception = new BoxOverlapChecker(
m_CellScale,
m_GridSize,
m_RotateWithAgent,
m_ColliderMask,
gameObject,
AgentGameObject,
m_DetectableTags,
m_InitialColliderBufferSize,
m_MaxColliderBufferSize
);
// debug data is positive int value and will trigger data validation exception if SensorCompressionType is not None.
m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None);
m_GridPerception.RegisterDebugSensor(m_DebugSensor);
m_Sensors = GetGridSensors().ToList();
if (m_Sensors == null || m_Sensors.Count < 1)
{
throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." +
"If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor.");
}
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
m_Sensors[0].m_GridPerception = m_GridPerception;
foreach (var sensor in m_Sensors)
{
m_GridPerception.RegisterSensor(sensor);
}
if (ObservationStacks != 1)
{
var sensors = new ISensor[m_Sensors.Count];
for (var i = 0; i < m_Sensors.Count; i++)
{
sensors[i] = new StackingSensor(m_Sensors[i], ObservationStacks);
}
return sensors;
}
else
{
return m_Sensors.ToArray();
}
}
/// <summary>
/// Get an array of GridSensors to be added in this component.
/// Override this method and return custom GridSensor implementations.
/// </summary>
/// <returns>Array of grid sensors to be added to the component.</returns>
protected virtual GridSensorBase[] GetGridSensors()
{
List<GridSensorBase> sensorList = new List<GridSensorBase>();
var sensor = new OneHotGridSensor(m_SensorName + "-OneHot", m_CellScale, m_GridSize, m_DetectableTags, m_CompressionType);
sensorList.Add(sensor);
return sensorList.ToArray();
}
/// <summary>
/// Update fields that are safe to change on the Sensor at runtime.
/// </summary>
internal void UpdateSensor()
{
if (m_Sensors != null)
{
m_GridPerception.RotateWithAgent = m_RotateWithAgent;
m_GridPerception.ColliderMask = m_ColliderMask;
foreach (var sensor in m_Sensors)
{
sensor.CompressionType = m_CompressionType;
}
}
}
void OnDrawGizmos()
{
if (m_ShowGizmos)
{
if (m_GridPerception == null || m_DebugSensor == null)
{
return;
}
m_DebugSensor.ResetPerceptionBuffer();
m_GridPerception.UpdateGizmo();
var cellColors = m_DebugSensor.PerceptionBuffer;
var rotation = m_GridPerception.GetGridRotation();
var scale = new Vector3(m_CellScale.x, m_CellScale.y, m_CellScale.z);
var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0);
var oldGizmoMatrix = Gizmos.matrix;
for (var i = 0; i < m_DebugSensor.PerceptionBuffer.Length; i++)
{
var cellPosition = m_GridPerception.GetCellGlobalPosition(i);
var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale);
Gizmos.matrix = oldGizmoMatrix * cubeTransform;
var colorIndex = cellColors[i] - 1;
var debugRayColor = Color.white;
if (colorIndex > -1 && m_DebugColors.Length > colorIndex)
{
debugRayColor = m_DebugColors[(int)colorIndex];
}
Gizmos.color = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f);
Gizmos.DrawCube(Vector3.zero, Vector3.one);
}
Gizmos.matrix = oldGizmoMatrix;
}
}
}
}
|