File size: 4,374 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
using System.Collections.Generic;
#if UNITY_2020_1_OR_NEWER
using UnityEngine;
#endif
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{
    /// <summary>
    /// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
    /// </summary>
    public class PhysicsBodySensor : ISensor, IBuiltInSensor
    {
        ObservationSpec m_ObservationSpec;
        string m_SensorName;

        PoseExtractor m_PoseExtractor;
        List<IJointExtractor> m_JointExtractors;
        PhysicsSensorSettings m_Settings;

        /// <summary>
        /// Construct a new PhysicsBodySensor
        /// </summary>
        /// <param name="poseExtractor"></param>
        /// <param name="settings"></param>
        /// <param name="sensorName"></param>
        public PhysicsBodySensor(
            RigidBodyPoseExtractor poseExtractor,
            PhysicsSensorSettings settings,
            string sensorName
        )
        {
            m_PoseExtractor = poseExtractor;
            m_SensorName = sensorName;
            m_Settings = settings;

            var numJointExtractorObservations = 0;
            m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
            foreach (var rb in poseExtractor.GetEnabledRigidbodies())
            {
                var jointExtractor = new RigidBodyJointExtractor(rb);
                numJointExtractorObservations += jointExtractor.NumObservations(settings);
                m_JointExtractors.Add(jointExtractor);
            }

            var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
            m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
        }

#if UNITY_2020_1_OR_NEWER
        public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName = null)
        {
            var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
            m_PoseExtractor = poseExtractor;
            m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
            m_Settings = settings;

            var numJointExtractorObservations = 0;
            m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
            foreach (var articBody in poseExtractor.GetEnabledArticulationBodies())
            {
                var jointExtractor = new ArticulationBodyJointExtractor(articBody);
                numJointExtractorObservations += jointExtractor.NumObservations(settings);
                m_JointExtractors.Add(jointExtractor);
            }

            var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
            m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
        }

#endif

        /// <inheritdoc/>
        public ObservationSpec GetObservationSpec()
        {
            return m_ObservationSpec;
        }

        /// <inheritdoc/>
        public int Write(ObservationWriter writer)
        {
            var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
            foreach (var jointExtractor in m_JointExtractors)
            {
                numWritten += jointExtractor.Write(m_Settings, writer, numWritten);
            }
            return numWritten;
        }

        /// <inheritdoc/>
        public byte[] GetCompressedObservation()
        {
            return null;
        }

        /// <inheritdoc/>
        public void Update()
        {
            if (m_Settings.UseModelSpace)
            {
                m_PoseExtractor.UpdateModelSpacePoses();
            }

            if (m_Settings.UseLocalSpace)
            {
                m_PoseExtractor.UpdateLocalSpacePoses();
            }
        }

        /// <inheritdoc/>
        public void Reset() { }

        /// <inheritdoc/>
        public CompressionSpec GetCompressionSpec()
        {
            return CompressionSpec.Default();
        }

        /// <inheritdoc/>
        public string GetName()
        {
            return m_SensorName;
        }

        /// <inheritdoc/>
        public BuiltInSensorType GetBuiltInSensorType()
        {
            return BuiltInSensorType.PhysicsBodySensor;
        }
    }
}