diff --git a/personal_python_ast_optimizer/parser/minifier.py b/personal_python_ast_optimizer/parser/minifier.py index b54f6c9..bde7bb2 100644 --- a/personal_python_ast_optimizer/parser/minifier.py +++ b/personal_python_ast_optimizer/parser/minifier.py @@ -21,6 +21,7 @@ class MinifyUnparser(_Unparser): __slots__ = ( "constant_vars_to_fold", + "is_last_node_in_body", "module_name", "target_python_version", "within_class", @@ -41,6 +42,7 @@ def __init__( constant_vars_to_fold if constant_vars_to_fold is not None else {} ) + self.is_last_node_in_body: bool = False self.within_class: bool = False self.within_function: bool = False @@ -67,12 +69,33 @@ def _update_text_to_write(self, text: str) -> str: return text + def visit_node(self, node: ast.AST, is_last_node_in_body: bool = False): + method = "visit_" + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + + previous_state: bool = self.is_last_node_in_body + + self.is_last_node_in_body = is_last_node_in_body + try: + return visitor(node) # type: ignore + finally: + self.is_last_node_in_body = previous_state + + def traverse(self, node: list[ast.stmt] | ast.AST) -> None: + if isinstance(node, list): + last_index = len(node) - 1 + for index, item in enumerate(node): + is_last_node_in_body: bool = index == last_index + self.visit_node(item, is_last_node_in_body) + else: + self.visit_node(node) + def visit_Pass(self, _: ast.Pass | None = None) -> None: - same_line: bool = self._last_token_was_colon() + same_line: bool = self._can_write_same_line() self.fill("pass", same_line=same_line) def visit_Return(self, node: ast.Return) -> None: - same_line: bool = self._last_token_was_colon() + same_line: bool = self._can_write_same_line() self.fill("return", same_line=same_line) if node.value and not is_return_none(node): self.write(" ") @@ -157,7 +180,12 @@ def visit_Assign(self, node: ast.Assign) -> None: if len(node.value.elts) == 1: node.value = node.value.elts[0] - super().visit_Assign(node) + self.fill(same_line=self._can_write_same_line()) + for target in node.targets: + self.set_precedence(ast._Precedence.TUPLE, target) # type: ignore + self.traverse(target) + self.write("=") + self.traverse(node.value) def visit_AugAssign(self, node: ast.AugAssign) -> None: self.fill() @@ -267,8 +295,6 @@ def _function_helper( remove_empty_annotations(node) add_pass_if_body_empty(node) - # if len(node.body) == 1: - # self._set_can_write_same_line(node.body[0]) self._write_decorators(node) @@ -301,5 +327,9 @@ def _use_version_optimization(self, python_version: tuple[int, int]) -> bool: return self.target_python_version >= python_version - def _last_token_was_colon(self): - return len(self._source) > 0 and self._source[-1] == ":" + def _can_write_same_line(self): + return ( + len(self._source) > 0 + and self._source[-1] == ":" + and self.is_last_node_in_body + ) diff --git a/pyproject.toml b/pyproject.toml index d77c82b..98a1fda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ authors = [ {name = "James Demetris"}, ] name = "personal-python-ast-optimizer" -version = "1.1.0" +version = "1.1.1" readme = "README.md" requires-python = ">=3.10" dependencies = [ diff --git a/tests/parser/test_assign.py b/tests/parser/test_assign.py new file mode 100644 index 0000000..9e140a3 --- /dev/null +++ b/tests/parser/test_assign.py @@ -0,0 +1,31 @@ +import pytest + +from tests.utils import BeforeAndAfter, run_minifiyer_and_assert_correct + + +@pytest.mark.parametrize( + "before_and_after", + [ + ( + BeforeAndAfter( + """ +if a > 6: + b = 3 + c = 4 +""", + "if a>6:\n\tb=3\n\tc=4", + ) + ), + ( + BeforeAndAfter( + """ +if a > 6: + b = 3 +""", + "if a>6:b=3", + ) + ), + ], +) +def test_exclude_name_equals_main(before_and_after: BeforeAndAfter): + run_minifiyer_and_assert_correct(before_and_after)