Spaces:
Building
Building
# Copyright 2024 The etils Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Utils for colab/jupyter. | |
Usage: | |
```python | |
from etils.ecolab import array_as_img | |
``` | |
""" | |
from __future__ import annotations | |
import functools | |
import traceback | |
from typing import Any, Optional, Tuple | |
from etils import enp | |
from etils.epy import _internal | |
with _internal.check_missing_deps(): | |
# pylint: disable=g-import-not-at-top | |
import IPython | |
import IPython.display | |
import mediapy as media | |
# pylint: enable=g-import-not-at-top | |
Array = Any | |
# Images smaller than this are displayed as text (e.g. `np.array([...`) | |
_MIN_IMG_SHAPE: Tuple[int, int] = (10, 10) | |
def show(*objs, **kwargs) -> None: | |
"""Alias for `IPython.display.display`.""" | |
return IPython.display.display(*objs, **kwargs) | |
def auto_plot_array( | |
*, | |
# If updating this, also update `_array_repr_html_inner` !!! | |
video_min_num_frames: int = 15, | |
# Images outside this range are rescalled | |
height: None | int | tuple[int, int] = (100, 250), | |
show_images_kwargs: Optional[dict[str, Any]] = None, | |
show_videos_kwargs: Optional[dict[str, Any]] = None, | |
) -> None: | |
"""If called, 2d/3d imgage arrays will be plotted as images in colab/jupyter. | |
Usage: | |
>>> ecolab.auto_plot_array() | |
>>> np.zeros((28, 28, 3)) # Displayed as image | |
Args: | |
video_min_num_frames: Video `(num_frames, h, w, c)` with less than this | |
number of frames will be displayed as individual images | |
height: `(min, max)` image height in pixels. Images smaller/larger will be | |
reshaped. `None` to disable. If a single number, assume `min == max`. | |
show_images_kwargs: Kwargs forwarded to `mediapy.show_images` | |
show_videos_kwargs: Kwargs forwarded to `mediapy.show_videos` | |
""" | |
ipython = IPython.get_ipython() | |
if ipython is None: | |
return # Non-notebook environement | |
array_repr_html_fn = functools.partial( | |
array_repr_html, | |
video_min_num_frames=video_min_num_frames, | |
height=height, | |
show_images_kwargs=show_images_kwargs, | |
show_videos_kwargs=show_videos_kwargs, | |
) | |
# Register the new representation fo np, tf and jax array | |
print('Display big np/tf/jax arrays as image for nicer IPython display') | |
formatter = ipython.display_formatter.formatters['text/html'] | |
# TODO(epot): How to support lazy-imports without catching everything ? | |
# Try registering jax | |
try: | |
jnp = enp.lazy.jnp | |
except ImportError: | |
pass | |
else: | |
# The array type is not exposed in the public API (registering jnp.ndarray | |
# does not works), so dynamically extracting the type | |
jax_array_cls = type(jnp.zeros(shape=())) # DeviceArrayBase | |
formatter.for_type(jax_array_cls, array_repr_html_fn) | |
# Try registering TF | |
try: | |
tf = enp.lazy.tf | |
except ImportError: | |
pass | |
else: | |
formatter.for_type(tf.Tensor, array_repr_html_fn) | |
# Try registering Torch | |
try: | |
torch = enp.lazy.torch | |
except ImportError: | |
pass | |
else: | |
formatter.for_type(torch.Tensor, array_repr_html_fn) | |
# Register np | |
formatter.for_type(enp.lazy.np.ndarray, array_repr_html_fn) | |
def array_repr_html( | |
array: Array, | |
**kwargs: Any, | |
) -> Optional[str]: | |
"""Returns the HTML `<img/>` repr, or `None` if array is not an image.""" | |
try: | |
return _array_repr_html_inner(array, **kwargs) | |
except Exception: | |
# IPython display silence exceptions, so display it here | |
traceback.print_exc() | |
raise | |
def _array_repr_html_inner( | |
img: Array, | |
*, | |
# If updating this, also update `auto_plot_array` !!! | |
video_min_num_frames: int = 15, | |
height: None | int | tuple[int, int] = (100, 250), | |
show_images_kwargs: Optional[dict[str, Any]] = None, | |
show_videos_kwargs: Optional[dict[str, Any]] = None, | |
) -> Optional[str]: | |
"""Display the normalized img, or `None` if the input is not an image.""" | |
show_images_kwargs = show_images_kwargs or {} | |
show_videos_kwargs = show_videos_kwargs or {} | |
if not enp.lazy.is_array(img): # Not an array | |
return None | |
# Normalize tf.Tensor into np.array | |
if enp.lazy.is_tf(img) or enp.lazy.is_torch(img): | |
img = img.numpy() | |
shape = img.shape | |
ndim = len(shape) | |
# Infer the array type (image or video ?) | |
if ndim == 2: | |
img_shape = shape | |
num_channel = 1 | |
elif ndim == 3: | |
img_shape = shape[:2] | |
num_channel = shape[-1] | |
elif ndim == 4: | |
img_shape = shape[1:3] | |
num_channel = shape[-1] | |
num_frames = shape[0] | |
else: | |
return None | |
# Filter non-images | |
if 0 in shape: # Empty image | |
return None | |
if _smaller_than(img_shape, _MIN_IMG_SHAPE): | |
return None | |
if num_channel not in {1, 3, 4}: | |
return None | |
show_images_kwargs = show_images_kwargs.copy() | |
show_videos_kwargs = show_videos_kwargs.copy() | |
# Resize small/large images to X pixels (otherwise, difficult to see) | |
if height: | |
if isinstance(height, int): | |
min_height = height | |
max_height = height | |
else: | |
min_height, max_height = height | |
del height | |
target_height = img_shape[0] # (h, w) | |
target_height = max(target_height, min_height) | |
target_height = min(target_height, max_height) | |
show_images_kwargs.setdefault('height', target_height) | |
show_videos_kwargs.setdefault('height', target_height) | |
if ndim < 4: | |
out = media.show_image(img, return_html=True, **show_images_kwargs) | |
elif num_frames < video_min_num_frames: | |
out = media.show_images(img, return_html=True, **show_images_kwargs) | |
else: | |
# TODO(epot): media.show_video does not support single channel video | |
if num_channel != 3: | |
return None | |
# Dynamically compute the frame-rate, capped at 25 FPS | |
fps = min(num_frames // 5, 25.0) | |
show_videos_kwargs.setdefault('fps', fps) | |
out = media.show_video( | |
img, | |
return_html=True, | |
**show_videos_kwargs, | |
) | |
return out | |
def _smaller_than(shape: tuple[int, ...], min_shape: tuple[int, ...]) -> bool: | |
"""Returns True if one of the dim of `shape` is smaller than `min_shape`.""" | |
return any(dim < min_dim for dim, min_dim in zip(shape, min_shape)) | |