diff --git a/requests_oauthlib/__init__.py b/requests_oauthlib/__init__.py index 1bb919ee..1f9418ae 100644 --- a/requests_oauthlib/__init__.py +++ b/requests_oauthlib/__init__.py @@ -1,5 +1,6 @@ import logging +from .exc import TokenRequestDenied from .oauth1_auth import OAuth1 from .oauth1_session import OAuth1Session from .oauth2_auth import OAuth2 diff --git a/requests_oauthlib/exc.py b/requests_oauthlib/exc.py new file mode 100644 index 00000000..7be01912 --- /dev/null +++ b/requests_oauthlib/exc.py @@ -0,0 +1,10 @@ +class TokenRequestDenied(ValueError): + + def __init__(self, message, response): + super(TokenRequestDenied, self).__init__(message) + self.response = response + + @property + def status_code(self): + """For backwards-compatibility purposes""" + return self.response.status_code diff --git a/requests_oauthlib/oauth1_session.py b/requests_oauthlib/oauth1_session.py index 53b7b6d1..f6b94421 100644 --- a/requests_oauthlib/oauth1_session.py +++ b/requests_oauthlib/oauth1_session.py @@ -14,6 +14,7 @@ ) import requests +from . import exc from . import OAuth1 @@ -29,18 +30,6 @@ def urldecode(body): return json.loads(body) -class TokenRequestDenied(ValueError): - - def __init__(self, message, response): - super(TokenRequestDenied, self).__init__(message) - self.response = response - - @property - def status_code(self): - """For backwards-compatibility purposes""" - return self.response.status_code - - class TokenMissing(ValueError): def __init__(self, message, response): super(TokenMissing, self).__init__(message) @@ -365,7 +354,7 @@ def _fetch_token(self, url, **request_kwargs): if r.status_code >= 400: error = "Token request failed with code %s, response was '%s'." - raise TokenRequestDenied(error % (r.status_code, r.text), r) + raise exc.TokenRequestDenied(error % (r.status_code, r.text), r) log.debug('Decoding token from response "%s"', r.text) try: diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index e5ad72c2..82f57bb6 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -9,6 +9,8 @@ log = logging.getLogger(__name__) +from . import exc + class TokenUpdated(Warning): def __init__(self, token): @@ -80,6 +82,7 @@ def __init__(self, client_id=None, client=None, auto_refresh_url=None, 'access_token_response': set(), 'refresh_token_response': set(), 'protected_request': set(), + 'token_request': set(), } def new_state(self): @@ -210,24 +213,16 @@ def fetch_token(self, token_url, code=None, authorization_response=None, log.debug('Encoding username, password as Basic auth credentials.') auth = requests.auth.HTTPBasicAuth(username, password) - headers = headers or { - 'Accept': 'application/json', - 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', - } self.token = {} - if method.upper() == 'POST': - r = self.post(token_url, data=dict(urldecode(body)), - timeout=timeout, headers=headers, auth=auth, - verify=verify, proxies=proxies) - log.debug('Prepared fetch token request body %s', body) - elif method.upper() == 'GET': - # if method is not 'POST', switch body to querystring and GET - r = self.get(token_url, params=dict(urldecode(body)), - timeout=timeout, headers=headers, auth=auth, - verify=verify, proxies=proxies) - log.debug('Prepared fetch token request querystring %s', body) - else: - raise ValueError('The method kwarg must be POST or GET.') + + r = self._auth_request(method.upper(), + token_url, + body, + timeout=timeout, + headers=headers, + auth=auth, + verify=verify, + proxies=proxies) log.debug('Request to fetch token completed with status %s.', r.status_code) @@ -286,16 +281,16 @@ def refresh_token(self, token_url, refresh_token=None, body='', auth=None, refresh_token=refresh_token, scope=self.scope, **kwargs) log.debug('Prepared refresh token request body %s', body) - if headers is None: - headers = { - 'Accept': 'application/json', - 'Content-Type': ( - 'application/x-www-form-urlencoded;charset=UTF-8' - ), - } + r = self._auth_request('POST', + token_url, + body, + auth=auth, + timeout=timeout, + headers=headers, + verify=verify, + withhold_token=True, + proxies=proxies) - r = self.post(token_url, data=dict(urldecode(body)), auth=auth, - timeout=timeout, headers=headers, verify=verify, withhold_token=True, proxies=proxies) log.debug('Request to refresh token completed with status %s.', r.status_code) log.debug('Response headers were %s and content %s.', @@ -312,6 +307,34 @@ def refresh_token(self, token_url, refresh_token=None, body='', auth=None, self.token['refresh_token'] = refresh_token return self.token + def _auth_request(self, method, url, body, **kwargs): + method = method.upper() + data = dict(urldecode(body)) + kwargs.setdefault('headers', { + 'Accept': 'application/json', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', + }) + + if method == 'POST': + kwargs['data'] = data + log.debug('Prepared fetch token request body %s', body) + elif method == 'GET': + kwargs['params'] = data + log.debug('Prepared fetch token request querystring %s', body) + else: + raise ValueError('The method kwarg must be POST or GET.') + + for hook in self.compliance_hook['token_request']: + method, url, kwargs = hook(method, url, **kwargs) + + r = self.request(method, url, **kwargs) + + if not r.ok: + error = "Token request failed with code %s, response was '%s'." + raise exc.TokenRequestDenied(error % (r.status_code, r.text), r) + + return r + def request(self, method, url, data=None, headers=None, withhold_token=False, client_id=None, client_secret=None, **kwargs): """Intercept all requests and add the OAuth 2 token if present.""" diff --git a/tests/test_oauth1_session.py b/tests/test_oauth1_session.py index 183244b2..2f4c454f 100644 --- a/tests/test_oauth1_session.py +++ b/tests/test_oauth1_session.py @@ -3,6 +3,7 @@ import unittest import sys import requests +import requests_mock from io import StringIO from oauthlib.oauth1 import SIGNATURE_TYPE_QUERY, SIGNATURE_TYPE_BODY @@ -62,6 +63,9 @@ "jPkI%2FkWMvpxtMrU3Z3KN31WQ%3D%3D" ) +TEST_URL = 'https://i.b' +TEST_TOKEN_URL = 'https://example.com/token' + class OAuth1SessionTest(unittest.TestCase): @@ -70,30 +74,32 @@ def setUp(self): if not hasattr(self, 'assertIn'): self.assertIn = lambda a, b: self.assertTrue(a in b) + self.requests_mock = requests_mock.mock() + self.requests_mock.start() + self.addCleanup(self.requests_mock.stop) + + def assert_signature(self, signature): + header = self.requests_mock.last_request.headers['Authorization'] + self.assertEqual(signature, header.decode('utf-8')) + + return header + def test_signature_types(self): - def verify_signature(getter): - def fake_send(r, **kwargs): - signature = getter(r) - if isinstance(signature, bytes_type): - signature = signature.decode('utf-8') - self.assertIn('oauth_signature', signature) - resp = mock.MagicMock(spec=requests.Response) - resp.cookies = [] - return resp - return fake_send + self.requests_mock.post(TEST_URL) header = OAuth1Session('foo') - header.send = verify_signature(lambda r: r.headers['Authorization']) - header.post('https://i.b') + header.post(TEST_URL) + self.assertIn('oauth_signature', + self.requests_mock.last_request.headers['Authorization'].decode('utf-8')) query = OAuth1Session('foo', signature_type=SIGNATURE_TYPE_QUERY) - query.send = verify_signature(lambda r: r.url) - query.post('https://i.b') + query.post(TEST_URL) + self.assertIn('oauth_signature', self.requests_mock.last_request.url) body = OAuth1Session('foo', signature_type=SIGNATURE_TYPE_BODY) headers = {'Content-Type': 'application/x-www-form-urlencoded'} - body.send = verify_signature(lambda r: r.body) - body.post('https://i.b', headers=headers, data='') + body.post(TEST_URL, headers=headers, data='') + self.assertIn('oauth_signature', self.requests_mock.last_request.text) @mock.patch('oauthlib.oauth1.rfc5849.generate_timestamp') @mock.patch('oauthlib.oauth1.rfc5849.generate_nonce') @@ -101,18 +107,19 @@ def test_signature_methods(self, generate_nonce, generate_timestamp): if not cryptography: raise unittest.SkipTest('cryptography module is required') + self.requests_mock.post(TEST_URL) generate_nonce.return_value = 'abc' generate_timestamp.return_value = '123' signature = 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", oauth_signature_method="HMAC-SHA1", oauth_consumer_key="foo", oauth_signature="h2sRqLArjhlc5p3FTkuNogVHlKE%3D"' auth = OAuth1Session('foo') - auth.send = self.verify_signature(signature) - auth.post('https://i.b') + auth.post(TEST_URL) + self.assert_signature(signature) signature = 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", oauth_signature_method="PLAINTEXT", oauth_consumer_key="foo", oauth_signature="%26"' auth = OAuth1Session('foo', signature_method=SIGNATURE_PLAINTEXT) - auth.send = self.verify_signature(signature) - auth.post('https://i.b') + auth.post(TEST_URL) + self.assert_signature(signature) signature = ('OAuth ' 'oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", ' @@ -121,30 +128,33 @@ def test_signature_methods(self, generate_nonce, generate_timestamp): ).format(sig=TEST_RSA_OAUTH_SIGNATURE) auth = OAuth1Session('foo', signature_method=SIGNATURE_RSA, rsa_key=TEST_RSA_KEY) - auth.send = self.verify_signature(signature) - auth.post('https://i.b') + auth.post(TEST_URL) + self.assert_signature(signature) @mock.patch('oauthlib.oauth1.rfc5849.generate_timestamp') @mock.patch('oauthlib.oauth1.rfc5849.generate_nonce') def test_binary_upload(self, generate_nonce, generate_timestamp): + self.requests_mock.post(TEST_URL) + generate_nonce.return_value = 'abc' generate_timestamp.return_value = '123' fake_xml = StringIO('hello world') headers = {'Content-Type': 'application/xml'} signature = 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", oauth_signature_method="HMAC-SHA1", oauth_consumer_key="foo", oauth_signature="h2sRqLArjhlc5p3FTkuNogVHlKE%3D"' auth = OAuth1Session('foo') - auth.send = self.verify_signature(signature) - auth.post('https://i.b', headers=headers, files=[('fake', fake_xml)]) + auth.post(TEST_URL, headers=headers, files=[('fake', fake_xml)]) + self.assert_signature(signature) @mock.patch('oauthlib.oauth1.rfc5849.generate_timestamp') @mock.patch('oauthlib.oauth1.rfc5849.generate_nonce') def test_nonascii(self, generate_nonce, generate_timestamp): + self.requests_mock.post('https://i.b') generate_nonce.return_value = 'abc' generate_timestamp.return_value = '123' signature = 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", oauth_signature_method="HMAC-SHA1", oauth_consumer_key="foo", oauth_signature="W0haoue5IZAZoaJiYCtfqwMf8x8%3D"' auth = OAuth1Session('foo') - auth.send = self.verify_signature(signature) auth.post('https://i.b?cjk=%E5%95%A6%E5%95%A6') + self.assert_signature(signature) def test_authorization_url(self): auth = OAuth1Session('foo') @@ -165,68 +175,72 @@ def test_parse_response_url(self): def test_fetch_request_token(self): auth = OAuth1Session('foo') - auth.send = self.fake_body('oauth_token=foo') - resp = auth.fetch_request_token('https://example.com/token') + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') + resp = auth.fetch_request_token(TEST_TOKEN_URL) self.assertEqual(resp['oauth_token'], 'foo') for k, v in resp.items(): self.assertTrue(isinstance(k, unicode_type)) self.assertTrue(isinstance(v, unicode_type)) + self.assertTrue(self.requests_mock.called_once) def test_fetch_request_token_with_optional_arguments(self): auth = OAuth1Session('foo') - auth.send = self.fake_body('oauth_token=foo') - resp = auth.fetch_request_token('https://example.com/token', - verify=False, stream=True) + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') + resp = auth.fetch_request_token(TEST_TOKEN_URL, verify=False, stream=True) self.assertEqual(resp['oauth_token'], 'foo') for k, v in resp.items(): self.assertTrue(isinstance(k, unicode_type)) self.assertTrue(isinstance(v, unicode_type)) + self.assertTrue(self.requests_mock.called_once) + self.assertFalse(self.requests_mock.last_request.verify) def test_fetch_access_token(self): auth = OAuth1Session('foo', verifier='bar') - auth.send = self.fake_body('oauth_token=foo') - resp = auth.fetch_access_token('https://example.com/token') + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') + resp = auth.fetch_access_token(TEST_TOKEN_URL) self.assertEqual(resp['oauth_token'], 'foo') for k, v in resp.items(): self.assertTrue(isinstance(k, unicode_type)) self.assertTrue(isinstance(v, unicode_type)) + self.assertTrue(self.requests_mock.called_once) def test_fetch_access_token_with_optional_arguments(self): auth = OAuth1Session('foo', verifier='bar') - auth.send = self.fake_body('oauth_token=foo') - resp = auth.fetch_access_token('https://example.com/token', - verify=False, stream=True) + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') + resp = auth.fetch_access_token(TEST_TOKEN_URL, verify=False, stream=True) self.assertEqual(resp['oauth_token'], 'foo') for k, v in resp.items(): self.assertTrue(isinstance(k, unicode_type)) self.assertTrue(isinstance(v, unicode_type)) + self.assertTrue(self.requests_mock.called_once) + self.assertFalse(self.requests_mock.last_request.verify) def _test_fetch_access_token_raises_error(self, auth): """Assert that an error is being raised whenever there's no verifier passed in to the client. """ - auth.send = self.fake_body('oauth_token=foo') + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') # Use a try-except block so that we can assert on the exception message # being raised and also keep the Python2.6 compatibility where # assertRaises is not a context manager. try: - auth.fetch_access_token('https://example.com/token') + auth.fetch_access_token(TEST_TOKEN_URL) except ValueError as exc: self.assertEqual('No client verifier has been set.', str(exc)) def test_fetch_token_invalid_response(self): auth = OAuth1Session('foo') - auth.send = self.fake_body('not valid urlencoded response!') - self.assertRaises(ValueError, auth.fetch_request_token, - 'https://example.com/token') + self.requests_mock.post(TEST_TOKEN_URL, + text='not valid urlencoded response!') + self.assertRaises(ValueError, auth.fetch_request_token, TEST_TOKEN_URL) for code in (400, 401, 403): - auth.send = self.fake_body('valid=response', code) + self.requests_mock.post(TEST_TOKEN_URL, status_code=code) # use try/catch rather than self.assertRaises, so we can # assert on the properties of the exception try: - auth.fetch_request_token('https://example.com/token') + auth.fetch_request_token(TEST_TOKEN_URL) except ValueError as err: self.assertEqual(err.status_code, code) self.assertTrue(isinstance(err.response, requests.Response)) @@ -289,13 +303,13 @@ def test_authorized_false_rsa(self): ).format(sig=TEST_RSA_OAUTH_SIGNATURE) sess = OAuth1Session('foo', signature_method=SIGNATURE_RSA, rsa_key=TEST_RSA_KEY) - sess.send = self.verify_signature(signature) self.assertFalse(sess.authorized) def test_authorized_true(self): + self.requests_mock.post(TEST_TOKEN_URL, + text='oauth_token=foo&oauth_token_secret=bar') sess = OAuth1Session('key', 'secret', verifier='bar') - sess.send = self.fake_body('oauth_token=foo&oauth_token_secret=bar') - sess.fetch_access_token('https://example.com/token') + sess.fetch_access_token(TEST_TOKEN_URL) self.assertTrue(sess.authorized) @mock.patch('oauthlib.oauth1.rfc5849.generate_timestamp') @@ -313,26 +327,7 @@ def test_authorized_true_rsa(self, generate_nonce, generate_timestamp): ).format(sig=TEST_RSA_OAUTH_SIGNATURE) sess = OAuth1Session('key', 'secret', signature_method=SIGNATURE_RSA, rsa_key=TEST_RSA_KEY, verifier='bar') - sess.send = self.fake_body('oauth_token=foo&oauth_token_secret=bar') - sess.fetch_access_token('https://example.com/token') + self.requests_mock.post(TEST_TOKEN_URL, + text='oauth_token=foo&oauth_token_secret=bar') + sess.fetch_access_token(TEST_TOKEN_URL) self.assertTrue(sess.authorized) - - def verify_signature(self, signature): - def fake_send(r, **kwargs): - auth_header = r.headers['Authorization'] - if isinstance(auth_header, bytes_type): - auth_header = auth_header.decode('utf-8') - self.assertEqual(auth_header, signature) - resp = mock.MagicMock(spec=requests.Response) - resp.cookies = [] - return resp - return fake_send - - def fake_body(self, body, status_code=200): - def fake_send(r, **kwargs): - resp = mock.MagicMock(spec=requests.Response) - resp.cookies = [] - resp.text = body - resp.status_code = status_code - return resp - return fake_send diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index e5892cab..011256ea 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -11,21 +11,13 @@ from oauthlib.oauth2 import MismatchingStateError from oauthlib.oauth2 import WebApplicationClient, MobileApplicationClient from oauthlib.oauth2 import LegacyApplicationClient, BackendApplicationClient -from requests_oauthlib import OAuth2Session, TokenUpdated +from requests_oauthlib import OAuth2Session, TokenUpdated, TokenRequestDenied +import requests_mock fake_time = time.time() - -def fake_token(token): - def fake_send(r, **kwargs): - resp = mock.MagicMock() - resp.text = json.dumps(token) - return resp - return fake_send - - class OAuth2SessionTest(TestCase): def setUp(self): @@ -48,20 +40,24 @@ def setUp(self): ] self.all_clients = self.clients + [MobileApplicationClient(self.client_id)] - def test_add_token(self): - token = 'Bearer ' + self.token['access_token'] + self.requests_mock = requests_mock.mock() + self.requests_mock.start() + self.addCleanup(self.requests_mock.stop) - def verifier(r, **kwargs): - auth_header = r.headers.get(str('Authorization'), None) - self.assertEqual(auth_header, token) - resp = mock.MagicMock() - resp.cookes = [] - return resp + def test_add_token(self): + self.requests_mock.get('https://i.b', text='Ok') for client in self.all_clients: auth = OAuth2Session(client=client, token=self.token) - auth.send = verifier - auth.get('https://i.b') + resp = auth.get('https://i.b') + self.assertEqual(200, resp.status_code) + + self.assertEqual(len(self.all_clients), + len(self.requests_mock.request_history)) + + token = 'Bearer ' + self.token['access_token'] + for r in self.requests_mock.request_history: + self.assertEqual(token, r.headers.get(str('Authorization'), None)) def test_authorization_url(self): url = 'https://example.com/authorize?foo=bar' @@ -81,58 +77,86 @@ def test_authorization_url(self): self.assertIn('response_type=token', auth_url) @mock.patch("time.time", new=lambda: fake_time) - def test_refresh_token_request(self): + def test_refresh_token_request_no_refresh(self): self.expired_token = dict(self.token) self.expired_token['expires_in'] = '-1' del self.expired_token['expires_at'] - def fake_refresh(r, **kwargs): - if "/refresh" in r.url: - self.assertNotIn("Authorization", r.headers) - resp = mock.MagicMock() - resp.text = json.dumps(self.token) - return resp - # No auto refresh setup for client in self.clients: auth = OAuth2Session(client=client, token=self.expired_token) self.assertRaises(TokenExpiredError, auth.get, 'https://i.b') + self.assertFalse(self.requests_mock.called) + + @mock.patch("time.time", new=lambda: fake_time) + def test_refresh_token_request_refresh_no_update(self): + self.expired_token = dict(self.token) + self.expired_token['expires_in'] = '-1' + del self.expired_token['expires_at'] + + m1 = self.requests_mock.get('https://i.b') + m2 = self.requests_mock.post('https://i.b/refresh', json=self.token) + # Auto refresh but no auto update for client in self.clients: auth = OAuth2Session(client=client, token=self.expired_token, auto_refresh_url='https://i.b/refresh') - auth.send = fake_refresh self.assertRaises(TokenUpdated, auth.get, 'https://i.b') - # Auto refresh and auto update - def token_updater(token): - self.assertEqual(token, self.token) + self.assertFalse(m1.called) + self.assertEquals(len(self.clients), m2.call_count) + + @mock.patch("time.time", new=lambda: fake_time) + def test_refresh_token_request_refresh_and_update(self): + self.expired_token = dict(self.token) + self.expired_token['expires_in'] = '-1' + del self.expired_token['expires_at'] + + m1 = self.requests_mock.get('https://i.b') + m2 = self.requests_mock.post('https://i.b/refresh', json=self.token) + + token_updater = mock.MagicMock() for client in self.clients: auth = OAuth2Session(client=client, token=self.expired_token, auto_refresh_url='https://i.b/refresh', token_updater=token_updater) - auth.send = fake_refresh - auth.get('https://i.b') - - def fake_refresh_with_auth(r, **kwargs): - if "/refresh" in r.url: - self.assertIn("Authorization", r.headers) - encoded = b64encode(b"foo:bar") - content = (b"Basic " + encoded).decode('latin1') - self.assertEqual(r.headers["Authorization"], content) - resp = mock.MagicMock() - resp.text = json.dumps(self.token) - return resp + resp = auth.get('https://i.b') + self.assertEqual(200, resp.status_code) + + self.assertEquals(len(self.clients), m1.call_count) + self.assertEquals(len(self.clients), m2.call_count) + self.assertEquals(len(self.clients), token_updater.call_count) + + @mock.patch("time.time", new=lambda: fake_time) + def test_refresh_token_request_refresh_and_update_2(self): + self.expired_token = dict(self.token) + self.expired_token['expires_in'] = '-1' + del self.expired_token['expires_at'] + + m1 = self.requests_mock.get('https://i.b') + m2 = self.requests_mock.post('https://i.b/refresh', json=self.token) + + token_updater = mock.MagicMock() for client in self.clients: auth = OAuth2Session(client=client, token=self.expired_token, auto_refresh_url='https://i.b/refresh', token_updater=token_updater) - auth.send = fake_refresh_with_auth auth.get('https://i.b', client_id='foo', client_secret='bar') + self.assertEquals(len(self.clients), m1.call_count) + self.assertEquals(len(self.clients), m2.call_count) + self.assertEquals(len(self.clients), token_updater.call_count) + + token = (b"Basic " + b64encode(b"foo:bar")).decode('latin1') + for r in m2.request_history: + self.assertEquals(token, r.headers["Authorization"]) + + for c in token_updater.call_args_list: + self.assertEqual(c, mock.call(self.token)) + @mock.patch("time.time", new=lambda: fake_time) def test_token_from_fragment(self): mobile = MobileApplicationClient(self.client_id) @@ -141,20 +165,27 @@ def test_token_from_fragment(self): self.assertEqual(auth.token_from_fragment(response_url), self.token) @mock.patch("time.time", new=lambda: fake_time) - def test_fetch_token(self): + def test_fetch_token_good(self): url = 'https://example.com/token' + self.requests_mock.post(url, json=self.token) for client in self.clients: auth = OAuth2Session(client=client, token=self.token) - auth.send = fake_token(self.token) self.assertEqual(auth.fetch_token(url), self.token) - error = {'error': 'invalid_request'} + self.assertEqual(len(self.clients), self.requests_mock.call_count) + + @mock.patch("time.time", new=lambda: fake_time) + def test_fetch_token_invalid(self): + url = 'https://example.com/token' + self.requests_mock.post(url, json={'error': 'invalid_request'}) + for client in self.clients: auth = OAuth2Session(client=client, token=self.token) - auth.send = fake_token(error) self.assertRaises(OAuth2Error, auth.fetch_token, url) + self.assertEqual(len(self.clients), self.requests_mock.call_count) + def test_cleans_previous_token_before_fetching_new_one(self): """Makes sure the previous token is cleaned before fetching a new one. @@ -170,12 +201,14 @@ def test_cleans_previous_token_before_fetching_new_one(self): new_token['expires_at'] = now + 3600 url = 'https://example.com/token' + self.requests_mock.post(url, json=new_token) + with mock.patch('time.time', lambda: now): for client in self.clients: auth = OAuth2Session(client=client, token=self.token) - auth.send = fake_token(new_token) self.assertEqual(auth.fetch_token(url), new_token) + self.assertTrue(len(self.clients), self.requests_mock.call_count) def test_web_app_fetch_token(self): # Ensure the state parameter is used, see issue #105. @@ -229,17 +262,29 @@ def test_authorized_false(self): @mock.patch("time.time", new=lambda: fake_time) def test_authorized_true(self): - def fake_token(token): - def fake_send(r, **kwargs): - resp = mock.MagicMock() - resp.text = json.dumps(token) - return resp - return fake_send url = 'https://example.com/token' + self.requests_mock.post(url, json=self.token) for client in self.clients: sess = OAuth2Session(client=client) - sess.send = fake_token(self.token) self.assertFalse(sess.authorized) sess.fetch_token(url) self.assertTrue(sess.authorized) + + self.assertEqual(len(self.clients), self.requests_mock.call_count) + + def test_token_fetch_invalid_status_code(self): + url = 'https://example.com/token' + self.requests_mock.post(url, + json={'message': 'Failure'}, + status_code=403) + + for client in self.clients: + sess = OAuth2Session(client=client) + self.assertRaises( + TokenRequestDenied, + sess.fetch_token, + url + ) + + self.assertEqual(len(self.clients), self.requests_mock.call_count)