File size: 3,693 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{
    /// <summary>
    /// Editor component that creates a PhysicsBodySensor for the Agent.
    /// </summary>
    public class RigidBodySensorComponent : SensorComponent
    {
        /// <summary>
        /// The root Rigidbody of the system.
        /// </summary>
        public Rigidbody RootBody;

        /// <summary>
        /// Optional GameObject used to determine the root of the poses.
        /// </summary>
        public GameObject VirtualRoot;

        /// <summary>
        /// Settings defining what types of observations will be generated.
        /// </summary>
        [SerializeField]
        public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();

        /// <summary>
        /// Optional sensor name. This must be unique for each Agent.
        /// </summary>
        [SerializeField]
        public string sensorName;

        [SerializeField]
        [HideInInspector]
        RigidBodyPoseExtractor m_PoseExtractor;

        /// <summary>
        /// Creates a PhysicsBodySensor.
        /// </summary>
        /// <returns></returns>
        public override ISensor[] CreateSensors()
        {
            var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName;
            return new ISensor[] { new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName) };
        }

        /// <summary>
        /// Get the DisplayNodes of the hierarchy.
        /// </summary>
        /// <returns></returns>
        internal IList<PoseExtractor.DisplayNode> GetDisplayNodes()
        {
            return GetPoseExtractor().GetDisplayNodes();
        }

        /// <summary>
        /// Lazy construction of the PoseExtractor.
        /// </summary>
        /// <returns></returns>
        RigidBodyPoseExtractor GetPoseExtractor()
        {
            if (m_PoseExtractor == null)
            {
                ResetPoseExtractor();
            }

            return m_PoseExtractor;
        }

        /// <summary>
        /// Reset the pose extractor, trying to keep the enabled state of the corresponding poses the same.
        /// </summary>
        internal void ResetPoseExtractor()
        {
            // Get the current enabled state of each body, so that we can reinitialize with them.
            Dictionary<Rigidbody, bool> bodyPosesEnabled = null;
            if (m_PoseExtractor != null)
            {
                bodyPosesEnabled = m_PoseExtractor.GetBodyPosesEnabled();
            }
            m_PoseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot, bodyPosesEnabled);
        }

        /// <summary>
        /// Toggle the pose at the given index.
        /// </summary>
        /// <param name="index"></param>
        /// <param name="enabled"></param>
        internal void SetPoseEnabled(int index, bool enabled)
        {
            GetPoseExtractor().SetPoseEnabled(index, enabled);
        }

        internal bool IsTrivial()
        {
            if (ReferenceEquals(RootBody, null))
            {
                // It *is* trivial, but this will happen when the sensor is being set up, so don't warn then.
                return false;
            }
            var joints = RootBody.GetComponentsInChildren<Joint>();
            if (joints.Length == 0)
            {
                if (ReferenceEquals(VirtualRoot, null) || ReferenceEquals(VirtualRoot, RootBody.gameObject))
                {
                    return true;
                }
            }
            return false;
        }
    }
}