File size: 7,403 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
using UnityEngine;
using UnityEditor;
using Unity.MLAgents.Policies;

namespace Unity.MLAgents.Editor
{
    /// <summary>
    /// PropertyDrawer for BrainParameters. Defines how BrainParameters are displayed in the
    /// Inspector.
    /// </summary>
    [CustomPropertyDrawer(typeof(BrainParameters))]
    internal class BrainParametersDrawer : PropertyDrawer
    {
        // The height of a line in the Unity Inspectors
        const float k_LineHeight = 17f;
        const int k_VecObsNumLine = 3;
        const string k_ActionSpecName = "m_ActionSpec";
        const string k_ContinuousActionSizeName = "m_NumContinuousActions";
        const string k_DiscreteBranchSizeName = "BranchSizes";
        const string k_ActionDescriptionPropName = "VectorActionDescriptions";
        const string k_VecObsPropName = "VectorObservationSize";
        const string k_NumVecObsPropName = "NumStackedVectorObservations";

        /// <inheritdoc />
        public override float GetPropertyHeight(SerializedProperty property, GUIContent label)
        {
            return GetHeightDrawVectorObservation() +
                GetHeightDrawVectorAction(property);
        }

        /// <inheritdoc />
        public override void OnGUI(Rect position, SerializedProperty property, GUIContent label)
        {
            var indent = EditorGUI.indentLevel;
            EditorGUI.indentLevel = 0;
            position.height = k_LineHeight;
            EditorGUI.BeginProperty(position, label, property);
            EditorGUI.indentLevel++;

            // Vector Observations
            DrawVectorObservation(position, property);
            position.y += GetHeightDrawVectorObservation();

            // Vector Action
            DrawVectorAction(position, property);
            position.y += GetHeightDrawVectorAction(property);

            EditorGUI.EndProperty();
            EditorGUI.indentLevel = indent;
        }

        /// <summary>
        /// Draws the Vector Observations for the Brain Parameters
        /// </summary>
        /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
        /// <param name="property">The SerializedProperty of the BrainParameters
        /// to make the custom GUI for.</param>
        static void DrawVectorObservation(Rect position, SerializedProperty property)
        {
            EditorGUI.LabelField(position, "Vector Observation");
            position.y += k_LineHeight;

            EditorGUI.indentLevel++;
            EditorGUI.PropertyField(position,
                property.FindPropertyRelative(k_VecObsPropName),
                new GUIContent("Space Size",
                    "Length of state " +
                    "vector for brain (In Continuous state space)." +
                    "Or number of possible values (in Discrete state space)."));
            position.y += k_LineHeight;

            EditorGUI.PropertyField(position,
                property.FindPropertyRelative(k_NumVecObsPropName),
                new GUIContent("Stacked Vectors",
                    "Number of states that will be stacked before " +
                    "being fed to the neural network."));
            position.y += k_LineHeight;
            EditorGUI.indentLevel--;
        }

        /// <summary>
        /// The Height required to draw the Vector Observations paramaters
        /// </summary>
        /// <returns>The height of the drawer of the Vector Observations </returns>
        static float GetHeightDrawVectorObservation()
        {
            return k_VecObsNumLine * k_LineHeight;
        }

        /// <summary>
        /// Draws the Vector Actions parameters for the Brain Parameters
        /// </summary>
        /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
        /// <param name="property">The SerializedProperty of the BrainParameters
        /// to make the custom GUI for.</param>
        static void DrawVectorAction(Rect position, SerializedProperty property)
        {
            EditorGUI.LabelField(position, "Actions");
            position.y += k_LineHeight;
            EditorGUI.indentLevel++;
            var actionSpecProperty = property.FindPropertyRelative(k_ActionSpecName);
            DrawContinuousVectorAction(position, actionSpecProperty);
            position.y += k_LineHeight;
            DrawDiscreteVectorAction(position, actionSpecProperty);
        }

        /// <summary>
        /// Draws the Continuous Vector Actions parameters for the Brain Parameters
        /// </summary>
        /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
        /// <param name="property">The SerializedProperty of the BrainParameters
        /// to make the custom GUI for.</param>
        static void DrawContinuousVectorAction(Rect position, SerializedProperty property)
        {
            var continuousActionSize = property.FindPropertyRelative(k_ContinuousActionSizeName);
            EditorGUI.PropertyField(
                position,
                continuousActionSize,
                new GUIContent("Continuous Actions", "Number of continuous actions."));
        }

        /// <summary>
        /// Draws the Discrete Vector Actions parameters for the Brain Parameters
        /// </summary>
        /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
        /// <param name="property">The SerializedProperty of the BrainParameters
        /// to make the custom GUI for.</param>
        static void DrawDiscreteVectorAction(Rect position, SerializedProperty property)
        {
            var branchSizes = property.FindPropertyRelative(k_DiscreteBranchSizeName);
            var newSize = EditorGUI.IntField(
                position, "Discrete Branches", branchSizes.arraySize);

            // This check is here due to:
            // https://fogbugz.unity3d.com/f/cases/1246524/
            // If this case has been resolved, please remove this if condition.
            if (newSize != branchSizes.arraySize)
            {
                branchSizes.arraySize = newSize;
            }

            position.y += k_LineHeight;
            position.x += 20;
            position.width -= 20;
            for (var branchIndex = 0;
                 branchIndex < branchSizes.arraySize;
                 branchIndex++)
            {
                var branchActionSize =
                    branchSizes.GetArrayElementAtIndex(branchIndex);

                EditorGUI.PropertyField(
                    position,
                    branchActionSize,
                    new GUIContent("Branch " + branchIndex + " Size",
                        "Number of possible actions for the branch number " + branchIndex + "."));
                position.y += k_LineHeight;
            }
        }

        /// <summary>
        /// The Height required to draw the Vector Action parameters.
        /// </summary>
        /// <returns>The height of the drawer of the Vector Action.</returns>
        static float GetHeightDrawVectorAction(SerializedProperty property)
        {
            var actionSpecProperty = property.FindPropertyRelative(k_ActionSpecName);
            var numActionLines = 3 + actionSpecProperty.FindPropertyRelative(k_DiscreteBranchSizeName).arraySize;
            return numActionLines * k_LineHeight;
        }
    }
}