Merge "Force prefetch if item is needed in the next frame for lazy layouts" into androidx-main
diff --git a/compose/foundation/foundation/api/current.txt b/compose/foundation/foundation/api/current.txt
index e026507..f8dc027 100644
--- a/compose/foundation/foundation/api/current.txt
+++ b/compose/foundation/foundation/api/current.txt
@@ -1132,6 +1132,7 @@
 
   public static sealed interface LazyLayoutPrefetchState.PrefetchHandle {
     method public void cancel();
+    method public void markAsUrgent();
   }
 
   public final class Lazy_androidKt {
@@ -1156,8 +1157,7 @@
   }
 
   @SuppressCompatibility @androidx.compose.foundation.ExperimentalFoundationApi public interface PrefetchRequestScope {
-    method public long getAvailableTimeNanos();
-    property public abstract long availableTimeNanos;
+    method public long availableTimeNanos();
   }
 
   @SuppressCompatibility @androidx.compose.foundation.ExperimentalFoundationApi public interface PrefetchScheduler {
diff --git a/compose/foundation/foundation/api/restricted_current.txt b/compose/foundation/foundation/api/restricted_current.txt
index cc020da..8ffe7e9 100644
--- a/compose/foundation/foundation/api/restricted_current.txt
+++ b/compose/foundation/foundation/api/restricted_current.txt
@@ -1134,6 +1134,7 @@
 
   public static sealed interface LazyLayoutPrefetchState.PrefetchHandle {
     method public void cancel();
+    method public void markAsUrgent();
   }
 
   public final class Lazy_androidKt {
@@ -1158,8 +1159,7 @@
   }
 
   @SuppressCompatibility @androidx.compose.foundation.ExperimentalFoundationApi public interface PrefetchRequestScope {
-    method public long getAvailableTimeNanos();
-    property public abstract long availableTimeNanos;
+    method public long availableTimeNanos();
   }
 
   @SuppressCompatibility @androidx.compose.foundation.ExperimentalFoundationApi public interface PrefetchScheduler {
diff --git a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/grid/LazyGridPrefetcherTest.kt b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/grid/LazyGridPrefetcherTest.kt
index 087d8db..c4d5b3a 100644
--- a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/grid/LazyGridPrefetcherTest.kt
+++ b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/grid/LazyGridPrefetcherTest.kt
@@ -14,14 +14,20 @@
  * limitations under the License.
  */
 
+@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
+
 package androidx.compose.foundation.lazy.grid
 
 import androidx.compose.foundation.AutoTestFrameClock
+import androidx.compose.foundation.ExperimentalFoundationApi
 import androidx.compose.foundation.gestures.Orientation
 import androidx.compose.foundation.gestures.scrollBy
 import androidx.compose.foundation.layout.PaddingValues
 import androidx.compose.foundation.layout.Spacer
+import androidx.compose.foundation.lazy.layout.TestPrefetchScheduler
+import androidx.compose.runtime.Composable
 import androidx.compose.runtime.DisposableEffect
+import androidx.compose.runtime.remember
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.layout.Remeasurement
 import androidx.compose.ui.layout.RemeasurementModifier
@@ -58,6 +64,16 @@
     val itemsSizeDp = with(rule.density) { itemsSizePx.toDp() }
 
     lateinit var state: LazyGridState
+    private val scheduler = TestPrefetchScheduler()
+
+    @OptIn(ExperimentalFoundationApi::class)
+    @Composable
+    fun rememberState(
+        initialFirstVisibleItemIndex: Int = 0,
+        initialFirstVisibleItemScrollOffset: Int = 0
+    ): LazyGridState = remember {
+        LazyGridState(initialFirstVisibleItemIndex, initialFirstVisibleItemScrollOffset, scheduler)
+    }
 
     @Test
     fun notPrefetchingForwardInitially() {
@@ -85,8 +101,8 @@
             }
         }
 
-        waitForPrefetch(4)
-        waitForPrefetch(5)
+        waitForPrefetch()
+        waitForPrefetch()
 
         rule.onNodeWithTag("4")
             .assertExists()
@@ -106,8 +122,8 @@
             }
         }
 
-        waitForPrefetch(2)
-        waitForPrefetch(3)
+        waitForPrefetch()
+        waitForPrefetch()
 
         rule.onNodeWithTag("2")
             .assertExists()
@@ -127,8 +143,8 @@
             }
         }
 
-        waitForPrefetch(6)
-        waitForPrefetch(7)
+        waitForPrefetch()
+        waitForPrefetch()
 
         rule.onNodeWithTag("6")
             .assertExists()
@@ -144,8 +160,8 @@
             }
         }
 
-        waitForPrefetch(0)
-        waitForPrefetch(1)
+        waitForPrefetch()
+        waitForPrefetch()
 
         rule.onNodeWithTag("0")
             .assertExists()
@@ -165,7 +181,7 @@
             }
         }
 
-        waitForPrefetch(4)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking {
@@ -174,7 +190,7 @@
             }
         }
 
-        waitForPrefetch(6)
+        waitForPrefetch()
 
         rule.onNodeWithTag("4")
             .assertIsDisplayed()
@@ -194,7 +210,7 @@
             }
         }
 
-        waitForPrefetch(4)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking {
@@ -203,7 +219,7 @@
             }
         }
 
-        waitForPrefetch(2)
+        waitForPrefetch()
 
         rule.onNodeWithTag("4")
             .assertIsDisplayed()
@@ -225,8 +241,7 @@
             }
         }
 
-        waitForPrefetch(6)
-        waitForPrefetch(7)
+        waitForPrefetch()
 
         rule.onNodeWithTag("6")
             .assertExists()
@@ -244,8 +259,7 @@
             }
         }
 
-        waitForPrefetch(0)
-        waitForPrefetch(1)
+        waitForPrefetch()
 
         rule.onNodeWithTag("0")
             .assertExists()
@@ -283,7 +297,7 @@
             }
         }
 
-        waitForPrefetch(6)
+        waitForPrefetch()
 
         rule.onNodeWithTag("8")
             .assertExists()
@@ -296,7 +310,7 @@
             }
         }
 
-        waitForPrefetch(0)
+        waitForPrefetch()
 
         rule.onNodeWithTag("0")
             .assertExists()
@@ -316,7 +330,7 @@
             ) { constraints ->
                 val placeable = if (emit) {
                     subcompose(Unit) {
-                        state = rememberLazyGridState()
+                        state = rememberState()
                         LazyGrid(
                             2,
                             Modifier.mainAxisSize(itemsSizeDp * 1.5f),
@@ -355,7 +369,7 @@
     fun snappingToOtherPositionWhilePrefetchIsScheduled() {
         val composedItems = mutableListOf<Int>()
         rule.setContent {
-            state = rememberLazyGridState()
+            state = rememberState()
             LazyGrid(
                 1,
                 Modifier.mainAxisSize(itemsSizeDp * 1.5f),
@@ -410,7 +424,7 @@
             }
         }
 
-        waitForPrefetch(13)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking(AutoTestFrameClock()) {
@@ -424,14 +438,13 @@
         }
     }
 
-    private fun waitForPrefetch(index: Int) {
-        rule.waitUntil {
-            activeNodes.contains(index) && activeMeasuredNodes.contains(index)
+    private fun waitForPrefetch() {
+        rule.runOnIdle {
+            scheduler.executeActiveRequests()
         }
     }
 
     private val activeNodes = mutableSetOf<Int>()
-    private val activeMeasuredNodes = mutableSetOf<Int>()
 
     private fun composeGrid(
         firstItem: Int = 0,
@@ -440,7 +453,7 @@
         contentPadding: PaddingValues = PaddingValues(0.dp)
     ) {
         rule.setContent {
-            state = rememberLazyGridState(
+            state = rememberState(
                 initialFirstVisibleItemIndex = firstItem,
                 initialFirstVisibleItemScrollOffset = itemOffset
             )
@@ -456,7 +469,6 @@
                         activeNodes.add(it)
                         onDispose {
                             activeNodes.remove(it)
-                            activeMeasuredNodes.remove(it)
                         }
                     }
                     Spacer(
@@ -465,7 +477,6 @@
                             .testTag("$it")
                             .layout { measurable, constraints ->
                                 val placeable = measurable.measure(constraints)
-                                activeMeasuredNodes.add(it)
                                 layout(placeable.width, placeable.height) {
                                     placeable.place(0, 0)
                                 }
diff --git a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutTest.kt b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutTest.kt
index bb6d01a..5596f14 100644
--- a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutTest.kt
+++ b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutTest.kt
@@ -257,7 +257,8 @@
                     .then(modifier))
         }
         var needToCompose by mutableStateOf(false)
-        val prefetchState = LazyLayoutPrefetchState()
+        val scheduler = TestPrefetchScheduler()
+        val prefetchState = LazyLayoutPrefetchState(scheduler)
         rule.setContent {
             LazyLayout(itemProvider, prefetchState = prefetchState) {
                 val item = if (needToCompose) {
@@ -273,9 +274,10 @@
             assertThat(measureCount).isEqualTo(0)
 
             prefetchState.schedulePrefetch(0, constraints)
-        }
 
-        rule.waitUntil { measureCount == 1 }
+            scheduler.executeActiveRequests()
+            assertThat(measureCount).isEqualTo(1)
+        }
 
         rule.onNodeWithTag("0").assertIsNotDisplayed()
 
@@ -303,20 +305,18 @@
                 }
             }
         }
-        val prefetchState = LazyLayoutPrefetchState()
+        val scheduler = TestPrefetchScheduler()
+        val prefetchState = LazyLayoutPrefetchState(scheduler)
         rule.setContent {
             LazyLayout(itemProvider, prefetchState = prefetchState) {
                 layout(100, 100) {}
             }
         }
 
-        val handle = rule.runOnIdle {
-            prefetchState.schedulePrefetch(0, Constraints.fixed(50, 50))
-        }
-
-        rule.waitUntil { composed }
-
         rule.runOnIdle {
+            val handle = prefetchState.schedulePrefetch(0, Constraints.fixed(50, 50))
+            scheduler.executeActiveRequests()
+            assertThat(composed).isTrue()
             handle.cancel()
         }
 
diff --git a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/layout/TestPrefetchScheduler.kt b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/layout/TestPrefetchScheduler.kt
new file mode 100644
index 0000000..ad3b8a6
--- /dev/null
+++ b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/layout/TestPrefetchScheduler.kt
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.compose.foundation.lazy.layout
+
+import androidx.compose.foundation.ExperimentalFoundationApi
+
+@OptIn(ExperimentalFoundationApi::class)
+internal class TestPrefetchScheduler : PrefetchScheduler {
+
+    private var activeRequests = mutableListOf<PrefetchRequest>()
+    override fun schedulePrefetch(prefetchRequest: PrefetchRequest) {
+        activeRequests.add(prefetchRequest)
+    }
+
+    fun executeActiveRequests() {
+        activeRequests.forEach {
+            with(it) { scope.execute() }
+        }
+        activeRequests.clear()
+    }
+
+    private val scope = object : PrefetchRequestScope {
+        override fun availableTimeNanos(): Long = Long.MAX_VALUE
+    }
+}
diff --git a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListFocusMoveCompositionCountTest.kt b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListFocusMoveCompositionCountTest.kt
index b8a37d5..21f1531 100644
--- a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListFocusMoveCompositionCountTest.kt
+++ b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListFocusMoveCompositionCountTest.kt
@@ -14,11 +14,14 @@
  * limitations under the License.
  */
 
+@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
+
 package androidx.compose.foundation.lazy.list
 
 import androidx.compose.foundation.focusable
 import androidx.compose.foundation.layout.Box
 import androidx.compose.foundation.layout.size
+import androidx.compose.foundation.lazy.LazyListState
 import androidx.compose.foundation.lazy.LazyRow
 import androidx.compose.runtime.SideEffect
 import androidx.compose.ui.Modifier
@@ -45,6 +48,10 @@
 
     private val composedItems = mutableSetOf<Int>()
 
+    private val state = LazyListState().also {
+        it.prefetchingEnabled = false
+    }
+
     @Test
     fun moveFocus() {
         // Arrange.
@@ -52,7 +59,7 @@
         lateinit var focusManager: FocusManager
         rule.setContent {
             focusManager = LocalFocusManager.current
-            LazyRow(Modifier.size(rowSize)) {
+            LazyRow(Modifier.size(rowSize), state) {
                 items(100) { index ->
                     Box(
                         Modifier
@@ -71,7 +78,7 @@
         rule.runOnIdle { focusManager.moveFocus(FocusDirection.Right) }
 
         // Assert
-        rule.runOnIdle { assertThat(composedItems).containsExactly(5, 6) }
+        rule.runOnIdle { assertThat(composedItems).containsExactly(5) }
     }
 
     @Test
@@ -81,7 +88,7 @@
         lateinit var focusManager: FocusManager
         rule.setContent {
             focusManager = LocalFocusManager.current
-            LazyRow(Modifier.size(rowSize)) {
+            LazyRow(Modifier.size(rowSize), state) {
                 items(100) { index ->
                     Box(Modifier.size(itemSize).focusable()) {
                         Box(Modifier.size(itemSize).focusable().testTag("$index"))
@@ -97,7 +104,7 @@
         rule.runOnIdle { focusManager.moveFocus(FocusDirection.Right) }
 
         // Assert
-        rule.runOnIdle { assertThat(composedItems).containsExactly(5, 6) }
+        rule.runOnIdle { assertThat(composedItems).containsExactly(5) }
     }
 
     @Test
@@ -107,7 +114,7 @@
         lateinit var focusManager: FocusManager
         rule.setContent {
             focusManager = LocalFocusManager.current
-            LazyRow(Modifier.size(rowSize)) {
+            LazyRow(Modifier.size(rowSize), state) {
                 items(100) { index ->
                     Box(Modifier.size(itemSize).focusable()) {
                         Box(Modifier.size(itemSize).focusable()) {
@@ -125,6 +132,6 @@
         rule.runOnIdle { focusManager.moveFocus(FocusDirection.Right) }
 
         // Assert
-        rule.runOnIdle { assertThat(composedItems).containsExactly(5, 6) }
+        rule.runOnIdle { assertThat(composedItems).containsExactly(5) }
     }
 }
diff --git a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListNestedPrefetchingTest.kt b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListNestedPrefetchingTest.kt
index ebd489b..229f84a 100644
--- a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListNestedPrefetchingTest.kt
+++ b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListNestedPrefetchingTest.kt
@@ -26,6 +26,8 @@
 import androidx.compose.foundation.lazy.LazyListPrefetchStrategy
 import androidx.compose.foundation.lazy.LazyListState
 import androidx.compose.foundation.lazy.layout.NestedPrefetchScope
+import androidx.compose.foundation.lazy.layout.PrefetchScheduler
+import androidx.compose.foundation.lazy.layout.TestPrefetchScheduler
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.DisposableEffect
 import androidx.compose.runtime.remember
@@ -70,11 +72,19 @@
     private val itemsSizePx = 30
     private val itemsSizeDp = with(rule.density) { itemsSizePx.toDp() }
     private val activeNodes = mutableSetOf<String>()
-    private val activeMeasuredNodes = mutableSetOf<String>()
+    private val scheduler = TestPrefetchScheduler()
+
+    @OptIn(ExperimentalFoundationApi::class)
+    private val strategy = object : LazyListPrefetchStrategy by LazyListPrefetchStrategy() {
+        override val prefetchScheduler: PrefetchScheduler = scheduler
+    }
+
+    @OptIn(ExperimentalFoundationApi::class)
+    private fun createState(): LazyListState = LazyListState(prefetchStrategy = strategy)
 
     @Test
     fun nestedPrefetchingForwardAfterSmallScroll() {
-        val state = LazyListState()
+        val state = createState()
         composeList(state)
 
         val prefetchIndex = 2
@@ -85,7 +95,7 @@
                 }
             }
 
-            waitForPrefetch(tagFor(prefetchIndex))
+            waitForPrefetch()
         }
 
         // We want to make sure nested children were precomposed before the parent was premeasured
@@ -111,7 +121,7 @@
 
     @Test
     fun cancelingPrefetchCancelsItsNestedPrefetches() {
-        val state = LazyListState()
+        val state = createState()
         composeList(state)
 
         rule.runOnIdle {
@@ -122,7 +132,7 @@
             }
         }
 
-        waitForPrefetch(tagFor(3))
+        waitForPrefetch()
 
         rule.runOnIdle {
             assertThat(activeNodes).contains(tagFor(3))
@@ -141,7 +151,7 @@
             }
         }
 
-        waitForPrefetch(tagFor(7))
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking(AutoTestFrameClock()) {
@@ -160,7 +170,7 @@
     @OptIn(ExperimentalFoundationApi::class)
     @Test
     fun overridingNestedPrefetchCountIsRespected() {
-        val state = LazyListState()
+        val state = createState()
         composeList(
             state,
             createNestedLazyListState = {
@@ -177,7 +187,7 @@
                 }
             }
 
-            waitForPrefetch(tagFor(prefetchIndex))
+            waitForPrefetch()
         }
 
         // Since the nested prefetch count on the strategy is 1, we only expect index 0 to be
@@ -197,7 +207,7 @@
     fun nestedPrefetchIsMeasuredWithProvidedConstraints() {
         val nestedConstraints =
             Constraints(minWidth = 20, minHeight = 20, maxWidth = 20, maxHeight = 20)
-        val state = LazyListState()
+        val state = createState()
         composeList(
             state,
             createNestedLazyListState = {
@@ -214,7 +224,7 @@
                 }
             }
 
-            waitForPrefetch(tagFor(prefetchIndex))
+            waitForPrefetch()
         }
 
         assertThat(actions).containsExactly(
@@ -232,7 +242,7 @@
 
     @Test
     fun nestedPrefetchStartsFromFirstVisibleItemIndex() {
-        val state = LazyListState()
+        val state = createState()
         composeList(
             state,
             createNestedLazyListState = {
@@ -247,7 +257,7 @@
                 }
             }
 
-            waitForPrefetch(tagFor(prefetchIndex))
+            waitForPrefetch()
         }
 
         assertThat(actions).containsExactly(
@@ -273,9 +283,9 @@
         }
     }
 
-    private fun waitForPrefetch(tag: String) {
-        rule.waitUntil {
-            activeNodes.contains(tag) && activeMeasuredNodes.contains(tag)
+    private fun waitForPrefetch() {
+        rule.runOnIdle {
+            scheduler.executeActiveRequests()
         }
     }
 
@@ -332,17 +342,14 @@
             actions?.add(Action.Compose(index, nestedIndex))
             onDispose {
                 activeNodes.remove(tag)
-                activeMeasuredNodes.remove(tag)
             }
         }
     }
 
     private fun Modifier.trackWhenMeasured(index: Int, nestedIndex: Int? = null): Modifier {
-        val tag = tagFor(index, nestedIndex)
         return this then Modifier.layout { measurable, constraints ->
             actions?.add(Action.Measure(index, nestedIndex))
             val placeable = measurable.measure(constraints)
-            activeMeasuredNodes.add(tag)
             layout(placeable.width, placeable.height) {
                 placeable.place(0, 0)
             }
diff --git a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListPrefetchStrategyTest.kt b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListPrefetchStrategyTest.kt
index 31dd332..0d773c2 100644
--- a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListPrefetchStrategyTest.kt
+++ b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListPrefetchStrategyTest.kt
@@ -29,9 +29,10 @@
 import androidx.compose.foundation.lazy.LazyListState
 import androidx.compose.foundation.lazy.layout.LazyLayoutPrefetchState
 import androidx.compose.foundation.lazy.layout.NestedPrefetchScope
+import androidx.compose.foundation.lazy.layout.PrefetchScheduler
+import androidx.compose.foundation.lazy.layout.TestPrefetchScheduler
 import androidx.compose.foundation.lazy.list.LazyListPrefetchStrategyTest.RecordingLazyListPrefetchStrategy.Callback
 import androidx.compose.foundation.lazy.rememberLazyListState
-import androidx.compose.runtime.DisposableEffect
 import androidx.compose.runtime.MutableState
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.ui.Modifier
@@ -74,10 +75,11 @@
     private val itemsSizeDp = with(rule.density) { itemsSizePx.toDp() }
 
     lateinit var state: LazyListState
+    private val scheduler = TestPrefetchScheduler()
 
     @Test
     fun callbacksTriggered_whenScrollForwardsWithoutVisibleItemsChanged() {
-        val strategy = RecordingLazyListPrefetchStrategy()
+        val strategy = RecordingLazyListPrefetchStrategy(scheduler)
 
         composeList(prefetchStrategy = strategy)
 
@@ -104,7 +106,7 @@
 
     @Test
     fun callbacksTriggered_whenScrollBackwardsWithoutVisibleItemsChanged() {
-        val strategy = RecordingLazyListPrefetchStrategy()
+        val strategy = RecordingLazyListPrefetchStrategy(scheduler)
 
         composeList(firstItem = 10, itemOffset = 10, prefetchStrategy = strategy)
 
@@ -131,7 +133,7 @@
 
     @Test
     fun callbacksTriggered_whenScrollWithVisibleItemsChanged() {
-        val strategy = RecordingLazyListPrefetchStrategy()
+        val strategy = RecordingLazyListPrefetchStrategy(scheduler)
 
         composeList(prefetchStrategy = strategy)
 
@@ -161,7 +163,7 @@
 
     @Test
     fun callbacksTriggered_whenItemsChangedWithoutScroll() {
-        val strategy = RecordingLazyListPrefetchStrategy()
+        val strategy = RecordingLazyListPrefetchStrategy(scheduler)
         val numItems = mutableStateOf(100)
 
         composeList(prefetchStrategy = strategy, numItems = numItems)
@@ -196,20 +198,17 @@
             }
         }
 
-        waitForPrefetch(2)
+        waitForPrefetch()
         rule.onNodeWithTag("2")
             .assertExists()
     }
 
-    private fun waitForPrefetch(index: Int) {
-        rule.waitUntil {
-            activeNodes.contains(index) && activeMeasuredNodes.contains(index)
+    private fun waitForPrefetch() {
+        rule.runOnIdle {
+            scheduler.executeActiveRequests()
         }
     }
 
-    private val activeNodes = mutableSetOf<Int>()
-    private val activeMeasuredNodes = mutableSetOf<Int>()
-
     @OptIn(ExperimentalFoundationApi::class)
     private fun composeList(
         firstItem: Int = 0,
@@ -228,13 +227,6 @@
                 state,
             ) {
                 items(numItems.value) {
-                    DisposableEffect(it) {
-                        activeNodes.add(it)
-                        onDispose {
-                            activeNodes.remove(it)
-                            activeMeasuredNodes.remove(it)
-                        }
-                    }
                     Spacer(
                         Modifier
                             .mainAxisSize(itemsSizeDp)
@@ -242,7 +234,6 @@
                             .testTag("$it")
                             .layout { measurable, constraints ->
                                 val placeable = measurable.measure(constraints)
-                                activeMeasuredNodes.add(it)
                                 layout(placeable.width, placeable.height) {
                                     placeable.place(0, 0)
                                 }
@@ -256,7 +247,10 @@
     /**
      * LazyListPrefetchStrategy that just records callbacks without scheduling prefetches.
      */
-    private class RecordingLazyListPrefetchStrategy : LazyListPrefetchStrategy {
+    private class RecordingLazyListPrefetchStrategy(
+        override val prefetchScheduler: PrefetchScheduler?
+    ) : LazyListPrefetchStrategy {
+
         sealed interface Callback {
             data class OnScroll(val delta: Float, val visibleIndices: List<Int>) : Callback
             data class OnVisibleItemsUpdated(val visibleIndices: List<Int>) : Callback
diff --git a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListPrefetcherTest.kt b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListPrefetcherTest.kt
index 70a88d8..ef1aed5 100644
--- a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListPrefetcherTest.kt
+++ b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/list/LazyListPrefetcherTest.kt
@@ -22,7 +22,10 @@
 import androidx.compose.foundation.gestures.scrollBy
 import androidx.compose.foundation.layout.PaddingValues
 import androidx.compose.foundation.layout.Spacer
+import androidx.compose.foundation.lazy.LazyListPrefetchStrategy
 import androidx.compose.foundation.lazy.LazyListState
+import androidx.compose.foundation.lazy.layout.PrefetchScheduler
+import androidx.compose.foundation.lazy.layout.TestPrefetchScheduler
 import androidx.compose.foundation.lazy.rememberLazyListState
 import androidx.compose.runtime.DisposableEffect
 import androidx.compose.ui.Modifier
@@ -71,6 +74,13 @@
 
     lateinit var state: LazyListState
 
+    private val scheduler = TestPrefetchScheduler()
+
+    @OptIn(ExperimentalFoundationApi::class)
+    private val strategy = object : LazyListPrefetchStrategy by LazyListPrefetchStrategy() {
+        override val prefetchScheduler: PrefetchScheduler = scheduler
+    }
+
     @Test
     fun notPrefetchingForwardInitially() {
         composeList()
@@ -97,7 +107,7 @@
             }
         }
 
-        waitForPrefetch(preFetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$preFetchIndex")
             .assertExists()
@@ -115,7 +125,7 @@
             }
         }
 
-        waitForPrefetch(1)
+        waitForPrefetch()
 
         rule.onNodeWithTag("1")
             .assertExists()
@@ -134,7 +144,7 @@
             }
         }
         var prefetchIndex = initialIndex + 2
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -149,7 +159,7 @@
         }
 
         prefetchIndex -= 3
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -167,7 +177,7 @@
             }
         }
 
-        waitForPrefetch(2)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking {
@@ -178,7 +188,7 @@
 
         val prefetchIndex = 3
 
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("${prefetchIndex - 1}")
             .assertIsDisplayed()
@@ -198,7 +208,7 @@
             }
         }
 
-        waitForPrefetch(2)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking {
@@ -207,7 +217,7 @@
             }
         }
 
-        waitForPrefetch(1)
+        waitForPrefetch()
 
         rule.onNodeWithTag("2")
             .assertIsDisplayed()
@@ -230,7 +240,7 @@
 
         var prefetchIndex = initialIndex + 2
 
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -245,7 +255,7 @@
         }
 
         prefetchIndex -= 3
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -281,7 +291,7 @@
         }
 
         var prefetchIndex = initialIndex + 1
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("${prefetchIndex + 1}")
             .assertExists()
@@ -295,7 +305,7 @@
         }
 
         prefetchIndex -= 3
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -458,7 +468,7 @@
             }
         }
 
-        waitForPrefetch(7)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking(AutoTestFrameClock()) {
@@ -472,14 +482,13 @@
         }
     }
 
-    private fun waitForPrefetch(index: Int) {
-        rule.waitUntil {
-            activeNodes.contains(index) && activeMeasuredNodes.contains(index)
+    private fun waitForPrefetch() {
+        rule.runOnIdle {
+            scheduler.executeActiveRequests()
         }
     }
 
     private val activeNodes = mutableSetOf<Int>()
-    private val activeMeasuredNodes = mutableSetOf<Int>()
 
     private fun composeList(
         firstItem: Int = 0,
@@ -488,9 +497,11 @@
         contentPadding: PaddingValues = PaddingValues(0.dp)
     ) {
         rule.setContent {
+            @OptIn(ExperimentalFoundationApi::class)
             state = rememberLazyListState(
                 initialFirstVisibleItemIndex = firstItem,
-                initialFirstVisibleItemScrollOffset = itemOffset
+                initialFirstVisibleItemScrollOffset = itemOffset,
+                prefetchStrategy = strategy
             )
             LazyColumnOrRow(
                 Modifier.mainAxisSize(itemsSizeDp * 1.5f),
@@ -504,7 +515,6 @@
                         activeNodes.add(it)
                         onDispose {
                             activeNodes.remove(it)
-                            activeMeasuredNodes.remove(it)
                         }
                     }
                     Spacer(
@@ -514,7 +524,6 @@
                             .testTag("$it")
                             .layout { measurable, constraints ->
                                 val placeable = measurable.measure(constraints)
-                                activeMeasuredNodes.add(it)
                                 layout(placeable.width, placeable.height) {
                                     placeable.place(0, 0)
                                 }
diff --git a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridPrefetcherTest.kt b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridPrefetcherTest.kt
index 224a1cc..72084b7 100644
--- a/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridPrefetcherTest.kt
+++ b/compose/foundation/foundation/integration-tests/lazy-tests/src/androidTest/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridPrefetcherTest.kt
@@ -14,6 +14,8 @@
  * limitations under the License.
  */
 
+@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
+
 package androidx.compose.foundation.lazy.staggeredgrid
 
 import androidx.compose.foundation.AutoTestFrameClock
@@ -22,7 +24,10 @@
 import androidx.compose.foundation.gestures.Orientation
 import androidx.compose.foundation.gestures.scrollBy
 import androidx.compose.foundation.layout.Spacer
+import androidx.compose.foundation.lazy.layout.TestPrefetchScheduler
+import androidx.compose.runtime.Composable
 import androidx.compose.runtime.DisposableEffect
+import androidx.compose.runtime.remember
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.graphics.Color
 import androidx.compose.ui.layout.Remeasurement
@@ -63,6 +68,20 @@
     val itemsSizeDp = with(rule.density) { itemsSizePx.toDp() }
 
     internal lateinit var state: LazyStaggeredGridState
+    private val scheduler = TestPrefetchScheduler()
+
+    @OptIn(ExperimentalFoundationApi::class)
+    @Composable
+    fun rememberState(
+        initialFirstVisibleItemIndex: Int = 0,
+        initialFirstVisibleItemOffset: Int = 0
+    ): LazyStaggeredGridState = remember {
+        LazyStaggeredGridState(
+            intArrayOf(initialFirstVisibleItemIndex),
+            intArrayOf(initialFirstVisibleItemOffset),
+            scheduler
+        )
+    }
 
     @Test
     fun notPrefetchingForwardInitially() {
@@ -90,7 +109,7 @@
             }
         }
 
-        waitForPrefetch(5)
+        waitForPrefetch()
 
         rule.onNodeWithTag("4")
             .assertExists()
@@ -110,7 +129,7 @@
             }
         }
 
-        waitForPrefetch(2)
+        waitForPrefetch()
 
         rule.onNodeWithTag("2")
             .assertExists()
@@ -130,7 +149,7 @@
             }
         }
 
-        waitForPrefetch(9)
+        waitForPrefetch()
 
         rule.onNodeWithTag("8")
             .assertExists()
@@ -145,7 +164,7 @@
             }
         }
 
-        waitForPrefetch(2)
+        waitForPrefetch()
 
         rule.onNodeWithTag("2")
             .assertExists()
@@ -165,7 +184,7 @@
             }
         }
 
-        waitForPrefetch(4)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking {
@@ -174,7 +193,7 @@
             }
         }
 
-        waitForPrefetch(6)
+        waitForPrefetch()
 
         rule.onNodeWithTag("4")
             .assertIsDisplayed()
@@ -194,7 +213,7 @@
             }
         }
 
-        waitForPrefetch(4)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking {
@@ -203,7 +222,7 @@
             }
         }
 
-        waitForPrefetch(2)
+        waitForPrefetch()
 
         rule.onNodeWithTag("4")
             .assertIsDisplayed()
@@ -226,7 +245,7 @@
             }
         }
 
-        waitForPrefetch(13)
+        waitForPrefetch()
 
         rule.onNodeWithTag("12")
             .assertExists()
@@ -378,7 +397,7 @@
     fun snappingToOtherPositionWhilePrefetchIsScheduled() {
         val composedItems = mutableListOf<Int>()
         rule.setContent {
-            state = rememberLazyStaggeredGridState()
+            state = rememberState()
             LazyStaggeredGrid(
                 1,
                 Modifier.mainAxisSize(itemsSizeDp * 1.5f),
@@ -433,7 +452,7 @@
             }
         }
 
-        waitForPrefetch(13)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking(AutoTestFrameClock()) {
@@ -450,7 +469,7 @@
     @Test
     fun scrollingWithStaggeredItemsPrefetchesCorrectly() {
         rule.setContent {
-            state = rememberLazyStaggeredGridState()
+            state = rememberState()
             LazyStaggeredGrid(
                 2,
                 Modifier.mainAxisSize(itemsSizeDp * 5f),
@@ -461,7 +480,6 @@
                         activeNodes.add(it)
                         onDispose {
                             activeNodes.remove(it)
-                            activeMeasuredNodes.remove(it)
                         }
                     }
                     Spacer(
@@ -471,7 +489,6 @@
                             .testTag("$it")
                             .layout { measurable, constraints ->
                                 val placeable = measurable.measure(constraints)
-                                activeMeasuredNodes.add(it)
                                 layout(placeable.width, placeable.height) {
                                     placeable.place(0, 0)
                                 }
@@ -495,8 +512,8 @@
             }
         }
 
-        waitForPrefetch(7)
-        waitForPrefetch(8)
+        waitForPrefetch()
+        waitForPrefetch()
 
         // ┌─┬─┐
         // │2├─┤
@@ -520,14 +537,14 @@
         // │6├─┤
         // └─┴─┘
 
-        waitForPrefetch(9)
+        waitForPrefetch()
     }
 
     @Test
     fun fullSpanIsPrefetchedCorrectly() {
         val nodeConstraints = mutableMapOf<Int, Constraints>()
         rule.setContent {
-            state = rememberLazyStaggeredGridState()
+            state = rememberState()
             LazyStaggeredGrid(
                 2,
                 Modifier.mainAxisSize(itemsSizeDp * 5f).crossAxisSize(itemsSizeDp * 2f),
@@ -546,7 +563,6 @@
                         activeNodes.add(it)
                         onDispose {
                             activeNodes.remove(it)
-                            activeMeasuredNodes.remove(it)
                         }
                     }
                     Spacer(
@@ -555,7 +571,6 @@
                             .testTag("$it")
                             .layout { measurable, constraints ->
                                 val placeable = measurable.measure(constraints)
-                                activeMeasuredNodes.add(it)
                                 nodeConstraints.put(it, constraints)
                                 layout(placeable.width, placeable.height) {
                                     placeable.place(0, 0)
@@ -577,7 +592,7 @@
         state.scrollBy(itemsSizeDp * 5f)
         assertThat(activeNodes).contains(9)
 
-        waitForPrefetch(10)
+        waitForPrefetch()
         val expectedConstraints = if (vertical) {
             Constraints.fixedWidth(itemsSizePx * 2)
         } else {
@@ -589,7 +604,7 @@
     @Test
     fun fullSpanIsPrefetchedCorrectly_scrollingBack() {
         rule.setContent {
-            state = rememberLazyStaggeredGridState()
+            state = rememberState()
             LazyStaggeredGrid(
                 2,
                 Modifier.mainAxisSize(itemsSizeDp * 5f),
@@ -608,7 +623,6 @@
                         activeNodes.add(it)
                         onDispose {
                             activeNodes.remove(it)
-                            activeMeasuredNodes.remove(it)
                         }
                     }
                     Spacer(
@@ -618,7 +632,6 @@
                             .testTag("$it")
                             .layout { measurable, constraints ->
                                 val placeable = measurable.measure(constraints)
-                                activeMeasuredNodes.add(it)
                                 layout(placeable.width, placeable.height) {
                                     placeable.place(0, 0)
                                 }
@@ -647,27 +660,26 @@
 
         state.scrollBy(-1.dp)
 
-        waitForPrefetch(10)
+        waitForPrefetch()
     }
 
-    private fun waitForPrefetch(index: Int) {
-        rule.waitUntil {
-            activeNodes.contains(index) && activeMeasuredNodes.contains(index)
+    private fun waitForPrefetch() {
+        rule.runOnIdle {
+            scheduler.executeActiveRequests()
         }
     }
 
     private val activeNodes = mutableSetOf<Int>()
-    private val activeMeasuredNodes = mutableSetOf<Int>()
 
     private fun composeStaggeredGrid(
         firstItem: Int = 0,
         itemOffset: Int = 0,
     ) {
-        state = LazyStaggeredGridState(
-            initialFirstVisibleItemIndex = firstItem,
-            initialFirstVisibleItemOffset = itemOffset
-        )
         rule.setContent {
+            state = rememberState(
+                initialFirstVisibleItemIndex = firstItem,
+                initialFirstVisibleItemOffset = itemOffset
+            )
             LazyStaggeredGrid(
                 2,
                 Modifier.mainAxisSize(itemsSizeDp * 1.5f),
@@ -678,7 +690,6 @@
                         activeNodes.add(it)
                         onDispose {
                             activeNodes.remove(it)
-                            activeMeasuredNodes.remove(it)
                         }
                     }
                     Spacer(
@@ -688,7 +699,6 @@
                             .testTag("$it")
                             .layout { measurable, constraints ->
                                 val placeable = measurable.measure(constraints)
-                                activeMeasuredNodes.add(it)
                                 layout(placeable.width, placeable.height) {
                                     placeable.place(0, 0)
                                 }
diff --git a/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/BasePagerTest.kt b/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/BasePagerTest.kt
index bf5658f..dd8c1ce 100644
--- a/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/BasePagerTest.kt
+++ b/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/BasePagerTest.kt
@@ -30,9 +30,11 @@
 import androidx.compose.foundation.layout.fillMaxHeight
 import androidx.compose.foundation.layout.fillMaxSize
 import androidx.compose.foundation.layout.fillMaxWidth
+import androidx.compose.foundation.lazy.layout.PrefetchScheduler
 import androidx.compose.foundation.text.BasicText
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.CompositionLocalProvider
+import androidx.compose.runtime.remember
 import androidx.compose.runtime.rememberCoroutineScope
 import androidx.compose.ui.Alignment
 import androidx.compose.ui.Modifier
@@ -131,13 +133,21 @@
         key: ((index: Int) -> Any)? = null,
         snapPosition: SnapPosition = config.snapPosition.first,
         flingBehavior: TargetedFlingBehavior? = null,
+        prefetchScheduler: PrefetchScheduler? = null,
         pageContent: @Composable PagerScope.(page: Int) -> Unit = { Page(index = it) }
     ) {
 
         rule.setContent {
-            val state = rememberPagerState(initialPage, initialPageOffsetFraction, pageCount).also {
-                pagerState = it
+            val state = if (prefetchScheduler == null) {
+                rememberPagerState(initialPage, initialPageOffsetFraction, pageCount)
+            } else {
+                remember {
+                    object : PagerState(initialPage, initialPageOffsetFraction, prefetchScheduler) {
+                        override val pageCount: Int get() = pageCount()
+                    }
+                }
             }
+            pagerState = state
             composeView = LocalView.current
             focusManager = LocalFocusManager.current
             CompositionLocalProvider(
diff --git a/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/PagerAccessibilityTest.kt b/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/PagerAccessibilityTest.kt
index d72d1e3..19d6748 100644
--- a/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/PagerAccessibilityTest.kt
+++ b/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/PagerAccessibilityTest.kt
@@ -17,6 +17,7 @@
 package androidx.compose.foundation.pager
 
 import android.view.accessibility.AccessibilityNodeProvider
+import androidx.compose.foundation.ExperimentalFoundationApi
 import androidx.compose.foundation.focusable
 import androidx.compose.foundation.layout.Box
 import androidx.compose.foundation.layout.fillMaxSize
@@ -41,6 +42,7 @@
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized
 
+@OptIn(ExperimentalFoundationApi::class)
 @RunWith(Parameterized::class)
 class PagerAccessibilityTest(config: ParamConfig) : BasePagerTest(config = config) {
 
diff --git a/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/PagerPrefetcherTest.kt b/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/PagerPrefetcherTest.kt
index 1f9cc0b..f9414d3 100644
--- a/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/PagerPrefetcherTest.kt
+++ b/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/PagerPrefetcherTest.kt
@@ -56,6 +56,7 @@
     var pageSizePx = 300
     val pageSizeDp = with(rule.density) { pageSizePx.toDp() }
     var touchSlope: Float = 0.0f
+    private val scheduler = TestPrefetchScheduler()
 
     @Test
     fun notPrefetchingForwardInitially() {
@@ -83,7 +84,7 @@
             }
         }
 
-        waitForPrefetch(preFetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$preFetchIndex")
             .assertExists()
@@ -102,7 +103,7 @@
             }
         }
 
-        waitForPrefetch(preFetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$preFetchIndex")
             .assertExists()
@@ -126,7 +127,7 @@
             up()
         }
 
-        waitForPrefetch(preFetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$preFetchIndex")
             .assertExists()
@@ -151,7 +152,7 @@
             up()
         }
 
-        waitForPrefetch(preFetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$preFetchIndex")
             .assertExists()
@@ -170,7 +171,7 @@
             }
         }
         var prefetchIndex = initialIndex + 2
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -185,7 +186,7 @@
         }
 
         prefetchIndex -= 3
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -203,7 +204,7 @@
             }
         }
 
-        waitForPrefetch(2)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking {
@@ -214,7 +215,7 @@
 
         val prefetchIndex = 3
 
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("${prefetchIndex - 1}")
             .assertIsDisplayed()
@@ -236,7 +237,7 @@
             }
         }
 
-        waitForPrefetch(preFetchIndex)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking {
@@ -245,7 +246,7 @@
             }
         }
 
-        waitForPrefetch(preFetchIndex - 1)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$preFetchIndex")
             .assertIsDisplayed()
@@ -268,7 +269,7 @@
 
         var prefetchIndex = initialIndex + 2
 
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -283,7 +284,7 @@
         }
 
         prefetchIndex -= 3
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -319,7 +320,7 @@
         }
 
         var prefetchIndex = initialIndex + 1
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("${prefetchIndex + 1}")
             .assertExists()
@@ -333,7 +334,7 @@
         }
 
         prefetchIndex -= 3
-        waitForPrefetch(prefetchIndex)
+        waitForPrefetch()
 
         rule.onNodeWithTag("$prefetchIndex")
             .assertExists()
@@ -457,7 +458,7 @@
             }
         }
 
-        waitForPrefetch(7)
+        waitForPrefetch()
 
         rule.runOnIdle {
             runBlocking(AutoTestFrameClock()) {
@@ -477,14 +478,13 @@
         return consumed
     }
 
-    private fun waitForPrefetch(index: Int) {
-        rule.waitUntil {
-            activeNodes.contains(index) && activeMeasuredNodes.contains(index)
+    private fun waitForPrefetch() {
+        rule.runOnIdle {
+            scheduler.executeActiveRequests()
         }
     }
 
     private val activeNodes = mutableSetOf<Int>()
-    private val activeMeasuredNodes = mutableSetOf<Int>()
 
     private fun composePager(
         initialPage: Int = 0,
@@ -499,6 +499,7 @@
             beyondViewportPageCount = paramConfig.beyondViewportPageCount,
             initialPage = initialPage,
             initialPageOffsetFraction = initialPageOffsetFraction,
+            prefetchScheduler = scheduler,
             pageCount = { 100 },
             pageSize = {
                 object : PageSize {
@@ -516,7 +517,6 @@
                 activeNodes.add(it)
                 onDispose {
                     activeNodes.remove(it)
-                    activeMeasuredNodes.remove(it)
                 }
             }
 
@@ -527,7 +527,6 @@
                     .testTag("$it")
                     .layout { measurable, constraints ->
                         val placeable = measurable.measure(constraints)
-                        activeMeasuredNodes.add(it)
                         layout(placeable.width, placeable.height) {
                             placeable.place(0, 0)
                         }
diff --git a/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/TestPrefetchScheduler.kt b/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/TestPrefetchScheduler.kt
new file mode 100644
index 0000000..04c60b750
--- /dev/null
+++ b/compose/foundation/foundation/src/androidInstrumentedTest/kotlin/androidx/compose/foundation/pager/TestPrefetchScheduler.kt
@@ -0,0 +1,42 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.compose.foundation.pager
+
+import androidx.compose.foundation.ExperimentalFoundationApi
+import androidx.compose.foundation.lazy.layout.PrefetchRequest
+import androidx.compose.foundation.lazy.layout.PrefetchRequestScope
+import androidx.compose.foundation.lazy.layout.PrefetchScheduler
+
+@OptIn(ExperimentalFoundationApi::class)
+internal class TestPrefetchScheduler : PrefetchScheduler {
+
+    private var activeRequests = mutableListOf<PrefetchRequest>()
+    override fun schedulePrefetch(prefetchRequest: PrefetchRequest) {
+        activeRequests.add(prefetchRequest)
+    }
+
+    fun executeActiveRequests() {
+        activeRequests.forEach {
+            with(it) { scope.execute() }
+        }
+        activeRequests.clear()
+    }
+
+    private val scope = object : PrefetchRequestScope {
+        override fun availableTimeNanos(): Long = Long.MAX_VALUE
+    }
+}
diff --git a/compose/foundation/foundation/src/androidMain/kotlin/androidx/compose/foundation/lazy/layout/PrefetchScheduler.android.kt b/compose/foundation/foundation/src/androidMain/kotlin/androidx/compose/foundation/lazy/layout/PrefetchScheduler.android.kt
index 54abefd..fd93543 100644
--- a/compose/foundation/foundation/src/androidMain/kotlin/androidx/compose/foundation/lazy/layout/PrefetchScheduler.android.kt
+++ b/compose/foundation/foundation/src/androidMain/kotlin/androidx/compose/foundation/lazy/layout/PrefetchScheduler.android.kt
@@ -126,16 +126,19 @@
         }
         val latestFrameVsyncNs = TimeUnit.MILLISECONDS.toNanos(view.drawingTime)
         val nextFrameNs = latestFrameVsyncNs + frameIntervalNs
-        val oneOverTimeTaskAllowed = System.nanoTime() > nextFrameNs
-        val scope = PrefetchRequestScopeImpl(nextFrameNs, oneOverTimeTaskAllowed)
+        val scope = PrefetchRequestScopeImpl(nextFrameNs)
         var scheduleForNextFrame = false
         while (prefetchRequests.isNotEmpty() && !scheduleForNextFrame) {
-            val request = prefetchRequests[0]
-            val hasMoreWorkToDo = with(request) { scope.execute() }
-            if (hasMoreWorkToDo) {
-                scheduleForNextFrame = true
+            if (scope.availableTimeNanos() > 0) {
+                val request = prefetchRequests[0]
+                val hasMoreWorkToDo = with(request) { scope.execute() }
+                if (hasMoreWorkToDo) {
+                    scheduleForNextFrame = true
+                } else {
+                    prefetchRequests.removeAt(0)
+                }
             } else {
-                prefetchRequests.removeAt(0)
+                scheduleForNextFrame = true
             }
         }
 
@@ -182,24 +185,10 @@
 
     class PrefetchRequestScopeImpl(
         private val nextFrameTimeNs: Long,
-        isOneOverTimeTaskAllowed: Boolean
     ) : PrefetchRequestScope {
 
-        private var canDoOverTimeTask = isOneOverTimeTaskAllowed
-
-        override val availableTimeNanos: Long
-            get() {
-                // This logic is meant to be temporary until we replace the isOneOverTimeTaskAllowed
-                // logic with something more general. For now, we assume that a PrefetchRequest
-                // impl will check availableTimeNanos once per task and we give it a large amount
-                // of time the first time it checks if we allow an overtime task.
-                return if (canDoOverTimeTask) {
-                    canDoOverTimeTask = false
-                    Long.MAX_VALUE
-                } else {
-                    max(0, nextFrameTimeNs - System.nanoTime())
-                }
-            }
+        override fun availableTimeNanos() =
+            max(0, nextFrameTimeNs - System.nanoTime())
     }
 
     companion object {
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListPrefetchStrategy.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListPrefetchStrategy.kt
index dc29f8c..83840e9 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListPrefetchStrategy.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListPrefetchStrategy.kt
@@ -147,21 +147,38 @@
             } else {
                 layoutInfo.visibleItemsInfo.first().index - 1
             }
-            if (indexToPrefetch != [email protected] &&
-                indexToPrefetch in 0 until layoutInfo.totalItemsCount
-            ) {
-                if (wasScrollingForward != scrollingForward) {
-                    // the scrolling direction has been changed which means the last prefetched
-                    // is not going to be reached anytime soon so it is safer to dispose it.
-                    // if this item is already visible it is safe to call the method anyway
-                    // as it will be no-op
-                    currentPrefetchHandle?.cancel()
+            if (indexToPrefetch in 0 until layoutInfo.totalItemsCount) {
+                if (indexToPrefetch != [email protected]) {
+                    if (wasScrollingForward != scrollingForward) {
+                        // the scrolling direction has been changed which means the last prefetched
+                        // is not going to be reached anytime soon so it is safer to dispose it.
+                        // if this item is already visible it is safe to call the method anyway
+                        // as it will be no-op
+                        currentPrefetchHandle?.cancel()
+                    }
+                    [email protected] = scrollingForward
+                    [email protected] = indexToPrefetch
+                    currentPrefetchHandle = schedulePrefetch(
+                        indexToPrefetch
+                    )
                 }
-                [email protected] = scrollingForward
-                [email protected] = indexToPrefetch
-                currentPrefetchHandle = schedulePrefetch(
-                    indexToPrefetch
-                )
+                if (scrollingForward) {
+                    val lastItem = layoutInfo.visibleItemsInfo.last()
+                    val spacing = layoutInfo.mainAxisItemSpacing
+                    val distanceToPrefetchItem =
+                        lastItem.offset + lastItem.size + spacing - layoutInfo.viewportEndOffset
+                    // if in the next frame we will get the same delta will we reach the item?
+                    if (distanceToPrefetchItem < -delta) {
+                        currentPrefetchHandle?.markAsUrgent()
+                    }
+                } else {
+                    val firstItem = layoutInfo.visibleItemsInfo.first()
+                    val distanceToPrefetchItem = layoutInfo.viewportStartOffset - firstItem.offset
+                    // if in the next frame we will get the same delta will we reach the item?
+                    if (distanceToPrefetchItem < delta) {
+                        currentPrefetchHandle?.markAsUrgent()
+                    }
+                }
             }
         }
     }
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridState.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridState.kt
index ab042b2..d12d075 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridState.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridState.kt
@@ -22,6 +22,7 @@
 import androidx.compose.foundation.gestures.Orientation
 import androidx.compose.foundation.gestures.ScrollScope
 import androidx.compose.foundation.gestures.ScrollableState
+import androidx.compose.foundation.gestures.snapping.offsetOnMainAxis
 import androidx.compose.foundation.gestures.stopScroll
 import androidx.compose.foundation.interaction.InteractionSource
 import androidx.compose.foundation.interaction.MutableInteractionSource
@@ -31,6 +32,7 @@
 import androidx.compose.foundation.lazy.layout.LazyLayoutPinnedItemList
 import androidx.compose.foundation.lazy.layout.LazyLayoutPrefetchState
 import androidx.compose.foundation.lazy.layout.ObservableScopeInvalidator
+import androidx.compose.foundation.lazy.layout.PrefetchScheduler
 import androidx.compose.foundation.lazy.layout.animateScrollToItem
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.Stable
@@ -81,19 +83,26 @@
  * A state object that can be hoisted to control and observe scrolling.
  *
  * In most cases, this will be created via [rememberLazyGridState].
- *
- * @param firstVisibleItemIndex the initial value for [LazyGridState.firstVisibleItemIndex]
- * @param firstVisibleItemScrollOffset the initial value for
- * [LazyGridState.firstVisibleItemScrollOffset]
  */
 @OptIn(ExperimentalFoundationApi::class)
 @Stable
-class LazyGridState constructor(
+class LazyGridState internal constructor(
     firstVisibleItemIndex: Int = 0,
-    firstVisibleItemScrollOffset: Int = 0
+    firstVisibleItemScrollOffset: Int = 0,
+    prefetchScheduler: PrefetchScheduler?,
 ) : ScrollableState {
 
     /**
+     * @param firstVisibleItemIndex the initial value for [LazyGridState.firstVisibleItemIndex]
+     * @param firstVisibleItemScrollOffset the initial value for
+     * [LazyGridState.firstVisibleItemScrollOffset]
+     */
+    constructor(
+        firstVisibleItemIndex: Int = 0,
+        firstVisibleItemScrollOffset: Int = 0
+    ) : this(firstVisibleItemIndex, firstVisibleItemScrollOffset, null)
+
+    /**
      * The holder class for the current scroll position.
      */
     private val scrollPosition =
@@ -413,23 +422,40 @@
                 }
                 closestNextItemToPrefetch = info.visibleItemsInfo.first().index - 1
             }
-            if (lineToPrefetch != this.lineToPrefetch &&
-                closestNextItemToPrefetch in 0 until info.totalItemsCount
-            ) {
-                if (wasScrollingForward != scrollingForward) {
-                    // the scrolling direction has been changed which means the last prefetched
-                    // is not going to be reached anytime soon so it is safer to dispose it.
-                    // if this line is already visible it is safe to call the method anyway
-                    // as it will be no-op
-                    currentLinePrefetchHandles.forEach { it.cancel() }
+            if (closestNextItemToPrefetch in 0 until info.totalItemsCount) {
+                if (lineToPrefetch != this.lineToPrefetch) {
+                    if (wasScrollingForward != scrollingForward) {
+                        // the scrolling direction has been changed which means the last prefetched
+                        // is not going to be reached anytime soon so it is safer to dispose it.
+                        // if this line is already visible it is safe to call the method anyway
+                        // as it will be no-op
+                        currentLinePrefetchHandles.forEach { it.cancel() }
+                    }
+                    this.wasScrollingForward = scrollingForward
+                    this.lineToPrefetch = lineToPrefetch
+                    currentLinePrefetchHandles.clear()
+                    info.prefetchInfoRetriever(lineToPrefetch).fastForEach {
+                        currentLinePrefetchHandles.add(
+                            prefetchState.schedulePrefetch(it.first, it.second)
+                        )
+                    }
                 }
-                this.wasScrollingForward = scrollingForward
-                this.lineToPrefetch = lineToPrefetch
-                currentLinePrefetchHandles.clear()
-                info.prefetchInfoRetriever(lineToPrefetch).fastForEach {
-                    currentLinePrefetchHandles.add(
-                        prefetchState.schedulePrefetch(it.first, it.second)
-                    )
+                if (scrollingForward) {
+                    val lastItem = info.visibleItemsInfo.last()
+                    val distanceToPrefetchItem = lastItem.offsetOnMainAxis(info.orientation) +
+                        lastItem.mainAxisSizeWithSpacings - info.viewportEndOffset
+                    // if in the next frame we will get the same delta will we reach the item?
+                    if (distanceToPrefetchItem < -delta) {
+                        currentLinePrefetchHandles.forEach { it.markAsUrgent() }
+                    }
+                } else {
+                    val firstItem = info.visibleItemsInfo.first()
+                    val distanceToPrefetchItem = info.viewportStartOffset -
+                        firstItem.offsetOnMainAxis(info.orientation)
+                    // if in the next frame we will get the same delta will we reach the item?
+                    if (distanceToPrefetchItem < delta) {
+                        currentLinePrefetchHandles.forEach { it.markAsUrgent() }
+                    }
                 }
             }
         }
@@ -454,7 +480,7 @@
         }
     }
 
-    internal val prefetchState = LazyLayoutPrefetchState()
+    internal val prefetchState = LazyLayoutPrefetchState(prefetchScheduler)
 
     private val numOfItemsToTeleport: Int get() = 100 * slotsPerLine
 
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutPrefetchState.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutPrefetchState.kt
index 11c2a13..20d6ca5 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutPrefetchState.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutPrefetchState.kt
@@ -79,6 +79,15 @@
          * was precomposed already it will be disposed.
          */
         fun cancel()
+
+        /**
+         * Marks this prefetch request as urgent, which is a way to communicate that the requested
+         * item is expected to be needed during the next frame.
+         *
+         * For urgent requests we can proceed with doing the prefetch even if the available time
+         * in the frame is less than we spend on similar prefetch requests on average.
+         */
+        fun markAsUrgent()
     }
 
     private inner class NestedPrefetchScopeImpl : NestedPrefetchScope {
@@ -169,6 +178,7 @@
 @ExperimentalFoundationApi
 private object DummyHandle : PrefetchHandle {
     override fun cancel() {}
+    override fun markAsUrgent() {}
 }
 
 /**
@@ -212,6 +222,7 @@
         private val isComposed get() = precomposeHandle != null
         private var hasResolvedNestedPrefetches = false
         private var nestedPrefetchController: NestedPrefetchController? = null
+        private var isUrgent = false
 
         private val isValid
             get() = !isCanceled &&
@@ -225,13 +236,24 @@
             }
         }
 
+        override fun markAsUrgent() {
+            isUrgent = true
+        }
+
+        private fun PrefetchRequestScope.shouldExecute(average: Long): Boolean {
+            val available = availableTimeNanos()
+            // even for urgent request we only do the work if we have time available, as otherwise
+            // it is better to just return early to allow the next frame to start and do the work.
+            return (isUrgent && available > 0) || average < available
+        }
+
         override fun PrefetchRequestScope.execute(): Boolean {
             if (!isValid) {
                 return false
             }
 
             if (!isComposed) {
-                if (prefetchMetrics.averageCompositionTimeNanos < availableTimeNanos) {
+                if (shouldExecute(prefetchMetrics.averageCompositionTimeNanos)) {
                     prefetchMetrics.recordCompositionTiming {
                         trace("compose:lazy:prefetch:compose") {
                             performComposition()
@@ -242,27 +264,35 @@
                 }
             }
 
-            // Nested prefetch logic is best-effort: if nested LazyLayout children are
-            // added/removed/updated after we've resolved nested prefetch states here or resolved
-            // nestedPrefetchRequests below, those changes won't be taken into account.
-            if (!hasResolvedNestedPrefetches) {
-                if (availableTimeNanos > 0) {
-                    trace("compose:lazy:prefetch:resolve-nested") {
-                        nestedPrefetchController = resolveNestedPrefetchStates()
-                        hasResolvedNestedPrefetches = true
+            // if the request is urgent we better proceed with the measuring straight away instead
+            // of spending time trying to split the work more via nested prefetch. nested prefetch
+            // is always an estimation and it could potentially do work we will not need in the end,
+            // but the measuring will only do exactly the needed work (including composing nested
+            // lazy layouts)
+            if (!isUrgent) {
+                // Nested prefetch logic is best-effort: if nested LazyLayout children are
+                // added/removed/updated after we've resolved nested prefetch states here or resolved
+                // nestedPrefetchRequests below, those changes won't be taken into account.
+                if (!hasResolvedNestedPrefetches) {
+                    if (availableTimeNanos() > 0) {
+                        trace("compose:lazy:prefetch:resolve-nested") {
+                            nestedPrefetchController = resolveNestedPrefetchStates()
+                            hasResolvedNestedPrefetches = true
+                        }
+                    } else {
+                        return true
                     }
-                } else {
+                }
+
+                val hasMoreWork =
+                    nestedPrefetchController?.run { executeNestedPrefetches() } ?: false
+                if (hasMoreWork) {
                     return true
                 }
             }
 
-            val hasMoreWork = nestedPrefetchController?.run { executeNestedPrefetches() } ?: false
-            if (hasMoreWork) {
-                return true
-            }
-
             if (!isMeasured && constraints != null) {
-                if (prefetchMetrics.averageMeasureTimeNanos < availableTimeNanos) {
+                if (shouldExecute(prefetchMetrics.averageMeasureTimeNanos)) {
                     prefetchMetrics.recordMeasureTiming {
                         trace("compose:lazy:prefetch:measure") {
                             performMeasure(constraints)
@@ -349,7 +379,7 @@
                 trace("compose:lazy:prefetch:nested") {
                     while (stateIndex < states.size) {
                         if (requestsByState[stateIndex] == null) {
-                            if (availableTimeNanos <= 0) {
+                            if (availableTimeNanos() <= 0) {
                                 // When we have time again, we'll resolve nested requests for this
                                 // state
                                 return true
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/PrefetchScheduler.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/PrefetchScheduler.kt
index d3497f8..131eb4f 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/PrefetchScheduler.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/PrefetchScheduler.kt
@@ -75,5 +75,5 @@
      * How much time is available to do prefetch work. Implementations of [PrefetchRequest] should
      * do their best to fit their work into this time without going over.
      */
-    val availableTimeNanos: Long
+    fun availableTimeNanos(): Long
 }
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridState.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridState.kt
index 09cb286..3a15ecb 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridState.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridState.kt
@@ -34,6 +34,7 @@
 import androidx.compose.foundation.lazy.layout.LazyLayoutPrefetchState
 import androidx.compose.foundation.lazy.layout.LazyLayoutPrefetchState.PrefetchHandle
 import androidx.compose.foundation.lazy.layout.ObservableScopeInvalidator
+import androidx.compose.foundation.lazy.layout.PrefetchScheduler
 import androidx.compose.foundation.lazy.layout.animateScrollToItem
 import androidx.compose.foundation.lazy.staggeredgrid.LazyStaggeredGridLaneInfo.Companion.FullSpan
 import androidx.compose.foundation.lazy.staggeredgrid.LazyStaggeredGridLaneInfo.Companion.Unset
@@ -82,9 +83,10 @@
  * In most cases, it should be created via [rememberLazyStaggeredGridState].
  */
 @OptIn(ExperimentalFoundationApi::class)
-class LazyStaggeredGridState private constructor(
+class LazyStaggeredGridState internal constructor(
     initialFirstVisibleItems: IntArray,
     initialFirstVisibleOffsets: IntArray,
+    prefetchScheduler: PrefetchScheduler?
 ) : ScrollableState {
     /**
      * @param initialFirstVisibleItemIndex initial value for [firstVisibleItemIndex]
@@ -95,7 +97,8 @@
         initialFirstVisibleItemOffset: Int = 0
     ) : this(
         intArrayOf(initialFirstVisibleItemIndex),
-        intArrayOf(initialFirstVisibleItemOffset)
+        intArrayOf(initialFirstVisibleItemOffset),
+        null
     )
 
     /**
@@ -178,7 +181,7 @@
     internal var prefetchingEnabled: Boolean = true
 
     /** prefetch state used for precomputing items in the direction of scroll */
-    internal val prefetchState: LazyLayoutPrefetchState = LazyLayoutPrefetchState()
+    internal val prefetchState: LazyLayoutPrefetchState = LazyLayoutPrefetchState(prefetchScheduler)
 
     /** state controlling the scroll */
     private val scrollableState = ScrollableState { -onScroll(-it) }
@@ -584,7 +587,7 @@
                 )
             },
             restore = {
-                LazyStaggeredGridState(it[0], it[1])
+                LazyStaggeredGridState(it[0], it[1], null)
             }
         )
     }
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/pager/PagerState.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/pager/PagerState.kt
index edf220c..a4310c5 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/pager/PagerState.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/pager/PagerState.kt
@@ -36,6 +36,7 @@
 import androidx.compose.foundation.lazy.layout.LazyLayoutPinnedItemList
 import androidx.compose.foundation.lazy.layout.LazyLayoutPrefetchState
 import androidx.compose.foundation.lazy.layout.ObservableScopeInvalidator
+import androidx.compose.foundation.lazy.layout.PrefetchScheduler
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.Stable
 import androidx.compose.runtime.derivedStateOf
@@ -147,18 +148,26 @@
 
 /**
  * The state that can be used to control [VerticalPager] and [HorizontalPager]
- * @param currentPage The initial page to be displayed
- * @param currentPageOffsetFraction The offset of the initial page with respect to the start of
- * the layout.
  */
 @OptIn(ExperimentalFoundationApi::class)
 @Stable
-abstract class PagerState(
+abstract class PagerState internal constructor(
     currentPage: Int = 0,
-    @FloatRange(from = -0.5, to = 0.5) currentPageOffsetFraction: Float = 0f
+    @FloatRange(from = -0.5, to = 0.5) currentPageOffsetFraction: Float = 0f,
+    prefetchScheduler: PrefetchScheduler? = null
 ) : ScrollableState {
 
     /**
+     * @param currentPage The initial page to be displayed
+     * @param currentPageOffsetFraction The offset of the initial page with respect to the start of
+     * the layout.
+     */
+    constructor(
+        currentPage: Int = 0,
+        @FloatRange(from = -0.5, to = 0.5) currentPageOffsetFraction: Float = 0f
+    ) : this(currentPage, currentPageOffsetFraction, null)
+
+    /**
      * The total amount of pages present in this pager. The source of this data should be
      * observable.
      */
@@ -431,7 +440,7 @@
      */
     val currentPageOffsetFraction: Float get() = scrollPosition.currentPageOffsetFraction
 
-    internal val prefetchState = LazyLayoutPrefetchState()
+    internal val prefetchState = LazyLayoutPrefetchState(prefetchScheduler)
 
     internal val beyondBoundsInfo = LazyLayoutBeyondBoundsInfo()
 
@@ -716,21 +725,38 @@
             } else {
                 info.visiblePagesInfo.first().index - info.beyondViewportPageCount - PagesToPrefetch
             }
-            if (indexToPrefetch != this.indexToPrefetch &&
-                indexToPrefetch in 0 until pageCount
-            ) {
-                if (wasPrefetchingForward != isPrefetchingForward) {
-                    // the scrolling direction has been changed which means the last prefetched
-                    // is not going to be reached anytime soon so it is safer to dispose it.
-                    // if this item is already visible it is safe to call the method anyway
-                    // as it will be no-op
-                    currentPrefetchHandle?.cancel()
+            if (indexToPrefetch in 0 until pageCount) {
+                if (indexToPrefetch != this.indexToPrefetch) {
+                    if (wasPrefetchingForward != isPrefetchingForward) {
+                        // the scrolling direction has been changed which means the last prefetched
+                        // is not going to be reached anytime soon so it is safer to dispose it.
+                        // if this item is already visible it is safe to call the method anyway
+                        // as it will be no-op
+                        currentPrefetchHandle?.cancel()
+                    }
+                    this.wasPrefetchingForward = isPrefetchingForward
+                    this.indexToPrefetch = indexToPrefetch
+                    currentPrefetchHandle = prefetchState.schedulePrefetch(
+                        indexToPrefetch, premeasureConstraints
+                    )
                 }
-                this.wasPrefetchingForward = isPrefetchingForward
-                this.indexToPrefetch = indexToPrefetch
-                currentPrefetchHandle = prefetchState.schedulePrefetch(
-                    indexToPrefetch, premeasureConstraints
-                )
+                if (isPrefetchingForward) {
+                    val lastItem = info.visiblePagesInfo.last()
+                    val pageSize = info.pageSize + info.pageSpacing
+                    val distanceToReachNextItem =
+                        lastItem.offset + pageSize - info.viewportEndOffset
+                    // if in the next frame we will get the same delta will we reach the item?
+                    if (distanceToReachNextItem < delta) {
+                        currentPrefetchHandle?.markAsUrgent()
+                    }
+                } else {
+                    val firstItem = info.visiblePagesInfo.first()
+                    val distanceToReachNextItem = info.viewportStartOffset - firstItem.offset
+                    // if in the next frame we will get the same delta will we reach the item?
+                    if (distanceToReachNextItem < -delta) {
+                        currentPrefetchHandle?.markAsUrgent()
+                    }
+                }
             }
         }
     }
diff --git a/compose/integration-tests/macrobenchmark/src/main/java/androidx/compose/integration/macrobenchmark/ComplexNestedListsScrollBenchmark.kt b/compose/integration-tests/macrobenchmark/src/main/java/androidx/compose/integration/macrobenchmark/ComplexNestedListsScrollBenchmark.kt
index 003b307..c070d43 100644
--- a/compose/integration-tests/macrobenchmark/src/main/java/androidx/compose/integration/macrobenchmark/ComplexNestedListsScrollBenchmark.kt
+++ b/compose/integration-tests/macrobenchmark/src/main/java/androidx/compose/integration/macrobenchmark/ComplexNestedListsScrollBenchmark.kt
@@ -19,6 +19,8 @@
 import android.content.Intent
 import android.graphics.Point
 import androidx.benchmark.macro.CompilationMode
+import androidx.benchmark.macro.ExperimentalMetricApi
+import androidx.benchmark.macro.FrameTimingGfxInfoMetric
 import androidx.benchmark.macro.FrameTimingMetric
 import androidx.benchmark.macro.junit4.MacrobenchmarkRule
 import androidx.test.platform.app.InstrumentationRegistry
@@ -41,11 +43,12 @@
         device = UiDevice.getInstance(instrumentation)
     }
 
+    @OptIn(ExperimentalMetricApi::class)
     @Test
     fun start() {
         benchmarkRule.measureRepeated(
             packageName = PACKAGE_NAME,
-            metrics = listOf(FrameTimingMetric()),
+            metrics = listOf(FrameTimingMetric(), FrameTimingGfxInfoMetric()),
             compilationMode = CompilationMode.Full(),
             iterations = 8,
             setupBlock = {