|
using System.Collections.Generic; |
|
using UnityEditor; |
|
using Unity.Barracuda; |
|
using Unity.MLAgents.Actuators; |
|
using Unity.MLAgents.Policies; |
|
using Unity.MLAgents.Sensors; |
|
using Unity.MLAgents.Sensors.Reflection; |
|
using CheckTypeEnum = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck.CheckTypeEnum; |
|
|
|
namespace Unity.MLAgents.Editor |
|
{ |
|
|
|
|
|
|
|
[CustomEditor(typeof(BehaviorParameters))] |
|
[CanEditMultipleObjects] |
|
internal class BehaviorParametersEditor : UnityEditor.Editor |
|
{ |
|
const float k_TimeBetweenModelReloads = 2f; |
|
|
|
float m_TimeSinceModelReload; |
|
|
|
bool m_RequireReload; |
|
const string k_BehaviorName = "m_BehaviorName"; |
|
const string k_BrainParametersName = "m_BrainParameters"; |
|
const string k_ModelName = "m_Model"; |
|
const string k_InferenceDeviceName = "m_InferenceDevice"; |
|
const string k_DeterministicInference = "m_DeterministicInference"; |
|
const string k_BehaviorTypeName = "m_BehaviorType"; |
|
const string k_TeamIdName = "TeamId"; |
|
const string k_UseChildSensorsName = "m_UseChildSensors"; |
|
const string k_ObservableAttributeHandlingName = "m_ObservableAttributeHandling"; |
|
|
|
public override void OnInspectorGUI() |
|
{ |
|
var so = serializedObject; |
|
so.Update(); |
|
bool needPolicyUpdate; |
|
|
|
var behaviorParameters = (BehaviorParameters)target; |
|
var agent = behaviorParameters.gameObject.GetComponent<Agent>(); |
|
if (agent == null) |
|
{ |
|
EditorGUILayout.HelpBox( |
|
"No Agent is associated with this Behavior Parameters. Attach an Agent to " + |
|
"this GameObject to configure your Agent with these behavior parameters.", |
|
MessageType.Warning); |
|
} |
|
|
|
|
|
EditorGUI.indentLevel++; |
|
EditorGUI.BeginChangeCheck(); |
|
|
|
EditorGUI.BeginChangeCheck(); |
|
{ |
|
EditorGUILayout.PropertyField(so.FindProperty(k_BehaviorName)); |
|
} |
|
needPolicyUpdate = EditorGUI.EndChangeCheck(); |
|
|
|
EditorGUI.BeginChangeCheck(); |
|
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); |
|
{ |
|
EditorGUILayout.PropertyField(so.FindProperty(k_BrainParametersName), true); |
|
} |
|
EditorGUI.EndDisabledGroup(); |
|
|
|
EditorGUI.BeginChangeCheck(); |
|
{ |
|
EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true); |
|
EditorGUI.indentLevel++; |
|
EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true); |
|
EditorGUILayout.PropertyField(so.FindProperty(k_DeterministicInference), true); |
|
EditorGUI.indentLevel--; |
|
} |
|
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck(); |
|
|
|
EditorGUI.BeginChangeCheck(); |
|
{ |
|
EditorGUILayout.PropertyField(so.FindProperty(k_BehaviorTypeName)); |
|
} |
|
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck(); |
|
|
|
EditorGUILayout.PropertyField(so.FindProperty(k_TeamIdName)); |
|
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); |
|
{ |
|
EditorGUILayout.PropertyField(so.FindProperty(k_UseChildSensorsName), true); |
|
EditorGUILayout.PropertyField(so.FindProperty(k_ObservableAttributeHandlingName), true); |
|
} |
|
EditorGUI.EndDisabledGroup(); |
|
|
|
EditorGUI.indentLevel--; |
|
m_RequireReload = EditorGUI.EndChangeCheck(); |
|
DisplayFailedModelChecks(); |
|
so.ApplyModifiedProperties(); |
|
|
|
if (needPolicyUpdate) |
|
{ |
|
UpdateAgentPolicy(); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
void DisplayFailedModelChecks() |
|
{ |
|
if (m_RequireReload && m_TimeSinceModelReload > k_TimeBetweenModelReloads) |
|
{ |
|
m_RequireReload = false; |
|
m_TimeSinceModelReload = 0; |
|
} |
|
|
|
D.logEnabled = false; |
|
Model barracudaModel = null; |
|
var model = (NNModel)serializedObject.FindProperty(k_ModelName).objectReferenceValue; |
|
var behaviorParameters = (BehaviorParameters)target; |
|
|
|
|
|
|
|
var agent = behaviorParameters.gameObject.GetComponent<Agent>(); |
|
if (agent == null) |
|
{ |
|
return; |
|
} |
|
agent.sensors = new List<ISensor>(); |
|
agent.InitializeSensors(); |
|
var sensors = agent.sensors.ToArray(); |
|
|
|
ActuatorComponent[] actuatorComponents; |
|
if (behaviorParameters.UseChildActuators) |
|
{ |
|
actuatorComponents = behaviorParameters.GetComponentsInChildren<ActuatorComponent>(); |
|
} |
|
else |
|
{ |
|
actuatorComponents = behaviorParameters.GetComponents<ActuatorComponent>(); |
|
} |
|
|
|
|
|
|
|
int observableAttributeSensorTotalSize = 0; |
|
if (agent != null && behaviorParameters.ObservableAttributeHandling != ObservableAttributeOptions.Ignore) |
|
{ |
|
List<string> observableErrors = new List<string>(); |
|
observableAttributeSensorTotalSize = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors); |
|
foreach (var check in observableErrors) |
|
{ |
|
EditorGUILayout.HelpBox(check, MessageType.Warning); |
|
} |
|
} |
|
|
|
var brainParameters = behaviorParameters.BrainParameters; |
|
if (model != null) |
|
{ |
|
barracudaModel = ModelLoader.Load(model); |
|
} |
|
if (brainParameters != null) |
|
{ |
|
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel( |
|
barracudaModel, brainParameters, sensors, actuatorComponents, |
|
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType, behaviorParameters.DeterministicInference |
|
); |
|
foreach (var check in failedChecks) |
|
{ |
|
if (check != null) |
|
{ |
|
switch (check.CheckType) |
|
{ |
|
case CheckTypeEnum.Info: |
|
EditorGUILayout.HelpBox(check.Message, MessageType.Info); |
|
break; |
|
case CheckTypeEnum.Warning: |
|
EditorGUILayout.HelpBox(check.Message, MessageType.Warning); |
|
break; |
|
case CheckTypeEnum.Error: |
|
EditorGUILayout.HelpBox(check.Message, MessageType.Error); |
|
break; |
|
default: |
|
break; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
void UpdateAgentPolicy() |
|
{ |
|
var behaviorParameters = (BehaviorParameters)target; |
|
behaviorParameters.UpdateAgentPolicy(); |
|
} |
|
} |
|
} |
|
|