|
using UnityEngine; |
|
using Unity.MLAgents; |
|
using Unity.MLAgents.Actuators; |
|
using Unity.MLAgents.Policies; |
|
|
|
public enum Team |
|
{ |
|
Blue = 0, |
|
Purple = 1 |
|
} |
|
|
|
public class AgentSoccer : Agent |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public enum Position |
|
{ |
|
Striker, |
|
Goalie, |
|
Generic |
|
} |
|
|
|
[HideInInspector] |
|
public Team team; |
|
float m_KickPower; |
|
|
|
float m_BallTouch; |
|
public Position position; |
|
|
|
const float k_Power = 2000f; |
|
float m_Existential; |
|
float m_LateralSpeed; |
|
float m_ForwardSpeed; |
|
|
|
|
|
[HideInInspector] |
|
public Rigidbody agentRb; |
|
SoccerSettings m_SoccerSettings; |
|
BehaviorParameters m_BehaviorParameters; |
|
public Vector3 initialPos; |
|
public float rotSign; |
|
|
|
EnvironmentParameters m_ResetParams; |
|
|
|
public override void Initialize() |
|
{ |
|
SoccerEnvController envController = GetComponentInParent<SoccerEnvController>(); |
|
if (envController != null) |
|
{ |
|
m_Existential = 1f / envController.MaxEnvironmentSteps; |
|
} |
|
else |
|
{ |
|
m_Existential = 1f / MaxStep; |
|
} |
|
|
|
m_BehaviorParameters = gameObject.GetComponent<BehaviorParameters>(); |
|
if (m_BehaviorParameters.TeamId == (int)Team.Blue) |
|
{ |
|
team = Team.Blue; |
|
initialPos = new Vector3(transform.position.x - 5f, .5f, transform.position.z); |
|
rotSign = 1f; |
|
} |
|
else |
|
{ |
|
team = Team.Purple; |
|
initialPos = new Vector3(transform.position.x + 5f, .5f, transform.position.z); |
|
rotSign = -1f; |
|
} |
|
if (position == Position.Goalie) |
|
{ |
|
m_LateralSpeed = 1.0f; |
|
m_ForwardSpeed = 1.0f; |
|
} |
|
else if (position == Position.Striker) |
|
{ |
|
m_LateralSpeed = 0.3f; |
|
m_ForwardSpeed = 1.3f; |
|
} |
|
else |
|
{ |
|
m_LateralSpeed = 0.3f; |
|
m_ForwardSpeed = 1.0f; |
|
} |
|
m_SoccerSettings = FindObjectOfType<SoccerSettings>(); |
|
agentRb = GetComponent<Rigidbody>(); |
|
agentRb.maxAngularVelocity = 500; |
|
|
|
m_ResetParams = Academy.Instance.EnvironmentParameters; |
|
} |
|
|
|
public void MoveAgent(ActionSegment<int> act) |
|
{ |
|
var dirToGo = Vector3.zero; |
|
var rotateDir = Vector3.zero; |
|
|
|
m_KickPower = 0f; |
|
|
|
var forwardAxis = act[0]; |
|
var rightAxis = act[1]; |
|
var rotateAxis = act[2]; |
|
|
|
switch (forwardAxis) |
|
{ |
|
case 1: |
|
dirToGo = transform.forward * m_ForwardSpeed; |
|
m_KickPower = 1f; |
|
break; |
|
case 2: |
|
dirToGo = transform.forward * -m_ForwardSpeed; |
|
break; |
|
} |
|
|
|
switch (rightAxis) |
|
{ |
|
case 1: |
|
dirToGo = transform.right * m_LateralSpeed; |
|
break; |
|
case 2: |
|
dirToGo = transform.right * -m_LateralSpeed; |
|
break; |
|
} |
|
|
|
switch (rotateAxis) |
|
{ |
|
case 1: |
|
rotateDir = transform.up * -1f; |
|
break; |
|
case 2: |
|
rotateDir = transform.up * 1f; |
|
break; |
|
} |
|
|
|
transform.Rotate(rotateDir, Time.deltaTime * 100f); |
|
agentRb.AddForce(dirToGo * m_SoccerSettings.agentRunSpeed, |
|
ForceMode.VelocityChange); |
|
} |
|
|
|
public override void OnActionReceived(ActionBuffers actionBuffers) |
|
|
|
{ |
|
|
|
if (position == Position.Goalie) |
|
{ |
|
|
|
AddReward(m_Existential); |
|
} |
|
else if (position == Position.Striker) |
|
{ |
|
|
|
AddReward(-m_Existential); |
|
} |
|
MoveAgent(actionBuffers.DiscreteActions); |
|
} |
|
|
|
public override void Heuristic(in ActionBuffers actionsOut) |
|
{ |
|
var discreteActionsOut = actionsOut.DiscreteActions; |
|
|
|
if (Input.GetKey(KeyCode.W)) |
|
{ |
|
discreteActionsOut[0] = 1; |
|
} |
|
if (Input.GetKey(KeyCode.S)) |
|
{ |
|
discreteActionsOut[0] = 2; |
|
} |
|
|
|
if (Input.GetKey(KeyCode.A)) |
|
{ |
|
discreteActionsOut[2] = 1; |
|
} |
|
if (Input.GetKey(KeyCode.D)) |
|
{ |
|
discreteActionsOut[2] = 2; |
|
} |
|
|
|
if (Input.GetKey(KeyCode.E)) |
|
{ |
|
discreteActionsOut[1] = 1; |
|
} |
|
if (Input.GetKey(KeyCode.Q)) |
|
{ |
|
discreteActionsOut[1] = 2; |
|
} |
|
} |
|
|
|
|
|
|
|
void OnCollisionEnter(Collision c) |
|
{ |
|
var force = k_Power * m_KickPower; |
|
if (position == Position.Goalie) |
|
{ |
|
force = k_Power; |
|
} |
|
if (c.gameObject.CompareTag("ball")) |
|
{ |
|
AddReward(.2f * m_BallTouch); |
|
var dir = c.contacts[0].point - transform.position; |
|
dir = dir.normalized; |
|
c.gameObject.GetComponent<Rigidbody>().AddForce(dir * force); |
|
} |
|
} |
|
|
|
public override void OnEpisodeBegin() |
|
{ |
|
m_BallTouch = m_ResetParams.GetWithDefault("ball_touch", 0); |
|
} |
|
|
|
} |
|
|