|
using System.Collections.Generic; |
|
using UnityEngine; |
|
|
|
namespace Unity.MLAgents.Sensors |
|
{ |
|
public class SensorShapeValidator |
|
{ |
|
List<ObservationSpec> m_SensorShapes; |
|
|
|
|
|
|
|
|
|
|
|
public void ValidateSensors(List<ISensor> sensors) |
|
{ |
|
if (m_SensorShapes == null) |
|
{ |
|
m_SensorShapes = new List<ObservationSpec>(sensors.Count); |
|
|
|
foreach (var sensor in sensors) |
|
{ |
|
m_SensorShapes.Add(sensor.GetObservationSpec()); |
|
} |
|
} |
|
else |
|
{ |
|
|
|
if (m_SensorShapes.Count != sensors.Count) |
|
{ |
|
Debug.AssertFormat( |
|
m_SensorShapes.Count == sensors.Count, |
|
"Number of Sensors must match. {0} != {1}", |
|
m_SensorShapes.Count, |
|
sensors.Count |
|
); |
|
} |
|
for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++) |
|
{ |
|
var cachedSpec = m_SensorShapes[i]; |
|
var sensorSpec = sensors[i].GetObservationSpec(); |
|
if (cachedSpec.Shape != sensorSpec.Shape) |
|
{ |
|
Debug.AssertFormat( |
|
cachedSpec.Shape == sensorSpec.Shape, |
|
"Sensor shapes must match. {0} != {1}", |
|
cachedSpec.Shape, |
|
sensorSpec.Shape |
|
); |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|