ppo-Pyramids-Training
/
Project
/Assets
/ML-Agents
/Examples
/SharedAssets
/Scripts
/JointDriveController.cs
using System.Collections.Generic; | |
using UnityEngine; | |
using UnityEngine.Serialization; | |
using Unity.MLAgents; | |
namespace Unity.MLAgentsExamples | |
{ | |
/// <summary> | |
/// Used to store relevant information for acting and learning for each body part in agent. | |
/// </summary> | |
[ | ]|
public class BodyPart | |
{ | |
[10)] public ConfigurableJoint joint; | ][Space(|
public Rigidbody rb; | |
[public Vector3 startingPos; | ]|
[public Quaternion startingRot; | ]|
[ | ]|
[ | ]|
public GroundContact groundContact; | |
public TargetContact targetContact; | |
[ | ]|
[public JointDriveController thisJdController; | ]|
[ | ]|
[ | ]|
public Vector3 currentEularJointRotation; | |
[public float currentStrength; | ]|
public float currentXNormalizedRot; | |
public float currentYNormalizedRot; | |
public float currentZNormalizedRot; | |
[ | ]|
[ | ]|
public Vector3 currentJointForce; | |
public float currentJointForceSqrMag; | |
public Vector3 currentJointTorque; | |
public float currentJointTorqueSqrMag; | |
public AnimationCurve jointForceCurve = new AnimationCurve(); | |
public AnimationCurve jointTorqueCurve = new AnimationCurve(); | |
/// <summary> | |
/// Reset body part to initial configuration. | |
/// </summary> | |
public void Reset(BodyPart bp) | |
{ | |
bp.rb.transform.position = bp.startingPos; | |
bp.rb.transform.rotation = bp.startingRot; | |
bp.rb.velocity = Vector3.zero; | |
bp.rb.angularVelocity = Vector3.zero; | |
if (bp.groundContact) | |
{ | |
bp.groundContact.touchingGround = false; | |
} | |
if (bp.targetContact) | |
{ | |
bp.targetContact.touchingTarget = false; | |
} | |
} | |
/// <summary> | |
/// Apply torque according to defined goal `x, y, z` angle and force `strength`. | |
/// </summary> | |
public void SetJointTargetRotation(float x, float y, float z) | |
{ | |
x = (x + 1f) * 0.5f; | |
y = (y + 1f) * 0.5f; | |
z = (z + 1f) * 0.5f; | |
var xRot = Mathf.Lerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, x); | |
var yRot = Mathf.Lerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, y); | |
var zRot = Mathf.Lerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, z); | |
currentXNormalizedRot = | |
Mathf.InverseLerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, xRot); | |
currentYNormalizedRot = Mathf.InverseLerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, yRot); | |
currentZNormalizedRot = Mathf.InverseLerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, zRot); | |
joint.targetRotation = Quaternion.Euler(xRot, yRot, zRot); | |
currentEularJointRotation = new Vector3(xRot, yRot, zRot); | |
} | |
public void SetJointStrength(float strength) | |
{ | |
var rawVal = (strength + 1f) * 0.5f * thisJdController.maxJointForceLimit; | |
var jd = new JointDrive | |
{ | |
positionSpring = thisJdController.maxJointSpring, | |
positionDamper = thisJdController.jointDampen, | |
maximumForce = rawVal | |
}; | |
joint.slerpDrive = jd; | |
currentStrength = jd.maximumForce; | |
} | |
} | |
public class JointDriveController : MonoBehaviour | |
{ | |
[ | ]|
[ | ]|
public float maxJointSpring; | |
public float jointDampen; | |
public float maxJointForceLimit; | |
[public Dictionary<Transform, BodyPart> bodyPartsDict = new Dictionary<Transform, BodyPart>(); | ]|
[public List<BodyPart> bodyPartsList = new List<BodyPart>(); | ]|
const float k_MaxAngularVelocity = 50.0f; | |
/// <summary> | |
/// Create BodyPart object and add it to dictionary. | |
/// </summary> | |
public void SetupBodyPart(Transform t) | |
{ | |
var bp = new BodyPart | |
{ | |
rb = t.GetComponent<Rigidbody>(), | |
joint = t.GetComponent<ConfigurableJoint>(), | |
startingPos = t.position, | |
startingRot = t.rotation | |
}; | |
bp.rb.maxAngularVelocity = k_MaxAngularVelocity; | |
// Add & setup the ground contact script | |
bp.groundContact = t.GetComponent<GroundContact>(); | |
if (!bp.groundContact) | |
{ | |
bp.groundContact = t.gameObject.AddComponent<GroundContact>(); | |
bp.groundContact.agent = gameObject.GetComponent<Agent>(); | |
} | |
else | |
{ | |
bp.groundContact.agent = gameObject.GetComponent<Agent>(); | |
} | |
if (bp.joint) | |
{ | |
var jd = new JointDrive | |
{ | |
positionSpring = maxJointSpring, | |
positionDamper = jointDampen, | |
maximumForce = maxJointForceLimit | |
}; | |
bp.joint.slerpDrive = jd; | |
} | |
bp.thisJdController = this; | |
bodyPartsDict.Add(t, bp); | |
bodyPartsList.Add(bp); | |
} | |
public void GetCurrentJointForces() | |
{ | |
foreach (var bodyPart in bodyPartsDict.Values) | |
{ | |
if (bodyPart.joint) | |
{ | |
bodyPart.currentJointForce = bodyPart.joint.currentForce; | |
bodyPart.currentJointForceSqrMag = bodyPart.joint.currentForce.magnitude; | |
bodyPart.currentJointTorque = bodyPart.joint.currentTorque; | |
bodyPart.currentJointTorqueSqrMag = bodyPart.joint.currentTorque.magnitude; | |
if (Application.isEditor) | |
{ | |
if (bodyPart.jointForceCurve.length > 1000) | |
{ | |
bodyPart.jointForceCurve = new AnimationCurve(); | |
} | |
if (bodyPart.jointTorqueCurve.length > 1000) | |
{ | |
bodyPart.jointTorqueCurve = new AnimationCurve(); | |
} | |
bodyPart.jointForceCurve.AddKey(Time.time, bodyPart.currentJointForceSqrMag); | |
bodyPart.jointTorqueCurve.AddKey(Time.time, bodyPart.currentJointTorqueSqrMag); | |
} | |
} | |
} | |
} | |
} | |
} | |