|
using System; |
|
using System.Collections.Generic; |
|
using UnityEngine; |
|
using Object = UnityEngine.Object; |
|
|
|
namespace Unity.MLAgents.Extensions.Sensors |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public abstract class PoseExtractor |
|
{ |
|
int[] m_ParentIndices; |
|
Pose[] m_ModelSpacePoses; |
|
Pose[] m_LocalSpacePoses; |
|
|
|
Vector3[] m_ModelSpaceLinearVelocities; |
|
Vector3[] m_LocalSpaceLinearVelocities; |
|
|
|
bool[] m_PoseEnabled; |
|
|
|
|
|
|
|
|
|
|
|
public IEnumerable<Pose> GetEnabledModelSpacePoses() |
|
{ |
|
if (m_ModelSpacePoses == null) |
|
{ |
|
yield break; |
|
} |
|
|
|
for (var i = 0; i < m_ModelSpacePoses.Length; i++) |
|
{ |
|
if (m_PoseEnabled[i]) |
|
{ |
|
yield return m_ModelSpacePoses[i]; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public IEnumerable<Pose> GetEnabledLocalSpacePoses() |
|
{ |
|
if (m_LocalSpacePoses == null) |
|
{ |
|
yield break; |
|
} |
|
|
|
for (var i = 0; i < m_LocalSpacePoses.Length; i++) |
|
{ |
|
if (m_PoseEnabled[i]) |
|
{ |
|
yield return m_LocalSpacePoses[i]; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public IEnumerable<Vector3> GetEnabledModelSpaceVelocities() |
|
{ |
|
if (m_ModelSpaceLinearVelocities == null) |
|
{ |
|
yield break; |
|
} |
|
|
|
for (var i = 0; i < m_ModelSpaceLinearVelocities.Length; i++) |
|
{ |
|
if (m_PoseEnabled[i]) |
|
{ |
|
yield return m_ModelSpaceLinearVelocities[i]; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public IEnumerable<Vector3> GetEnabledLocalSpaceVelocities() |
|
{ |
|
if (m_LocalSpaceLinearVelocities == null) |
|
{ |
|
yield break; |
|
} |
|
|
|
for (var i = 0; i < m_LocalSpaceLinearVelocities.Length; i++) |
|
{ |
|
if (m_PoseEnabled[i]) |
|
{ |
|
yield return m_LocalSpaceLinearVelocities[i]; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public int NumEnabledPoses |
|
{ |
|
get |
|
{ |
|
if (m_PoseEnabled == null) |
|
{ |
|
return 0; |
|
} |
|
|
|
var numEnabled = 0; |
|
for (var i = 0; i < m_PoseEnabled.Length; i++) |
|
{ |
|
numEnabled += m_PoseEnabled[i] ? 1 : 0; |
|
} |
|
|
|
return numEnabled; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public int NumPoses |
|
{ |
|
get { return m_ModelSpacePoses?.Length ?? 0; } |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public int GetParentIndex(int index) |
|
{ |
|
if (m_ParentIndices == null) |
|
{ |
|
throw new NullReferenceException("No parent indices set"); |
|
} |
|
|
|
return m_ParentIndices[index]; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public void SetPoseEnabled(int index, bool val) |
|
{ |
|
m_PoseEnabled[index] = val; |
|
} |
|
|
|
public bool IsPoseEnabled(int index) |
|
{ |
|
return m_PoseEnabled[index]; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
protected void Setup(int[] parentIndices) |
|
{ |
|
#if DEBUG |
|
if (parentIndices[0] != -1) |
|
{ |
|
throw new UnityAgentsException($"Expected parentIndices[0] to be -1, got {parentIndices[0]}"); |
|
} |
|
#endif |
|
m_ParentIndices = parentIndices; |
|
var numPoses = parentIndices.Length; |
|
m_ModelSpacePoses = new Pose[numPoses]; |
|
m_LocalSpacePoses = new Pose[numPoses]; |
|
|
|
m_ModelSpaceLinearVelocities = new Vector3[numPoses]; |
|
m_LocalSpaceLinearVelocities = new Vector3[numPoses]; |
|
|
|
m_PoseEnabled = new bool[numPoses]; |
|
|
|
for (var i = 0; i < numPoses; i++) |
|
{ |
|
m_PoseEnabled[i] = true; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
protected internal abstract Pose GetPoseAt(int index); |
|
|
|
|
|
|
|
|
|
|
|
|
|
protected internal abstract Vector3 GetLinearVelocityAt(int index); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protected internal virtual Object GetObjectAt(int index) |
|
{ |
|
return null; |
|
} |
|
|
|
|
|
|
|
|
|
public void UpdateModelSpacePoses() |
|
{ |
|
using (TimerStack.Instance.Scoped("UpdateModelSpacePoses")) |
|
{ |
|
if (m_ModelSpacePoses == null) |
|
{ |
|
return; |
|
} |
|
|
|
var rootWorldTransform = GetPoseAt(0); |
|
var worldToModel = rootWorldTransform.Inverse(); |
|
var rootLinearVel = GetLinearVelocityAt(0); |
|
|
|
for (var i = 0; i < m_ModelSpacePoses.Length; i++) |
|
{ |
|
var currentWorldSpacePose = GetPoseAt(i); |
|
var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose); |
|
m_ModelSpacePoses[i] = currentModelSpacePose; |
|
|
|
var currentBodyLinearVel = GetLinearVelocityAt(i); |
|
var relativeVelocity = currentBodyLinearVel - rootLinearVel; |
|
m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public void UpdateLocalSpacePoses() |
|
{ |
|
using (TimerStack.Instance.Scoped("UpdateLocalSpacePoses")) |
|
{ |
|
if (m_LocalSpacePoses == null) |
|
{ |
|
return; |
|
} |
|
|
|
for (var i = 0; i < m_LocalSpacePoses.Length; i++) |
|
{ |
|
if (m_ParentIndices[i] != -1) |
|
{ |
|
var parentTransform = GetPoseAt(m_ParentIndices[i]); |
|
|
|
|
|
var invParent = parentTransform.Inverse(); |
|
var currentTransform = GetPoseAt(i); |
|
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform); |
|
|
|
var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]); |
|
var currentLinearVel = GetLinearVelocityAt(i); |
|
m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel); |
|
} |
|
else |
|
{ |
|
m_LocalSpacePoses[i] = Pose.identity; |
|
m_LocalSpaceLinearVelocities[i] = Vector3.zero; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public int GetNumPoseObservations(PhysicsSensorSettings settings) |
|
{ |
|
int obsPerPose = 0; |
|
obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0; |
|
obsPerPose += settings.UseModelSpaceRotations ? 4 : 0; |
|
obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0; |
|
obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0; |
|
|
|
obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0; |
|
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0; |
|
|
|
return NumEnabledPoses * obsPerPose; |
|
} |
|
|
|
internal void DrawModelSpace(Vector3 offset) |
|
{ |
|
UpdateLocalSpacePoses(); |
|
UpdateModelSpacePoses(); |
|
|
|
var pose = m_ModelSpacePoses; |
|
var localPose = m_LocalSpacePoses; |
|
for (var i = 0; i < pose.Length; i++) |
|
{ |
|
var current = pose[i]; |
|
if (m_ParentIndices[i] == -1) |
|
{ |
|
continue; |
|
} |
|
|
|
var parent = pose[m_ParentIndices[i]]; |
|
Debug.DrawLine(current.position + offset, parent.position + offset, Color.cyan); |
|
var localUp = localPose[i].rotation * Vector3.up; |
|
var localFwd = localPose[i].rotation * Vector3.forward; |
|
var localRight = localPose[i].rotation * Vector3.right; |
|
Debug.DrawLine(current.position + offset, current.position + offset + .1f * localUp, Color.red); |
|
Debug.DrawLine(current.position + offset, current.position + offset + .1f * localFwd, Color.green); |
|
Debug.DrawLine(current.position + offset, current.position + offset + .1f * localRight, Color.blue); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
internal struct DisplayNode |
|
{ |
|
|
|
|
|
|
|
public Object NodeObject; |
|
|
|
|
|
|
|
|
|
public bool Enabled; |
|
|
|
|
|
|
|
|
|
public int Depth; |
|
|
|
|
|
|
|
|
|
public int OriginalIndex; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
internal IList<DisplayNode> GetDisplayNodes() |
|
{ |
|
if (NumPoses == 0) |
|
{ |
|
return Array.Empty<DisplayNode>(); |
|
} |
|
var nodesOut = new List<DisplayNode>(NumPoses); |
|
|
|
|
|
var tree = new Dictionary<int, List<int>>(); |
|
for (var i = 0; i < NumPoses; i++) |
|
{ |
|
var parent = GetParentIndex(i); |
|
if (i == -1) |
|
{ |
|
continue; |
|
} |
|
|
|
if (!tree.ContainsKey(parent)) |
|
{ |
|
tree[parent] = new List<int>(); |
|
} |
|
tree[parent].Add(i); |
|
} |
|
|
|
|
|
var stack = new Stack<(int, int)>(); |
|
stack.Push((0, 0)); |
|
|
|
while (stack.Count != 0) |
|
{ |
|
var (current, depth) = stack.Pop(); |
|
var obj = GetObjectAt(current); |
|
|
|
var node = new DisplayNode |
|
{ |
|
NodeObject = obj, |
|
Enabled = IsPoseEnabled(current), |
|
OriginalIndex = current, |
|
Depth = depth |
|
}; |
|
nodesOut.Add(node); |
|
|
|
|
|
if (tree.ContainsKey(current)) |
|
{ |
|
|
|
var children = tree[current]; |
|
for (var childIdx = children.Count - 1; childIdx >= 0; childIdx--) |
|
{ |
|
stack.Push((children[childIdx], depth + 1)); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
if (nodesOut.Count > NumPoses) |
|
{ |
|
return nodesOut; |
|
} |
|
} |
|
|
|
return nodesOut; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public static class PoseExtensions |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static Pose Inverse(this Pose pose) |
|
{ |
|
var rotationInverse = Quaternion.Inverse(pose.rotation); |
|
var translationInverse = -(rotationInverse * pose.position); |
|
return new Pose { rotation = rotationInverse, position = translationInverse }; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static Pose Multiply(this Pose pose, Pose rhs) |
|
{ |
|
return rhs.GetTransformedBy(pose); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static Vector3 Multiply(this Pose pose, Vector3 rhs) |
|
{ |
|
return pose.rotation * rhs + pose.position; |
|
} |
|
|
|
|
|
} |
|
} |
|
|