File size: 5,121 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#if UNITY_2020_1_OR_NEWER

using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{
    public class ArticulationBodyJointExtractor : IJointExtractor
    {
        ArticulationBody m_Body;

        public ArticulationBodyJointExtractor(ArticulationBody body)
        {
            m_Body = body;
        }

        public int NumObservations(PhysicsSensorSettings settings)
        {
            return NumObservations(m_Body, settings);
        }

        public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings)
        {
            if (body == null || body.isRoot)
            {
                return 0;
            }

            var totalCount = 0;
            if (settings.UseJointPositionsAndAngles)
            {
                switch (body.jointType)
                {
                    case ArticulationJointType.RevoluteJoint:
                    case ArticulationJointType.SphericalJoint:
                        // Both RevoluteJoint and SphericalJoint have all angular components.
                        // We use sine and cosine of the angles for the observations.
                        totalCount += 2 * body.dofCount;
                        break;
                    case ArticulationJointType.FixedJoint:
                        // Since FixedJoint can't moved, there aren't any interesting observations for it.
                        break;
                    case ArticulationJointType.PrismaticJoint:
                        // One linear component
                        totalCount += body.dofCount;
                        break;
                }
            }

            if (settings.UseJointForces)
            {
                totalCount += body.dofCount;
            }

            return totalCount;
        }

        public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
        {
            if (m_Body == null || m_Body.isRoot)
            {
                return 0;
            }

            var currentOffset = offset;

            // Write joint positions
            if (settings.UseJointPositionsAndAngles)
            {
                switch (m_Body.jointType)
                {
                    case ArticulationJointType.RevoluteJoint:
                    case ArticulationJointType.SphericalJoint:
                        // All joint positions are angular
                        for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
                        {
                            var jointRotationRads = m_Body.jointPosition[dofIndex];
                            writer[currentOffset++] = Mathf.Sin(jointRotationRads);
                            writer[currentOffset++] = Mathf.Cos(jointRotationRads);
                        }
                        break;
                    case ArticulationJointType.FixedJoint:
                        // No observations
                        break;
                    case ArticulationJointType.PrismaticJoint:
                        writer[currentOffset++] = GetPrismaticValue();
                        break;
                }
            }

            if (settings.UseJointForces)
            {
                for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
                {
                    // take tanh to keep in [-1, 1]
                    writer[currentOffset++] = (float)System.Math.Tanh(m_Body.jointForce[dofIndex]);
                }
            }

            return currentOffset - offset;
        }

        float GetPrismaticValue()
        {
            // Prismatic joints should have at most one free axis.
            bool limited = false;
            var drive = m_Body.xDrive;
            if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion)
            {
                drive = m_Body.xDrive;
                limited = true;
            }
            else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion)
            {
                drive = m_Body.yDrive;
                limited = true;
            }
            else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion)
            {
                drive = m_Body.zDrive;
                limited = true;
            }

            var jointPos = m_Body.jointPosition[0];
            if (limited)
            {
                // If locked, interpolate between the limits.
                var upperLimit = drive.upperLimit;
                var lowerLimit = drive.lowerLimit;
                if (upperLimit <= lowerLimit)
                {
                    // Invalid limits (probably equal), so don't try to lerp
                    return 0;
                }
                var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos);

                // Convert [0, 1] -> [-1, 1]
                var normalized = 2.0f * invLerped - 1.0f;
                return normalized;
            }
            // take tanh() to keep in [-1, 1]
            return (float)System.Math.Tanh(jointPos);
        }
    }
}
#endif