From e8e2a49321e066c492ba530f4cdcf1678c2a90ae Mon Sep 17 00:00:00 2001 From: memento Date: Fri, 2 Dec 2022 11:59:55 -0600 Subject: [PATCH 1/5] (fix) Proposing a workaround for duplicate decorators in def -> async def --- docs/requirements.txt | 1 + refactor/actions.py | 26 ++++++++- tests/test_common.py | 18 +++++++ tests/test_complete_rules.py | 102 +++++++++++++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 13c5141..0f31fc5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,3 +2,4 @@ myst_parser==0.15.1 furo==2022.06.21 sphinx-design sphinx-hoverxref +pytest \ No newline at end of file diff --git a/refactor/actions.py b/refactor/actions.py index 13e8d76..eb78302 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -78,13 +78,17 @@ def _replace_input(self, node: ast.AST) -> _LazyActionMixin[K, T]: class _ReplaceCodeSegmentAction(BaseAction): def apply(self, context: Context, source: str) -> str: + # The decorators are removed in the 'lines' but present in the 'context` + # This lead to the 'replacement' containing the decorators and the returned + # 'lines' to duplicate them. Proposed workaround is to add the decorators in + # the 'view', in case the '_resynthesize()' adds/modifies them lines = split_lines(source, encoding=context.file_info.get_encoding()) ( lineno, col_offset, end_lineno, end_col_offset, - ) = self._get_segment_span(context) + ) = self._get_decorated_segment_span(context) view = slice(lineno - 1, end_lineno) source_lines = lines[view] @@ -102,6 +106,9 @@ def apply(self, context: Context, source: str) -> str: def _get_segment_span(self, context: Context) -> PositionType: raise NotImplementedError + def _get_decorated_segment_span(self, context: Context) -> PositionType: + raise NotImplementedError + def _resynthesize(self, context: Context) -> str: raise NotImplementedError @@ -121,6 +128,13 @@ class LazyReplace(_ReplaceCodeSegmentAction, _LazyActionMixin[ast.AST, ast.AST]) def _get_segment_span(self, context: Context) -> PositionType: return position_for(self.node) + def _get_decorated_segment_span(self, context: Context) -> PositionType: + lineno, col_offset, end_lineno, end_col_offset = position_for(self.node) + # Add the decorators to the segment span to resolve an issue with def -> async def + if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0: + lineno = position_for(getattr(self.node, "decorator_list")[0])[0] + return lineno, col_offset, end_lineno, end_col_offset + def _resynthesize(self, context: Context) -> str: return context.unparse(self.build()) @@ -228,6 +242,9 @@ class _Rename(Replace): def _get_segment_span(self, context: Context) -> PositionType: return self.identifier_span + def _get_decorated_segment_span(self, context: Context) -> PositionType: + return self.identifier_span + def _resynthesize(self, context: Context) -> str: return self.target.name @@ -260,6 +277,13 @@ def is_critical_node(self, context: Context) -> bool: def _get_segment_span(self, context: Context) -> PositionType: return position_for(self.node) + def _get_decorated_segment_span(self, context: Context) -> PositionType: + lineno, col_offset, end_lineno, end_col_offset = position_for(self.node) + # Add the decorators to the segment span to resolve an issue with def -> async def + if hasattr(self.node, "decorator_list") and len(self.node["decorator_list"]) > 0: + lineno = position_for(self.node["decorator_list"][0])[0] + return lineno, col_offset, end_lineno, end_col_offset + def _resynthesize(self, context: Context) -> str: if self.is_critical_node(context): raise InvalidActionError( diff --git a/tests/test_common.py b/tests/test_common.py index 5ff4b9d..49b5085 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -137,6 +137,24 @@ def func(): assert position_for(right_node) == (3, 23, 3, 25) +def test_get_positions_with_decorator(): + source = textwrap.dedent( + """\ + @deco0 + @deco1(arg0, + arg1) + def func(): + if a > 5: + return 5 + 3 + 25 + elif b > 10: + return 1 + 3 + 5 + 7 + """ + ) + tree = ast.parse(source) + right_node = tree.body[0].body[0].body[0].value.right + assert position_for(right_node) == (6, 23, 6, 25) + + def test_singleton(): from dataclasses import dataclass diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index b46c14e..c8a9296 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -296,6 +296,107 @@ def match(self, node): return AsyncifierAction(node) +class MakeFunctionAsyncWithDecorators(Rule): + INPUT_SOURCE = """ + @deco0 + @deco1(arg0, + arg1) + def something(): + a += .1 + '''you know + this is custom + literal + ''' + print(we, + preserve, + everything + ) + return ( + right + "?") + """ + + EXPECTED_SOURCE = """ + @deco0 + @deco1(arg0, + arg1) + async def something(): + a += .1 + '''you know + this is custom + literal + ''' + print(we, + preserve, + everything + ) + return ( + right + "?") + """ + + def match(self, node): + assert isinstance(node, ast.FunctionDef) + return AsyncifierAction(node) + + +class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule): + context_providers = (context.Scope,) + + INPUT_SOURCE = """ + class Klass: + def method(self, *, a): + print() + + lambda self, *, a: print + + """ + + EXPECTED_SOURCE = """ + class Klass: + def method(self, *, a=None): + print() + + lambda self, *, a=None: print + + """ + + def match(self, node: ast.AST) -> BaseAction | None: + assert isinstance(node, (ast.FunctionDef, ast.Lambda)) + assert any(kw_default is None for kw_default in node.args.kw_defaults) + + if isinstance(node, ast.Lambda) and not ( + isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load) + ): + scope = self.context["scope"].resolve(node.body) + scope.definitions.get(node.body.id, []) + + elif isinstance(node, ast.FunctionDef): + for stmt in node.body: + for identifier in ast.walk(stmt): + if not ( + isinstance(identifier, ast.Name) + and isinstance(identifier.ctx, ast.Load) + ): + continue + + scope = self.context["scope"].resolve(identifier) + while not scope.definitions.get(identifier.id, []): + scope = scope.parent + if scope is None: + break + + kw_defaults = [] + for kw_default in node.args.kw_defaults: + if kw_default is None: + kw_defaults.append(ast.Constant(value=None)) + else: + kw_defaults.append(kw_default) + + target = deepcopy(node) + target.args.kw_defaults = kw_defaults + + return Replace(node, target) + + class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule): context_providers = (context.Scope,) @@ -944,6 +1045,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: @pytest.mark.parametrize( "rule", [ + MakeFunctionAsyncWithDecorators, ReplaceNexts, ReplacePlaceholders, PropagateConstants, From 4aa50286b9f32f94a8aa5600731a28edf1981a4d Mon Sep 17 00:00:00 2001 From: memento Date: Fri, 2 Dec 2022 12:19:25 -0600 Subject: [PATCH 2/5] (fix) Proposing a workaround for duplicate decorators in def -> async def --- docs/requirements.txt | 1 - tests/test_complete_rules.py | 60 ------------------------------------ 2 files changed, 61 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 0f31fc5..13c5141 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,4 +2,3 @@ myst_parser==0.15.1 furo==2022.06.21 sphinx-design sphinx-hoverxref -pytest \ No newline at end of file diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index c8a9296..4851a99 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -397,65 +397,6 @@ def match(self, node: ast.AST) -> BaseAction | None: return Replace(node, target) -class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule): - context_providers = (context.Scope,) - - INPUT_SOURCE = """ - class Klass: - def method(self, *, a): - print() - - lambda self, *, a: print - - """ - - EXPECTED_SOURCE = """ - class Klass: - def method(self, *, a=None): - print() - - lambda self, *, a=None: print - - """ - - def match(self, node: ast.AST) -> BaseAction | None: - assert isinstance(node, (ast.FunctionDef, ast.Lambda)) - assert any(kw_default is None for kw_default in node.args.kw_defaults) - - if isinstance(node, ast.Lambda) and not ( - isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load) - ): - scope = self.context["scope"].resolve(node.body) - scope.definitions.get(node.body.id, []) - - elif isinstance(node, ast.FunctionDef): - for stmt in node.body: - for identifier in ast.walk(stmt): - if not ( - isinstance(identifier, ast.Name) - and isinstance(identifier.ctx, ast.Load) - ): - continue - - scope = self.context["scope"].resolve(identifier) - while not scope.definitions.get(identifier.id, []): - scope = scope.parent - if scope is None: - break - - kw_defaults = [] - for kw_default in node.args.kw_defaults: - if kw_default is None: - kw_defaults.append(ast.Constant(value=None)) - else: - kw_defaults.append(kw_default) - - target = deepcopy(node) - target.args.kw_defaults = kw_defaults - - return Replace(node, target) - - class InternalizeFunctions(Rule): INPUT_SOURCE = """ __all__ = ["regular"] @@ -1045,7 +986,6 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: @pytest.mark.parametrize( "rule", [ - MakeFunctionAsyncWithDecorators, ReplaceNexts, ReplacePlaceholders, PropagateConstants, From ea79bc292b6718d0ac8e3c63f91e7d1237ab2c87 Mon Sep 17 00:00:00 2001 From: memento Date: Mon, 12 Dec 2022 10:50:56 -0600 Subject: [PATCH 3/5] Fixing 'non-subscriptable' error --- refactor/actions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/refactor/actions.py b/refactor/actions.py index eb78302..756577e 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -132,7 +132,7 @@ def _get_decorated_segment_span(self, context: Context) -> PositionType: lineno, col_offset, end_lineno, end_col_offset = position_for(self.node) # Add the decorators to the segment span to resolve an issue with def -> async def if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0: - lineno = position_for(getattr(self.node, "decorator_list")[0])[0] + lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0]) return lineno, col_offset, end_lineno, end_col_offset def _resynthesize(self, context: Context) -> str: @@ -280,8 +280,8 @@ def _get_segment_span(self, context: Context) -> PositionType: def _get_decorated_segment_span(self, context: Context) -> PositionType: lineno, col_offset, end_lineno, end_col_offset = position_for(self.node) # Add the decorators to the segment span to resolve an issue with def -> async def - if hasattr(self.node, "decorator_list") and len(self.node["decorator_list"]) > 0: - lineno = position_for(self.node["decorator_list"][0])[0] + if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0: + lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0]) return lineno, col_offset, end_lineno, end_col_offset def _resynthesize(self, context: Context) -> str: From aa79d08fe4d8f19ffcc587835cdcfbd86cd286b9 Mon Sep 17 00:00:00 2001 From: memento Date: Thu, 5 Jan 2023 18:08:40 -0600 Subject: [PATCH 4/5] Fixing InsertBefore for decorated functions --- refactor/actions.py | 3 ++ tests/test_actions.py | 72 ++++++++++++++++++++++++------------------- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/refactor/actions.py b/refactor/actions.py index 4f2d85c..1812c8a 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -235,6 +235,9 @@ def apply(self, context: Context, source: str) -> str: replacement[-1] += lines._newline_type original_node_start = cast(int, self.node.lineno) + if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0: + original_node_start, _, _, _ = position_for(getattr(self.node, "decorator_list")[0]) + for line in reversed(replacement): lines.insert(original_node_start - 1, line) diff --git a/tests/test_actions.py b/tests/test_actions.py index a8040be..faf8057 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -10,6 +10,7 @@ from refactor import Session, common from refactor.actions import Erase, InvalidActionError, InsertAfter, Replace, InsertBefore +from refactor.common import clone from refactor.context import Context from refactor.core import Rule @@ -50,58 +51,64 @@ def foo(): class TestInsertAfterBottom(Rule): INPUT_SOURCE = """ - try: - base_tree = get_tree(base_file, module_name) - first_tree = get_tree(first_tree, module_name) - second_tree = get_tree(second_tree, module_name) - third_tree = get_tree(third_tree, module_name) - except (SyntaxError, FileNotFoundError): - continue""" + def undecorated(): + test_this()""" EXPECTED_SOURCE = """ - try: - base_tree = get_tree(base_file, module_name) - except (SyntaxError, FileNotFoundError): - continue + async def undecorated(): + test_this() await async_test()""" def match(self, node: ast.AST) -> Iterator[InsertAfter]: - assert isinstance(node, ast.Try) - assert len(node.body) >= 2 + assert isinstance(node, ast.FunctionDef) await_st = ast.parse("await async_test()") yield InsertAfter(node, cast(ast.stmt, await_st)) - new_try = common.clone(node) - new_try.body = [node.body[0]] - yield Replace(node, cast(ast.AST, new_try)) + new_node = clone(node) + new_node.__class__ = ast.AsyncFunctionDef + yield Replace(node, new_node) class TestInsertBeforeTop(Rule): INPUT_SOURCE = """ - try: - base_tree = get_tree(base_file, module_name) - first_tree = get_tree(first_tree, module_name) - second_tree = get_tree(second_tree, module_name) - third_tree = get_tree(third_tree, module_name) - except (SyntaxError, FileNotFoundError): - continue""" + def undecorated(): + test_this()""" EXPECTED_SOURCE = """ await async_test() - try: - base_tree = get_tree(base_file, module_name) - except (SyntaxError, FileNotFoundError): - continue""" + async def undecorated(): + test_this()""" def match(self, node: ast.AST) -> Iterator[InsertBefore]: - assert isinstance(node, ast.Try) - assert len(node.body) >= 2 + assert isinstance(node, ast.FunctionDef) + + await_st = ast.parse("await async_test()") + yield InsertBefore(node, cast(ast.stmt, await_st)) + new_node = clone(node) + new_node.__class__ = ast.AsyncFunctionDef + yield Replace(node, new_node) + + +class TestInsertBeforeDecoratedFunction(Rule): + INPUT_SOURCE = """ + @decorate + def decorated(): + test_this()""" + + EXPECTED_SOURCE = """ + await async_test() + @decorate + async def decorated(): + test_this()""" + + def match(self, node: ast.AST) -> Iterator[InsertBefore]: + assert isinstance(node, ast.FunctionDef) await_st = ast.parse("await async_test()") yield InsertBefore(node, cast(ast.stmt, await_st)) - new_try = common.clone(node) - new_try.body = [node.body[0]] - yield Replace(node, cast(ast.AST, new_try)) + new_node = clone(node) + new_node.__class__ = ast.AsyncFunctionDef + yield Replace(node, new_node) class TestInsertAfter(Rule): @@ -488,6 +495,7 @@ def test_erase_invalid(invalid_node): @pytest.mark.parametrize( "rule", [ + TestInsertBeforeDecoratedFunction, TestInsertAfterBottom, TestInsertBeforeTop, TestInsertAfter, From ee4f10ff189ab33b499a0e99639053e7263a27b9 Mon Sep 17 00:00:00 2001 From: memento Date: Thu, 5 Jan 2023 18:13:16 -0600 Subject: [PATCH 5/5] Adding multiple decorators in test --- tests/test_actions.py | 57 +++++++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 15 deletions(-) diff --git a/tests/test_actions.py b/tests/test_actions.py index faf8057..fb977df 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -49,34 +49,42 @@ def foo(): INVALID_ERASES_TREE = ast.parse(INVALID_ERASES) -class TestInsertAfterBottom(Rule): +class TestInsertBeforeDecoratedFunction(Rule): INPUT_SOURCE = """ - def undecorated(): + @decorate + def decorated(): test_this()""" EXPECTED_SOURCE = """ - async def undecorated(): - test_this() - await async_test()""" + await async_test() + @decorate + async def decorated(): + test_this()""" - def match(self, node: ast.AST) -> Iterator[InsertAfter]: + def match(self, node: ast.AST) -> Iterator[InsertBefore]: assert isinstance(node, ast.FunctionDef) await_st = ast.parse("await async_test()") - yield InsertAfter(node, cast(ast.stmt, await_st)) + yield InsertBefore(node, cast(ast.stmt, await_st)) new_node = clone(node) new_node.__class__ = ast.AsyncFunctionDef yield Replace(node, new_node) -class TestInsertBeforeTop(Rule): +class TestInsertBeforeMultipleDecorators(Rule): INPUT_SOURCE = """ - def undecorated(): + @decorate0 + @decorate1 + @decorate2 + def decorated(): test_this()""" EXPECTED_SOURCE = """ await async_test() - async def undecorated(): + @decorate0 + @decorate1 + @decorate2 + async def decorated(): test_this()""" def match(self, node: ast.AST) -> Iterator[InsertBefore]: @@ -89,16 +97,34 @@ def match(self, node: ast.AST) -> Iterator[InsertBefore]: yield Replace(node, new_node) -class TestInsertBeforeDecoratedFunction(Rule): +class TestInsertAfterBottom(Rule): INPUT_SOURCE = """ - @decorate - def decorated(): + def undecorated(): + test_this()""" + + EXPECTED_SOURCE = """ + async def undecorated(): + test_this() + await async_test()""" + + def match(self, node: ast.AST) -> Iterator[InsertAfter]: + assert isinstance(node, ast.FunctionDef) + + await_st = ast.parse("await async_test()") + yield InsertAfter(node, cast(ast.stmt, await_st)) + new_node = clone(node) + new_node.__class__ = ast.AsyncFunctionDef + yield Replace(node, new_node) + + +class TestInsertBeforeTop(Rule): + INPUT_SOURCE = """ + def undecorated(): test_this()""" EXPECTED_SOURCE = """ await async_test() - @decorate - async def decorated(): + async def undecorated(): test_this()""" def match(self, node: ast.AST) -> Iterator[InsertBefore]: @@ -496,6 +522,7 @@ def test_erase_invalid(invalid_node): "rule", [ TestInsertBeforeDecoratedFunction, + TestInsertBeforeMultipleDecorators, TestInsertAfterBottom, TestInsertBeforeTop, TestInsertAfter,