File size: 8,700 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
using System.Collections.Generic;
using Unity.Barracuda;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Inference
{
    /// <summary>
    /// Mapping between Tensor names and generators.
    /// A TensorGenerator implements a Dictionary of strings (node names) to an Action.
    /// The Action take as argument the tensor, the current batch size and a Dictionary of
    /// Agent to AgentInfo corresponding to the current batch.
    /// Each Generator reshapes and fills the data of the tensor based of the data of the batch.
    /// When the TensorProxy is an Input to the model, the shape of the Tensor will be modified
    /// depending on the current batch size and the data of the Tensor will be filled using the
    /// Dictionary of Agent to AgentInfo.
    /// When the TensorProxy is an Output of the model, only the shape of the Tensor will be
    /// modified using the current batch size. The data will be pre-filled with zeros.
    /// </summary>
    internal class TensorGenerator
    {
        public interface IGenerator
        {
            /// <summary>
            /// Modifies the data inside a Tensor according to the information contained in the
            /// AgentInfos contained in the current batch.
            /// </summary>
            /// <param name="tensorProxy"> The tensor the data and shape will be modified.</param>
            /// <param name="batchSize"> The number of agents present in the current batch.</param>
            /// <param name="infos">
            /// List of AgentInfos containing the information that will be used to populate
            /// the tensor's data.
            /// </param>
            void Generate(
                TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos);
        }

        readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>();
        int m_ApiVersion;

        /// <summary>
        /// Returns a new TensorGenerators object.
        /// </summary>
        /// <param name="seed"> The seed the Generators will be initialized with.</param>
        /// <param name="allocator"> Tensor allocator.</param>
        /// <param name="memories">Dictionary of AgentInfo.id to memory for use in the inference model.</param>
        /// <param name="barracudaModel"></param>
        /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
        /// deterministic. </param>
        public TensorGenerator(
            int seed,
            ITensorAllocator allocator,
            Dictionary<int, List<float>> memories,
            object barracudaModel = null,
            bool deterministicInference = false)
        {
            // If model is null, no inference to run and exception is thrown before reaching here.
            if (barracudaModel == null)
            {
                return;
            }
            var model = (Model)barracudaModel;

            m_ApiVersion = model.GetVersion();

            // Generator for Inputs
            m_Dict[TensorNames.BatchSizePlaceholder] =
                new BatchSizeGenerator(allocator);
            m_Dict[TensorNames.SequenceLengthPlaceholder] =
                new SequenceLengthGenerator(allocator);
            m_Dict[TensorNames.RecurrentInPlaceholder] =
                new RecurrentInputGenerator(allocator, memories);

            m_Dict[TensorNames.PreviousActionPlaceholder] =
                new PreviousActionInputGenerator(allocator);
            m_Dict[TensorNames.ActionMaskPlaceholder] =
                new ActionMaskInputGenerator(allocator);
            m_Dict[TensorNames.RandomNormalEpsilonPlaceholder] =
                new RandomNormalInputGenerator(seed, allocator);


            // Generators for Outputs
            if (model.HasContinuousOutputs(deterministicInference))
            {
                m_Dict[model.ContinuousOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator);
            }
            if (model.HasDiscreteOutputs(deterministicInference))
            {
                m_Dict[model.DiscreteOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator);
            }
            m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator);
            m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator);
        }

        public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator)
        {
            if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0)
            {
                // Loop through the sensors on a representative agent.
                // All vector observations use a shared ObservationGenerator since they are concatenated.
                // All other observations use a unique ObservationInputGenerator
                var visIndex = 0;
                ObservationGenerator vecObsGen = null;
                for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
                {
                    var sensor = sensors[sensorIndex];
                    var rank = sensor.GetObservationSpec().Rank;
                    ObservationGenerator obsGen = null;
                    string obsGenName = null;
                    switch (rank)
                    {
                        case 1:
                            if (vecObsGen == null)
                            {
                                vecObsGen = new ObservationGenerator(allocator);
                            }
                            obsGen = vecObsGen;
                            obsGenName = TensorNames.VectorObservationPlaceholder;
                            break;
                        case 2:
                            // If the tensor is of rank 2, we use the index of the sensor
                            // to create the name
                            obsGen = new ObservationGenerator(allocator);
                            obsGenName = TensorNames.GetObservationName(sensorIndex);
                            break;
                        case 3:
                            // If the tensor is of rank 3, we use the "visual observation
                            // index", which only counts the rank 3 sensors
                            obsGen = new ObservationGenerator(allocator);
                            obsGenName = TensorNames.GetVisualObservationName(visIndex);
                            visIndex++;
                            break;
                        default:
                            throw new UnityAgentsException(
                                $"Sensor {sensor.GetName()} have an invalid rank {rank}");
                    }
                    obsGen.AddSensorIndex(sensorIndex);
                    m_Dict[obsGenName] = obsGen;
                }
            }

            if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
            {
                for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
                {
                    var obsGen = new ObservationGenerator(allocator);
                    var obsGenName = TensorNames.GetObservationName(sensorIndex);
                    obsGen.AddSensorIndex(sensorIndex);
                    m_Dict[obsGenName] = obsGen;
                }
            }
        }

        /// <summary>
        /// Populates the data of the tensor inputs given the data contained in the current batch
        /// of agents.
        /// </summary>
        /// <param name="tensors"> Enumerable of tensors that will be modified.</param>
        /// <param name="currentBatchSize"> The number of agents present in the current batch
        /// </param>
        /// <param name="infos"> List of AgentsInfos and Sensors that contains the
        /// data that will be used to modify the tensors</param>
        /// <exception cref="UnityAgentsException"> One of the tensor does not have an
        /// associated generator.</exception>
        public void GenerateTensors(
            IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<AgentInfoSensorsPair> infos)
        {
            for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
            {
                var tensor = tensors[tensorIndex];
                if (!m_Dict.ContainsKey(tensor.name))
                {
                    throw new UnityAgentsException(
                        $"Unknown tensorProxy expected as input : {tensor.name}");
                }
                m_Dict[tensor.name].Generate(tensor, currentBatchSize, infos);
            }
        }
    }
}