Add model validations for custom callbacks

Test: ./gradlew :privacysandbox:tools:tools-core:test
Bug: 249937257
Change-Id: I06f100109e218eb6b142e488977ac9d71aafc3bb
diff --git a/privacysandbox/tools/tools-apicompiler/src/test/java/androidx/privacysandbox/tools/apicompiler/parser/InterfaceParserTest.kt b/privacysandbox/tools/tools-apicompiler/src/test/java/androidx/privacysandbox/tools/apicompiler/parser/InterfaceParserTest.kt
index e79f4e4..c9c7a11 100644
--- a/privacysandbox/tools/tools-apicompiler/src/test/java/androidx/privacysandbox/tools/apicompiler/parser/InterfaceParserTest.kt
+++ b/privacysandbox/tools/tools-apicompiler/src/test/java/androidx/privacysandbox/tools/apicompiler/parser/InterfaceParserTest.kt
@@ -220,8 +220,9 @@
     fun parameterWithGenerics_fails() {
         checkSourceFails(serviceMethod("suspend fun foo(x: MutableList<Int>)"))
             .containsExactlyErrors(
-                "Error in com.mysdk.MySdk.foo: only primitives and data classes annotated with " +
-                    "@PrivacySandboxValue are supported as parameter and return types."
+                "Error in com.mysdk.MySdk.foo: only primitives, data classes annotated with " +
+                    "@PrivacySandboxValue and interfaces annotated with @PrivacySandboxCallback " +
+                    "are supported as parameter types."
             )
     }
 
@@ -229,8 +230,9 @@
     fun parameterLambda_fails() {
         checkSourceFails(serviceMethod("suspend fun foo(x: (Int) -> Int)"))
             .containsExactlyErrors(
-                "Error in com.mysdk.MySdk.foo: only primitives and data classes annotated with " +
-                    "@PrivacySandboxValue are supported as parameter and return types."
+                "Error in com.mysdk.MySdk.foo: only primitives, data classes annotated with " +
+                    "@PrivacySandboxValue and interfaces annotated with @PrivacySandboxCallback " +
+                    "are supported as parameter types."
             )
     }
 
@@ -250,7 +252,7 @@
         )
         checkSourceFails(source).containsExactlyErrors(
             "Error in com.mysdk.MySdk.foo: only primitives and data classes annotated with " +
-                "@PrivacySandboxValue are supported as parameter and return types."
+                "@PrivacySandboxValue are supported as return types."
         )
     }
 
diff --git a/privacysandbox/tools/tools-core/src/main/java/androidx/privacysandbox/tools/core/validator/ModelValidator.kt b/privacysandbox/tools/tools-core/src/main/java/androidx/privacysandbox/tools/core/validator/ModelValidator.kt
index 402ff85..131a8f2 100644
--- a/privacysandbox/tools/tools-core/src/main/java/androidx/privacysandbox/tools/core/validator/ModelValidator.kt
+++ b/privacysandbox/tools/tools-core/src/main/java/androidx/privacysandbox/tools/core/validator/ModelValidator.kt
@@ -16,6 +16,7 @@
 
 package androidx.privacysandbox.tools.core.validator
 
+import androidx.privacysandbox.tools.core.model.AnnotatedInterface
 import androidx.privacysandbox.tools.core.model.AnnotatedValue
 import androidx.privacysandbox.tools.core.model.ParsedApi
 import androidx.privacysandbox.tools.core.model.Types
@@ -32,6 +33,8 @@
         validateNonSuspendFunctionsReturnUnit()
         validateParameterAndReturnValueTypes()
         validateValuePropertyTypes()
+        callbackMethodsAreFireAndForget()
+        callbacksDontReceiveCallbacks()
         return ValidationResult(errors)
     }
 
@@ -58,18 +61,27 @@
     }
 
     private fun validateParameterAndReturnValueTypes() {
-        val allowedParameterAndReturnValueTypes =
+        val allowedParameterTypes =
+            (api.values.map(AnnotatedValue::type) +
+                api.callbacks.map(AnnotatedInterface::type) +
+                Types.primitiveTypes).toSet()
+        val allowedReturnValueTypes =
             (api.values.map(AnnotatedValue::type) + Types.primitiveTypes).toSet()
         for (service in api.services) {
             for (method in service.methods) {
-                val isAnyTypeInvalid = (method.parameters.map { it.type } + method.returnType).any {
-                    !allowedParameterAndReturnValueTypes.contains(it)
+                if (method.parameters.any { !allowedParameterTypes.contains(it.type) }) {
+                    errors.add(
+                        "Error in ${service.type.qualifiedName}.${method.name}: " +
+                            "only primitives, data classes annotated with @PrivacySandboxValue " +
+                            "and interfaces annotated with @PrivacySandboxCallback are supported " +
+                            "as parameter types."
+                    )
                 }
-                if (isAnyTypeInvalid) {
+                if (!allowedReturnValueTypes.contains(method.returnType)) {
                     errors.add(
                         "Error in ${service.type.qualifiedName}.${method.name}: " +
                             "only primitives and data classes annotated with " +
-                            "@PrivacySandboxValue are supported as parameter and return types."
+                            "@PrivacySandboxValue are supported as return types."
                     )
                 }
             }
@@ -92,7 +104,32 @@
         }
     }
 
-    // TODO: check that callback methods are fire-and-forget
+    private fun callbackMethodsAreFireAndForget() {
+        for (callback in api.callbacks) {
+            for (method in callback.methods) {
+                if (method.returnType != Types.unit || method.isSuspend) {
+                    errors.add(
+                        "Error in ${callback.type.qualifiedName}.${method.name}: callback " +
+                            "methods should be non-suspending and have no return values."
+                    )
+                }
+            }
+        }
+    }
+
+    private fun callbacksDontReceiveCallbacks() {
+        val callbackTypes = api.callbacks.map { it.type }.toSet()
+        for (callback in api.callbacks) {
+            for (method in callback.methods) {
+                if (method.parameters.any { callbackTypes.contains(it.type) }) {
+                    errors.add(
+                        "Error in ${callback.type.qualifiedName}.${method.name}: callback " +
+                            "methods cannot receive other callbacks as arguments."
+                    )
+                }
+            }
+        }
+    }
 }
 
 data class ValidationResult(val errors: List<String>) {
diff --git a/privacysandbox/tools/tools-core/src/test/java/androidx/privacysandbox/tools/core/validator/ModelValidatorTest.kt b/privacysandbox/tools/tools-core/src/test/java/androidx/privacysandbox/tools/core/validator/ModelValidatorTest.kt
index b6eaf0b..604cb6f 100644
--- a/privacysandbox/tools/tools-core/src/test/java/androidx/privacysandbox/tools/core/validator/ModelValidatorTest.kt
+++ b/privacysandbox/tools/tools-core/src/test/java/androidx/privacysandbox/tools/core/validator/ModelValidatorTest.kt
@@ -48,7 +48,14 @@
                                 Parameter(
                                     name = "foo",
                                     type = Type(packageName = "com.mysdk", simpleName = "Foo")
-                                )
+                                ),
+                                Parameter(
+                                    name = "callback",
+                                    type = Type(
+                                        packageName = "com.mysdk",
+                                        simpleName = "MySdkCallback"
+                                    )
+                                ),
                             ),
                             returnType = Types.string,
                             isSuspend = true,
@@ -76,6 +83,24 @@
                     type = Type(packageName = "com.mysdk", simpleName = "Bar"),
                     properties = emptyList(),
                 )
+            ),
+            callbacks = setOf(
+                AnnotatedInterface(
+                    type = Type(packageName = "com.mysdk", simpleName = "MySdkCallback"),
+                    methods = listOf(
+                        Method(
+                            name = "onComplete",
+                            parameters = listOf(
+                                Parameter(
+                                    name = "result",
+                                    type = Types.int
+                                ),
+                            ),
+                            returnType = Types.unit,
+                            isSuspend = false,
+                        ),
+                    )
+                )
             )
         )
         assertThat(ModelValidator.validate(api).isSuccess).isTrue()
@@ -173,9 +198,10 @@
         assertThat(validationResult.isFailure).isTrue()
         assertThat(validationResult.errors).containsExactly(
             "Error in com.mysdk.MySdk.returnFoo: only primitives and data classes annotated with " +
-                "@PrivacySandboxValue are supported as parameter and return types.",
-            "Error in com.mysdk.MySdk.receiveFoo: only primitives and data classes annotated " +
-                "with @PrivacySandboxValue are supported as parameter and return types."
+                "@PrivacySandboxValue are supported as return types.",
+            "Error in com.mysdk.MySdk.receiveFoo: only primitives, data classes annotated with " +
+                "@PrivacySandboxValue and interfaces annotated with @PrivacySandboxCallback are " +
+                "supported as parameter types."
         )
     }
 
@@ -201,4 +227,70 @@
                 "@PrivacySandboxValue are supported as properties."
         )
     }
+
+    @Test
+    fun callbackWithNonFireAndForgetMethod_throws() {
+        val api = ParsedApi(
+            services = setOf(
+                AnnotatedInterface(type = Type(packageName = "com.mysdk", simpleName = "MySdk")),
+            ),
+            callbacks = setOf(
+                AnnotatedInterface(
+                    type = Type(packageName = "com.mysdk", simpleName = "MySdkCallback"),
+                    methods = listOf(
+                        Method(
+                            name = "suspendMethod",
+                            parameters = listOf(),
+                            returnType = Types.unit,
+                            isSuspend = true,
+                        ),
+                        Method(
+                            name = "methodWithReturnValue",
+                            parameters = listOf(),
+                            returnType = Types.int,
+                            isSuspend = false,
+                        ),
+                    )
+                )
+            )
+        )
+        val validationResult = ModelValidator.validate(api)
+        assertThat(validationResult.isFailure).isTrue()
+        assertThat(validationResult.errors).containsExactly(
+            "Error in com.mysdk.MySdkCallback.suspendMethod: callback methods should be " +
+                "non-suspending and have no return values.",
+            "Error in com.mysdk.MySdkCallback.methodWithReturnValue: callback methods should be " +
+                "non-suspending and have no return values.",
+        )
+    }
+
+    @Test
+    fun callbackReceivingCallbacks_throws() {
+        val api = ParsedApi(
+            services = setOf(
+                AnnotatedInterface(type = Type(packageName = "com.mysdk", simpleName = "MySdk")),
+            ),
+            callbacks = setOf(
+                AnnotatedInterface(
+                    type = Type(packageName = "com.mysdk", simpleName = "MySdkCallback"),
+                    methods = listOf(
+                        Method(
+                            name = "foo",
+                            parameters = listOf(
+                                Parameter("otherCallback", Type("com.mysdk", "MySdkCallback"))
+                            ),
+                            returnType = Types.unit,
+                            isSuspend = false,
+                        ),
+                    )
+                )
+            )
+        )
+        val validationResult = ModelValidator.validate(api)
+        assertThat(validationResult.isFailure).isTrue()
+        assertThat(validationResult.errors).containsExactly(
+            "Error in com.mysdk.MySdkCallback.foo: callback methods cannot receive other " +
+                "callbacks as arguments."
+        )
+    }
 }
\ No newline at end of file