File size: 14,462 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
using System;
using System.Collections.Generic;
using UnityEngine;
using Unity.Barracuda;
using System.IO;
using Unity.Barracuda.ONNX;
using Unity.MLAgents;
using Unity.MLAgents.Policies;
#if UNITY_EDITOR
using UnityEditor;
#endif

namespace Unity.MLAgentsExamples
{
    /// <summary>
    /// Utility class to allow the NNModel file for an agent to be overriden during inference.
    /// This is used internally to validate the file after training is done.
    /// The behavior name to override and file path are specified on the commandline, e.g.
    /// player.exe --mlagents-override-model-directory /path/to/models
    ///
    /// Additionally, a number of episodes to run can be specified; after this, the application will quit.
    /// Note this will only work with example scenes that have 1:1 Agent:Behaviors. More complicated scenes like WallJump
    /// probably won't override correctly.
    /// </summary>
    public class ModelOverrider : MonoBehaviour
    {
        HashSet<string> k_SupportedExtensions = new HashSet<string> { "nn", "onnx" };
        const string k_CommandLineModelOverrideDirectoryFlag = "--mlagents-override-model-directory";
        const string k_CommandLineModelOverrideExtensionFlag = "--mlagents-override-model-extension";
        const string k_CommandLineQuitAfterEpisodesFlag = "--mlagents-quit-after-episodes";
        const string k_CommandLineQuitAfterSeconds = "--mlagents-quit-after-seconds";
        const string k_CommandLineQuitOnLoadFailure = "--mlagents-quit-on-load-failure";

        // The attached Agent
        Agent m_Agent;

        // Whether or not the commandline args have already been processed.
        // Used to make sure that HasOverrides doesn't spam the logs if it's called multiple times.
        private bool m_HaveProcessedCommandLine;

        string m_BehaviorNameOverrideDirectory;

        private string m_OriginalBehaviorName;

        private List<string> m_OverrideExtensions = new List<string>();

        // Cached loaded NNModels, with the behavior name as the key.
        Dictionary<string, NNModel> m_CachedModels = new Dictionary<string, NNModel>();


        // Max episodes to run. Only used if > 0
        // Will default to 1 if override models are specified, otherwise 0.
        int m_MaxEpisodes;

        // Deadline - exit if the time exceeds this
        DateTime m_Deadline = DateTime.MaxValue;

        int m_NumSteps;
        int m_PreviousNumSteps;
        int m_PreviousAgentCompletedEpisodes;

        bool m_QuitOnLoadFailure;
        [Tooltip("Debug values to be used in place of the command line for overriding models.")]
        public string debugCommandLineOverride;

        // Static values to keep track of completed episodes and steps across resets
        // These are updated in OnDisable.
        static int s_PreviousAgentCompletedEpisodes;
        static int s_PreviousNumSteps;

        int TotalCompletedEpisodes
        {
            get { return m_PreviousAgentCompletedEpisodes + (m_Agent == null ? 0 : m_Agent.CompletedEpisodes); }
        }

        int TotalNumSteps
        {
            get { return m_PreviousNumSteps + m_NumSteps; }
        }

        public bool HasOverrides
        {
            get
            {
                GetAssetPathFromCommandLine();
                return !string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory);
            }
        }

        /// <summary>
        /// The original behavior name of the agent. The actual behavior name will change when it is overridden.
        /// </summary>
        public string OriginalBehaviorName
        {
            get
            {
                if (string.IsNullOrEmpty(m_OriginalBehaviorName))
                {
                    var bp = m_Agent.GetComponent<BehaviorParameters>();
                    m_OriginalBehaviorName = bp.BehaviorName;
                }

                return m_OriginalBehaviorName;
            }
        }

        public static string GetOverrideBehaviorName(string originalBehaviorName)
        {
            return $"Override_{originalBehaviorName}";
        }

        /// <summary>
        /// Get the asset path to use from the commandline arguments.
        /// Can be called multiple times - if m_HaveProcessedCommandLine is set, will have no effect.
        /// </summary>
        /// <returns></returns>
        void GetAssetPathFromCommandLine()
        {
            if (m_HaveProcessedCommandLine)
            {
                return;
            }
            var maxEpisodes = 0;
            var timeoutSeconds = 0;

            string[] commandLineArgsOverride = null;
            if (!string.IsNullOrEmpty(debugCommandLineOverride) && Application.isEditor)
            {
                commandLineArgsOverride = debugCommandLineOverride.Split(' ');
            }

            var args = commandLineArgsOverride ?? Environment.GetCommandLineArgs();
            for (var i = 0; i < args.Length; i++)
            {
                if (args[i] == k_CommandLineModelOverrideDirectoryFlag && i < args.Length - 1)
                {
                    m_BehaviorNameOverrideDirectory = args[i + 1].Trim();
                }
                else if (args[i] == k_CommandLineModelOverrideExtensionFlag && i < args.Length - 1)
                {
                    var overrideExtension = args[i + 1].Trim().ToLower();
                    var isKnownExtension = k_SupportedExtensions.Contains(overrideExtension);
                    if (!isKnownExtension)
                    {
                        Debug.LogError($"loading unsupported format: {overrideExtension}");
                        Application.Quit(1);
#if UNITY_EDITOR
                        EditorApplication.isPlaying = false;
#endif
                    }
                    m_OverrideExtensions.Add(overrideExtension);
                }
                else if (args[i] == k_CommandLineQuitAfterEpisodesFlag && i < args.Length - 1)
                {
                    Int32.TryParse(args[i + 1], out maxEpisodes);
                }
                else if (args[i] == k_CommandLineQuitAfterSeconds && i < args.Length - 1)
                {
                    Int32.TryParse(args[i + 1], out timeoutSeconds);
                }
                else if (args[i] == k_CommandLineQuitOnLoadFailure)
                {
                    m_QuitOnLoadFailure = true;
                }
            }

            if (!string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory))
            {
                // If overriding models, set maxEpisodes to 1 or the command line value
                m_MaxEpisodes = maxEpisodes > 0 ? maxEpisodes : 1;
                Debug.Log($"setting m_MaxEpisodes to {maxEpisodes}");
            }

            if (timeoutSeconds > 0)
            {
                m_Deadline = DateTime.Now + TimeSpan.FromSeconds(timeoutSeconds);
                Debug.Log($"setting deadline to {timeoutSeconds} from now.");
            }

            m_HaveProcessedCommandLine = true;
        }

        void OnEnable()
        {
            // Start with these initialized to previous values in the case where we're resetting scenes.
            m_PreviousNumSteps = s_PreviousNumSteps;
            m_PreviousAgentCompletedEpisodes = s_PreviousAgentCompletedEpisodes;

            m_Agent = GetComponent<Agent>();

            GetAssetPathFromCommandLine();
            if (HasOverrides)
            {
                OverrideModel();
            }
        }

        void OnDisable()
        {
            // Update the static episode and step counts.
            // For a single agent in the scene, this will be a straightforward increment.
            // If there are multiple agents, we'll increment the count by the Agent that completed the most episodes.
            s_PreviousAgentCompletedEpisodes = Mathf.Max(s_PreviousAgentCompletedEpisodes, TotalCompletedEpisodes);
            s_PreviousNumSteps = Mathf.Max(s_PreviousNumSteps, TotalNumSteps);
        }

        void FixedUpdate()
        {
            if (m_MaxEpisodes > 0)
            {
                // For Agents without maxSteps, exit as soon as we've hit the target number of episodes.
                // For Agents that specify MaxStep, also make sure we've gone at least that many steps.
                // Since we exit as soon as *any* Agent hits its target, the maxSteps condition keeps us running
                // a bit longer in case there's an early failure.
                if (TotalCompletedEpisodes >= m_MaxEpisodes && TotalNumSteps > m_MaxEpisodes * m_Agent.MaxStep)
                {
                    Debug.Log($"ModelOverride reached {TotalCompletedEpisodes} episodes and {TotalNumSteps} steps. Exiting.");
                    Application.Quit(0);
#if UNITY_EDITOR
                    EditorApplication.isPlaying = false;
#endif
                }
                else if (DateTime.Now >= m_Deadline)
                {
                    Debug.Log(
                        $"Deadline exceeded. " +
                        $"{TotalCompletedEpisodes}/{m_MaxEpisodes} episodes and " +
                        $"{TotalNumSteps}/{m_MaxEpisodes * m_Agent.MaxStep} steps completed. Exiting.");
                    Application.Quit(0);
#if UNITY_EDITOR
                    EditorApplication.isPlaying = false;
#endif
                }
            }

            m_NumSteps++;
        }

        public NNModel GetModelForBehaviorName(string behaviorName)
        {
            if (m_CachedModels.ContainsKey(behaviorName))
            {
                return m_CachedModels[behaviorName];
            }

            if (string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory))
            {
                Debug.Log($"No override directory set.");
                return null;
            }

            // Try the override extensions in order. If they weren't set, try .nn first, then .onnx.
            var overrideExtensions = (m_OverrideExtensions.Count > 0)
                ? m_OverrideExtensions.ToArray()
                : new[] { "nn", "onnx" };

            byte[] rawModel = null;
            bool isOnnx = false;
            string assetName = null;
            foreach (var overrideExtension in overrideExtensions)
            {
                var assetPath = Path.Combine(m_BehaviorNameOverrideDirectory, $"{behaviorName}.{overrideExtension}");
                try
                {
                    rawModel = File.ReadAllBytes(assetPath);
                    isOnnx = overrideExtension.Equals("onnx");
                    assetName = "Override - " + Path.GetFileName(assetPath);
                    break;
                }
                catch (IOException)
                {
                    // Do nothing - try the next extension, or we'll exit if nothing loaded.
                }
            }

            if (rawModel == null)
            {
                Debug.Log($"Couldn't load model file(s) for {behaviorName} in {m_BehaviorNameOverrideDirectory} (full path: {Path.GetFullPath(m_BehaviorNameOverrideDirectory)}");
                // Cache the null so we don't repeatedly try to load a missing file
                m_CachedModels[behaviorName] = null;
                return null;
            }

            var asset = isOnnx ? LoadOnnxModel(rawModel) : LoadBarracudaModel(rawModel);
            asset.name = assetName;
            m_CachedModels[behaviorName] = asset;
            return asset;
        }

        NNModel LoadBarracudaModel(byte[] rawModel)
        {
            var asset = ScriptableObject.CreateInstance<NNModel>();
            asset.modelData = ScriptableObject.CreateInstance<NNModelData>();
            asset.modelData.Value = rawModel;
            return asset;
        }

        NNModel LoadOnnxModel(byte[] rawModel)
        {
            var converter = new ONNXModelConverter(true);
            var onnxModel = converter.Convert(rawModel);

            NNModelData assetData = ScriptableObject.CreateInstance<NNModelData>();
            using (var memoryStream = new MemoryStream())
            using (var writer = new BinaryWriter(memoryStream))
            {
                ModelWriter.Save(writer, onnxModel);
                assetData.Value = memoryStream.ToArray();
            }
            assetData.name = "Data";
            assetData.hideFlags = HideFlags.HideInHierarchy;

            var asset = ScriptableObject.CreateInstance<NNModel>();
            asset.modelData = assetData;
            return asset;
        }


        /// <summary>
        /// Load the NNModel file from the specified path, and give it to the attached agent.
        /// </summary>
        void OverrideModel()
        {
            bool overrideOk = false;
            string overrideError = null;

            m_Agent.LazyInitialize();

            NNModel nnModel = null;
            try
            {
                nnModel = GetModelForBehaviorName(OriginalBehaviorName);
            }
            catch (Exception e)
            {
                overrideError = $"Exception calling GetModelForBehaviorName: {e}";
            }

            if (nnModel == null)
            {
                if (string.IsNullOrEmpty(overrideError))
                {
                    overrideError =
                        $"Didn't find a model for behaviorName {OriginalBehaviorName}. Make " +
                        "sure the behaviorName is set correctly in the commandline " +
                        "and that the model file exists";
                }
            }
            else
            {
                var modelName = nnModel != null ? nnModel.name : "<null>";
                Debug.Log($"Overriding behavior {OriginalBehaviorName} for agent with model {modelName}");
                try
                {
                    m_Agent.SetModel(GetOverrideBehaviorName(OriginalBehaviorName), nnModel);
                    overrideOk = true;
                }
                catch (Exception e)
                {
                    overrideError = $"Exception calling Agent.SetModel: {e}";
                }
            }

            if (!overrideOk && m_QuitOnLoadFailure)
            {
                if (!string.IsNullOrEmpty(overrideError))
                {
                    Debug.LogWarning(overrideError);
                }
                Application.Quit(1);
#if UNITY_EDITOR
                EditorApplication.isPlaying = false;
#endif
            }

        }
    }
}