File size: 3,239 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
using System;
using System.Linq;
using UnityEngine;
using Random = UnityEngine.Random;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class PyramidAgent : Agent
{
public GameObject area;
PyramidArea m_MyArea;
Rigidbody m_AgentRb;
PyramidSwitch m_SwitchLogic;
public GameObject areaSwitch;
public bool useVectorObs;
public override void Initialize()
{
m_AgentRb = GetComponent<Rigidbody>();
m_MyArea = area.GetComponent<PyramidArea>();
m_SwitchLogic = areaSwitch.GetComponent<PyramidSwitch>();
}
public override void CollectObservations(VectorSensor sensor)
{
if (useVectorObs)
{
sensor.AddObservation(m_SwitchLogic.GetState());
sensor.AddObservation(transform.InverseTransformDirection(m_AgentRb.velocity));
}
}
public void MoveAgent(ActionSegment<int> act)
{
var dirToGo = Vector3.zero;
var rotateDir = Vector3.zero;
var action = act[0];
switch (action)
{
case 1:
dirToGo = transform.forward * 1f;
break;
case 2:
dirToGo = transform.forward * -1f;
break;
case 3:
rotateDir = transform.up * 1f;
break;
case 4:
rotateDir = transform.up * -1f;
break;
}
transform.Rotate(rotateDir, Time.deltaTime * 200f);
m_AgentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange);
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
AddReward(-1f / MaxStep);
MoveAgent(actionBuffers.DiscreteActions);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
}
else if (Input.GetKey(KeyCode.W))
{
discreteActionsOut[0] = 1;
}
else if (Input.GetKey(KeyCode.A))
{
discreteActionsOut[0] = 4;
}
else if (Input.GetKey(KeyCode.S))
{
discreteActionsOut[0] = 2;
}
}
public override void OnEpisodeBegin()
{
var enumerable = Enumerable.Range(0, 9).OrderBy(x => Guid.NewGuid()).Take(9);
var items = enumerable.ToArray();
m_MyArea.CleanPyramidArea();
m_AgentRb.velocity = Vector3.zero;
m_MyArea.PlaceObject(gameObject, items[0]);
transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360)));
m_SwitchLogic.ResetSwitch(items[1], items[2]);
m_MyArea.CreateStonePyramid(1, items[3]);
m_MyArea.CreateStonePyramid(1, items[4]);
m_MyArea.CreateStonePyramid(1, items[5]);
m_MyArea.CreateStonePyramid(1, items[6]);
m_MyArea.CreateStonePyramid(1, items[7]);
m_MyArea.CreateStonePyramid(1, items[8]);
}
void OnCollisionEnter(Collision collision)
{
if (collision.gameObject.CompareTag("goal"))
{
SetReward(2f);
EndEpisode();
}
}
}
|