|
using System.Collections.Generic; |
|
using System.Collections.ObjectModel; |
|
using UnityEngine; |
|
|
|
namespace Unity.MLAgents.Sensors |
|
{ |
|
|
|
|
|
|
|
public class VectorSensor : ISensor, IBuiltInSensor |
|
{ |
|
|
|
|
|
List<float> m_Observations; |
|
ObservationSpec m_ObservationSpec; |
|
string m_Name; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default) |
|
{ |
|
if (string.IsNullOrEmpty(name)) |
|
{ |
|
name = $"VectorSensor_size{observationSize}"; |
|
if (observationType != ObservationType.Default) |
|
{ |
|
name += $"_{observationType.ToString()}"; |
|
} |
|
} |
|
|
|
m_Observations = new List<float>(observationSize); |
|
m_Name = name; |
|
m_ObservationSpec = ObservationSpec.Vector(observationSize, observationType); |
|
} |
|
|
|
|
|
public int Write(ObservationWriter writer) |
|
{ |
|
var expectedObservations = m_ObservationSpec.Shape[0]; |
|
if (m_Observations.Count > expectedObservations) |
|
{ |
|
|
|
Debug.LogWarningFormat( |
|
"More observations ({0}) made than vector observation size ({1}). The observations will be truncated.", |
|
m_Observations.Count, expectedObservations |
|
); |
|
m_Observations.RemoveRange(expectedObservations, m_Observations.Count - expectedObservations); |
|
} |
|
else if (m_Observations.Count < expectedObservations) |
|
{ |
|
|
|
Debug.LogWarningFormat( |
|
"Fewer observations ({0}) made than vector observation size ({1}). The observations will be padded.", |
|
m_Observations.Count, expectedObservations |
|
); |
|
for (int i = m_Observations.Count; i < expectedObservations; i++) |
|
{ |
|
m_Observations.Add(0); |
|
} |
|
} |
|
writer.AddList(m_Observations); |
|
return expectedObservations; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
internal ReadOnlyCollection<float> GetObservations() |
|
{ |
|
return m_Observations.AsReadOnly(); |
|
} |
|
|
|
|
|
public void Update() |
|
{ |
|
Clear(); |
|
} |
|
|
|
|
|
public void Reset() |
|
{ |
|
Clear(); |
|
} |
|
|
|
|
|
public ObservationSpec GetObservationSpec() |
|
{ |
|
return m_ObservationSpec; |
|
} |
|
|
|
|
|
public string GetName() |
|
{ |
|
return m_Name; |
|
} |
|
|
|
|
|
public virtual byte[] GetCompressedObservation() |
|
{ |
|
return null; |
|
} |
|
|
|
|
|
public CompressionSpec GetCompressionSpec() |
|
{ |
|
return CompressionSpec.Default(); |
|
} |
|
|
|
|
|
public BuiltInSensorType GetBuiltInSensorType() |
|
{ |
|
return BuiltInSensorType.VectorSensor; |
|
} |
|
|
|
void Clear() |
|
{ |
|
m_Observations.Clear(); |
|
} |
|
|
|
void AddFloatObs(float obs) |
|
{ |
|
Utilities.DebugCheckNanAndInfinity(obs, nameof(obs), nameof(AddFloatObs)); |
|
m_Observations.Add(obs); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public void AddObservation(float observation) |
|
{ |
|
AddFloatObs(observation); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void AddObservation(int observation) |
|
{ |
|
AddFloatObs(observation); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void AddObservation(Vector3 observation) |
|
{ |
|
AddFloatObs(observation.x); |
|
AddFloatObs(observation.y); |
|
AddFloatObs(observation.z); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void AddObservation(Vector2 observation) |
|
{ |
|
AddFloatObs(observation.x); |
|
AddFloatObs(observation.y); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void AddObservation(IList<float> observation) |
|
{ |
|
for (var i = 0; i < observation.Count; i++) |
|
{ |
|
AddFloatObs(observation[i]); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void AddObservation(Quaternion observation) |
|
{ |
|
AddFloatObs(observation.x); |
|
AddFloatObs(observation.y); |
|
AddFloatObs(observation.z); |
|
AddFloatObs(observation.w); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void AddObservation(bool observation) |
|
{ |
|
AddFloatObs(observation ? 1f : 0f); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public void AddOneHotObservation(int observation, int range) |
|
{ |
|
for (var i = 0; i < range; i++) |
|
{ |
|
AddFloatObs(i == observation ? 1.0f : 0.0f); |
|
} |
|
} |
|
} |
|
} |
|
|