fix: add SAML challenge to reauth (#819)

* fix: add SAML challenge to reauth

* add enable_reauth_refresh flag

* address comments

* fix unit test

* address comments

* update

* update

* update

* update

* 🦉 Updates from OwlBot

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
Co-authored-by: Tres Seaver <[email protected]>
diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py
index 57f181e..e9e7377 100644
--- a/google/auth/exceptions.py
+++ b/google/auth/exceptions.py
@@ -57,3 +57,7 @@
         super(ReauthFailError, self).__init__(
             "Reauthentication failed. {0}".format(message)
         )
+
+
+class ReauthSamlChallengeFailError(ReauthFailError):
+    """An exception for SAML reauth challenge failures."""
diff --git a/google/oauth2/_credentials_async.py b/google/oauth2/_credentials_async.py
index b4878c5..e7b9637 100644
--- a/google/oauth2/_credentials_async.py
+++ b/google/oauth2/_credentials_async.py
@@ -75,6 +75,7 @@
             self._client_secret,
             scopes=self._scopes,
             rapt_token=self._rapt_token,
+            enable_reauth_refresh=self._enable_reauth_refresh,
         )
 
         self.token = access_token
diff --git a/google/oauth2/_reauth_async.py b/google/oauth2/_reauth_async.py
index 510578b..f74f50b 100644
--- a/google/oauth2/_reauth_async.py
+++ b/google/oauth2/_reauth_async.py
@@ -248,6 +248,7 @@
     client_secret,
     scopes=None,
     rapt_token=None,
+    enable_reauth_refresh=False,
 ):
     """Implements the reauthentication flow.
 
@@ -265,6 +266,9 @@
             token has a wild card scope (e.g.
             'https://www.googleapis.com/auth/any-api').
         rapt_token (Optional(str)): The rapt token for reauth.
+        enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow
+            should be used. The default value is False. This option is for
+            gcloud only, other users should use the default value.
 
     Returns:
         Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The
@@ -299,6 +303,11 @@
             == reauth._REAUTH_NEEDED_ERROR_RAPT_REQUIRED
         )
     ):
+        if not enable_reauth_refresh:
+            raise exceptions.RefreshError(
+                "Reauthentication is needed. Please run `gcloud auth login --update-adc` to reauthenticate."
+            )
+
         rapt_token = await get_rapt_token(
             request, client_id, client_secret, refresh_token, token_uri, scopes=scopes
         )
diff --git a/google/oauth2/challenges.py b/google/oauth2/challenges.py
index 7756a80..0baff62 100644
--- a/google/oauth2/challenges.py
+++ b/google/oauth2/challenges.py
@@ -25,6 +25,9 @@
 
 
 REAUTH_ORIGIN = "https://accounts.google.com"
+SAML_CHALLENGE_MESSAGE = (
+    "Please run `gcloud auth login` to complete reauthentication with SAML."
+)
 
 
 def get_user_password(text):
@@ -148,7 +151,30 @@
         return None
 
 
+class SamlChallenge(ReauthChallenge):
+    """Challenge that asks the users to browse to their ID Providers.
+
+    Currently SAML challenge is not supported. When obtaining the challenge
+    input, exception will be raised to instruct the users to run
+    `gcloud auth login` for reauthentication.
+    """
+
+    @property
+    def name(self):
+        return "SAML"
+
+    @property
+    def is_locally_eligible(self):
+        return True
+
+    def obtain_challenge_input(self, metadata):
+        # Magic Arch has not fully supported returning a proper dedirect URL
+        # for programmatic SAML users today. So we error our here and request
+        # users to use gcloud to complete a login.
+        raise exceptions.ReauthSamlChallengeFailError(SAML_CHALLENGE_MESSAGE)
+
+
 AVAILABLE_CHALLENGES = {
     challenge.name: challenge
-    for challenge in [SecurityKeyChallenge(), PasswordChallenge()]
+    for challenge in [SecurityKeyChallenge(), PasswordChallenge(), SamlChallenge()]
 }
diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py
index 98fd71b..e259f78 100644
--- a/google/oauth2/credentials.py
+++ b/google/oauth2/credentials.py
@@ -54,6 +54,9 @@
 
         credentials = credentials.with_quota_project('myproject-123)
 
+    Reauth is disabled by default. To enable reauth, set the
+    `enable_reauth_refresh` parameter to True in the constructor. Note that
+    reauth feature is intended for gcloud to use only.
     If reauth is enabled, `pyu2f` dependency has to be installed in order to use security
     key reauth feature. Dependency can be installed via `pip install pyu2f` or `pip install
     google-auth[reauth]`.
@@ -73,6 +76,7 @@
         expiry=None,
         rapt_token=None,
         refresh_handler=None,
+        enable_reauth_refresh=False,
     ):
         """
         Args:
@@ -109,6 +113,8 @@
                 refresh tokens are provided and tokens are obtained by calling
                 some external process on demand. It is particularly useful for
                 retrieving downscoped tokens from a token broker.
+            enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow
+                should be used. This flag is for gcloud to use only.
         """
         super(Credentials, self).__init__()
         self.token = token
@@ -123,6 +129,7 @@
         self._quota_project_id = quota_project_id
         self._rapt_token = rapt_token
         self.refresh_handler = refresh_handler
+        self._enable_reauth_refresh = enable_reauth_refresh
 
     def __getstate__(self):
         """A __getstate__ method must exist for the __setstate__ to be called
@@ -151,6 +158,7 @@
         self._client_secret = d.get("_client_secret")
         self._quota_project_id = d.get("_quota_project_id")
         self._rapt_token = d.get("_rapt_token")
+        self._enable_reauth_refresh = d.get("_enable_reauth_refresh")
         # The refresh_handler setter should be used to repopulate this.
         self._refresh_handler = None
 
@@ -241,6 +249,7 @@
             default_scopes=self.default_scopes,
             quota_project_id=quota_project_id,
             rapt_token=self.rapt_token,
+            enable_reauth_refresh=self._enable_reauth_refresh,
         )
 
     @_helpers.copy_docstring(credentials.Credentials)
@@ -296,6 +305,7 @@
             self._client_secret,
             scopes=scopes,
             rapt_token=self._rapt_token,
+            enable_reauth_refresh=self._enable_reauth_refresh,
         )
 
         self.token = access_token
@@ -366,6 +376,7 @@
             client_secret=info.get("client_secret"),
             quota_project_id=info.get("quota_project_id"),  # may not exist
             expiry=expiry,
+            rapt_token=info.get("rapt_token"),  # may not exist
         )
 
     @classmethod
diff --git a/google/oauth2/reauth.py b/google/oauth2/reauth.py
index fc2629e..1e496d1 100644
--- a/google/oauth2/reauth.py
+++ b/google/oauth2/reauth.py
@@ -275,6 +275,7 @@
     client_secret,
     scopes=None,
     rapt_token=None,
+    enable_reauth_refresh=False,
 ):
     """Implements the reauthentication flow.
 
@@ -292,6 +293,9 @@
             token has a wild card scope (e.g.
             'https://www.googleapis.com/auth/any-api').
         rapt_token (Optional(str)): The rapt token for reauth.
+        enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow
+            should be used. The default value is False. This option is for
+            gcloud only, other users should use the default value.
 
     Returns:
         Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The
@@ -324,6 +328,11 @@
             or response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_RAPT_REQUIRED
         )
     ):
+        if not enable_reauth_refresh:
+            raise exceptions.RefreshError(
+                "Reauthentication is needed. Please run `gcloud auth login --update-adc` to reauthenticate."
+            )
+
         rapt_token = get_rapt_token(
             request, client_id, client_secret, refresh_token, token_uri, scopes=scopes
         )
diff --git a/tests/data/authorized_user_with_rapt_token.json b/tests/data/authorized_user_with_rapt_token.json
new file mode 100644
index 0000000..64b161d
--- /dev/null
+++ b/tests/data/authorized_user_with_rapt_token.json
@@ -0,0 +1,8 @@
+{
+    "client_id": "123",
+    "client_secret": "secret",
+    "refresh_token": "alabalaportocala",
+    "type": "authorized_user",
+    "rapt_token": "rapt"
+  }
+  
\ No newline at end of file
diff --git a/tests/oauth2/test_challenges.py b/tests/oauth2/test_challenges.py
index 019b908..412895a 100644
--- a/tests/oauth2/test_challenges.py
+++ b/tests/oauth2/test_challenges.py
@@ -130,3 +130,11 @@
         assert challenges.PasswordChallenge().obtain_challenge_input({}) == {
             "credential": " "
         }
+
+
+def test_saml_challenge():
+    challenge = challenges.SamlChallenge()
+    assert challenge.is_locally_eligible
+    assert challenge.name == "SAML"
+    with pytest.raises(exceptions.ReauthSamlChallengeFailError):
+        challenge.obtain_challenge_input(None)
diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py
index 4a7f66e..b6a80e3 100644
--- a/tests/oauth2/test_credentials.py
+++ b/tests/oauth2/test_credentials.py
@@ -51,6 +51,7 @@
             client_id=cls.CLIENT_ID,
             client_secret=cls.CLIENT_SECRET,
             rapt_token=cls.RAPT_TOKEN,
+            enable_reauth_refresh=True,
         )
 
     def test_default_state(self):
@@ -149,6 +150,7 @@
             self.CLIENT_SECRET,
             None,
             self.RAPT_TOKEN,
+            True,
         )
 
         # Check that the credentials have the token and expiry
@@ -219,6 +221,7 @@
             self.CLIENT_SECRET,
             None,
             self.RAPT_TOKEN,
+            False,
         )
 
         # Check that the credentials have the token and expiry
@@ -422,6 +425,7 @@
             scopes=scopes,
             default_scopes=default_scopes,
             rapt_token=self.RAPT_TOKEN,
+            enable_reauth_refresh=True,
         )
 
         # Refresh credentials
@@ -436,6 +440,7 @@
             self.CLIENT_SECRET,
             scopes,
             self.RAPT_TOKEN,
+            True,
         )
 
         # Check that the credentials have the token and expiry
@@ -484,6 +489,7 @@
             client_secret=self.CLIENT_SECRET,
             default_scopes=default_scopes,
             rapt_token=self.RAPT_TOKEN,
+            enable_reauth_refresh=True,
         )
 
         # Refresh credentials
@@ -498,6 +504,7 @@
             self.CLIENT_SECRET,
             default_scopes,
             self.RAPT_TOKEN,
+            True,
         )
 
         # Check that the credentials have the token and expiry
@@ -549,6 +556,7 @@
             client_secret=self.CLIENT_SECRET,
             scopes=scopes,
             rapt_token=self.RAPT_TOKEN,
+            enable_reauth_refresh=True,
         )
 
         # Refresh credentials
@@ -563,6 +571,7 @@
             self.CLIENT_SECRET,
             scopes,
             self.RAPT_TOKEN,
+            True,
         )
 
         # Check that the credentials have the token and expiry
@@ -615,6 +624,7 @@
             client_secret=self.CLIENT_SECRET,
             scopes=scopes,
             rapt_token=self.RAPT_TOKEN,
+            enable_reauth_refresh=True,
         )
 
         # Refresh credentials
@@ -632,6 +642,7 @@
             self.CLIENT_SECRET,
             scopes,
             self.RAPT_TOKEN,
+            True,
         )
 
         # Check that the credentials have the token and expiry
@@ -731,6 +742,7 @@
         assert creds.refresh_token == info["refresh_token"]
         assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
         assert creds.scopes is None
+        assert creds.rapt_token is None
 
         scopes = ["email", "profile"]
         creds = credentials.Credentials.from_authorized_user_file(
@@ -742,6 +754,18 @@
         assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
         assert creds.scopes == scopes
 
+    def test_from_authorized_user_file_with_rapt_token(self):
+        info = AUTH_USER_INFO.copy()
+        file_path = os.path.join(DATA_DIR, "authorized_user_with_rapt_token.json")
+
+        creds = credentials.Credentials.from_authorized_user_file(file_path)
+        assert creds.client_secret == info["client_secret"]
+        assert creds.client_id == info["client_id"]
+        assert creds.refresh_token == info["refresh_token"]
+        assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+        assert creds.scopes is None
+        assert creds.rapt_token == "rapt"
+
     def test_to_json(self):
         info = AUTH_USER_INFO.copy()
         expiry = datetime.datetime(2020, 8, 14, 15, 54, 1)
diff --git a/tests/oauth2/test_reauth.py b/tests/oauth2/test_reauth.py
index e9ffa8a..58d649d 100644
--- a/tests/oauth2/test_reauth.py
+++ b/tests/oauth2/test_reauth.py
@@ -270,6 +270,7 @@
                 "client_secret",
                 scopes=["foo", "bar"],
                 rapt_token="rapt_token",
+                enable_reauth_refresh=True,
             )
         assert excinfo.match(r"Bad request")
         mock_token_request.assert_called_with(
@@ -298,7 +299,12 @@
             "google.oauth2.reauth.get_rapt_token", return_value="new_rapt_token"
         ):
             assert reauth.refresh_grant(
-                MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
+                MOCK_REQUEST,
+                "token_uri",
+                "refresh_token",
+                "client_id",
+                "client_secret",
+                enable_reauth_refresh=True,
             ) == (
                 "access_token",
                 "refresh_token",
@@ -306,3 +312,18 @@
                 {"access_token": "access_token"},
                 "new_rapt_token",
             )
+
+
+def test_refresh_grant_reauth_refresh_disabled():
+    with mock.patch(
+        "google.oauth2._client._token_endpoint_request_no_throw"
+    ) as mock_token_request:
+        mock_token_request.side_effect = [
+            (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}),
+            (True, {"access_token": "access_token"}),
+        ]
+        with pytest.raises(exceptions.RefreshError) as excinfo:
+            reauth.refresh_grant(
+                MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
+            )
+        assert excinfo.match(r"Reauthentication is needed")
diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py
index 99cf16f..bc89392 100644
--- a/tests_async/oauth2/test_credentials_async.py
+++ b/tests_async/oauth2/test_credentials_async.py
@@ -43,6 +43,7 @@
             token_uri=cls.TOKEN_URI,
             client_id=cls.CLIENT_ID,
             client_secret=cls.CLIENT_SECRET,
+            enable_reauth_refresh=True,
         )
 
     def test_default_state(self):
@@ -97,6 +98,7 @@
             self.CLIENT_SECRET,
             None,
             None,
+            True,
         )
 
         # Check that the credentials have the token and expiry
@@ -169,6 +171,7 @@
             self.CLIENT_SECRET,
             scopes,
             "old_rapt_token",
+            False,
         )
 
         # Check that the credentials have the token and expiry
@@ -231,6 +234,7 @@
             self.CLIENT_SECRET,
             scopes,
             None,
+            False,
         )
 
         # Check that the credentials have the token and expiry
@@ -301,6 +305,7 @@
             self.CLIENT_SECRET,
             scopes,
             None,
+            False,
         )
 
         # Check that the credentials have the token and expiry
diff --git a/tests_async/oauth2/test_reauth_async.py b/tests_async/oauth2/test_reauth_async.py
index f144d89..d982e13 100644
--- a/tests_async/oauth2/test_reauth_async.py
+++ b/tests_async/oauth2/test_reauth_async.py
@@ -318,7 +318,12 @@
             "google.oauth2._reauth_async.get_rapt_token", return_value="new_rapt_token"
         ):
             assert await _reauth_async.refresh_grant(
-                MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
+                MOCK_REQUEST,
+                "token_uri",
+                "refresh_token",
+                "client_id",
+                "client_secret",
+                enable_reauth_refresh=True,
             ) == (
                 "access_token",
                 "refresh_token",
@@ -326,3 +331,19 @@
                 {"access_token": "access_token"},
                 "new_rapt_token",
             )
+
+
[email protected]
+async def test_refresh_grant_reauth_refresh_disabled():
+    with mock.patch(
+        "google.oauth2._client_async._token_endpoint_request_no_throw"
+    ) as mock_token_request:
+        mock_token_request.side_effect = [
+            (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}),
+            (True, {"access_token": "access_token"}),
+        ]
+        with pytest.raises(exceptions.RefreshError) as excinfo:
+            assert await _reauth_async.refresh_grant(
+                MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
+            )
+        assert excinfo.match(r"Reauthentication is needed")