File size: 3,914 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
using System.Collections.Generic;
using Unity.MLAgents;
using UnityEngine;
public class SoccerEnvController : MonoBehaviour
{
[System.Serializable]
public class PlayerInfo
{
public AgentSoccer Agent;
[HideInInspector]
public Vector3 StartingPos;
[HideInInspector]
public Quaternion StartingRot;
[HideInInspector]
public Rigidbody Rb;
}
/// <summary>
/// Max Academy steps before this platform resets
/// </summary>
/// <returns></returns>
[Tooltip("Max Environment Steps")] public int MaxEnvironmentSteps = 25000;
/// <summary>
/// The area bounds.
/// </summary>
/// <summary>
/// We will be changing the ground material based on success/failue
/// </summary>
public GameObject ball;
[HideInInspector]
public Rigidbody ballRb;
Vector3 m_BallStartingPos;
//List of Agents On Platform
public List<PlayerInfo> AgentsList = new List<PlayerInfo>();
private SoccerSettings m_SoccerSettings;
private SimpleMultiAgentGroup m_BlueAgentGroup;
private SimpleMultiAgentGroup m_PurpleAgentGroup;
private int m_ResetTimer;
void Start()
{
m_SoccerSettings = FindObjectOfType<SoccerSettings>();
// Initialize TeamManager
m_BlueAgentGroup = new SimpleMultiAgentGroup();
m_PurpleAgentGroup = new SimpleMultiAgentGroup();
ballRb = ball.GetComponent<Rigidbody>();
m_BallStartingPos = new Vector3(ball.transform.position.x, ball.transform.position.y, ball.transform.position.z);
foreach (var item in AgentsList)
{
item.StartingPos = item.Agent.transform.position;
item.StartingRot = item.Agent.transform.rotation;
item.Rb = item.Agent.GetComponent<Rigidbody>();
if (item.Agent.team == Team.Blue)
{
m_BlueAgentGroup.RegisterAgent(item.Agent);
}
else
{
m_PurpleAgentGroup.RegisterAgent(item.Agent);
}
}
ResetScene();
}
void FixedUpdate()
{
m_ResetTimer += 1;
if (m_ResetTimer >= MaxEnvironmentSteps && MaxEnvironmentSteps > 0)
{
m_BlueAgentGroup.GroupEpisodeInterrupted();
m_PurpleAgentGroup.GroupEpisodeInterrupted();
ResetScene();
}
}
public void ResetBall()
{
var randomPosX = Random.Range(-2.5f, 2.5f);
var randomPosZ = Random.Range(-2.5f, 2.5f);
ball.transform.position = m_BallStartingPos + new Vector3(randomPosX, 0f, randomPosZ);
ballRb.velocity = Vector3.zero;
ballRb.angularVelocity = Vector3.zero;
}
public void GoalTouched(Team scoredTeam)
{
if (scoredTeam == Team.Blue)
{
m_BlueAgentGroup.AddGroupReward(1 - (float)m_ResetTimer / MaxEnvironmentSteps);
m_PurpleAgentGroup.AddGroupReward(-1);
}
else
{
m_PurpleAgentGroup.AddGroupReward(1 - (float)m_ResetTimer / MaxEnvironmentSteps);
m_BlueAgentGroup.AddGroupReward(-1);
}
m_PurpleAgentGroup.EndGroupEpisode();
m_BlueAgentGroup.EndGroupEpisode();
ResetScene();
}
public void ResetScene()
{
m_ResetTimer = 0;
//Reset Agents
foreach (var item in AgentsList)
{
var randomPosX = Random.Range(-5f, 5f);
var newStartPos = item.Agent.initialPos + new Vector3(randomPosX, 0f, 0f);
var rot = item.Agent.rotSign * Random.Range(80.0f, 100.0f);
var newRot = Quaternion.Euler(0, rot, 0);
item.Agent.transform.SetPositionAndRotation(newStartPos, newRot);
item.Rb.velocity = Vector3.zero;
item.Rb.angularVelocity = Vector3.zero;
}
//Reset Ball
ResetBall();
}
}
|