diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index bfec8e8d..0b91b23b 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -7,6 +7,7 @@ from oauthlib.oauth2 import LegacyApplicationClient from oauthlib.oauth2 import TokenExpiredError, is_secure_transport import requests +from requests.exceptions import HTTPError log = logging.getLogger(__name__) @@ -230,9 +231,9 @@ def fetch_token( `auth` tuple. If the value is `None`, it will be omitted from the request, however if the value is an empty string, an empty string will be sent. - :param cert: Client certificate to send for OAuth 2.0 Mutual-TLS Client - Authentication (draft-ietf-oauth-mtls). Can either be the - path of a file containing the private key and certificate or + :param cert: Client certificate to send for OAuth 2.0 Mutual-TLS Client + Authentication (draft-ietf-oauth-mtls). Can either be the + path of a file containing the private key and certificate or a tuple of two filenames for certificate and key. :param kwargs: Extra parameters to include in the token request. :return: A token dict @@ -363,6 +364,8 @@ def fetch_token( log.debug("Invoking hook %s.", hook) r = hook(r) + self._raise_for_5xx(response=r) + self._client.parse_request_body_response(r.text, scope=self.scope) self.token = self._client.token log.debug("Obtained token %s.", self.token) @@ -449,6 +452,8 @@ def refresh_token( log.debug("Invoking hook %s.", hook) r = hook(r) + self._raise_for_5xx(response=r) + self.token = self._client.parse_request_body_response(r.text, scope=self.scope) if not "refresh_token" in self.token: log.debug("No new refresh token given. Re-using old.") @@ -538,3 +543,40 @@ def register_compliance_hook(self, hook_type, hook): "Hook type %s is not in %s.", hook_type, self.compliance_hook ) self.compliance_hook[hook_type].add(hook) + + def _raise_for_5xx(self, response): + # type: (requests.models.Response) -> None + """ + Raise requests.HTTPError if response is an HTTP 5XX error. + + Just like the existing Response.raise_for_status() but ignores 4XX + errors. + + :param response: HTTP response object from requests + Raises :class:`requests.exceptions.HTTPError`, if a 5XX error occurred. + """ + http_error_msg = "" + if isinstance(response.reason, bytes): + # We attempt to decode utf-8 first because some servers + # choose to localize their reason strings. If the string + # isn't utf-8, we fall back to iso-8859-1 for all other + # encodings. (See psf/requests PR #3538) + try: + reason = response.reason.decode("utf-8") + except UnicodeDecodeError: + reason = response.reason.decode("iso-8859-1") + else: + reason = response.reason + + if 400 <= response.status_code < 500: + pass # ignored + + elif 500 <= response.status_code < 600: + http_error_msg = "%s Server Error: %s for url: %s" % ( + response.status_code, + reason, + response.url, + ) + + if http_error_msg: + raise HTTPError(http_error_msg, response=response) diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index cfc62368..282bf2f2 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -22,15 +22,17 @@ import requests from requests.auth import _basic_auth_str +from requests.exceptions import HTTPError fake_time = time.time() CODE = "asdf345xdf" -def fake_token(token): +def fake_token(token, status_code=200): def fake_send(r, **kwargs): resp = mock.MagicMock() + resp.status_code = status_code resp.text = json.dumps(token) return resp @@ -70,6 +72,7 @@ def verifier(r, **kwargs): auth_header = r.headers.get(str("Authorization"), None) self.assertEqual(auth_header, token) resp = mock.MagicMock() + resp.status_code = 200 resp.cookes = [] return resp @@ -89,6 +92,7 @@ def verifier(r, **kwargs): self.assertEqual(cert, kwargs["cert"]) self.assertIn("client_id=" + self.client_id, r.body) resp = mock.MagicMock() + resp.status_code = 200 resp.text = json.dumps(self.token) return resp @@ -130,10 +134,11 @@ def test_refresh_token_request(self): self.expired_token["expires_in"] = "-1" del self.expired_token["expires_at"] - def fake_refresh(r, **kwargs): + def fake_refresh(r, status_code=200, **kwargs): if "/refresh" in r.url: self.assertNotIn("Authorization", r.headers) resp = mock.MagicMock() + resp.status_code = status_code resp.text = json.dumps(self.token) return resp @@ -166,6 +171,19 @@ def token_updater(token): sess.send = fake_refresh sess.get("https://i.b") + # test 5xx error handler + for client in self.clients: + sess = OAuth2Session( + client=client, + token=self.expired_token, + auto_refresh_url="https://i.b/refresh", + token_updater=token_updater, + ) + sess.send = lambda r, **kwargs: fake_refresh( + r=r, status_code=503, kwargs=kwargs + ) + self.assertRaises(HTTPError, sess.get, "https://i.b") + def fake_refresh_with_auth(r, **kwargs): if "/refresh" in r.url: self.assertIn("Authorization", r.headers) @@ -177,6 +195,7 @@ def fake_refresh_with_auth(r, **kwargs): content = "Basic {encoded}".format(encoded=encoded.decode("latin1")) self.assertEqual(r.headers["Authorization"], content) resp = mock.MagicMock() + resp.status_code = 200 resp.text = json.dumps(self.token) return resp @@ -251,6 +270,23 @@ def test_fetch_token(self): else: self.assertRaises(OAuth2Error, sess.fetch_token, url) + # test 5xx error responses + error = {"error": "server error!"} + for client in self.clients: + sess = OAuth2Session(client=client, token=self.token) + sess.send = fake_token(error, status_code=500) + if isinstance(client, LegacyApplicationClient): + # this client requires a username+password + self.assertRaises( + HTTPError, + sess.fetch_token, + url, + username="username1", + password="password1", + ) + else: + self.assertRaises(HTTPError, sess.fetch_token, url) + # there are different scenarios in which the `client_id` can be specified # reference `oauthlib.tests.oauth2.rfc6749.clients.test_web_application.WebApplicationClientTest.test_prepare_request_body` # this only needs to test WebApplicationClient @@ -263,6 +299,7 @@ def test_fetch_token(self): def fake_token_history(token): def fake_send(r, **kwargs): resp = mock.MagicMock() + resp.status_code = 200 resp.text = json.dumps(token) _fetch_history.append( (r.url, r.body, r.headers.get("Authorization", None)) @@ -470,6 +507,7 @@ def test_authorized_true(self): def fake_token(token): def fake_send(r, **kwargs): resp = mock.MagicMock() + resp.status_code = 200 resp.text = json.dumps(token) return resp @@ -497,6 +535,31 @@ def fake_send(r, **kwargs): sess.fetch_token(url) self.assertTrue(sess.authorized) + def test_raise_for_5xx(self): + for reason_bytes in [ + b"\xa1An error occurred!", # iso-8859-i + b"\xc2\xa1An error occurred!", # utf-8 + ]: + fake_resp = mock.MagicMock() + fake_resp.status_code = 504 + fake_resp.reason = reason_bytes + reason_unicode = "\u00A1An error occurred!" + fake_resp.url = "https://example.com/token" + expected = ( + "504 Server Error: " + reason_unicode + " for url: " + fake_resp.url + ) + + # Make sure our expected unicode string literal is indeed unicode + # in both py2 and py3 + self.assertEqual(reason_unicode[0].encode("utf-8"), b"\xc2\xa1") + + sess = OAuth2Session("test-id") + + with self.assertRaises(HTTPError) as cm: + sess._raise_for_5xx(fake_resp) + + self.assertEqual(cm.exception.args[0], expected) + class OAuth2SessionNetrcTest(OAuth2SessionTest): """Ensure that there is no magic auth handling.