|
from textwrap import dedent |
|
|
|
from parso import split_lines |
|
|
|
from jedi import debug |
|
from jedi.api.exceptions import RefactoringError |
|
from jedi.api.refactoring import Refactoring, EXPRESSION_PARTS |
|
from jedi.common import indent_block |
|
from jedi.parser_utils import function_is_classmethod, function_is_staticmethod |
|
|
|
|
|
_DEFINITION_SCOPES = ('suite', 'file_input') |
|
_VARIABLE_EXCTRACTABLE = EXPRESSION_PARTS + \ |
|
('atom testlist_star_expr testlist test lambdef lambdef_nocond ' |
|
'keyword name number string fstring').split() |
|
|
|
|
|
def extract_variable(inference_state, path, module_node, name, pos, until_pos): |
|
nodes = _find_nodes(module_node, pos, until_pos) |
|
debug.dbg('Extracting nodes: %s', nodes) |
|
|
|
is_expression, message = _is_expression_with_error(nodes) |
|
if not is_expression: |
|
raise RefactoringError(message) |
|
|
|
generated_code = name + ' = ' + _expression_nodes_to_string(nodes) |
|
file_to_node_changes = {path: _replace(nodes, name, generated_code, pos)} |
|
return Refactoring(inference_state, file_to_node_changes) |
|
|
|
|
|
def _is_expression_with_error(nodes): |
|
""" |
|
Returns a tuple (is_expression, error_string). |
|
""" |
|
if any(node.type == 'name' and node.is_definition() for node in nodes): |
|
return False, 'Cannot extract a name that defines something' |
|
|
|
if nodes[0].type not in _VARIABLE_EXCTRACTABLE: |
|
return False, 'Cannot extract a "%s"' % nodes[0].type |
|
return True, '' |
|
|
|
|
|
def _find_nodes(module_node, pos, until_pos): |
|
""" |
|
Looks up a module and tries to find the appropriate amount of nodes that |
|
are in there. |
|
""" |
|
start_node = module_node.get_leaf_for_position(pos, include_prefixes=True) |
|
|
|
if until_pos is None: |
|
if start_node.type == 'operator': |
|
next_leaf = start_node.get_next_leaf() |
|
if next_leaf is not None and next_leaf.start_pos == pos: |
|
start_node = next_leaf |
|
|
|
if _is_not_extractable_syntax(start_node): |
|
start_node = start_node.parent |
|
|
|
if start_node.parent.type == 'trailer': |
|
start_node = start_node.parent.parent |
|
while start_node.parent.type in EXPRESSION_PARTS: |
|
start_node = start_node.parent |
|
|
|
nodes = [start_node] |
|
else: |
|
|
|
if start_node.end_pos == pos: |
|
next_leaf = start_node.get_next_leaf() |
|
if next_leaf is not None: |
|
start_node = next_leaf |
|
|
|
|
|
if _is_not_extractable_syntax(start_node): |
|
start_node = start_node.parent |
|
|
|
|
|
end_leaf = module_node.get_leaf_for_position(until_pos, include_prefixes=True) |
|
if end_leaf.start_pos > until_pos: |
|
end_leaf = end_leaf.get_previous_leaf() |
|
if end_leaf is None: |
|
raise RefactoringError('Cannot extract anything from that') |
|
|
|
parent_node = start_node |
|
while parent_node.end_pos < end_leaf.end_pos: |
|
parent_node = parent_node.parent |
|
|
|
nodes = _remove_unwanted_expression_nodes(parent_node, pos, until_pos) |
|
|
|
|
|
|
|
|
|
if len(nodes) == 1 and start_node.type in ('return_stmt', 'yield_expr'): |
|
return [nodes[0].children[1]] |
|
return nodes |
|
|
|
|
|
def _replace(nodes, expression_replacement, extracted, pos, |
|
insert_before_leaf=None, remaining_prefix=None): |
|
|
|
|
|
definition = _get_parent_definition(nodes[0]) |
|
if insert_before_leaf is None: |
|
insert_before_leaf = definition.get_first_leaf() |
|
first_node_leaf = nodes[0].get_first_leaf() |
|
|
|
lines = split_lines(insert_before_leaf.prefix, keepends=True) |
|
if first_node_leaf is insert_before_leaf: |
|
if remaining_prefix is not None: |
|
|
|
lines[:-1] = remaining_prefix |
|
lines[-1:-1] = [indent_block(extracted, lines[-1]) + '\n'] |
|
extracted_prefix = ''.join(lines) |
|
|
|
replacement_dct = {} |
|
if first_node_leaf is insert_before_leaf: |
|
replacement_dct[nodes[0]] = extracted_prefix + expression_replacement |
|
else: |
|
if remaining_prefix is None: |
|
p = first_node_leaf.prefix |
|
else: |
|
p = remaining_prefix + _get_indentation(nodes[0]) |
|
replacement_dct[nodes[0]] = p + expression_replacement |
|
replacement_dct[insert_before_leaf] = extracted_prefix + insert_before_leaf.value |
|
|
|
for node in nodes[1:]: |
|
replacement_dct[node] = '' |
|
return replacement_dct |
|
|
|
|
|
def _expression_nodes_to_string(nodes): |
|
return ''.join(n.get_code(include_prefix=i != 0) for i, n in enumerate(nodes)) |
|
|
|
|
|
def _suite_nodes_to_string(nodes, pos): |
|
n = nodes[0] |
|
prefix, part_of_code = _split_prefix_at(n.get_first_leaf(), pos[0] - 1) |
|
code = part_of_code + n.get_code(include_prefix=False) \ |
|
+ ''.join(n.get_code() for n in nodes[1:]) |
|
return prefix, code |
|
|
|
|
|
def _split_prefix_at(leaf, until_line): |
|
""" |
|
Returns a tuple of the leaf's prefix, split at the until_line |
|
position. |
|
""" |
|
|
|
second_line_count = leaf.start_pos[0] - until_line |
|
lines = split_lines(leaf.prefix, keepends=True) |
|
return ''.join(lines[:-second_line_count]), ''.join(lines[-second_line_count:]) |
|
|
|
|
|
def _get_indentation(node): |
|
return split_lines(node.get_first_leaf().prefix)[-1] |
|
|
|
|
|
def _get_parent_definition(node): |
|
""" |
|
Returns the statement where a node is defined. |
|
""" |
|
while node is not None: |
|
if node.parent.type in _DEFINITION_SCOPES: |
|
return node |
|
node = node.parent |
|
raise NotImplementedError('We should never even get here') |
|
|
|
|
|
def _remove_unwanted_expression_nodes(parent_node, pos, until_pos): |
|
""" |
|
This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even |
|
though it is not part of the expression. |
|
""" |
|
typ = parent_node.type |
|
is_suite_part = typ in ('suite', 'file_input') |
|
if typ in EXPRESSION_PARTS or is_suite_part: |
|
nodes = parent_node.children |
|
for i, n in enumerate(nodes): |
|
if n.end_pos > pos: |
|
start_index = i |
|
if n.type == 'operator': |
|
start_index -= 1 |
|
break |
|
for i, n in reversed(list(enumerate(nodes))): |
|
if n.start_pos < until_pos: |
|
end_index = i |
|
if n.type == 'operator': |
|
end_index += 1 |
|
|
|
|
|
for n2 in nodes[i:]: |
|
if _is_not_extractable_syntax(n2): |
|
end_index += 1 |
|
else: |
|
break |
|
break |
|
nodes = nodes[start_index:end_index + 1] |
|
if not is_suite_part: |
|
nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos) |
|
nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos) |
|
return nodes |
|
return [parent_node] |
|
|
|
|
|
def _is_not_extractable_syntax(node): |
|
return node.type == 'operator' \ |
|
or node.type == 'keyword' and node.value not in ('None', 'True', 'False') |
|
|
|
|
|
def extract_function(inference_state, path, module_context, name, pos, until_pos): |
|
nodes = _find_nodes(module_context.tree_node, pos, until_pos) |
|
assert len(nodes) |
|
|
|
is_expression, _ = _is_expression_with_error(nodes) |
|
context = module_context.create_context(nodes[0]) |
|
is_bound_method = context.is_bound_method() |
|
params, return_variables = list(_find_inputs_and_outputs(module_context, context, nodes)) |
|
|
|
|
|
|
|
if context.is_module(): |
|
insert_before_leaf = None |
|
else: |
|
node = _get_code_insertion_node(context.tree_node, is_bound_method) |
|
insert_before_leaf = node.get_first_leaf() |
|
if is_expression: |
|
code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n' |
|
remaining_prefix = None |
|
has_ending_return_stmt = False |
|
else: |
|
has_ending_return_stmt = _is_node_ending_return_stmt(nodes[-1]) |
|
if not has_ending_return_stmt: |
|
|
|
|
|
|
|
return_variables = list(_find_needed_output_variables( |
|
context, |
|
nodes[0].parent, |
|
nodes[-1].end_pos, |
|
return_variables |
|
)) or [return_variables[-1]] if return_variables else [] |
|
|
|
remaining_prefix, code_block = _suite_nodes_to_string(nodes, pos) |
|
after_leaf = nodes[-1].get_next_leaf() |
|
first, second = _split_prefix_at(after_leaf, until_pos[0]) |
|
code_block += first |
|
|
|
code_block = dedent(code_block) |
|
if not has_ending_return_stmt: |
|
output_var_str = ', '.join(return_variables) |
|
code_block += 'return ' + output_var_str + '\n' |
|
|
|
|
|
_check_for_non_extractables(nodes[:-1] if has_ending_return_stmt else nodes) |
|
|
|
decorator = '' |
|
self_param = None |
|
if is_bound_method: |
|
if not function_is_staticmethod(context.tree_node): |
|
function_param_names = context.get_value().get_param_names() |
|
if len(function_param_names): |
|
self_param = function_param_names[0].string_name |
|
params = [p for p in params if p != self_param] |
|
|
|
if function_is_classmethod(context.tree_node): |
|
decorator = '@classmethod\n' |
|
else: |
|
code_block += '\n' |
|
|
|
function_code = '%sdef %s(%s):\n%s' % ( |
|
decorator, |
|
name, |
|
', '.join(params if self_param is None else [self_param] + params), |
|
indent_block(code_block) |
|
) |
|
|
|
function_call = '%s(%s)' % ( |
|
('' if self_param is None else self_param + '.') + name, |
|
', '.join(params) |
|
) |
|
if is_expression: |
|
replacement = function_call |
|
else: |
|
if has_ending_return_stmt: |
|
replacement = 'return ' + function_call + '\n' |
|
else: |
|
replacement = output_var_str + ' = ' + function_call + '\n' |
|
|
|
replacement_dct = _replace(nodes, replacement, function_code, pos, |
|
insert_before_leaf, remaining_prefix) |
|
if not is_expression: |
|
replacement_dct[after_leaf] = second + after_leaf.value |
|
file_to_node_changes = {path: replacement_dct} |
|
return Refactoring(inference_state, file_to_node_changes) |
|
|
|
|
|
def _check_for_non_extractables(nodes): |
|
for n in nodes: |
|
try: |
|
children = n.children |
|
except AttributeError: |
|
if n.value == 'return': |
|
raise RefactoringError( |
|
'Can only extract return statements if they are at the end.') |
|
if n.value == 'yield': |
|
raise RefactoringError('Cannot extract yield statements.') |
|
else: |
|
_check_for_non_extractables(children) |
|
|
|
|
|
def _is_name_input(module_context, names, first, last): |
|
for name in names: |
|
if name.api_type == 'param' or not name.parent_context.is_module(): |
|
if name.get_root_context() is not module_context: |
|
return True |
|
if name.start_pos is None or not (first <= name.start_pos < last): |
|
return True |
|
return False |
|
|
|
|
|
def _find_inputs_and_outputs(module_context, context, nodes): |
|
first = nodes[0].start_pos |
|
last = nodes[-1].end_pos |
|
|
|
inputs = [] |
|
outputs = [] |
|
for name in _find_non_global_names(nodes): |
|
if name.is_definition(): |
|
if name not in outputs: |
|
outputs.append(name.value) |
|
else: |
|
if name.value not in inputs: |
|
name_definitions = context.goto(name, name.start_pos) |
|
if not name_definitions \ |
|
or _is_name_input(module_context, name_definitions, first, last): |
|
inputs.append(name.value) |
|
|
|
|
|
return inputs, outputs |
|
|
|
|
|
def _find_non_global_names(nodes): |
|
for node in nodes: |
|
try: |
|
children = node.children |
|
except AttributeError: |
|
if node.type == 'name': |
|
yield node |
|
else: |
|
|
|
if node.type == 'trailer' and node.children[0] == '.': |
|
continue |
|
|
|
yield from _find_non_global_names(children) |
|
|
|
|
|
def _get_code_insertion_node(node, is_bound_method): |
|
if not is_bound_method or function_is_staticmethod(node): |
|
while node.parent.type != 'file_input': |
|
node = node.parent |
|
|
|
while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'): |
|
node = node.parent |
|
return node |
|
|
|
|
|
def _find_needed_output_variables(context, search_node, at_least_pos, return_variables): |
|
""" |
|
Searches everything after at_least_pos in a node and checks if any of the |
|
return_variables are used in there and returns those. |
|
""" |
|
for node in search_node.children: |
|
if node.start_pos < at_least_pos: |
|
continue |
|
|
|
return_variables = set(return_variables) |
|
for name in _find_non_global_names([node]): |
|
if not name.is_definition() and name.value in return_variables: |
|
return_variables.remove(name.value) |
|
yield name.value |
|
|
|
|
|
def _is_node_ending_return_stmt(node): |
|
t = node.type |
|
if t == 'simple_stmt': |
|
return _is_node_ending_return_stmt(node.children[0]) |
|
return t == 'return_stmt' |
|
|