|
using System; |
|
using System.Linq; |
|
using UnityEngine; |
|
using Random = UnityEngine.Random; |
|
using Unity.MLAgents; |
|
using Unity.MLAgents.Actuators; |
|
using Unity.MLAgents.Sensors; |
|
|
|
public class PyramidAgent : Agent |
|
{ |
|
public GameObject area; |
|
PyramidArea m_MyArea; |
|
Rigidbody m_AgentRb; |
|
PyramidSwitch m_SwitchLogic; |
|
public GameObject areaSwitch; |
|
public bool useVectorObs; |
|
|
|
public override void Initialize() |
|
{ |
|
m_AgentRb = GetComponent<Rigidbody>(); |
|
m_MyArea = area.GetComponent<PyramidArea>(); |
|
m_SwitchLogic = areaSwitch.GetComponent<PyramidSwitch>(); |
|
} |
|
|
|
public override void CollectObservations(VectorSensor sensor) |
|
{ |
|
if (useVectorObs) |
|
{ |
|
sensor.AddObservation(m_SwitchLogic.GetState()); |
|
sensor.AddObservation(transform.InverseTransformDirection(m_AgentRb.velocity)); |
|
} |
|
} |
|
|
|
public void MoveAgent(ActionSegment<int> act) |
|
{ |
|
var dirToGo = Vector3.zero; |
|
var rotateDir = Vector3.zero; |
|
|
|
var action = act[0]; |
|
switch (action) |
|
{ |
|
case 1: |
|
dirToGo = transform.forward * 1f; |
|
break; |
|
case 2: |
|
dirToGo = transform.forward * -1f; |
|
break; |
|
case 3: |
|
rotateDir = transform.up * 1f; |
|
break; |
|
case 4: |
|
rotateDir = transform.up * -1f; |
|
break; |
|
} |
|
transform.Rotate(rotateDir, Time.deltaTime * 200f); |
|
m_AgentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange); |
|
} |
|
|
|
public override void OnActionReceived(ActionBuffers actionBuffers) |
|
|
|
{ |
|
AddReward(-1f / MaxStep); |
|
MoveAgent(actionBuffers.DiscreteActions); |
|
} |
|
|
|
public override void Heuristic(in ActionBuffers actionsOut) |
|
{ |
|
var discreteActionsOut = actionsOut.DiscreteActions; |
|
if (Input.GetKey(KeyCode.D)) |
|
{ |
|
discreteActionsOut[0] = 3; |
|
} |
|
else if (Input.GetKey(KeyCode.W)) |
|
{ |
|
discreteActionsOut[0] = 1; |
|
} |
|
else if (Input.GetKey(KeyCode.A)) |
|
{ |
|
discreteActionsOut[0] = 4; |
|
} |
|
else if (Input.GetKey(KeyCode.S)) |
|
{ |
|
discreteActionsOut[0] = 2; |
|
} |
|
} |
|
|
|
public override void OnEpisodeBegin() |
|
{ |
|
var enumerable = Enumerable.Range(0, 9).OrderBy(x => Guid.NewGuid()).Take(9); |
|
var items = enumerable.ToArray(); |
|
|
|
m_MyArea.CleanPyramidArea(); |
|
|
|
m_AgentRb.velocity = Vector3.zero; |
|
m_MyArea.PlaceObject(gameObject, items[0]); |
|
transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360))); |
|
|
|
m_SwitchLogic.ResetSwitch(items[1], items[2]); |
|
m_MyArea.CreateStonePyramid(1, items[3]); |
|
m_MyArea.CreateStonePyramid(1, items[4]); |
|
m_MyArea.CreateStonePyramid(1, items[5]); |
|
m_MyArea.CreateStonePyramid(1, items[6]); |
|
m_MyArea.CreateStonePyramid(1, items[7]); |
|
m_MyArea.CreateStonePyramid(1, items[8]); |
|
} |
|
|
|
void OnCollisionEnter(Collision collision) |
|
{ |
|
if (collision.gameObject.CompareTag("goal")) |
|
{ |
|
SetReward(2f); |
|
EndEpisode(); |
|
} |
|
} |
|
} |
|
|