Respect caller config to disable VCN safe mode

Update VcnGatewayConnection to support disabling safe mode

Bug: 276358140
Test: atest FrameworksVcnTests (new tests)
Test: atest CtsVcnTestCases
Change-Id: If767dbace925c7e13705db2ea6f23c890992405c
diff --git a/services/core/java/com/android/server/vcn/VcnContext.java b/services/core/java/com/android/server/vcn/VcnContext.java
index d958222..9213d96 100644
--- a/services/core/java/com/android/server/vcn/VcnContext.java
+++ b/services/core/java/com/android/server/vcn/VcnContext.java
@@ -18,6 +18,8 @@
 
 import android.annotation.NonNull;
 import android.content.Context;
+import android.net.vcn.FeatureFlags;
+import android.net.vcn.FeatureFlagsImpl;
 import android.os.Looper;
 
 import java.util.Objects;
@@ -31,6 +33,7 @@
     @NonNull private final Context mContext;
     @NonNull private final Looper mLooper;
     @NonNull private final VcnNetworkProvider mVcnNetworkProvider;
+    @NonNull private final FeatureFlags mFeatureFlags;
     private final boolean mIsInTestMode;
 
     public VcnContext(
@@ -42,6 +45,9 @@
         mLooper = Objects.requireNonNull(looper, "Missing looper");
         mVcnNetworkProvider = Objects.requireNonNull(vcnNetworkProvider, "Missing networkProvider");
         mIsInTestMode = isInTestMode;
+
+        // Auto-generated class
+        mFeatureFlags = new FeatureFlagsImpl();
     }
 
     @NonNull
@@ -63,6 +69,11 @@
         return mIsInTestMode;
     }
 
+    @NonNull
+    public FeatureFlags getFeatureFlags() {
+        return mFeatureFlags;
+    }
+
     /**
      * Verifies that the caller is running on the VcnContext Thread.
      *
diff --git a/services/core/java/com/android/server/vcn/VcnGatewayConnection.java b/services/core/java/com/android/server/vcn/VcnGatewayConnection.java
index d480ddb..54c97dd 100644
--- a/services/core/java/com/android/server/vcn/VcnGatewayConnection.java
+++ b/services/core/java/com/android/server/vcn/VcnGatewayConnection.java
@@ -1222,6 +1222,14 @@
 
     @VisibleForTesting(visibility = Visibility.PRIVATE)
     void setSafeModeAlarm() {
+        final boolean isFlagSafeModeConfigEnabled = mVcnContext.getFeatureFlags().safeModeConfig();
+        logVdbg("isFlagSafeModeConfigEnabled " + isFlagSafeModeConfigEnabled);
+
+        if (isFlagSafeModeConfigEnabled && !mConnectionConfig.isSafeModeEnabled()) {
+            logVdbg("setSafeModeAlarm: safe mode disabled");
+            return;
+        }
+
         logVdbg("Setting safe mode alarm; mCurrentToken: " + mCurrentToken);
 
         // Only schedule a NEW alarm if none is already set.
diff --git a/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java b/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java
index 359ef83..cb37821 100644
--- a/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java
+++ b/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java
@@ -117,6 +117,16 @@
         return buildTestConfig(UNDERLYING_NETWORK_TEMPLATES);
     }
 
+    // Public for use in VcnGatewayConnectionTest
+    public static VcnGatewayConnectionConfig.Builder newTestBuilderMinimal() {
+        final VcnGatewayConnectionConfig.Builder builder = newBuilder();
+        for (int caps : EXPOSED_CAPS) {
+            builder.addExposedCapability(caps);
+        }
+
+        return builder;
+    }
+
     private static VcnGatewayConnectionConfig.Builder newBuilder() {
         // Append a unique identifier to the name prefix to guarantee that all created
         // VcnGatewayConnectionConfigs have a unique name (required by VcnConfig).
diff --git a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionConnectedStateTest.java b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionConnectedStateTest.java
index 302af52..bf73198 100644
--- a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionConnectedStateTest.java
+++ b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionConnectedStateTest.java
@@ -75,6 +75,9 @@
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.server.vcn.VcnGatewayConnection.VcnChildSessionCallback;
+import com.android.server.vcn.VcnGatewayConnection.VcnChildSessionConfiguration;
+import com.android.server.vcn.VcnGatewayConnection.VcnIkeSession;
+import com.android.server.vcn.VcnGatewayConnection.VcnNetworkAgent;
 import com.android.server.vcn.routeselection.UnderlyingNetworkRecord;
 import com.android.server.vcn.util.MtuUtils;
 
@@ -651,6 +654,74 @@
         verifySafeModeStateAndCallbackFired(2 /* invocationCount */, true /* isInSafeMode */);
     }
 
+    private void verifySetSafeModeAlarm(
+            boolean safeModeEnabledByCaller,
+            boolean safeModeConfigFlagEnabled,
+            boolean expectingSafeModeEnabled)
+            throws Exception {
+        final VcnGatewayConnectionConfig config =
+                VcnGatewayConnectionConfigTest.newTestBuilderMinimal()
+                        .enableSafeMode(safeModeEnabledByCaller)
+                        .build();
+        final VcnGatewayConnection.Dependencies deps =
+                mock(VcnGatewayConnection.Dependencies.class);
+        setUpWakeupMessage(
+                mSafeModeTimeoutAlarm, VcnGatewayConnection.SAFEMODE_TIMEOUT_ALARM, deps);
+        doReturn(safeModeConfigFlagEnabled).when(mFeatureFlags).safeModeConfig();
+
+        final VcnGatewayConnection connection =
+                new VcnGatewayConnection(
+                        mVcnContext,
+                        TEST_SUB_GRP,
+                        TEST_SUBSCRIPTION_SNAPSHOT,
+                        config,
+                        mGatewayStatusCallback,
+                        true /* isMobileDataEnabled */,
+                        deps);
+
+        connection.setSafeModeAlarm();
+
+        final int expectedCallCnt = expectingSafeModeEnabled ? 1 : 0;
+        verify(deps, times(expectedCallCnt))
+                .newWakeupMessage(
+                        eq(mVcnContext),
+                        any(),
+                        eq(VcnGatewayConnection.SAFEMODE_TIMEOUT_ALARM),
+                        any());
+    }
+
+    @Test
+    public void testSafeModeEnabled_configFlagEnabled() throws Exception {
+        verifySetSafeModeAlarm(
+                true /* safeModeEnabledByCaller */,
+                true /* safeModeConfigFlagEnabled */,
+                true /* expectingSafeModeEnabled */);
+    }
+
+    @Test
+    public void testSafeModeEnabled_configFlagDisabled() throws Exception {
+        verifySetSafeModeAlarm(
+                true /* safeModeEnabledByCaller */,
+                false /* safeModeConfigFlagEnabled */,
+                true /* expectingSafeModeEnabled */);
+    }
+
+    @Test
+    public void testSafeModeDisabled_configFlagEnabled() throws Exception {
+        verifySetSafeModeAlarm(
+                false /* safeModeEnabledByCaller */,
+                true /* safeModeConfigFlagEnabled */,
+                false /* expectingSafeModeEnabled */);
+    }
+
+    @Test
+    public void testSafeModeDisabled_configFlagDisabled() throws Exception {
+        verifySetSafeModeAlarm(
+                false /* safeModeEnabledByCaller */,
+                false /* safeModeConfigFlagEnabled */,
+                true /* expectingSafeModeEnabled */);
+    }
+
     private Consumer<VcnNetworkAgent> setupNetworkAndGetUnwantedCallback() {
         triggerChildOpened();
         mTestLooper.dispatchAll();
diff --git a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java
index 5efbf59..edced87 100644
--- a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java
+++ b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java
@@ -53,6 +53,7 @@
 import android.net.ipsec.ike.IkeSessionCallback;
 import android.net.ipsec.ike.IkeSessionConfiguration;
 import android.net.ipsec.ike.IkeSessionConnectionInfo;
+import android.net.vcn.FeatureFlags;
 import android.net.vcn.VcnGatewayConnectionConfig;
 import android.net.vcn.VcnGatewayConnectionConfigTest;
 import android.os.ParcelUuid;
@@ -165,6 +166,7 @@
     @NonNull protected final Context mContext;
     @NonNull protected final TestLooper mTestLooper;
     @NonNull protected final VcnNetworkProvider mVcnNetworkProvider;
+    @NonNull protected final FeatureFlags mFeatureFlags;
     @NonNull protected final VcnContext mVcnContext;
     @NonNull protected final VcnGatewayConnectionConfig mConfig;
     @NonNull protected final VcnGatewayStatusCallback mGatewayStatusCallback;
@@ -190,6 +192,7 @@
         mContext = mock(Context.class);
         mTestLooper = new TestLooper();
         mVcnNetworkProvider = mock(VcnNetworkProvider.class);
+        mFeatureFlags = mock(FeatureFlags.class);
         mVcnContext = mock(VcnContext.class);
         mConfig = VcnGatewayConnectionConfigTest.buildTestConfig();
         mGatewayStatusCallback = mock(VcnGatewayStatusCallback.class);
@@ -222,6 +225,7 @@
         doReturn(mContext).when(mVcnContext).getContext();
         doReturn(mTestLooper.getLooper()).when(mVcnContext).getLooper();
         doReturn(mVcnNetworkProvider).when(mVcnContext).getVcnNetworkProvider();
+        doReturn(mFeatureFlags).when(mVcnContext).getFeatureFlags();
 
         doReturn(mUnderlyingNetworkController)
                 .when(mDeps)
@@ -241,8 +245,15 @@
         doReturn(ELAPSED_REAL_TIME).when(mDeps).getElapsedRealTime();
     }
 
+    protected void setUpWakeupMessage(
+            @NonNull WakeupMessage msg,
+            @NonNull String cmdName,
+            VcnGatewayConnection.Dependencies deps) {
+        doReturn(msg).when(deps).newWakeupMessage(eq(mVcnContext), any(), eq(cmdName), any());
+    }
+
     private void setUpWakeupMessage(@NonNull WakeupMessage msg, @NonNull String cmdName) {
-        doReturn(msg).when(mDeps).newWakeupMessage(eq(mVcnContext), any(), eq(cmdName), any());
+        setUpWakeupMessage(msg, cmdName, mDeps);
     }
 
     @Before