|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ast |
|
import collections |
|
import io |
|
import sys |
|
import token |
|
import tokenize |
|
from abc import ABCMeta |
|
from ast import Module, expr, AST |
|
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union, cast, Any, TYPE_CHECKING |
|
|
|
from six import iteritems |
|
|
|
|
|
if TYPE_CHECKING: |
|
from .astroid_compat import NodeNG |
|
|
|
|
|
|
|
class EnhancedAST(AST): |
|
|
|
first_token = None |
|
last_token = None |
|
lineno = 0 |
|
|
|
AstNode = Union[EnhancedAST, NodeNG] |
|
|
|
if sys.version_info[0] == 2: |
|
TokenInfo = Tuple[int, str, Tuple[int, int], Tuple[int, int], str] |
|
else: |
|
TokenInfo = tokenize.TokenInfo |
|
|
|
|
|
def token_repr(tok_type, string): |
|
|
|
"""Returns a human-friendly representation of a token with the given type and string.""" |
|
|
|
return '%s:%s' % (token.tok_name[tok_type], repr(string).lstrip('u')) |
|
|
|
|
|
class Token(collections.namedtuple('Token', 'type string start end line index startpos endpos')): |
|
""" |
|
TokenInfo is an 8-tuple containing the same 5 fields as the tokens produced by the tokenize |
|
module, and 3 additional ones useful for this module: |
|
|
|
- [0] .type Token type (see token.py) |
|
- [1] .string Token (a string) |
|
- [2] .start Starting (row, column) indices of the token (a 2-tuple of ints) |
|
- [3] .end Ending (row, column) indices of the token (a 2-tuple of ints) |
|
- [4] .line Original line (string) |
|
- [5] .index Index of the token in the list of tokens that it belongs to. |
|
- [6] .startpos Starting character offset into the input text. |
|
- [7] .endpos Ending character offset into the input text. |
|
""" |
|
def __str__(self): |
|
|
|
return token_repr(self.type, self.string) |
|
|
|
|
|
if sys.version_info >= (3, 6): |
|
AstConstant = ast.Constant |
|
else: |
|
class AstConstant: |
|
value = object() |
|
|
|
|
|
def match_token(token, tok_type, tok_str=None): |
|
|
|
"""Returns true if token is of the given type and, if a string is given, has that string.""" |
|
return token.type == tok_type and (tok_str is None or token.string == tok_str) |
|
|
|
|
|
def expect_token(token, tok_type, tok_str=None): |
|
|
|
""" |
|
Verifies that the given token is of the expected type. If tok_str is given, the token string |
|
is verified too. If the token doesn't match, raises an informative ValueError. |
|
""" |
|
if not match_token(token, tok_type, tok_str): |
|
raise ValueError("Expected token %s, got %s on line %s col %s" % ( |
|
token_repr(tok_type, tok_str), str(token), |
|
token.start[0], token.start[1] + 1)) |
|
|
|
|
|
|
|
if sys.version_info >= (3, 7): |
|
def is_non_coding_token(token_type): |
|
|
|
""" |
|
These are considered non-coding tokens, as they don't affect the syntax tree. |
|
""" |
|
return token_type in (token.NL, token.COMMENT, token.ENCODING) |
|
else: |
|
def is_non_coding_token(token_type): |
|
|
|
""" |
|
These are considered non-coding tokens, as they don't affect the syntax tree. |
|
""" |
|
return token_type >= token.N_TOKENS |
|
|
|
|
|
def generate_tokens(text): |
|
|
|
""" |
|
Generates standard library tokens for the given code. |
|
""" |
|
|
|
|
|
|
|
return tokenize.generate_tokens(cast(Callable[[], str], io.StringIO(text).readline)) |
|
|
|
|
|
def iter_children_func(node): |
|
|
|
""" |
|
Returns a function which yields all direct children of a AST node, |
|
skipping children that are singleton nodes. |
|
The function depends on whether ``node`` is from ``ast`` or from the ``astroid`` module. |
|
""" |
|
return iter_children_astroid if hasattr(node, 'get_children') else iter_children_ast |
|
|
|
|
|
def iter_children_astroid(node, include_joined_str=False): |
|
|
|
if not include_joined_str and is_joined_str(node): |
|
return [] |
|
|
|
return node.get_children() |
|
|
|
|
|
SINGLETONS = {c for n, c in iteritems(ast.__dict__) if isinstance(c, type) and |
|
issubclass(c, (ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop))} |
|
|
|
|
|
def iter_children_ast(node, include_joined_str=False): |
|
|
|
if not include_joined_str and is_joined_str(node): |
|
return |
|
|
|
if isinstance(node, ast.Dict): |
|
|
|
|
|
for (key, value) in zip(node.keys, node.values): |
|
if key is not None: |
|
yield key |
|
yield value |
|
return |
|
|
|
for child in ast.iter_child_nodes(node): |
|
|
|
|
|
|
|
if child.__class__ not in SINGLETONS: |
|
yield child |
|
|
|
|
|
stmt_class_names = {n for n, c in iteritems(ast.__dict__) |
|
if isinstance(c, type) and issubclass(c, ast.stmt)} |
|
expr_class_names = ({n for n, c in iteritems(ast.__dict__) |
|
if isinstance(c, type) and issubclass(c, ast.expr)} | |
|
{'AssignName', 'DelName', 'Const', 'AssignAttr', 'DelAttr'}) |
|
|
|
|
|
|
|
def is_expr(node): |
|
|
|
"""Returns whether node is an expression node.""" |
|
return node.__class__.__name__ in expr_class_names |
|
|
|
def is_stmt(node): |
|
|
|
"""Returns whether node is a statement node.""" |
|
return node.__class__.__name__ in stmt_class_names |
|
|
|
def is_module(node): |
|
|
|
"""Returns whether node is a module node.""" |
|
return node.__class__.__name__ == 'Module' |
|
|
|
def is_joined_str(node): |
|
|
|
"""Returns whether node is a JoinedStr node, used to represent f-strings.""" |
|
|
|
|
|
return node.__class__.__name__ == 'JoinedStr' |
|
|
|
|
|
def is_starred(node): |
|
|
|
"""Returns whether node is a starred expression node.""" |
|
return node.__class__.__name__ == 'Starred' |
|
|
|
|
|
def is_slice(node): |
|
|
|
"""Returns whether node represents a slice, e.g. `1:2` in `x[1:2]`""" |
|
|
|
|
|
return ( |
|
node.__class__.__name__ in ('Slice', 'ExtSlice') |
|
or ( |
|
node.__class__.__name__ == 'Tuple' |
|
and any(map(is_slice, cast(ast.Tuple, node).elts)) |
|
) |
|
) |
|
|
|
|
|
def is_empty_astroid_slice(node): |
|
|
|
return ( |
|
node.__class__.__name__ == "Slice" |
|
and not isinstance(node, ast.AST) |
|
and node.lower is node.upper is node.step is None |
|
) |
|
|
|
|
|
|
|
_PREVISIT = object() |
|
|
|
def visit_tree(node, previsit, postvisit): |
|
|
|
""" |
|
Scans the tree under the node depth-first using an explicit stack. It avoids implicit recursion |
|
via the function call stack to avoid hitting 'maximum recursion depth exceeded' error. |
|
|
|
It calls ``previsit()`` and ``postvisit()`` as follows: |
|
|
|
* ``previsit(node, par_value)`` - should return ``(par_value, value)`` |
|
``par_value`` is as returned from ``previsit()`` of the parent. |
|
|
|
* ``postvisit(node, par_value, value)`` - should return ``value`` |
|
``par_value`` is as returned from ``previsit()`` of the parent, and ``value`` is as |
|
returned from ``previsit()`` of this node itself. The return ``value`` is ignored except |
|
the one for the root node, which is returned from the overall ``visit_tree()`` call. |
|
|
|
For the initial node, ``par_value`` is None. ``postvisit`` may be None. |
|
""" |
|
if not postvisit: |
|
postvisit = lambda node, pvalue, value: None |
|
|
|
iter_children = iter_children_func(node) |
|
done = set() |
|
ret = None |
|
stack = [(node, None, _PREVISIT)] |
|
while stack: |
|
current, par_value, value = stack.pop() |
|
if value is _PREVISIT: |
|
assert current not in done |
|
done.add(current) |
|
|
|
pvalue, post_value = previsit(current, par_value) |
|
stack.append((current, par_value, post_value)) |
|
|
|
|
|
ins = len(stack) |
|
for n in iter_children(current): |
|
stack.insert(ins, (n, pvalue, _PREVISIT)) |
|
else: |
|
ret = postvisit(current, par_value, cast(Optional[Token], value)) |
|
return ret |
|
|
|
|
|
def walk(node, include_joined_str=False): |
|
|
|
""" |
|
Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node`` |
|
itself), using depth-first pre-order traversal (yieling parents before their children). |
|
|
|
This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and |
|
``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``. |
|
|
|
By default, ``JoinedStr`` (f-string) nodes and their contents are skipped |
|
because they previously couldn't be handled. Set ``include_joined_str`` to True to include them. |
|
""" |
|
iter_children = iter_children_func(node) |
|
done = set() |
|
stack = [node] |
|
while stack: |
|
current = stack.pop() |
|
assert current not in done |
|
done.add(current) |
|
|
|
yield current |
|
|
|
|
|
|
|
ins = len(stack) |
|
for c in iter_children(current, include_joined_str): |
|
stack.insert(ins, c) |
|
|
|
|
|
def replace(text, replacements): |
|
|
|
""" |
|
Replaces multiple slices of text with new values. This is a convenience method for making code |
|
modifications of ranges e.g. as identified by ``ASTTokens.get_text_range(node)``. Replacements is |
|
an iterable of ``(start, end, new_text)`` tuples. |
|
|
|
For example, ``replace("this is a test", [(0, 4, "X"), (8, 9, "THE")])`` produces |
|
``"X is THE test"``. |
|
""" |
|
p = 0 |
|
parts = [] |
|
for (start, end, new_text) in sorted(replacements): |
|
parts.append(text[p:start]) |
|
parts.append(new_text) |
|
p = end |
|
parts.append(text[p:]) |
|
return ''.join(parts) |
|
|
|
|
|
class NodeMethods(object): |
|
""" |
|
Helper to get `visit_{node_type}` methods given a node's class and cache the results. |
|
""" |
|
def __init__(self): |
|
|
|
self._cache = {} |
|
|
|
def get(self, obj, cls): |
|
|
|
""" |
|
Using the lowercase name of the class as node_type, returns `obj.visit_{node_type}`, |
|
or `obj.visit_default` if the type-specific method is not found. |
|
""" |
|
method = self._cache.get(cls) |
|
if not method: |
|
name = "visit_" + cls.__name__.lower() |
|
method = getattr(obj, name, obj.visit_default) |
|
self._cache[cls] = method |
|
return method |
|
|
|
|
|
if sys.version_info[0] == 2: |
|
|
|
|
|
def patched_generate_tokens(original_tokens): |
|
|
|
return iter(original_tokens) |
|
else: |
|
def patched_generate_tokens(original_tokens): |
|
|
|
""" |
|
Fixes tokens yielded by `tokenize.generate_tokens` to handle more non-ASCII characters in identifiers. |
|
Workaround for https://github.com/python/cpython/issues/68382. |
|
Should only be used when tokenizing a string that is known to be valid syntax, |
|
because it assumes that error tokens are not actually errors. |
|
Combines groups of consecutive NAME, NUMBER, and/or ERRORTOKEN tokens into a single NAME token. |
|
""" |
|
group = [] |
|
for tok in original_tokens: |
|
if ( |
|
tok.type in (tokenize.NAME, tokenize.ERRORTOKEN, tokenize.NUMBER) |
|
|
|
and (not group or group[-1].end == tok.start) |
|
): |
|
group.append(tok) |
|
else: |
|
for combined_token in combine_tokens(group): |
|
yield combined_token |
|
group = [] |
|
yield tok |
|
for combined_token in combine_tokens(group): |
|
yield combined_token |
|
|
|
def combine_tokens(group): |
|
|
|
if not any(tok.type == tokenize.ERRORTOKEN for tok in group) or len({tok.line for tok in group}) != 1: |
|
return group |
|
return [ |
|
tokenize.TokenInfo( |
|
type=tokenize.NAME, |
|
string="".join(t.string for t in group), |
|
start=group[0].start, |
|
end=group[-1].end, |
|
line=group[0].line, |
|
) |
|
] |
|
|
|
|
|
def last_stmt(node): |
|
|
|
""" |
|
If the given AST node contains multiple statements, return the last one. |
|
Otherwise, just return the node. |
|
""" |
|
child_stmts = [ |
|
child for child in iter_children_func(node)(node) |
|
if is_stmt(child) or type(child).__name__ in ( |
|
"excepthandler", |
|
"ExceptHandler", |
|
"match_case", |
|
"MatchCase", |
|
"TryExcept", |
|
"TryFinally", |
|
) |
|
] |
|
if child_stmts: |
|
return last_stmt(child_stmts[-1]) |
|
return node |
|
|
|
|
|
if sys.version_info[:2] >= (3, 8): |
|
from functools import lru_cache |
|
|
|
@lru_cache(maxsize=None) |
|
def fstring_positions_work(): |
|
|
|
""" |
|
The positions attached to nodes inside f-string FormattedValues have some bugs |
|
that were fixed in Python 3.9.7 in https://github.com/python/cpython/pull/27729. |
|
This checks for those bugs more concretely without relying on the Python version. |
|
Specifically this checks: |
|
- Values with a format spec or conversion |
|
- Repeated (i.e. identical-looking) expressions |
|
- f-strings implicitly concatenated over multiple lines. |
|
- Multiline, triple-quoted f-strings. |
|
""" |
|
source = """( |
|
f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}" |
|
f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}" |
|
f"{x + y + z} {x} {y} {z} {z} {z!a} {z:z}" |
|
f''' |
|
{s} {t} |
|
{u} {v} |
|
''' |
|
)""" |
|
tree = ast.parse(source) |
|
name_nodes = [node for node in ast.walk(tree) if isinstance(node, ast.Name)] |
|
name_positions = [(node.lineno, node.col_offset) for node in name_nodes] |
|
positions_are_unique = len(set(name_positions)) == len(name_positions) |
|
correct_source_segments = all( |
|
ast.get_source_segment(source, node) == node.id |
|
for node in name_nodes |
|
) |
|
return positions_are_unique and correct_source_segments |
|
|
|
def annotate_fstring_nodes(tree): |
|
|
|
""" |
|
Add a special attribute `_broken_positions` to nodes inside f-strings |
|
if the lineno/col_offset cannot be trusted. |
|
""" |
|
if sys.version_info >= (3, 12): |
|
|
|
|
|
return |
|
for joinedstr in walk(tree, include_joined_str=True): |
|
if not isinstance(joinedstr, ast.JoinedStr): |
|
continue |
|
for part in joinedstr.values: |
|
|
|
setattr(part, '_broken_positions', True) |
|
|
|
if isinstance(part, ast.FormattedValue): |
|
if not fstring_positions_work(): |
|
for child in walk(part.value): |
|
setattr(child, '_broken_positions', True) |
|
|
|
if part.format_spec: |
|
|
|
setattr(part.format_spec, '_broken_positions', True) |
|
|
|
else: |
|
def fstring_positions_work(): |
|
|
|
return False |
|
|
|
def annotate_fstring_nodes(_tree): |
|
|
|
pass |
|
|