File size: 2,458 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
namespace Unity.MLAgents.Inference
{
/// <summary>
/// Contains the names of the input and output tensors for the Inference Brain.
/// </summary>
internal static class TensorNames
{
public const string BatchSizePlaceholder = "batch_size";
public const string SequenceLengthPlaceholder = "sequence_length";
public const string VectorObservationPlaceholder = "vector_observation";
public const string RecurrentInPlaceholder = "recurrent_in";
public const string VisualObservationPlaceholderPrefix = "visual_observation_";
public const string ObservationPlaceholderPrefix = "obs_";
public const string PreviousActionPlaceholder = "prev_action";
public const string ActionMaskPlaceholder = "action_masks";
public const string RandomNormalEpsilonPlaceholder = "epsilon";
public const string ValueEstimateOutput = "value_estimate";
public const string RecurrentOutput = "recurrent_out";
public const string MemorySize = "memory_size";
public const string VersionNumber = "version_number";
public const string ContinuousActionOutputShape = "continuous_action_output_shape";
public const string DiscreteActionOutputShape = "discrete_action_output_shape";
public const string ContinuousActionOutput = "continuous_actions";
public const string DiscreteActionOutput = "discrete_actions";
public const string DeterministicContinuousActionOutput = "deterministic_continuous_actions";
public const string DeterministicDiscreteActionOutput = "deterministic_discrete_actions";
// Deprecated TensorNames entries for backward compatibility
public const string IsContinuousControlDeprecated = "is_continuous_control";
public const string ActionOutputDeprecated = "action";
public const string ActionOutputShapeDeprecated = "action_output_shape";
/// <summary>
/// Returns the name of the visual observation with a given index
/// </summary>
public static string GetVisualObservationName(int index)
{
return VisualObservationPlaceholderPrefix + index;
}
/// <summary>
/// Returns the name of the observation with a given index
/// </summary>
public static string GetObservationName(int index)
{
return ObservationPlaceholderPrefix + index;
}
}
}
|