Merge "Refactor VelocityTracker Logic Into One Dimension" into androidx-main
diff --git a/compose/ui/ui/src/commonMain/kotlin/androidx/compose/ui/input/pointer/util/VelocityTracker.kt b/compose/ui/ui/src/commonMain/kotlin/androidx/compose/ui/input/pointer/util/VelocityTracker.kt
index 4edb9c4..0e9ab68 100644
--- a/compose/ui/ui/src/commonMain/kotlin/androidx/compose/ui/input/pointer/util/VelocityTracker.kt
+++ b/compose/ui/ui/src/commonMain/kotlin/androidx/compose/ui/input/pointer/util/VelocityTracker.kt
@@ -46,9 +46,9 @@
  */
 class VelocityTracker {
 
-    // Circular buffer; current sample at index.
-    private val samples: Array<PointAtTime?> = Array(HistorySize) { null }
-    private var index: Int = 0
+    private val xVelocityTracker = VelocityTracker1D()
+    private val yVelocityTracker = VelocityTracker1D()
+
     internal var currentPointerPositionAccumulator = Offset.Zero
 
     /**
@@ -62,8 +62,8 @@
     //   positions. For velocity tracking, the only thing that is important is the change in
     //   position over time.
     fun addPosition(timeMillis: Long, position: Offset) {
-        index = (index + 1) % HistorySize
-        samples[index] = PointAtTime(position, timeMillis)
+        xVelocityTracker.addDataPoint(timeMillis, position.x)
+        yVelocityTracker.addDataPoint(timeMillis, position.y)
     }
 
     /**
@@ -72,42 +72,85 @@
      * This can be expensive. Only call this when you need the velocity.
      */
     fun calculateVelocity(): Velocity {
-        val estimate = getVelocityEstimate().pixelsPerSecond
-        return Velocity(estimate.x, estimate.y)
+        return Velocity(xVelocityTracker.calculateVelocity(), yVelocityTracker.calculateVelocity())
     }
 
     /**
      * Clears the tracked positions added by [addPosition].
      */
     fun resetTracking() {
+        xVelocityTracker.resetTracking()
+        yVelocityTracker.resetTracking()
+    }
+}
+
+/**
+ * A velocity tracker calculating velocity in 1 dimension.
+ *
+ * Add displacement data points using [addDataPoint], and obtain velocity using [calculateVelocity].
+ */
+private class VelocityTracker1D {
+    // Circular buffer; current sample at index.
+    private val samples: Array<DataPointAtTime?> = Array(HistorySize) { null }
+    private var index: Int = 0
+
+    /**
+     * Adds a data point for velocity calculation. A data point should represent a position along
+     * the tracked axis at a given time, [timeMillis].
+     *
+     * Use the same units for the data points provided. For example, having some data points in `cm`
+     * and some in `m` will result in incorrect velocity calculations, as this method (and the
+     * tracker) has no knowledge of the units used.
+     */
+    fun addDataPoint(timeMillis: Long, dataPoint: Float) {
+        index = (index + 1) % HistorySize
+        samples[index] = DataPointAtTime(dataPoint, timeMillis)
+    }
+
+    /**
+     * Computes the estimated velocity at the time of the last provided data point. The units of
+     * velocity will be `units/second`, where `units` is the units of the data points provided via
+     * [addDataPoint].
+     *
+     * This can be expensive. Only call this when you need the velocity.
+     */
+    fun calculateVelocity(): Float {
+        return getVelocityEstimate().velocity
+    }
+
+    /**
+     * Clears the tracked positions added by [addDataPoint].
+     */
+    fun resetTracking() {
         samples.fill(element = null)
+        index = 0
     }
 
     /**
      * Returns an estimate of the velocity of the object being tracked by the
      * tracker given the current information available to the tracker.
      *
-     * Information is added using [addPosition].
+     * Information is added using [addDataPoint].
      *
-     * Returns null if there is no data on which to base an estimate.
+     * Returns an estimate of 0 velocity if there is no data on which to base an estimate.
      */
+
     private fun getVelocityEstimate(): VelocityEstimate {
-        val x: MutableList<Float> = mutableListOf()
-        val y: MutableList<Float> = mutableListOf()
+        val dataPoints: MutableList<Float> = mutableListOf()
         val time: MutableList<Float> = mutableListOf()
         var sampleCount = 0
         var index: Int = index
 
         // The sample at index is our newest sample.  If it is null, we have no samples so return.
-        val newestSample: PointAtTime = samples[index] ?: return VelocityEstimate.None
+        val newestSample: DataPointAtTime = samples[index] ?: return VelocityEstimate.None
 
-        var previousSample: PointAtTime = newestSample
-        var oldestSample: PointAtTime = newestSample
+        var previousSample: DataPointAtTime = newestSample
+        var oldestSample: DataPointAtTime = newestSample
 
         // Starting with the most recent PointAtTime sample, iterate backwards while
         // the samples represent continuous motion.
         do {
-            val sample: PointAtTime = samples[index] ?: break
+            val sample: DataPointAtTime = samples[index] ?: break
 
             val age: Float = (newestSample.time - sample.time).toFloat()
             val delta: Float =
@@ -118,9 +161,7 @@
             }
 
             oldestSample = sample
-            val position: Offset = sample.point
-            x.add(position.x)
-            y.add(position.y)
+            dataPoints.add(sample.dataPoint)
             time.add(-age)
             index = (if (index == 0) HistorySize else index) - 1
 
@@ -129,23 +170,17 @@
 
         if (sampleCount >= MinSampleSize) {
             try {
-                val xFit: PolynomialFit = polyFitLeastSquares(time, x, 2)
-                val yFit: PolynomialFit = polyFitLeastSquares(time, y, 2)
-
+                val fit = polyFitLeastSquares(time, dataPoints, 2)
                 // The 2nd coefficient is the derivative of the quadratic polynomial at
                 // x = 0, and that happens to be the last timestamp that we end up
-                // passing to polyFitLeastSquares for both x and y.
-                val xSlope = xFit.coefficients[1]
-                val ySlope = yFit.coefficients[1]
+                // passing to polyFitLeastSquares.
+                val slope = fit.coefficients[1]
                 return VelocityEstimate(
-                    pixelsPerSecond = Offset(
-                        // Convert from pixels/ms to pixels/s
-                        (xSlope * 1000),
-                        (ySlope * 1000)
-                    ),
-                    confidence = xFit.confidence * yFit.confidence,
+                    // Convert from units/ms to units/s
+                    velocity = slope * 1000,
+                    confidence = fit.confidence,
                     durationMillis = newestSample.time - oldestSample.time,
-                    offset = newestSample.point - oldestSample.point
+                    offset = newestSample.dataPoint - oldestSample.dataPoint
                 )
             } catch (exception: IllegalArgumentException) {
                 // TODO(b/129494918): Is catching an exception here something we really want to do?
@@ -156,10 +191,10 @@
         // We're unable to make a velocity estimate but we did have at least one
         // valid pointer position.
         return VelocityEstimate(
-            pixelsPerSecond = Offset.Zero,
+            velocity = 0f,
             confidence = 1.0f,
             durationMillis = newestSample.time - oldestSample.time,
-            offset = newestSample.point - oldestSample.point
+            offset = newestSample.dataPoint - oldestSample.dataPoint
         )
     }
 }
@@ -214,13 +249,12 @@
     addPosition(event.uptimeMillis, currentPointerPositionAccumulator)
 }
 
-private data class PointAtTime(val point: Offset, val time: Long)
+private data class DataPointAtTime(val dataPoint: Float, val time: Long)
 
 /**
- * A two dimensional velocity estimate.
+ * A velocity estimate.
  *
- * VelocityEstimates are computed by [VelocityTracker.getImpulseVelocity]. An
- * estimate's [confidence] measures how well the velocity tracker's position
+ * An estimate's [confidence] measures how well the velocity tracker's position
  * data fit a straight line, [durationMillis] is the time that elapsed between the
  * first and last position sample used to compute the velocity, and [offset]
  * is similarly the difference between the first and last positions.
@@ -228,12 +262,10 @@
  * See also:
  *
  *  * VelocityTracker, which computes [VelocityEstimate]s.
- *  * Velocity, which encapsulates (just) a velocity vector and provides some
- *    useful velocity operations.
  */
 private data class VelocityEstimate(
-    /** The number of pixels per second of velocity in the x and y directions. */
-    val pixelsPerSecond: Offset,
+    /** The velocity, in  units per second. */
+    val velocity: Float,
     /**
      * A value between 0.0 and 1.0 that indicates how well [VelocityTracker]
      * was able to fit a straight line to its position data.
@@ -243,17 +275,17 @@
     val confidence: Float,
     /**
      * The time that elapsed between the first and last position sample used
-     * to compute [pixelsPerSecond].
+     * to compute [velocity].
      */
     val durationMillis: Long,
     /**
-     * The difference between the first and last position sample used
-     * to compute [pixelsPerSecond].
+     * The difference between the first and last datapoint used
+     * to compute [velocity].
      */
-    val offset: Offset
+    val offset: Float
 ) {
     companion object {
-        val None = VelocityEstimate(Offset.Zero, 1f, 0, Offset.Zero)
+        val None = VelocityEstimate(0f, 1f, 0, 0f)
     }
 }