File size: 14,462 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
using System;
using System.Collections.Generic;
using UnityEngine;
using Unity.Barracuda;
using System.IO;
using Unity.Barracuda.ONNX;
using Unity.MLAgents;
using Unity.MLAgents.Policies;
#if UNITY_EDITOR
using UnityEditor;
#endif
namespace Unity.MLAgentsExamples
{
/// <summary>
/// Utility class to allow the NNModel file for an agent to be overriden during inference.
/// This is used internally to validate the file after training is done.
/// The behavior name to override and file path are specified on the commandline, e.g.
/// player.exe --mlagents-override-model-directory /path/to/models
///
/// Additionally, a number of episodes to run can be specified; after this, the application will quit.
/// Note this will only work with example scenes that have 1:1 Agent:Behaviors. More complicated scenes like WallJump
/// probably won't override correctly.
/// </summary>
public class ModelOverrider : MonoBehaviour
{
HashSet<string> k_SupportedExtensions = new HashSet<string> { "nn", "onnx" };
const string k_CommandLineModelOverrideDirectoryFlag = "--mlagents-override-model-directory";
const string k_CommandLineModelOverrideExtensionFlag = "--mlagents-override-model-extension";
const string k_CommandLineQuitAfterEpisodesFlag = "--mlagents-quit-after-episodes";
const string k_CommandLineQuitAfterSeconds = "--mlagents-quit-after-seconds";
const string k_CommandLineQuitOnLoadFailure = "--mlagents-quit-on-load-failure";
// The attached Agent
Agent m_Agent;
// Whether or not the commandline args have already been processed.
// Used to make sure that HasOverrides doesn't spam the logs if it's called multiple times.
private bool m_HaveProcessedCommandLine;
string m_BehaviorNameOverrideDirectory;
private string m_OriginalBehaviorName;
private List<string> m_OverrideExtensions = new List<string>();
// Cached loaded NNModels, with the behavior name as the key.
Dictionary<string, NNModel> m_CachedModels = new Dictionary<string, NNModel>();
// Max episodes to run. Only used if > 0
// Will default to 1 if override models are specified, otherwise 0.
int m_MaxEpisodes;
// Deadline - exit if the time exceeds this
DateTime m_Deadline = DateTime.MaxValue;
int m_NumSteps;
int m_PreviousNumSteps;
int m_PreviousAgentCompletedEpisodes;
bool m_QuitOnLoadFailure;
[Tooltip("Debug values to be used in place of the command line for overriding models.")]
public string debugCommandLineOverride;
// Static values to keep track of completed episodes and steps across resets
// These are updated in OnDisable.
static int s_PreviousAgentCompletedEpisodes;
static int s_PreviousNumSteps;
int TotalCompletedEpisodes
{
get { return m_PreviousAgentCompletedEpisodes + (m_Agent == null ? 0 : m_Agent.CompletedEpisodes); }
}
int TotalNumSteps
{
get { return m_PreviousNumSteps + m_NumSteps; }
}
public bool HasOverrides
{
get
{
GetAssetPathFromCommandLine();
return !string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory);
}
}
/// <summary>
/// The original behavior name of the agent. The actual behavior name will change when it is overridden.
/// </summary>
public string OriginalBehaviorName
{
get
{
if (string.IsNullOrEmpty(m_OriginalBehaviorName))
{
var bp = m_Agent.GetComponent<BehaviorParameters>();
m_OriginalBehaviorName = bp.BehaviorName;
}
return m_OriginalBehaviorName;
}
}
public static string GetOverrideBehaviorName(string originalBehaviorName)
{
return $"Override_{originalBehaviorName}";
}
/// <summary>
/// Get the asset path to use from the commandline arguments.
/// Can be called multiple times - if m_HaveProcessedCommandLine is set, will have no effect.
/// </summary>
/// <returns></returns>
void GetAssetPathFromCommandLine()
{
if (m_HaveProcessedCommandLine)
{
return;
}
var maxEpisodes = 0;
var timeoutSeconds = 0;
string[] commandLineArgsOverride = null;
if (!string.IsNullOrEmpty(debugCommandLineOverride) && Application.isEditor)
{
commandLineArgsOverride = debugCommandLineOverride.Split(' ');
}
var args = commandLineArgsOverride ?? Environment.GetCommandLineArgs();
for (var i = 0; i < args.Length; i++)
{
if (args[i] == k_CommandLineModelOverrideDirectoryFlag && i < args.Length - 1)
{
m_BehaviorNameOverrideDirectory = args[i + 1].Trim();
}
else if (args[i] == k_CommandLineModelOverrideExtensionFlag && i < args.Length - 1)
{
var overrideExtension = args[i + 1].Trim().ToLower();
var isKnownExtension = k_SupportedExtensions.Contains(overrideExtension);
if (!isKnownExtension)
{
Debug.LogError($"loading unsupported format: {overrideExtension}");
Application.Quit(1);
#if UNITY_EDITOR
EditorApplication.isPlaying = false;
#endif
}
m_OverrideExtensions.Add(overrideExtension);
}
else if (args[i] == k_CommandLineQuitAfterEpisodesFlag && i < args.Length - 1)
{
Int32.TryParse(args[i + 1], out maxEpisodes);
}
else if (args[i] == k_CommandLineQuitAfterSeconds && i < args.Length - 1)
{
Int32.TryParse(args[i + 1], out timeoutSeconds);
}
else if (args[i] == k_CommandLineQuitOnLoadFailure)
{
m_QuitOnLoadFailure = true;
}
}
if (!string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory))
{
// If overriding models, set maxEpisodes to 1 or the command line value
m_MaxEpisodes = maxEpisodes > 0 ? maxEpisodes : 1;
Debug.Log($"setting m_MaxEpisodes to {maxEpisodes}");
}
if (timeoutSeconds > 0)
{
m_Deadline = DateTime.Now + TimeSpan.FromSeconds(timeoutSeconds);
Debug.Log($"setting deadline to {timeoutSeconds} from now.");
}
m_HaveProcessedCommandLine = true;
}
void OnEnable()
{
// Start with these initialized to previous values in the case where we're resetting scenes.
m_PreviousNumSteps = s_PreviousNumSteps;
m_PreviousAgentCompletedEpisodes = s_PreviousAgentCompletedEpisodes;
m_Agent = GetComponent<Agent>();
GetAssetPathFromCommandLine();
if (HasOverrides)
{
OverrideModel();
}
}
void OnDisable()
{
// Update the static episode and step counts.
// For a single agent in the scene, this will be a straightforward increment.
// If there are multiple agents, we'll increment the count by the Agent that completed the most episodes.
s_PreviousAgentCompletedEpisodes = Mathf.Max(s_PreviousAgentCompletedEpisodes, TotalCompletedEpisodes);
s_PreviousNumSteps = Mathf.Max(s_PreviousNumSteps, TotalNumSteps);
}
void FixedUpdate()
{
if (m_MaxEpisodes > 0)
{
// For Agents without maxSteps, exit as soon as we've hit the target number of episodes.
// For Agents that specify MaxStep, also make sure we've gone at least that many steps.
// Since we exit as soon as *any* Agent hits its target, the maxSteps condition keeps us running
// a bit longer in case there's an early failure.
if (TotalCompletedEpisodes >= m_MaxEpisodes && TotalNumSteps > m_MaxEpisodes * m_Agent.MaxStep)
{
Debug.Log($"ModelOverride reached {TotalCompletedEpisodes} episodes and {TotalNumSteps} steps. Exiting.");
Application.Quit(0);
#if UNITY_EDITOR
EditorApplication.isPlaying = false;
#endif
}
else if (DateTime.Now >= m_Deadline)
{
Debug.Log(
$"Deadline exceeded. " +
$"{TotalCompletedEpisodes}/{m_MaxEpisodes} episodes and " +
$"{TotalNumSteps}/{m_MaxEpisodes * m_Agent.MaxStep} steps completed. Exiting.");
Application.Quit(0);
#if UNITY_EDITOR
EditorApplication.isPlaying = false;
#endif
}
}
m_NumSteps++;
}
public NNModel GetModelForBehaviorName(string behaviorName)
{
if (m_CachedModels.ContainsKey(behaviorName))
{
return m_CachedModels[behaviorName];
}
if (string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory))
{
Debug.Log($"No override directory set.");
return null;
}
// Try the override extensions in order. If they weren't set, try .nn first, then .onnx.
var overrideExtensions = (m_OverrideExtensions.Count > 0)
? m_OverrideExtensions.ToArray()
: new[] { "nn", "onnx" };
byte[] rawModel = null;
bool isOnnx = false;
string assetName = null;
foreach (var overrideExtension in overrideExtensions)
{
var assetPath = Path.Combine(m_BehaviorNameOverrideDirectory, $"{behaviorName}.{overrideExtension}");
try
{
rawModel = File.ReadAllBytes(assetPath);
isOnnx = overrideExtension.Equals("onnx");
assetName = "Override - " + Path.GetFileName(assetPath);
break;
}
catch (IOException)
{
// Do nothing - try the next extension, or we'll exit if nothing loaded.
}
}
if (rawModel == null)
{
Debug.Log($"Couldn't load model file(s) for {behaviorName} in {m_BehaviorNameOverrideDirectory} (full path: {Path.GetFullPath(m_BehaviorNameOverrideDirectory)}");
// Cache the null so we don't repeatedly try to load a missing file
m_CachedModels[behaviorName] = null;
return null;
}
var asset = isOnnx ? LoadOnnxModel(rawModel) : LoadBarracudaModel(rawModel);
asset.name = assetName;
m_CachedModels[behaviorName] = asset;
return asset;
}
NNModel LoadBarracudaModel(byte[] rawModel)
{
var asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = ScriptableObject.CreateInstance<NNModelData>();
asset.modelData.Value = rawModel;
return asset;
}
NNModel LoadOnnxModel(byte[] rawModel)
{
var converter = new ONNXModelConverter(true);
var onnxModel = converter.Convert(rawModel);
NNModelData assetData = ScriptableObject.CreateInstance<NNModelData>();
using (var memoryStream = new MemoryStream())
using (var writer = new BinaryWriter(memoryStream))
{
ModelWriter.Save(writer, onnxModel);
assetData.Value = memoryStream.ToArray();
}
assetData.name = "Data";
assetData.hideFlags = HideFlags.HideInHierarchy;
var asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = assetData;
return asset;
}
/// <summary>
/// Load the NNModel file from the specified path, and give it to the attached agent.
/// </summary>
void OverrideModel()
{
bool overrideOk = false;
string overrideError = null;
m_Agent.LazyInitialize();
NNModel nnModel = null;
try
{
nnModel = GetModelForBehaviorName(OriginalBehaviorName);
}
catch (Exception e)
{
overrideError = $"Exception calling GetModelForBehaviorName: {e}";
}
if (nnModel == null)
{
if (string.IsNullOrEmpty(overrideError))
{
overrideError =
$"Didn't find a model for behaviorName {OriginalBehaviorName}. Make " +
"sure the behaviorName is set correctly in the commandline " +
"and that the model file exists";
}
}
else
{
var modelName = nnModel != null ? nnModel.name : "<null>";
Debug.Log($"Overriding behavior {OriginalBehaviorName} for agent with model {modelName}");
try
{
m_Agent.SetModel(GetOverrideBehaviorName(OriginalBehaviorName), nnModel);
overrideOk = true;
}
catch (Exception e)
{
overrideError = $"Exception calling Agent.SetModel: {e}";
}
}
if (!overrideOk && m_QuitOnLoadFailure)
{
if (!string.IsNullOrEmpty(overrideError))
{
Debug.LogWarning(overrideError);
}
Application.Quit(1);
#if UNITY_EDITOR
EditorApplication.isPlaying = false;
#endif
}
}
}
}
|