|
using System.IO; |
|
using Google.Protobuf; |
|
using System.Collections.Generic; |
|
using Unity.MLAgents.Sensors; |
|
using Unity.MLAgents.Policies; |
|
|
|
namespace Unity.MLAgents.Demonstrations |
|
{ |
|
|
|
|
|
|
|
|
|
public class DemonstrationWriter |
|
{ |
|
|
|
|
|
|
|
internal const int MetaDataBytes = 32; |
|
|
|
DemonstrationMetaData m_MetaData; |
|
Stream m_Writer; |
|
float m_CumulativeReward; |
|
ObservationWriter m_ObservationWriter = new ObservationWriter(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
public DemonstrationWriter(Stream stream) |
|
{ |
|
m_Writer = stream; |
|
} |
|
|
|
|
|
|
|
|
|
internal int NumSteps |
|
{ |
|
get { return m_MetaData.numberSteps; } |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal void Initialize( |
|
string demonstrationName, BrainParameters brainParameters, string brainName) |
|
{ |
|
if (m_Writer == null) |
|
{ |
|
|
|
return; |
|
} |
|
|
|
m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName }; |
|
var metaProto = m_MetaData.ToProto(); |
|
metaProto.WriteDelimitedTo(m_Writer); |
|
|
|
WriteBrainParameters(brainName, brainParameters); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
void WriteMetadata() |
|
{ |
|
if (m_Writer == null) |
|
{ |
|
|
|
return; |
|
} |
|
|
|
var metaProto = m_MetaData.ToProto(); |
|
var metaProtoBytes = metaProto.ToByteArray(); |
|
m_Writer.Write(metaProtoBytes, 0, metaProtoBytes.Length); |
|
m_Writer.Seek(0, 0); |
|
metaProto.WriteDelimitedTo(m_Writer); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
void WriteBrainParameters(string brainName, BrainParameters brainParameters) |
|
{ |
|
if (m_Writer == null) |
|
{ |
|
|
|
return; |
|
} |
|
|
|
|
|
m_Writer.Seek(MetaDataBytes + 1, 0); |
|
var brainProto = brainParameters.ToProto(brainName, false); |
|
brainProto.WriteDelimitedTo(m_Writer); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
internal void Record(AgentInfo info, List<ISensor> sensors) |
|
{ |
|
if (m_Writer == null) |
|
{ |
|
|
|
return; |
|
} |
|
|
|
|
|
m_MetaData.numberSteps++; |
|
m_CumulativeReward += info.reward; |
|
if (info.done) |
|
{ |
|
EndEpisode(); |
|
} |
|
|
|
|
|
var agentProto = info.ToInfoActionPairProto(); |
|
foreach (var sensor in sensors) |
|
{ |
|
agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_ObservationWriter)); |
|
} |
|
|
|
agentProto.WriteDelimitedTo(m_Writer); |
|
} |
|
|
|
|
|
|
|
|
|
public void Close() |
|
{ |
|
if (m_Writer == null) |
|
{ |
|
|
|
return; |
|
} |
|
|
|
EndEpisode(); |
|
m_MetaData.meanReward = m_CumulativeReward / m_MetaData.numberEpisodes; |
|
WriteMetadata(); |
|
m_Writer.Close(); |
|
m_Writer = null; |
|
} |
|
|
|
|
|
|
|
|
|
void EndEpisode() |
|
{ |
|
m_MetaData.numberEpisodes += 1; |
|
} |
|
} |
|
} |
|
|