Detect missing test annotations
Test: ./ravenwood/run-ravenwood-tests.sh
Bug: 292141694
Change-Id: I135939464d917024d667feae9998bf324e522831
diff --git a/ravenwood/junit-impl-src/android/platform/test/ravenwood/RavenwoodRuleImpl.java b/ravenwood/junit-impl-src/android/platform/test/ravenwood/RavenwoodRuleImpl.java
index 7b5932b..1d5c79c 100644
--- a/ravenwood/junit-impl-src/android/platform/test/ravenwood/RavenwoodRuleImpl.java
+++ b/ravenwood/junit-impl-src/android/platform/test/ravenwood/RavenwoodRuleImpl.java
@@ -16,6 +16,8 @@
package android.platform.test.ravenwood;
+import static org.junit.Assert.assertFalse;
+
import android.app.ActivityManager;
import android.app.Instrumentation;
import android.os.Build;
@@ -28,12 +30,19 @@
import com.android.internal.os.RuntimeInit;
+import org.junit.After;
import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
import org.junit.runner.Description;
import org.junit.runner.RunWith;
import org.junit.runners.model.Statement;
import java.io.PrintStream;
+import java.lang.reflect.Method;
+import java.lang.reflect.Modifier;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executors;
@@ -183,6 +192,7 @@
public static void validate(Statement base, Description description,
boolean enableOptionalValidation) {
validateTestRunner(base, description, enableOptionalValidation);
+ validateTestAnnotations(base, description, enableOptionalValidation);
}
private static void validateTestRunner(Statement base, Description description,
@@ -206,4 +216,63 @@
}
}
}
+
+ private static void validateTestAnnotations(Statement base, Description description,
+ boolean enableOptionalValidation) {
+ final var testClass = description.getTestClass();
+
+ final var message = new StringBuilder();
+
+ boolean hasErrors = false;
+ for (Method m : collectMethods(testClass)) {
+ if (Modifier.isPublic(m.getModifiers()) && m.getName().startsWith("test")) {
+ if (m.getAnnotation(Test.class) == null) {
+ message.append("\nMethod " + m.getName() + "() doesn't have @Test");
+ hasErrors = true;
+ }
+ }
+ if ("setUp".equals(m.getName())) {
+ if (m.getAnnotation(Before.class) == null) {
+ message.append("\nMethod " + m.getName() + "() doesn't have @Before");
+ hasErrors = true;
+ }
+ if (!Modifier.isPublic(m.getModifiers())) {
+ message.append("\nMethod " + m.getName() + "() must be public");
+ hasErrors = true;
+ }
+ }
+ if ("tearDown".equals(m.getName())) {
+ if (m.getAnnotation(After.class) == null) {
+ message.append("\nMethod " + m.getName() + "() doesn't have @After");
+ hasErrors = true;
+ }
+ if (!Modifier.isPublic(m.getModifiers())) {
+ message.append("\nMethod " + m.getName() + "() must be public");
+ hasErrors = true;
+ }
+ }
+ }
+ assertFalse("Problem(s) detected in class " + testClass.getCanonicalName() + ":"
+ + message, hasErrors);
+ }
+
+ /**
+ * Collect all (public or private or any) methods in a class, including inherited methods.
+ */
+ private static List<Method> collectMethods(Class<?> clazz) {
+ var ret = new ArrayList<Method>();
+ collectMethods(clazz, ret);
+ return ret;
+ }
+
+ private static void collectMethods(Class<?> clazz, List<Method> result) {
+ // Class.getMethods() only return public methods, so we need to use getDeclaredMethods()
+ // instead, and recurse.
+ for (var m : clazz.getDeclaredMethods()) {
+ result.add(m);
+ }
+ if (clazz.getSuperclass() != null) {
+ collectMethods(clazz.getSuperclass(), result);
+ }
+ }
}