Spaces:
Building
Building
# Copyright 2024 The etils Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Text utils.""" | |
from __future__ import annotations | |
import contextlib | |
import dataclasses | |
import difflib | |
import inspect | |
import reprlib | |
import sys | |
import textwrap | |
from typing import Any, Iterable, Iterator, Union | |
from etils.epy import py_utils | |
_BRACE_TO_BRACES = { | |
'(': ('(', ')'), | |
'[': ('[', ']'), | |
'{': ('{', '}'), | |
} | |
class _Line: | |
"""Line item.""" | |
content: str | |
indent_lvl: int | |
indent_size: int | |
class Lines: | |
"""Util to build multi-line text. | |
Useful for pretty-print tools and human readable `__repr__`. | |
Example: | |
```python | |
d = {'a': 1, 'b': 2} | |
lines = epy.Lines() | |
lines += 'dict(' | |
with lines.indent(): | |
for k, v in d.items(): | |
lines += f'{k}={v},' | |
lines += ')' | |
text = lines.join() | |
``` | |
Output: | |
``` | |
dict( | |
a=1, | |
b=2, | |
) | |
``` | |
""" | |
def __init__(self, *, indent: int = 4): | |
self._lines: list[_Line] = [] | |
self._indent_size = indent | |
self._indent_lvl = 0 | |
def append(self, line: str) -> None: | |
"""Append a new line `str`.""" | |
if not isinstance(line, str): | |
raise TypeError(f'Lines should be added with `str`. Got {line!r}') | |
self._lines.append( | |
_Line( | |
content=line, | |
indent_lvl=self._indent_lvl, | |
indent_size=self._indent_size, | |
), | |
) | |
def extend(self, iterable: Iterable[str]) -> None: | |
"""Append all the new line `str` from the iterable.""" | |
for line in iterable: | |
self.append(line) | |
def __iadd__(self, line: str) -> Lines: | |
"""Append a new line `str`.""" | |
self.append(line) | |
return self | |
def indent(self) -> Iterator[None]: | |
self._indent_lvl += 1 | |
try: | |
yield | |
finally: | |
self._indent_lvl -= 1 | |
def join(self, *, collapse: bool = False) -> str: | |
"""Returns the lines. | |
Args: | |
collapse: If `True`, all lines are merged together in a single line. | |
Returns: | |
text: All lines merged together | |
""" | |
lines = [] | |
for line in self._lines: | |
content = line.content | |
if not collapse: | |
# Add the indentation to all the sub-lines | |
indentation = ' ' * line.indent_lvl * line.indent_size | |
content = textwrap.indent(content, indentation) | |
lines.append(content) | |
if collapse: | |
token = '' | |
else: | |
token = '\n' | |
return token.join(lines) | |
def make_block( | |
cls, | |
header: str = '', | |
content: str | dict[str, Any] | list[Any] | tuple[Any, ...] = (), | |
*, | |
braces: Union[str, tuple[str, str]] = '(', | |
equal: str = '=', | |
limit: int = 20, | |
) -> str: | |
"""Util function to create a code block. | |
Example: | |
```python | |
epy.Lines.make_block('A', {}) == 'A()' | |
epy.Lines.make_block('A', {'x': '1'}) == 'A(x=1)' | |
epy.Lines.make_block('A', {'x': '1', 'y': '2'}) == '''A( | |
x=1, | |
y=2, | |
)''' | |
``` | |
Pattern is as: | |
``` | |
{header}{braces[0]} | |
{k}={v}, | |
... | |
{braces[1]} | |
``` | |
Args: | |
header: Prefix before the brace | |
content: Dict of key to values. One line will be displayed per item if | |
`len(content) > 1`. Otherwise the code is collapsed | |
braces: Brace type (`(`, `[`, `{`), can be tuple for custom open/close. | |
equal: The separator (`=`, `: `) | |
limit: Strings smaller than this will be collapsed | |
Returns: | |
The block string | |
""" | |
if isinstance(braces, str): | |
braces = _BRACE_TO_BRACES[braces] | |
brace_start, brace_end = braces | |
if isinstance(content, str): | |
content = [content] | |
if isinstance(content, dict): | |
parts = [f'{k}{equal}{pretty_repr(v)}' for k, v in content.items()] | |
elif isinstance(content, (list, tuple)): | |
parts = [f'{pretty_repr(v)}' for v in content] | |
else: | |
raise TypeError(f'Invalid fields {type(content)}') | |
collapse = len(parts) <= 1 | |
if any('\n' in p for p in parts): | |
collapse = False | |
# Also collapse string which are small | |
elif sum(len(p) for p in parts) <= limit: | |
collapse = True | |
lines = cls() | |
lines += f'{header}{brace_start}' | |
with lines.indent(): | |
if collapse: | |
lines += ', '.join(parts) | |
else: | |
for p in parts: | |
lines += f'{p},' | |
lines += f'{brace_end}' | |
return lines.join(collapse=collapse) | |
def pprint(obj: Any) -> None: | |
"""Pretty print `obj`.""" | |
print(pretty_repr(obj)) | |
def pretty_repr(obj: Any) -> str: | |
"""Pretty `repr(obj)` for nested list, dict, dataclasses,...""" | |
return pretty_repr_top_level(obj) | |
def pretty_repr_top_level(obj: Any, *, force: bool = False) -> str: | |
"""Pretty `repr(obj)` for nested list, dict, dataclasses,... | |
This version do not use `@reprlib.recursive_repr()` to avoid bug when used | |
inside `__repr__`: | |
```python | |
class A: | |
def __repr__(self): | |
return epy.pretty_repr_top_level(self) | |
epy.pretty_repr(A()) # Do not display `...` | |
``` | |
Args: | |
obj: Object to display | |
force: Force the pretty_repr, even if the object has a custom `__repr__`. | |
This is useful when the `__repr__` implementation itself want to call | |
`pretty_repr(self)`. | |
Returns: | |
Repr | |
""" | |
# TODO(epot): Should still somehow register `self` with the `recursive_repr`, | |
# should support both: | |
# pretty_repr(a) == A(recursive=...) | |
# a.__repr__() == A(recursive=...) | |
if isinstance(obj, str): | |
return repr(obj) | |
elif py_utils.is_namedtuple(obj): | |
# TODO(epot): Could check if obj has custom `__repr__` | |
return Lines.make_block( | |
header=obj.__class__.__name__, | |
content={ | |
field_name: getattr(obj, field_name) | |
for field_name in type(obj)._fields | |
}, | |
) | |
elif type(obj) in (list, tuple): # Skip sub-class as could have custom repr | |
lines = Lines.make_block( | |
content=obj, | |
braces='[' if isinstance(obj, list) else '(', | |
) | |
# Singleton tuple have a trailing `,` | |
if isinstance(obj, tuple) and len(obj) == 1: | |
lines = lines.removesuffix(')') + ',)' | |
return lines | |
elif type(obj) is dict: # pylint: disable=unidiomatic-typecheck | |
return Lines.make_block( | |
content={repr(k): v for k, v in obj.items()}, | |
braces='{', | |
equal=': ', | |
) | |
elif _is_datclass(obj, force=force): | |
all_fields = dataclasses.fields(obj) | |
return Lines.make_block( | |
header=obj.__class__.__name__, | |
content={ | |
field.name: getattr(obj, field.name) | |
for field in all_fields | |
if field.repr | |
}, | |
) | |
elif _is_attr(obj, force=force): | |
import attr # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
all_fields = attr.fields_dict(type(obj)) | |
return Lines.make_block( | |
header=obj.__class__.__name__, | |
content={ | |
field.name: getattr(obj, field.name) | |
for field in all_fields.values() | |
if field.repr | |
}, | |
) | |
else: | |
return repr(obj) | |
def has_default_repr(cls: Any) -> bool: | |
"""Returns `True` if the dataclass do not overwrite `__repr__`.""" | |
repr_fn = inspect.unwrap(cls.__repr__) | |
return repr_fn.__qualname__ == '__create_fn__.<locals>.__repr__' | |
def _is_datclass(obj: Any, *, force: bool = False) -> bool: | |
"""Returns `True` if the object is a dataclass.""" | |
if isinstance(obj, type): # Class are not pretty-print | |
return False | |
if not dataclasses.is_dataclass(obj): | |
return False | |
if force: | |
return True | |
if not obj.__dataclass_params__.repr: # dataclass(repr=False) | |
return False | |
# TODO(epot): Better support for recursive `pretty_repr` to avoid infinite | |
# loops. | |
if has_default_repr(type(obj)) or type(obj).__repr__ in ( | |
pretty_repr, | |
pretty_repr_top_level, | |
): | |
return True | |
return False | |
def _is_attr(obj: Any, *, force: bool = False) -> bool: | |
"""Returns `True` if the object is a `attr` dataclass.""" | |
if 'attr' not in sys.modules: | |
return False | |
import attr # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
if not attr.has(type(obj)): | |
return False | |
if force: | |
return True | |
if not (doc := type(obj).__repr__.__doc__): | |
return False | |
if not doc.startswith('Method generated by attrs'): | |
return False | |
return True | |
def dedent(text: str) -> str: | |
r"""Wrapper around `textwrap.dedent` which also `strip()` the content. | |
Before: | |
```python | |
text = textwrap.dedent( | |
\"\"\"\\ | |
A( | |
x=1, | |
)\"\"\" | |
) | |
``` | |
After: | |
```python | |
text = epy.dedent( | |
\"\"\" | |
A( | |
x=1, | |
) | |
\"\"\" | |
) | |
``` | |
Args: | |
text: The text to dedent | |
Returns: | |
The dedented text | |
""" | |
return textwrap.dedent(text).strip() | |
def diff_str(a: str | object, b: str | object) -> str: | |
"""Pretty diff between 2 objects. | |
Args: | |
a: Object/str to compare | |
b: Object/str to compare | |
Returns: | |
The diff string | |
""" | |
if not isinstance(a, str): | |
a = pretty_repr(a).split('\n') | |
if not isinstance(b, str): | |
b = pretty_repr(b).split('\n') | |
diff = difflib.ndiff(a, b) | |
return '\n'.join(diff) | |