File size: 5,247 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
using System.IO;
using Google.Protobuf;
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Policies;

namespace Unity.MLAgents.Demonstrations
{
    /// <summary>
    /// Responsible for writing demonstration data to stream (typically a file stream).
    /// </summary>
    /// <seealso cref="DemonstrationRecorder"/>
    public class DemonstrationWriter
    {
        /// <summary>
        /// Number of bytes reserved for the <see cref="DemonstrationMetaData"/> at the start of the demo file.
        /// </summary>
        internal const int MetaDataBytes = 32;

        DemonstrationMetaData m_MetaData;
        Stream m_Writer;
        float m_CumulativeReward;
        ObservationWriter m_ObservationWriter = new ObservationWriter();

        /// <summary>
        /// Create a DemonstrationWriter that will write to the specified stream.
        /// The stream must support writes and seeking.
        /// </summary>
        /// <param name="stream"></param>
        public DemonstrationWriter(Stream stream)
        {
            m_Writer = stream;
        }

        /// <summary>
        /// Number of steps written so far.
        /// </summary>
        internal int NumSteps
        {
            get { return m_MetaData.numberSteps; }
        }

        /// <summary>
        /// Writes the initial data to the stream.
        /// </summary>
        /// <param name="demonstrationName">Base name of the demonstration file(s).</param>
        /// <param name="brainName">The name of the Brain the agent is attached to.</param>
        /// <param name="brainParameters">The parameters of the Brain the agent is attached to.</param>
        internal void Initialize(
            string demonstrationName, BrainParameters brainParameters, string brainName)
        {
            if (m_Writer == null)
            {
                // Already closed
                return;
            }

            m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName };
            var metaProto = m_MetaData.ToProto();
            metaProto.WriteDelimitedTo(m_Writer);

            WriteBrainParameters(brainName, brainParameters);
        }

        /// <summary>
        /// Writes meta-data. Note that this is called at the *end* of recording, but writes to the
        /// beginning of the file.
        /// </summary>
        void WriteMetadata()
        {
            if (m_Writer == null)
            {
                // Already closed
                return;
            }

            var metaProto = m_MetaData.ToProto();
            var metaProtoBytes = metaProto.ToByteArray();
            m_Writer.Write(metaProtoBytes, 0, metaProtoBytes.Length);
            m_Writer.Seek(0, 0);
            metaProto.WriteDelimitedTo(m_Writer);
        }

        /// <summary>
        /// Writes brain parameters to file.
        /// </summary>
        /// <param name="brainName">The name of the Brain the agent is attached to.</param>
        /// <param name="brainParameters">The parameters of the Brain the agent is attached to.</param>
        void WriteBrainParameters(string brainName, BrainParameters brainParameters)
        {
            if (m_Writer == null)
            {
                // Already closed
                return;
            }

            // Writes BrainParameters to file.
            m_Writer.Seek(MetaDataBytes + 1, 0);
            var brainProto = brainParameters.ToProto(brainName, false);
            brainProto.WriteDelimitedTo(m_Writer);
        }

        /// <summary>
        /// Write AgentInfo experience to file.
        /// </summary>
        /// <param name="info"> <see cref="AgentInfo"/> for the agent being recorded.</param>
        /// <param name="sensors">List of sensors to record for the agent.</param>
        internal void Record(AgentInfo info, List<ISensor> sensors)
        {
            if (m_Writer == null)
            {
                // Already closed
                return;
            }

            // Increment meta-data counters.
            m_MetaData.numberSteps++;
            m_CumulativeReward += info.reward;
            if (info.done)
            {
                EndEpisode();
            }

            // Generate observations and add AgentInfo to file.
            var agentProto = info.ToInfoActionPairProto();
            foreach (var sensor in sensors)
            {
                agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_ObservationWriter));
            }

            agentProto.WriteDelimitedTo(m_Writer);
        }

        /// <summary>
        /// Performs all clean-up necessary.
        /// </summary>
        public void Close()
        {
            if (m_Writer == null)
            {
                // Already closed
                return;
            }

            EndEpisode();
            m_MetaData.meanReward = m_CumulativeReward / m_MetaData.numberEpisodes;
            WriteMetadata();
            m_Writer.Close();
            m_Writer = null;
        }

        /// <summary>
        /// Performs necessary episode-completion steps.
        /// </summary>
        void EndEpisode()
        {
            m_MetaData.numberEpisodes += 1;
        }
    }
}