File size: 1,987 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
public class RigidBodyJointExtractor : IJointExtractor
{
Rigidbody m_Body;
Joint m_Joint;
public RigidBodyJointExtractor(Rigidbody body)
{
m_Body = body;
m_Joint = m_Body?.GetComponent<Joint>();
}
public int NumObservations(PhysicsSensorSettings settings)
{
return NumObservations(m_Body, m_Joint, settings);
}
public static int NumObservations(Rigidbody body, Joint joint, PhysicsSensorSettings settings)
{
if (body == null || joint == null)
{
return 0;
}
var numObservations = 0;
if (settings.UseJointForces)
{
// 3 force and 3 torque values
numObservations += 6;
}
return numObservations;
}
public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
{
if (m_Body == null || m_Joint == null)
{
return 0;
}
var currentOffset = offset;
if (settings.UseJointForces)
{
// Take tanh of the forces and torques to ensure they're in [-1, 1]
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.x);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.y);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.z);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.x);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.y);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.z);
}
return currentOffset - offset;
}
}
}
|