Include self for FocusOwner Key and Rotary Traversal

In the Modifier.Node world, a node could be both a focus target as well as a key input/soft keyboard interception/rotary input node. When traversing the ancestors, we want to include the active focus target node in the traversal as it could be the node we are looking for.

This will let us flatten hierarchies and delegate to focus target from e.g. Clickable instead, which would be both a focus target and key input node.

Test: FocusTargetAttachDetachTest
Change-Id: Ica5c964d244daadf321e13224b9b22c817cfb10e
diff --git a/compose/ui/ui/src/androidInstrumentedTest/kotlin/androidx/compose/ui/focus/FocusTargetAttachDetachTest.kt b/compose/ui/ui/src/androidInstrumentedTest/kotlin/androidx/compose/ui/focus/FocusTargetAttachDetachTest.kt
index 85b3f11..f446af8 100644
--- a/compose/ui/ui/src/androidInstrumentedTest/kotlin/androidx/compose/ui/focus/FocusTargetAttachDetachTest.kt
+++ b/compose/ui/ui/src/androidInstrumentedTest/kotlin/androidx/compose/ui/focus/FocusTargetAttachDetachTest.kt
@@ -20,8 +20,29 @@
 import androidx.compose.runtime.getValue
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.setValue
+import androidx.compose.ui.ExperimentalComposeUiApi
 import androidx.compose.ui.Modifier
+import androidx.compose.ui.input.InputMode
+import androidx.compose.ui.input.InputModeManager
+import androidx.compose.ui.input.key.Key
+import androidx.compose.ui.input.key.KeyEvent
+import androidx.compose.ui.input.key.KeyInputModifierNode
+import androidx.compose.ui.input.key.NativeKeyEvent
+import androidx.compose.ui.input.key.SoftKeyboardInterceptionModifierNode
+import androidx.compose.ui.input.key.key
+import androidx.compose.ui.input.pointer.elementFor
+import androidx.compose.ui.input.rotary.RotaryInputModifierNode
+import androidx.compose.ui.input.rotary.RotaryScrollEvent
+import androidx.compose.ui.node.DelegatingNode
+import androidx.compose.ui.platform.LocalInputModeManager
+import androidx.compose.ui.platform.testTag
+import androidx.compose.ui.test.ExperimentalTestApi
 import androidx.compose.ui.test.junit4.createComposeRule
+import androidx.compose.ui.test.onNodeWithTag
+import androidx.compose.ui.test.onRoot
+import androidx.compose.ui.test.performKeyInput
+import androidx.compose.ui.test.performKeyPress
+import androidx.compose.ui.test.performRotaryScrollInput
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.MediumTest
 import com.google.common.truth.Truth.assertThat
@@ -980,6 +1001,300 @@
         }
     }
 
+    @OptIn(ExperimentalComposeUiApi::class, ExperimentalTestApi::class)
+    @Test
+    fun focusTarget_nodeThatIsKeyInputNodeKind_implementing_receivesKeyEventsWhenFocused() {
+        class FocusTargetAndKeyInputNode : DelegatingNode(), KeyInputModifierNode {
+            val keyEvents = mutableListOf<KeyEvent>()
+            val focusTargetNode = FocusTargetNode()
+
+            init {
+                delegate(focusTargetNode)
+            }
+
+            override fun onKeyEvent(event: KeyEvent): Boolean {
+                keyEvents.add(event)
+                return true
+            }
+
+            override fun onPreKeyEvent(event: KeyEvent) = false
+        }
+
+        val focusTargetAndKeyInputNode = FocusTargetAndKeyInputNode()
+        val focusTargetAndKeyInputModifier = elementFor(key1 = null, focusTargetAndKeyInputNode)
+
+        val focusRequester = FocusRequester()
+        val targetTestTag = "target"
+        lateinit var inputModeManager: InputModeManager
+
+        rule.setFocusableContent(extraItemForInitialFocus = false) {
+            inputModeManager = LocalInputModeManager.current
+            Box(
+                modifier = Modifier
+                    .testTag(targetTestTag)
+                    .focusRequester(focusRequester)
+                    .then(focusTargetAndKeyInputModifier)
+            )
+        }
+
+        rule.runOnUiThread {
+            inputModeManager.requestInputMode(InputMode.Keyboard)
+            focusRequester.requestFocus()
+        }
+
+        assertThat(focusTargetAndKeyInputNode.focusTargetNode.focusState.isFocused).isTrue()
+
+        rule.onNodeWithTag(targetTestTag).performKeyInput { keyDown(Key.Enter) }
+
+        assertThat(focusTargetAndKeyInputNode.keyEvents).hasSize(1)
+        assertThat(focusTargetAndKeyInputNode.keyEvents[0].key).isEqualTo(Key.Enter)
+    }
+
+    @OptIn(ExperimentalComposeUiApi::class, ExperimentalTestApi::class)
+    @Test
+    fun focusTarget_nodeThatIsKeyInputNodeKind_delegating_receivesKeyEventsWhenFocused() {
+        class FocusTargetAndKeyInputNode : DelegatingNode() {
+            val keyEvents = mutableListOf<KeyEvent>()
+            val focusTargetNode = FocusTargetNode()
+            val keyInputNode = object : KeyInputModifierNode, Modifier.Node() {
+                override fun onKeyEvent(event: KeyEvent): Boolean {
+                    keyEvents.add(event)
+                    return true
+                }
+
+                override fun onPreKeyEvent(event: KeyEvent) = false
+            }
+
+            init {
+                delegate(focusTargetNode)
+                delegate(keyInputNode)
+            }
+        }
+
+        val focusTargetAndKeyInputNode = FocusTargetAndKeyInputNode()
+        val focusTargetAndKeyInputModifier = elementFor(key1 = null, focusTargetAndKeyInputNode)
+
+        val focusRequester = FocusRequester()
+        val targetTestTag = "target"
+        lateinit var inputModeManager: InputModeManager
+
+        rule.setFocusableContent(extraItemForInitialFocus = false) {
+            inputModeManager = LocalInputModeManager.current
+            Box(
+                modifier = Modifier
+                    .testTag(targetTestTag)
+                    .focusRequester(focusRequester)
+                    .then(focusTargetAndKeyInputModifier)
+            )
+        }
+
+        rule.runOnUiThread {
+            inputModeManager.requestInputMode(InputMode.Keyboard)
+            focusRequester.requestFocus()
+        }
+
+        assertThat(focusTargetAndKeyInputNode.focusTargetNode.focusState.isFocused).isTrue()
+
+        rule.onNodeWithTag(targetTestTag).performKeyInput { keyDown(Key.Enter) }
+
+        assertThat(focusTargetAndKeyInputNode.keyEvents).hasSize(1)
+        assertThat(focusTargetAndKeyInputNode.keyEvents[0].key).isEqualTo(Key.Enter)
+    }
+
+    @OptIn(ExperimentalComposeUiApi::class)
+    @Test
+    fun focusTarget_nodeThatIsSoftKeyInputNodeKind_implementing_receivesSoftKeyEventsWhenFocused() {
+        class FocusTargetAndSoftKeyboardNode : DelegatingNode(),
+            SoftKeyboardInterceptionModifierNode {
+            val keyEvents = mutableListOf<KeyEvent>()
+            val focusTargetNode = FocusTargetNode()
+
+            init {
+                delegate(focusTargetNode)
+            }
+
+            override fun onInterceptKeyBeforeSoftKeyboard(event: KeyEvent) = keyEvents.add(event)
+
+            override fun onPreInterceptKeyBeforeSoftKeyboard(event: KeyEvent) = false
+        }
+
+        val focusTargetAndSoftKeyboardNode = FocusTargetAndSoftKeyboardNode()
+        val focusTargetAndSoftKeyboardModifier =
+            elementFor(key1 = null, focusTargetAndSoftKeyboardNode)
+
+        val focusRequester = FocusRequester()
+        val targetTestTag = "target"
+
+        rule.setFocusableContent(extraItemForInitialFocus = false) {
+            Box(
+                modifier = Modifier
+                    .testTag(targetTestTag)
+                    .focusRequester(focusRequester)
+                    .then(focusTargetAndSoftKeyboardModifier)
+            )
+        }
+
+        rule.runOnUiThread { focusRequester.requestFocus() }
+        assertThat(focusTargetAndSoftKeyboardNode.focusTargetNode.focusState.isFocused).isTrue()
+
+        // This test specifically uses performKeyPress over performKeyInput as performKeyPress calls
+        // sendKeyEvent, which in turn notifies FocusOwner that there's a
+        // SoftKeyboardInterceptionModifierNode-interceptable key event first. performKeyInput goes
+        // through dispatchKeyEvent which does not notify SoftKeyboardInterceptionModifierNodes.
+        rule.onRoot().performKeyPress(
+            KeyEvent(
+                NativeKeyEvent(
+                    android.view.KeyEvent.ACTION_DOWN,
+                    android.view.KeyEvent.KEYCODE_ENTER
+                )
+            )
+        )
+
+        assertThat(focusTargetAndSoftKeyboardNode.keyEvents).hasSize(1)
+        assertThat(focusTargetAndSoftKeyboardNode.keyEvents[0].key).isEqualTo(Key.Enter)
+    }
+
+    @OptIn(ExperimentalComposeUiApi::class)
+    @Test
+    fun focusTarget_nodeThatIsSoftKeyInputNodeKind_delegating_receivesSoftKeyEventsWhenFocused() {
+        class FocusTargetAndSoftKeyboardNode : DelegatingNode() {
+            val keyEvents = mutableListOf<KeyEvent>()
+            val focusTargetNode = FocusTargetNode()
+            val softKeyboardInterceptionNode = object : SoftKeyboardInterceptionModifierNode,
+                Modifier.Node() {
+                override fun onInterceptKeyBeforeSoftKeyboard(event: KeyEvent) =
+                    keyEvents.add(event)
+
+                override fun onPreInterceptKeyBeforeSoftKeyboard(event: KeyEvent) = false
+            }
+
+            init {
+                delegate(focusTargetNode)
+                delegate(softKeyboardInterceptionNode)
+            }
+        }
+
+        val focusTargetAndSoftKeyboardNode = FocusTargetAndSoftKeyboardNode()
+        val focusTargetAndSoftKeyboardModifier =
+            elementFor(key1 = null, focusTargetAndSoftKeyboardNode)
+
+        val focusRequester = FocusRequester()
+        val targetTestTag = "target"
+
+        rule.setFocusableContent(extraItemForInitialFocus = false) {
+            Box(
+                modifier = Modifier
+                    .testTag(targetTestTag)
+                    .focusRequester(focusRequester)
+                    .then(focusTargetAndSoftKeyboardModifier)
+            )
+        }
+
+        rule.runOnUiThread { focusRequester.requestFocus() }
+        assertThat(focusTargetAndSoftKeyboardNode.focusTargetNode.focusState.isFocused).isTrue()
+
+        // This test specifically uses performKeyPress over performKeyInput as performKeyPress calls
+        // sendKeyEvent, which in turn notifies FocusOwner that there's a
+        // SoftKeyboardInterceptionModifierNode-interceptable key event first. performKeyInput goes
+        // through dispatchKeyEvent which does not notify SoftKeyboardInterceptionModifierNodes.
+        rule.onRoot().performKeyPress(
+            KeyEvent(
+                NativeKeyEvent(
+                    android.view.KeyEvent.ACTION_DOWN,
+                    android.view.KeyEvent.KEYCODE_ENTER
+                )
+            )
+        )
+
+        assertThat(focusTargetAndSoftKeyboardNode.keyEvents).hasSize(1)
+        assertThat(focusTargetAndSoftKeyboardNode.keyEvents[0].key).isEqualTo(Key.Enter)
+    }
+
+    @OptIn(ExperimentalTestApi::class)
+    @Test
+    fun focusTarget_nodeThatIsRotaryInputNodeKind_implementing_receivesRotaryEventsWhenFocused() {
+        class FocusTargetAndRotaryNode : DelegatingNode(), RotaryInputModifierNode {
+            val events = mutableListOf<RotaryScrollEvent>()
+            val focusTargetNode = FocusTargetNode()
+
+            init {
+                delegate(focusTargetNode)
+            }
+
+            override fun onRotaryScrollEvent(event: RotaryScrollEvent) = events.add(event)
+
+            override fun onPreRotaryScrollEvent(event: RotaryScrollEvent) = false
+        }
+
+        val focusTargetAndRotaryNode = FocusTargetAndRotaryNode()
+        val focusTargetAndRotaryModifier = elementFor(key1 = null, focusTargetAndRotaryNode)
+
+        val focusRequester = FocusRequester()
+        val targetTestTag = "target"
+
+        rule.setFocusableContent(extraItemForInitialFocus = false) {
+            Box(
+                modifier = Modifier
+                    .testTag(targetTestTag)
+                    .focusRequester(focusRequester)
+                    .then(focusTargetAndRotaryModifier)
+            )
+        }
+
+        rule.runOnUiThread { focusRequester.requestFocus() }
+        assertThat(focusTargetAndRotaryNode.focusTargetNode.focusState.isFocused).isTrue()
+
+        rule.onNodeWithTag(targetTestTag).performRotaryScrollInput {
+            rotateToScrollVertically(100f)
+        }
+
+        assertThat(focusTargetAndRotaryNode.events).hasSize(1)
+        assertThat(focusTargetAndRotaryNode.events[0].verticalScrollPixels).isEqualTo(100f)
+    }
+
+    @OptIn(ExperimentalTestApi::class)
+    @Test
+    fun focusTarget_nodeThatIsRotaryInputNodeKind_delegating_receivesRotaryEventsWhenFocused() {
+        class FocusTargetAndRotaryNode : DelegatingNode() {
+            val events = mutableListOf<RotaryScrollEvent>()
+            val focusTargetNode = FocusTargetNode()
+            val rotaryInputNode = object : RotaryInputModifierNode, Modifier.Node() {
+                override fun onRotaryScrollEvent(event: RotaryScrollEvent) = events.add(event)
+                override fun onPreRotaryScrollEvent(event: RotaryScrollEvent) = false
+            }
+
+            init {
+                delegate(focusTargetNode)
+                delegate(rotaryInputNode)
+            }
+        }
+
+        val focusTargetAndRotaryNode = FocusTargetAndRotaryNode()
+        val focusTargetAndRotaryModifier = elementFor(key1 = null, focusTargetAndRotaryNode)
+
+        val focusRequester = FocusRequester()
+        val targetTestTag = "target"
+
+        rule.setFocusableContent(extraItemForInitialFocus = false) {
+            Box(
+                modifier = Modifier
+                    .testTag(targetTestTag)
+                    .focusRequester(focusRequester)
+                    .then(focusTargetAndRotaryModifier)
+            )
+        }
+
+        rule.runOnUiThread { focusRequester.requestFocus() }
+        assertThat(focusTargetAndRotaryNode.focusTargetNode.focusState.isFocused).isTrue()
+
+        rule.onNodeWithTag(targetTestTag).performRotaryScrollInput {
+            rotateToScrollVertically(100f)
+        }
+
+        assertThat(focusTargetAndRotaryNode.events).hasSize(1)
+        assertThat(focusTargetAndRotaryNode.events[0].verticalScrollPixels).isEqualTo(100f)
+    }
+
     private inline fun Modifier.thenIf(condition: Boolean, block: () -> Modifier): Modifier {
         return if (condition) then(block()) else this
     }
diff --git a/compose/ui/ui/src/commonMain/kotlin/androidx/compose/ui/focus/FocusOwnerImpl.kt b/compose/ui/ui/src/commonMain/kotlin/androidx/compose/ui/focus/FocusOwnerImpl.kt
index c78b8ef..257922e 100644
--- a/compose/ui/ui/src/commonMain/kotlin/androidx/compose/ui/focus/FocusOwnerImpl.kt
+++ b/compose/ui/ui/src/commonMain/kotlin/androidx/compose/ui/focus/FocusOwnerImpl.kt
@@ -42,6 +42,7 @@
 import androidx.compose.ui.node.ancestors
 import androidx.compose.ui.node.dispatchForKind
 import androidx.compose.ui.node.nearestAncestor
+import androidx.compose.ui.node.visitAncestors
 import androidx.compose.ui.node.visitLocalDescendants
 import androidx.compose.ui.platform.InspectorInfo
 import androidx.compose.ui.unit.LayoutDirection
@@ -266,10 +267,10 @@
 
         val activeFocusTarget = rootFocusNode.findActiveFocusNode()
         val focusedKeyInputNode = activeFocusTarget?.lastLocalKeyInputNode()
-            ?: activeFocusTarget?.nearestAncestor(Nodes.KeyInput)?.node
+            ?: activeFocusTarget?.nearestAncestorIncludingSelf(Nodes.KeyInput)?.node
             ?: rootFocusNode.nearestAncestor(Nodes.KeyInput)?.node
 
-        focusedKeyInputNode?.traverseAncestors(
+        focusedKeyInputNode?.traverseAncestorsIncludingSelf(
             type = Nodes.KeyInput,
             onPreVisit = { if (it.onPreKeyEvent(keyEvent)) return true },
             onVisit = { if (onFocusedItem.invoke()) return true },
@@ -285,9 +286,9 @@
         }
 
         val focusedSoftKeyboardInterceptionNode = rootFocusNode.findActiveFocusNode()
-            ?.nearestAncestor(Nodes.SoftKeyboardKeyInput)
+            ?.nearestAncestorIncludingSelf(Nodes.SoftKeyboardKeyInput)
 
-        focusedSoftKeyboardInterceptionNode?.traverseAncestors(
+        focusedSoftKeyboardInterceptionNode?.traverseAncestorsIncludingSelf(
             type = Nodes.SoftKeyboardKeyInput,
             onPreVisit = { if (it.onPreInterceptKeyBeforeSoftKeyboard(keyEvent)) return true },
             onVisit = { /* TODO(b/320510084): dispatch soft keyboard events to embedded views. */ },
@@ -305,9 +306,9 @@
         }
 
         val focusedRotaryInputNode = rootFocusNode.findActiveFocusNode()
-            ?.nearestAncestor(Nodes.RotaryInput)
+            ?.nearestAncestorIncludingSelf(Nodes.RotaryInput)
 
-        focusedRotaryInputNode?.traverseAncestors(
+        focusedRotaryInputNode?.traverseAncestorsIncludingSelf(
             type = Nodes.RotaryInput,
             onPreVisit = { if (it.onPreRotaryScrollEvent(event)) return true },
             onVisit = { /* TODO(b/320510084): dispatch rotary events to embedded views. */ },
@@ -341,7 +342,7 @@
         }
     }
 
-    private inline fun <reified T : DelegatableNode> DelegatableNode.traverseAncestors(
+    private inline fun <reified T : DelegatableNode> DelegatableNode.traverseAncestorsIncludingSelf(
         type: NodeKind<T>,
         onPreVisit: (T) -> Unit,
         onVisit: () -> Unit,
@@ -355,6 +356,15 @@
         ancestors?.fastForEach(onPostVisit)
     }
 
+    private inline fun <reified T : Any> DelegatableNode.nearestAncestorIncludingSelf(
+        type: NodeKind<T>
+    ): T? {
+        visitAncestors(type, includeSelf = true) {
+            return it
+        }
+        return null
+    }
+
     /**
      * Searches for the currently focused item, and returns its coordinates as a rect.
      */