Synchronize pointer display change requests

Previously, when InputManagerService requests for PointerController to
change the pointer display, there was no way to know when the request
was completed or whether it succeeded. This could lead to a few issues:

- WM's MousePositionTracker's coordinates would not be updated until the
  next mouse event was generated, meaning the position would be out of
  sync.
- The creation of a virtual mouse device moves the pointer to a specific
  displayId. In order to test this behavior, we would need to sleep in
  the test code to wait for the system to update the pointer display and
  position, resulting in generally flaky tests.

Here, we add a way to synchonize changes to the pointer display so that
InputMangerService can know the current pointer display with certainty.
PointerController, which is updated in the InputReader thread, is the
source of truth of the pointer display. We add a policy call to notify
IMS when the pointer display changes.

When the pointer display is changed, the cursor position on the updated
display is also updated so that the VirtualMouse#getCursorPosition() API
is synchronized to the pointer display change.

Bug: 216792538
Test: atest FrameworksServicesTests:InputManagerServiceTests
Test: atest PointerIconTest
Change-Id: I578fd1aba9335e2e078d749321e55a6d05299f3b
diff --git a/core/java/android/hardware/input/InputManagerInternal.java b/core/java/android/hardware/input/InputManagerInternal.java
index b37c27c..fc6bc55 100644
--- a/core/java/android/hardware/input/InputManagerInternal.java
+++ b/core/java/android/hardware/input/InputManagerInternal.java
@@ -75,8 +75,15 @@
     /**
      * Sets the display id that the MouseCursorController will be forced to target. Pass
      * {@link android.view.Display#INVALID_DISPLAY} to clear the override.
+     *
+     * Note: This method generally blocks until the pointer display override has propagated.
+     * When setting a new override, the caller should ensure that an input device that can control
+     * the mouse pointer is connected. If a new override is set when no such input device is
+     * connected, the caller may be blocked for an arbitrary period of time.
+     *
+     * @return true if the pointer displayId was set successfully, or false if it fails.
      */
-    public abstract void setVirtualMousePointerDisplayId(int pointerDisplayId);
+    public abstract boolean setVirtualMousePointerDisplayId(int pointerDisplayId);
 
     /**
      * Gets the display id that the MouseCursorController is being forced to target. Returns
diff --git a/libs/input/PointerController.cpp b/libs/input/PointerController.cpp
index 1dc74e5..10ea651 100644
--- a/libs/input/PointerController.cpp
+++ b/libs/input/PointerController.cpp
@@ -106,6 +106,7 @@
 PointerController::~PointerController() {
     mDisplayInfoListener->onPointerControllerDestroyed();
     mUnregisterWindowInfosListener(mDisplayInfoListener);
+    mContext.getPolicy()->onPointerDisplayIdChanged(ADISPLAY_ID_NONE, 0, 0);
 }
 
 std::mutex& PointerController::getLock() const {
@@ -255,6 +256,12 @@
         getAdditionalMouseResources = true;
     }
     mCursorController.setDisplayViewport(viewport, getAdditionalMouseResources);
+    if (viewport.displayId != mLocked.pointerDisplayId) {
+        float xPos, yPos;
+        mCursorController.getPosition(&xPos, &yPos);
+        mContext.getPolicy()->onPointerDisplayIdChanged(viewport.displayId, xPos, yPos);
+        mLocked.pointerDisplayId = viewport.displayId;
+    }
 }
 
 void PointerController::updatePointerIcon(int32_t iconId) {
diff --git a/libs/input/PointerController.h b/libs/input/PointerController.h
index 2e6e851..eab030f 100644
--- a/libs/input/PointerController.h
+++ b/libs/input/PointerController.h
@@ -104,6 +104,7 @@
 
     struct Locked {
         Presentation presentation;
+        int32_t pointerDisplayId = ADISPLAY_ID_NONE;
 
         std::vector<gui::DisplayInfo> mDisplayInfos;
         std::unordered_map<int32_t /* displayId */, TouchSpotController> spotControllers;
diff --git a/libs/input/PointerControllerContext.h b/libs/input/PointerControllerContext.h
index 26a65a4..c2bc1e0 100644
--- a/libs/input/PointerControllerContext.h
+++ b/libs/input/PointerControllerContext.h
@@ -79,6 +79,7 @@
             std::map<int32_t, PointerAnimation>* outAnimationResources, int32_t displayId) = 0;
     virtual int32_t getDefaultPointerIconId() = 0;
     virtual int32_t getCustomPointerIconId() = 0;
+    virtual void onPointerDisplayIdChanged(int32_t displayId, float xPos, float yPos) = 0;
 };
 
 /*
diff --git a/libs/input/tests/PointerController_test.cpp b/libs/input/tests/PointerController_test.cpp
index dae1fcc..f9752ed 100644
--- a/libs/input/tests/PointerController_test.cpp
+++ b/libs/input/tests/PointerController_test.cpp
@@ -56,9 +56,11 @@
             std::map<int32_t, PointerAnimation>* outAnimationResources, int32_t displayId) override;
     virtual int32_t getDefaultPointerIconId() override;
     virtual int32_t getCustomPointerIconId() override;
+    virtual void onPointerDisplayIdChanged(int32_t displayId, float xPos, float yPos) override;
 
     bool allResourcesAreLoaded();
     bool noResourcesAreLoaded();
+    std::optional<int32_t> getLastReportedPointerDisplayId() { return latestPointerDisplayId; }
 
 private:
     void loadPointerIconForType(SpriteIcon* icon, int32_t cursorType);
@@ -66,6 +68,7 @@
     bool pointerIconLoaded{false};
     bool pointerResourcesLoaded{false};
     bool additionalMouseResourcesLoaded{false};
+    std::optional<int32_t /*displayId*/> latestPointerDisplayId;
 };
 
 void MockPointerControllerPolicyInterface::loadPointerIcon(SpriteIcon* icon, int32_t) {
@@ -126,12 +129,19 @@
     icon->hotSpotX = hotSpot.first;
     icon->hotSpotY = hotSpot.second;
 }
+
+void MockPointerControllerPolicyInterface::onPointerDisplayIdChanged(int32_t displayId,
+                                                                     float /*xPos*/,
+                                                                     float /*yPos*/) {
+    latestPointerDisplayId = displayId;
+}
+
 class PointerControllerTest : public Test {
 protected:
     PointerControllerTest();
     ~PointerControllerTest();
 
-    void ensureDisplayViewportIsSet();
+    void ensureDisplayViewportIsSet(int32_t displayId = ADISPLAY_ID_DEFAULT);
 
     sp<MockSprite> mPointerSprite;
     sp<MockPointerControllerPolicyInterface> mPolicy;
@@ -168,9 +178,9 @@
     mThread.join();
 }
 
-void PointerControllerTest::ensureDisplayViewportIsSet() {
+void PointerControllerTest::ensureDisplayViewportIsSet(int32_t displayId) {
     DisplayViewport viewport;
-    viewport.displayId = ADISPLAY_ID_DEFAULT;
+    viewport.displayId = displayId;
     viewport.logicalRight = 1600;
     viewport.logicalBottom = 1200;
     viewport.physicalRight = 800;
@@ -255,6 +265,30 @@
     ensureDisplayViewportIsSet();
 }
 
+TEST_F(PointerControllerTest, notifiesPolicyWhenPointerDisplayChanges) {
+    EXPECT_FALSE(mPolicy->getLastReportedPointerDisplayId())
+            << "A pointer display change does not occur when PointerController is created.";
+
+    ensureDisplayViewportIsSet(ADISPLAY_ID_DEFAULT);
+
+    const auto lastReportedPointerDisplayId = mPolicy->getLastReportedPointerDisplayId();
+    ASSERT_TRUE(lastReportedPointerDisplayId)
+            << "The policy is notified of a pointer display change when the viewport is first set.";
+    EXPECT_EQ(ADISPLAY_ID_DEFAULT, *lastReportedPointerDisplayId)
+            << "Incorrect pointer display notified.";
+
+    ensureDisplayViewportIsSet(42);
+
+    EXPECT_EQ(42, *mPolicy->getLastReportedPointerDisplayId())
+            << "The policy is notified when the pointer display changes.";
+
+    // Release the PointerController.
+    mPointerController = nullptr;
+
+    EXPECT_EQ(ADISPLAY_ID_NONE, *mPolicy->getLastReportedPointerDisplayId())
+            << "The pointer display changes to invalid when PointerController is destroyed.";
+}
+
 class PointerControllerWindowInfoListenerTest : public Test {};
 
 class TestPointerController : public PointerController {
diff --git a/services/core/java/com/android/server/input/InputManagerService.java b/services/core/java/com/android/server/input/InputManagerService.java
index e433324..5bee308 100644
--- a/services/core/java/com/android/server/input/InputManagerService.java
+++ b/services/core/java/com/android/server/input/InputManagerService.java
@@ -167,6 +167,7 @@
     private static final int MSG_UPDATE_KEYBOARD_LAYOUTS = 4;
     private static final int MSG_RELOAD_DEVICE_ALIASES = 5;
     private static final int MSG_DELIVER_TABLET_MODE_CHANGED = 6;
+    private static final int MSG_POINTER_DISPLAY_ID_CHANGED = 7;
 
     private static final int DEFAULT_VIBRATION_MAGNITUDE = 192;
 
@@ -279,11 +280,24 @@
     @GuardedBy("mAssociationLock")
     private final Map<String, String> mUniqueIdAssociations = new ArrayMap<>();
 
+    // Guards per-display input properties and properties relating to the mouse pointer.
+    // Threads can wait on this lock to be notified the next time the display on which the mouse
+    // pointer is shown has changed.
     private final Object mAdditionalDisplayInputPropertiesLock = new Object();
 
-    // Forces the MouseCursorController to target a specific display id.
+    // Forces the PointerController to target a specific display id.
     @GuardedBy("mAdditionalDisplayInputPropertiesLock")
     private int mOverriddenPointerDisplayId = Display.INVALID_DISPLAY;
+
+    // PointerController is the source of truth of the pointer display. This is the value of the
+    // latest pointer display id reported by PointerController.
+    @GuardedBy("mAdditionalDisplayInputPropertiesLock")
+    private int mAcknowledgedPointerDisplayId = Display.INVALID_DISPLAY;
+    // This is the latest display id that IMS has requested PointerController to use. If there are
+    // no devices that can control the pointer, PointerController may end up disregarding this
+    // value.
+    @GuardedBy("mAdditionalDisplayInputPropertiesLock")
+    private int mRequestedPointerDisplayId = Display.INVALID_DISPLAY;
     @GuardedBy("mAdditionalDisplayInputPropertiesLock")
     private final SparseArray<AdditionalDisplayInputProperties> mAdditionalDisplayInputProperties =
             new SparseArray<>();
@@ -292,7 +306,6 @@
     @GuardedBy("mAdditionalDisplayInputPropertiesLock")
     private PointerIcon mIcon;
 
-
     // Holds all the registered gesture monitors that are implemented as spy windows. The spy
     // windows are mapped by their InputChannel tokens.
     @GuardedBy("mInputMonitors")
@@ -388,6 +401,10 @@
         NativeInputManagerService getNativeService(InputManagerService service) {
             return new NativeInputManagerService.NativeImpl(service, mContext, mLooper.getQueue());
         }
+
+        void registerLocalService(InputManagerInternal localService) {
+            LocalServices.addService(InputManagerInternal.class, localService);
+        }
     }
 
     public InputManagerService(Context context) {
@@ -413,7 +430,8 @@
 
         mVelocityTrackerStrategy = DeviceConfig.getProperty(
                 NAMESPACE_INPUT_NATIVE_BOOT, VELOCITYTRACKER_STRATEGY_PROPERTY);
-        LocalServices.addService(InputManagerInternal.class, new LocalService());
+
+        injector.registerLocalService(new LocalService());
     }
 
     public void setWindowManagerCallbacks(WindowManagerCallbacks callbacks) {
@@ -563,6 +581,8 @@
                 vArray[i] = viewports.get(i);
             }
             mNative.setDisplayViewports(vArray);
+            // Always attempt to update the pointer display when viewports change.
+            updatePointerDisplayId();
 
             if (mOverriddenPointerDisplayId != Display.INVALID_DISPLAY) {
                 final AdditionalDisplayInputProperties properties =
@@ -1973,10 +1993,43 @@
         return result;
     }
 
-    private void setVirtualMousePointerDisplayId(int displayId) {
+    /**
+     * Update the display on which the mouse pointer is shown.
+     * If there is an overridden display for the mouse pointer, use that. Otherwise, query
+     * WindowManager for the pointer display.
+     *
+     * @return true if the pointer displayId changed, false otherwise.
+     */
+    private boolean updatePointerDisplayId() {
+        synchronized (mAdditionalDisplayInputPropertiesLock) {
+            final int pointerDisplayId = mOverriddenPointerDisplayId != Display.INVALID_DISPLAY
+                    ? mOverriddenPointerDisplayId : mWindowManagerCallbacks.getPointerDisplayId();
+            if (mRequestedPointerDisplayId == pointerDisplayId) {
+                return false;
+            }
+            mRequestedPointerDisplayId = pointerDisplayId;
+            mNative.setPointerDisplayId(pointerDisplayId);
+            return true;
+        }
+    }
+
+    private void handlePointerDisplayIdChanged(PointerDisplayIdChangedArgs args) {
+        synchronized (mAdditionalDisplayInputPropertiesLock) {
+            mAcknowledgedPointerDisplayId = args.mPointerDisplayId;
+            // Notify waiting threads that the display of the mouse pointer has changed.
+            mAdditionalDisplayInputPropertiesLock.notifyAll();
+        }
+        mWindowManagerCallbacks.notifyPointerDisplayIdChanged(
+                args.mPointerDisplayId, args.mXPosition, args.mYPosition);
+    }
+
+    private boolean setVirtualMousePointerDisplayIdBlocking(int displayId) {
+        // Indicates whether this request is for removing the override.
+        final boolean removingOverride = displayId == Display.INVALID_DISPLAY;
+
         synchronized (mAdditionalDisplayInputPropertiesLock) {
             mOverriddenPointerDisplayId = displayId;
-            if (displayId != Display.INVALID_DISPLAY) {
+            if (!removingOverride) {
                 final AdditionalDisplayInputProperties properties =
                         mAdditionalDisplayInputProperties.get(displayId);
                 if (properties != null) {
@@ -1984,9 +2037,30 @@
                     updatePointerIconVisibleLocked(properties.pointerIconVisible);
                 }
             }
+            if (!updatePointerDisplayId() && mAcknowledgedPointerDisplayId == displayId) {
+                // The requested pointer display is already set.
+                return true;
+            }
+            if (removingOverride && mAcknowledgedPointerDisplayId == Display.INVALID_DISPLAY) {
+                // The pointer display override is being removed, but the current pointer display
+                // is already invalid. This can happen when the PointerController is destroyed as a
+                // result of the removal of all input devices that can control the pointer.
+                return true;
+            }
+            try {
+                // The pointer display changed, so wait until the change has propagated.
+                mAdditionalDisplayInputPropertiesLock.wait(5_000 /*mills*/);
+            } catch (InterruptedException ignored) {
+            }
+            // This request succeeds in two cases:
+            // - This request was to remove the override, in which case the new pointer display
+            //   could be anything that WM has set.
+            // - We are setting a new override, in which case the request only succeeds if the
+            //   reported new displayId is the one we requested. This check ensures that if two
+            //   competing overrides are requested in succession, the caller can be notified if one
+            //   of them fails.
+            return  removingOverride || mAcknowledgedPointerDisplayId == displayId;
         }
-        // TODO(b/215597605): trigger MousePositionTracker update
-        mNative.notifyPointerDisplayIdChanged();
     }
 
     private int getVirtualMousePointerDisplayId() {
@@ -3168,18 +3242,6 @@
 
     // Native callback.
     @SuppressWarnings("unused")
-    private int getPointerDisplayId() {
-        synchronized (mAdditionalDisplayInputPropertiesLock) {
-            // Prefer the override to all other displays.
-            if (mOverriddenPointerDisplayId != Display.INVALID_DISPLAY) {
-                return mOverriddenPointerDisplayId;
-            }
-        }
-        return mWindowManagerCallbacks.getPointerDisplayId();
-    }
-
-    // Native callback.
-    @SuppressWarnings("unused")
     private String[] getKeyboardLayoutOverlay(InputDeviceIdentifier identifier) {
         if (!mSystemReady) {
             return null;
@@ -3218,6 +3280,26 @@
         return null;
     }
 
+    private static class PointerDisplayIdChangedArgs {
+        final int mPointerDisplayId;
+        final float mXPosition;
+        final float mYPosition;
+        PointerDisplayIdChangedArgs(int pointerDisplayId, float xPosition, float yPosition) {
+            mPointerDisplayId = pointerDisplayId;
+            mXPosition = xPosition;
+            mYPosition = yPosition;
+        }
+    }
+
+    // Native callback.
+    @SuppressWarnings("unused")
+    @VisibleForTesting
+    void onPointerDisplayIdChanged(int pointerDisplayId, float xPosition, float yPosition) {
+        mHandler.obtainMessage(MSG_POINTER_DISPLAY_ID_CHANGED,
+                new PointerDisplayIdChangedArgs(pointerDisplayId, xPosition,
+                        yPosition)).sendToTarget();
+    }
+
     /**
      * Callback interface implemented by the Window Manager.
      */
@@ -3341,6 +3423,14 @@
          */
         @Nullable
         SurfaceControl createSurfaceForGestureMonitor(String name, int displayId);
+
+        /**
+         * Notify WindowManagerService when the display of the mouse pointer changes.
+         * @param displayId The display on which the mouse pointer is shown.
+         * @param x The x coordinate of the mouse pointer.
+         * @param y The y coordinate of the mouse pointer.
+         */
+        void notifyPointerDisplayIdChanged(int displayId, float x, float y);
     }
 
     /**
@@ -3393,6 +3483,9 @@
                     boolean inTabletMode = (boolean) args.arg1;
                     deliverTabletModeChanged(whenNanos, inTabletMode);
                     break;
+                case MSG_POINTER_DISPLAY_ID_CHANGED:
+                    handlePointerDisplayIdChanged((PointerDisplayIdChangedArgs) msg.obj);
+                    break;
             }
         }
     }
@@ -3643,8 +3736,9 @@
         }
 
         @Override
-        public void setVirtualMousePointerDisplayId(int pointerDisplayId) {
-            InputManagerService.this.setVirtualMousePointerDisplayId(pointerDisplayId);
+        public boolean setVirtualMousePointerDisplayId(int pointerDisplayId) {
+            return InputManagerService.this
+                    .setVirtualMousePointerDisplayIdBlocking(pointerDisplayId);
         }
 
         @Override
diff --git a/services/core/java/com/android/server/input/NativeInputManagerService.java b/services/core/java/com/android/server/input/NativeInputManagerService.java
index 2169155..81882d2 100644
--- a/services/core/java/com/android/server/input/NativeInputManagerService.java
+++ b/services/core/java/com/android/server/input/NativeInputManagerService.java
@@ -176,6 +176,9 @@
 
     void cancelCurrentTouch();
 
+    /** Set the displayId on which the mouse cursor should be shown. */
+    void setPointerDisplayId(int displayId);
+
     /** The native implementation of InputManagerService methods. */
     class NativeImpl implements NativeInputManagerService {
         /** Pointer to native input manager service object, used by native code. */
@@ -388,5 +391,8 @@
 
         @Override
         public native void cancelCurrentTouch();
+
+        @Override
+        public native void setPointerDisplayId(int displayId);
     }
 }
diff --git a/services/core/java/com/android/server/wm/InputManagerCallback.java b/services/core/java/com/android/server/wm/InputManagerCallback.java
index 67dd89e..33cdd2e 100644
--- a/services/core/java/com/android/server/wm/InputManagerCallback.java
+++ b/services/core/java/com/android/server/wm/InputManagerCallback.java
@@ -270,6 +270,22 @@
         }
     }
 
+    @Override
+    public void notifyPointerDisplayIdChanged(int displayId, float x, float y) {
+        synchronized (mService.mGlobalLock) {
+            mService.setMousePointerDisplayId(displayId);
+            if (displayId == Display.INVALID_DISPLAY) return;
+
+            final DisplayContent dc = mService.mRoot.getDisplayContent(displayId);
+            if (dc == null) {
+                Slog.wtf(TAG, "The mouse pointer was moved to display " + displayId
+                        + " that does not have a valid DisplayContent.");
+                return;
+            }
+            mService.restorePointerIconLocked(dc, x, y);
+        }
+    }
+
     /** Waits until the built-in input devices have been configured. */
     public boolean waitForInputDevicesReady(long timeoutMillis) {
         synchronized (mInputDevicesReadyMonitor) {
diff --git a/services/core/java/com/android/server/wm/WindowManagerService.java b/services/core/java/com/android/server/wm/WindowManagerService.java
index 3bc6dbd..bbf29c4 100644
--- a/services/core/java/com/android/server/wm/WindowManagerService.java
+++ b/services/core/java/com/android/server/wm/WindowManagerService.java
@@ -7163,18 +7163,42 @@
         private float mLatestMouseX;
         private float mLatestMouseY;
 
-        void updatePosition(float x, float y) {
+        /**
+         * The display that the pointer (mouse cursor) is currently shown on. This is updated
+         * directly by InputManagerService when the pointer display changes.
+         */
+        private int mPointerDisplayId = INVALID_DISPLAY;
+
+        /**
+         * Update the mouse cursor position as a result of a mouse movement.
+         * @return true if the position was successfully updated, false otherwise.
+         */
+        boolean updatePosition(int displayId, float x, float y) {
             synchronized (this) {
                 mLatestEventWasMouse = true;
+
+                if (displayId != mPointerDisplayId) {
+                    // The display of the position update does not match the display on which the
+                    // mouse pointer is shown, so do not update the position.
+                    return false;
+                }
                 mLatestMouseX = x;
                 mLatestMouseY = y;
+                return true;
+            }
+        }
+
+        void setPointerDisplayId(int displayId) {
+            synchronized (this) {
+                mPointerDisplayId = displayId;
             }
         }
 
         @Override
         public void onPointerEvent(MotionEvent motionEvent) {
             if (motionEvent.isFromSource(InputDevice.SOURCE_MOUSE)) {
-                updatePosition(motionEvent.getRawX(), motionEvent.getRawY());
+                updatePosition(motionEvent.getDisplayId(), motionEvent.getRawX(),
+                        motionEvent.getRawY());
             } else {
                 synchronized (this) {
                     mLatestEventWasMouse = false;
@@ -7184,6 +7208,7 @@
     };
 
     void updatePointerIcon(IWindow client) {
+        int pointerDisplayId;
         float mouseX, mouseY;
 
         synchronized(mMousePositionTracker) {
@@ -7192,6 +7217,7 @@
             }
             mouseX = mMousePositionTracker.mLatestMouseX;
             mouseY = mMousePositionTracker.mLatestMouseY;
+            pointerDisplayId = mMousePositionTracker.mPointerDisplayId;
         }
 
         synchronized (mGlobalLock) {
@@ -7208,6 +7234,10 @@
             if (displayContent == null) {
                 return;
             }
+            if (pointerDisplayId != displayContent.getDisplayId()) {
+                // Do not let the pointer icon be updated by a window on a different display.
+                return;
+            }
             WindowState windowUnderPointer =
                     displayContent.getTouchableWinAtPointLocked(mouseX, mouseY);
             if (windowUnderPointer != callingWin) {
@@ -7225,7 +7255,11 @@
 
     void restorePointerIconLocked(DisplayContent displayContent, float latestX, float latestY) {
         // Mouse position tracker has not been getting updates while dragging, update it now.
-        mMousePositionTracker.updatePosition(latestX, latestY);
+        if (!mMousePositionTracker.updatePosition(
+                displayContent.getDisplayId(), latestX, latestY)) {
+            // The mouse position could not be updated, so ignore this request.
+            return;
+        }
 
         WindowState windowUnderPointer =
                 displayContent.getTouchableWinAtPointLocked(latestX, latestY);
@@ -7249,6 +7283,10 @@
         }
     }
 
+    void setMousePointerDisplayId(int displayId) {
+        mMousePositionTracker.setPointerDisplayId(displayId);
+    }
+
     /**
      * Update a tap exclude region in the window identified by the provided id. Touches down on this
      * region will not:
diff --git a/services/core/jni/com_android_server_input_InputManagerService.cpp b/services/core/jni/com_android_server_input_InputManagerService.cpp
index ffda8be..b303448 100644
--- a/services/core/jni/com_android_server_input_InputManagerService.cpp
+++ b/services/core/jni/com_android_server_input_InputManagerService.cpp
@@ -115,6 +115,7 @@
     jmethodID interceptKeyBeforeDispatching;
     jmethodID dispatchUnhandledKey;
     jmethodID checkInjectEventsPermission;
+    jmethodID onPointerDisplayIdChanged;
     jmethodID onPointerDownOutsideFocus;
     jmethodID getVirtualKeyQuietTimeMillis;
     jmethodID getExcludedDeviceNames;
@@ -128,7 +129,6 @@
     jmethodID getLongPressTimeout;
     jmethodID getPointerLayer;
     jmethodID getPointerIcon;
-    jmethodID getPointerDisplayId;
     jmethodID getKeyboardLayoutOverlay;
     jmethodID getDeviceAlias;
     jmethodID getTouchCalibrationForInputDevice;
@@ -285,6 +285,7 @@
     void setFocusedDisplay(int32_t displayId);
     void setInputDispatchMode(bool enabled, bool frozen);
     void setSystemUiLightsOut(bool lightsOut);
+    void setPointerDisplayId(int32_t displayId);
     void setPointerSpeed(int32_t speed);
     void setPointerAcceleration(float acceleration);
     void setInputDeviceEnabled(uint32_t deviceId, bool enabled);
@@ -296,7 +297,6 @@
     void requestPointerCapture(const sp<IBinder>& windowToken, bool enabled);
     void setCustomPointerIcon(const SpriteIcon& icon);
     void setMotionClassifierEnabled(bool enabled);
-    void notifyPointerDisplayIdChanged();
 
     /* --- InputReaderPolicyInterface implementation --- */
 
@@ -354,6 +354,7 @@
             std::map<int32_t, PointerAnimation>* outAnimationResources, int32_t displayId);
     virtual int32_t getDefaultPointerIconId();
     virtual int32_t getCustomPointerIconId();
+    virtual void onPointerDisplayIdChanged(int32_t displayId, float xPos, float yPos);
 
 private:
     sp<InputManagerInterface> mInputManager;
@@ -402,7 +403,6 @@
     void updateInactivityTimeoutLocked();
     void handleInterceptActions(jint wmActions, nsecs_t when, uint32_t& policyFlags);
     void ensureSpriteControllerLocked();
-    int32_t getPointerDisplayId();
     sp<SurfaceControl> getParentSurfaceForPointers(int displayId);
     static bool checkAndClearExceptionFromCallback(JNIEnv* env, const char* methodName);
 
@@ -506,13 +506,9 @@
         }
     }
 
-    // Get the preferred pointer controller displayId.
-    int32_t pointerDisplayId = getPointerDisplayId();
-
     { // acquire lock
         AutoMutex _l(mLock);
         mLocked.viewports = viewports;
-        mLocked.pointerDisplayId = pointerDisplayId;
         std::shared_ptr<PointerController> controller = mLocked.pointerController.lock();
         if (controller != nullptr) {
             controller->onDisplayViewportsUpdated(mLocked.viewports);
@@ -674,15 +670,12 @@
     return controller;
 }
 
-int32_t NativeInputManager::getPointerDisplayId() {
+void NativeInputManager::onPointerDisplayIdChanged(int32_t pointerDisplayId, float xPos,
+                                                   float yPos) {
     JNIEnv* env = jniEnv();
-    jint pointerDisplayId = env->CallIntMethod(mServiceObj,
-            gServiceClassInfo.getPointerDisplayId);
-    if (checkAndClearExceptionFromCallback(env, "getPointerDisplayId")) {
-        pointerDisplayId = ADISPLAY_ID_DEFAULT;
-    }
-
-    return pointerDisplayId;
+    env->CallVoidMethod(mServiceObj, gServiceClassInfo.onPointerDisplayIdChanged, pointerDisplayId,
+                        xPos, yPos);
+    checkAndClearExceptionFromCallback(env, "onPointerDisplayIdChanged");
 }
 
 sp<SurfaceControl> NativeInputManager::getParentSurfaceForPointers(int displayId) {
@@ -1040,6 +1033,22 @@
                                                                : InactivityTimeout::NORMAL);
 }
 
+void NativeInputManager::setPointerDisplayId(int32_t displayId) {
+    { // acquire lock
+        AutoMutex _l(mLock);
+
+        if (mLocked.pointerDisplayId == displayId) {
+            return;
+        }
+
+        ALOGI("Setting pointer display id to %d.", displayId);
+        mLocked.pointerDisplayId = displayId;
+    } // release lock
+
+    mInputManager->getReader().requestRefreshConfiguration(
+            InputReaderConfiguration::CHANGE_DISPLAY_INFO);
+}
+
 void NativeInputManager::setPointerSpeed(int32_t speed) {
     { // acquire lock
         AutoMutex _l(mLock);
@@ -1502,18 +1511,6 @@
     mInputManager->getClassifier().setMotionClassifierEnabled(enabled);
 }
 
-void NativeInputManager::notifyPointerDisplayIdChanged() {
-    int32_t pointerDisplayId = getPointerDisplayId();
-
-    { // acquire lock
-        AutoMutex _l(mLock);
-        mLocked.pointerDisplayId = pointerDisplayId;
-    } // release lock
-
-    mInputManager->getReader().requestRefreshConfiguration(
-            InputReaderConfiguration::CHANGE_DISPLAY_INFO);
-}
-
 // ----------------------------------------------------------------------------
 
 static NativeInputManager* getNativeInputManager(JNIEnv* env, jobject clazz) {
@@ -2209,11 +2206,6 @@
             InputReaderConfiguration::CHANGE_DISPLAY_INFO);
 }
 
-static void nativeNotifyPointerDisplayIdChanged(JNIEnv* env, jobject nativeImplObj) {
-    NativeInputManager* im = getNativeInputManager(env, nativeImplObj);
-    im->notifyPointerDisplayIdChanged();
-}
-
 static void nativeSetDisplayEligibilityForPointerCapture(JNIEnv* env, jobject nativeImplObj,
                                                          jint displayId, jboolean isEligible) {
     NativeInputManager* im = getNativeInputManager(env, nativeImplObj);
@@ -2331,6 +2323,11 @@
     im->getInputManager()->getDispatcher().cancelCurrentTouch();
 }
 
+static void nativeSetPointerDisplayId(JNIEnv* env, jobject nativeImplObj, jint displayId) {
+    NativeInputManager* im = getNativeInputManager(env, nativeImplObj);
+    im->setPointerDisplayId(displayId);
+}
+
 // ----------------------------------------------------------------------------
 
 static const JNINativeMethod gInputManagerMethods[] = {
@@ -2403,7 +2400,6 @@
         {"canDispatchToDisplay", "(II)Z", (void*)nativeCanDispatchToDisplay},
         {"notifyPortAssociationsChanged", "()V", (void*)nativeNotifyPortAssociationsChanged},
         {"changeUniqueIdAssociation", "()V", (void*)nativeChangeUniqueIdAssociation},
-        {"notifyPointerDisplayIdChanged", "()V", (void*)nativeNotifyPointerDisplayIdChanged},
         {"setDisplayEligibilityForPointerCapture", "(IZ)V",
          (void*)nativeSetDisplayEligibilityForPointerCapture},
         {"setMotionClassifierEnabled", "(Z)V", (void*)nativeSetMotionClassifierEnabled},
@@ -2413,6 +2409,7 @@
         {"disableSensor", "(II)V", (void*)nativeDisableSensor},
         {"flushSensor", "(II)Z", (void*)nativeFlushSensor},
         {"cancelCurrentTouch", "()V", (void*)nativeCancelCurrentTouch},
+        {"setPointerDisplayId", "(I)V", (void*)nativeSetPointerDisplayId},
 };
 
 #define FIND_CLASS(var, className) \
@@ -2508,6 +2505,9 @@
     GET_METHOD_ID(gServiceClassInfo.checkInjectEventsPermission, clazz,
                   "checkInjectEventsPermission", "(II)Z");
 
+    GET_METHOD_ID(gServiceClassInfo.onPointerDisplayIdChanged, clazz, "onPointerDisplayIdChanged",
+                  "(IFF)V");
+
     GET_METHOD_ID(gServiceClassInfo.onPointerDownOutsideFocus, clazz,
             "onPointerDownOutsideFocus", "(Landroid/os/IBinder;)V");
 
@@ -2547,9 +2547,6 @@
     GET_METHOD_ID(gServiceClassInfo.getPointerIcon, clazz,
             "getPointerIcon", "(I)Landroid/view/PointerIcon;");
 
-    GET_METHOD_ID(gServiceClassInfo.getPointerDisplayId, clazz,
-            "getPointerDisplayId", "()I");
-
     GET_METHOD_ID(gServiceClassInfo.getKeyboardLayoutOverlay, clazz,
             "getKeyboardLayoutOverlay",
             "(Landroid/hardware/input/InputDeviceIdentifier;)[Ljava/lang/String;");
diff --git a/services/tests/servicestests/src/com/android/server/companion/virtual/InputControllerTest.java b/services/tests/servicestests/src/com/android/server/companion/virtual/InputControllerTest.java
index 92e7a86..77cbb3a 100644
--- a/services/tests/servicestests/src/com/android/server/companion/virtual/InputControllerTest.java
+++ b/services/tests/servicestests/src/com/android/server/companion/virtual/InputControllerTest.java
@@ -19,7 +19,6 @@
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -67,7 +66,7 @@
         mInputManagerMockHelper = new InputManagerMockHelper(
                 TestableLooper.get(this), mNativeWrapperMock, mIInputManagerMock);
 
-        doNothing().when(mInputManagerInternalMock).setVirtualMousePointerDisplayId(anyInt());
+        doReturn(true).when(mInputManagerInternalMock).setVirtualMousePointerDisplayId(anyInt());
         LocalServices.removeServiceForTest(InputManagerInternal.class);
         LocalServices.addService(InputManagerInternal.class, mInputManagerInternalMock);
 
diff --git a/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java b/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java
index 22152a1..cbb9fd7 100644
--- a/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java
@@ -180,7 +180,7 @@
         LocalServices.removeServiceForTest(DisplayManagerInternal.class);
         LocalServices.addService(DisplayManagerInternal.class, mDisplayManagerInternalMock);
 
-        doNothing().when(mInputManagerInternalMock).setVirtualMousePointerDisplayId(anyInt());
+        doReturn(true).when(mInputManagerInternalMock).setVirtualMousePointerDisplayId(anyInt());
         doNothing().when(mInputManagerInternalMock).setPointerAcceleration(anyFloat(), anyInt());
         doNothing().when(mInputManagerInternalMock).setPointerIconVisible(anyBoolean(), anyInt());
         LocalServices.removeServiceForTest(InputManagerInternal.class);
diff --git a/services/tests/servicestests/src/com/android/server/input/InputManagerServiceTests.kt b/services/tests/servicestests/src/com/android/server/input/InputManagerServiceTests.kt
new file mode 100644
index 0000000..cb97c9b
--- /dev/null
+++ b/services/tests/servicestests/src/com/android/server/input/InputManagerServiceTests.kt
@@ -0,0 +1,222 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.input
+
+import android.content.Context
+import android.content.ContextWrapper
+import android.hardware.display.DisplayViewport
+import android.hardware.input.InputManagerInternal
+import android.os.test.TestLooper
+import android.platform.test.annotations.Presubmit
+import android.view.Display
+import androidx.test.InstrumentationRegistry
+import org.junit.Assert.assertFalse
+import org.junit.Assert.assertTrue
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.mockito.ArgumentMatchers.any
+import org.mockito.ArgumentMatchers.anyInt
+import org.mockito.Mock
+import org.mockito.Mockito.`when`
+import org.mockito.Mockito.doAnswer
+import org.mockito.Mockito.never
+import org.mockito.Mockito.spy
+import org.mockito.Mockito.times
+import org.mockito.Mockito.verify
+import org.mockito.junit.MockitoJUnit
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
+
+/**
+ * Tests for {@link InputManagerService}.
+ *
+ * Build/Install/Run:
+ * atest FrameworksServicesTests:InputManagerServiceTests
+ */
+@Presubmit
+class InputManagerServiceTests {
+
+    @get:Rule
+    val rule = MockitoJUnit.rule()!!
+
+    @Mock
+    private lateinit var native: NativeInputManagerService
+
+    @Mock
+    private lateinit var wmCallbacks: InputManagerService.WindowManagerCallbacks
+
+    private lateinit var service: InputManagerService
+    private lateinit var localService: InputManagerInternal
+    private lateinit var context: Context
+    private lateinit var testLooper: TestLooper
+
+    @Before
+    fun setup() {
+        context = spy(ContextWrapper(InstrumentationRegistry.getContext()))
+        testLooper = TestLooper()
+        service =
+            InputManagerService(object : InputManagerService.Injector(context, testLooper.looper) {
+                override fun getNativeService(
+                    service: InputManagerService?
+                ): NativeInputManagerService {
+                    return native
+                }
+
+                override fun registerLocalService(service: InputManagerInternal?) {
+                    localService = service!!
+                }
+            })
+        assertTrue("Local service must be registered", this::localService.isInitialized)
+        service.setWindowManagerCallbacks(wmCallbacks)
+    }
+
+    @Test
+    fun testPointerDisplayUpdatesWhenDisplayViewportsChanged() {
+        val displayId = 123
+        `when`(wmCallbacks.pointerDisplayId).thenReturn(displayId)
+        val viewports = listOf<DisplayViewport>()
+        localService.setDisplayViewports(viewports)
+        verify(native).setDisplayViewports(any(Array<DisplayViewport>::class.java))
+        verify(native).setPointerDisplayId(displayId)
+
+        val x = 42f
+        val y = 314f
+        service.onPointerDisplayIdChanged(displayId, x, y)
+        testLooper.dispatchNext()
+        verify(wmCallbacks).notifyPointerDisplayIdChanged(displayId, x, y)
+    }
+
+    @Test
+    fun testSetVirtualMousePointerDisplayId() {
+        // Set the virtual mouse pointer displayId, and ensure that the calling thread is blocked
+        // until the native callback happens.
+        var countDownLatch = CountDownLatch(1)
+        val overrideDisplayId = 123
+        Thread {
+            assertTrue("Setting virtual pointer display should succeed",
+                localService.setVirtualMousePointerDisplayId(overrideDisplayId))
+            countDownLatch.countDown()
+        }.start()
+        assertFalse("Setting virtual pointer display should block",
+            countDownLatch.await(100, TimeUnit.MILLISECONDS))
+
+        val x = 42f
+        val y = 314f
+        service.onPointerDisplayIdChanged(overrideDisplayId, x, y)
+        testLooper.dispatchNext()
+        verify(wmCallbacks).notifyPointerDisplayIdChanged(overrideDisplayId, x, y)
+        assertTrue("Native callback unblocks calling thread",
+            countDownLatch.await(100, TimeUnit.MILLISECONDS))
+        verify(native).setPointerDisplayId(overrideDisplayId)
+
+        // Ensure that setting the same override again succeeds immediately.
+        assertTrue("Setting the same virtual mouse pointer displayId again should succeed",
+            localService.setVirtualMousePointerDisplayId(overrideDisplayId))
+
+        // Ensure that we did not query WM for the pointerDisplayId when setting the override
+        verify(wmCallbacks, never()).pointerDisplayId
+
+        // Unset the virtual mouse pointer displayId, and ensure that we query WM for the new
+        // pointer displayId and the calling thread is blocked until the native callback happens.
+        countDownLatch = CountDownLatch(1)
+        val pointerDisplayId = 42
+        `when`(wmCallbacks.pointerDisplayId).thenReturn(pointerDisplayId)
+        Thread {
+            assertTrue("Unsetting virtual mouse pointer displayId should succeed",
+                localService.setVirtualMousePointerDisplayId(Display.INVALID_DISPLAY))
+            countDownLatch.countDown()
+        }.start()
+        assertFalse("Unsetting virtual mouse pointer displayId should block",
+            countDownLatch.await(100, TimeUnit.MILLISECONDS))
+
+        service.onPointerDisplayIdChanged(pointerDisplayId, x, y)
+        testLooper.dispatchNext()
+        verify(wmCallbacks).notifyPointerDisplayIdChanged(pointerDisplayId, x, y)
+        assertTrue("Native callback unblocks calling thread",
+            countDownLatch.await(100, TimeUnit.MILLISECONDS))
+        verify(native).setPointerDisplayId(pointerDisplayId)
+    }
+
+    @Test
+    fun testSetVirtualMousePointerDisplayId_unsuccessfulUpdate() {
+        // Set the virtual mouse pointer displayId, and ensure that the calling thread is blocked
+        // until the native callback happens.
+        val countDownLatch = CountDownLatch(1)
+        val overrideDisplayId = 123
+        Thread {
+            assertFalse("Setting virtual pointer display should be unsuccessful",
+                localService.setVirtualMousePointerDisplayId(overrideDisplayId))
+            countDownLatch.countDown()
+        }.start()
+        assertFalse("Setting virtual pointer display should block",
+            countDownLatch.await(100, TimeUnit.MILLISECONDS))
+
+        val x = 42f
+        val y = 314f
+        // Assume the native callback updates the pointerDisplayId to the incorrect value.
+        service.onPointerDisplayIdChanged(Display.INVALID_DISPLAY, x, y)
+        testLooper.dispatchNext()
+        verify(wmCallbacks).notifyPointerDisplayIdChanged(Display.INVALID_DISPLAY, x, y)
+        assertTrue("Native callback unblocks calling thread",
+            countDownLatch.await(100, TimeUnit.MILLISECONDS))
+        verify(native).setPointerDisplayId(overrideDisplayId)
+    }
+
+    @Test
+    fun testSetVirtualMousePointerDisplayId_competingRequests() {
+        val firstRequestSyncLatch = CountDownLatch(1)
+        doAnswer {
+            firstRequestSyncLatch.countDown()
+        }.`when`(native).setPointerDisplayId(anyInt())
+
+        val firstRequestLatch = CountDownLatch(1)
+        val firstOverride = 123
+        Thread {
+            assertFalse("Setting virtual pointer display from thread 1 should be unsuccessful",
+                localService.setVirtualMousePointerDisplayId(firstOverride))
+            firstRequestLatch.countDown()
+        }.start()
+        assertFalse("Setting virtual pointer display should block",
+            firstRequestLatch.await(100, TimeUnit.MILLISECONDS))
+
+        assertTrue("Wait for first thread's request should succeed",
+            firstRequestSyncLatch.await(100, TimeUnit.MILLISECONDS))
+
+        val secondRequestLatch = CountDownLatch(1)
+        val secondOverride = 42
+        Thread {
+            assertTrue("Setting virtual mouse pointer from thread 2 should be successful",
+                localService.setVirtualMousePointerDisplayId(secondOverride))
+            secondRequestLatch.countDown()
+        }.start()
+        assertFalse("Setting virtual mouse pointer should block",
+            secondRequestLatch.await(100, TimeUnit.MILLISECONDS))
+
+        val x = 42f
+        val y = 314f
+        // Assume the native callback updates directly to the second request.
+        service.onPointerDisplayIdChanged(secondOverride, x, y)
+        testLooper.dispatchNext()
+        verify(wmCallbacks).notifyPointerDisplayIdChanged(secondOverride, x, y)
+        assertTrue("Native callback unblocks first thread",
+            firstRequestLatch.await(100, TimeUnit.MILLISECONDS))
+        assertTrue("Native callback unblocks second thread",
+            secondRequestLatch.await(100, TimeUnit.MILLISECONDS))
+        verify(native, times(2)).setPointerDisplayId(anyInt())
+    }
+}
\ No newline at end of file