|
using System; |
|
using System.Collections; |
|
using System.Collections.Generic; |
|
using UnityEngine; |
|
using UnityEngine.Profiling; |
|
|
|
namespace Unity.MLAgents.Actuators |
|
{ |
|
|
|
|
|
|
|
internal class ActuatorManager : IList<IActuator> |
|
{ |
|
|
|
List<IActuator> m_Actuators; |
|
|
|
|
|
ActuatorDiscreteActionMask m_DiscreteActionMask; |
|
|
|
ActionSpec m_CombinedActionSpec; |
|
|
|
|
|
|
|
|
|
|
|
bool m_ReadyForExecution; |
|
|
|
|
|
|
|
|
|
internal int SumOfDiscreteBranchSizes { get; private set; } |
|
|
|
|
|
|
|
|
|
internal int NumDiscreteActions { get; private set; } |
|
|
|
|
|
|
|
|
|
internal int NumContinuousActions { get; private set; } |
|
|
|
|
|
|
|
|
|
public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions; |
|
|
|
|
|
|
|
|
|
public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask; |
|
|
|
|
|
|
|
|
|
public ActionBuffers StoredActions { get; private set; } |
|
|
|
|
|
|
|
|
|
|
|
public ActuatorManager(int capacity = 0) |
|
{ |
|
m_Actuators = new List<IActuator>(capacity); |
|
} |
|
|
|
|
|
|
|
|
|
void ReadyActuatorsForExecution() |
|
{ |
|
ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes, |
|
NumDiscreteActions); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal void ReadyActuatorsForExecution(IList<IActuator> actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches) |
|
{ |
|
if (m_ReadyForExecution) |
|
{ |
|
return; |
|
} |
|
#if DEBUG |
|
|
|
ValidateActuators(); |
|
#endif |
|
|
|
|
|
SortActuators(m_Actuators); |
|
var continuousActions = numContinuousActions == 0 ? ActionSegment<float>.Empty : |
|
new ActionSegment<float>(new float[numContinuousActions]); |
|
var discreteActions = numDiscreteBranches == 0 ? ActionSegment<int>.Empty : new ActionSegment<int>(new int[numDiscreteBranches]); |
|
|
|
StoredActions = new ActionBuffers(continuousActions, discreteActions); |
|
m_CombinedActionSpec = CombineActionSpecs(actuators); |
|
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches, m_CombinedActionSpec.BranchSizes); |
|
m_ReadyForExecution = true; |
|
} |
|
|
|
internal static ActionSpec CombineActionSpecs(IList<IActuator> actuators) |
|
{ |
|
int numContinuousActions = 0; |
|
int numDiscreteActions = 0; |
|
|
|
foreach (var actuator in actuators) |
|
{ |
|
numContinuousActions += actuator.ActionSpec.NumContinuousActions; |
|
numDiscreteActions += actuator.ActionSpec.NumDiscreteActions; |
|
} |
|
|
|
int[] combinedBranchSizes; |
|
if (numDiscreteActions == 0) |
|
{ |
|
combinedBranchSizes = Array.Empty<int>(); |
|
} |
|
else |
|
{ |
|
combinedBranchSizes = new int[numDiscreteActions]; |
|
var start = 0; |
|
for (var i = 0; i < actuators.Count; i++) |
|
{ |
|
var branchSizes = actuators[i].ActionSpec.BranchSizes; |
|
if (branchSizes != null) |
|
{ |
|
Array.Copy(branchSizes, 0, combinedBranchSizes, start, branchSizes.Length); |
|
start += branchSizes.Length; |
|
} |
|
} |
|
} |
|
|
|
return new ActionSpec(numContinuousActions, combinedBranchSizes); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public ActionSpec GetCombinedActionSpec() |
|
{ |
|
ReadyActuatorsForExecution(); |
|
return m_CombinedActionSpec; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public void UpdateActions(ActionBuffers actions) |
|
{ |
|
Profiler.BeginSample("ActuatorManager.UpdateActions"); |
|
ReadyActuatorsForExecution(); |
|
UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions); |
|
UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions); |
|
Profiler.EndSample(); |
|
} |
|
|
|
static void UpdateActionArray<T>(ActionSegment<T> sourceActionBuffer, ActionSegment<T> destination) |
|
where T : struct |
|
{ |
|
if (sourceActionBuffer.Length <= 0) |
|
{ |
|
destination.Clear(); |
|
} |
|
else |
|
{ |
|
if (sourceActionBuffer.Length != destination.Length) |
|
{ |
|
Debug.AssertFormat(sourceActionBuffer.Length == destination.Length, |
|
"sourceActionBuffer: {0} is a different size than destination: {1}.", |
|
sourceActionBuffer.Length, |
|
destination.Length); |
|
} |
|
|
|
Array.Copy(sourceActionBuffer.Array, |
|
sourceActionBuffer.Offset, |
|
destination.Array, |
|
destination.Offset, |
|
destination.Length); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void WriteActionMask() |
|
{ |
|
ReadyActuatorsForExecution(); |
|
m_DiscreteActionMask.ResetMask(); |
|
var offset = 0; |
|
for (var i = 0; i < m_Actuators.Count; i++) |
|
{ |
|
var actuator = m_Actuators[i]; |
|
if (actuator.ActionSpec.NumDiscreteActions > 0) |
|
{ |
|
m_DiscreteActionMask.CurrentBranchOffset = offset; |
|
actuator.WriteDiscreteActionMask(m_DiscreteActionMask); |
|
offset += actuator.ActionSpec.NumDiscreteActions; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public void ApplyHeuristic(in ActionBuffers actionBuffersOut) |
|
{ |
|
Profiler.BeginSample("ActuatorManager.ApplyHeuristic"); |
|
var continuousStart = 0; |
|
var discreteStart = 0; |
|
for (var i = 0; i < m_Actuators.Count; i++) |
|
{ |
|
var actuator = m_Actuators[i]; |
|
var numContinuousActions = actuator.ActionSpec.NumContinuousActions; |
|
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; |
|
|
|
if (numContinuousActions == 0 && numDiscreteActions == 0) |
|
{ |
|
continue; |
|
} |
|
|
|
var continuousActions = ActionSegment<float>.Empty; |
|
if (numContinuousActions > 0) |
|
{ |
|
continuousActions = new ActionSegment<float>(actionBuffersOut.ContinuousActions.Array, |
|
continuousStart, |
|
numContinuousActions); |
|
} |
|
|
|
var discreteActions = ActionSegment<int>.Empty; |
|
if (numDiscreteActions > 0) |
|
{ |
|
discreteActions = new ActionSegment<int>(actionBuffersOut.DiscreteActions.Array, |
|
discreteStart, |
|
numDiscreteActions); |
|
} |
|
actuator.Heuristic(new ActionBuffers(continuousActions, discreteActions)); |
|
continuousStart += numContinuousActions; |
|
discreteStart += numDiscreteActions; |
|
} |
|
Profiler.EndSample(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public void ExecuteActions() |
|
{ |
|
Profiler.BeginSample("ActuatorManager.ExecuteActions"); |
|
ReadyActuatorsForExecution(); |
|
var continuousStart = 0; |
|
var discreteStart = 0; |
|
for (var i = 0; i < m_Actuators.Count; i++) |
|
{ |
|
var actuator = m_Actuators[i]; |
|
var numContinuousActions = actuator.ActionSpec.NumContinuousActions; |
|
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; |
|
|
|
if (numContinuousActions == 0 && numDiscreteActions == 0) |
|
{ |
|
continue; |
|
} |
|
|
|
var continuousActions = ActionSegment<float>.Empty; |
|
if (numContinuousActions > 0) |
|
{ |
|
continuousActions = new ActionSegment<float>(StoredActions.ContinuousActions.Array, |
|
continuousStart, |
|
numContinuousActions); |
|
} |
|
|
|
var discreteActions = ActionSegment<int>.Empty; |
|
if (numDiscreteActions > 0) |
|
{ |
|
discreteActions = new ActionSegment<int>(StoredActions.DiscreteActions.Array, |
|
discreteStart, |
|
numDiscreteActions); |
|
} |
|
|
|
actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions)); |
|
continuousStart += numContinuousActions; |
|
discreteStart += numDiscreteActions; |
|
} |
|
Profiler.EndSample(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void ResetData() |
|
{ |
|
if (!m_ReadyForExecution) |
|
{ |
|
return; |
|
} |
|
StoredActions.Clear(); |
|
for (var i = 0; i < m_Actuators.Count; i++) |
|
{ |
|
m_Actuators[i].ResetData(); |
|
} |
|
m_DiscreteActionMask.ResetMask(); |
|
} |
|
|
|
|
|
|
|
|
|
internal static void SortActuators(List<IActuator> actuators) |
|
{ |
|
actuators.Sort((x, y) => string.Compare(x.Name, y.Name, StringComparison.InvariantCulture)); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
void ValidateActuators() |
|
{ |
|
for (var i = 0; i < m_Actuators.Count - 1; i++) |
|
{ |
|
Debug.Assert( |
|
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name), |
|
"Actuator names must be unique."); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
void AddToBufferSizes(IActuator actuatorItem) |
|
{ |
|
if (actuatorItem == null) |
|
{ |
|
return; |
|
} |
|
|
|
NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions; |
|
NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions; |
|
SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
void SubtractFromBufferSize(IActuator actuatorItem) |
|
{ |
|
if (actuatorItem == null) |
|
{ |
|
return; |
|
} |
|
|
|
NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions; |
|
NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions; |
|
SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; |
|
} |
|
|
|
|
|
|
|
|
|
void ClearBufferSizes() |
|
{ |
|
NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public void AddActuators(IActuator[] actuators) |
|
{ |
|
for (var i = 0; i < actuators.Length; i++) |
|
{ |
|
Add(actuators[i]); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public IEnumerator<IActuator> GetEnumerator() |
|
{ |
|
return m_Actuators.GetEnumerator(); |
|
} |
|
|
|
|
|
IEnumerator IEnumerable.GetEnumerator() |
|
{ |
|
return ((IEnumerable)m_Actuators).GetEnumerator(); |
|
} |
|
|
|
|
|
public void Add(IActuator item) |
|
{ |
|
Debug.Assert(m_ReadyForExecution == false, |
|
"Cannot add to the ActuatorManager after its buffers have been initialized"); |
|
m_Actuators.Add(item); |
|
AddToBufferSizes(item); |
|
} |
|
|
|
|
|
public void Clear() |
|
{ |
|
Debug.Assert(m_ReadyForExecution == false, |
|
"Cannot clear the ActuatorManager after its buffers have been initialized"); |
|
m_Actuators.Clear(); |
|
ClearBufferSizes(); |
|
} |
|
|
|
|
|
public bool Contains(IActuator item) |
|
{ |
|
return m_Actuators.Contains(item); |
|
} |
|
|
|
|
|
public void CopyTo(IActuator[] array, int arrayIndex) |
|
{ |
|
m_Actuators.CopyTo(array, arrayIndex); |
|
} |
|
|
|
|
|
public bool Remove(IActuator item) |
|
{ |
|
Debug.Assert(m_ReadyForExecution == false, |
|
"Cannot remove from the ActuatorManager after its buffers have been initialized"); |
|
if (m_Actuators.Remove(item)) |
|
{ |
|
SubtractFromBufferSize(item); |
|
return true; |
|
} |
|
return false; |
|
} |
|
|
|
|
|
public int Count => m_Actuators.Count; |
|
|
|
|
|
public bool IsReadOnly => false; |
|
|
|
|
|
public int IndexOf(IActuator item) |
|
{ |
|
return m_Actuators.IndexOf(item); |
|
} |
|
|
|
|
|
public void Insert(int index, IActuator item) |
|
{ |
|
Debug.Assert(m_ReadyForExecution == false, |
|
"Cannot insert into the ActuatorManager after its buffers have been initialized"); |
|
m_Actuators.Insert(index, item); |
|
AddToBufferSizes(item); |
|
} |
|
|
|
|
|
public void RemoveAt(int index) |
|
{ |
|
Debug.Assert(m_ReadyForExecution == false, |
|
"Cannot remove from the ActuatorManager after its buffers have been initialized"); |
|
var actuator = m_Actuators[index]; |
|
SubtractFromBufferSize(actuator); |
|
m_Actuators.RemoveAt(index); |
|
} |
|
|
|
|
|
public IActuator this[int index] |
|
{ |
|
get => m_Actuators[index]; |
|
set |
|
{ |
|
Debug.Assert(m_ReadyForExecution == false, |
|
"Cannot modify the ActuatorManager after its buffers have been initialized"); |
|
var old = m_Actuators[index]; |
|
SubtractFromBufferSize(old); |
|
m_Actuators[index] = value; |
|
AddToBufferSizes(value); |
|
} |
|
} |
|
} |
|
} |
|
|