File size: 5,121 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#if UNITY_2020_1_OR_NEWER
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
public class ArticulationBodyJointExtractor : IJointExtractor
{
ArticulationBody m_Body;
public ArticulationBodyJointExtractor(ArticulationBody body)
{
m_Body = body;
}
public int NumObservations(PhysicsSensorSettings settings)
{
return NumObservations(m_Body, settings);
}
public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings)
{
if (body == null || body.isRoot)
{
return 0;
}
var totalCount = 0;
if (settings.UseJointPositionsAndAngles)
{
switch (body.jointType)
{
case ArticulationJointType.RevoluteJoint:
case ArticulationJointType.SphericalJoint:
// Both RevoluteJoint and SphericalJoint have all angular components.
// We use sine and cosine of the angles for the observations.
totalCount += 2 * body.dofCount;
break;
case ArticulationJointType.FixedJoint:
// Since FixedJoint can't moved, there aren't any interesting observations for it.
break;
case ArticulationJointType.PrismaticJoint:
// One linear component
totalCount += body.dofCount;
break;
}
}
if (settings.UseJointForces)
{
totalCount += body.dofCount;
}
return totalCount;
}
public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
{
if (m_Body == null || m_Body.isRoot)
{
return 0;
}
var currentOffset = offset;
// Write joint positions
if (settings.UseJointPositionsAndAngles)
{
switch (m_Body.jointType)
{
case ArticulationJointType.RevoluteJoint:
case ArticulationJointType.SphericalJoint:
// All joint positions are angular
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
{
var jointRotationRads = m_Body.jointPosition[dofIndex];
writer[currentOffset++] = Mathf.Sin(jointRotationRads);
writer[currentOffset++] = Mathf.Cos(jointRotationRads);
}
break;
case ArticulationJointType.FixedJoint:
// No observations
break;
case ArticulationJointType.PrismaticJoint:
writer[currentOffset++] = GetPrismaticValue();
break;
}
}
if (settings.UseJointForces)
{
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
{
// take tanh to keep in [-1, 1]
writer[currentOffset++] = (float)System.Math.Tanh(m_Body.jointForce[dofIndex]);
}
}
return currentOffset - offset;
}
float GetPrismaticValue()
{
// Prismatic joints should have at most one free axis.
bool limited = false;
var drive = m_Body.xDrive;
if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion)
{
drive = m_Body.xDrive;
limited = true;
}
else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion)
{
drive = m_Body.yDrive;
limited = true;
}
else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion)
{
drive = m_Body.zDrive;
limited = true;
}
var jointPos = m_Body.jointPosition[0];
if (limited)
{
// If locked, interpolate between the limits.
var upperLimit = drive.upperLimit;
var lowerLimit = drive.lowerLimit;
if (upperLimit <= lowerLimit)
{
// Invalid limits (probably equal), so don't try to lerp
return 0;
}
var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos);
// Convert [0, 1] -> [-1, 1]
var normalized = 2.0f * invLerped - 1.0f;
return normalized;
}
// take tanh() to keep in [-1, 1]
return (float)System.Math.Tanh(jointPos);
}
}
}
#endif
|