|
using System; |
|
using System.Collections.Generic; |
|
using Unity.Barracuda; |
|
using Unity.MLAgents.Inference; |
|
using UnityEngine; |
|
|
|
namespace Unity.MLAgents.Sensors |
|
{ |
|
|
|
|
|
|
|
public class ObservationWriter |
|
{ |
|
IList<float> m_Data; |
|
int m_Offset; |
|
|
|
TensorProxy m_Proxy; |
|
int m_Batch; |
|
|
|
TensorShape m_TensorShape; |
|
|
|
public ObservationWriter() { } |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal void SetTarget(IList<float> data, ObservationSpec observationSpec, int offset) |
|
{ |
|
SetTarget(data, observationSpec.Shape, offset); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal void SetTarget(IList<float> data, InplaceArray<int> shape, int offset) |
|
{ |
|
m_Data = data; |
|
m_Offset = offset; |
|
m_Proxy = null; |
|
m_Batch = 0; |
|
|
|
if (shape.Length == 1) |
|
{ |
|
m_TensorShape = new TensorShape(m_Batch, shape[0]); |
|
} |
|
else if (shape.Length == 2) |
|
{ |
|
m_TensorShape = new TensorShape(new[] { m_Batch, 1, shape[0], shape[1] }); |
|
} |
|
else |
|
{ |
|
m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
internal void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset) |
|
{ |
|
m_Proxy = tensorProxy; |
|
m_Batch = batchIndex; |
|
m_Offset = channelOffset; |
|
m_Data = null; |
|
m_TensorShape = m_Proxy.data.shape; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
public float this[int index] |
|
{ |
|
set |
|
{ |
|
if (m_Data != null) |
|
{ |
|
m_Data[index + m_Offset] = value; |
|
} |
|
else |
|
{ |
|
m_Proxy.data[m_Batch, index + m_Offset] = value; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public float this[int h, int w, int ch] |
|
{ |
|
set |
|
{ |
|
if (m_Data != null) |
|
{ |
|
if (h < 0 || h >= m_TensorShape.height) |
|
{ |
|
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height - 1}]"); |
|
} |
|
if (w < 0 || w >= m_TensorShape.width) |
|
{ |
|
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width - 1}]"); |
|
} |
|
if (ch < 0 || ch >= m_TensorShape.channels) |
|
{ |
|
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels - 1}]"); |
|
} |
|
|
|
var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset); |
|
m_Data[index] = value; |
|
} |
|
else |
|
{ |
|
m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public void AddList(IList<float> data, int writeOffset = 0) |
|
{ |
|
if (m_Data != null) |
|
{ |
|
for (var index = 0; index < data.Count; index++) |
|
{ |
|
var val = data[index]; |
|
m_Data[index + m_Offset + writeOffset] = val; |
|
} |
|
} |
|
else |
|
{ |
|
for (var index = 0; index < data.Count; index++) |
|
{ |
|
var val = data[index]; |
|
m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public void Add(Vector3 vec, int writeOffset = 0) |
|
{ |
|
if (m_Data != null) |
|
{ |
|
m_Data[m_Offset + writeOffset + 0] = vec.x; |
|
m_Data[m_Offset + writeOffset + 1] = vec.y; |
|
m_Data[m_Offset + writeOffset + 2] = vec.z; |
|
} |
|
else |
|
{ |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
public void Add(Vector4 vec, int writeOffset = 0) |
|
{ |
|
if (m_Data != null) |
|
{ |
|
m_Data[m_Offset + writeOffset + 0] = vec.x; |
|
m_Data[m_Offset + writeOffset + 1] = vec.y; |
|
m_Data[m_Offset + writeOffset + 2] = vec.z; |
|
m_Data[m_Offset + writeOffset + 3] = vec.w; |
|
} |
|
else |
|
{ |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = vec.w; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public void Add(Quaternion quat, int writeOffset = 0) |
|
{ |
|
if (m_Data != null) |
|
{ |
|
m_Data[m_Offset + writeOffset + 0] = quat.x; |
|
m_Data[m_Offset + writeOffset + 1] = quat.y; |
|
m_Data[m_Offset + writeOffset + 2] = quat.z; |
|
m_Data[m_Offset + writeOffset + 3] = quat.w; |
|
} |
|
else |
|
{ |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = quat.x; |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = quat.y; |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = quat.z; |
|
m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = quat.w; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
public static class ObservationWriterExtension |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static int WriteTexture( |
|
this ObservationWriter obsWriter, |
|
Texture2D texture, |
|
bool grayScale) |
|
{ |
|
if (texture.format == TextureFormat.RGB24) |
|
{ |
|
return obsWriter.WriteTextureRGB24(texture, grayScale); |
|
} |
|
var width = texture.width; |
|
var height = texture.height; |
|
|
|
var texturePixels = texture.GetPixels32(); |
|
|
|
|
|
for (var h = height - 1; h >= 0; h--) |
|
{ |
|
for (var w = 0; w < width; w++) |
|
{ |
|
var currentPixel = texturePixels[(height - h - 1) * width + w]; |
|
|
|
if (grayScale) |
|
{ |
|
obsWriter[h, w, 0] = |
|
(currentPixel.r + currentPixel.g + currentPixel.b) / 3f / 255.0f; |
|
} |
|
else |
|
{ |
|
|
|
obsWriter[h, w, 0] = currentPixel.r / 255.0f; |
|
obsWriter[h, w, 1] = currentPixel.g / 255.0f; |
|
obsWriter[h, w, 2] = currentPixel.b / 255.0f; |
|
} |
|
} |
|
} |
|
|
|
return height * width * (grayScale ? 1 : 3); |
|
} |
|
|
|
internal static int WriteTextureRGB24( |
|
this ObservationWriter obsWriter, |
|
Texture2D texture, |
|
bool grayScale |
|
) |
|
{ |
|
var width = texture.width; |
|
var height = texture.height; |
|
|
|
var rawBytes = texture.GetRawTextureData<byte>(); |
|
|
|
|
|
for (var h = height - 1; h >= 0; h--) |
|
{ |
|
for (var w = 0; w < width; w++) |
|
{ |
|
var offset = (height - h - 1) * width + w; |
|
var r = rawBytes[3 * offset]; |
|
var g = rawBytes[3 * offset + 1]; |
|
var b = rawBytes[3 * offset + 2]; |
|
|
|
if (grayScale) |
|
{ |
|
obsWriter[h, w, 0] = (r + g + b) / 3f / 255.0f; |
|
} |
|
else |
|
{ |
|
|
|
obsWriter[h, w, 0] = r / 255.0f; |
|
obsWriter[h, w, 1] = g / 255.0f; |
|
obsWriter[h, w, 2] = b / 255.0f; |
|
} |
|
} |
|
} |
|
|
|
return height * width * (grayScale ? 1 : 3); |
|
} |
|
} |
|
} |
|
|