File size: 11,998 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgentsExamples;
using Unity.MLAgents.Sensors;
using BodyPart = Unity.MLAgentsExamples.BodyPart;
using Random = UnityEngine.Random;
public class WalkerAgent : Agent
{
[Header("Walk Speed")]
[Range(0.1f, 10)]
[SerializeField]
//The walking speed to try and achieve
private float m_TargetWalkingSpeed = 10;
public float MTargetWalkingSpeed // property
{
get { return m_TargetWalkingSpeed; }
set { m_TargetWalkingSpeed = Mathf.Clamp(value, .1f, m_maxWalkingSpeed); }
}
const float m_maxWalkingSpeed = 10; //The max walking speed
//Should the agent sample a new goal velocity each episode?
//If true, walkSpeed will be randomly set between zero and m_maxWalkingSpeed in OnEpisodeBegin()
//If false, the goal velocity will be walkingSpeed
public bool randomizeWalkSpeedEachEpisode;
//The direction an agent will walk during training.
private Vector3 m_WorldDirToWalk = Vector3.right;
[Header("Target To Walk Towards")] public Transform target; //Target the agent will walk towards during training.
[Header("Body Parts")] public Transform hips;
public Transform chest;
public Transform spine;
public Transform head;
public Transform thighL;
public Transform shinL;
public Transform footL;
public Transform thighR;
public Transform shinR;
public Transform footR;
public Transform armL;
public Transform forearmL;
public Transform handL;
public Transform armR;
public Transform forearmR;
public Transform handR;
//This will be used as a stabilized model space reference point for observations
//Because ragdolls can move erratically during training, using a stabilized reference transform improves learning
OrientationCubeController m_OrientationCube;
//The indicator graphic gameobject that points towards the target
DirectionIndicator m_DirectionIndicator;
JointDriveController m_JdController;
EnvironmentParameters m_ResetParams;
public override void Initialize()
{
m_OrientationCube = GetComponentInChildren<OrientationCubeController>();
m_DirectionIndicator = GetComponentInChildren<DirectionIndicator>();
//Setup each body part
m_JdController = GetComponent<JointDriveController>();
m_JdController.SetupBodyPart(hips);
m_JdController.SetupBodyPart(chest);
m_JdController.SetupBodyPart(spine);
m_JdController.SetupBodyPart(head);
m_JdController.SetupBodyPart(thighL);
m_JdController.SetupBodyPart(shinL);
m_JdController.SetupBodyPart(footL);
m_JdController.SetupBodyPart(thighR);
m_JdController.SetupBodyPart(shinR);
m_JdController.SetupBodyPart(footR);
m_JdController.SetupBodyPart(armL);
m_JdController.SetupBodyPart(forearmL);
m_JdController.SetupBodyPart(handL);
m_JdController.SetupBodyPart(armR);
m_JdController.SetupBodyPart(forearmR);
m_JdController.SetupBodyPart(handR);
m_ResetParams = Academy.Instance.EnvironmentParameters;
}
/// <summary>
/// Loop over body parts and reset them to initial conditions.
/// </summary>
public override void OnEpisodeBegin()
{
//Reset all of the body parts
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
bodyPart.Reset(bodyPart);
}
//Random start rotation to help generalize
hips.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0);
UpdateOrientationObjects();
//Set our goal walking speed
MTargetWalkingSpeed =
randomizeWalkSpeedEachEpisode ? Random.Range(0.1f, m_maxWalkingSpeed) : MTargetWalkingSpeed;
}
/// <summary>
/// Add relevant information on each body part to observations.
/// </summary>
public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
{
//GROUND CHECK
sensor.AddObservation(bp.groundContact.touchingGround); // Is this bp touching the ground
//Get velocities in the context of our orientation cube's space
//Note: You can get these velocities in world space as well but it may not train as well.
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.velocity));
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity));
//Get position relative to hips in the context of our orientation cube's space
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.position - hips.position));
if (bp.rb.transform != hips && bp.rb.transform != handL && bp.rb.transform != handR)
{
sensor.AddObservation(bp.rb.transform.localRotation);
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}
/// <summary>
/// Loop over body parts to add them to observation.
/// </summary>
public override void CollectObservations(VectorSensor sensor)
{
var cubeForward = m_OrientationCube.transform.forward;
//velocity we want to match
var velGoal = cubeForward * MTargetWalkingSpeed;
//ragdoll's avg vel
var avgVel = GetAvgVelocity();
//current ragdoll velocity. normalized
sensor.AddObservation(Vector3.Distance(velGoal, avgVel));
//avg body vel relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(avgVel));
//vel goal relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(velGoal));
//rotation deltas
sensor.AddObservation(Quaternion.FromToRotation(hips.forward, cubeForward));
sensor.AddObservation(Quaternion.FromToRotation(head.forward, cubeForward));
//Position of target position relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(target.transform.position));
foreach (var bodyPart in m_JdController.bodyPartsList)
{
CollectObservationBodyPart(bodyPart, sensor);
}
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
var bpDict = m_JdController.bodyPartsDict;
var i = -1;
var continuousActions = actionBuffers.ContinuousActions;
bpDict[chest].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[spine].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[thighL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[thighR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[shinL].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[shinR].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[footR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[footL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[armL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[armR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[forearmL].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[forearmR].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[head].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
//update joint strength settings
bpDict[chest].SetJointStrength(continuousActions[++i]);
bpDict[spine].SetJointStrength(continuousActions[++i]);
bpDict[head].SetJointStrength(continuousActions[++i]);
bpDict[thighL].SetJointStrength(continuousActions[++i]);
bpDict[shinL].SetJointStrength(continuousActions[++i]);
bpDict[footL].SetJointStrength(continuousActions[++i]);
bpDict[thighR].SetJointStrength(continuousActions[++i]);
bpDict[shinR].SetJointStrength(continuousActions[++i]);
bpDict[footR].SetJointStrength(continuousActions[++i]);
bpDict[armL].SetJointStrength(continuousActions[++i]);
bpDict[forearmL].SetJointStrength(continuousActions[++i]);
bpDict[armR].SetJointStrength(continuousActions[++i]);
bpDict[forearmR].SetJointStrength(continuousActions[++i]);
}
//Update OrientationCube and DirectionIndicator
void UpdateOrientationObjects()
{
m_WorldDirToWalk = target.position - hips.position;
m_OrientationCube.UpdateOrientation(hips, target);
if (m_DirectionIndicator)
{
m_DirectionIndicator.MatchOrientation(m_OrientationCube.transform);
}
}
void FixedUpdate()
{
UpdateOrientationObjects();
var cubeForward = m_OrientationCube.transform.forward;
// Set reward for this step according to mixture of the following elements.
// a. Match target speed
//This reward will approach 1 if it matches perfectly and approach zero as it deviates
var matchSpeedReward = GetMatchingVelocityReward(cubeForward * MTargetWalkingSpeed, GetAvgVelocity());
//Check for NaNs
if (float.IsNaN(matchSpeedReward))
{
throw new ArgumentException(
"NaN in moveTowardsTargetReward.\n" +
$" cubeForward: {cubeForward}\n" +
$" hips.velocity: {m_JdController.bodyPartsDict[hips].rb.velocity}\n" +
$" maximumWalkingSpeed: {m_maxWalkingSpeed}"
);
}
// b. Rotation alignment with target direction.
//This reward will approach 1 if it faces the target direction perfectly and approach zero as it deviates
var headForward = head.forward;
headForward.y = 0;
// var lookAtTargetReward = (Vector3.Dot(cubeForward, head.forward) + 1) * .5F;
var lookAtTargetReward = (Vector3.Dot(cubeForward, headForward) + 1) * .5F;
//Check for NaNs
if (float.IsNaN(lookAtTargetReward))
{
throw new ArgumentException(
"NaN in lookAtTargetReward.\n" +
$" cubeForward: {cubeForward}\n" +
$" head.forward: {head.forward}"
);
}
AddReward(matchSpeedReward * lookAtTargetReward);
}
//Returns the average velocity of all of the body parts
//Using the velocity of the hips only has shown to result in more erratic movement from the limbs, so...
//...using the average helps prevent this erratic movement
Vector3 GetAvgVelocity()
{
Vector3 velSum = Vector3.zero;
//ALL RBS
int numOfRb = 0;
foreach (var item in m_JdController.bodyPartsList)
{
numOfRb++;
velSum += item.rb.velocity;
}
var avgVel = velSum / numOfRb;
return avgVel;
}
//normalized value of the difference in avg speed vs goal walking speed.
public float GetMatchingVelocityReward(Vector3 velocityGoal, Vector3 actualVelocity)
{
//distance between our actual velocity and goal velocity
var velDeltaMagnitude = Mathf.Clamp(Vector3.Distance(actualVelocity, velocityGoal), 0, MTargetWalkingSpeed);
//return the value on a declining sigmoid shaped curve that decays from 1 to 0
//This reward will approach 1 if it matches perfectly and approach zero as it deviates
return Mathf.Pow(1 - Mathf.Pow(velDeltaMagnitude / MTargetWalkingSpeed, 2), 2);
}
/// <summary>
/// Agent touched the target
/// </summary>
public void TouchedTarget()
{
AddReward(1f);
}
}
|