VCN: Handle sequence number leap in packet loss detector

This pacth enables IPsec packet loss detector to handle the case when
there is an intentional sequence number leap on the server's downlink.

Previously the detector assumes that sequence number will always
increase consecutively, which is not true. During load balancing the
server might add a big leap on the sequence number intentionally. In
such case a high packet loss rate does not always indicate a lossy
network. At the same time, a low loss rate does mean the network is
not lossy

Bug: 332598276
Test: atest FrameworksVcnTests (new tests) && atest CtsVcnTestCases
Change-Id: I163cb274d293a305499fd60f7ad6eed394af5a4d
diff --git a/core/java/android/net/vcn/VcnManager.java b/core/java/android/net/vcn/VcnManager.java
index 6246dd7..91cdf8d 100644
--- a/core/java/android/net/vcn/VcnManager.java
+++ b/core/java/android/net/vcn/VcnManager.java
@@ -124,6 +124,22 @@
             "vcn_network_selection_ipsec_packet_loss_percent_threshold";
 
     /**
+     * Key for detecting unusually large increases in IPsec packet sequence numbers.
+     *
+     * <p>If the sequence number increases by more than this value within a second, it may indicate
+     * an intentional leap on the server's downlink. To avoid false positives, the packet loss
+     * detector will suppress loss reporting.
+     *
+     * <p>By default, there's no maximum limit enforced, prioritizing detection of lossy networks.
+     * To reduce false positives, consider setting an appropriate maximum threshold.
+     *
+     * @hide
+     */
+    @NonNull
+    public static final String VCN_NETWORK_SELECTION_MAX_SEQ_NUM_INCREASE_PER_SECOND_KEY =
+            "vcn_network_selection_max_seq_num_increase_per_second";
+
+    /**
      * Key for the list of timeouts in minute to stop penalizing an underlying network candidate
      *
      * @hide
@@ -180,6 +196,7 @@
                 VCN_NETWORK_SELECTION_WIFI_EXIT_RSSI_THRESHOLD_KEY,
                 VCN_NETWORK_SELECTION_POLL_IPSEC_STATE_INTERVAL_SECONDS_KEY,
                 VCN_NETWORK_SELECTION_IPSEC_PACKET_LOSS_PERCENT_THRESHOLD_KEY,
+                VCN_NETWORK_SELECTION_MAX_SEQ_NUM_INCREASE_PER_SECOND_KEY,
                 VCN_NETWORK_SELECTION_PENALTY_TIMEOUT_MINUTES_LIST_KEY,
                 VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY,
                 VCN_SAFE_MODE_TIMEOUT_SECONDS_KEY,
diff --git a/core/java/android/net/vcn/flags.aconfig b/core/java/android/net/vcn/flags.aconfig
index e64823a..6fde398 100644
--- a/core/java/android/net/vcn/flags.aconfig
+++ b/core/java/android/net/vcn/flags.aconfig
@@ -34,4 +34,14 @@
     namespace: "vcn"
     description: "Re-evaluate IPsec packet loss on LinkProperties or NetworkCapabilities change"
     bug: "323238888"
+}
+
+flag{
+    name: "handle_seq_num_leap"
+    namespace: "vcn"
+    description: "Do not report bad network when there is a suspected sequence number leap"
+    bug: "332598276"
+    metadata {
+      purpose: PURPOSE_BUGFIX
+    }
 }
\ No newline at end of file
diff --git a/services/core/java/com/android/server/vcn/routeselection/IpSecPacketLossDetector.java b/services/core/java/com/android/server/vcn/routeselection/IpSecPacketLossDetector.java
index c5d3333..206523a 100644
--- a/services/core/java/com/android/server/vcn/routeselection/IpSecPacketLossDetector.java
+++ b/services/core/java/com/android/server/vcn/routeselection/IpSecPacketLossDetector.java
@@ -70,6 +70,7 @@
             value = {
                 PACKET_LOSS_RATE_VALID,
                 PACKET_LOSS_RATE_INVALID,
+                PACKET_LOSS_UNUSUAL_SEQ_NUM_LEAP,
             })
     @Target({ElementType.TYPE_USE})
     private @interface PacketLossResultType {}
@@ -89,6 +90,16 @@
      */
     private static final int PACKET_LOSS_RATE_INVALID = 1;
 
+    /**
+     * The sequence number increase is unusually large and might be caused an intentional leap on
+     * the server's downlink
+     *
+     * <p>Inbound sequence number will not always increase consecutively. During load balancing the
+     * server might add a big leap on the sequence number intentionally. In such case a high packet
+     * loss rate does not always indicate a lossy network
+     */
+    private static final int PACKET_LOSS_UNUSUAL_SEQ_NUM_LEAP = 2;
+
     // For VoIP, losses between 5% and 10% of the total packet stream will affect the quality
     // significantly (as per "Computer Networking for LANS to WANS: Hardware, Software and
     // Security"). For audio and video streaming, above 10-12% packet loss is unacceptable (as per
@@ -98,8 +109,12 @@
 
     private static final int POLL_IPSEC_STATE_INTERVAL_SECONDS_DEFAULT = 20;
 
+    // By default, there's no maximum limit enforced
+    private static final int MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED = -1;
+
     private long mPollIpSecStateIntervalMs;
     private final int mPacketLossRatePercentThreshold;
+    private int mMaxSeqNumIncreasePerSecond;
 
     @NonNull private final Handler mHandler;
     @NonNull private final PowerManager mPowerManager;
@@ -138,6 +153,7 @@
 
         mPollIpSecStateIntervalMs = getPollIpSecStateIntervalMs(carrierConfig);
         mPacketLossRatePercentThreshold = getPacketLossRatePercentThreshold(carrierConfig);
+        mMaxSeqNumIncreasePerSecond = getMaxSeqNumIncreasePerSecond(carrierConfig);
 
         // Register for system broadcasts to monitor idle mode change
         final IntentFilter intentFilter = new IntentFilter();
@@ -202,6 +218,24 @@
         return IPSEC_PACKET_LOSS_PERCENT_THRESHOLD_DEFAULT;
     }
 
+    @VisibleForTesting(visibility = Visibility.PRIVATE)
+    static int getMaxSeqNumIncreasePerSecond(@Nullable PersistableBundleWrapper carrierConfig) {
+        int maxSeqNumIncrease = MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED;
+        if (Flags.handleSeqNumLeap() && carrierConfig != null) {
+            maxSeqNumIncrease =
+                    carrierConfig.getInt(
+                            VcnManager.VCN_NETWORK_SELECTION_MAX_SEQ_NUM_INCREASE_PER_SECOND_KEY,
+                            MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED);
+        }
+
+        if (maxSeqNumIncrease < MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED) {
+            logE(TAG, "Invalid value of MAX_SEQ_NUM_INCREASE_PER_SECOND_KEY " + maxSeqNumIncrease);
+            return MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED;
+        }
+
+        return maxSeqNumIncrease;
+    }
+
     @Override
     protected void onSelectedUnderlyingNetworkChanged() {
         if (!isSelectedUnderlyingNetwork()) {
@@ -237,6 +271,10 @@
         // The already scheduled event will not be affected. The followup events will be scheduled
         // with the new interval
         mPollIpSecStateIntervalMs = getPollIpSecStateIntervalMs(carrierConfig);
+
+        if (Flags.handleSeqNumLeap()) {
+            mMaxSeqNumIncreasePerSecond = getMaxSeqNumIncreasePerSecond(carrierConfig);
+        }
     }
 
     @Override
@@ -339,7 +377,10 @@
 
         final PacketLossCalculationResult calculateResult =
                 mPacketLossCalculator.getPacketLossRatePercentage(
-                        mLastIpSecTransformState, state, getLogPrefix());
+                        mLastIpSecTransformState,
+                        state,
+                        mMaxSeqNumIncreasePerSecond,
+                        getLogPrefix());
 
         if (calculateResult.getResultType() == PACKET_LOSS_RATE_INVALID) {
             return;
@@ -356,11 +397,18 @@
         mLastIpSecTransformState = state;
         if (calculateResult.getPacketLossRatePercent() < mPacketLossRatePercentThreshold) {
             logV(logMsg);
+
+            // In both "valid" or "unusual_seq_num_leap" cases, notify that the network has passed
+            // the validation
             onValidationResultReceivedInternal(false /* isFailed */);
         } else {
             logInfo(logMsg);
-            onValidationResultReceivedInternal(true /* isFailed */);
 
+            if (calculateResult.getResultType() == PACKET_LOSS_RATE_VALID) {
+                onValidationResultReceivedInternal(true /* isFailed */);
+            }
+
+            // In both "valid" or "unusual_seq_num_leap" cases, trigger network validation
             if (Flags.validateNetworkOnIpsecLoss()) {
                 // Trigger re-validation of the underlying network; if it fails, the VCN will
                 // attempt to migrate away.
@@ -376,6 +424,7 @@
         public PacketLossCalculationResult getPacketLossRatePercentage(
                 @NonNull IpSecTransformState oldState,
                 @NonNull IpSecTransformState newState,
+                int maxSeqNumIncreasePerSecond,
                 String logPrefix) {
             logVIpSecTransform("oldState", oldState, logPrefix);
             logVIpSecTransform("newState", newState, logPrefix);
@@ -392,6 +441,22 @@
                 return PacketLossCalculationResult.invalid();
             }
 
+            boolean isUnusualSeqNumLeap = false;
+
+            // Handle sequence number leap
+            if (Flags.handleSeqNumLeap()
+                    && maxSeqNumIncreasePerSecond != MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED) {
+                final long timeDiffMillis =
+                        newState.getTimestampMillis() - oldState.getTimestampMillis();
+                final long maxSeqNumIncrease = timeDiffMillis * maxSeqNumIncreasePerSecond / 1000;
+
+                // Sequence numbers are unsigned 32-bit values. If maxSeqNumIncrease overflows,
+                // isUnusualSeqNumLeap can never be true.
+                if (maxSeqNumIncrease >= 0 && newSeqHi - oldSeqHi >= maxSeqNumIncrease) {
+                    isUnusualSeqNumLeap = true;
+                }
+            }
+
             // Get the expected packet count by assuming there is no packet loss. In this case, SA
             // should receive all packets whose sequence numbers are smaller than the lower bound of
             // the replay window AND the packets received within the window.
@@ -420,7 +485,9 @@
             }
 
             final int percent = 100 - (int) (actualPktCntDiff * 100 / expectedPktCntDiff);
-            return PacketLossCalculationResult.valid(percent);
+            return isUnusualSeqNumLeap
+                    ? PacketLossCalculationResult.unusualSeqNumLeap(percent)
+                    : PacketLossCalculationResult.valid(percent);
         }
     }
 
@@ -462,6 +529,11 @@
                     PACKET_LOSS_RATE_INVALID, PACKET_LOSS_PERCENT_UNAVAILABLE);
         }
 
+        /** Construct an instance indicating that there is an unusual sequence number leap */
+        public static PacketLossCalculationResult unusualSeqNumLeap(int percent) {
+            return new PacketLossCalculationResult(PACKET_LOSS_UNUSUAL_SEQ_NUM_LEAP, percent);
+        }
+
         @PacketLossResultType
         public int getResultType() {
             return mResultType;
diff --git a/services/core/java/com/android/server/vcn/routeselection/NetworkMetricMonitor.java b/services/core/java/com/android/server/vcn/routeselection/NetworkMetricMonitor.java
index a1b212f..b9b1060 100644
--- a/services/core/java/com/android/server/vcn/routeselection/NetworkMetricMonitor.java
+++ b/services/core/java/com/android/server/vcn/routeselection/NetworkMetricMonitor.java
@@ -272,6 +272,11 @@
         }
     }
 
+    protected static void logE(String className, String msgWithPrefix) {
+        Slog.w(className, msgWithPrefix);
+        LOCAL_LOG.log("[ERROR ] " + className + msgWithPrefix);
+    }
+
     protected static void logWtf(String className, String msgWithPrefix) {
         Slog.wtf(className, msgWithPrefix);
         LOCAL_LOG.log("[WTF ] " + className + msgWithPrefix);
diff --git a/tests/vcn/java/com/android/server/vcn/routeselection/IpSecPacketLossDetectorTest.java b/tests/vcn/java/com/android/server/vcn/routeselection/IpSecPacketLossDetectorTest.java
index 0a83a53..68a2ad7 100644
--- a/tests/vcn/java/com/android/server/vcn/routeselection/IpSecPacketLossDetectorTest.java
+++ b/tests/vcn/java/com/android/server/vcn/routeselection/IpSecPacketLossDetectorTest.java
@@ -17,8 +17,10 @@
 package com.android.server.vcn.routeselection;
 
 import static android.net.vcn.VcnManager.VCN_NETWORK_SELECTION_IPSEC_PACKET_LOSS_PERCENT_THRESHOLD_KEY;
+import static android.net.vcn.VcnManager.VCN_NETWORK_SELECTION_MAX_SEQ_NUM_INCREASE_PER_SECOND_KEY;
 import static android.net.vcn.VcnManager.VCN_NETWORK_SELECTION_POLL_IPSEC_STATE_INTERVAL_SECONDS_KEY;
 
+import static com.android.server.vcn.routeselection.IpSecPacketLossDetector.getMaxSeqNumIncreasePerSecond;
 import static com.android.server.vcn.util.PersistableBundleUtils.PersistableBundleWrapper;
 
 import static org.junit.Assert.assertEquals;
@@ -65,6 +67,7 @@
     private static final int REPLAY_BITMAP_LEN_BYTE = 512;
     private static final int REPLAY_BITMAP_LEN_BIT = REPLAY_BITMAP_LEN_BYTE * 8;
     private static final int IPSEC_PACKET_LOSS_PERCENT_THRESHOLD = 5;
+    private static final int MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED = -1;
     private static final long POLL_IPSEC_STATE_INTERVAL_MS = TimeUnit.SECONDS.toMillis(30L);
 
     @Mock private IpSecTransformWrapper mIpSecTransform;
@@ -91,6 +94,9 @@
                         eq(VCN_NETWORK_SELECTION_IPSEC_PACKET_LOSS_PERCENT_THRESHOLD_KEY),
                         anyInt()))
                 .thenReturn(IPSEC_PACKET_LOSS_PERCENT_THRESHOLD);
+        when(mCarrierConfig.getInt(
+                        eq(VCN_NETWORK_SELECTION_MAX_SEQ_NUM_INCREASE_PER_SECOND_KEY), anyInt()))
+                .thenReturn(MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED);
 
         when(mDependencies.getPacketLossCalculator()).thenReturn(mPacketLossCalculator);
 
@@ -112,6 +118,20 @@
                 .build();
     }
 
+    private static IpSecTransformState newNextTransformState(
+            IpSecTransformState before,
+            long timeDiffMillis,
+            long rxSeqNoDiff,
+            long packtCountDiff,
+            int packetInWin) {
+        return new IpSecTransformState.Builder()
+                .setTimestampMillis(before.getTimestampMillis() + timeDiffMillis)
+                .setRxHighestSequenceNumber(before.getRxHighestSequenceNumber() + rxSeqNoDiff)
+                .setPacketCount(before.getPacketCount() + packtCountDiff)
+                .setReplayBitmap(newReplayBitmap(packetInWin))
+                .build();
+    }
+
     private static byte[] newReplayBitmap(int receivedPktCnt) {
         final BitSet bitSet = new BitSet(REPLAY_BITMAP_LEN_BIT);
         for (int i = 0; i < receivedPktCnt; i++) {
@@ -165,7 +185,7 @@
         // Verify the first polled state is stored
         assertEquals(mTransformStateInitial, mIpSecPacketLossDetector.getLastTransformState());
         verify(mPacketLossCalculator, never())
-                .getPacketLossRatePercentage(any(), any(), anyString());
+                .getPacketLossRatePercentage(any(), any(), anyInt(), anyString());
 
         // Verify next poll is scheduled
         assertNull(mTestLooper.nextMessage());
@@ -278,7 +298,7 @@
 
         xfrmStateReceiver.onResult(newTransformState(1, 1, newReplayBitmap(1)));
         verify(mPacketLossCalculator, never())
-                .getPacketLossRatePercentage(any(), any(), anyString());
+                .getPacketLossRatePercentage(any(), any(), anyInt(), anyString());
     }
 
     @Test
@@ -289,7 +309,7 @@
 
         xfrmStateReceiver.onError(new RuntimeException("Test"));
         verify(mPacketLossCalculator, never())
-                .getPacketLossRatePercentage(any(), any(), anyString());
+                .getPacketLossRatePercentage(any(), any(), anyInt(), anyString());
     }
 
     private void checkHandleLossRate(
@@ -301,7 +321,7 @@
                 startMonitorAndCaptureStateReceiver();
         doReturn(mockPacketLossRate)
                 .when(mPacketLossCalculator)
-                .getPacketLossRatePercentage(any(), any(), anyString());
+                .getPacketLossRatePercentage(any(), any(), anyInt(), anyString());
 
         // Mock receiving two states with mTransformStateInitial and an arbitrary transformNew
         final IpSecTransformState transformNew = newTransformState(1, 1, newReplayBitmap(1));
@@ -311,7 +331,10 @@
         // Verifications
         verify(mPacketLossCalculator)
                 .getPacketLossRatePercentage(
-                        eq(mTransformStateInitial), eq(transformNew), anyString());
+                        eq(mTransformStateInitial),
+                        eq(transformNew),
+                        eq(MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED),
+                        anyString());
 
         if (isLastStateExpectedToUpdate) {
             assertEquals(transformNew, mIpSecPacketLossDetector.getLastTransformState());
@@ -351,6 +374,22 @@
                 false /* isCallbackExpected */);
     }
 
+    @Test
+    public void testHandleLossRate_unusualSeqNumLeap_highLossRate() throws Exception {
+        checkHandleLossRate(
+                PacketLossCalculationResult.unusualSeqNumLeap(22),
+                true /* isLastStateExpectedToUpdate */,
+                false /* isCallbackExpected */);
+    }
+
+    @Test
+    public void testHandleLossRate_unusualSeqNumLeap_lowLossRate() throws Exception {
+        checkHandleLossRate(
+                PacketLossCalculationResult.unusualSeqNumLeap(2),
+                true /* isLastStateExpectedToUpdate */,
+                true /* isCallbackExpected */);
+    }
+
     private void checkGetPacketLossRate(
             IpSecTransformState oldState,
             IpSecTransformState newState,
@@ -358,7 +397,8 @@
             throws Exception {
         assertEquals(
                 expectedLossRate,
-                mPacketLossCalculator.getPacketLossRatePercentage(oldState, newState, TAG));
+                mPacketLossCalculator.getPacketLossRatePercentage(
+                        oldState, newState, MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED, TAG));
     }
 
     private void checkGetPacketLossRate(
@@ -443,6 +483,45 @@
         checkGetPacketLossRate(oldState, 20000, 14000, 3000, 10);
     }
 
+    private void checkGetPktLossRate_unusualSeqNumLeap(
+            int maxSeqNumIncreasePerSecond,
+            int timeDiffMillis,
+            int rxSeqNoDiff,
+            PacketLossCalculationResult expected)
+            throws Exception {
+        final IpSecTransformState oldState = mTransformStateInitial;
+        final IpSecTransformState newState =
+                newNextTransformState(
+                        oldState,
+                        timeDiffMillis,
+                        rxSeqNoDiff,
+                        1 /* packtCountDiff */,
+                        1 /* packetInWin */);
+
+        assertEquals(
+                expected,
+                mPacketLossCalculator.getPacketLossRatePercentage(
+                        oldState, newState, maxSeqNumIncreasePerSecond, TAG));
+    }
+
+    @Test
+    public void testGetPktLossRate_unusualSeqNumLeap() throws Exception {
+        checkGetPktLossRate_unusualSeqNumLeap(
+                10000 /* maxSeqNumIncreasePerSecond */,
+                (int) TimeUnit.SECONDS.toMillis(2L),
+                30000 /* rxSeqNoDiff */,
+                PacketLossCalculationResult.unusualSeqNumLeap(100));
+    }
+
+    @Test
+    public void testGetPktLossRate_unusualSeqNumLeap_smallSeqNumDiff() throws Exception {
+        checkGetPktLossRate_unusualSeqNumLeap(
+                10000 /* maxSeqNumIncreasePerSecond */,
+                (int) TimeUnit.SECONDS.toMillis(2L),
+                5000 /* rxSeqNoDiff */,
+                PacketLossCalculationResult.valid(100));
+    }
+
     // Verify the polling event is scheduled with expected delays
     private void verifyPollEventDelayAndScheduleNext(long expectedDelayMs) {
         if (expectedDelayMs > 0) {
@@ -469,4 +548,24 @@
         // Verify the 3rd poll is scheduled with configured delay
         verifyPollEventDelayAndScheduleNext(POLL_IPSEC_STATE_INTERVAL_MS);
     }
+
+    @Test
+    public void testGetMaxSeqNumIncreasePerSecond() throws Exception {
+        final int seqNumLeapNegative = 500_000;
+        when(mCarrierConfig.getInt(
+                        eq(VCN_NETWORK_SELECTION_MAX_SEQ_NUM_INCREASE_PER_SECOND_KEY), anyInt()))
+                .thenReturn(seqNumLeapNegative);
+        assertEquals(seqNumLeapNegative, getMaxSeqNumIncreasePerSecond(mCarrierConfig));
+    }
+
+    @Test
+    public void testGetMaxSeqNumIncreasePerSecond_negativeValue() throws Exception {
+        final int seqNumLeapNegative = -10;
+        when(mCarrierConfig.getInt(
+                        eq(VCN_NETWORK_SELECTION_MAX_SEQ_NUM_INCREASE_PER_SECOND_KEY), anyInt()))
+                .thenReturn(seqNumLeapNegative);
+        assertEquals(
+                MAX_SEQ_NUM_INCREASE_DEFAULT_DISABLED,
+                getMaxSeqNumIncreasePerSecond(mCarrierConfig));
+    }
 }
diff --git a/tests/vcn/java/com/android/server/vcn/routeselection/NetworkEvaluationTestBase.java b/tests/vcn/java/com/android/server/vcn/routeselection/NetworkEvaluationTestBase.java
index af6daa1..6189fb0 100644
--- a/tests/vcn/java/com/android/server/vcn/routeselection/NetworkEvaluationTestBase.java
+++ b/tests/vcn/java/com/android/server/vcn/routeselection/NetworkEvaluationTestBase.java
@@ -123,6 +123,7 @@
 
         mSetFlagsRule.enableFlags(Flags.FLAG_VALIDATE_NETWORK_ON_IPSEC_LOSS);
         mSetFlagsRule.enableFlags(Flags.FLAG_EVALUATE_IPSEC_LOSS_ON_LP_NC_CHANGE);
+        mSetFlagsRule.enableFlags(Flags.FLAG_HANDLE_SEQ_NUM_LEAP);
 
         when(mNetwork.getNetId()).thenReturn(-1);