File size: 4,374 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
using System.Collections.Generic;
#if UNITY_2020_1_OR_NEWER
using UnityEngine;
#endif
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
/// </summary>
public class PhysicsBodySensor : ISensor, IBuiltInSensor
{
ObservationSpec m_ObservationSpec;
string m_SensorName;
PoseExtractor m_PoseExtractor;
List<IJointExtractor> m_JointExtractors;
PhysicsSensorSettings m_Settings;
/// <summary>
/// Construct a new PhysicsBodySensor
/// </summary>
/// <param name="poseExtractor"></param>
/// <param name="settings"></param>
/// <param name="sensorName"></param>
public PhysicsBodySensor(
RigidBodyPoseExtractor poseExtractor,
PhysicsSensorSettings settings,
string sensorName
)
{
m_PoseExtractor = poseExtractor;
m_SensorName = sensorName;
m_Settings = settings;
var numJointExtractorObservations = 0;
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach (var rb in poseExtractor.GetEnabledRigidbodies())
{
var jointExtractor = new RigidBodyJointExtractor(rb);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors.Add(jointExtractor);
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName = null)
{
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
var numJointExtractorObservations = 0;
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach (var articBody in poseExtractor.GetEnabledArticulationBodies())
{
var jointExtractor = new ArticulationBodyJointExtractor(articBody);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors.Add(jointExtractor);
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#endif
/// <inheritdoc/>
public ObservationSpec GetObservationSpec()
{
return m_ObservationSpec;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
foreach (var jointExtractor in m_JointExtractors)
{
numWritten += jointExtractor.Write(m_Settings, writer, numWritten);
}
return numWritten;
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
return null;
}
/// <inheritdoc/>
public void Update()
{
if (m_Settings.UseModelSpace)
{
m_PoseExtractor.UpdateModelSpacePoses();
}
if (m_Settings.UseLocalSpace)
{
m_PoseExtractor.UpdateLocalSpacePoses();
}
}
/// <inheritdoc/>
public void Reset() { }
/// <inheritdoc/>
public CompressionSpec GetCompressionSpec()
{
return CompressionSpec.Default();
}
/// <inheritdoc/>
public string GetName()
{
return m_SensorName;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.PhysicsBodySensor;
}
}
}
|