File size: 4,489 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{
public static class TestGridSensorConfig
{
public static int ObservationSize;
public static bool IsNormalized;
public static bool ParseAllColliders;
public static void SetParameters(int observationSize, bool isNormalized, bool parseAllColliders)
{
ObservationSize = observationSize;
IsNormalized = isNormalized;
ParseAllColliders = parseAllColliders;
}
public static void Reset()
{
ObservationSize = 0;
IsNormalized = false;
ParseAllColliders = false;
}
}
public class SimpleTestGridSensor : GridSensorBase
{
public float[] DummyData;
public SimpleTestGridSensor(
string name,
Vector3 cellScale,
Vector3Int gridSize,
string[] detectableTags,
SensorCompressionType compression
) : base(
name,
cellScale,
gridSize,
detectableTags,
compression)
{ }
protected override int GetCellObservationSize()
{
return TestGridSensorConfig.ObservationSize;
}
protected override bool IsDataNormalized()
{
return TestGridSensorConfig.IsNormalized;
}
protected internal override ProcessCollidersMethod GetProcessCollidersMethod()
{
return TestGridSensorConfig.ParseAllColliders ? ProcessCollidersMethod.ProcessAllColliders : ProcessCollidersMethod.ProcessClosestColliders;
}
protected override void GetObjectData(GameObject detectedObject, int typeIndex, float[] dataBuffer)
{
for (var i = 0; i < DummyData.Length; i++)
{
dataBuffer[i] = DummyData[i];
}
}
}
public class SimpleTestGridSensorComponent : GridSensorComponent
{
bool m_UseOneHotTag;
bool m_UseTestingGridSensor;
bool m_UseGridSensorBase;
protected override GridSensorBase[] GetGridSensors()
{
List<GridSensorBase> sensorList = new List<GridSensorBase>();
if (m_UseOneHotTag)
{
var testSensor = new OneHotGridSensor(
SensorName,
CellScale,
GridSize,
DetectableTags,
CompressionType
);
sensorList.Add(testSensor);
}
if (m_UseGridSensorBase)
{
var testSensor = new GridSensorBase(
SensorName,
CellScale,
GridSize,
DetectableTags,
CompressionType
);
sensorList.Add(testSensor);
}
if (m_UseTestingGridSensor)
{
var testSensor = new SimpleTestGridSensor(
SensorName,
CellScale,
GridSize,
DetectableTags,
CompressionType
);
sensorList.Add(testSensor);
}
return sensorList.ToArray();
}
public void SetComponentParameters(
string[] detectableTags = null,
float cellScaleX = 1f,
float cellScaleZ = 1f,
int gridSizeX = 10,
int gridSizeY = 1,
int gridSizeZ = 10,
int colliderMaskInt = -1,
SensorCompressionType compression = SensorCompressionType.None,
bool rotateWithAgent = false,
bool useOneHotTag = false,
bool useTestingGridSensor = false,
bool useGridSensorBase = false
)
{
DetectableTags = detectableTags;
CellScale = new Vector3(cellScaleX, 0.01f, cellScaleZ);
GridSize = new Vector3Int(gridSizeX, gridSizeY, gridSizeZ);
ColliderMask = colliderMaskInt < 0 ? LayerMask.GetMask("Default") : colliderMaskInt;
RotateWithAgent = rotateWithAgent;
CompressionType = compression;
m_UseOneHotTag = useOneHotTag;
m_UseGridSensorBase = useGridSensorBase;
m_UseTestingGridSensor = useTestingGridSensor;
}
}
}
|