|
using System; |
|
using System.Collections.Generic; |
|
using System.Reflection; |
|
using UnityEngine; |
|
|
|
namespace Unity.MLAgents.Sensors.Reflection |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] |
|
public class ObservableAttribute : Attribute |
|
{ |
|
string m_Name; |
|
int m_NumStackedObservations; |
|
|
|
|
|
|
|
|
|
const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; |
|
|
|
|
|
|
|
|
|
static Dictionary<Type, (int, Type)> s_TypeToSensorInfo = new Dictionary<Type, (int, Type)>() |
|
{ |
|
{typeof(int), (1, typeof(IntReflectionSensor))}, |
|
{typeof(bool), (1, typeof(BoolReflectionSensor))}, |
|
{typeof(float), (1, typeof(FloatReflectionSensor))}, |
|
|
|
{typeof(Vector2), (2, typeof(Vector2ReflectionSensor))}, |
|
{typeof(Vector3), (3, typeof(Vector3ReflectionSensor))}, |
|
{typeof(Vector4), (4, typeof(Vector4ReflectionSensor))}, |
|
{typeof(Quaternion), (4, typeof(QuaternionReflectionSensor))}, |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public ObservableAttribute(string name = null, int numStackedObservations = 1) |
|
{ |
|
m_Name = name; |
|
m_NumStackedObservations = numStackedObservations; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited) |
|
{ |
|
|
|
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); |
|
var fields = o.GetType().GetFields(bindingFlags); |
|
foreach (var field in fields) |
|
{ |
|
var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); |
|
if (attr != null) |
|
{ |
|
yield return (field, attr); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited) |
|
{ |
|
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); |
|
var properties = o.GetType().GetProperties(bindingFlags); |
|
foreach (var prop in properties) |
|
{ |
|
var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); |
|
if (attr != null) |
|
{ |
|
yield return (prop, attr); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal static List<ISensor> CreateObservableSensors(object o, bool excludeInherited) |
|
{ |
|
var sensorsOut = new List<ISensor>(); |
|
foreach (var (field, attr) in GetObservableFields(o, excludeInherited)) |
|
{ |
|
var sensor = CreateReflectionSensor(o, field, null, attr); |
|
if (sensor != null) |
|
{ |
|
sensorsOut.Add(sensor); |
|
} |
|
} |
|
|
|
foreach (var (prop, attr) in GetObservableProperties(o, excludeInherited)) |
|
{ |
|
if (!prop.CanRead) |
|
{ |
|
|
|
continue; |
|
} |
|
var sensor = CreateReflectionSensor(o, null, prop, attr); |
|
if (sensor != null) |
|
{ |
|
sensorsOut.Add(sensor); |
|
} |
|
} |
|
|
|
return sensorsOut; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) |
|
{ |
|
string memberName; |
|
string declaringTypeName; |
|
Type memberType; |
|
if (fieldInfo != null) |
|
{ |
|
declaringTypeName = fieldInfo.DeclaringType.Name; |
|
memberName = fieldInfo.Name; |
|
memberType = fieldInfo.FieldType; |
|
} |
|
else |
|
{ |
|
declaringTypeName = propertyInfo.DeclaringType.Name; |
|
memberName = propertyInfo.Name; |
|
memberType = propertyInfo.PropertyType; |
|
} |
|
|
|
if (!s_TypeToSensorInfo.ContainsKey(memberType) && !memberType.IsEnum) |
|
{ |
|
|
|
return null; |
|
} |
|
|
|
string sensorName; |
|
if (string.IsNullOrEmpty(observableAttribute.m_Name)) |
|
{ |
|
sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}"; |
|
} |
|
else |
|
{ |
|
sensorName = observableAttribute.m_Name; |
|
} |
|
|
|
var reflectionSensorInfo = new ReflectionSensorInfo |
|
{ |
|
Object = o, |
|
FieldInfo = fieldInfo, |
|
PropertyInfo = propertyInfo, |
|
ObservableAttribute = observableAttribute, |
|
SensorName = sensorName |
|
}; |
|
|
|
ISensor sensor = null; |
|
if (memberType.IsEnum) |
|
{ |
|
sensor = new EnumReflectionSensor(reflectionSensorInfo); |
|
} |
|
else |
|
{ |
|
var (_, sensorType) = s_TypeToSensorInfo[memberType]; |
|
sensor = (ISensor)Activator.CreateInstance(sensorType, reflectionSensorInfo); |
|
} |
|
|
|
|
|
if (observableAttribute.m_NumStackedObservations > 1) |
|
{ |
|
return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations); |
|
} |
|
|
|
return sensor; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal static int GetTotalObservationSize(object o, bool excludeInherited, List<string> errorsOut) |
|
{ |
|
int sizeOut = 0; |
|
foreach (var (field, attr) in GetObservableFields(o, excludeInherited)) |
|
{ |
|
if (s_TypeToSensorInfo.ContainsKey(field.FieldType)) |
|
{ |
|
var (obsSize, _) = s_TypeToSensorInfo[field.FieldType]; |
|
sizeOut += obsSize * attr.m_NumStackedObservations; |
|
} |
|
else if (field.FieldType.IsEnum) |
|
{ |
|
sizeOut += EnumReflectionSensor.GetEnumObservationSize(field.FieldType); |
|
} |
|
else |
|
{ |
|
errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}"); |
|
} |
|
} |
|
|
|
foreach (var (prop, attr) in GetObservableProperties(o, excludeInherited)) |
|
{ |
|
if (!prop.CanRead) |
|
{ |
|
errorsOut.Add($"Observable property {prop.Name} is write-only."); |
|
} |
|
else if (s_TypeToSensorInfo.ContainsKey(prop.PropertyType)) |
|
{ |
|
var (obsSize, _) = s_TypeToSensorInfo[prop.PropertyType]; |
|
sizeOut += obsSize * attr.m_NumStackedObservations; |
|
} |
|
else if (prop.PropertyType.IsEnum) |
|
{ |
|
sizeOut += EnumReflectionSensor.GetEnumObservationSize(prop.PropertyType); |
|
} |
|
else |
|
{ |
|
errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}"); |
|
} |
|
} |
|
|
|
return sizeOut; |
|
} |
|
} |
|
} |
|
|