|
using System; |
|
using System.Reflection; |
|
using NUnit.Framework; |
|
using UnityEngine; |
|
|
|
namespace Unity.MLAgents.Tests |
|
{ |
|
public class MultiAgentGroupTests |
|
{ |
|
class TestAgent : Agent |
|
{ |
|
internal int _GroupId |
|
{ |
|
get |
|
{ |
|
return (int)typeof(Agent).GetField("m_GroupId", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); |
|
} |
|
} |
|
|
|
internal float _GroupReward |
|
{ |
|
get |
|
{ |
|
return (float)typeof(Agent).GetField("m_GroupReward", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); |
|
} |
|
} |
|
|
|
internal Action<Agent> _OnAgentDisabledActions |
|
{ |
|
get |
|
{ |
|
return (Action<Agent>)typeof(Agent).GetField("OnAgentDisabled", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); |
|
} |
|
} |
|
} |
|
|
|
[Test] |
|
public void TestRegisteredAgentGroupId() |
|
{ |
|
var agentGo = new GameObject("TestAgent"); |
|
agentGo.AddComponent<TestAgent>(); |
|
var agent = agentGo.GetComponent<TestAgent>(); |
|
|
|
|
|
SimpleMultiAgentGroup agentGroup1 = new SimpleMultiAgentGroup(); |
|
agentGroup1.RegisterAgent(agent); |
|
Assert.AreEqual(agentGroup1.GetId(), agent._GroupId); |
|
Assert.IsNotNull(agent._OnAgentDisabledActions); |
|
|
|
|
|
SimpleMultiAgentGroup agentGroup2 = new SimpleMultiAgentGroup(); |
|
Assert.Throws<UnityAgentsException>( |
|
() => agentGroup2.RegisterAgent(agent)); |
|
Assert.AreEqual(agentGroup1.GetId(), agent._GroupId); |
|
|
|
|
|
agentGroup1.UnregisterAgent(agent); |
|
Assert.AreEqual(0, agent._GroupId); |
|
Assert.IsNull(agent._OnAgentDisabledActions); |
|
|
|
|
|
agentGroup2.RegisterAgent(agent); |
|
Assert.AreEqual(agentGroup2.GetId(), agent._GroupId); |
|
Assert.IsNotNull(agent._OnAgentDisabledActions); |
|
} |
|
|
|
[Test] |
|
public void TestRegisterMultipleAgent() |
|
{ |
|
var agentGo1 = new GameObject("TestAgent"); |
|
agentGo1.AddComponent<TestAgent>(); |
|
var agent1 = agentGo1.GetComponent<TestAgent>(); |
|
var agentGo2 = new GameObject("TestAgent"); |
|
agentGo2.AddComponent<TestAgent>(); |
|
var agent2 = agentGo2.GetComponent<TestAgent>(); |
|
|
|
SimpleMultiAgentGroup agentGroup = new SimpleMultiAgentGroup(); |
|
agentGroup.RegisterAgent(agent1); |
|
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); |
|
agentGroup.UnregisterAgent(agent2); |
|
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); |
|
agentGroup.UnregisterAgent(agent1); |
|
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 0); |
|
agentGroup.RegisterAgent(agent1); |
|
agentGroup.RegisterAgent(agent1); |
|
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); |
|
agentGroup.RegisterAgent(agent2); |
|
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 2); |
|
|
|
|
|
agentGroup.AddGroupReward(0.1f); |
|
Assert.AreEqual(0.1f, agent1._GroupReward); |
|
agentGroup.AddGroupReward(0.5f); |
|
Assert.AreEqual(0.6f, agent1._GroupReward); |
|
agentGroup.SetGroupReward(0.3f); |
|
Assert.AreEqual(0.3f, agent1._GroupReward); |
|
|
|
agentGroup.UnregisterAgent(agent1); |
|
agentGroup.AddGroupReward(0.2f); |
|
Assert.AreEqual(0.3f, agent1._GroupReward); |
|
Assert.AreEqual(0.5f, agent2._GroupReward); |
|
|
|
|
|
agentGroup.Dispose(); |
|
Assert.AreEqual(0, agent1._GroupId); |
|
Assert.AreEqual(0, agent2._GroupId); |
|
} |
|
|
|
[Test] |
|
public void TestGroupIdCounter() |
|
{ |
|
SimpleMultiAgentGroup group1 = new SimpleMultiAgentGroup(); |
|
SimpleMultiAgentGroup group2 = new SimpleMultiAgentGroup(); |
|
|
|
Assert.AreNotEqual(group1.GetId(), group2.GetId()); |
|
} |
|
} |
|
} |
|
|