Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions src/ecooptimizer/refactorers/concrete/member_ignoring_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from ..multi_file_refactorer import MultiFileRefactorer
from ...data_types.smell import MIMSmell

logger = CONFIG["refactorLogger"]


class CallTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, class_name: str):
self.method_calls: list[tuple[str, int, str, str]] = None # type: ignore
self.class_name = class_name # Class name to replace instance calls
self.class_name = class_name # Class nme to replace instance calls
self.transformed = False

def set_calls(self, valid_calls: list[tuple[str, int, str, str]]):
Expand All @@ -34,15 +36,13 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal

# Check if this call matches one from astroid (by caller, method name, and line number)
for call_caller, line, call_method, cls in self.method_calls:
CONFIG["refactorLogger"].debug(
f"cst caller: {call_caller} at line {position.start.line}"
)
logger.debug(f"cst caller: {call_caller} at line {position.start.line}")
if (
method == call_method
and position.start.line == line
and caller.deep_equals(cst.parse_expression(call_caller))
):
CONFIG["refactorLogger"].debug("transforming")
logger.debug("transforming")
# Transform `obj.method(args)` -> `ClassName.method(args)`
new_func = cst.Attribute(
value=cst.Name(cls), # Replace `obj` with class name
Expand All @@ -65,41 +65,45 @@ def find_valid_method_calls(
"""
valid_calls = []

CONFIG["refactorLogger"].info("Finding valid method calls")
logger.debug("Finding valid method calls")

for node in tree.body:
for descendant in node.nodes_of_class(nodes.Call):
if isinstance(descendant.func, nodes.Attribute):
CONFIG["refactorLogger"].debug(f"caller: {descendant.func.expr.as_string()}")
logger.debug(f"caller: {descendant.func.expr.as_string()}")
caller = descendant.func.expr # The object calling the method
method_name = descendant.func.attrname

if method_name != mim_method:
continue

inferred_types: list[str] = []
inferrences = caller.infer()

for inferred in inferrences:
CONFIG["refactorLogger"].debug(f"inferred: {inferred.repr_name()}")
if isinstance(inferred, util.UninferableBase):
hint = check_for_annotations(caller, descendant.scope())
inits = check_for_initializations(caller, descendant.scope())
if hint:
inferred_types.append(hint.as_string())
elif inits:
inferred_types.extend(inits)
try:
inferrences = caller.infer()

for inferred in inferrences:
logger.debug(f"inferred: {inferred.repr_name()}")
if isinstance(inferred, util.UninferableBase):
hint = check_for_annotations(caller, descendant.scope())
inits = check_for_initializations(caller, descendant.scope())
if hint:
inferred_types.append(hint.as_string())
elif inits:
inferred_types.extend(inits)
else:
continue
else:
continue
else:
inferred_types.append(inferred.repr_name())
inferred_types.append(inferred.repr_name())
except astroid.InferenceError as e:
print(e)
continue

CONFIG["refactorLogger"].debug(f"Inferred types: {inferred_types}")
logger.debug(f"Inferred types: {inferred_types}")

# Check if any inferred type matches a valid class
for cls in inferred_types:
if cls in valid_classes:
CONFIG["refactorLogger"].debug(
logger.debug(
f"Foud valid call: {caller.as_string()} at line {descendant.lineno}"
)
valid_calls.append(
Expand Down Expand Up @@ -127,7 +131,7 @@ def check_for_annotations(caller: nodes.NodeNG, scope: nodes.NodeNG):
return None

hint = None
CONFIG["refactorLogger"].debug(f"annotations: {scope.args}")
logger.debug(f"annotations: {scope.args}")

args = scope.args.args
anns = scope.args.annotations
Expand Down Expand Up @@ -162,6 +166,8 @@ def refactor(
self.target_line = smell.occurences[0].line
self.target_file = target_file

print("smell:", smell)

if not smell.obj:
raise TypeError("No method object found")

Expand Down Expand Up @@ -194,12 +200,12 @@ def get_subclasses(tree: nodes.Module):
subclasses.add(klass.name)
return subclasses

CONFIG["refactorLogger"].debug("find all subclasses")
logger.debug("find all subclasses")
self.traverse(directory)
for file in self.py_files:
tree = astroid.parse(file.read_text())
self.valid_classes = self.valid_classes.union(get_subclasses(tree))
CONFIG["refactorLogger"].debug(f"valid classes: {self.valid_classes}")
logger.debug(f"valid classes: {self.valid_classes}")

def _process_file(self, file: Path):
processed = False
Expand Down Expand Up @@ -228,7 +234,7 @@ def leave_FunctionDef(
if func_name and updated_node.deep_equals(original_node):
position = self.get_metadata(PositionProvider, original_node).start # type: ignore
if position.line == self.target_line and func_name == self.mim_method:
CONFIG["refactorLogger"].debug("Modifying MIM method")
logger.debug("Modifying MIM method")
decorators = [
*list(original_node.decorators),
cst.Decorator(cst.Name("staticmethod")),
Expand Down
2 changes: 1 addition & 1 deletion src/ecooptimizer/refactorers/multi_file_refactorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def traverse(self, directory: Path):
continue

CONFIG["refactorLogger"].debug(f"Entering directory: {item!s}")
self.traverse_and_process(item)
self.traverse(item)
elif item.is_file() and item.suffix == ".py":
self.py_files.append(item)

Expand Down
Loading