File size: 4,489 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Tests
{
    public static class TestGridSensorConfig
    {
        public static int ObservationSize;
        public static bool IsNormalized;
        public static bool ParseAllColliders;

        public static void SetParameters(int observationSize, bool isNormalized, bool parseAllColliders)
        {
            ObservationSize = observationSize;
            IsNormalized = isNormalized;
            ParseAllColliders = parseAllColliders;
        }

        public static void Reset()
        {
            ObservationSize = 0;
            IsNormalized = false;
            ParseAllColliders = false;
        }
    }

    public class SimpleTestGridSensor : GridSensorBase
    {
        public float[] DummyData;

        public SimpleTestGridSensor(
            string name,
            Vector3 cellScale,
            Vector3Int gridSize,
            string[] detectableTags,
            SensorCompressionType compression
        ) : base(
            name,
            cellScale,
            gridSize,
            detectableTags,
            compression)
        { }

        protected override int GetCellObservationSize()
        {
            return TestGridSensorConfig.ObservationSize;
        }

        protected override bool IsDataNormalized()
        {
            return TestGridSensorConfig.IsNormalized;
        }

        protected internal override ProcessCollidersMethod GetProcessCollidersMethod()
        {
            return TestGridSensorConfig.ParseAllColliders ? ProcessCollidersMethod.ProcessAllColliders : ProcessCollidersMethod.ProcessClosestColliders;
        }

        protected override void GetObjectData(GameObject detectedObject, int typeIndex, float[] dataBuffer)
        {
            for (var i = 0; i < DummyData.Length; i++)
            {
                dataBuffer[i] = DummyData[i];
            }
        }
    }

    public class SimpleTestGridSensorComponent : GridSensorComponent
    {
        bool m_UseOneHotTag;
        bool m_UseTestingGridSensor;
        bool m_UseGridSensorBase;

        protected override GridSensorBase[] GetGridSensors()
        {
            List<GridSensorBase> sensorList = new List<GridSensorBase>();
            if (m_UseOneHotTag)
            {
                var testSensor = new OneHotGridSensor(
                    SensorName,
                    CellScale,
                    GridSize,
                    DetectableTags,
                    CompressionType
                );
                sensorList.Add(testSensor);
            }
            if (m_UseGridSensorBase)
            {
                var testSensor = new GridSensorBase(
                    SensorName,
                    CellScale,
                    GridSize,
                    DetectableTags,
                    CompressionType
                );
                sensorList.Add(testSensor);
            }
            if (m_UseTestingGridSensor)
            {
                var testSensor = new SimpleTestGridSensor(
                    SensorName,
                    CellScale,
                    GridSize,
                    DetectableTags,
                    CompressionType
                );
                sensorList.Add(testSensor);
            }
            return sensorList.ToArray();
        }

        public void SetComponentParameters(
            string[] detectableTags = null,
            float cellScaleX = 1f,
            float cellScaleZ = 1f,
            int gridSizeX = 10,
            int gridSizeY = 1,
            int gridSizeZ = 10,
            int colliderMaskInt = -1,
            SensorCompressionType compression = SensorCompressionType.None,
            bool rotateWithAgent = false,
            bool useOneHotTag = false,
            bool useTestingGridSensor = false,
            bool useGridSensorBase = false
        )
        {
            DetectableTags = detectableTags;
            CellScale = new Vector3(cellScaleX, 0.01f, cellScaleZ);
            GridSize = new Vector3Int(gridSizeX, gridSizeY, gridSizeZ);
            ColliderMask = colliderMaskInt < 0 ? LayerMask.GetMask("Default") : colliderMaskInt;
            RotateWithAgent = rotateWithAgent;
            CompressionType = compression;
            m_UseOneHotTag = useOneHotTag;
            m_UseGridSensorBase = useGridSensorBase;
            m_UseTestingGridSensor = useTestingGridSensor;
        }
    }
}