Skip to content

Commit

Permalink
Added handling of org.mockito.Mockito.eq(..) when simplifying mockito…
Browse files Browse the repository at this point in the history
… matchers (#635)

* Add test case for issue #634

* Add handling of org.mockito.Mockito.eq

* Apply formatter to minimize diff

* Add issue reference to document the change

---------

Co-authored-by: Tim te Beek <[email protected]>
  • Loading branch information
adambir and timtebeek authored Nov 5, 2024
1 parent 6d665a0 commit 671addc
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class SimplifyMockitoVerifyWhenGiven extends Recipe {
private static final MethodMatcher VERIFY_MATCHER = new MethodMatcher("org.mockito.Mockito verify(..)");
private static final MethodMatcher STUBBER_MATCHER = new MethodMatcher("org.mockito.stubbing.Stubber when(..)");
private static final MethodMatcher EQ_MATCHER = new MethodMatcher("org.mockito.ArgumentMatchers eq(..)");
private static final MethodMatcher MOCKITO_EQ_MATCHER = new MethodMatcher("org.mockito.Mockito eq(..)");

@Override
public String getDisplayName() {
Expand All @@ -57,32 +58,35 @@ public Set<String> getTags() {

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
return Preconditions.check(new UsesMethod<>(EQ_MATCHER), new JavaIsoVisitor<ExecutionContext>() {
@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext ctx) {
J.MethodInvocation mi = super.visitMethodInvocation(methodInvocation, ctx);
return Preconditions.check(Preconditions.or(new UsesMethod<>(EQ_MATCHER), new UsesMethod<>(MOCKITO_EQ_MATCHER)),
new JavaIsoVisitor<ExecutionContext>() {
@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext ctx) {
J.MethodInvocation mi = super.visitMethodInvocation(methodInvocation, ctx);

if ((WHEN_MATCHER.matches(mi) || GIVEN_MATCHER.matches(mi)) && mi.getArguments().get(0) instanceof J.MethodInvocation) {
List<Expression> updatedArguments = new ArrayList<>(mi.getArguments());
updatedArguments.set(0, checkAndUpdateEq((J.MethodInvocation) mi.getArguments().get(0)));
mi = mi.withArguments(updatedArguments);
} else if (VERIFY_MATCHER.matches(mi.getSelect()) ||
STUBBER_MATCHER.matches(mi.getSelect())) {
mi = checkAndUpdateEq(mi);
}
if ((WHEN_MATCHER.matches(mi) || GIVEN_MATCHER.matches(mi)) && mi.getArguments().get(0) instanceof J.MethodInvocation) {
List<Expression> updatedArguments = new ArrayList<>(mi.getArguments());
updatedArguments.set(0, checkAndUpdateEq((J.MethodInvocation) mi.getArguments().get(0)));
mi = mi.withArguments(updatedArguments);
} else if (VERIFY_MATCHER.matches(mi.getSelect()) ||
STUBBER_MATCHER.matches(mi.getSelect())) {
mi = checkAndUpdateEq(mi);
}

maybeRemoveImport("org.mockito.ArgumentMatchers.eq");
return mi;
}
maybeRemoveImport("org.mockito.ArgumentMatchers.eq");
maybeRemoveImport("org.mockito.Mockito.eq");
return mi;
}

private J.MethodInvocation checkAndUpdateEq(J.MethodInvocation methodInvocation) {
if (methodInvocation.getArguments().stream().allMatch(EQ_MATCHER::matches)) {
return methodInvocation.withArguments(ListUtils.map(methodInvocation.getArguments(), invocation ->
((MethodCall) invocation).getArguments().get(0).withPrefix(invocation.getPrefix())));
}
return methodInvocation;
}
});
private J.MethodInvocation checkAndUpdateEq(J.MethodInvocation methodInvocation) {
if (methodInvocation.getArguments().stream().allMatch(arg -> EQ_MATCHER.matches(arg) ||
MOCKITO_EQ_MATCHER.matches(arg))) {
return methodInvocation.withArguments(ListUtils.map(methodInvocation.getArguments(), invocation ->
((MethodCall) invocation).getArguments().get(0).withPrefix(invocation.getPrefix())));
}
return methodInvocation;
}
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.junit.jupiter.api.Test;
import org.openrewrite.DocumentExample;
import org.openrewrite.InMemoryExecutionContext;
import org.openrewrite.Issue;
import org.openrewrite.java.JavaParser;
import org.openrewrite.test.RecipeSpec;
import org.openrewrite.test.RewriteTest;
Expand All @@ -42,7 +43,7 @@ void shouldRemoveUnneccesaryEqFromVerify() {
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.mock;
import static org.mockito.ArgumentMatchers.eq;
class Test {
void test() {
var mockString = mock(String.class);
Expand All @@ -52,7 +53,39 @@ void test() {
""", """
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.mock;
class Test {
void test() {
var mockString = mock(String.class);
verify(mockString).replace("foo", "bar");
}
}
"""
)
);
}

@Test
@Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/634")
void shouldRemoveUnneccesaryEqFromVerify_withMockitoStarImport() {
rewriteRun(
//language=Java
java(
"""
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
class Test {
void test() {
var mockString = mock(String.class);
verify(mockString).replace(eq("foo"), eq("bar"));
}
}
""", """
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
class Test {
void test() {
var mockString = mock(String.class);
Expand All @@ -73,7 +106,7 @@ void shouldRemoveUnneccesaryEqFromWhen() {
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.ArgumentMatchers.eq;
class Test {
void test() {
var mockString = mock(String.class);
Expand All @@ -83,7 +116,7 @@ void test() {
""", """
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
class Test {
void test() {
var mockString = mock(String.class);
Expand All @@ -105,7 +138,7 @@ void shouldNotRemoveEqWhenMatchersAreMixed() {
import static org.mockito.Mockito.when;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.anyString;
class Test {
void test() {
var mockString = mock(String.class);
Expand All @@ -125,15 +158,15 @@ void shouldRemoveUnneccesaryEqFromStubber() {
"""
import static org.mockito.Mockito.doThrow;
import static org.mockito.ArgumentMatchers.eq;
class Test {
void test() {
doThrow(new RuntimeException()).when("foo").substring(eq(1));
}
}
""", """
import static org.mockito.Mockito.doThrow;
class Test {
void test() {
doThrow(new RuntimeException()).when("foo").substring(1);
Expand All @@ -152,15 +185,15 @@ void shouldRemoveUnneccesaryEqFromBDDGiven() {
"""
import static org.mockito.BDDMockito.given;
import static org.mockito.ArgumentMatchers.eq;
class Test {
void test() {
given("foo".substring(eq(1)));
}
}
""", """
import static org.mockito.BDDMockito.given;
class Test {
void test() {
given("foo".substring(1));
Expand All @@ -181,13 +214,13 @@ void shouldNotRemoveEqImportWhenStillNeeded() {
import static org.mockito.Mockito.when;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.anyString;
class Test {
void testRemoveEq() {
var mockString = mock(String.class);
when(mockString.replace(eq("foo"), eq("bar"))).thenReturn("bar");
}
void testKeepEq() {
var mockString = mock(String.class);
when(mockString.replace(eq("foo"), anyString())).thenReturn("bar");
Expand All @@ -198,13 +231,13 @@ void testKeepEq() {
import static org.mockito.Mockito.when;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.anyString;
class Test {
void testRemoveEq() {
var mockString = mock(String.class);
when(mockString.replace("foo", "bar")).thenReturn("bar");
}
void testKeepEq() {
var mockString = mock(String.class);
when(mockString.replace(eq("foo"), anyString())).thenReturn("bar");
Expand All @@ -227,7 +260,7 @@ void shouldFixSonarExamples() {
import static org.mockito.Mockito.doThrow;
import static org.mockito.BDDMockito.given;
import static org.mockito.ArgumentMatchers.eq;
class Test {
void test(Object v1, Object v2, Object v3, Object v4, Object v5, Foo foo) {
given(foo.bar(eq(v1), eq(v2), eq(v3))).willReturn(null);
Expand All @@ -236,7 +269,7 @@ void test(Object v1, Object v2, Object v3, Object v4, Object v5, Foo foo) {
verify(foo).bar(eq(v1), eq(v2), eq(v3));
}
}
class Foo {
Object bar(Object v1, Object v2, Object v3) { return null; }
String baz(Object v4, Object v5) { return ""; }
Expand All @@ -248,7 +281,7 @@ void quux(int x) {}
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.doThrow;
import static org.mockito.BDDMockito.given;
class Test {
void test(Object v1, Object v2, Object v3, Object v4, Object v5, Foo foo) {
given(foo.bar(v1, v2, v3)).willReturn(null);
Expand All @@ -257,7 +290,7 @@ void test(Object v1, Object v2, Object v3, Object v4, Object v5, Foo foo) {
verify(foo).bar(v1, v2, v3);
}
}
class Foo {
Object bar(Object v1, Object v2, Object v3) { return null; }
String baz(Object v4, Object v5) { return ""; }
Expand Down

0 comments on commit 671addc

Please sign in to comment.