|
using System.Collections.Generic; |
|
using UnityEngine; |
|
|
|
namespace Unity.MLAgents.Extensions.Sensors |
|
{ |
|
|
|
|
|
|
|
|
|
public class RigidBodyPoseExtractor : PoseExtractor |
|
{ |
|
Rigidbody[] m_Bodies; |
|
|
|
|
|
|
|
|
|
|
|
GameObject m_VirtualRoot; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, |
|
GameObject virtualRoot = null, Dictionary<Rigidbody, bool> enableBodyPoses = null) |
|
{ |
|
if (rootBody == null) |
|
{ |
|
return; |
|
} |
|
|
|
Rigidbody[] rbs; |
|
Joint[] joints; |
|
if (rootGameObject == null) |
|
{ |
|
rbs = rootBody.GetComponentsInChildren<Rigidbody>(); |
|
joints = rootBody.GetComponentsInChildren<Joint>(); |
|
} |
|
else |
|
{ |
|
rbs = rootGameObject.GetComponentsInChildren<Rigidbody>(); |
|
joints = rootGameObject.GetComponentsInChildren<Joint>(); |
|
} |
|
|
|
if (rbs == null || rbs.Length == 0) |
|
{ |
|
Debug.Log("No rigid bodies found!"); |
|
return; |
|
} |
|
|
|
if (rbs[0] != rootBody) |
|
{ |
|
Debug.Log("Expected root body at index 0"); |
|
return; |
|
} |
|
|
|
|
|
|
|
if (virtualRoot != null) |
|
{ |
|
var extendedRbs = new Rigidbody[rbs.Length + 1]; |
|
for (var i = 0; i < rbs.Length; i++) |
|
{ |
|
extendedRbs[i + 1] = rbs[i]; |
|
} |
|
|
|
rbs = extendedRbs; |
|
} |
|
|
|
var bodyToIndex = new Dictionary<Rigidbody, int>(rbs.Length); |
|
var parentIndices = new int[rbs.Length]; |
|
parentIndices[0] = -1; |
|
|
|
for (var i = 0; i < rbs.Length; i++) |
|
{ |
|
if (rbs[i] != null) |
|
{ |
|
bodyToIndex[rbs[i]] = i; |
|
} |
|
} |
|
|
|
foreach (var j in joints) |
|
{ |
|
var parent = j.connectedBody; |
|
var child = j.GetComponent<Rigidbody>(); |
|
|
|
var parentIndex = bodyToIndex[parent]; |
|
var childIndex = bodyToIndex[child]; |
|
parentIndices[childIndex] = parentIndex; |
|
} |
|
|
|
if (virtualRoot != null) |
|
{ |
|
|
|
parentIndices[1] = 0; |
|
m_VirtualRoot = virtualRoot; |
|
} |
|
|
|
m_Bodies = rbs; |
|
Setup(parentIndices); |
|
|
|
|
|
SetPoseEnabled(0, false); |
|
|
|
if (enableBodyPoses != null) |
|
{ |
|
foreach (var pair in enableBodyPoses) |
|
{ |
|
var rb = pair.Key; |
|
if (bodyToIndex.TryGetValue(rb, out var index)) |
|
{ |
|
SetPoseEnabled(index, pair.Value); |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
protected internal override Vector3 GetLinearVelocityAt(int index) |
|
{ |
|
if (index == 0 && m_VirtualRoot != null) |
|
{ |
|
|
|
return Vector3.zero; |
|
} |
|
return m_Bodies[index].velocity; |
|
} |
|
|
|
|
|
protected internal override Pose GetPoseAt(int index) |
|
{ |
|
if (index == 0 && m_VirtualRoot != null) |
|
{ |
|
|
|
return new Pose |
|
{ |
|
rotation = m_VirtualRoot.transform.rotation, |
|
position = m_VirtualRoot.transform.position |
|
}; |
|
} |
|
|
|
var body = m_Bodies[index]; |
|
return new Pose { rotation = body.rotation, position = body.position }; |
|
} |
|
|
|
|
|
protected internal override Object GetObjectAt(int index) |
|
{ |
|
if (index == 0 && m_VirtualRoot != null) |
|
{ |
|
return m_VirtualRoot; |
|
} |
|
return m_Bodies[index]; |
|
} |
|
|
|
internal Rigidbody[] Bodies => m_Bodies; |
|
|
|
|
|
|
|
|
|
|
|
internal Dictionary<Rigidbody, bool> GetBodyPosesEnabled() |
|
{ |
|
var bodyPosesEnabled = new Dictionary<Rigidbody, bool>(m_Bodies.Length); |
|
for (var i = 0; i < m_Bodies.Length; i++) |
|
{ |
|
var rb = m_Bodies[i]; |
|
if (rb == null) |
|
{ |
|
continue; |
|
} |
|
|
|
bodyPosesEnabled[rb] = IsPoseEnabled(i); |
|
} |
|
|
|
return bodyPosesEnabled; |
|
} |
|
|
|
internal IEnumerable<Rigidbody> GetEnabledRigidbodies() |
|
{ |
|
if (m_Bodies == null) |
|
{ |
|
yield break; |
|
} |
|
|
|
for (var i = 0; i < m_Bodies.Length; i++) |
|
{ |
|
var rb = m_Bodies[i]; |
|
if (rb == null) |
|
{ |
|
|
|
continue; |
|
} |
|
|
|
if (IsPoseEnabled(i)) |
|
{ |
|
yield return rb; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|