|
using System; |
|
using System.Collections.Generic; |
|
using System.Collections.ObjectModel; |
|
using System.Linq; |
|
using UnityEngine; |
|
using Unity.Barracuda; |
|
|
|
namespace Unity.MLAgents.Sensors |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public class StackingSensor : ISensor, IBuiltInSensor |
|
{ |
|
|
|
|
|
|
|
ISensor m_WrappedSensor; |
|
|
|
|
|
|
|
|
|
int m_NumStackedObservations; |
|
int m_UnstackedObservationSize; |
|
|
|
string m_Name; |
|
private ObservationSpec m_ObservationSpec; |
|
private ObservationSpec m_WrappedSpec; |
|
|
|
|
|
|
|
|
|
float[][] m_StackedObservations; |
|
|
|
byte[][] m_StackedCompressedObservations; |
|
|
|
int m_CurrentIndex; |
|
ObservationWriter m_LocalWriter = new ObservationWriter(); |
|
|
|
byte[] m_EmptyCompressedObservation; |
|
int[] m_CompressionMapping; |
|
TensorShape m_tensorShape; |
|
|
|
|
|
|
|
|
|
|
|
|
|
public StackingSensor(ISensor wrapped, int numStackedObservations) |
|
{ |
|
|
|
m_WrappedSensor = wrapped; |
|
m_NumStackedObservations = numStackedObservations; |
|
|
|
m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}"; |
|
|
|
m_WrappedSpec = wrapped.GetObservationSpec(); |
|
|
|
m_UnstackedObservationSize = wrapped.ObservationSize(); |
|
|
|
|
|
var newShape = m_WrappedSpec.Shape; |
|
|
|
newShape[newShape.Length - 1] *= numStackedObservations; |
|
m_ObservationSpec = new ObservationSpec( |
|
newShape, m_WrappedSpec.DimensionProperties, m_WrappedSpec.ObservationType |
|
); |
|
|
|
|
|
|
|
m_StackedObservations = new float[numStackedObservations][]; |
|
for (var i = 0; i < numStackedObservations; i++) |
|
{ |
|
m_StackedObservations[i] = new float[m_UnstackedObservationSize]; |
|
} |
|
|
|
if (m_WrappedSensor.GetCompressionSpec().SensorCompressionType != SensorCompressionType.None) |
|
{ |
|
m_StackedCompressedObservations = new byte[numStackedObservations][]; |
|
m_EmptyCompressedObservation = CreateEmptyPNG(); |
|
for (var i = 0; i < numStackedObservations; i++) |
|
{ |
|
m_StackedCompressedObservations[i] = m_EmptyCompressedObservation; |
|
} |
|
m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped); |
|
} |
|
|
|
if (m_WrappedSpec.Rank != 1) |
|
{ |
|
var wrappedShape = m_WrappedSpec.Shape; |
|
m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]); |
|
} |
|
} |
|
|
|
|
|
public int Write(ObservationWriter writer) |
|
{ |
|
|
|
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec, 0); |
|
m_WrappedSensor.Write(m_LocalWriter); |
|
|
|
|
|
var numWritten = 0; |
|
if (m_WrappedSpec.Rank == 1) |
|
{ |
|
for (var i = 0; i < m_NumStackedObservations; i++) |
|
{ |
|
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; |
|
writer.AddList(m_StackedObservations[obsIndex], numWritten); |
|
numWritten += m_UnstackedObservationSize; |
|
} |
|
} |
|
else |
|
{ |
|
for (var i = 0; i < m_NumStackedObservations; i++) |
|
{ |
|
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; |
|
for (var h = 0; h < m_WrappedSpec.Shape[0]; h++) |
|
{ |
|
for (var w = 0; w < m_WrappedSpec.Shape[1]; w++) |
|
{ |
|
for (var c = 0; c < m_WrappedSpec.Shape[2]; c++) |
|
{ |
|
writer[h, w, i * m_WrappedSpec.Shape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)]; |
|
} |
|
} |
|
} |
|
} |
|
numWritten = m_WrappedSpec.Shape[0] * m_WrappedSpec.Shape[1] * m_WrappedSpec.Shape[2] * m_NumStackedObservations; |
|
} |
|
|
|
return numWritten; |
|
} |
|
|
|
|
|
|
|
|
|
public void Update() |
|
{ |
|
m_WrappedSensor.Update(); |
|
m_CurrentIndex = (m_CurrentIndex + 1) % m_NumStackedObservations; |
|
} |
|
|
|
|
|
public void Reset() |
|
{ |
|
m_WrappedSensor.Reset(); |
|
|
|
for (var i = 0; i < m_NumStackedObservations; i++) |
|
{ |
|
Array.Clear(m_StackedObservations[i], 0, m_StackedObservations[i].Length); |
|
} |
|
if (m_WrappedSensor.GetCompressionSpec().SensorCompressionType != SensorCompressionType.None) |
|
{ |
|
for (var i = 0; i < m_NumStackedObservations; i++) |
|
{ |
|
m_StackedCompressedObservations[i] = m_EmptyCompressedObservation; |
|
} |
|
} |
|
} |
|
|
|
|
|
public ObservationSpec GetObservationSpec() |
|
{ |
|
return m_ObservationSpec; |
|
} |
|
|
|
|
|
public string GetName() |
|
{ |
|
return m_Name; |
|
} |
|
|
|
|
|
public byte[] GetCompressedObservation() |
|
{ |
|
var compressed = m_WrappedSensor.GetCompressedObservation(); |
|
m_StackedCompressedObservations[m_CurrentIndex] = compressed; |
|
|
|
int bytesLength = 0; |
|
foreach (byte[] compressedObs in m_StackedCompressedObservations) |
|
{ |
|
bytesLength += compressedObs.Length; |
|
} |
|
|
|
byte[] outputBytes = new byte[bytesLength]; |
|
int offset = 0; |
|
for (var i = 0; i < m_NumStackedObservations; i++) |
|
{ |
|
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; |
|
Buffer.BlockCopy(m_StackedCompressedObservations[obsIndex], |
|
0, outputBytes, offset, m_StackedCompressedObservations[obsIndex].Length); |
|
offset += m_StackedCompressedObservations[obsIndex].Length; |
|
} |
|
|
|
return outputBytes; |
|
} |
|
|
|
|
|
public CompressionSpec GetCompressionSpec() |
|
{ |
|
var wrappedSpec = m_WrappedSensor.GetCompressionSpec(); |
|
return new CompressionSpec(wrappedSpec.SensorCompressionType, m_CompressionMapping); |
|
} |
|
|
|
|
|
|
|
|
|
internal byte[] CreateEmptyPNG() |
|
{ |
|
var shape = m_WrappedSpec.Shape; |
|
int height = shape[0]; |
|
int width = shape[1]; |
|
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false); |
|
Color32[] resetColorArray = texture2D.GetPixels32(); |
|
Color32 black = new Color32(0, 0, 0, 0); |
|
for (int i = 0; i < resetColorArray.Length; i++) |
|
{ |
|
resetColorArray[i] = black; |
|
} |
|
texture2D.SetPixels32(resetColorArray); |
|
texture2D.Apply(); |
|
return texture2D.EncodeToPNG(); |
|
} |
|
|
|
|
|
|
|
|
|
internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor) |
|
{ |
|
|
|
|
|
|
|
int[] wrappedMapping = null; |
|
int wrappedNumChannel = m_WrappedSpec.Shape[2]; |
|
|
|
wrappedMapping = wrappedSenesor.GetCompressionSpec().CompressedChannelMapping; |
|
if (wrappedMapping == null) |
|
{ |
|
if (wrappedNumChannel == 1) |
|
{ |
|
wrappedMapping = new[] { 0, 0, 0 }; |
|
} |
|
else |
|
{ |
|
wrappedMapping = Enumerable.Range(0, wrappedNumChannel).ToArray(); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
int paddedMapLength = (wrappedMapping.Length + 2) / 3 * 3; |
|
var compressionMapping = new int[paddedMapLength * m_NumStackedObservations]; |
|
for (var i = 0; i < m_NumStackedObservations; i++) |
|
{ |
|
var offset = wrappedNumChannel * i; |
|
for (var j = 0; j < paddedMapLength; j++) |
|
{ |
|
if (j < wrappedMapping.Length) |
|
{ |
|
compressionMapping[j + paddedMapLength * i] = wrappedMapping[j] >= 0 ? wrappedMapping[j] + offset : -1; |
|
} |
|
else |
|
{ |
|
compressionMapping[j + paddedMapLength * i] = -1; |
|
} |
|
} |
|
} |
|
return compressionMapping; |
|
} |
|
|
|
|
|
public BuiltInSensorType GetBuiltInSensorType() |
|
{ |
|
IBuiltInSensor wrappedBuiltInSensor = m_WrappedSensor as IBuiltInSensor; |
|
return wrappedBuiltInSensor?.GetBuiltInSensorType() ?? BuiltInSensorType.Unknown; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
internal ReadOnlyCollection<float> GetStackedObservations() |
|
{ |
|
List<float> observations = new List<float>(); |
|
for (var i = 0; i < m_NumStackedObservations; i++) |
|
{ |
|
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; |
|
observations.AddRange(m_StackedObservations[obsIndex].ToList()); |
|
} |
|
return observations.AsReadOnly(); |
|
} |
|
} |
|
} |
|
|