Skip to content

Extend InlineVariable to support local variable assignments #647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 22, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ public J.Switch visitSwitch(J.Switch switch_, P p) {

casesWithDefaultLast = addBreakToLastCase(casesWithDefaultLast, p);
casesWithDefaultLast.addAll(maybeReorderFallthroughCases(defaultCases, p));
casesWithDefaultLast = ListUtils.mapLast(casesWithDefaultLast, this::removeBreak);
return casesWithDefaultLast;
return ListUtils.mapLast(casesWithDefaultLast, this::removeBreak);
}

private List<J.Case> maybeReorderFallthroughCases(List<J.Case> cases, P p) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.tree.*;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.Space;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.marker.Markers;

import java.util.List;
Expand Down Expand Up @@ -69,8 +72,7 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl
}

List<Statement> parameters = ListUtils.map(declarations.getParameters(), FinalizeMethodArguments::updateParam);
declarations = declarations.withParameters(parameters);
return declarations;
return declarations.withParameters(parameters);
}

private void checkIfAssigned(final AtomicBoolean assigned, final Statement p) {
Expand Down Expand Up @@ -176,8 +178,7 @@ private static Statement updateParam(final Statement p) {
J.VariableDeclarations variableDeclarations = (J.VariableDeclarations) p;
if (variableDeclarations.getModifiers().isEmpty()) {
variableDeclarations = updateModifiers(variableDeclarations, !((J.VariableDeclarations) p).getLeadingAnnotations().isEmpty());
variableDeclarations = updateDeclarations(variableDeclarations);
return variableDeclarations;
return updateDeclarations(variableDeclarations);
}
}
return p;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,12 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
argIndex++;
}
int finalArgIndex = argIndex;
mi = mi.withArguments(ListUtils.map(mi.getArguments(), (i, arg) -> {
return mi.withArguments(ListUtils.map(mi.getArguments(), (i, arg) -> {
if (i == 0 || i < finalArgIndex) {
return arg;
}
return null;
}));
return mi;
}
return mi;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,13 @@ public J visitForLoop(J.ForLoop forLoop, ExecutionContext ctx) {
Comparator.comparing(s -> s.printTrimmed(getCursor()), Comparator.naturalOrder())
)));

//noinspection ConstantConditions
f = f.withBody((Statement) new JavaVisitor<ExecutionContext>() {
return f.withBody((Statement) new JavaVisitor<ExecutionContext>() {

@Override
public @Nullable J visit(@Nullable Tree tree, ExecutionContext ctx) {
return tree == unary ? null : super.visit(tree, ctx);
}
}.visit(f.getBody(), ctx));

return f;
}
}
}
Expand Down
95 changes: 60 additions & 35 deletions src/main/java/org/openrewrite/staticanalysis/InlineVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.java.search.SemanticallyEqual;
import org.openrewrite.java.tree.*;

import java.time.Duration;
import java.util.Collections;
Expand All @@ -42,7 +40,8 @@ public String getDisplayName() {

@Override
public String getDescription() {
return "Inline variables when they are immediately used to return or throw.";
return "Inline variables when they are immediately used to return or throw. " +
"Supports both variable declarations and assignments to local variables.";
}

@Override
Expand All @@ -62,53 +61,79 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
J.Block bl = super.visitBlock(block, ctx);
List<Statement> statements = bl.getStatements();
if (statements.size() > 1) {
String identReturned = identReturned(statements);
if (1 < statements.size()) {
J.Identifier identReturned = identReturnedOrThrown(statements);
if (identReturned != null) {
if (statements.get(statements.size() - 2) instanceof J.VariableDeclarations) {
J.VariableDeclarations varDec = (J.VariableDeclarations) statements.get(statements.size() - 2);
J.VariableDeclarations.NamedVariable identDefinition = varDec.getVariables().get(0);
if (varDec.getLeadingAnnotations().isEmpty() && identDefinition.getSimpleName().equals(identReturned)) {
bl = bl.withStatements(ListUtils.map(statements, (i, statement) -> {
if (i == statements.size() - 2) {
return null;
}
if (i == statements.size() - 1) {
if (statement instanceof J.Return) {
J.Return return_ = (J.Return) statement;
return return_.withExpression(requireNonNull(identDefinition.getInitializer())
.withPrefix(requireNonNull(return_.getExpression()).getPrefix()))
.withPrefix(varDec.getPrefix().withComments(ListUtils.concatAll(varDec.getComments(), return_.getComments())));
}
if (statement instanceof J.Throw) {
J.Throw thrown = (J.Throw) statement;
return thrown.withException(requireNonNull(identDefinition.getInitializer())
.withPrefix(requireNonNull(thrown.getException()).getPrefix()))
.withPrefix(varDec.getPrefix().withComments(ListUtils.concatAll(varDec.getComments(), thrown.getComments())));
}
}
return statement;
}));
Statement secondLastStatement = statements.get(statements.size() - 2);
if (secondLastStatement instanceof J.VariableDeclarations) {
J.VariableDeclarations varDec = (J.VariableDeclarations) secondLastStatement;
// Only inline if there's exactly one variable declared
if (varDec.getVariables().size() == 1) {
J.VariableDeclarations.NamedVariable identDefinition = varDec.getVariables().get(0);
if (varDec.getLeadingAnnotations().isEmpty() &&
SemanticallyEqual.areEqual(identDefinition.getName(), identReturned)) {
return inlineExpression(identDefinition.getInitializer(), bl, statements, varDec.getPrefix(), varDec.getComments());
}
}
} else if (secondLastStatement instanceof J.Assignment) {
J.Assignment assignment = (J.Assignment) secondLastStatement;
if (assignment.getVariable() instanceof J.Identifier) {
J.Identifier assignedVar = (J.Identifier) assignment.getVariable();
// Only inline local variable assignments, not fields
if (assignedVar.getFieldType() != null &&
assignedVar.getFieldType().getOwner() instanceof JavaType.Method &&
SemanticallyEqual.areEqual(assignedVar, identReturned)) {
doAfterVisit(new RemoveUnusedLocalVariables(null, null, null).getVisitor());
return inlineExpression(assignment.getAssignment(), bl, statements, assignment.getPrefix(), assignment.getComments());
}
}
}
}
}
return bl;
}

private @Nullable String identReturned(List<Statement> stats) {
private J.Block inlineExpression(@Nullable Expression expression, J.Block bl, List<Statement> statements,
Space prefix, List<Comment> comments) {
if (expression == null) {
return bl;
}

return bl.withStatements(ListUtils.map(statements, (i, statement) -> {
if (i == statements.size() - 2) {
return null;
}
if (i == statements.size() - 1) {
if (statement instanceof J.Return) {
J.Return return_ = (J.Return) statement;
return return_
.withExpression(expression.withPrefix(requireNonNull(return_.getExpression()).getPrefix()))
.withPrefix(prefix.withComments(ListUtils.concatAll(comments, return_.getComments())));
}
if (statement instanceof J.Throw) {
J.Throw thrown = (J.Throw) statement;
return thrown.
withException(expression.withPrefix(requireNonNull(thrown.getException()).getPrefix()))
.withPrefix(prefix.withComments(ListUtils.concatAll(comments, thrown.getComments())));
}
}
return statement;
}));
}

private J.@Nullable Identifier identReturnedOrThrown(List<Statement> stats) {
Statement lastStatement = stats.get(stats.size() - 1);
if (lastStatement instanceof J.Return) {
J.Return return_ = (J.Return) lastStatement;
Expression expression = return_.getExpression();
if (expression instanceof J.Identifier &&
!(expression.getType() instanceof JavaType.Array)) {
return ((J.Identifier) expression).getSimpleName();
!(expression.getType() instanceof JavaType.Array)) {
return ((J.Identifier) expression);
}
} else if (lastStatement instanceof J.Throw) {
J.Throw thr = (J.Throw) lastStatement;
if (thr.getException() instanceof J.Identifier) {
return ((J.Identifier) thr.getException()).getSimpleName();
return ((J.Identifier) thr.getException());
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,7 @@ public J visitBinary(J.Binary binary, Integer p) {
@Override
public J.InstanceOf visitInstanceOf(J.InstanceOf instanceOf, Integer p) {
instanceOf = (J.InstanceOf) super.visitInstanceOf(instanceOf, p);
instanceOf = replacements.processInstanceOf(instanceOf, getCursor());
return instanceOf;
return replacements.processInstanceOf(instanceOf, getCursor());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ private boolean switchesOnEnum(J.Switch switch_) {
}

private static J.If createIfForEnum(Expression expression, Expression enumTree) {
J.If generatedIf;
if (enumTree instanceof J.Identifier) {
enumTree = new J.FieldAccess(
randomId(),
Expand All @@ -267,15 +266,14 @@ private static J.If createIfForEnum(Expression expression, Expression enumTree)
);
}
J.Binary ifCond = JavaElementFactory.newLogicalExpression(J.Binary.Type.Equal, expression, enumTree);
generatedIf = new J.If(
return new J.If(
randomId(),
Space.EMPTY,
Markers.EMPTY,
new J.ControlParentheses<>(randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build(ifCond)),
JRightPadded.build(J.Block.createEmptyBlock()),
null
);
return generatedIf;
}

@AllArgsConstructor
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/org/openrewrite/staticanalysis/NoFinalizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,15 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
@Override
public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
J.ClassDeclaration cd = super.visitClassDeclaration(classDecl, ctx);
cd = cd.withBody(cd.getBody().withStatements(ListUtils.map(cd.getBody().getStatements(), stmt -> {

return cd.withBody(cd.getBody().withStatements(ListUtils.map(cd.getBody().getStatements(), stmt -> {
if (stmt instanceof J.MethodDeclaration) {
if (FINALIZER.matches((J.MethodDeclaration) stmt, classDecl)) {
return null;
}
}
return stmt;
})));

return cd;
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,14 @@ private static class PruneAssignmentExpression extends JavaIsoVisitor<ExecutionC

@Override
public <T extends J> J.ControlParentheses<T> visitControlParentheses(J.ControlParentheses<T> c, ExecutionContext ctx) {
//noinspection unchecked
c = (J.ControlParentheses<T>) new AssignmentToLiteral(assignment)
return (J.ControlParentheses<T>) new AssignmentToLiteral(assignment)
.visitNonNull(c, ctx, getCursor().getParentOrThrow());
return c;
}

@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation m, ExecutionContext ctx) {
AssignmentToLiteral atl = new AssignmentToLiteral(assignment);
m = m.withArguments(ListUtils.map(m.getArguments(), it -> (Expression) atl.visitNonNull(it, ctx, getCursor().getParentOrThrow())));
return m;
return m.withArguments(ListUtils.map(m.getArguments(), it -> (Expression) atl.visitNonNull(it, ctx, getCursor().getParentOrThrow())));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,8 @@ public J visit(@Nullable Tree tree, ExecutionContext ctx) {
@Override
public J.TypeCast visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) {
J.TypeCast tc = super.visitTypeCast(typeCast, ctx);
tc = (J.TypeCast) new SpacesVisitor<>(spacesStyle, null, null, tc)
return (J.TypeCast) new SpacesVisitor<>(spacesStyle, null, null, tc)
.visitNonNull(tc, ctx, getCursor().getParentTreeCursor().fork());
return tc;
}
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ private J.VariableDeclarations consolidateBuilder(J.VariableDeclarations consoli
);
})
);
cb = formatTabsAndIndents(cb, getCursor());
return cb;
return formatTabsAndIndents(cb, getCursor());
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ public J visitForLoop(J.ForLoop forLoop, ExecutionContext ctx) {
!(forLoop.getControl().getCondition() instanceof J.Empty)
) {
J.WhileLoop w = whileLoop.apply(getCursor(), forLoop.getCoordinates().replace(), forLoop.getControl().getCondition());
w = w.withBody(forLoop.getBody());
return w;
return w.withBody(forLoop.getBody());
}
return super.visitForLoop(forLoop, ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ private TypeTree annotateInnerClass(TypeTree qualifiedClassRef, J.Annotation ann
}
if (qualifiedClassRef instanceof J.ArrayType) {
J.ArrayType at = (J.ArrayType) qualifiedClassRef;
at = at.withAnnotations(ListUtils.concat(annotation.withPrefix(Space.SINGLE_SPACE), at.getAnnotations()));
return at;
return at.withAnnotations(ListUtils.concat(annotation.withPrefix(Space.SINGLE_SPACE), at.getAnnotations()));
}
return qualifiedClassRef;
}
Expand Down
Loading