File size: 16,415 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
from mlagents_envs.base_env import (
    ActionSpec,
    ObservationSpec,
    DimensionProperty,
    BehaviorSpec,
    DecisionSteps,
    TerminalSteps,
    ObservationType,
)
from mlagents_envs.exception import UnityObservationException
from mlagents_envs.timers import hierarchical_timer, timed
from mlagents_envs.communicator_objects.agent_info_pb2 import AgentInfoProto
from mlagents_envs.communicator_objects.observation_pb2 import (
    ObservationProto,
    NONE as COMPRESSION_TYPE_NONE,
)
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
import numpy as np
import io
from typing import cast, List, Tuple, Collection, Optional, Iterable
from PIL import Image


PNG_HEADER = b"\x89PNG\r\n\x1a\n"


def behavior_spec_from_proto(
    brain_param_proto: BrainParametersProto, agent_info: AgentInfoProto
) -> BehaviorSpec:
    """
    Converts brain parameter and agent info proto to BehaviorSpec object.
    :param brain_param_proto: protobuf object.
    :param agent_info: protobuf object.
    :return: BehaviorSpec object.
    """
    observation_specs = []
    for obs in agent_info.observations:
        observation_specs.append(
            ObservationSpec(
                name=obs.name,
                shape=tuple(obs.shape),
                observation_type=ObservationType(obs.observation_type),
                dimension_property=tuple(
                    DimensionProperty(dim) for dim in obs.dimension_properties
                )
                if len(obs.dimension_properties) > 0
                else (DimensionProperty.UNSPECIFIED,) * len(obs.shape),
            )
        )

    # proto from communicator < v1.3 does not set action spec, use deprecated fields instead
    if (
        brain_param_proto.action_spec.num_continuous_actions == 0
        and brain_param_proto.action_spec.num_discrete_actions == 0
    ):
        if brain_param_proto.vector_action_space_type_deprecated == 1:
            action_spec = ActionSpec(
                brain_param_proto.vector_action_size_deprecated[0], ()
            )
        else:
            action_spec = ActionSpec(
                0, tuple(brain_param_proto.vector_action_size_deprecated)
            )
    else:
        action_spec_proto = brain_param_proto.action_spec
        action_spec = ActionSpec(
            action_spec_proto.num_continuous_actions,
            tuple(branch for branch in action_spec_proto.discrete_branch_sizes),
        )
    return BehaviorSpec(observation_specs, action_spec)


class OffsetBytesIO:
    """
    Simple file-like class that wraps a bytes, and allows moving its "start"
    position in the bytes. This is only used for reading concatenated PNGs,
    because Pillow always calls seek(0) at the start of reading.
    """

    __slots__ = ["fp", "offset"]

    def __init__(self, data: bytes):
        self.fp = io.BytesIO(data)
        self.offset = 0

    def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
        if whence == io.SEEK_SET:
            res = self.fp.seek(offset + self.offset)
            return res - self.offset
        raise NotImplementedError()

    def tell(self) -> int:
        return self.fp.tell() - self.offset

    def read(self, size: int = -1) -> bytes:
        return self.fp.read(size)

    def original_tell(self) -> int:
        """
        Returns the offset into the original byte array
        """
        return self.fp.tell()


@timed
def process_pixels(
    image_bytes: bytes, expected_channels: int, mappings: Optional[List[int]] = None
) -> np.ndarray:
    """
    Converts byte array observation image into numpy array, re-sizes it,
    and optionally converts it to grey scale
    :param image_bytes: input byte array corresponding to image
    :param expected_channels: Expected output channels
    :return: processed numpy array of observation from environment
    """
    image_fp = OffsetBytesIO(image_bytes)

    image_arrays = []
    # Read the images back from the bytes (without knowing the sizes).
    while True:
        with hierarchical_timer("image_decompress"):
            image = Image.open(image_fp)
            # Normally Image loads lazily, load() forces it to do loading in the timer scope.
            image.load()
        image_arrays.append(np.array(image, dtype=np.float32) / 255.0)

        # Look for the next header, starting from the current stream location
        try:
            new_offset = image_bytes.index(PNG_HEADER, image_fp.original_tell())
            image_fp.offset = new_offset
        except ValueError:
            # Didn't find the header, so must be at the end.
            break

    if mappings is not None and len(mappings) > 0:
        return _process_images_mapping(image_arrays, mappings)
    else:
        return _process_images_num_channels(image_arrays, expected_channels)


def _process_images_mapping(image_arrays, mappings):
    """
    Helper function for processing decompressed images with compressed channel mappings.
    """
    image_arrays = np.concatenate(image_arrays, axis=2).transpose((2, 0, 1))

    if len(mappings) != len(image_arrays):
        raise UnityObservationException(
            f"Compressed observation and its mapping had different number of channels - "
            f"observation had {len(image_arrays)} channels but its mapping had {len(mappings)} channels"
        )
    if len({m for m in mappings if m > -1}) != max(mappings) + 1:
        raise UnityObservationException(
            f"Invalid Compressed Channel Mapping: the mapping {mappings} does not have the correct format."
        )
    if max(mappings) >= len(image_arrays):
        raise UnityObservationException(
            f"Invalid Compressed Channel Mapping: the mapping has index larger than the total "
            f"number of channels in observation - mapping index {max(mappings)} is"
            f"invalid for input observation with {len(image_arrays)} channels."
        )

    processed_image_arrays: List[np.array] = [[] for _ in range(max(mappings) + 1)]
    for mapping_idx, img in zip(mappings, image_arrays):
        if mapping_idx > -1:
            processed_image_arrays[mapping_idx].append(img)

    for i, img_array in enumerate(processed_image_arrays):
        processed_image_arrays[i] = np.mean(img_array, axis=0)
    img = np.stack(processed_image_arrays, axis=2)
    return img


def _process_images_num_channels(image_arrays, expected_channels):
    """
    Helper function for processing decompressed images with number of expected channels.
    This is for old API without mapping provided. Use the first n channel, n=expected_channels.
    """
    if expected_channels == 1:
        # Convert to grayscale
        img = np.mean(image_arrays[0], axis=2)
        img = np.reshape(img, [img.shape[0], img.shape[1], 1])
    else:
        img = np.concatenate(image_arrays, axis=2)
        # We can drop additional channels since they may need to be added to include
        # numbers of observation channels not divisible by 3.
        actual_channels = list(img.shape)[2]
        if actual_channels > expected_channels:
            img = img[..., 0:expected_channels]
    return img


def _check_observations_match_spec(
    obs_index: int,
    observation_spec: ObservationSpec,
    agent_info_list: Collection[AgentInfoProto],
) -> None:
    """
    Check that all the observations match the expected size.
    This gives a nicer error than a cryptic numpy error later.
    """
    expected_obs_shape = tuple(observation_spec.shape)
    for agent_info in agent_info_list:
        agent_obs_shape = tuple(agent_info.observations[obs_index].shape)
        if expected_obs_shape != agent_obs_shape:
            raise UnityObservationException(
                f"Observation at index={obs_index} for agent with "
                f"id={agent_info.id} didn't match the ObservationSpec. "
                f"Expected shape {expected_obs_shape} but got {agent_obs_shape}."
            )


@timed
def _observation_to_np_array(
    obs: ObservationProto, expected_shape: Optional[Iterable[int]] = None
) -> np.ndarray:
    """
    Converts observation proto into numpy array of the appropriate size.
    :param obs: observation proto to be converted
    :param expected_shape: optional shape information, used for sanity checks.
    :return: processed numpy array of observation from environment
    """
    if expected_shape is not None:
        if list(obs.shape) != list(expected_shape):
            raise UnityObservationException(
                f"Observation did not have the expected shape - got {obs.shape} but expected {expected_shape}"
            )
    expected_channels = obs.shape[2]
    if obs.compression_type == COMPRESSION_TYPE_NONE:
        img = np.array(obs.float_data.data, dtype=np.float32)
        img = np.reshape(img, obs.shape)
        return img
    else:
        img = process_pixels(
            obs.compressed_data, expected_channels, list(obs.compressed_channel_mapping)
        )
        # Compare decompressed image size to observation shape and make sure they match
        if list(obs.shape) != list(img.shape):
            raise UnityObservationException(
                f"Decompressed observation did not have the expected shape - "
                f"decompressed had {img.shape} but expected {obs.shape}"
            )
        return img


@timed
def _process_maybe_compressed_observation(
    obs_index: int,
    observation_spec: ObservationSpec,
    agent_info_list: Collection[AgentInfoProto],
) -> np.ndarray:
    shape = cast(Tuple[int, int, int], observation_spec.shape)
    if len(agent_info_list) == 0:
        return np.zeros((0, shape[0], shape[1], shape[2]), dtype=np.float32)

    try:
        batched_visual = [
            _observation_to_np_array(agent_obs.observations[obs_index], shape)
            for agent_obs in agent_info_list
        ]
    except ValueError:
        # Try to get a more useful error message
        _check_observations_match_spec(obs_index, observation_spec, agent_info_list)
        # If that didn't raise anything, raise the original error
        raise
    return np.array(batched_visual, dtype=np.float32)


def _raise_on_nan_and_inf(data: np.array, source: str) -> np.array:
    # Check for NaNs or Infinite values in the observation or reward data.
    # If there's a NaN in the observations, the np.mean() result will be NaN
    # If there's an Infinite value (either sign) then the result will be Inf
    # See https://stackoverflow.com/questions/6736590/fast-check-for-nan-in-numpy for background
    # Note that a very large values (larger than sqrt(float_max)) will result in an Inf value here
    # Raise a Runtime error in the case that NaNs or Infinite values make it into the data.
    if data.size == 0:
        return data

    d = np.mean(data)
    has_nan = np.isnan(d)
    has_inf = not np.isfinite(d)

    if has_nan:
        raise RuntimeError(f"The {source} provided had NaN values.")
    if has_inf:
        raise RuntimeError(f"The {source} provided had Infinite values.")


@timed
def _process_rank_one_or_two_observation(
    obs_index: int,
    observation_spec: ObservationSpec,
    agent_info_list: Collection[AgentInfoProto],
) -> np.ndarray:
    if len(agent_info_list) == 0:
        return np.zeros((0,) + observation_spec.shape, dtype=np.float32)
    try:
        np_obs = np.array(
            [
                agent_obs.observations[obs_index].float_data.data
                for agent_obs in agent_info_list
            ],
            dtype=np.float32,
        ).reshape((len(agent_info_list),) + observation_spec.shape)
    except ValueError:
        # Try to get a more useful error message
        _check_observations_match_spec(obs_index, observation_spec, agent_info_list)
        # If that didn't raise anything, raise the original error
        raise
    _raise_on_nan_and_inf(np_obs, "observations")
    return np_obs


@timed
def steps_from_proto(
    agent_info_list: Collection[AgentInfoProto], behavior_spec: BehaviorSpec
) -> Tuple[DecisionSteps, TerminalSteps]:
    decision_agent_info_list = [
        agent_info for agent_info in agent_info_list if not agent_info.done
    ]
    terminal_agent_info_list = [
        agent_info for agent_info in agent_info_list if agent_info.done
    ]
    decision_obs_list: List[np.ndarray] = []
    terminal_obs_list: List[np.ndarray] = []
    for obs_index, observation_spec in enumerate(behavior_spec.observation_specs):
        is_visual = len(observation_spec.shape) == 3
        if is_visual:
            decision_obs_list.append(
                _process_maybe_compressed_observation(
                    obs_index, observation_spec, decision_agent_info_list
                )
            )
            terminal_obs_list.append(
                _process_maybe_compressed_observation(
                    obs_index, observation_spec, terminal_agent_info_list
                )
            )
        else:
            decision_obs_list.append(
                _process_rank_one_or_two_observation(
                    obs_index, observation_spec, decision_agent_info_list
                )
            )
            terminal_obs_list.append(
                _process_rank_one_or_two_observation(
                    obs_index, observation_spec, terminal_agent_info_list
                )
            )
    decision_rewards = np.array(
        [agent_info.reward for agent_info in decision_agent_info_list], dtype=np.float32
    )
    terminal_rewards = np.array(
        [agent_info.reward for agent_info in terminal_agent_info_list], dtype=np.float32
    )

    decision_group_rewards = np.array(
        [agent_info.group_reward for agent_info in decision_agent_info_list],
        dtype=np.float32,
    )
    terminal_group_rewards = np.array(
        [agent_info.group_reward for agent_info in terminal_agent_info_list],
        dtype=np.float32,
    )

    _raise_on_nan_and_inf(decision_rewards, "rewards")
    _raise_on_nan_and_inf(terminal_rewards, "rewards")
    _raise_on_nan_and_inf(decision_group_rewards, "group_rewards")
    _raise_on_nan_and_inf(terminal_group_rewards, "group_rewards")

    decision_group_id = [agent_info.group_id for agent_info in decision_agent_info_list]
    terminal_group_id = [agent_info.group_id for agent_info in terminal_agent_info_list]

    max_step = np.array(
        [agent_info.max_step_reached for agent_info in terminal_agent_info_list],
        dtype=bool,
    )
    decision_agent_id = np.array(
        [agent_info.id for agent_info in decision_agent_info_list], dtype=np.int32
    )
    terminal_agent_id = np.array(
        [agent_info.id for agent_info in terminal_agent_info_list], dtype=np.int32
    )
    action_mask = None
    if behavior_spec.action_spec.discrete_size > 0:
        if any(
            [agent_info.action_mask is not None]
            for agent_info in decision_agent_info_list
        ):
            n_agents = len(decision_agent_info_list)
            a_size = np.sum(behavior_spec.action_spec.discrete_branches)
            mask_matrix = np.ones((n_agents, a_size), dtype=bool)
            for agent_index, agent_info in enumerate(decision_agent_info_list):
                if agent_info.action_mask is not None:
                    if len(agent_info.action_mask) == a_size:
                        mask_matrix[agent_index, :] = [
                            False if agent_info.action_mask[k] else True
                            for k in range(a_size)
                        ]
            action_mask = (1 - mask_matrix).astype(bool)
            indices = _generate_split_indices(
                behavior_spec.action_spec.discrete_branches
            )
            action_mask = np.split(action_mask, indices, axis=1)
    return (
        DecisionSteps(
            decision_obs_list,
            decision_rewards,
            decision_agent_id,
            action_mask,
            decision_group_id,
            decision_group_rewards,
        ),
        TerminalSteps(
            terminal_obs_list,
            terminal_rewards,
            max_step,
            terminal_agent_id,
            terminal_group_id,
            terminal_group_rewards,
        ),
    )


def _generate_split_indices(dims):
    if len(dims) <= 1:
        return ()
    result = (dims[0],)
    for i in range(len(dims) - 2):
        result += (dims[i + 1] + result[i],)
    return result