Spaces:
Paused
Paused
"""API for traversing the AST nodes. Implemented by the compiler and | |
meta introspection. | |
""" | |
import typing as t | |
from .nodes import Node | |
if t.TYPE_CHECKING: | |
import typing_extensions as te | |
class VisitCallable(te.Protocol): | |
def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: ... | |
class NodeVisitor: | |
"""Walks the abstract syntax tree and call visitor functions for every | |
node found. The visitor functions may return values which will be | |
forwarded by the `visit` method. | |
Per default the visitor functions for the nodes are ``'visit_'`` + | |
class name of the node. So a `TryFinally` node visit function would | |
be `visit_TryFinally`. This behavior can be changed by overriding | |
the `get_visitor` function. If no visitor function exists for a node | |
(return value `None`) the `generic_visit` visitor is used instead. | |
""" | |
def get_visitor(self, node: Node) -> "t.Optional[VisitCallable]": | |
"""Return the visitor function for this node or `None` if no visitor | |
exists for this node. In that case the generic visit function is | |
used instead. | |
""" | |
return getattr(self, f"visit_{type(node).__name__}", None) | |
def visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: | |
"""Visit a node.""" | |
f = self.get_visitor(node) | |
if f is not None: | |
return f(node, *args, **kwargs) | |
return self.generic_visit(node, *args, **kwargs) | |
def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: | |
"""Called if no explicit visitor function exists for a node.""" | |
for child_node in node.iter_child_nodes(): | |
self.visit(child_node, *args, **kwargs) | |
class NodeTransformer(NodeVisitor): | |
"""Walks the abstract syntax tree and allows modifications of nodes. | |
The `NodeTransformer` will walk the AST and use the return value of the | |
visitor functions to replace or remove the old node. If the return | |
value of the visitor function is `None` the node will be removed | |
from the previous location otherwise it's replaced with the return | |
value. The return value may be the original node in which case no | |
replacement takes place. | |
""" | |
def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> Node: | |
for field, old_value in node.iter_fields(): | |
if isinstance(old_value, list): | |
new_values = [] | |
for value in old_value: | |
if isinstance(value, Node): | |
value = self.visit(value, *args, **kwargs) | |
if value is None: | |
continue | |
elif not isinstance(value, Node): | |
new_values.extend(value) | |
continue | |
new_values.append(value) | |
old_value[:] = new_values | |
elif isinstance(old_value, Node): | |
new_node = self.visit(old_value, *args, **kwargs) | |
if new_node is None: | |
delattr(node, field) | |
else: | |
setattr(node, field, new_node) | |
return node | |
def visit_list(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.List[Node]: | |
"""As transformers may return lists in some places this method | |
can be used to enforce a list as return value. | |
""" | |
rv = self.visit(node, *args, **kwargs) | |
if not isinstance(rv, list): | |
return [rv] | |
return rv | |