|
using System.Collections.Generic; |
|
#if UNITY_2020_1_OR_NEWER |
|
using UnityEngine; |
|
#endif |
|
using Unity.MLAgents.Sensors; |
|
|
|
namespace Unity.MLAgents.Extensions.Sensors |
|
{ |
|
|
|
|
|
|
|
public class PhysicsBodySensor : ISensor, IBuiltInSensor |
|
{ |
|
ObservationSpec m_ObservationSpec; |
|
string m_SensorName; |
|
|
|
PoseExtractor m_PoseExtractor; |
|
List<IJointExtractor> m_JointExtractors; |
|
PhysicsSensorSettings m_Settings; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public PhysicsBodySensor( |
|
RigidBodyPoseExtractor poseExtractor, |
|
PhysicsSensorSettings settings, |
|
string sensorName |
|
) |
|
{ |
|
m_PoseExtractor = poseExtractor; |
|
m_SensorName = sensorName; |
|
m_Settings = settings; |
|
|
|
var numJointExtractorObservations = 0; |
|
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses); |
|
foreach (var rb in poseExtractor.GetEnabledRigidbodies()) |
|
{ |
|
var jointExtractor = new RigidBodyJointExtractor(rb); |
|
numJointExtractorObservations += jointExtractor.NumObservations(settings); |
|
m_JointExtractors.Add(jointExtractor); |
|
} |
|
|
|
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); |
|
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations); |
|
} |
|
|
|
#if UNITY_2020_1_OR_NEWER |
|
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName = null) |
|
{ |
|
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody); |
|
m_PoseExtractor = poseExtractor; |
|
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; |
|
m_Settings = settings; |
|
|
|
var numJointExtractorObservations = 0; |
|
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses); |
|
foreach (var articBody in poseExtractor.GetEnabledArticulationBodies()) |
|
{ |
|
var jointExtractor = new ArticulationBodyJointExtractor(articBody); |
|
numJointExtractorObservations += jointExtractor.NumObservations(settings); |
|
m_JointExtractors.Add(jointExtractor); |
|
} |
|
|
|
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); |
|
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations); |
|
} |
|
|
|
#endif |
|
|
|
|
|
public ObservationSpec GetObservationSpec() |
|
{ |
|
return m_ObservationSpec; |
|
} |
|
|
|
|
|
public int Write(ObservationWriter writer) |
|
{ |
|
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); |
|
foreach (var jointExtractor in m_JointExtractors) |
|
{ |
|
numWritten += jointExtractor.Write(m_Settings, writer, numWritten); |
|
} |
|
return numWritten; |
|
} |
|
|
|
|
|
public byte[] GetCompressedObservation() |
|
{ |
|
return null; |
|
} |
|
|
|
|
|
public void Update() |
|
{ |
|
if (m_Settings.UseModelSpace) |
|
{ |
|
m_PoseExtractor.UpdateModelSpacePoses(); |
|
} |
|
|
|
if (m_Settings.UseLocalSpace) |
|
{ |
|
m_PoseExtractor.UpdateLocalSpacePoses(); |
|
} |
|
} |
|
|
|
|
|
public void Reset() { } |
|
|
|
|
|
public CompressionSpec GetCompressionSpec() |
|
{ |
|
return CompressionSpec.Default(); |
|
} |
|
|
|
|
|
public string GetName() |
|
{ |
|
return m_SensorName; |
|
} |
|
|
|
|
|
public BuiltInSensorType GetBuiltInSensorType() |
|
{ |
|
return BuiltInSensorType.PhysicsBodySensor; |
|
} |
|
} |
|
} |
|
|