|
using UnityEngine; |
|
using UnityEditor; |
|
using Unity.MLAgents.Policies; |
|
|
|
namespace Unity.MLAgents.Editor |
|
{ |
|
|
|
|
|
|
|
|
|
[CustomPropertyDrawer(typeof(BrainParameters))] |
|
internal class BrainParametersDrawer : PropertyDrawer |
|
{ |
|
|
|
const float k_LineHeight = 17f; |
|
const int k_VecObsNumLine = 3; |
|
const string k_ActionSpecName = "m_ActionSpec"; |
|
const string k_ContinuousActionSizeName = "m_NumContinuousActions"; |
|
const string k_DiscreteBranchSizeName = "BranchSizes"; |
|
const string k_ActionDescriptionPropName = "VectorActionDescriptions"; |
|
const string k_VecObsPropName = "VectorObservationSize"; |
|
const string k_NumVecObsPropName = "NumStackedVectorObservations"; |
|
|
|
|
|
public override float GetPropertyHeight(SerializedProperty property, GUIContent label) |
|
{ |
|
return GetHeightDrawVectorObservation() + |
|
GetHeightDrawVectorAction(property); |
|
} |
|
|
|
|
|
public override void OnGUI(Rect position, SerializedProperty property, GUIContent label) |
|
{ |
|
var indent = EditorGUI.indentLevel; |
|
EditorGUI.indentLevel = 0; |
|
position.height = k_LineHeight; |
|
EditorGUI.BeginProperty(position, label, property); |
|
EditorGUI.indentLevel++; |
|
|
|
|
|
DrawVectorObservation(position, property); |
|
position.y += GetHeightDrawVectorObservation(); |
|
|
|
|
|
DrawVectorAction(position, property); |
|
position.y += GetHeightDrawVectorAction(property); |
|
|
|
EditorGUI.EndProperty(); |
|
EditorGUI.indentLevel = indent; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void DrawVectorObservation(Rect position, SerializedProperty property) |
|
{ |
|
EditorGUI.LabelField(position, "Vector Observation"); |
|
position.y += k_LineHeight; |
|
|
|
EditorGUI.indentLevel++; |
|
EditorGUI.PropertyField(position, |
|
property.FindPropertyRelative(k_VecObsPropName), |
|
new GUIContent("Space Size", |
|
"Length of state " + |
|
"vector for brain (In Continuous state space)." + |
|
"Or number of possible values (in Discrete state space).")); |
|
position.y += k_LineHeight; |
|
|
|
EditorGUI.PropertyField(position, |
|
property.FindPropertyRelative(k_NumVecObsPropName), |
|
new GUIContent("Stacked Vectors", |
|
"Number of states that will be stacked before " + |
|
"being fed to the neural network.")); |
|
position.y += k_LineHeight; |
|
EditorGUI.indentLevel--; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
static float GetHeightDrawVectorObservation() |
|
{ |
|
return k_VecObsNumLine * k_LineHeight; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void DrawVectorAction(Rect position, SerializedProperty property) |
|
{ |
|
EditorGUI.LabelField(position, "Actions"); |
|
position.y += k_LineHeight; |
|
EditorGUI.indentLevel++; |
|
var actionSpecProperty = property.FindPropertyRelative(k_ActionSpecName); |
|
DrawContinuousVectorAction(position, actionSpecProperty); |
|
position.y += k_LineHeight; |
|
DrawDiscreteVectorAction(position, actionSpecProperty); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void DrawContinuousVectorAction(Rect position, SerializedProperty property) |
|
{ |
|
var continuousActionSize = property.FindPropertyRelative(k_ContinuousActionSizeName); |
|
EditorGUI.PropertyField( |
|
position, |
|
continuousActionSize, |
|
new GUIContent("Continuous Actions", "Number of continuous actions.")); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void DrawDiscreteVectorAction(Rect position, SerializedProperty property) |
|
{ |
|
var branchSizes = property.FindPropertyRelative(k_DiscreteBranchSizeName); |
|
var newSize = EditorGUI.IntField( |
|
position, "Discrete Branches", branchSizes.arraySize); |
|
|
|
|
|
|
|
|
|
if (newSize != branchSizes.arraySize) |
|
{ |
|
branchSizes.arraySize = newSize; |
|
} |
|
|
|
position.y += k_LineHeight; |
|
position.x += 20; |
|
position.width -= 20; |
|
for (var branchIndex = 0; |
|
branchIndex < branchSizes.arraySize; |
|
branchIndex++) |
|
{ |
|
var branchActionSize = |
|
branchSizes.GetArrayElementAtIndex(branchIndex); |
|
|
|
EditorGUI.PropertyField( |
|
position, |
|
branchActionSize, |
|
new GUIContent("Branch " + branchIndex + " Size", |
|
"Number of possible actions for the branch number " + branchIndex + ".")); |
|
position.y += k_LineHeight; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
static float GetHeightDrawVectorAction(SerializedProperty property) |
|
{ |
|
var actionSpecProperty = property.FindPropertyRelative(k_ActionSpecName); |
|
var numActionLines = 3 + actionSpecProperty.FindPropertyRelative(k_DiscreteBranchSizeName).arraySize; |
|
return numActionLines * k_LineHeight; |
|
} |
|
} |
|
} |
|
|