File size: 1,987 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
using UnityEngine;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{
    public class RigidBodyJointExtractor : IJointExtractor
    {
        Rigidbody m_Body;
        Joint m_Joint;

        public RigidBodyJointExtractor(Rigidbody body)
        {
            m_Body = body;
            m_Joint = m_Body?.GetComponent<Joint>();
        }

        public int NumObservations(PhysicsSensorSettings settings)
        {
            return NumObservations(m_Body, m_Joint, settings);
        }

        public static int NumObservations(Rigidbody body, Joint joint, PhysicsSensorSettings settings)
        {
            if (body == null || joint == null)
            {
                return 0;
            }

            var numObservations = 0;
            if (settings.UseJointForces)
            {
                // 3 force and 3 torque values
                numObservations += 6;
            }

            return numObservations;
        }

        public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
        {
            if (m_Body == null || m_Joint == null)
            {
                return 0;
            }

            var currentOffset = offset;
            if (settings.UseJointForces)
            {
                // Take tanh of the forces and torques to ensure they're in [-1, 1]
                writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.x);
                writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.y);
                writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.z);

                writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.x);
                writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.y);
                writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.z);
            }
            return currentOffset - offset;
        }
    }
}