Clean up IKE SPI resources if Create IKE is interrupted

This CL ensures IKE SPI resources are released if IKE Session
is terminated after Create IKE for IKE INIT or Rekey is sent
and before any response is received.

Without this fix, this is how SPI resources will get leaked
before caught by a CloseGuard:
- IKE enters CreateIkeLocalIkeInit and allocates an SPI
  for mLocalIkeSpiResource
- IkeSession#killSession or handleIkeFatalError is called;
  IKE state machine quits
- mLocalIkeSpiResource still holds the SPI resource until
  it is released by the CloseGuard

This will not be a problem for any functional code because it is
impossible to run out of IKE SPI values. However, for any tests
that are based on test-mode IKE and expect the IKE SPI to
always be a specific value, they may fail because the expected
SPI might be blocked by a previous test.

Bug: 357151634
Test: atest FrameworksIkeTests(new tests)
      atest CtsIkeTestCases
Flag: EXEMPT low risk bug fix
(cherry picked from https://android-review.googlesource.com/q/commit:7a1df1c43b5b8985e1da3a248c0c5204b71e4401)
Merged-In: I9059236989667ade113f844145f5813826cebf6a
Change-Id: I9059236989667ade113f844145f5813826cebf6a
diff --git a/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachine.java b/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachine.java
index e7cee92..2fd9b07 100644
--- a/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachine.java
+++ b/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachine.java
@@ -2473,7 +2473,7 @@
                     IkePayload.getPayloadForTypeInProvidedList(
                             IkePayload.PAYLOAD_TYPE_SA, IkeSaPayload.class, reqPayloads);
             if (saPayload != null) {
-                saPayload.releaseChildSpiResourcesIfExists();
+                saPayload.releaseSpiResources();
             }
         }
 
diff --git a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
index 228b299..f28b170 100644
--- a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
+++ b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
@@ -564,7 +564,7 @@
                                 CMD_ALARM_FIRED,
                                 CMD_SEND_KEEPALIVE,
                                 this));
-        mIkeSpiGenerator = new IkeSpiGenerator(mIkeContext.getRandomnessFactory());
+        mIkeSpiGenerator = mDeps.newIkeSpiGenerator(mIkeContext.getRandomnessFactory());
         mIpSecSpiGenerator =
                 new IpSecSpiGenerator(mIpSecManager, mIkeContext.getRandomnessFactory());
 
@@ -798,6 +798,11 @@
         public IkeAlarm newExactAndAllowWhileIdleAlarm(IkeAlarmConfig alarmConfig) {
             return IkeAlarm.newExactAndAllowWhileIdleAlarm(alarmConfig);
         }
+
+        /** Builds and returns a new IkeSpiGenerator */
+        public IkeSpiGenerator newIkeSpiGenerator(RandomnessFactory randomnessFactory) {
+            return new IkeSpiGenerator(randomnessFactory);
+        }
     }
 
     private boolean hasChildSessionCallback(ChildSessionCallback callback) {
@@ -3333,13 +3338,6 @@
 
         @Override
         protected void handleResponseIkeMessage(IkeMessage ikeMessage) {
-            // IKE_SA_INIT exchange and IKE SA setup succeed
-            boolean ikeInitSuccess = false;
-
-            // IKE INIT is not finished. IKE_SA_INIT request was re-sent with Notify-Cookie,
-            // and the same INIT SPI and other payloads.
-            boolean ikeInitRetriedWithCookie = false;
-
             try {
                 int exchangeType = ikeMessage.ikeHeader.exchangeType;
                 if (exchangeType != IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT) {
@@ -3356,7 +3354,6 @@
                             buildReqWithCookie(mRetransmitter.getMessage(), outCookiePayload);
 
                     sendRequest(initReq);
-                    ikeInitRetriedWithCookie = true;
                     return;
                 }
 
@@ -3375,7 +3372,6 @@
                                 buildSaLifetimeAlarmScheduler(mRemoteIkeSpiResource.getSpi()));
 
                 addIkeSaRecord(mCurrentIkeSaRecord);
-                ikeInitSuccess = true;
 
                 List<Integer> integrityAlgorithms = mSaProposal.getIntegrityAlgorithms();
 
@@ -3439,17 +3435,6 @@
                 }
 
                 handleIkeFatalError(e);
-            } finally {
-                if (!ikeInitSuccess && !ikeInitRetriedWithCookie) {
-                    if (mLocalIkeSpiResource != null) {
-                        mLocalIkeSpiResource.close();
-                        mLocalIkeSpiResource = null;
-                    }
-                    if (mRemoteIkeSpiResource != null) {
-                        mRemoteIkeSpiResource.close();
-                        mRemoteIkeSpiResource = null;
-                    }
-                }
             }
         }
 
@@ -3694,6 +3679,15 @@
             if (mRetransmitter != null) {
                 mRetransmitter.stopRetransmitting();
             }
+
+            if (mLocalIkeSpiResource != null) {
+                mLocalIkeSpiResource.close();
+                mLocalIkeSpiResource = null;
+            }
+            if (mRemoteIkeSpiResource != null) {
+                mRemoteIkeSpiResource.close();
+                mRemoteIkeSpiResource = null;
+            }
         }
 
         private class UnencryptedRetransmitter extends Retransmitter {
@@ -5024,10 +5018,13 @@
 
     /** RekeyIkeLocalCreate represents state when IKE library initiates Rekey IKE exchange. */
     class RekeyIkeLocalCreate extends RekeyIkeHandlerBase {
+        private IkeMessage mRekeyRequestMsg;
+
         @Override
         public void enterState() {
             try {
-                mRetransmitter = new EncryptedRetransmitter(buildIkeRekeyReq());
+                mRekeyRequestMsg = buildIkeRekeyReq();
+                mRetransmitter = new EncryptedRetransmitter(mRekeyRequestMsg);
             } catch (IOException e) {
                 loge("Fail to assign IKE SPI for rekey. Schedule a retry.", e);
                 mCurrentIkeSaRecord.rescheduleRekey(RETRY_INTERVAL_MS);
@@ -5036,6 +5033,17 @@
         }
 
         @Override
+        public void exitState() {
+            IkeSaPayload saPayload =
+                    mRekeyRequestMsg.getPayloadForType(
+                            IkePayload.PAYLOAD_TYPE_SA, IkeSaPayload.class);
+            if (saPayload != null) {
+                saPayload.releaseSpiResources();
+            }
+            mRekeyRequestMsg = null;
+        }
+
+        @Override
         protected void triggerRetransmit() {
             mRetransmitter.retransmit();
         }
diff --git a/src/java/com/android/internal/net/ipsec/ike/SaRecord.java b/src/java/com/android/internal/net/ipsec/ike/SaRecord.java
index b46d4bb..16aea21 100644
--- a/src/java/com/android/internal/net/ipsec/ike/SaRecord.java
+++ b/src/java/com/android/internal/net/ipsec/ike/SaRecord.java
@@ -675,6 +675,8 @@
 
             mInitiatorSpiResource = initSpi;
             mResponderSpiResource = respSpi;
+            mInitiatorSpiResource.bindToIkeSaRecord();
+            mResponderSpiResource.bindToIkeSaRecord();
 
             mSkD = skD;
             mSkPi = skPi;
@@ -925,6 +927,8 @@
         @Override
         public void close() {
             super.close();
+            mInitiatorSpiResource.unbindFromIkeSaRecord();
+            mResponderSpiResource.unbindFromIkeSaRecord();
             mInitiatorSpiResource.close();
             mResponderSpiResource.close();
         }
diff --git a/src/java/com/android/internal/net/ipsec/ike/message/IkeSaPayload.java b/src/java/com/android/internal/net/ipsec/ike/message/IkeSaPayload.java
index 9c20a7a..fa38731 100644
--- a/src/java/com/android/internal/net/ipsec/ike/message/IkeSaPayload.java
+++ b/src/java/com/android/internal/net/ipsec/ike/message/IkeSaPayload.java
@@ -583,17 +583,16 @@
     }
 
     /**
-     * Release IPsec SPI resources in the outbound Create Child request
+     * Release SPI resources in the outbound Create IKE/Child request
      *
-     * <p>This method is usually called when an IKE library fails to receive a Create Child response
-     * before it is terminated. It is also safe to call after the Create Child exchange has
-     * succeeded because the newly created IpSecTransform pair will hold the IPsec SPI resource.
+     * <p>This method is usually called when an IKE library fails to receive a Create IKE/Child
+     * response before it is terminated. It is also safe to call after the Create IKE/Child exchange
+     * has succeeded because the newly created IkeSaRecord or ChildSaRecord (IpSecTransform pair)
+     * will hold the SPI resource.
      */
-    public void releaseChildSpiResourcesIfExists() {
+    public void releaseSpiResources() {
         for (Proposal proposal : proposalList) {
-            if (proposal instanceof ChildProposal) {
-                proposal.releaseSpiResourceIfExists();
-            }
+            proposal.releaseSpiResourceIfExists();
         }
     }
 
diff --git a/src/java/com/android/internal/net/ipsec/ike/utils/IkeSecurityParameterIndex.java b/src/java/com/android/internal/net/ipsec/ike/utils/IkeSecurityParameterIndex.java
index 7336948..c7c28ca 100644
--- a/src/java/com/android/internal/net/ipsec/ike/utils/IkeSecurityParameterIndex.java
+++ b/src/java/com/android/internal/net/ipsec/ike/utils/IkeSecurityParameterIndex.java
@@ -48,10 +48,17 @@
     private final long mSpi;
     private final CloseGuard mCloseGuard = new CloseGuard();
 
+    /**
+     * Whether this SPI has been used to construct an IkeSaRecord. If it is bound, then this SPI
+     * cannot be released unless it is unbound from the IkeSaRecord.
+     */
+    private boolean mIsBoundToIkeSaRecord;
+
     // Package private constructor that MUST only be called from IkeSpiGenerator
     IkeSecurityParameterIndex(InetAddress sourceAddress, long spi) {
         mSourceAddress = sourceAddress;
         mSpi = spi;
+        mIsBoundToIkeSaRecord = false;
         mCloseGuard.open("close");
     }
 
@@ -73,6 +80,10 @@
     /** Release an SPI that was previously reserved. */
     @Override
     public void close() {
+        if (mIsBoundToIkeSaRecord) {
+            return;
+        }
+
         sAssignedIkeSpis.remove(new Pair<InetAddress, Long>(mSourceAddress, mSpi));
         mCloseGuard.close();
     }
@@ -106,4 +117,25 @@
         sAssignedIkeSpis.remove(new Pair<InetAddress, Long>(mSourceAddress, mSpi));
         mSourceAddress = newSourceAddress;
     }
+
+    /**
+     * Bind this SPI to an IkeSaRecord
+     *
+     * <p>This MUST ONLY be called from an IkeSaRecord
+     */
+    public void bindToIkeSaRecord() {
+        if (mIsBoundToIkeSaRecord) {
+            throw new IllegalStateException("Already bound");
+        }
+        mIsBoundToIkeSaRecord = true;
+    }
+
+    /**
+     * Unbind this SPI from an IkeSaRecord
+     *
+     * <p>This MUST ONLY be called from an IkeSaRecord
+     */
+    public void unbindFromIkeSaRecord() {
+        mIsBoundToIkeSaRecord = false;
+    }
 }
diff --git a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
index 99a127c..30ac326 100644
--- a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
+++ b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
@@ -450,6 +450,7 @@
     private EapAuthenticator mMockEapAuthenticator;
 
     private IkeConnectionController mSpyIkeConnectionCtrl;
+    private IkeSpiGenerator mSpyIkeSpiGenerator;
 
     private Ike3gppDataListener mMockIke3gppDataListener;
     private Ike3gppExtension mIke3gppExtension;
@@ -891,6 +892,9 @@
                 .when(spyDeps)
                 .newIkeConnectionController(
                         any(IkeContext.class), any(IkeConnectionController.Config.class));
+        doReturn(mSpyIkeSpiGenerator)
+                .when(spyDeps)
+                .newIkeSpiGenerator(any(RandomnessFactory.class));
         injectChildSessionInSpyDeps(spyDeps, child, childCb);
 
 
@@ -995,6 +999,8 @@
                                         CMD_SEND_KEEPALIVE,
                                         mockIkeConnectionCtrlCb),
                                 spyIkeConnectionCtrlDeps));
+        mSpyIkeSpiGenerator = spy(new IkeSpiGenerator(createMockRandomFactory()));
+
         mSpyDeps =
                 buildSpyDepsWithChildSession(
                         mMockChildSessionStateMachine, mMockChildSessionCallback);
@@ -1724,6 +1730,48 @@
     }
 
     @Test
+    public void testCreateIkeLocalIkeInit_closeSpi_ikeTerminated() throws Exception {
+        // Setup
+        final IkeSecurityParameterIndex mockIkeSpi = mock(IkeSecurityParameterIndex.class);
+        doReturn(mockIkeSpi).when(mSpyIkeSpiGenerator).allocateSpi(any(InetAddress.class));
+
+        // Send out IKE INIT request
+        mIkeSessionStateMachine.sendMessage(IkeSessionStateMachine.CMD_LOCAL_REQUEST_CREATE_IKE);
+        mLooper.dispatchAll();
+
+        // Verifications
+        verify(mSpyIkeSpiGenerator).allocateSpi(any(InetAddress.class));
+
+        mIkeSessionStateMachine.killSession();
+        mLooper.dispatchAll();
+
+        verify(mockIkeSpi).close();
+    }
+
+    @Test
+    public void testRekeyIkeLocalCreate_closeSpi_ikeTerminated() throws Exception {
+        // Setup
+        setupIdleStateMachine();
+        final IkeSecurityParameterIndex mockIkeSpi = mock(IkeSecurityParameterIndex.class);
+        doReturn(mockIkeSpi).when(mSpyIkeSpiGenerator).allocateSpi(any(InetAddress.class));
+
+        // Send Rekey-Create request
+        mIkeSessionStateMachine.sendMessage(
+                IkeSessionStateMachine.CMD_EXECUTE_LOCAL_REQ,
+                mLocalRequestFactory.getIkeLocalRequest(
+                        IkeSessionStateMachine.CMD_LOCAL_REQUEST_REKEY_IKE));
+        mLooper.dispatchAll();
+
+        // Verifications
+        verify(mSpyIkeSpiGenerator).allocateSpi(any(InetAddress.class));
+
+        mIkeSessionStateMachine.killSession();
+        mLooper.dispatchAll();
+
+        verify(mockIkeSpi).close();
+    }
+
+    @Test
     public void testCreateIkeLocalIkeInitNegotiatesDhGroup() throws Exception {
         // Clear the calls triggered by starting IkeSessionStateMachine in #setup()
         reset(mSpyIkeConnectionCtrl);
diff --git a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/SaRecordTest.java b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/SaRecordTest.java
index 8551dce..344b53d 100644
--- a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/SaRecordTest.java
+++ b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/SaRecordTest.java
@@ -21,12 +21,14 @@
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.AdditionalMatchers.aryEq;
 import static org.mockito.Matchers.anyInt;
 import static org.mockito.Matchers.anyObject;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -177,9 +179,9 @@
         byte[] nonceResp = TestUtils.hexStringToByteArray(IKE_NONCE_RESP_HEX_STRING);
 
         IkeSecurityParameterIndex ikeInitSpi =
-                IKE_SPI_GENERATOR.allocateSpi(LOCAL_ADDRESS, IKE_INIT_SPI);
+                spy(IKE_SPI_GENERATOR.allocateSpi(LOCAL_ADDRESS, IKE_INIT_SPI));
         IkeSecurityParameterIndex ikeRespSpi =
-                IKE_SPI_GENERATOR.allocateSpi(REMOTE_ADDRESS, IKE_RESP_SPI);
+                spy(IKE_SPI_GENERATOR.allocateSpi(REMOTE_ADDRESS, IKE_RESP_SPI));
         IkeSaRecordConfig ikeSaRecordConfig =
                 new IkeSaRecordConfig(
                         ikeInitSpi,
@@ -221,9 +223,13 @@
         assertArrayEquals(
                 TestUtils.hexStringToByteArray(IKE_SK_PRF_RESP_HEX_STRING), ikeSaRecord.getSkPr());
         verify(mMockLifetimeAlarmScheduler).scheduleLifetimeExpiryAlarm(anyString());
+        verify(ikeInitSpi).bindToIkeSaRecord();
+        verify(ikeRespSpi).bindToIkeSaRecord();
 
         ikeSaRecord.close();
         verify(mMockLifetimeAlarmScheduler).cancelLifetimeExpiryAlarm(anyString());
+        verify(ikeInitSpi).unbindFromIkeSaRecord();
+        verify(ikeRespSpi).unbindFromIkeSaRecord();
     }
 
     // Test generating keying material and building IpSecTransform for making Child SA.
@@ -379,4 +385,29 @@
     public void testRemoteInitChildKeyExchange() throws Exception {
         verifyChildKeyExchange(false /* isLocalInit */);
     }
+
+    @Test
+    public void testBindIkeSpiToSaRecord() throws Exception {
+        IkeSecurityParameterIndex ikeInitSpi =
+                IKE_SPI_GENERATOR.allocateSpi(LOCAL_ADDRESS, IKE_INIT_SPI);
+
+        // Try closing SPI that is bound to an IKE SA
+        ikeInitSpi.bindToIkeSaRecord();
+        ikeInitSpi.close();
+
+        try {
+            IKE_SPI_GENERATOR.allocateSpi(LOCAL_ADDRESS, IKE_INIT_SPI);
+            fail("Expect to fail since this SPI-address combo is not released");
+        } catch (Exception expected) {
+        }
+
+        // Try closing SPI that is no longer bound to an IKE SA
+        ikeInitSpi.unbindFromIkeSaRecord();
+        ikeInitSpi.close();
+
+        IkeSecurityParameterIndex ikeInitSpiAnother =
+                IKE_SPI_GENERATOR.allocateSpi(LOCAL_ADDRESS, IKE_INIT_SPI);
+        assertEquals(LOCAL_ADDRESS, ikeInitSpiAnother.getSourceAddress());
+        assertEquals(IKE_INIT_SPI, ikeInitSpiAnother.getSpi());
+    }
 }