|
using System; |
|
|
|
namespace Unity.MLAgents.Sensors.Reflection |
|
{ |
|
internal class EnumReflectionSensor : ReflectionSensorBase |
|
{ |
|
Array m_Values; |
|
bool m_IsFlags; |
|
|
|
internal EnumReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|
: base(reflectionSensorInfo, GetEnumObservationSize(reflectionSensorInfo.GetMemberType())) |
|
{ |
|
var memberType = reflectionSensorInfo.GetMemberType(); |
|
m_Values = Enum.GetValues(memberType); |
|
m_IsFlags = memberType.IsDefined(typeof(FlagsAttribute), false); |
|
} |
|
|
|
internal override void WriteReflectedField(ObservationWriter writer) |
|
{ |
|
|
|
|
|
|
|
var enumValue = (Enum)GetReflectedValue(); |
|
|
|
int i = 0; |
|
foreach (var val in m_Values) |
|
{ |
|
if (m_IsFlags) |
|
{ |
|
if (enumValue.HasFlag((Enum)val)) |
|
{ |
|
writer[i] = 1.0f; |
|
} |
|
else |
|
{ |
|
writer[i] = 0.0f; |
|
} |
|
} |
|
else |
|
{ |
|
if (val.Equals(enumValue)) |
|
{ |
|
writer[i] = 1.0f; |
|
} |
|
else |
|
{ |
|
writer[i] = 0.0f; |
|
} |
|
} |
|
i++; |
|
} |
|
} |
|
|
|
internal static int GetEnumObservationSize(Type t) |
|
{ |
|
var values = Enum.GetValues(t); |
|
|
|
return values.Length; |
|
} |
|
} |
|
} |
|
|