File size: 14,241 Bytes
f5f3483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Auto display statements ending with `;` on Colab."""

from __future__ import annotations

import ast
import dataclasses
import enum
import functools
import re
import traceback
import typing
from typing import Any, Callable, TypeVar

from etils import epy
from etils.ecolab import array_as_img
from etils.ecolab import highlight_util
from etils.ecolab.inspects import core as inspects
from etils.etree import jax as etree  # pylint: disable=g-importing-member
import IPython
import packaging

_T = TypeVar('_T')
_DisplayFn = Callable[[Any], None]

# Old API (before IPython 7.0)
_IS_LEGACY_API = packaging.version.parse(
    IPython.__version__
) < packaging.version.parse('7')


class _Options(enum.StrEnum):
  """Available options."""

  SPEC = 's'
  INSPECT = 'i'
  ARRAY = 'a'
  PPRINT = 'p'
  SYNTAX_HIGHLIGHT = 'h'
  LINE = 'l'
  QUIET = 'q'

  @classmethod
  @property
  def all_letters(cls) -> set[str]:
    return {option.value for option in cls}


def auto_display(activate: bool = True) -> None:
  r"""Activate auto display statements ending with `;` (activated by default).

  Add a trailing `;` to any statement (assignment, expression, return
  statement) to display the current line.
  This call `IPython.display.display()` for pretty display.

  This change the default IPython behavior where `;` added to the last statement
  of the cell still silence the output.

  ```python
  x = my_fn();  # Display `my_fn()`

  my_fn();  # Display `my_fn()`
  ```

  `;` behavior can be disabled with `ecolab.auto_display(False)`

  Format:

  *   `my_obj;`: Alias for `IPython.display.display(x)`
  *   `my_obj;s`: (`spec`) Alias for
      `IPython.display.display(etree.spec_like(x))`
  *   `my_obj;i`: (`inspect`) Alias for `ecolab.inspect(x)`
  *   `my_obj;a`: (`array`) Alias for `media.show_images(x)` /
      `media.show_videos(x)` (`ecolab.auto_plot_array` behavior)
  *   `my_obj;p`: (`pretty_display`) Alias for `print(epy.pretty_repr(x))`.
      Can be combined with `s`. Used for pretty print `dataclasses` or print
      strings containing new lines (rather than displaying `\n`).
  *   `my_obj;h`: (`syntax_highlight`) Add Python code syntax highlighting (
      using `ecolab.highlight_html`)
  *   `my_obj;q`: (`quiet`) Don't display the line (e.g. last line)
  *   `my_obj;l`: (`line`) Also display the line (can be combined with previous
      statements). Has to be at the end (`;sl` is valid but not `;ls`).

  `p`, `s`, `h`, `l` can be combined.

  A functional API exists:

  ```python
  ecolab.disp(obj, mode='sp')  # Equivalent to `obj;sp`
  ```

  Args:
    activate: Allow to disable `auto_display`
  """
  if not epy.is_notebook():  # No-op outside Colab
    return

  # Register the transformation
  ip = IPython.get_ipython()
  shell = ip.kernel.shell

  # Clear previous transformation to support reload and disable.
  _clear_transform(shell.ast_transformers)
  if _IS_LEGACY_API:
    _clear_transform(shell.input_transformer_manager.python_line_transforms)
    _clear_transform(shell.input_splitter.python_line_transforms)
  else:
    _clear_transform(ip.input_transformers_post)

  if not activate:
    return

  # Register the transformation.
  # The `ast` do not contain the trailing `;` information, and there's no way
  # to access the original source code from `shell.ast_transformers`, so
  # the logic has 2 steps:
  # 1) Record the source code string using `python_line_transforms`
  # 2) Parse and add the `display` the source code with `ast_transformers`

  lines_recorder = _RecordLines()

  if _IS_LEGACY_API:
    _append_transform(
        shell.input_splitter.python_line_transforms, lines_recorder
    )
    _append_transform(
        shell.input_transformer_manager.python_line_transforms, lines_recorder
    )
  else:
    _append_transform(ip.input_transformers_post, lines_recorder)
  _append_transform(
      shell.ast_transformers, _AddDisplayStatement(lines_recorder)
  )


def _clear_transform(list_):
  # Support `ecolab` adhoc `reload`
  for elem in list(list_):
    if hasattr(elem, '__is_ecolab__'):
      list_.remove(elem)


def _append_transform(list_, transform):
  """Append `transform` to `list_` (with `ecolab` reload support)."""
  transform.__is_ecolab__ = True  # Marker to support `ecolab` adhoc `reload`
  list_.append(transform)


if _IS_LEGACY_API and not typing.TYPE_CHECKING:

  class _RecordLines(IPython.core.inputtransformer.InputTransformer):
    """Record lines."""

    def __init__(self):
      self._lines = []
      self.last_lines = []

      self.trailing_stmt_line_nums = {}
      super().__init__()

    def push(self, line):
      """Add a line."""
      self._lines.append(line)
      return line

    def reset(self):
      """Reset."""
      if self._lines:  # Save the last recorded lines
        # Required because reset is called multiple times on empty output,
        # so always keep the last non-empty output
        self.last_lines = []
        for line in self._lines:
          self.last_lines.extend(line.split('\n'))
      self._lines.clear()

      self.trailing_stmt_line_nums.clear()
      return

else:

  class _RecordLines:
    """Record lines."""

    def __init__(self):
      self.last_lines = []
      # Additional state (reset at each cell) to keep track of which lines
      # contain trailing statements
      self.trailing_stmt_line_nums = {}

    def __call__(self, lines: list[str]) -> list[str]:
      self.last_lines = [l.rstrip('\n') for l in lines]
      self.trailing_stmt_line_nums = {}
      return lines


# TODO(epot): During the first parsing stage, could implement a fault-tolerant
# parsing to support `line;=` Rather than `line;l`
# Something like that but that would chunk the code in blocks of valid
# statements
# def fault_tolerant_parsing(code: str):
#   lines = code.split('\n')
#   for i in range(len(lines)):
#     try:
#       last_valid = ast.parse('\n'.join(lines[:i]))
#     except SyntaxError:
#       break

#   return ast.unparse(last_valid)


def _reraise_error(fn: _T) -> _T:
  @functools.wraps(fn)
  def decorated(self, node: ast.AST):
    try:
      return fn(self, node)  # pytype: disable=wrong-arg-types
    except Exception as e:  # pylint: disable=broad-exception-caught
      code = '\n'.join(self.lines_recorder.last_lines)
      print(f'Error for code:\n-----\n{code}\n-----')
      traceback.print_exception(e)

  return decorated


class _AddDisplayStatement(ast.NodeTransformer):
  """Transform the `ast` to add the `IPython.display.display` statements."""

  def __init__(self, lines_recorder: _RecordLines):
    self.lines_recorder = lines_recorder
    super().__init__()

  def _maybe_display(
      self, node: ast.Assign | ast.AnnAssign | ast.Expr
  ) -> ast.AST:
    """Wrap the node in a `display()` call."""
    if self._is_alias_stmt(node):  # Alias statements should be no-op
      return ast.Pass()

    line_info = _has_trailing_semicolon(self.lines_recorder.last_lines, node)
    if line_info.has_trailing:
      if node.value is None:  # `AnnAssign().value` can be `None` (`a: int`)
        pass
      else:
        options = ''.join([o.value for o in line_info.options])
        fn_kwargs = [
            ast.keyword('options', ast.Constant(options)),
        ]
        if line_info.print_line:
          fn_kwargs.append(ast.keyword('line_code', _unparse_line(node)))

        node.value = ast.Call(
            func=_parse_expr('ecolab.auto_display_utils._display_and_return'),
            args=[node.value],
            keywords=fn_kwargs,
        )
        node = ast.fix_missing_locations(node)
        self.lines_recorder.trailing_stmt_line_nums[line_info.line_num] = (
            line_info
        )

    return node

  def _is_alias_stmt(self, node: ast.AST) -> bool:
    match node:
      case ast.Expr(value=ast.Name(id=name)):
        pass
      case _:
        return False
    if any(l not in _Options.all_letters for l in name):
      return False
    # The alias is not in the same line as a trailing `;`
    if node.end_lineno - 1 not in self.lines_recorder.trailing_stmt_line_nums:
      return False
    return True

  @_reraise_error
  def visit_Assert(self, node: ast.Assert) -> None:  # pylint: disable=invalid-name
    # Wrap assert so the `node.value` match the expected API
    node = _WrapAssertNode(node)
    node = self._maybe_display(node)  # pytype: disable=wrong-arg-types
    assert isinstance(node, _WrapAssertNode)
    node = node._node  # Unwrap  # pylint: disable=protected-access
    return node

  # pylint: disable=invalid-name
  visit_Assign = _reraise_error(_maybe_display)
  visit_AnnAssign = _reraise_error(_maybe_display)
  visit_Expr = _reraise_error(_maybe_display)
  visit_Return = _reraise_error(_maybe_display)
  # pylint: enable=invalid-name


def _parse_expr(code: str) -> ast.AST:
  return ast.parse(code, mode='eval').body


class _WrapAssertNode(ast.Assert):
  """Like `Assert`, but rename `node.test` to `node.value`."""

  def __init__(self, node: ast.Assert) -> None:
    self._node = node

  def __getattribute__(self, name: str) -> Any:
    if name in ('value', '_node'):
      return super().__getattribute__(name)
    return getattr(self._node, name)

  @property
  def value(self) -> ast.AST:
    return self._node.test

  @value.setter
  def value(self, value: ast.AST) -> None:
    self._node.test = value


@dataclasses.dataclass(frozen=True)
class _LineInfo:
  has_trailing: bool
  options: set[_Options]
  line_num: int

  @property
  def print_line(self) -> bool:
    return _Options.LINE in self.options


def _has_trailing_semicolon(
    code_lines: list[str],
    node: ast.AST,
) -> _LineInfo:
  """Check if `node` has trailing `;`."""
  if isinstance(node, ast.AnnAssign) and node.value is None:
    # `AnnAssign().value` can be `None` (`a: int`), do not print anything
    return _LineInfo(
        has_trailing=False,
        options=set(),
        line_num=-1,
    )

  # Extract the lines of the statement
  line_num = node.end_lineno - 1
  last_line = code_lines[line_num]  # lineno starts at `1`

  # `node.end_col_offset` is in bytes, so UTF-8 characters count 3.
  last_part_of_line = last_line.encode('utf-8')
  last_part_of_line = last_part_of_line[node.end_col_offset :]
  last_part_of_line = last_part_of_line.decode('utf-8')

  # Check if the last character is a `;` token
  has_trailing = False
  options = set()
  if match := _detect_trailing_regex().match(last_part_of_line):
    has_trailing = True
    if match.group('options'):
      options = match.group('options')
      options = {_Options(o) for o in options}

  return _LineInfo(
      has_trailing=has_trailing,
      options=options,
      line_num=line_num,
  )


@functools.cache
def _detect_trailing_regex() -> re.Pattern[str]:
  """Check if the last character is a `;` token."""
  # Match:
  # * `; a`
  # * `; a # Some comment`
  # * `; # Some comment`
  # Do not match:
  # * `; a; b`
  # * `; a=1`

  available_letters = ''.join(sorted(_Options.all_letters))  # pytype: disable=wrong-arg-types
  return re.compile(
      ' *; *'  # Trailing `;` (surrounded by spaces)
      f'(?P<options>[{available_letters}]*)?'  # Optionally a `option` letter
      ' *(?:#.*)?$'  # Line can end by a `# comment`
  )


def _unparse_line(node: ast.AST) -> ast.Constant:
  """Extract the line code."""
  if isinstance(node, ast.Assign):
    node = node.targets
  elif isinstance(node, ast.AnnAssign):
    node = node.target
  return ast.Constant(ast.unparse(node))


def disp(obj: Any, *, mode: str = '') -> None:
  """Display the object.

  This is the functional API for the `;` auto display magic.

  Args:
    obj: The object to display
    mode: Any mode supported by `ecolab.auto_display()`
  """
  if _Options.LINE in mode:
    raise NotImplementedError('Line mode not supported in `disp()`')
  # Do not return anything so the object is not displayed twice at the last
  # instuction of a cell
  _display_and_return(obj, options=mode)


def _display_and_return(
    x: _T,
    *,
    options: str,
    line_code: str | None = None,
) -> _T:
  """Print `x` and return `x`."""
  x_origin = x
  options = {_Options(o) for o in options}

  if _Options.QUIET in options:  # Do not display anything
    return x_origin

  if _Options.SPEC in options:  # Convert to spec
    x = etree.spec_like(x)

  repr_fn = repr
  display_fn = IPython.display.display
  if line_code and _Options.SYNTAX_HIGHLIGHT not in options:
    print(line_code + ' = ', end='')
    # When the next element is a `IPython.display`, the next element is
    # displayed on a new line. This is because `display()` create a new
    # <div> section. So use standard `print` when line is displayed.
    display_fn = lambda x: print(repr(x))

  if _Options.PPRINT in options:
    repr_fn = epy.pretty_repr
    display_fn = _pretty_display

  if _Options.INSPECT in options:
    inspects.inspect(x)
    return x_origin

  if _Options.ARRAY in options:
    html = array_as_img.array_repr_html(x)
    if html is None:
      print(f'Invalid array to display: {type(x)}')
    else:
      _html_display(html)
    return x_origin

  if _Options.SYNTAX_HIGHLIGHT in options:
    x_repr = repr_fn(x)
    if line_code:
      x_repr = f'{line_code} = {x_repr}'
    _html_display(highlight_util.highlight_html(x_repr))
    return x_origin

  display_fn(x)
  return x_origin


def _html_display(html):
  IPython.display.display(IPython.display.HTML(html))


def _pretty_display(x):
  """Print `x` and return `x`."""
  # 2 main use-case:
  # * Print strings (including `\n`)
  # * Pretty-print dataclasses
  if isinstance(x, str):
    print(x)
  else:
    print(epy.pretty_repr(x))