diff --git a/src/main/java/org/openrewrite/java/testing/mockito/SimplifyMockitoVerifyWhenGiven.java b/src/main/java/org/openrewrite/java/testing/mockito/SimplifyMockitoVerifyWhenGiven.java index 76260a777..e837aa9be 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/SimplifyMockitoVerifyWhenGiven.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/SimplifyMockitoVerifyWhenGiven.java @@ -39,6 +39,7 @@ public class SimplifyMockitoVerifyWhenGiven extends Recipe { private static final MethodMatcher VERIFY_MATCHER = new MethodMatcher("org.mockito.Mockito verify(..)"); private static final MethodMatcher STUBBER_MATCHER = new MethodMatcher("org.mockito.stubbing.Stubber when(..)"); private static final MethodMatcher EQ_MATCHER = new MethodMatcher("org.mockito.ArgumentMatchers eq(..)"); + private static final MethodMatcher MOCKITO_EQ_MATCHER = new MethodMatcher("org.mockito.Mockito eq(..)"); @Override public String getDisplayName() { @@ -57,32 +58,35 @@ public Set getTags() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesMethod<>(EQ_MATCHER), new JavaIsoVisitor() { - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext ctx) { - J.MethodInvocation mi = super.visitMethodInvocation(methodInvocation, ctx); + return Preconditions.check(Preconditions.or(new UsesMethod<>(EQ_MATCHER), new UsesMethod<>(MOCKITO_EQ_MATCHER)), + new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(methodInvocation, ctx); - if ((WHEN_MATCHER.matches(mi) || GIVEN_MATCHER.matches(mi)) && mi.getArguments().get(0) instanceof J.MethodInvocation) { - List updatedArguments = new ArrayList<>(mi.getArguments()); - updatedArguments.set(0, checkAndUpdateEq((J.MethodInvocation) mi.getArguments().get(0))); - mi = mi.withArguments(updatedArguments); - } else if (VERIFY_MATCHER.matches(mi.getSelect()) || - STUBBER_MATCHER.matches(mi.getSelect())) { - mi = checkAndUpdateEq(mi); - } + if ((WHEN_MATCHER.matches(mi) || GIVEN_MATCHER.matches(mi)) && mi.getArguments().get(0) instanceof J.MethodInvocation) { + List updatedArguments = new ArrayList<>(mi.getArguments()); + updatedArguments.set(0, checkAndUpdateEq((J.MethodInvocation) mi.getArguments().get(0))); + mi = mi.withArguments(updatedArguments); + } else if (VERIFY_MATCHER.matches(mi.getSelect()) || + STUBBER_MATCHER.matches(mi.getSelect())) { + mi = checkAndUpdateEq(mi); + } - maybeRemoveImport("org.mockito.ArgumentMatchers.eq"); - return mi; - } + maybeRemoveImport("org.mockito.ArgumentMatchers.eq"); + maybeRemoveImport("org.mockito.Mockito.eq"); + return mi; + } - private J.MethodInvocation checkAndUpdateEq(J.MethodInvocation methodInvocation) { - if (methodInvocation.getArguments().stream().allMatch(EQ_MATCHER::matches)) { - return methodInvocation.withArguments(ListUtils.map(methodInvocation.getArguments(), invocation -> - ((MethodCall) invocation).getArguments().get(0).withPrefix(invocation.getPrefix()))); - } - return methodInvocation; - } - }); + private J.MethodInvocation checkAndUpdateEq(J.MethodInvocation methodInvocation) { + if (methodInvocation.getArguments().stream().allMatch(arg -> EQ_MATCHER.matches(arg) || + MOCKITO_EQ_MATCHER.matches(arg))) { + return methodInvocation.withArguments(ListUtils.map(methodInvocation.getArguments(), invocation -> + ((MethodCall) invocation).getArguments().get(0).withPrefix(invocation.getPrefix()))); + } + return methodInvocation; + } + }); } } diff --git a/src/test/java/org/openrewrite/java/testing/mockito/SimplifyMockitoVerifyWhenGivenTest.java b/src/test/java/org/openrewrite/java/testing/mockito/SimplifyMockitoVerifyWhenGivenTest.java index 69815db02..baa688cdf 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/SimplifyMockitoVerifyWhenGivenTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/SimplifyMockitoVerifyWhenGivenTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.InMemoryExecutionContext; +import org.openrewrite.Issue; import org.openrewrite.java.JavaParser; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; @@ -42,7 +43,7 @@ void shouldRemoveUnneccesaryEqFromVerify() { import static org.mockito.Mockito.verify; import static org.mockito.Mockito.mock; import static org.mockito.ArgumentMatchers.eq; - + class Test { void test() { var mockString = mock(String.class); @@ -52,7 +53,39 @@ void test() { """, """ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.mock; - + + class Test { + void test() { + var mockString = mock(String.class); + verify(mockString).replace("foo", "bar"); + } + } + """ + ) + ); + } + + @Test + @Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/634") + void shouldRemoveUnneccesaryEqFromVerify_withMockitoStarImport() { + rewriteRun( + //language=Java + java( + """ + import static org.mockito.Mockito.eq; + import static org.mockito.Mockito.mock; + import static org.mockito.Mockito.verify; + + class Test { + void test() { + var mockString = mock(String.class); + verify(mockString).replace(eq("foo"), eq("bar")); + } + } + """, """ + import static org.mockito.Mockito.mock; + import static org.mockito.Mockito.verify; + class Test { void test() { var mockString = mock(String.class); @@ -73,7 +106,7 @@ void shouldRemoveUnneccesaryEqFromWhen() { import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.mockito.ArgumentMatchers.eq; - + class Test { void test() { var mockString = mock(String.class); @@ -83,7 +116,7 @@ void test() { """, """ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; - + class Test { void test() { var mockString = mock(String.class); @@ -105,7 +138,7 @@ void shouldNotRemoveEqWhenMatchersAreMixed() { import static org.mockito.Mockito.when; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.anyString; - + class Test { void test() { var mockString = mock(String.class); @@ -125,7 +158,7 @@ void shouldRemoveUnneccesaryEqFromStubber() { """ import static org.mockito.Mockito.doThrow; import static org.mockito.ArgumentMatchers.eq; - + class Test { void test() { doThrow(new RuntimeException()).when("foo").substring(eq(1)); @@ -133,7 +166,7 @@ void test() { } """, """ import static org.mockito.Mockito.doThrow; - + class Test { void test() { doThrow(new RuntimeException()).when("foo").substring(1); @@ -152,7 +185,7 @@ void shouldRemoveUnneccesaryEqFromBDDGiven() { """ import static org.mockito.BDDMockito.given; import static org.mockito.ArgumentMatchers.eq; - + class Test { void test() { given("foo".substring(eq(1))); @@ -160,7 +193,7 @@ void test() { } """, """ import static org.mockito.BDDMockito.given; - + class Test { void test() { given("foo".substring(1)); @@ -181,13 +214,13 @@ void shouldNotRemoveEqImportWhenStillNeeded() { import static org.mockito.Mockito.when; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.anyString; - + class Test { void testRemoveEq() { var mockString = mock(String.class); when(mockString.replace(eq("foo"), eq("bar"))).thenReturn("bar"); } - + void testKeepEq() { var mockString = mock(String.class); when(mockString.replace(eq("foo"), anyString())).thenReturn("bar"); @@ -198,13 +231,13 @@ void testKeepEq() { import static org.mockito.Mockito.when; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.anyString; - + class Test { void testRemoveEq() { var mockString = mock(String.class); when(mockString.replace("foo", "bar")).thenReturn("bar"); } - + void testKeepEq() { var mockString = mock(String.class); when(mockString.replace(eq("foo"), anyString())).thenReturn("bar"); @@ -227,7 +260,7 @@ void shouldFixSonarExamples() { import static org.mockito.Mockito.doThrow; import static org.mockito.BDDMockito.given; import static org.mockito.ArgumentMatchers.eq; - + class Test { void test(Object v1, Object v2, Object v3, Object v4, Object v5, Foo foo) { given(foo.bar(eq(v1), eq(v2), eq(v3))).willReturn(null); @@ -236,7 +269,7 @@ void test(Object v1, Object v2, Object v3, Object v4, Object v5, Foo foo) { verify(foo).bar(eq(v1), eq(v2), eq(v3)); } } - + class Foo { Object bar(Object v1, Object v2, Object v3) { return null; } String baz(Object v4, Object v5) { return ""; } @@ -248,7 +281,7 @@ void quux(int x) {} import static org.mockito.Mockito.verify; import static org.mockito.Mockito.doThrow; import static org.mockito.BDDMockito.given; - + class Test { void test(Object v1, Object v2, Object v3, Object v4, Object v5, Foo foo) { given(foo.bar(v1, v2, v3)).willReturn(null); @@ -257,7 +290,7 @@ void test(Object v1, Object v2, Object v3, Object v4, Object v5, Foo foo) { verify(foo).bar(v1, v2, v3); } } - + class Foo { Object bar(Object v1, Object v2, Object v3) { return null; } String baz(Object v4, Object v5) { return ""; }