Ensure prediction does not goes backwards while lifting

At the moment the prediction will be aborted if we detect that
the stylus is likely to be lifted (pressure < 0.1); however, this
causes the prediction to disappear and make it feel like it's
going backwards due to the pressure prediction not matching the
real result. This change will force the prediction to be generated
up to the time of the last one, since the previous prediction did
not predict the lift.

Bug: 302300930
Test: `gradlew :input:input-motionprediction:test`
Change-Id: I3c254e2e3b559f88482aee7f2e339ee81bae724b
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/SinglePointerPredictor.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/SinglePointerPredictor.java
index 633cbb7..fd64a94 100644
--- a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/SinglePointerPredictor.java
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/SinglePointerPredictor.java
@@ -70,8 +70,8 @@
     private final PointerKalmanFilter mKalman = new PointerKalmanFilter(0.01, 1.0);
 
     private final DVector2 mLastPosition = new DVector2();
-    private long mPrevEventTime;
-    private long mPrevPredictedEventTime;
+    private long mLastSeenEventTime;
+    private long mLastPredictEventTime;
     private long mDownEventTime;
     private List<Float> mReportRates = new LinkedList<>();
     private int mExpectedPredictionSampleSize = -1;
@@ -102,8 +102,8 @@
      */
     public SinglePointerPredictor(int pointerId, int toolType) {
         mKalman.reset();
-        mPrevEventTime = 0;
-        mPrevPredictedEventTime = 0;
+        mLastSeenEventTime = 0;
+        mLastPredictEventTime = 0;
         mDownEventTime = 0;
         mPointerId = pointerId;
         mToolType = toolType;
@@ -113,7 +113,7 @@
             float tilt, long eventTime) {
         if (x == mLastPosition.a1
                 && y == mLastPosition.a2
-                && (eventTime <= (mPrevEventTime + EVENT_TIME_IGNORED_THRESHOLD_MS))) {
+                && (eventTime <= (mLastSeenEventTime + EVENT_TIME_IGNORED_THRESHOLD_MS))) {
             // Reduce Kalman filter jank by ignoring input event with similar coordinates
             // and eventTime as previous input event.
             // This is particularly useful when multiple pointer are on screen as in this case the
@@ -134,8 +134,8 @@
         // provide reliable timestamps and do not report at an even interval, so this is just
         // to be used as an estimate.
         if (mReportRates != null && mReportRates.size() < 20) {
-            if (mPrevEventTime > 0) {
-                float dt = eventTime - mPrevEventTime;
+            if (mLastSeenEventTime > 0) {
+                float dt = eventTime - mLastSeenEventTime;
                 mReportRates.add(dt);
                 float sum = 0;
                 for (float rate : mReportRates) {
@@ -144,7 +144,7 @@
                 mReportRateMs = sum / mReportRates.size();
             }
         }
-        mPrevEventTime = eventTime;
+        mLastSeenEventTime = eventTime;
     }
 
     @Override
@@ -161,8 +161,8 @@
     public boolean onTouchEvent(@NonNull MotionEvent event) {
         if (event.getActionMasked() == MotionEvent.ACTION_CANCEL) {
             mKalman.reset();
-            mPrevEventTime = 0;
-            mPrevPredictedEventTime = 0;
+            mLastSeenEventTime = 0;
+            mLastPredictEventTime = 0;
             return false;
         }
         int pointerIndex = event.findPointerIndex(mPointerId);
@@ -241,9 +241,9 @@
 
         // Predict at least as far in time as the previous prediction.
         // Otherwise, it may appear that the coordinates are going backwards.
-        if (mPrevPredictedEventTime > mPrevEventTime) {
+        if (mLastPredictEventTime > mLastSeenEventTime) {
             int minimumPredictionSampleSize = (int) Math.floor(
-                    (mPrevPredictedEventTime - mPrevEventTime) / mReportRateMs
+                    (mLastPredictEventTime - mLastSeenEventTime) / mReportRateMs
             );
             if (predictionTargetInSamples < minimumPredictionSampleSize) {
                 predictionTargetInSamples = minimumPredictionSampleSize;
@@ -256,10 +256,9 @@
             }
         }
 
-        long predictedEventTime = mPrevEventTime;
+        long predictedEventTime = mLastSeenEventTime;
         int i = 0;
         for (; i < predictionTargetInSamples; i++) {
-            predictedEventTime += Math.round(mReportRateMs);
             mAcceleration.a1 += mJank.a1 * JANK_INFLUENCE;
             mAcceleration.a2 += mJank.a2 * JANK_INFLUENCE;
             mVelocity.a1 += mAcceleration.a1 * ACCELERATION_INFLUENCE;
@@ -268,12 +267,21 @@
             mPosition.a2 += mVelocity.a2 * VELOCITY_INFLUENCE;
             mPressure += pressureChange;
 
+            // Ensure it's in the valid range
+            if (mPressure < 0) {
+                mPressure = 0;
+            } else if (mPressure > 1) {
+                mPressure = 1;
+            }
+
+            long nextPredictedEventTime = predictedEventTime + Math.round(mReportRateMs);
+
             // Abort prediction if the pen is to be lifted.
-            if (mPressure < 0.1) {
+            if (mPressure < 0.1
+                    && nextPredictedEventTime > mLastPredictEventTime) {
                 //TODO: Should we generate ACTION_UP MotionEvent instead of ACTION_MOVE?
                 break;
             }
-            mPressure = Math.min(mPressure, 1.0f);
 
             MotionEvent.PointerCoords[] coords = {new MotionEvent.PointerCoords()};
             coords[0].x = (float) mPosition.a1;
@@ -285,7 +293,7 @@
                 predictedEvent =
                         MotionEvent.obtain(
                                 mDownEventTime /* downTime */,
-                                predictedEventTime /* eventTime */,
+                                nextPredictedEventTime /* eventTime */,
                                 MotionEvent.ACTION_MOVE /* action */,
                                 1 /* pointerCount */,
                                 pointerProperties /* pointer properties */,
@@ -299,10 +307,13 @@
                                 0 /* source */,
                                 0 /* flags */);
             } else {
-                predictedEvent.addBatch(predictedEventTime, coords, 0);
+                predictedEvent.addBatch(nextPredictedEventTime, coords, 0);
             }
+            predictedEventTime = nextPredictedEventTime;
         }
-        mPrevPredictedEventTime = predictedEventTime;
+
+        // Store the last predicted time
+        mLastPredictEventTime = predictedEventTime;
 
         return predictedEvent;
     }
diff --git a/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/MotionEventGenerator.kt b/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/MotionEventGenerator.kt
index 22b76925..1e57a62 100644
--- a/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/MotionEventGenerator.kt
+++ b/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/MotionEventGenerator.kt
@@ -22,14 +22,16 @@
 class MotionEventGenerator(
     val firstXGenerator: (Long) -> Float,
     val firstYGenerator: (Long) -> Float,
+    val firstPressureGenerator: ((Long) -> Float)?,
     val secondXGenerator: ((Long) -> Float)?,
-    val secondYGenerator: ((Long) -> Float)?
+    val secondYGenerator: ((Long) -> Float)?,
+    val secondPressureGenerator: ((Long) -> Float)?,
 ) {
-
     constructor(
-        firstXGenerator: (Long) -> Float,
-        firstYGenerator: (Long) -> Float
-    ) : this(firstXGenerator, firstYGenerator, null, null)
+            firstXGenerator: (Long) -> Float,
+            firstYGenerator: (Long) -> Float,
+            firstPressureGenerator: ((Long) -> Float)?
+    ) : this(firstXGenerator, firstYGenerator, firstPressureGenerator, null, null, null)
 
     private val downEventTime: Long = 0
     private var currentEventTime: Long = downEventTime
@@ -67,7 +69,11 @@
         val coords = MotionEvent.PointerCoords()
         coords.x = firstStartX + firstXGenerator(currentEventTime - downEventTime)
         coords.y = firstStartY + firstYGenerator(currentEventTime - downEventTime)
-        coords.pressure = 1f
+        if (firstPressureGenerator == null) {
+            coords.pressure = 1f
+        } else {
+            coords.pressure = firstPressureGenerator.invoke(currentEventTime - downEventTime)
+        }
 
         motionEventBuilder.setPointer(pointerProperties, coords)
 
@@ -77,9 +83,16 @@
             secondPointerProperties.toolType = MotionEvent.TOOL_TYPE_STYLUS
 
             val secondCoords = MotionEvent.PointerCoords()
-            secondCoords.x = firstStartX + secondXGenerator.invoke(currentEventTime - downEventTime)
-            secondCoords.y = firstStartY + secondYGenerator.invoke(currentEventTime - downEventTime)
-            secondCoords.pressure = 1f
+            secondCoords.x = secondStartX +
+                    secondXGenerator.invoke(currentEventTime - downEventTime)
+            secondCoords.y = secondStartY +
+                    secondYGenerator.invoke(currentEventTime - downEventTime)
+            if (secondPressureGenerator == null) {
+                secondCoords.pressure = 1f
+            } else {
+                secondCoords.pressure =
+                        secondPressureGenerator.invoke(currentEventTime - downEventTime)
+            }
 
             motionEventBuilder.setPointer(secondPointerProperties, secondCoords)
         }
diff --git a/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/kalman/MultiPointerPredictorTest.kt b/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/kalman/MultiPointerPredictorTest.kt
index eacd62f..caee9e0 100644
--- a/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/kalman/MultiPointerPredictorTest.kt
+++ b/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/kalman/MultiPointerPredictorTest.kt
@@ -34,8 +34,10 @@
         val generator = MotionEventGenerator(
                 { delta: Long -> delta.toFloat() },
                 { delta: Long -> delta.toFloat() },
+                null,
                 { delta: Long -> delta.toFloat() },
                 { delta: Long -> delta.toFloat() },
+                null,
         )
         for (i in 1..INITIAL_FEED) {
             predictor.onTouchEvent(generator.next())
diff --git a/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/kalman/SinglePointerPredictorTest.kt b/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/kalman/SinglePointerPredictorTest.kt
index f31ac05..2f66349 100644
--- a/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/kalman/SinglePointerPredictorTest.kt
+++ b/input/input-motionprediction/src/test/kotlin/androidx/input/motionprediction/kalman/SinglePointerPredictorTest.kt
@@ -53,7 +53,7 @@
                     continue
                 }
                 val predictor = constructPredictor()
-                val generator = MotionEventGenerator(xGenerator, yGenerator)
+                val generator = MotionEventGenerator(xGenerator, yGenerator, null)
                 for (i in 1..INITIAL_FEED) {
                     predictor.onTouchEvent(generator.next())
                     predictor.predict(generator.getRateMs().toInt())
@@ -74,8 +74,8 @@
     @Test
     fun predictionNeverGoesBackwards() {
         val predictor = constructPredictor()
-        val accelerationGenerator = { delta: Long -> delta.toFloat() }
-        val motionGenerator = MotionEventGenerator(accelerationGenerator, accelerationGenerator)
+        val coordGenerator = { delta: Long -> delta.toFloat() }
+        val motionGenerator = MotionEventGenerator(coordGenerator, coordGenerator, null)
         var lastPredictedTime = 0L;
         for (i in 1..INITIAL_FEED) {
             predictor.onTouchEvent(motionGenerator.next())
@@ -90,6 +90,37 @@
         val predicted = predictor.predict(motionGenerator.getRateMs().toInt())!!
         assertThat(predicted.eventTime).isAtLeast(lastPredictedTime)
     }
+
+    @Test
+    fun predictionNeverGoesBackwardsEvenWhenLifting() {
+        val predictor = constructPredictor()
+        val coordGenerator = { delta: Long -> delta.toFloat() }
+        // Pressure will be 1 at the beginning and trend to zero while never getting there
+        val pressureGenerator = fun(delta: Long): Float {
+                if (delta > 500) {
+                    return ((700 - delta) / 500).toFloat()
+                }
+                return 1f
+            }
+        val motionGenerator =
+                MotionEventGenerator(coordGenerator, coordGenerator, pressureGenerator)
+        var lastPredictedTime = 0L
+        var lastPredictedEvent: MotionEvent? = null
+        var predicted: MotionEvent?
+        var iteration = 0
+        do {
+            predictor.onTouchEvent(motionGenerator.next())
+            predicted = predictor.predict(motionGenerator.getRateMs().toInt() * 10)
+            if (predicted != null) {
+                assertThat(predicted.eventTime).isAtLeast(lastPredictedTime)
+                lastPredictedTime = predicted.eventTime
+            } else if (lastPredictedEvent != null) {
+                assertThat(lastPredictedEvent.getHistorySize()).isEqualTo(0);
+            }
+            lastPredictedEvent = predicted
+            iteration++
+        } while (predicted != null || iteration < INITIAL_FEED)
+    }
 }
 
 private fun constructPredictor(): SinglePointerPredictor = SinglePointerPredictor(