|
using System.Collections.Generic; |
|
using System.Linq; |
|
using UnityEngine; |
|
|
|
namespace Unity.MLAgents.Sensors |
|
{ |
|
|
|
|
|
|
|
[AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)] |
|
public class GridSensorComponent : SensorComponent |
|
{ |
|
|
|
GridSensorBase m_DebugSensor; |
|
List<GridSensorBase> m_Sensors; |
|
internal IGridPerception m_GridPerception; |
|
|
|
[HideInInspector, SerializeField] |
|
protected internal string m_SensorName = "GridSensor"; |
|
|
|
|
|
|
|
|
|
public string SensorName |
|
{ |
|
get { return m_SensorName; } |
|
set { m_SensorName = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal Vector3 m_CellScale = new Vector3(1f, 0.01f, 1f); |
|
|
|
|
|
|
|
|
|
|
|
public Vector3 CellScale |
|
{ |
|
get { return m_CellScale; } |
|
set { m_CellScale = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal Vector3Int m_GridSize = new Vector3Int(16, 1, 16); |
|
|
|
|
|
|
|
|
|
public Vector3Int GridSize |
|
{ |
|
get { return m_GridSize; } |
|
set |
|
{ |
|
if (value.y != 1) |
|
{ |
|
m_GridSize = new Vector3Int(value.x, 1, value.z); |
|
} |
|
else |
|
{ |
|
m_GridSize = value; |
|
} |
|
} |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal bool m_RotateWithAgent = true; |
|
|
|
|
|
|
|
public bool RotateWithAgent |
|
{ |
|
get { return m_RotateWithAgent; } |
|
set { m_RotateWithAgent = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal GameObject m_AgentGameObject; |
|
|
|
|
|
|
|
|
|
public GameObject AgentGameObject |
|
{ |
|
get { return (m_AgentGameObject == null ? gameObject : m_AgentGameObject); } |
|
set { m_AgentGameObject = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal string[] m_DetectableTags; |
|
|
|
|
|
|
|
|
|
public string[] DetectableTags |
|
{ |
|
get { return m_DetectableTags; } |
|
set { m_DetectableTags = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal LayerMask m_ColliderMask; |
|
|
|
|
|
|
|
public LayerMask ColliderMask |
|
{ |
|
get { return m_ColliderMask; } |
|
set { m_ColliderMask = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal int m_MaxColliderBufferSize = 500; |
|
|
|
|
|
|
|
|
|
|
|
public int MaxColliderBufferSize |
|
{ |
|
get { return m_MaxColliderBufferSize; } |
|
set { m_MaxColliderBufferSize = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal int m_InitialColliderBufferSize = 4; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public int InitialColliderBufferSize |
|
{ |
|
get { return m_InitialColliderBufferSize; } |
|
set { m_InitialColliderBufferSize = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal Color[] m_DebugColors; |
|
|
|
|
|
|
|
public Color[] DebugColors |
|
{ |
|
get { return m_DebugColors; } |
|
set { m_DebugColors = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal float m_GizmoYOffset = 0f; |
|
|
|
|
|
|
|
public float GizmoYOffset |
|
{ |
|
get { return m_GizmoYOffset; } |
|
set { m_GizmoYOffset = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal bool m_ShowGizmos = false; |
|
|
|
|
|
|
|
public bool ShowGizmos |
|
{ |
|
get { return m_ShowGizmos; } |
|
set { m_ShowGizmos = value; } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
internal SensorCompressionType m_CompressionType = SensorCompressionType.PNG; |
|
|
|
|
|
|
|
public SensorCompressionType CompressionType |
|
{ |
|
get { return m_CompressionType; } |
|
set { m_CompressionType = value; UpdateSensor(); } |
|
} |
|
|
|
[HideInInspector, SerializeField] |
|
[Range(1, 50)] |
|
[Tooltip("Number of frames of observations that will be stacked before being fed to the neural network.")] |
|
internal int m_ObservationStacks = 1; |
|
|
|
|
|
|
|
|
|
public int ObservationStacks |
|
{ |
|
get { return m_ObservationStacks; } |
|
set { m_ObservationStacks = value; } |
|
} |
|
|
|
|
|
public override ISensor[] CreateSensors() |
|
{ |
|
m_GridPerception = new BoxOverlapChecker( |
|
m_CellScale, |
|
m_GridSize, |
|
m_RotateWithAgent, |
|
m_ColliderMask, |
|
gameObject, |
|
AgentGameObject, |
|
m_DetectableTags, |
|
m_InitialColliderBufferSize, |
|
m_MaxColliderBufferSize |
|
); |
|
|
|
|
|
m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None); |
|
m_GridPerception.RegisterDebugSensor(m_DebugSensor); |
|
|
|
m_Sensors = GetGridSensors().ToList(); |
|
if (m_Sensors == null || m_Sensors.Count < 1) |
|
{ |
|
throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." + |
|
"If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor."); |
|
} |
|
|
|
|
|
m_Sensors[0].m_GridPerception = m_GridPerception; |
|
foreach (var sensor in m_Sensors) |
|
{ |
|
m_GridPerception.RegisterSensor(sensor); |
|
} |
|
|
|
if (ObservationStacks != 1) |
|
{ |
|
var sensors = new ISensor[m_Sensors.Count]; |
|
for (var i = 0; i < m_Sensors.Count; i++) |
|
{ |
|
sensors[i] = new StackingSensor(m_Sensors[i], ObservationStacks); |
|
} |
|
return sensors; |
|
} |
|
else |
|
{ |
|
return m_Sensors.ToArray(); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
protected virtual GridSensorBase[] GetGridSensors() |
|
{ |
|
List<GridSensorBase> sensorList = new List<GridSensorBase>(); |
|
var sensor = new OneHotGridSensor(m_SensorName + "-OneHot", m_CellScale, m_GridSize, m_DetectableTags, m_CompressionType); |
|
sensorList.Add(sensor); |
|
return sensorList.ToArray(); |
|
} |
|
|
|
|
|
|
|
|
|
internal void UpdateSensor() |
|
{ |
|
if (m_Sensors != null) |
|
{ |
|
m_GridPerception.RotateWithAgent = m_RotateWithAgent; |
|
m_GridPerception.ColliderMask = m_ColliderMask; |
|
foreach (var sensor in m_Sensors) |
|
{ |
|
sensor.CompressionType = m_CompressionType; |
|
} |
|
} |
|
} |
|
|
|
void OnDrawGizmos() |
|
{ |
|
if (m_ShowGizmos) |
|
{ |
|
if (m_GridPerception == null || m_DebugSensor == null) |
|
{ |
|
return; |
|
} |
|
|
|
m_DebugSensor.ResetPerceptionBuffer(); |
|
m_GridPerception.UpdateGizmo(); |
|
var cellColors = m_DebugSensor.PerceptionBuffer; |
|
var rotation = m_GridPerception.GetGridRotation(); |
|
|
|
var scale = new Vector3(m_CellScale.x, m_CellScale.y, m_CellScale.z); |
|
var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0); |
|
var oldGizmoMatrix = Gizmos.matrix; |
|
for (var i = 0; i < m_DebugSensor.PerceptionBuffer.Length; i++) |
|
{ |
|
var cellPosition = m_GridPerception.GetCellGlobalPosition(i); |
|
var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale); |
|
Gizmos.matrix = oldGizmoMatrix * cubeTransform; |
|
var colorIndex = cellColors[i] - 1; |
|
var debugRayColor = Color.white; |
|
if (colorIndex > -1 && m_DebugColors.Length > colorIndex) |
|
{ |
|
debugRayColor = m_DebugColors[(int)colorIndex]; |
|
} |
|
Gizmos.color = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f); |
|
Gizmos.DrawCube(Vector3.zero, Vector3.one); |
|
} |
|
|
|
Gizmos.matrix = oldGizmoMatrix; |
|
} |
|
} |
|
} |
|
} |
|
|