Detect missing test annotations

Test: ./ravenwood/run-ravenwood-tests.sh
Bug: 292141694
Change-Id: I135939464d917024d667feae9998bf324e522831
diff --git a/ravenwood/junit-impl-src/android/platform/test/ravenwood/RavenwoodRuleImpl.java b/ravenwood/junit-impl-src/android/platform/test/ravenwood/RavenwoodRuleImpl.java
index 7b5932b..1d5c79c 100644
--- a/ravenwood/junit-impl-src/android/platform/test/ravenwood/RavenwoodRuleImpl.java
+++ b/ravenwood/junit-impl-src/android/platform/test/ravenwood/RavenwoodRuleImpl.java
@@ -16,6 +16,8 @@
 
 package android.platform.test.ravenwood;
 
+import static org.junit.Assert.assertFalse;
+
 import android.app.ActivityManager;
 import android.app.Instrumentation;
 import android.os.Build;
@@ -28,12 +30,19 @@
 
 import com.android.internal.os.RuntimeInit;
 
+import org.junit.After;
 import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
 import org.junit.runner.Description;
 import org.junit.runner.RunWith;
 import org.junit.runners.model.Statement;
 
 import java.io.PrintStream;
+import java.lang.reflect.Method;
+import java.lang.reflect.Modifier;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.Executors;
@@ -183,6 +192,7 @@
     public static void validate(Statement base, Description description,
             boolean enableOptionalValidation) {
         validateTestRunner(base, description, enableOptionalValidation);
+        validateTestAnnotations(base, description, enableOptionalValidation);
     }
 
     private static void validateTestRunner(Statement base, Description description,
@@ -206,4 +216,63 @@
             }
         }
     }
+
+    private static void validateTestAnnotations(Statement base, Description description,
+            boolean enableOptionalValidation) {
+        final var testClass = description.getTestClass();
+
+        final var message = new StringBuilder();
+
+        boolean hasErrors = false;
+        for (Method m : collectMethods(testClass)) {
+            if (Modifier.isPublic(m.getModifiers()) && m.getName().startsWith("test")) {
+                if (m.getAnnotation(Test.class) == null) {
+                    message.append("\nMethod " + m.getName() + "() doesn't have @Test");
+                    hasErrors = true;
+                }
+            }
+            if ("setUp".equals(m.getName())) {
+                if (m.getAnnotation(Before.class) == null) {
+                    message.append("\nMethod " + m.getName() + "() doesn't have @Before");
+                    hasErrors = true;
+                }
+                if (!Modifier.isPublic(m.getModifiers())) {
+                    message.append("\nMethod " + m.getName() + "() must be public");
+                    hasErrors = true;
+                }
+            }
+            if ("tearDown".equals(m.getName())) {
+                if (m.getAnnotation(After.class) == null) {
+                    message.append("\nMethod " + m.getName() + "() doesn't have @After");
+                    hasErrors = true;
+                }
+                if (!Modifier.isPublic(m.getModifiers())) {
+                    message.append("\nMethod " + m.getName() + "() must be public");
+                    hasErrors = true;
+                }
+            }
+        }
+        assertFalse("Problem(s) detected in class " + testClass.getCanonicalName() + ":"
+                + message, hasErrors);
+    }
+
+    /**
+     * Collect all (public or private or any) methods in a class, including inherited methods.
+     */
+    private static List<Method> collectMethods(Class<?> clazz) {
+        var ret = new ArrayList<Method>();
+        collectMethods(clazz, ret);
+        return ret;
+    }
+
+    private static void collectMethods(Class<?> clazz, List<Method> result) {
+        // Class.getMethods() only return public methods, so we need to use getDeclaredMethods()
+        // instead, and recurse.
+        for (var m : clazz.getDeclaredMethods()) {
+            result.add(m);
+        }
+        if (clazz.getSuperclass() != null) {
+            collectMethods(clazz.getSuperclass(), result);
+        }
+    }
 }