Merge "Add benchmarks for WASM bidding logic" into tm-mainline-prod
diff --git a/apct-tests/perftests/rubidium/Android.bp b/apct-tests/perftests/rubidium/Android.bp
index 339ef30..ba2b442 100644
--- a/apct-tests/perftests/rubidium/Android.bp
+++ b/apct-tests/perftests/rubidium/Android.bp
@@ -33,6 +33,7 @@
         "compatibility-device-util-axt",
         "platform-test-annotations",
         "adservices-service-core",
+        "androidx.core_core",
     ],
     test_suites: ["device-tests"],
     data: [":perfetto_artifacts"],
diff --git a/apct-tests/perftests/rubidium/assets/generate_bid.wasm b/apct-tests/perftests/rubidium/assets/generate_bid.wasm
new file mode 100644
index 0000000..5e7fe9e
--- /dev/null
+++ b/apct-tests/perftests/rubidium/assets/generate_bid.wasm
Binary files differ
diff --git a/apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js b/apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js
new file mode 100644
index 0000000..bc50d0a
--- /dev/null
+++ b/apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js
@@ -0,0 +1,24 @@
+function generateBid(ad, wasmModule) {
+  let input = ad.metadata.input;
+
+  const instance = new WebAssembly.Instance(wasmModule);
+
+  const memory = instance.exports.memory;
+  const input_in_memory = new Float32Array(memory.buffer, 0, 200);
+  for (let i = 0; i < input.length; ++i) {
+    input_in_memory[i] = input[i];
+  }
+  const results = [
+    instance.exports.nn_forward_model0(input_in_memory.length, input_in_memory),
+    instance.exports.nn_forward_model1(input_in_memory.length, input_in_memory),
+    instance.exports.nn_forward_model2(input_in_memory.length, input_in_memory),
+    instance.exports.nn_forward_model3(input_in_memory.length, input_in_memory),
+    instance.exports.nn_forward_model4(input_in_memory.length, input_in_memory),
+  ];
+  const bid = results.map(x => Math.max(x, 1)).reduce((x, y) => x * y);
+  return {
+    ad: 'example',
+    bid: bid,
+    render: ad.renderUrl
+  }
+}
\ No newline at end of file
diff --git a/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java b/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java
index bf9ff3a..0ddec23 100644
--- a/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java
+++ b/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java
@@ -24,6 +24,9 @@
 
 import static com.google.common.truth.Truth.assertThat;
 
+import static org.junit.Assume.assumeTrue;
+
+import android.annotation.SuppressLint;
 import android.content.Context;
 import android.perftests.utils.BenchmarkState;
 import android.perftests.utils.PerfStatusReporter;
@@ -45,48 +48,44 @@
 import com.google.common.util.concurrent.ListenableFuture;
 
 import org.json.JSONArray;
-import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.io.IOException;
 import java.io.InputStream;
 import java.nio.charset.StandardCharsets;
 import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
+import java.util.Random;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
 
 /** To run the unit tests for this class, run "atest RubidiumPerfTests:JSScriptEnginePerfTests" */
 @MediumTest
 @RunWith(AndroidJUnit4.class)
 public class JSScriptEnginePerfTests {
-    private static final String TAG = JSScriptEnginePerfTests.class.getSimpleName();
+    private static final String TAG = JSScriptEngine.TAG;
     private static final Context sContext = ApplicationProvider.getApplicationContext();
     private static final ExecutorService sExecutorService = Executors.newFixedThreadPool(10);
 
-    private static JSScriptEngine sJSScriptEngine;
+    private static final JSScriptEngine sJSScriptEngine =
+            JSScriptEngine.getInstanceForTesting(
+                    sContext, Profiler.createInstance(JSScriptEngine.TAG));
 
     @Rule public PerfStatusReporter mPerfStatusReporter = new PerfStatusReporter();
 
     @Before
     public void before() throws Exception {
-        Profiler profiler = Profiler.createInstance(JSScriptEngine.TAG);
-        sJSScriptEngine = JSScriptEngine.getInstanceForTesting(sContext, profiler);
-
         // Warm up the sandbox env.
         callJSEngine(
                 "function test() { return \"hello world\";" + " }", ImmutableList.of(), "test");
     }
 
-    @After
-    public void after() {
-        sJSScriptEngine.shutdown();
-    }
-
     @Test
     public void evaluate_helloWorld() throws Exception {
         BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
@@ -156,6 +155,7 @@
         runParametrizedTurtledoveScript(75);
     }
 
+    @SuppressLint("DefaultLocale")
     private void runParametrizedTurtledoveScript(int numAds) throws Exception {
         BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
         state.pauseTiming();
@@ -220,7 +220,34 @@
         return arrayArg("foo", Collections.nCopies(numCustomAudiences, interestGroupArg));
     }
 
-    private static String callJSEngine(
+    @Test
+    public void evaluate_turtledoveWasm() throws Exception {
+        assumeTrue(sJSScriptEngine.isWasmSupported().get(3, TimeUnit.SECONDS));
+
+        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
+        state.pauseTiming();
+
+        String jsTestFile = readAsset("generate_bid_using_wasm.js");
+        byte[] wasmTestFile = readBinaryAsset("generate_bid.wasm");
+        JSScriptArgument[] inputBytes = new JSScriptArgument[200];
+        Random rand = new Random();
+        for (int i = 0; i < inputBytes.length; i++) {
+            byte value = (byte) (rand.nextInt(2 * Byte.MAX_VALUE) - Byte.MIN_VALUE);
+            inputBytes[i] = JSScriptArgument.numericArg("_", value);
+        }
+        JSScriptArgument adDataArgument =
+                recordArg(
+                        "ad",
+                        stringArg("render_url", "http://google.com"),
+                        recordArg("metadata", JSScriptArgument.arrayArg("input", inputBytes)));
+
+        state.resumeTiming();
+        while (state.keepRunning()) {
+            callJSEngine(jsTestFile, wasmTestFile, ImmutableList.of(adDataArgument), "generateBid");
+        }
+    }
+
+    private String callJSEngine(
             @NonNull String jsScript,
             @NonNull List<JSScriptArgument> args,
             @NonNull String functionName)
@@ -228,6 +255,15 @@
         return callJSEngine(sJSScriptEngine, jsScript, args, functionName);
     }
 
+    private String callJSEngine(
+            @NonNull String jsScript,
+            @NonNull byte[] wasmScript,
+            @NonNull List<JSScriptArgument> args,
+            @NonNull String functionName)
+            throws Exception {
+        return callJSEngine(sJSScriptEngine, jsScript, wasmScript, args, functionName);
+    }
+
     private static String callJSEngine(
             @NonNull JSScriptEngine jsScriptEngine,
             @NonNull String jsScript,
@@ -241,6 +277,21 @@
         return futureResult.get();
     }
 
+    private String callJSEngine(
+            @NonNull JSScriptEngine jsScriptEngine,
+            @NonNull String jsScript,
+            @NonNull byte[] wasmScript,
+            @NonNull List<JSScriptArgument> args,
+            @NonNull String functionName)
+            throws Exception {
+        CountDownLatch resultLatch = new CountDownLatch(1);
+        ListenableFuture<String> futureResult =
+                callJSEngineAsync(
+                        jsScriptEngine, jsScript, wasmScript, args, functionName, resultLatch);
+        resultLatch.await();
+        return futureResult.get();
+    }
+
     private static ListenableFuture<String> callJSEngineAsync(
             @NonNull String jsScript,
             @NonNull List<JSScriptArgument> args,
@@ -261,4 +312,26 @@
         result.addListener(resultLatch::countDown, sExecutorService);
         return result;
     }
+
+    private ListenableFuture<String> callJSEngineAsync(
+            @NonNull JSScriptEngine engine,
+            @NonNull String jsScript,
+            @NonNull byte[] wasmScript,
+            @NonNull List<JSScriptArgument> args,
+            @NonNull String functionName,
+            @NonNull CountDownLatch resultLatch) {
+        Objects.requireNonNull(engine);
+        Objects.requireNonNull(resultLatch);
+        ListenableFuture<String> result = engine.evaluate(jsScript, wasmScript, args, functionName);
+        result.addListener(resultLatch::countDown, sExecutorService);
+        return result;
+    }
+
+    private byte[] readBinaryAsset(@NonNull String assetName) throws IOException {
+        return sContext.getAssets().open(assetName).readAllBytes();
+    }
+
+    private String readAsset(@NonNull String assetName) throws IOException {
+        return new String(readBinaryAsset(assetName), StandardCharsets.UTF_8);
+    }
 }