PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text utils."""
from __future__ import annotations
import contextlib
import dataclasses
import difflib
import inspect
import reprlib
import sys
import textwrap
from typing import Any, Iterable, Iterator, Union
from etils.epy import py_utils
_BRACE_TO_BRACES = {
'(': ('(', ')'),
'[': ('[', ']'),
'{': ('{', '}'),
}
@dataclasses.dataclass
class _Line:
"""Line item."""
content: str
indent_lvl: int
indent_size: int
class Lines:
"""Util to build multi-line text.
Useful for pretty-print tools and human readable `__repr__`.
Example:
```python
d = {'a': 1, 'b': 2}
lines = epy.Lines()
lines += 'dict('
with lines.indent():
for k, v in d.items():
lines += f'{k}={v},'
lines += ')'
text = lines.join()
```
Output:
```
dict(
a=1,
b=2,
)
```
"""
def __init__(self, *, indent: int = 4):
self._lines: list[_Line] = []
self._indent_size = indent
self._indent_lvl = 0
def append(self, line: str) -> None:
"""Append a new line `str`."""
if not isinstance(line, str):
raise TypeError(f'Lines should be added with `str`. Got {line!r}')
self._lines.append(
_Line(
content=line,
indent_lvl=self._indent_lvl,
indent_size=self._indent_size,
),
)
def extend(self, iterable: Iterable[str]) -> None:
"""Append all the new line `str` from the iterable."""
for line in iterable:
self.append(line)
def __iadd__(self, line: str) -> Lines:
"""Append a new line `str`."""
self.append(line)
return self
@contextlib.contextmanager
def indent(self) -> Iterator[None]:
self._indent_lvl += 1
try:
yield
finally:
self._indent_lvl -= 1
def join(self, *, collapse: bool = False) -> str:
"""Returns the lines.
Args:
collapse: If `True`, all lines are merged together in a single line.
Returns:
text: All lines merged together
"""
lines = []
for line in self._lines:
content = line.content
if not collapse:
# Add the indentation to all the sub-lines
indentation = ' ' * line.indent_lvl * line.indent_size
content = textwrap.indent(content, indentation)
lines.append(content)
if collapse:
token = ''
else:
token = '\n'
return token.join(lines)
@classmethod
def make_block(
cls,
header: str = '',
content: str | dict[str, Any] | list[Any] | tuple[Any, ...] = (),
*,
braces: Union[str, tuple[str, str]] = '(',
equal: str = '=',
limit: int = 20,
) -> str:
"""Util function to create a code block.
Example:
```python
epy.Lines.make_block('A', {}) == 'A()'
epy.Lines.make_block('A', {'x': '1'}) == 'A(x=1)'
epy.Lines.make_block('A', {'x': '1', 'y': '2'}) == '''A(
x=1,
y=2,
)'''
```
Pattern is as:
```
{header}{braces[0]}
{k}={v},
...
{braces[1]}
```
Args:
header: Prefix before the brace
content: Dict of key to values. One line will be displayed per item if
`len(content) > 1`. Otherwise the code is collapsed
braces: Brace type (`(`, `[`, `{`), can be tuple for custom open/close.
equal: The separator (`=`, `: `)
limit: Strings smaller than this will be collapsed
Returns:
The block string
"""
if isinstance(braces, str):
braces = _BRACE_TO_BRACES[braces]
brace_start, brace_end = braces
if isinstance(content, str):
content = [content]
if isinstance(content, dict):
parts = [f'{k}{equal}{pretty_repr(v)}' for k, v in content.items()]
elif isinstance(content, (list, tuple)):
parts = [f'{pretty_repr(v)}' for v in content]
else:
raise TypeError(f'Invalid fields {type(content)}')
collapse = len(parts) <= 1
if any('\n' in p for p in parts):
collapse = False
# Also collapse string which are small
elif sum(len(p) for p in parts) <= limit:
collapse = True
lines = cls()
lines += f'{header}{brace_start}'
with lines.indent():
if collapse:
lines += ', '.join(parts)
else:
for p in parts:
lines += f'{p},'
lines += f'{brace_end}'
return lines.join(collapse=collapse)
def pprint(obj: Any) -> None:
"""Pretty print `obj`."""
print(pretty_repr(obj))
@reprlib.recursive_repr()
def pretty_repr(obj: Any) -> str:
"""Pretty `repr(obj)` for nested list, dict, dataclasses,..."""
return pretty_repr_top_level(obj)
def pretty_repr_top_level(obj: Any, *, force: bool = False) -> str:
"""Pretty `repr(obj)` for nested list, dict, dataclasses,...
This version do not use `@reprlib.recursive_repr()` to avoid bug when used
inside `__repr__`:
```python
class A:
def __repr__(self):
return epy.pretty_repr_top_level(self)
epy.pretty_repr(A()) # Do not display `...`
```
Args:
obj: Object to display
force: Force the pretty_repr, even if the object has a custom `__repr__`.
This is useful when the `__repr__` implementation itself want to call
`pretty_repr(self)`.
Returns:
Repr
"""
# TODO(epot): Should still somehow register `self` with the `recursive_repr`,
# should support both:
# pretty_repr(a) == A(recursive=...)
# a.__repr__() == A(recursive=...)
if isinstance(obj, str):
return repr(obj)
elif py_utils.is_namedtuple(obj):
# TODO(epot): Could check if obj has custom `__repr__`
return Lines.make_block(
header=obj.__class__.__name__,
content={
field_name: getattr(obj, field_name)
for field_name in type(obj)._fields
},
)
elif type(obj) in (list, tuple): # Skip sub-class as could have custom repr
lines = Lines.make_block(
content=obj,
braces='[' if isinstance(obj, list) else '(',
)
# Singleton tuple have a trailing `,`
if isinstance(obj, tuple) and len(obj) == 1:
lines = lines.removesuffix(')') + ',)'
return lines
elif type(obj) is dict: # pylint: disable=unidiomatic-typecheck
return Lines.make_block(
content={repr(k): v for k, v in obj.items()},
braces='{',
equal=': ',
)
elif _is_datclass(obj, force=force):
all_fields = dataclasses.fields(obj)
return Lines.make_block(
header=obj.__class__.__name__,
content={
field.name: getattr(obj, field.name)
for field in all_fields
if field.repr
},
)
elif _is_attr(obj, force=force):
import attr # pylint: disable=g-import-not-at-top # pytype: disable=import-error
all_fields = attr.fields_dict(type(obj))
return Lines.make_block(
header=obj.__class__.__name__,
content={
field.name: getattr(obj, field.name)
for field in all_fields.values()
if field.repr
},
)
else:
return repr(obj)
def has_default_repr(cls: Any) -> bool:
"""Returns `True` if the dataclass do not overwrite `__repr__`."""
repr_fn = inspect.unwrap(cls.__repr__)
return repr_fn.__qualname__ == '__create_fn__.<locals>.__repr__'
def _is_datclass(obj: Any, *, force: bool = False) -> bool:
"""Returns `True` if the object is a dataclass."""
if isinstance(obj, type): # Class are not pretty-print
return False
if not dataclasses.is_dataclass(obj):
return False
if force:
return True
if not obj.__dataclass_params__.repr: # dataclass(repr=False)
return False
# TODO(epot): Better support for recursive `pretty_repr` to avoid infinite
# loops.
if has_default_repr(type(obj)) or type(obj).__repr__ in (
pretty_repr,
pretty_repr_top_level,
):
return True
return False
def _is_attr(obj: Any, *, force: bool = False) -> bool:
"""Returns `True` if the object is a `attr` dataclass."""
if 'attr' not in sys.modules:
return False
import attr # pylint: disable=g-import-not-at-top # pytype: disable=import-error
if not attr.has(type(obj)):
return False
if force:
return True
if not (doc := type(obj).__repr__.__doc__):
return False
if not doc.startswith('Method generated by attrs'):
return False
return True
def dedent(text: str) -> str:
r"""Wrapper around `textwrap.dedent` which also `strip()` the content.
Before:
```python
text = textwrap.dedent(
\"\"\"\\
A(
x=1,
)\"\"\"
)
```
After:
```python
text = epy.dedent(
\"\"\"
A(
x=1,
)
\"\"\"
)
```
Args:
text: The text to dedent
Returns:
The dedented text
"""
return textwrap.dedent(text).strip()
def diff_str(a: str | object, b: str | object) -> str:
"""Pretty diff between 2 objects.
Args:
a: Object/str to compare
b: Object/str to compare
Returns:
The diff string
"""
if not isinstance(a, str):
a = pretty_repr(a).split('\n')
if not isinstance(b, str):
b = pretty_repr(b).split('\n')
diff = difflib.ndiff(a, b)
return '\n'.join(diff)