Skip to content

Commit 439a36d

Browse files
katroganwild-endeavor
authored andcommitted
OAuth2 support for flyte-cli and SDK engine (#23)
This change adds authentication support for flyte-cli and the pyflyte CLIs. # New authorization code Specifically this change introduces an **AuthorizationClient** which implements the [PKCE authorization flow](https://www.oauth.com/oauth2-servers/pkce/authorization-code-exchange/) for untrusted clients. This client handles requesting an initial access token, spinning up a callback server to receive the access token and using that to retrieve an authorization code. The client also handles refreshing expired authorization tokens. This change also includes a lightweight **DiscoveryClient** for retrieving authorization endpoint metadata defined in the [OAuth 2.0 Authorization Server Metadata](https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html) draft document. An authorization client singleton is lazily initialized for use by flyte-cli. # Pyflyte changes (basic auth) Requests an authorization token using a username and password. # Flyte-cli changes (standard auth) Requests an authorization token using the PKCE flow. # Raw client changes Wraps RPC calls to flyteadmin in a retry handler that initiates the appropriate authentication flow defined in the flytekit config in response to `HTTP 401 unauthorized` response codes.
1 parent 603a918 commit 439a36d

File tree

24 files changed

+941
-44
lines changed

24 files changed

+941
-44
lines changed

flytekit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from __future__ import absolute_import
22
import flytekit.plugins
33

4-
__version__ = '0.3.1'
4+
__version__ = '0.4.0'

flytekit/clients/helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11

2+
from flytekit.clis.auth import credentials as _credentials_access
3+
4+
25

36
def iterate_node_executions(
47
client,
@@ -75,3 +78,4 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte
7578
if not next_token:
7679
break
7780
token = next_token
81+

flytekit/clients/raw.py

Lines changed: 114 additions & 28 deletions
Large diffs are not rendered by default.

flytekit/clis/auth/__init__.py

Whitespace-only changes.

flytekit/clis/auth/auth.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import base64 as _base64
2+
import hashlib as _hashlib
3+
import keyring as _keyring
4+
import os as _os
5+
import re as _re
6+
import requests as _requests
7+
import webbrowser as _webbrowser
8+
9+
from multiprocessing import Process as _Process, Queue as _Queue
10+
11+
try: # Python 3.5+
12+
from http import HTTPStatus as _StatusCodes
13+
except ImportError:
14+
try: # Python 3
15+
from http import client as _StatusCodes
16+
except ImportError: # Python 2
17+
import httplib as _StatusCodes
18+
try: # Python 3
19+
import http.server as _BaseHTTPServer
20+
except ImportError: # Python 2
21+
import BaseHTTPServer as _BaseHTTPServer
22+
23+
try: # Python 3
24+
import urllib.parse as _urlparse
25+
from urllib.parse import urlencode as _urlencode
26+
except ImportError: # Python 2
27+
import urlparse as _urlparse
28+
from urllib import urlencode as _urlencode
29+
30+
_code_verifier_length = 64
31+
_random_seed_length = 40
32+
_utf_8 = 'utf-8'
33+
34+
35+
# Identifies the service used for storing passwords in keyring
36+
_keyring_service_name = "flyteauth"
37+
# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs
38+
# suggest, we are storing a user's oidc.
39+
_keyring_access_token_storage_key = "access_token"
40+
_keyring_refresh_token_storage_key = "refresh_token"
41+
42+
43+
def _generate_code_verifier():
44+
"""
45+
Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1
46+
Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py.
47+
:return str:
48+
"""
49+
code_verifier = _base64.urlsafe_b64encode(_os.urandom(_code_verifier_length)).decode(_utf_8)
50+
# Eliminate invalid characters.
51+
code_verifier = _re.sub(r'[^a-zA-Z0-9_\-.~]+', '', code_verifier)
52+
if len(code_verifier) < 43:
53+
raise ValueError("Verifier too short. number of bytes must be > 30.")
54+
elif len(code_verifier) > 128:
55+
raise ValueError("Verifier too long. number of bytes must be < 97.")
56+
return code_verifier
57+
58+
59+
def _generate_state_parameter():
60+
state = _base64.urlsafe_b64encode(_os.urandom(_random_seed_length)).decode(_utf_8)
61+
# Eliminate invalid characters.
62+
code_verifier = _re.sub('[^a-zA-Z0-9-_.,]+', '', state)
63+
return code_verifier
64+
65+
66+
def _create_code_challenge(code_verifier):
67+
"""
68+
Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py.
69+
:param str code_verifier: represents a code verifier generated by generate_code_verifier()
70+
:return str: urlsafe base64-encoded sha256 hash digest
71+
"""
72+
code_challenge = _hashlib.sha256(code_verifier.encode(_utf_8)).digest()
73+
code_challenge = _base64.urlsafe_b64encode(code_challenge).decode(_utf_8)
74+
# Eliminate invalid characters
75+
code_challenge = code_challenge.replace('=', '')
76+
return code_challenge
77+
78+
79+
class AuthorizationCode(object):
80+
def __init__(self, code, state):
81+
self._code = code
82+
self._state = state
83+
84+
@property
85+
def code(self):
86+
return self._code
87+
88+
@property
89+
def state(self):
90+
return self._state
91+
92+
93+
class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler):
94+
"""
95+
A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an
96+
authorization token.
97+
"""
98+
99+
def do_GET(self):
100+
url = _urlparse.urlparse(self.path)
101+
if url.path == self.server.redirect_path:
102+
self.send_response(_StatusCodes.OK)
103+
self.end_headers()
104+
self.handle_login(dict(_urlparse.parse_qsl(url.query)))
105+
else:
106+
self.send_response(_StatusCodes.NOT_FOUND)
107+
108+
def handle_login(self, data):
109+
self.server.handle_authorization_code(AuthorizationCode(data['code'], data['state']))
110+
111+
112+
class OAuthHTTPServer(_BaseHTTPServer.HTTPServer):
113+
"""
114+
A simple wrapper around the BaseHTTPServer.HTTPServer implementation that binds an authorization_client for handling
115+
authorization code callbacks.
116+
"""
117+
def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True,
118+
redirect_path=None, queue=None):
119+
_BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate)
120+
self._redirect_path = redirect_path
121+
self._auth_code = None
122+
self._queue = queue
123+
124+
@property
125+
def redirect_path(self):
126+
return self._redirect_path
127+
128+
def handle_authorization_code(self, auth_code):
129+
self._queue.put(auth_code)
130+
131+
132+
class Credentials(object):
133+
def __init__(self, access_token=None):
134+
self._access_token = access_token
135+
136+
@property
137+
def access_token(self):
138+
return self._access_token
139+
140+
141+
class AuthorizationClient(object):
142+
def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_uri=None):
143+
self._auth_endpoint = auth_endpoint
144+
self._token_endpoint = token_endpoint
145+
self._client_id = client_id
146+
self._redirect_uri = redirect_uri
147+
self._code_verifier = _generate_code_verifier()
148+
code_challenge = _create_code_challenge(self._code_verifier)
149+
self._code_challenge = code_challenge
150+
state = _generate_state_parameter()
151+
self._state = state
152+
self._credentials = None
153+
self._refresh_token = None
154+
self._headers = {'content-type': "application/x-www-form-urlencoded"}
155+
self._expired = False
156+
157+
self._params = {
158+
"client_id": client_id, # This must match the Client ID of the OAuth application.
159+
"response_type": "code", # Indicates the authorization code grant
160+
"scope": "openid offline_access", # ensures that the /token endpoint returns an ID and refresh token
161+
# callback location where the user-agent will be directed to.
162+
"redirect_uri": self._redirect_uri,
163+
"state": state,
164+
"code_challenge": code_challenge,
165+
"code_challenge_method": "S256",
166+
}
167+
168+
# Prefer to use already-fetched token values when they've been set globally.
169+
self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key)
170+
access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key)
171+
if access_token:
172+
self._credentials = Credentials(access_token=access_token)
173+
return
174+
175+
# In the absence of globally-set token values, initiate the token request flow
176+
q = _Queue()
177+
# First prepare the callback server in the background
178+
server = self._create_callback_server(q)
179+
server_process = _Process(target=server.handle_request)
180+
server_process.start()
181+
182+
# Send the call to request the authorization code
183+
self._request_authorization_code()
184+
185+
# Request the access token once the auth code has been received.
186+
auth_code = q.get()
187+
server_process.terminate()
188+
self.request_access_token(auth_code)
189+
190+
def _create_callback_server(self, q):
191+
server_url = _urlparse.urlparse(self._redirect_uri)
192+
server_address = (server_url.hostname, server_url.port)
193+
return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q)
194+
195+
def _request_authorization_code(self):
196+
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
197+
query = _urlencode(self._params)
198+
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
199+
_webbrowser.open_new_tab(endpoint)
200+
201+
def _initialize_credentials(self, auth_token_resp):
202+
203+
"""
204+
The auth_token_resp body is of the form:
205+
{
206+
"access_token": "foo",
207+
"refresh_token": "bar",
208+
"id_token": "baz",
209+
"token_type": "Bearer"
210+
}
211+
"""
212+
response_body = auth_token_resp.json()
213+
if "access_token" not in response_body:
214+
raise ValueError('Expected "access_token" in response from oauth server')
215+
if "refresh_token" in response_body:
216+
self._refresh_token = response_body["refresh_token"]
217+
218+
access_token = response_body["access_token"]
219+
refresh_token = response_body["refresh_token"]
220+
221+
_keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token)
222+
_keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token)
223+
self._credentials = Credentials(access_token=access_token)
224+
225+
def request_access_token(self, auth_code):
226+
if self._state != auth_code.state:
227+
raise ValueError("Unexpected state parameter [{}] passed".format(auth_code.state))
228+
self._params.update({
229+
"code": auth_code.code,
230+
"code_verifier": self._code_verifier,
231+
"grant_type": "authorization_code",
232+
})
233+
resp = _requests.post(
234+
url=self._token_endpoint,
235+
data=self._params,
236+
headers=self._headers,
237+
allow_redirects=False
238+
)
239+
if resp.status_code != _StatusCodes.OK:
240+
# TODO: handle expected (?) error cases:
241+
# https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses
242+
raise Exception('Failed to request access token with response: [{}] {}'.format(
243+
resp.status_code, resp.content))
244+
self._initialize_credentials(resp)
245+
246+
def refresh_access_token(self):
247+
if self._refresh_token is None:
248+
raise ValueError("no refresh token available with which to refresh authorization credentials")
249+
250+
resp = _requests.post(
251+
url=self._token_endpoint,
252+
data={'grant_type': 'refresh_token',
253+
'client_id': self._client_id,
254+
'refresh_token': self._refresh_token},
255+
headers=self._headers,
256+
allow_redirects=False
257+
)
258+
if resp.status_code != _StatusCodes.OK:
259+
self._expired = True
260+
# In the absence of a successful response, assume the refresh token is expired. This should indicate
261+
# to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized.
262+
263+
_keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key)
264+
_keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key)
265+
return
266+
self._initialize_credentials(resp)
267+
268+
@property
269+
def credentials(self):
270+
"""
271+
:return flytekit.clis.auth.auth.Credentials:
272+
"""
273+
return self._credentials
274+
275+
@property
276+
def expired(self):
277+
"""
278+
:return bool:
279+
"""
280+
return self._expired

flytekit/clis/auth/credentials.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import absolute_import
2+
from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient
3+
from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient
4+
5+
from flytekit.configuration.creds import (
6+
REDIRECT_URI as _REDIRECT_URI,
7+
CLIENT_ID as _CLIENT_ID
8+
)
9+
from flytekit.configuration.platform import URL as _URL, INSECURE as _INSECURE
10+
11+
try: # Python 3
12+
import urllib.parse as _urlparse
13+
except ImportError: # Python 2
14+
import urlparse as _urlparse
15+
16+
# Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3.
17+
discovery_endpoint_path = ".well-known/oauth-authorization-server"
18+
19+
20+
def _get_discovery_endpoint():
21+
if _INSECURE.get():
22+
return _urlparse.urljoin('http://{}/'.format(_URL.get()), discovery_endpoint_path)
23+
return _urlparse.urljoin('https://{}/'.format(_URL.get()), discovery_endpoint_path)
24+
25+
26+
# Lazy initialized authorization client singleton
27+
_authorization_client = None
28+
29+
30+
def get_client():
31+
global _authorization_client
32+
if _authorization_client is not None and not _authorization_client.expired:
33+
return _authorization_client
34+
authorization_endpoints = get_authorization_endpoints()
35+
36+
_authorization_client =\
37+
_AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), client_id=_CLIENT_ID.get(),
38+
auth_endpoint=authorization_endpoints.auth_endpoint,
39+
token_endpoint=authorization_endpoints.token_endpoint)
40+
return _authorization_client
41+
42+
43+
def get_authorization_endpoints():
44+
discovery_endpoint = _get_discovery_endpoint()
45+
discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint)
46+
return discovery_client.get_authorization_endpoints()

0 commit comments

Comments
 (0)