1717import json
1818import logging
1919import os
20+ import re
2021import time
2122from typing import Any , Dict , List , Optional , Tuple , Union , cast
2223from urllib .parse import urlencode , urlparse
2324
25+ import boto3
2426import requests
2527from flask import Flask , jsonify , render_template , request
2628from jose import jwk , jwt
3335
3436def is_ready (logger : logging .Logger , app : Flask ) -> Any :
3537 logger .debug ("cookies: %s" , json .dumps (request .cookies ))
36- email , username = _get_user_info_from_jwt (logger )
38+ email , username , groups = _get_user_info_from_jwt (logger )
3739
3840 ready = _is_profile_ready_for_user (logger , username , email )
3941 logger .debug ("username: %s, email: %s" , username , email )
@@ -42,11 +44,18 @@ def is_ready(logger: logging.Logger, app: Flask) -> Any:
4244
4345def login (logger : logging .Logger , app : Flask ) -> Any :
4446 logger .debug ("cookies: %s" , json .dumps (request .cookies ))
45- email , username = _get_user_info_from_jwt (logger )
46- logger .debug ("username: %s, email: %s" , username , email )
47+ email , username , groups = _get_user_info_from_jwt (logger )
4748
48- groups = _get_user_groups_from_jwt (logger )
49+ # If we have groups, then the provider sent them.
50+ # Match them to the proper user groups
51+ if groups is not None :
52+ logger .info ("We got groups in the auth payload, we need to align them to teams" )
53+ user_groups = _get_user_groups_from_provider (logger , list (groups ))
54+ else :
55+ logger .info ("No groups in auth payload, we are fetchng the from the Cognito User Pool" )
56+ user_groups = _get_user_groups_from_jwt (logger )
4957
58+ logger .debug ("username: %s, email: %s, groups: %s" , username , email , user_groups )
5059 ready = _is_profile_ready_for_user (logger , username , email )
5160 logger .debug ("user space is READY? %s" , ready )
5261
@@ -60,7 +69,7 @@ def login(logger: logging.Logger, app: Flask) -> Any:
6069 logout_uri = logout_uri ,
6170 client_id = client_id ,
6271 cognito_domain = cognito_domain ,
63- teams = groups ,
72+ teams = user_groups ,
6473 env_name = env_name ,
6574 )
6675
@@ -105,7 +114,7 @@ def _get_kf_profiles(client: dynamic.DynamicClient) -> List[Dict[str, Any]]:
105114
106115
107116# https://docs.aws.amazon.com/elasticloadbalancing/latest/application/listener-authenticate-users.html
108- def _get_user_info_from_jwt (logger : logging .Logger ) -> Tuple [str , str ]:
117+ def _get_user_info_from_jwt (logger : logging .Logger ) -> Tuple [Any , Any , Optional [ Any ] ]:
109118 logger .debug ("headers: %s" , json .dumps (dict (request .headers )))
110119 encoded_jwt = request .headers ["x-amzn-oidc-data" ]
111120 logger .debug ("encoded_jwt 'x-amzn-oidc-data':\n %s" , encoded_jwt )
@@ -122,9 +131,32 @@ def _get_user_info_from_jwt(logger: logging.Logger) -> Tuple[str, str]:
122131 # Step 3: Get the payload
123132 payload = jwt .decode (encoded_jwt , pub_key , algorithms = ["ES256" ])
124133 logger .debug ("payload:\n %s" , payload )
134+
125135 username = payload ["username" ]
136+ if "preferred_username" in payload :
137+ username = payload ["preferred_username" ]
138+
126139 email = payload ["email" ]
127- return email , username
140+
141+ groups = None
142+ if "custom:groups" in payload :
143+ groups = payload ["custom:groups" ].strip ("][" ).split (", " )
144+
145+ return email , username , groups
146+
147+
148+ def _get_user_groups_from_provider (logger : logging .Logger , groups_from_provider : List [Any ]) -> List [str ]:
149+ logger .info ("Starting to get groups" )
150+ team_info = _get_auth_group_from_ssm (logger )
151+ user_groups = []
152+ for group_name in groups_from_provider :
153+ for team_name in team_info :
154+ if group_name in team_info [team_name ]:
155+ g = team_name
156+ user_groups .append (g )
157+ user_groups = list (dict .fromkeys (user_groups ))
158+ logger .info (f"User Groups: { user_groups } " )
159+ return user_groups
128160
129161
130162def _get_user_groups_from_jwt (logger : logging .Logger ) -> List [str ]:
@@ -133,7 +165,22 @@ def _get_user_groups_from_jwt(logger: logging.Logger) -> List[str]:
133165 logger .debug ("encoded_jwt 'X-Amzn-Oidc-Accesstoken':\n %s" , encoded_jwt )
134166 claims = get_claims (logger , encoded_jwt )
135167 groups : Union [List [Any ], str , int ] = claims ["cognito:groups" ] if "cognito:groups" in claims else []
136- return cast (List [str ], groups )
168+ logger .debug (f"Groups from Cognito : { groups } " )
169+ team_info = _get_auth_group_from_ssm (logger )
170+ orbit_env = os .environ ["ENV_NAME" ]
171+ user_groups = []
172+ for group_name in groups : # type: ignore
173+ if (f"{ orbit_env } -" ) in group_name :
174+ group_name = group_name .split (f"{ orbit_env } -" )[1 ]
175+ for team_name in team_info :
176+ logger .debug (
177+ f"Team Name: { team_name } group_name: { group_name } team_info[team_name] :{ team_info [team_name ]} "
178+ )
179+ if group_name in team_info [team_name ]:
180+ g = team_name
181+ user_groups .append (g )
182+ logger .info (f"User Groups: { user_groups } " )
183+ return user_groups
137184
138185
139186def _get_keys (logger : logging .Logger ) -> List [Dict [str , str ]]:
@@ -178,3 +225,28 @@ def get_claims(logger: logging.Logger, token: str) -> Dict[str, Union[str, int]]
178225 logger .debug ("claims: %s" , claims )
179226
180227 return cast (Dict [str , Union [str , int ]], claims )
228+
229+
230+ def _get_auth_group_from_ssm (logger : logging .Logger ) -> Dict [str , List [str ]]:
231+ ssm_client = boto3 .client ("ssm" )
232+
233+ team_info = {}
234+ orbit_env = os .environ ["ENV_NAME" ]
235+
236+ team_manifest_pattern = re .compile (rf"/orbit/{ orbit_env } /teams/.*/manifest" )
237+
238+ paginator = ssm_client .get_paginator ("describe_parameters" )
239+ page_iterator = paginator .paginate ()
240+
241+ for page in page_iterator :
242+ for path in page .get ("Parameters" ):
243+ param = path .get ("Name" )
244+
245+ if team_manifest_pattern .fullmatch (param ):
246+ param_value = json .loads (ssm_client .get_parameter (Name = param ).get ("Parameter" ).get ("Value" ))
247+ team = param .split ("/" )[- 2 ]
248+ auth_group_val = param_value .get ("AuthenticationGroups" )
249+ team_info [team ] = auth_group_val
250+
251+ logger .info (f"Team Info fetch: { team_info } " )
252+ return team_info
0 commit comments