File size: 11,637 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
using System;
using System.Collections.Generic;
using System.Reflection;
using UnityEngine;

namespace Unity.MLAgents.Sensors.Reflection
{
    /// <summary>
    /// Specify that a field or property should be used to generate observations for an Agent.
    /// For each field or property that uses ObservableAttribute, a corresponding
    /// <see cref="ISensor"/> will be created during Agent initialization, and this
    /// sensor will read the values during training and inference.
    /// </summary>
    /// <remarks>
    /// ObservableAttribute is intended to make initial setup of an Agent easier. Because it
    /// uses reflection to read the values of fields and properties at runtime, this may
    /// be much slower than reading the values directly. If the performance of
    /// ObservableAttribute is an issue, you can get the same functionality by overriding
    /// <see cref="Agent.CollectObservations(VectorSensor)"/> or creating a custom
    /// <see cref="ISensor"/> implementation to read the values without reflection.
    ///
    /// Note that you do not need to adjust the VectorObservationSize in
    /// <see cref="Unity.MLAgents.Policies.BrainParameters"/> when adding ObservableAttribute
    /// to fields or properties.
    /// </remarks>
    /// <example>
    /// This sample class will produce two observations, one for the m_Health field, and one
    /// for the HealthPercent property.
    /// <code>
    /// using Unity.MLAgents;
    /// using Unity.MLAgents.Sensors.Reflection;
    ///
    /// public class MyAgent : Agent
    /// {
    ///     [Observable]
    ///     int m_Health;
    ///
    ///     [Observable]
    ///     float HealthPercent
    ///     {
    ///         get => return 100.0f * m_Health / float(m_MaxHealth);
    ///     }
    /// }
    /// </code>
    /// </example>
    [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
    public class ObservableAttribute : Attribute
    {
        string m_Name;
        int m_NumStackedObservations;

        /// <summary>
        /// Default binding flags used for reflection of members and properties.
        /// </summary>
        const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;

        /// <summary>
        /// Supported types and their observation sizes and corresponding sensor type.
        /// </summary>
        static Dictionary<Type, (int, Type)> s_TypeToSensorInfo = new Dictionary<Type, (int, Type)>()
        {
            {typeof(int), (1, typeof(IntReflectionSensor))},
            {typeof(bool), (1, typeof(BoolReflectionSensor))},
            {typeof(float), (1, typeof(FloatReflectionSensor))},

            {typeof(Vector2), (2, typeof(Vector2ReflectionSensor))},
            {typeof(Vector3), (3, typeof(Vector3ReflectionSensor))},
            {typeof(Vector4), (4, typeof(Vector4ReflectionSensor))},
            {typeof(Quaternion), (4, typeof(QuaternionReflectionSensor))},
        };

        /// <summary>
        /// ObservableAttribute constructor.
        /// </summary>
        /// <param name="name">Optional override for the sensor name. Note that all sensors for an Agent
        /// must have a unique name.</param>
        /// <param name="numStackedObservations">Number of frames to concatenate observations from.</param>
        public ObservableAttribute(string name = null, int numStackedObservations = 1)
        {
            m_Name = name;
            m_NumStackedObservations = numStackedObservations;
        }

        /// <summary>
        /// Returns a FieldInfo for all fields that have an ObservableAttribute
        /// </summary>
        /// <param name="o">Object being reflected</param>
        /// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
        /// <returns></returns>
        static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited)
        {
            // TODO cache these (and properties) by type, so that we only have to reflect once.
            var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0);
            var fields = o.GetType().GetFields(bindingFlags);
            foreach (var field in fields)
            {
                var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute));
                if (attr != null)
                {
                    yield return (field, attr);
                }
            }
        }

        /// <summary>
        /// Returns a PropertyInfo for all fields that have an ObservableAttribute
        /// </summary>
        /// <param name="o">Object being reflected</param>
        /// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
        /// <returns></returns>
        static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited)
        {
            var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0);
            var properties = o.GetType().GetProperties(bindingFlags);
            foreach (var prop in properties)
            {
                var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute));
                if (attr != null)
                {
                    yield return (prop, attr);
                }
            }
        }

        /// <summary>
        /// Creates sensors for each field and property with ObservableAttribute.
        /// </summary>
        /// <param name="o">Object being reflected</param>
        /// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
        /// <returns></returns>
        internal static List<ISensor> CreateObservableSensors(object o, bool excludeInherited)
        {
            var sensorsOut = new List<ISensor>();
            foreach (var (field, attr) in GetObservableFields(o, excludeInherited))
            {
                var sensor = CreateReflectionSensor(o, field, null, attr);
                if (sensor != null)
                {
                    sensorsOut.Add(sensor);
                }
            }

            foreach (var (prop, attr) in GetObservableProperties(o, excludeInherited))
            {
                if (!prop.CanRead)
                {
                    // Skip unreadable properties.
                    continue;
                }
                var sensor = CreateReflectionSensor(o, null, prop, attr);
                if (sensor != null)
                {
                    sensorsOut.Add(sensor);
                }
            }

            return sensorsOut;
        }

        /// <summary>
        /// Create the ISensor for either the field or property on the provided object.
        /// If the data type is unsupported, or the property is write-only, returns null.
        /// </summary>
        /// <param name="o"></param>
        /// <param name="fieldInfo"></param>
        /// <param name="propertyInfo"></param>
        /// <param name="observableAttribute"></param>
        /// <returns></returns>
        /// <exception cref="UnityAgentsException"></exception>
        static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute)
        {
            string memberName;
            string declaringTypeName;
            Type memberType;
            if (fieldInfo != null)
            {
                declaringTypeName = fieldInfo.DeclaringType.Name;
                memberName = fieldInfo.Name;
                memberType = fieldInfo.FieldType;
            }
            else
            {
                declaringTypeName = propertyInfo.DeclaringType.Name;
                memberName = propertyInfo.Name;
                memberType = propertyInfo.PropertyType;
            }

            if (!s_TypeToSensorInfo.ContainsKey(memberType) && !memberType.IsEnum)
            {
                // For unsupported types, return null and we'll filter them out later.
                return null;
            }

            string sensorName;
            if (string.IsNullOrEmpty(observableAttribute.m_Name))
            {
                sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}";
            }
            else
            {
                sensorName = observableAttribute.m_Name;
            }

            var reflectionSensorInfo = new ReflectionSensorInfo
            {
                Object = o,
                FieldInfo = fieldInfo,
                PropertyInfo = propertyInfo,
                ObservableAttribute = observableAttribute,
                SensorName = sensorName
            };

            ISensor sensor = null;
            if (memberType.IsEnum)
            {
                sensor = new EnumReflectionSensor(reflectionSensorInfo);
            }
            else
            {
                var (_, sensorType) = s_TypeToSensorInfo[memberType];
                sensor = (ISensor)Activator.CreateInstance(sensorType, reflectionSensorInfo);
            }

            // Wrap the base sensor in a StackingSensor if we're using stacking.
            if (observableAttribute.m_NumStackedObservations > 1)
            {
                return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations);
            }

            return sensor;
        }

        /// <summary>
        /// Gets the sum of the observation sizes of the Observable fields and properties on an object.
        /// Also appends errors to the errorsOut array.
        /// </summary>
        /// <param name="o"></param>
        /// <param name="excludeInherited"></param>
        /// <param name="errorsOut"></param>
        /// <returns></returns>
        internal static int GetTotalObservationSize(object o, bool excludeInherited, List<string> errorsOut)
        {
            int sizeOut = 0;
            foreach (var (field, attr) in GetObservableFields(o, excludeInherited))
            {
                if (s_TypeToSensorInfo.ContainsKey(field.FieldType))
                {
                    var (obsSize, _) = s_TypeToSensorInfo[field.FieldType];
                    sizeOut += obsSize * attr.m_NumStackedObservations;
                }
                else if (field.FieldType.IsEnum)
                {
                    sizeOut += EnumReflectionSensor.GetEnumObservationSize(field.FieldType);
                }
                else
                {
                    errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}");
                }
            }

            foreach (var (prop, attr) in GetObservableProperties(o, excludeInherited))
            {
                if (!prop.CanRead)
                {
                    errorsOut.Add($"Observable property {prop.Name} is write-only.");
                }
                else if (s_TypeToSensorInfo.ContainsKey(prop.PropertyType))
                {
                    var (obsSize, _) = s_TypeToSensorInfo[prop.PropertyType];
                    sizeOut += obsSize * attr.m_NumStackedObservations;
                }
                else if (prop.PropertyType.IsEnum)
                {
                    sizeOut += EnumReflectionSensor.GetEnumObservationSize(prop.PropertyType);
                }
                else
                {
                    errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}");
                }
            }

            return sizeOut;
        }
    }
}