feat: add reauth support to async user credentials (#738)

diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py
index 4817ea4..cf51211 100644
--- a/google/oauth2/_client_async.py
+++ b/google/oauth2/_client_async.py
@@ -30,53 +30,16 @@
 from six.moves import http_client
 from six.moves import urllib
 
-from google.auth import _helpers
 from google.auth import exceptions
 from google.auth import jwt
 from google.oauth2 import _client as client
 
 
-def _handle_error_response(response_body):
-    """"Translates an error response into an exception.
-
-    Args:
-        response_body (str): The decoded response data.
-
-    Raises:
-        google.auth.exceptions.RefreshError
-    """
-    try:
-        error_data = json.loads(response_body)
-        error_details = "{}: {}".format(
-            error_data["error"], error_data.get("error_description")
-        )
-    # If no details could be extracted, use the response data.
-    except (KeyError, ValueError):
-        error_details = response_body
-
-    raise exceptions.RefreshError(error_details, response_body)
-
-
-def _parse_expiry(response_data):
-    """Parses the expiry field from a response into a datetime.
-
-    Args:
-        response_data (Mapping): The JSON-parsed response data.
-
-    Returns:
-        Optional[datetime]: The expiration or ``None`` if no expiration was
-            specified.
-    """
-    expires_in = response_data.get("expires_in", None)
-
-    if expires_in is not None:
-        return _helpers.utcnow() + datetime.timedelta(seconds=expires_in)
-    else:
-        return None
-
-
-async def _token_endpoint_request(request, token_uri, body):
+async def _token_endpoint_request_no_throw(
+    request, token_uri, body, access_token=None, use_json=False
+):
     """Makes a request to the OAuth 2.0 authorization server's token endpoint.
+    This function doesn't throw on response errors.
 
     Args:
         request (google.auth.transport.Request): A callable used to make
@@ -84,16 +47,23 @@
         token_uri (str): The OAuth 2.0 authorizations server's token endpoint
             URI.
         body (Mapping[str, str]): The parameters to send in the request body.
+        access_token (Optional(str)): The access token needed to make the request.
+        use_json (Optional(bool)): Use urlencoded format or json format for the
+            content type. The default value is False.
 
     Returns:
-        Mapping[str, str]: The JSON-decoded response data.
-
-    Raises:
-        google.auth.exceptions.RefreshError: If the token endpoint returned
-            an error.
+        Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
+            successful, and a mapping for the JSON-decoded response data.
     """
-    body = urllib.parse.urlencode(body).encode("utf-8")
-    headers = {"content-type": client._URLENCODED_CONTENT_TYPE}
+    if use_json:
+        headers = {"Content-Type": client._JSON_CONTENT_TYPE}
+        body = json.dumps(body).encode("utf-8")
+    else:
+        headers = {"Content-Type": client._URLENCODED_CONTENT_TYPE}
+        body = urllib.parse.urlencode(body).encode("utf-8")
+
+    if access_token:
+        headers["Authorization"] = "Bearer {}".format(access_token)
 
     retry = 0
     # retry to fetch token for maximum of two times if any internal failure
@@ -126,8 +96,38 @@
             ):
                 retry += 1
                 continue
-            _handle_error_response(response_body)
+            return response.status == http_client.OK, response_data
 
+    return response.status == http_client.OK, response_data
+
+
+async def _token_endpoint_request(
+    request, token_uri, body, access_token=None, use_json=False
+):
+    """Makes a request to the OAuth 2.0 authorization server's token endpoint.
+
+    Args:
+        request (google.auth.transport.Request): A callable used to make
+            HTTP requests.
+        token_uri (str): The OAuth 2.0 authorizations server's token endpoint
+            URI.
+        body (Mapping[str, str]): The parameters to send in the request body.
+        access_token (Optional(str)): The access token needed to make the request.
+        use_json (Optional(bool)): Use urlencoded format or json format for the
+            content type. The default value is False.
+
+    Returns:
+        Mapping[str, str]: The JSON-decoded response data.
+
+    Raises:
+        google.auth.exceptions.RefreshError: If the token endpoint returned
+            an error.
+    """
+    response_status_ok, response_data = await _token_endpoint_request_no_throw(
+        request, token_uri, body, access_token=access_token, use_json=use_json
+    )
+    if not response_status_ok:
+        client._handle_error_response(response_data)
     return response_data
 
 
@@ -163,7 +163,7 @@
         new_exc = exceptions.RefreshError("No access token in response.", response_data)
         six.raise_from(new_exc, caught_exc)
 
-    expiry = _parse_expiry(response_data)
+    expiry = client._parse_expiry(response_data)
 
     return access_token, expiry, response_data
 
@@ -210,7 +210,13 @@
 
 
 async def refresh_grant(
-    request, token_uri, refresh_token, client_id, client_secret, scopes=None
+    request,
+    token_uri,
+    refresh_token,
+    client_id,
+    client_secret,
+    scopes=None,
+    rapt_token=None,
 ):
     """Implements the OAuth 2.0 refresh token grant.
 
@@ -229,10 +235,11 @@
             scopes must be authorized for the refresh token. Useful if refresh
             token has a wild card scope (e.g.
             'https://www.googleapis.com/auth/any-api').
+        rapt_token (Optional(str)): The reauth Proof Token.
 
     Returns:
         Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
-            access token, new refresh token, expiration, and additional data
+            access token, new or current refresh token, expiration, and additional data
             returned by the token endpoint.
 
     Raises:
@@ -249,16 +256,8 @@
     }
     if scopes:
         body["scope"] = " ".join(scopes)
+    if rapt_token:
+        body["rapt"] = rapt_token
 
     response_data = await _token_endpoint_request(request, token_uri, body)
-
-    try:
-        access_token = response_data["access_token"]
-    except KeyError as caught_exc:
-        new_exc = exceptions.RefreshError("No access token in response.", response_data)
-        six.raise_from(new_exc, caught_exc)
-
-    refresh_token = response_data.get("refresh_token", refresh_token)
-    expiry = _parse_expiry(response_data)
-
-    return access_token, refresh_token, expiry, response_data
+    return client._handle_refresh_grant_response(response_data, refresh_token)
diff --git a/google/oauth2/_credentials_async.py b/google/oauth2/_credentials_async.py
index eb3e97c..b4878c5 100644
--- a/google/oauth2/_credentials_async.py
+++ b/google/oauth2/_credentials_async.py
@@ -34,7 +34,7 @@
 from google.auth import _credentials_async as credentials
 from google.auth import _helpers
 from google.auth import exceptions
-from google.oauth2 import _client_async as _client
+from google.oauth2 import _reauth_async as reauth
 from google.oauth2 import credentials as oauth2_credentials
 
 
@@ -66,23 +66,26 @@
             refresh_token,
             expiry,
             grant_response,
-        ) = await _client.refresh_grant(
+            rapt_token,
+        ) = await reauth.refresh_grant(
             request,
             self._token_uri,
             self._refresh_token,
             self._client_id,
             self._client_secret,
-            self._scopes,
+            scopes=self._scopes,
+            rapt_token=self._rapt_token,
         )
 
         self.token = access_token
         self.expiry = expiry
         self._refresh_token = refresh_token
         self._id_token = grant_response.get("id_token")
+        self._rapt_token = rapt_token
 
-        if self._scopes and "scopes" in grant_response:
+        if self._scopes and "scope" in grant_response:
             requested_scopes = frozenset(self._scopes)
-            granted_scopes = frozenset(grant_response["scopes"].split())
+            granted_scopes = frozenset(grant_response["scope"].split())
             scopes_requested_but_not_granted = requested_scopes - granted_scopes
             if scopes_requested_but_not_granted:
                 raise exceptions.RefreshError(
diff --git a/google/oauth2/_reauth_async.py b/google/oauth2/_reauth_async.py
new file mode 100644
index 0000000..09e0760
--- /dev/null
+++ b/google/oauth2/_reauth_async.py
@@ -0,0 +1,320 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A module that provides functions for handling rapt authentication.
+
+Reauth is a process of obtaining additional authentication (such as password,
+security token, etc.) while refreshing OAuth 2.0 credentials for a user.
+
+Credentials that use the Reauth flow must have the reauth scope,
+``https://www.googleapis.com/auth/accounts.reauth``.
+
+This module provides a high-level function for executing the Reauth process,
+:func:`refresh_grant`, and lower-level helpers for doing the individual
+steps of the reauth process.
+
+Those steps are:
+
+1. Obtaining a list of challenges from the reauth server.
+2. Running through each challenge and sending the result back to the reauth
+   server.
+3. Refreshing the access token using the returned rapt token.
+"""
+
+import sys
+
+from six.moves import range
+
+from google.auth import exceptions
+from google.oauth2 import _client
+from google.oauth2 import _client_async
+from google.oauth2 import challenges
+from google.oauth2 import reauth
+
+
+async def _get_challenges(
+    request, supported_challenge_types, access_token, requested_scopes=None
+):
+    """Does initial request to reauth API to get the challenges.
+
+    Args:
+        request (google.auth.transport.Request): A callable used to make
+            HTTP requests. This must be an aiohttp request.
+        supported_challenge_types (Sequence[str]): list of challenge names
+            supported by the manager.
+        access_token (str): Access token with reauth scopes.
+        requested_scopes (Optional(Sequence[str])): Authorized scopes for the credentials.
+
+    Returns:
+        dict: The response from the reauth API.
+    """
+    body = {"supportedChallengeTypes": supported_challenge_types}
+    if requested_scopes:
+        body["oauthScopesForDomainPolicyLookup"] = requested_scopes
+
+    return await _client_async._token_endpoint_request(
+        request,
+        reauth._REAUTH_API + ":start",
+        body,
+        access_token=access_token,
+        use_json=True,
+    )
+
+
+async def _send_challenge_result(
+    request, session_id, challenge_id, client_input, access_token
+):
+    """Attempt to refresh access token by sending next challenge result.
+
+    Args:
+        request (google.auth.transport.Request): A callable used to make
+            HTTP requests. This must be an aiohttp request.
+        session_id (str): session id returned by the initial reauth call.
+        challenge_id (str): challenge id returned by the initial reauth call.
+        client_input: dict with a challenge-specific client input. For example:
+            ``{'credential': password}`` for password challenge.
+        access_token (str): Access token with reauth scopes.
+
+    Returns:
+        dict: The response from the reauth API.
+    """
+    body = {
+        "sessionId": session_id,
+        "challengeId": challenge_id,
+        "action": "RESPOND",
+        "proposalResponse": client_input,
+    }
+
+    return await _client_async._token_endpoint_request(
+        request,
+        reauth._REAUTH_API + "/{}:continue".format(session_id),
+        body,
+        access_token=access_token,
+        use_json=True,
+    )
+
+
+async def _run_next_challenge(msg, request, access_token):
+    """Get the next challenge from msg and run it.
+
+    Args:
+        msg (dict): Reauth API response body (either from the initial request to
+            https://reauth.googleapis.com/v2/sessions:start or from sending the
+            previous challenge response to
+            https://reauth.googleapis.com/v2/sessions/id:continue)
+        request (google.auth.transport.Request): A callable used to make
+            HTTP requests. This must be an aiohttp request.
+        access_token (str): reauth access token
+
+    Returns:
+        dict: The response from the reauth API.
+
+    Raises:
+        google.auth.exceptions.ReauthError: if reauth failed.
+    """
+    for challenge in msg["challenges"]:
+        if challenge["status"] != "READY":
+            # Skip non-activated challenges.
+            continue
+        c = challenges.AVAILABLE_CHALLENGES.get(challenge["challengeType"], None)
+        if not c:
+            raise exceptions.ReauthFailError(
+                "Unsupported challenge type {0}. Supported types: {1}".format(
+                    challenge["challengeType"],
+                    ",".join(list(challenges.AVAILABLE_CHALLENGES.keys())),
+                )
+            )
+        if not c.is_locally_eligible:
+            raise exceptions.ReauthFailError(
+                "Challenge {0} is not locally eligible".format(
+                    challenge["challengeType"]
+                )
+            )
+        client_input = c.obtain_challenge_input(challenge)
+        if not client_input:
+            return None
+        return await _send_challenge_result(
+            request,
+            msg["sessionId"],
+            challenge["challengeId"],
+            client_input,
+            access_token,
+        )
+    return None
+
+
+async def _obtain_rapt(request, access_token, requested_scopes):
+    """Given an http request method and reauth access token, get rapt token.
+
+    Args:
+        request (google.auth.transport.Request): A callable used to make
+            HTTP requests. This must be an aiohttp request.
+        access_token (str): reauth access token
+        requested_scopes (Sequence[str]): scopes required by the client application
+
+    Returns:
+        str: The rapt token.
+
+    Raises:
+        google.auth.exceptions.ReauthError: if reauth failed
+    """
+    msg = await _get_challenges(
+        request,
+        list(challenges.AVAILABLE_CHALLENGES.keys()),
+        access_token,
+        requested_scopes,
+    )
+
+    if msg["status"] == reauth._AUTHENTICATED:
+        return msg["encodedProofOfReauthToken"]
+
+    for _ in range(0, reauth.RUN_CHALLENGE_RETRY_LIMIT):
+        if not (
+            msg["status"] == reauth._CHALLENGE_REQUIRED
+            or msg["status"] == reauth._CHALLENGE_PENDING
+        ):
+            raise exceptions.ReauthFailError(
+                "Reauthentication challenge failed due to API error: {}".format(
+                    msg["status"]
+                )
+            )
+
+        if not reauth.is_interactive():
+            raise exceptions.ReauthFailError(
+                "Reauthentication challenge could not be answered because you are not"
+                " in an interactive session."
+            )
+
+        msg = await _run_next_challenge(msg, request, access_token)
+
+        if msg["status"] == reauth._AUTHENTICATED:
+            return msg["encodedProofOfReauthToken"]
+
+    # If we got here it means we didn't get authenticated.
+    raise exceptions.ReauthFailError("Failed to obtain rapt token.")
+
+
+async def get_rapt_token(
+    request, client_id, client_secret, refresh_token, token_uri, scopes=None
+):
+    """Given an http request method and refresh_token, get rapt token.
+
+    Args:
+        request (google.auth.transport.Request): A callable used to make
+            HTTP requests. This must be an aiohttp request.
+        client_id (str): client id to get access token for reauth scope.
+        client_secret (str): client secret for the client_id
+        refresh_token (str): refresh token to refresh access token
+        token_uri (str): uri to refresh access token
+        scopes (Optional(Sequence[str])): scopes required by the client application
+
+    Returns:
+        str: The rapt token.
+    Raises:
+        google.auth.exceptions.RefreshError: If reauth failed.
+    """
+    sys.stderr.write("Reauthentication required.\n")
+
+    # Get access token for reauth.
+    access_token, _, _, _ = await _client_async.refresh_grant(
+        request=request,
+        client_id=client_id,
+        client_secret=client_secret,
+        refresh_token=refresh_token,
+        token_uri=token_uri,
+        scopes=[reauth._REAUTH_SCOPE],
+    )
+
+    # Get rapt token from reauth API.
+    rapt_token = await _obtain_rapt(request, access_token, requested_scopes=scopes)
+
+    return rapt_token
+
+
+async def refresh_grant(
+    request,
+    token_uri,
+    refresh_token,
+    client_id,
+    client_secret,
+    scopes=None,
+    rapt_token=None,
+):
+    """Implements the reauthentication flow.
+
+    Args:
+        request (google.auth.transport.Request): A callable used to make
+            HTTP requests. This must be an aiohttp request.
+        token_uri (str): The OAuth 2.0 authorizations server's token endpoint
+            URI.
+        refresh_token (str): The refresh token to use to get a new access
+            token.
+        client_id (str): The OAuth 2.0 application's client ID.
+        client_secret (str): The Oauth 2.0 appliaction's client secret.
+        scopes (Optional(Sequence[str])): Scopes to request. If present, all
+            scopes must be authorized for the refresh token. Useful if refresh
+            token has a wild card scope (e.g.
+            'https://www.googleapis.com/auth/any-api').
+        rapt_token (Optional(str)): The rapt token for reauth.
+
+    Returns:
+        Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The
+            access token, new refresh token, expiration, the additional data
+            returned by the token endpoint, and the rapt token.
+
+    Raises:
+        google.auth.exceptions.RefreshError: If the token endpoint returned
+            an error.
+    """
+    body = {
+        "grant_type": _client._REFRESH_GRANT_TYPE,
+        "client_id": client_id,
+        "client_secret": client_secret,
+        "refresh_token": refresh_token,
+    }
+    if scopes:
+        body["scope"] = " ".join(scopes)
+    if rapt_token:
+        body["rapt"] = rapt_token
+
+    response_status_ok, response_data = await _client_async._token_endpoint_request_no_throw(
+        request, token_uri, body
+    )
+    if (
+        not response_status_ok
+        and response_data.get("error") == reauth._REAUTH_NEEDED_ERROR
+        and (
+            response_data.get("error_subtype")
+            == reauth._REAUTH_NEEDED_ERROR_INVALID_RAPT
+            or response_data.get("error_subtype")
+            == reauth._REAUTH_NEEDED_ERROR_RAPT_REQUIRED
+        )
+    ):
+        rapt_token = await get_rapt_token(
+            request, client_id, client_secret, refresh_token, token_uri, scopes=scopes
+        )
+        body["rapt"] = rapt_token
+        (
+            response_status_ok,
+            response_data,
+        ) = await _client_async._token_endpoint_request_no_throw(
+            request, token_uri, body
+        )
+
+    if not response_status_ok:
+        _client._handle_error_response(response_data)
+    refresh_response = _client._handle_refresh_grant_response(
+        response_data, refresh_token
+    )
+    return refresh_response + (rapt_token,)
diff --git a/google/oauth2/reauth.py b/google/oauth2/reauth.py
index d539d7c..d914fe9 100644
--- a/google/oauth2/reauth.py
+++ b/google/oauth2/reauth.py
@@ -296,9 +296,9 @@
         rapt_token (Optional(str)): The rapt token for reauth.
 
     Returns:
-        Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
-            access token, new refresh token, expiration, and additional data
-            returned by the token endpoint.
+        Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The
+            access token, new refresh token, expiration, the additional data
+            returned by the token endpoint, and the rapt token.
 
     Raises:
         google.auth.exceptions.RefreshError: If the token endpoint returned
diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py
index 458937a..6e48c45 100644
--- a/tests_async/oauth2/test__client_async.py
+++ b/tests_async/oauth2/test__client_async.py
@@ -29,34 +29,6 @@
 from tests.oauth2 import test__client as test_client
 
 
-def test__handle_error_response():
-    response_data = json.dumps({"error": "help", "error_description": "I'm alive"})
-
-    with pytest.raises(exceptions.RefreshError) as excinfo:
-        _client._handle_error_response(response_data)
-
-    assert excinfo.match(r"help: I\'m alive")
-
-
-def test__handle_error_response_non_json():
-    response_data = "Help, I'm alive"
-
-    with pytest.raises(exceptions.RefreshError) as excinfo:
-        _client._handle_error_response(response_data)
-
-    assert excinfo.match(r"Help, I\'m alive")
-
-
[email protected]("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
-def test__parse_expiry(unused_utcnow):
-    result = _client._parse_expiry({"expires_in": 500})
-    assert result == datetime.datetime.min + datetime.timedelta(seconds=500)
-
-
-def test__parse_expiry_none():
-    assert _client._parse_expiry({}) is None
-
-
 def make_request(response_data, status=http_client.OK):
     response = mock.AsyncMock(spec=["transport.Response"])
     response.status = status
@@ -82,7 +54,7 @@
     request.assert_called_with(
         method="POST",
         url="http://example.com",
-        headers={"content-type": "application/x-www-form-urlencoded"},
+        headers={"Content-Type": "application/x-www-form-urlencoded"},
         body="test=params".encode("utf-8"),
     )
 
@@ -91,6 +63,35 @@
 
 
 @pytest.mark.asyncio
+async def test__token_endpoint_request_json():
+
+    request = make_request({"test": "response"})
+    access_token = "access_token"
+
+    result = await _client._token_endpoint_request(
+        request,
+        "http://example.com",
+        {"test": "params"},
+        access_token=access_token,
+        use_json=True,
+    )
+
+    # Check request call
+    request.assert_called_with(
+        method="POST",
+        url="http://example.com",
+        headers={
+            "Content-Type": "application/json",
+            "Authorization": "Bearer access_token",
+        },
+        body=b'{"test": "params"}',
+    )
+
+    # Check result
+    assert result == {"test": "response"}
+
+
[email protected]
 async def test__token_endpoint_request_error():
     request = make_request({}, status=http_client.BAD_REQUEST)
 
@@ -218,7 +219,12 @@
     )
 
     token, refresh_token, expiry, extra_data = await _client.refresh_grant(
-        request, "http://example.com", "refresh_token", "client_id", "client_secret"
+        request,
+        "http://example.com",
+        "refresh_token",
+        "client_id",
+        "client_secret",
+        rapt_token="rapt_token",
     )
 
     # Check request call
@@ -229,6 +235,7 @@
             "refresh_token": "refresh_token",
             "client_id": "client_id",
             "client_secret": "client_secret",
+            "rapt": "rapt_token",
         },
     )
 
diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py
index 5c883d6..99cf16f 100644
--- a/tests_async/oauth2/test_credentials_async.py
+++ b/tests_async/oauth2/test_credentials_async.py
@@ -58,7 +58,7 @@
         assert credentials.client_id == self.CLIENT_ID
         assert credentials.client_secret == self.CLIENT_SECRET
 
-    @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+    @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True)
     @mock.patch(
         "google.auth._helpers.utcnow",
         return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
@@ -68,6 +68,7 @@
         token = "token"
         expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
         grant_response = {"id_token": mock.sentinel.id_token}
+        rapt_token = "rapt_token"
         refresh_grant.return_value = (
             # Access token
             token,
@@ -77,6 +78,8 @@
             expiry,
             # Extra data
             grant_response,
+            # Rapt token
+            rapt_token,
         )
 
         request = mock.AsyncMock(spec=["transport.Request"])
@@ -93,12 +96,14 @@
             self.CLIENT_ID,
             self.CLIENT_SECRET,
             None,
+            None,
         )
 
         # Check that the credentials have the token and expiry
         assert creds.token == token
         assert creds.expiry == expiry
         assert creds.id_token == mock.sentinel.id_token
+        assert creds.rapt_token == rapt_token
 
         # Check that the credentials are valid (have a token and are not
         # expired)
@@ -114,7 +119,7 @@
 
         request.assert_not_called()
 
-    @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+    @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True)
     @mock.patch(
         "google.auth._helpers.utcnow",
         return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
@@ -127,6 +132,7 @@
         token = "token"
         expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
         grant_response = {"id_token": mock.sentinel.id_token}
+        rapt_token = "rapt_token"
         refresh_grant.return_value = (
             # Access token
             token,
@@ -136,6 +142,8 @@
             expiry,
             # Extra data
             grant_response,
+            # Rapt token
+            rapt_token,
         )
 
         request = mock.AsyncMock(spec=["transport.Request"])
@@ -146,6 +154,7 @@
             client_id=self.CLIENT_ID,
             client_secret=self.CLIENT_SECRET,
             scopes=scopes,
+            rapt_token="old_rapt_token",
         )
 
         # Refresh credentials
@@ -159,6 +168,7 @@
             self.CLIENT_ID,
             self.CLIENT_SECRET,
             scopes,
+            "old_rapt_token",
         )
 
         # Check that the credentials have the token and expiry
@@ -166,12 +176,13 @@
         assert creds.expiry == expiry
         assert creds.id_token == mock.sentinel.id_token
         assert creds.has_scopes(scopes)
+        assert creds.rapt_token == rapt_token
 
         # Check that the credentials are valid (have a token and are not
         # expired.)
         assert creds.valid
 
-    @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+    @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True)
     @mock.patch(
         "google.auth._helpers.utcnow",
         return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
@@ -183,10 +194,8 @@
         scopes = ["email", "profile"]
         token = "token"
         expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
-        grant_response = {
-            "id_token": mock.sentinel.id_token,
-            "scopes": " ".join(scopes),
-        }
+        grant_response = {"id_token": mock.sentinel.id_token, "scope": " ".join(scopes)}
+        rapt_token = "rapt_token"
         refresh_grant.return_value = (
             # Access token
             token,
@@ -196,6 +205,8 @@
             expiry,
             # Extra data
             grant_response,
+            # Rapt token
+            rapt_token,
         )
 
         request = mock.AsyncMock(spec=["transport.Request"])
@@ -219,6 +230,7 @@
             self.CLIENT_ID,
             self.CLIENT_SECRET,
             scopes,
+            None,
         )
 
         # Check that the credentials have the token and expiry
@@ -226,12 +238,13 @@
         assert creds.expiry == expiry
         assert creds.id_token == mock.sentinel.id_token
         assert creds.has_scopes(scopes)
+        assert creds.rapt_token == rapt_token
 
         # Check that the credentials are valid (have a token and are not
         # expired.)
         assert creds.valid
 
-    @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+    @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True)
     @mock.patch(
         "google.auth._helpers.utcnow",
         return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
@@ -246,8 +259,9 @@
         expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
         grant_response = {
             "id_token": mock.sentinel.id_token,
-            "scopes": " ".join(scopes_returned),
+            "scope": " ".join(scopes_returned),
         }
+        rapt_token = "rapt_token"
         refresh_grant.return_value = (
             # Access token
             token,
@@ -257,6 +271,8 @@
             expiry,
             # Extra data
             grant_response,
+            # Rapt token
+            rapt_token,
         )
 
         request = mock.AsyncMock(spec=["transport.Request"])
@@ -267,6 +283,7 @@
             client_id=self.CLIENT_ID,
             client_secret=self.CLIENT_SECRET,
             scopes=scopes,
+            rapt_token=None,
         )
 
         # Refresh credentials
@@ -283,6 +300,7 @@
             self.CLIENT_ID,
             self.CLIENT_SECRET,
             scopes,
+            None,
         )
 
         # Check that the credentials have the token and expiry
diff --git a/tests_async/oauth2/test_reauth_async.py b/tests_async/oauth2/test_reauth_async.py
new file mode 100644
index 0000000..f144d89
--- /dev/null
+++ b/tests_async/oauth2/test_reauth_async.py
@@ -0,0 +1,328 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+import mock
+import pytest
+
+from google.auth import exceptions
+from google.oauth2 import _reauth_async
+from google.oauth2 import reauth
+
+
+MOCK_REQUEST = mock.AsyncMock(spec=["transport.Request"])
+CHALLENGES_RESPONSE_TEMPLATE = {
+    "status": "CHALLENGE_REQUIRED",
+    "sessionId": "123",
+    "challenges": [
+        {
+            "status": "READY",
+            "challengeId": 1,
+            "challengeType": "PASSWORD",
+            "securityKey": {},
+        }
+    ],
+}
+CHALLENGES_RESPONSE_AUTHENTICATED = {
+    "status": "AUTHENTICATED",
+    "sessionId": "123",
+    "encodedProofOfReauthToken": "new_rapt_token",
+}
+
+
+class MockChallenge(object):
+    def __init__(self, name, locally_eligible, challenge_input):
+        self.name = name
+        self.is_locally_eligible = locally_eligible
+        self.challenge_input = challenge_input
+
+    def obtain_challenge_input(self, metadata):
+        return self.challenge_input
+
+
[email protected]
+async def test__get_challenges():
+    with mock.patch(
+        "google.oauth2._client_async._token_endpoint_request"
+    ) as mock_token_endpoint_request:
+        await _reauth_async._get_challenges(MOCK_REQUEST, ["SAML"], "token")
+        mock_token_endpoint_request.assert_called_with(
+            MOCK_REQUEST,
+            reauth._REAUTH_API + ":start",
+            {"supportedChallengeTypes": ["SAML"]},
+            access_token="token",
+            use_json=True,
+        )
+
+
[email protected]
+async def test__get_challenges_with_scopes():
+    with mock.patch(
+        "google.oauth2._client_async._token_endpoint_request"
+    ) as mock_token_endpoint_request:
+        await _reauth_async._get_challenges(
+            MOCK_REQUEST, ["SAML"], "token", requested_scopes=["scope"]
+        )
+        mock_token_endpoint_request.assert_called_with(
+            MOCK_REQUEST,
+            reauth._REAUTH_API + ":start",
+            {
+                "supportedChallengeTypes": ["SAML"],
+                "oauthScopesForDomainPolicyLookup": ["scope"],
+            },
+            access_token="token",
+            use_json=True,
+        )
+
+
[email protected]
+async def test__send_challenge_result():
+    with mock.patch(
+        "google.oauth2._client_async._token_endpoint_request"
+    ) as mock_token_endpoint_request:
+        await _reauth_async._send_challenge_result(
+            MOCK_REQUEST, "123", "1", {"credential": "password"}, "token"
+        )
+        mock_token_endpoint_request.assert_called_with(
+            MOCK_REQUEST,
+            reauth._REAUTH_API + "/123:continue",
+            {
+                "sessionId": "123",
+                "challengeId": "1",
+                "action": "RESPOND",
+                "proposalResponse": {"credential": "password"},
+            },
+            access_token="token",
+            use_json=True,
+        )
+
+
[email protected]
+async def test__run_next_challenge_not_ready():
+    challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE)
+    challenges_response["challenges"][0]["status"] = "STATUS_UNSPECIFIED"
+    assert (
+        await _reauth_async._run_next_challenge(
+            challenges_response, MOCK_REQUEST, "token"
+        )
+        is None
+    )
+
+
[email protected]
+async def test__run_next_challenge_not_supported():
+    challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE)
+    challenges_response["challenges"][0]["challengeType"] = "CHALLENGE_TYPE_UNSPECIFIED"
+    with pytest.raises(exceptions.ReauthFailError) as excinfo:
+        await _reauth_async._run_next_challenge(
+            challenges_response, MOCK_REQUEST, "token"
+        )
+    assert excinfo.match(r"Unsupported challenge type CHALLENGE_TYPE_UNSPECIFIED")
+
+
[email protected]
+async def test__run_next_challenge_not_locally_eligible():
+    mock_challenge = MockChallenge("PASSWORD", False, "challenge_input")
+    with mock.patch(
+        "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge}
+    ):
+        with pytest.raises(exceptions.ReauthFailError) as excinfo:
+            await _reauth_async._run_next_challenge(
+                CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token"
+            )
+        assert excinfo.match(r"Challenge PASSWORD is not locally eligible")
+
+
[email protected]
+async def test__run_next_challenge_no_challenge_input():
+    mock_challenge = MockChallenge("PASSWORD", True, None)
+    with mock.patch(
+        "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge}
+    ):
+        assert (
+            await _reauth_async._run_next_challenge(
+                CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token"
+            )
+            is None
+        )
+
+
[email protected]
+async def test__run_next_challenge_success():
+    mock_challenge = MockChallenge("PASSWORD", True, {"credential": "password"})
+    with mock.patch(
+        "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge}
+    ):
+        with mock.patch(
+            "google.oauth2._reauth_async._send_challenge_result"
+        ) as mock_send_challenge_result:
+            await _reauth_async._run_next_challenge(
+                CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token"
+            )
+            mock_send_challenge_result.assert_called_with(
+                MOCK_REQUEST, "123", 1, {"credential": "password"}, "token"
+            )
+
+
[email protected]
+async def test__obtain_rapt_authenticated():
+    with mock.patch(
+        "google.oauth2._reauth_async._get_challenges",
+        return_value=CHALLENGES_RESPONSE_AUTHENTICATED,
+    ):
+        new_rapt_token = await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
+        assert new_rapt_token == "new_rapt_token"
+
+
[email protected]
+async def test__obtain_rapt_authenticated_after_run_next_challenge():
+    with mock.patch(
+        "google.oauth2._reauth_async._get_challenges",
+        return_value=CHALLENGES_RESPONSE_TEMPLATE,
+    ):
+        with mock.patch(
+            "google.oauth2._reauth_async._run_next_challenge",
+            side_effect=[
+                CHALLENGES_RESPONSE_TEMPLATE,
+                CHALLENGES_RESPONSE_AUTHENTICATED,
+            ],
+        ):
+            with mock.patch("google.oauth2.reauth.is_interactive", return_value=True):
+                assert (
+                    await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
+                    == "new_rapt_token"
+                )
+
+
[email protected]
+async def test__obtain_rapt_unsupported_status():
+    challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE)
+    challenges_response["status"] = "STATUS_UNSPECIFIED"
+    with mock.patch(
+        "google.oauth2._reauth_async._get_challenges", return_value=challenges_response
+    ):
+        with pytest.raises(exceptions.ReauthFailError) as excinfo:
+            await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
+        assert excinfo.match(r"API error: STATUS_UNSPECIFIED")
+
+
[email protected]
+async def test__obtain_rapt_not_interactive():
+    with mock.patch(
+        "google.oauth2._reauth_async._get_challenges",
+        return_value=CHALLENGES_RESPONSE_TEMPLATE,
+    ):
+        with mock.patch("google.oauth2.reauth.is_interactive", return_value=False):
+            with pytest.raises(exceptions.ReauthFailError) as excinfo:
+                await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
+            assert excinfo.match(r"not in an interactive session")
+
+
[email protected]
+async def test__obtain_rapt_not_authenticated():
+    with mock.patch(
+        "google.oauth2._reauth_async._get_challenges",
+        return_value=CHALLENGES_RESPONSE_TEMPLATE,
+    ):
+        with mock.patch("google.oauth2.reauth.RUN_CHALLENGE_RETRY_LIMIT", 0):
+            with pytest.raises(exceptions.ReauthFailError) as excinfo:
+                await _reauth_async._obtain_rapt(MOCK_REQUEST, "token", None)
+            assert excinfo.match(r"Reauthentication failed")
+
+
[email protected]
+async def test_get_rapt_token():
+    with mock.patch(
+        "google.oauth2._client_async.refresh_grant",
+        return_value=("token", None, None, None),
+    ) as mock_refresh_grant:
+        with mock.patch(
+            "google.oauth2._reauth_async._obtain_rapt", return_value="new_rapt_token"
+        ) as mock_obtain_rapt:
+            assert (
+                await _reauth_async.get_rapt_token(
+                    MOCK_REQUEST,
+                    "client_id",
+                    "client_secret",
+                    "refresh_token",
+                    "token_uri",
+                )
+                == "new_rapt_token"
+            )
+            mock_refresh_grant.assert_called_with(
+                request=MOCK_REQUEST,
+                client_id="client_id",
+                client_secret="client_secret",
+                refresh_token="refresh_token",
+                token_uri="token_uri",
+                scopes=[reauth._REAUTH_SCOPE],
+            )
+            mock_obtain_rapt.assert_called_with(
+                MOCK_REQUEST, "token", requested_scopes=None
+            )
+
+
[email protected]
+async def test_refresh_grant_failed():
+    with mock.patch(
+        "google.oauth2._client_async._token_endpoint_request_no_throw"
+    ) as mock_token_request:
+        mock_token_request.return_value = (False, {"error": "Bad request"})
+        with pytest.raises(exceptions.RefreshError) as excinfo:
+            await _reauth_async.refresh_grant(
+                MOCK_REQUEST,
+                "token_uri",
+                "refresh_token",
+                "client_id",
+                "client_secret",
+                scopes=["foo", "bar"],
+                rapt_token="rapt_token",
+            )
+        assert excinfo.match(r"Bad request")
+        mock_token_request.assert_called_with(
+            MOCK_REQUEST,
+            "token_uri",
+            {
+                "grant_type": "refresh_token",
+                "client_id": "client_id",
+                "client_secret": "client_secret",
+                "refresh_token": "refresh_token",
+                "scope": "foo bar",
+                "rapt": "rapt_token",
+            },
+        )
+
+
[email protected]
+async def test_refresh_grant_success():
+    with mock.patch(
+        "google.oauth2._client_async._token_endpoint_request_no_throw"
+    ) as mock_token_request:
+        mock_token_request.side_effect = [
+            (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}),
+            (True, {"access_token": "access_token"}),
+        ]
+        with mock.patch(
+            "google.oauth2._reauth_async.get_rapt_token", return_value="new_rapt_token"
+        ):
+            assert await _reauth_async.refresh_grant(
+                MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
+            ) == (
+                "access_token",
+                "refresh_token",
+                None,
+                {"access_token": "access_token"},
+                "new_rapt_token",
+            )