|
using System; |
|
using System.Collections.Generic; |
|
using Unity.Collections; |
|
using Unity.Jobs; |
|
using UnityEngine; |
|
|
|
namespace Unity.MLAgents.Sensors |
|
{ |
|
|
|
|
|
|
|
public enum RayPerceptionCastType |
|
{ |
|
|
|
|
|
|
|
Cast2D, |
|
|
|
|
|
|
|
|
|
Cast3D, |
|
} |
|
|
|
|
|
|
|
|
|
public struct RayPerceptionInput |
|
{ |
|
|
|
|
|
|
|
public float RayLength; |
|
|
|
|
|
|
|
|
|
public IReadOnlyList<string> DetectableTags; |
|
|
|
|
|
|
|
|
|
|
|
public IReadOnlyList<float> Angles; |
|
|
|
|
|
|
|
|
|
public float StartOffset; |
|
|
|
|
|
|
|
|
|
public float EndOffset; |
|
|
|
|
|
|
|
|
|
|
|
public float CastRadius; |
|
|
|
|
|
|
|
|
|
public Transform Transform; |
|
|
|
|
|
|
|
|
|
public RayPerceptionCastType CastType; |
|
|
|
|
|
|
|
|
|
public int LayerMask; |
|
|
|
|
|
|
|
|
|
public bool UseBatchedRaycasts; |
|
|
|
|
|
|
|
|
|
|
|
public int OutputSize() |
|
{ |
|
return ((DetectableTags?.Count ?? 0) + 2) * (Angles?.Count ?? 0); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public (Vector3 StartPositionWorld, Vector3 EndPositionWorld) RayExtents(int rayIndex) |
|
{ |
|
var angle = Angles[rayIndex]; |
|
Vector3 startPositionLocal, endPositionLocal; |
|
if (CastType == RayPerceptionCastType.Cast3D) |
|
{ |
|
startPositionLocal = new Vector3(0, StartOffset, 0); |
|
endPositionLocal = PolarToCartesian3D(RayLength, angle); |
|
endPositionLocal.y += EndOffset; |
|
} |
|
else |
|
{ |
|
|
|
startPositionLocal = new Vector2(); |
|
endPositionLocal = PolarToCartesian2D(RayLength, angle); |
|
} |
|
|
|
var startPositionWorld = Transform.TransformPoint(startPositionLocal); |
|
var endPositionWorld = Transform.TransformPoint(endPositionLocal); |
|
|
|
return (StartPositionWorld: startPositionWorld, EndPositionWorld: endPositionWorld); |
|
} |
|
|
|
|
|
|
|
|
|
static internal Vector3 PolarToCartesian3D(float radius, float angleDegrees) |
|
{ |
|
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees); |
|
var z = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees); |
|
return new Vector3(x, 0f, z); |
|
} |
|
|
|
|
|
|
|
|
|
static internal Vector2 PolarToCartesian2D(float radius, float angleDegrees) |
|
{ |
|
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees); |
|
var y = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees); |
|
return new Vector2(x, y); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public class RayPerceptionOutput |
|
{ |
|
|
|
|
|
|
|
public struct RayOutput |
|
{ |
|
|
|
|
|
|
|
public bool HasHit; |
|
|
|
|
|
|
|
|
|
public bool HitTaggedObject; |
|
|
|
|
|
|
|
|
|
|
|
public int HitTagIndex; |
|
|
|
|
|
|
|
|
|
public float HitFraction; |
|
|
|
|
|
|
|
|
|
public GameObject HitGameObject; |
|
|
|
|
|
|
|
|
|
public Vector3 StartPositionWorld; |
|
|
|
|
|
|
|
|
|
public Vector3 EndPositionWorld; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public float ScaledRayLength |
|
{ |
|
get |
|
{ |
|
var rayDirection = EndPositionWorld - StartPositionWorld; |
|
return rayDirection.magnitude; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public float ScaledCastRadius; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public void ToFloatArray(int numDetectableTags, int rayIndex, float[] buffer) |
|
{ |
|
var bufferOffset = (numDetectableTags + 2) * rayIndex; |
|
if (HitTaggedObject) |
|
{ |
|
buffer[bufferOffset + HitTagIndex] = 1f; |
|
} |
|
buffer[bufferOffset + numDetectableTags] = HasHit ? 0f : 1f; |
|
buffer[bufferOffset + numDetectableTags + 1] = HitFraction; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public RayOutput[] RayOutputs; |
|
} |
|
|
|
|
|
|
|
|
|
public class RayPerceptionSensor : ISensor, IBuiltInSensor |
|
{ |
|
float[] m_Observations; |
|
ObservationSpec m_ObservationSpec; |
|
string m_Name; |
|
|
|
RayPerceptionInput m_RayPerceptionInput; |
|
RayPerceptionOutput m_RayPerceptionOutput; |
|
|
|
bool m_UseBatchedRaycasts; |
|
|
|
|
|
|
|
|
|
int m_DebugLastFrameCount; |
|
|
|
internal int DebugLastFrameCount |
|
{ |
|
get { return m_DebugLastFrameCount; } |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public RayPerceptionSensor(string name, RayPerceptionInput rayInput) |
|
{ |
|
m_Name = name; |
|
m_RayPerceptionInput = rayInput; |
|
m_UseBatchedRaycasts = rayInput.UseBatchedRaycasts; |
|
|
|
SetNumObservations(rayInput.OutputSize()); |
|
|
|
m_DebugLastFrameCount = Time.frameCount; |
|
m_RayPerceptionOutput = new RayPerceptionOutput(); |
|
} |
|
|
|
|
|
|
|
|
|
public RayPerceptionOutput RayPerceptionOutput |
|
{ |
|
get { return m_RayPerceptionOutput; } |
|
} |
|
|
|
void SetNumObservations(int numObservations) |
|
{ |
|
m_ObservationSpec = ObservationSpec.Vector(numObservations); |
|
m_Observations = new float[numObservations]; |
|
} |
|
|
|
internal void SetRayPerceptionInput(RayPerceptionInput rayInput) |
|
{ |
|
|
|
|
|
if (m_RayPerceptionInput.OutputSize() != rayInput.OutputSize()) |
|
{ |
|
Debug.Log( |
|
"Changing the number of tags or rays at runtime is not " + |
|
"supported and may cause errors in training or inference." |
|
); |
|
|
|
|
|
SetNumObservations(rayInput.OutputSize()); |
|
} |
|
m_RayPerceptionInput = rayInput; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public int Write(ObservationWriter writer) |
|
{ |
|
using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive")) |
|
{ |
|
Array.Clear(m_Observations, 0, m_Observations.Length); |
|
var numRays = m_RayPerceptionInput.Angles.Count; |
|
var numDetectableTags = m_RayPerceptionInput.DetectableTags.Count; |
|
|
|
|
|
for (var rayIndex = 0; rayIndex < numRays; rayIndex++) |
|
{ |
|
m_RayPerceptionOutput.RayOutputs?[rayIndex].ToFloatArray(numDetectableTags, rayIndex, m_Observations); |
|
} |
|
|
|
|
|
writer.AddList(m_Observations); |
|
} |
|
return m_Observations.Length; |
|
} |
|
|
|
|
|
public void Update() |
|
{ |
|
m_DebugLastFrameCount = Time.frameCount; |
|
var numRays = m_RayPerceptionInput.Angles.Count; |
|
|
|
if (m_RayPerceptionOutput.RayOutputs == null || m_RayPerceptionOutput.RayOutputs.Length != numRays) |
|
{ |
|
m_RayPerceptionOutput.RayOutputs = new RayPerceptionOutput.RayOutput[numRays]; |
|
} |
|
|
|
if (m_UseBatchedRaycasts && m_RayPerceptionInput.CastType == RayPerceptionCastType.Cast3D) |
|
{ |
|
PerceiveBatchedRays(ref m_RayPerceptionOutput.RayOutputs, m_RayPerceptionInput); |
|
} |
|
else |
|
{ |
|
|
|
for (var rayIndex = 0; rayIndex < numRays; rayIndex++) |
|
{ |
|
m_RayPerceptionOutput.RayOutputs[rayIndex] = PerceiveSingleRay(m_RayPerceptionInput, rayIndex); |
|
} |
|
} |
|
} |
|
|
|
|
|
public void Reset() { } |
|
|
|
|
|
public ObservationSpec GetObservationSpec() |
|
{ |
|
return m_ObservationSpec; |
|
} |
|
|
|
|
|
public string GetName() |
|
{ |
|
return m_Name; |
|
} |
|
|
|
|
|
public virtual byte[] GetCompressedObservation() |
|
{ |
|
return null; |
|
} |
|
|
|
|
|
public CompressionSpec GetCompressionSpec() |
|
{ |
|
return CompressionSpec.Default(); |
|
} |
|
|
|
|
|
public BuiltInSensorType GetBuiltInSensorType() |
|
{ |
|
return BuiltInSensorType.RayPerceptionSensor; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static RayPerceptionOutput Perceive(RayPerceptionInput input, bool batched) |
|
{ |
|
RayPerceptionOutput output = new RayPerceptionOutput(); |
|
output.RayOutputs = new RayPerceptionOutput.RayOutput[input.Angles.Count]; |
|
|
|
if (batched) |
|
{ |
|
PerceiveBatchedRays(ref output.RayOutputs, input); |
|
} |
|
else |
|
{ |
|
for (var rayIndex = 0; rayIndex < input.Angles.Count; rayIndex++) |
|
{ |
|
output.RayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex); |
|
} |
|
} |
|
|
|
return output; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal static void PerceiveBatchedRays(ref RayPerceptionOutput.RayOutput[] batchedRaycastOutputs, RayPerceptionInput input) |
|
{ |
|
var numRays = input.Angles.Count; |
|
var results = new NativeArray<RaycastHit>(numRays, Allocator.TempJob); |
|
var unscaledRayLength = input.RayLength; |
|
var unscaledCastRadius = input.CastRadius; |
|
|
|
var raycastCommands = new NativeArray<RaycastCommand>(unscaledCastRadius <= 0f ? numRays : 0, Allocator.TempJob); |
|
var spherecastCommands = new NativeArray<SpherecastCommand>(unscaledCastRadius > 0f ? numRays : 0, Allocator.TempJob); |
|
|
|
|
|
|
|
for (int i = 0; i < numRays; i++) |
|
{ |
|
var extents = input.RayExtents(i); |
|
var startPositionWorld = extents.StartPositionWorld; |
|
var endPositionWorld = extents.EndPositionWorld; |
|
|
|
var rayDirection = endPositionWorld - startPositionWorld; |
|
|
|
|
|
|
|
var scaledRayLength = rayDirection.magnitude; |
|
|
|
var scaledCastRadius = unscaledRayLength > 0 ? |
|
unscaledCastRadius * scaledRayLength / unscaledRayLength : |
|
unscaledCastRadius; |
|
|
|
var queryParameters = QueryParameters.Default; |
|
queryParameters.layerMask = input.LayerMask; |
|
|
|
var rayDirectionNormalized = rayDirection.normalized; |
|
|
|
if (scaledCastRadius > 0f) |
|
{ |
|
spherecastCommands[i] = new SpherecastCommand(startPositionWorld, scaledCastRadius, rayDirectionNormalized, queryParameters, scaledRayLength); |
|
} |
|
else |
|
{ |
|
raycastCommands[i] = new RaycastCommand(startPositionWorld, rayDirectionNormalized, queryParameters, scaledRayLength); |
|
} |
|
|
|
batchedRaycastOutputs[i] = new RayPerceptionOutput.RayOutput |
|
{ |
|
HitTaggedObject = false, |
|
HitTagIndex = -1, |
|
StartPositionWorld = startPositionWorld, |
|
EndPositionWorld = endPositionWorld, |
|
ScaledCastRadius = scaledCastRadius |
|
}; |
|
|
|
} |
|
|
|
if (unscaledCastRadius > 0f) |
|
{ |
|
JobHandle handle = SpherecastCommand.ScheduleBatch(spherecastCommands, results, 1, 1, default(JobHandle)); |
|
handle.Complete(); |
|
} |
|
else |
|
{ |
|
JobHandle handle = RaycastCommand.ScheduleBatch(raycastCommands, results, 1, 1, default(JobHandle)); |
|
handle.Complete(); |
|
} |
|
|
|
for (int i = 0; i < results.Length; i++) |
|
{ |
|
var castHit = results[i].collider != null; |
|
var hitFraction = 1.0f; |
|
GameObject hitObject = null; |
|
float scaledRayLength; |
|
float scaledCastRadius = batchedRaycastOutputs[i].ScaledCastRadius; |
|
if (scaledCastRadius > 0f) |
|
{ |
|
scaledRayLength = spherecastCommands[i].distance; |
|
} |
|
else |
|
{ |
|
scaledRayLength = raycastCommands[i].distance; |
|
} |
|
|
|
|
|
|
|
hitFraction = castHit ? (scaledRayLength > 0 ? results[i].distance / scaledRayLength : 0.0f) : 1.0f; |
|
hitObject = castHit ? results[i].collider.gameObject : null; |
|
|
|
if (castHit) |
|
{ |
|
var numTags = input.DetectableTags?.Count ?? 0; |
|
for (int j = 0; j < numTags; j++) |
|
{ |
|
var tagsEqual = false; |
|
try |
|
{ |
|
var tag = input.DetectableTags[j]; |
|
if (!string.IsNullOrEmpty(tag)) |
|
{ |
|
tagsEqual = hitObject.CompareTag(tag); |
|
} |
|
} |
|
catch (UnityException) |
|
{ |
|
} |
|
|
|
if (tagsEqual) |
|
{ |
|
batchedRaycastOutputs[i].HitTaggedObject = true; |
|
batchedRaycastOutputs[i].HitTagIndex = j; |
|
break; |
|
} |
|
} |
|
} |
|
|
|
batchedRaycastOutputs[i].HasHit = castHit; |
|
batchedRaycastOutputs[i].HitFraction = hitFraction; |
|
batchedRaycastOutputs[i].HitGameObject = hitObject; |
|
|
|
} |
|
|
|
results.Dispose(); |
|
raycastCommands.Dispose(); |
|
spherecastCommands.Dispose(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal static RayPerceptionOutput.RayOutput PerceiveSingleRay( |
|
RayPerceptionInput input, |
|
int rayIndex |
|
) |
|
{ |
|
var unscaledRayLength = input.RayLength; |
|
var unscaledCastRadius = input.CastRadius; |
|
|
|
var extents = input.RayExtents(rayIndex); |
|
var startPositionWorld = extents.StartPositionWorld; |
|
var endPositionWorld = extents.EndPositionWorld; |
|
|
|
var rayDirection = endPositionWorld - startPositionWorld; |
|
|
|
|
|
|
|
var scaledRayLength = rayDirection.magnitude; |
|
|
|
var scaledCastRadius = unscaledRayLength > 0 ? |
|
unscaledCastRadius * scaledRayLength / unscaledRayLength : |
|
unscaledCastRadius; |
|
|
|
|
|
var castHit = false; |
|
var hitFraction = 1.0f; |
|
GameObject hitObject = null; |
|
|
|
if (input.CastType == RayPerceptionCastType.Cast3D) |
|
{ |
|
#if MLA_UNITY_PHYSICS_MODULE |
|
RaycastHit rayHit; |
|
if (scaledCastRadius > 0f) |
|
{ |
|
castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit, |
|
scaledRayLength, input.LayerMask); |
|
} |
|
else |
|
{ |
|
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit, |
|
scaledRayLength, input.LayerMask); |
|
} |
|
|
|
|
|
|
|
hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f; |
|
hitObject = castHit ? rayHit.collider.gameObject : null; |
|
#endif |
|
} |
|
else |
|
{ |
|
#if MLA_UNITY_PHYSICS2D_MODULE |
|
RaycastHit2D rayHit; |
|
if (scaledCastRadius > 0f) |
|
{ |
|
rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection, |
|
scaledRayLength, input.LayerMask); |
|
} |
|
else |
|
{ |
|
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, input.LayerMask); |
|
} |
|
|
|
castHit = rayHit; |
|
hitFraction = castHit ? rayHit.fraction : 1.0f; |
|
hitObject = castHit ? rayHit.collider.gameObject : null; |
|
#endif |
|
} |
|
|
|
var rayOutput = new RayPerceptionOutput.RayOutput |
|
{ |
|
HasHit = castHit, |
|
HitFraction = hitFraction, |
|
HitTaggedObject = false, |
|
HitTagIndex = -1, |
|
HitGameObject = hitObject, |
|
StartPositionWorld = startPositionWorld, |
|
EndPositionWorld = endPositionWorld, |
|
ScaledCastRadius = scaledCastRadius |
|
}; |
|
|
|
if (castHit) |
|
{ |
|
|
|
var numTags = input.DetectableTags?.Count ?? 0; |
|
for (var i = 0; i < numTags; i++) |
|
{ |
|
var tagsEqual = false; |
|
try |
|
{ |
|
var tag = input.DetectableTags[i]; |
|
if (!string.IsNullOrEmpty(tag)) |
|
{ |
|
tagsEqual = hitObject.CompareTag(tag); |
|
} |
|
} |
|
catch (UnityException) |
|
{ |
|
|
|
} |
|
|
|
if (tagsEqual) |
|
{ |
|
rayOutput.HitTaggedObject = true; |
|
rayOutput.HitTagIndex = i; |
|
break; |
|
} |
|
} |
|
} |
|
|
|
|
|
return rayOutput; |
|
} |
|
} |
|
} |
|
|