File size: 5,247 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
using System.IO;
using Google.Protobuf;
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Demonstrations
{
/// <summary>
/// Responsible for writing demonstration data to stream (typically a file stream).
/// </summary>
/// <seealso cref="DemonstrationRecorder"/>
public class DemonstrationWriter
{
/// <summary>
/// Number of bytes reserved for the <see cref="DemonstrationMetaData"/> at the start of the demo file.
/// </summary>
internal const int MetaDataBytes = 32;
DemonstrationMetaData m_MetaData;
Stream m_Writer;
float m_CumulativeReward;
ObservationWriter m_ObservationWriter = new ObservationWriter();
/// <summary>
/// Create a DemonstrationWriter that will write to the specified stream.
/// The stream must support writes and seeking.
/// </summary>
/// <param name="stream"></param>
public DemonstrationWriter(Stream stream)
{
m_Writer = stream;
}
/// <summary>
/// Number of steps written so far.
/// </summary>
internal int NumSteps
{
get { return m_MetaData.numberSteps; }
}
/// <summary>
/// Writes the initial data to the stream.
/// </summary>
/// <param name="demonstrationName">Base name of the demonstration file(s).</param>
/// <param name="brainName">The name of the Brain the agent is attached to.</param>
/// <param name="brainParameters">The parameters of the Brain the agent is attached to.</param>
internal void Initialize(
string demonstrationName, BrainParameters brainParameters, string brainName)
{
if (m_Writer == null)
{
// Already closed
return;
}
m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName };
var metaProto = m_MetaData.ToProto();
metaProto.WriteDelimitedTo(m_Writer);
WriteBrainParameters(brainName, brainParameters);
}
/// <summary>
/// Writes meta-data. Note that this is called at the *end* of recording, but writes to the
/// beginning of the file.
/// </summary>
void WriteMetadata()
{
if (m_Writer == null)
{
// Already closed
return;
}
var metaProto = m_MetaData.ToProto();
var metaProtoBytes = metaProto.ToByteArray();
m_Writer.Write(metaProtoBytes, 0, metaProtoBytes.Length);
m_Writer.Seek(0, 0);
metaProto.WriteDelimitedTo(m_Writer);
}
/// <summary>
/// Writes brain parameters to file.
/// </summary>
/// <param name="brainName">The name of the Brain the agent is attached to.</param>
/// <param name="brainParameters">The parameters of the Brain the agent is attached to.</param>
void WriteBrainParameters(string brainName, BrainParameters brainParameters)
{
if (m_Writer == null)
{
// Already closed
return;
}
// Writes BrainParameters to file.
m_Writer.Seek(MetaDataBytes + 1, 0);
var brainProto = brainParameters.ToProto(brainName, false);
brainProto.WriteDelimitedTo(m_Writer);
}
/// <summary>
/// Write AgentInfo experience to file.
/// </summary>
/// <param name="info"> <see cref="AgentInfo"/> for the agent being recorded.</param>
/// <param name="sensors">List of sensors to record for the agent.</param>
internal void Record(AgentInfo info, List<ISensor> sensors)
{
if (m_Writer == null)
{
// Already closed
return;
}
// Increment meta-data counters.
m_MetaData.numberSteps++;
m_CumulativeReward += info.reward;
if (info.done)
{
EndEpisode();
}
// Generate observations and add AgentInfo to file.
var agentProto = info.ToInfoActionPairProto();
foreach (var sensor in sensors)
{
agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_ObservationWriter));
}
agentProto.WriteDelimitedTo(m_Writer);
}
/// <summary>
/// Performs all clean-up necessary.
/// </summary>
public void Close()
{
if (m_Writer == null)
{
// Already closed
return;
}
EndEpisode();
m_MetaData.meanReward = m_CumulativeReward / m_MetaData.numberEpisodes;
WriteMetadata();
m_Writer.Close();
m_Writer = null;
}
/// <summary>
/// Performs necessary episode-completion steps.
/// </summary>
void EndEpisode()
{
m_MetaData.numberEpisodes += 1;
}
}
}
|