File size: 2,458 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
namespace Unity.MLAgents.Inference
{
    /// <summary>
    /// Contains the names of the input and output tensors for the Inference Brain.
    /// </summary>
    internal static class TensorNames
    {
        public const string BatchSizePlaceholder = "batch_size";
        public const string SequenceLengthPlaceholder = "sequence_length";
        public const string VectorObservationPlaceholder = "vector_observation";
        public const string RecurrentInPlaceholder = "recurrent_in";
        public const string VisualObservationPlaceholderPrefix = "visual_observation_";
        public const string ObservationPlaceholderPrefix = "obs_";
        public const string PreviousActionPlaceholder = "prev_action";
        public const string ActionMaskPlaceholder = "action_masks";
        public const string RandomNormalEpsilonPlaceholder = "epsilon";

        public const string ValueEstimateOutput = "value_estimate";
        public const string RecurrentOutput = "recurrent_out";
        public const string MemorySize = "memory_size";
        public const string VersionNumber = "version_number";
        public const string ContinuousActionOutputShape = "continuous_action_output_shape";
        public const string DiscreteActionOutputShape = "discrete_action_output_shape";
        public const string ContinuousActionOutput = "continuous_actions";
        public const string DiscreteActionOutput = "discrete_actions";
        public const string DeterministicContinuousActionOutput = "deterministic_continuous_actions";
        public const string DeterministicDiscreteActionOutput = "deterministic_discrete_actions";

        // Deprecated TensorNames entries for backward compatibility
        public const string IsContinuousControlDeprecated = "is_continuous_control";
        public const string ActionOutputDeprecated = "action";
        public const string ActionOutputShapeDeprecated = "action_output_shape";

        /// <summary>
        /// Returns the name of the visual observation with a given index
        /// </summary>
        public static string GetVisualObservationName(int index)
        {
            return VisualObservationPlaceholderPrefix + index;
        }

        /// <summary>
        /// Returns the name of the observation with a given index
        /// </summary>
        public static string GetObservationName(int index)
        {
            return ObservationPlaceholderPrefix + index;
        }
    }
}