File size: 6,772 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public class RigidBodyPoseExtractorTests
{
[TearDown]
public void RemoveGameObjects()
{
var objects = GameObject.FindObjectsOfType<GameObject>();
foreach (var o in objects)
{
UnityEngine.Object.DestroyImmediate(o);
}
}
[Test]
public void TestNullRoot()
{
var poseExtractor = new RigidBodyPoseExtractor(null);
// These should be no-ops
poseExtractor.UpdateLocalSpacePoses();
poseExtractor.UpdateModelSpacePoses();
Assert.AreEqual(0, poseExtractor.NumPoses);
}
[Test]
public void TestSingleBody()
{
var go = new GameObject();
var rootRb = go.AddComponent<Rigidbody>();
var poseExtractor = new RigidBodyPoseExtractor(rootRb);
Assert.AreEqual(1, poseExtractor.NumPoses);
// Also pass the GameObject
poseExtractor = new RigidBodyPoseExtractor(rootRb, go);
Assert.AreEqual(1, poseExtractor.NumPoses);
}
[Test]
public void TestNoBodiesFound()
{
// Check that if we can't find any bodies under the game object, we get an empty extractor
var gameObj = new GameObject();
var rootRb = gameObj.AddComponent<Rigidbody>();
var otherGameObj = new GameObject();
var poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj);
Assert.AreEqual(0, poseExtractor.NumPoses);
// Add an RB under the other GameObject. Constructor will find a rigid body, but not the root.
otherGameObj.AddComponent<Rigidbody>();
poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj);
Assert.AreEqual(0, poseExtractor.NumPoses);
}
[Test]
public void TestTwoBodies()
{
// * rootObj
// - rb1
// * go2
// - rb2
// - joint
var rootObj = new GameObject();
var rb1 = rootObj.AddComponent<Rigidbody>();
var go2 = new GameObject();
var rb2 = go2.AddComponent<Rigidbody>();
go2.transform.SetParent(rootObj.transform);
var joint = go2.AddComponent<ConfigurableJoint>();
joint.connectedBody = rb1;
var poseExtractor = new RigidBodyPoseExtractor(rb1);
Assert.AreEqual(2, poseExtractor.NumPoses);
rb1.position = new Vector3(1, 0, 0);
rb1.rotation = Quaternion.Euler(0, 13.37f, 0);
rb1.velocity = new Vector3(2, 0, 0);
Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position);
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation);
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0));
// Check DisplayNodes gives expected results
var displayNodes = poseExtractor.GetDisplayNodes();
Assert.AreEqual(2, displayNodes.Count);
Assert.AreEqual(rb1, displayNodes[0].NodeObject);
Assert.AreEqual(false, displayNodes[0].Enabled);
Assert.AreEqual(rb2, displayNodes[1].NodeObject);
Assert.AreEqual(true, displayNodes[1].Enabled);
}
[Test]
public void TestTwoBodiesVirtualRoot()
{
// * virtualRoot
// * rootObj
// - rb1
// * go2
// - rb2
// - joint
var virtualRoot = new GameObject("I am vroot");
var rootObj = new GameObject();
var rb1 = rootObj.AddComponent<Rigidbody>();
var go2 = new GameObject();
go2.AddComponent<Rigidbody>();
go2.transform.SetParent(rootObj.transform);
var joint = go2.AddComponent<ConfigurableJoint>();
joint.connectedBody = rb1;
var poseExtractor = new RigidBodyPoseExtractor(rb1, null, virtualRoot);
Assert.AreEqual(3, poseExtractor.NumPoses);
// "body" 0 has no parent
Assert.AreEqual(-1, poseExtractor.GetParentIndex(0));
// body 1 has parent 0
Assert.AreEqual(0, poseExtractor.GetParentIndex(1));
var virtualRootPos = new Vector3(0, 2, 0);
var virtualRootRot = Quaternion.Euler(0, 42, 0);
virtualRoot.transform.position = virtualRootPos;
virtualRoot.transform.rotation = virtualRootRot;
Assert.AreEqual(virtualRootPos, poseExtractor.GetPoseAt(0).position);
Assert.IsTrue(virtualRootRot == poseExtractor.GetPoseAt(0).rotation);
Assert.AreEqual(Vector3.zero, poseExtractor.GetLinearVelocityAt(0));
// Same as above test, but using index 1
rb1.position = new Vector3(1, 0, 0);
rb1.rotation = Quaternion.Euler(0, 13.37f, 0);
rb1.velocity = new Vector3(2, 0, 0);
Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position);
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation);
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1));
}
[Test]
public void TestBodyPosesEnabledDictionary()
{
// * rootObj
// - rb1
// * go2
// - rb2
// - joint
var rootObj = new GameObject();
var rb1 = rootObj.AddComponent<Rigidbody>();
var go2 = new GameObject();
var rb2 = go2.AddComponent<Rigidbody>();
go2.transform.SetParent(rootObj.transform);
var joint = go2.AddComponent<ConfigurableJoint>();
joint.connectedBody = rb1;
var poseExtractor = new RigidBodyPoseExtractor(rb1);
// Expect the root body disabled and the attached one enabled.
Assert.IsFalse(poseExtractor.IsPoseEnabled(0));
Assert.IsTrue(poseExtractor.IsPoseEnabled(1));
var bodyPosesEnabled = poseExtractor.GetBodyPosesEnabled();
Assert.IsFalse(bodyPosesEnabled[rb1]);
Assert.IsTrue(bodyPosesEnabled[rb2]);
// Swap the values
bodyPosesEnabled[rb1] = true;
bodyPosesEnabled[rb2] = false;
var poseExtractor2 = new RigidBodyPoseExtractor(rb1, null, null, bodyPosesEnabled);
Assert.IsTrue(poseExtractor2.IsPoseEnabled(0));
Assert.IsFalse(poseExtractor2.IsPoseEnabled(1));
}
}
}
|