From 3a1bb4a3b4d6962d94ac0a11de364478f7a7fad6 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Thu, 5 Jun 2025 19:20:24 +0530 Subject: [PATCH 01/37] feat: adds basic plugin functionality --- supertokens_python/__init__.py | 11 +- supertokens_python/plugins.py | 326 +++++++++++++++++++++++ supertokens_python/recipe/totp/recipe.py | 12 +- supertokens_python/supertokens.py | 230 +++++++++++++++- 4 files changed, 570 insertions(+), 9 deletions(-) create mode 100644 supertokens_python/plugins.py diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index 43d3573b2..3b5611bdc 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -26,6 +26,7 @@ Supertokens = supertokens.Supertokens SupertokensConfig = supertokens.SupertokensConfig AppInfo = supertokens.AppInfo +SupertokensExperimentalConfig = supertokens.SupertokensExperimentalConfig def init( @@ -36,9 +37,17 @@ def init( mode: Optional[Literal["asgi", "wsgi"]] = None, telemetry: Optional[bool] = None, debug: Optional[bool] = None, + experimental: Optional[SupertokensExperimentalConfig] = None, ): return Supertokens.init( - app_info, framework, supertokens_config, recipe_list, mode, telemetry, debug + app_info, + framework, + supertokens_config, + recipe_list, + mode, + telemetry, + debug, + experimental=experimental, ) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py new file mode 100644 index 000000000..18f5c245b --- /dev/null +++ b/supertokens_python/plugins.py @@ -0,0 +1,326 @@ +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + TypeVar, + Union, + runtime_checkable, +) + +from typing_extensions import Protocol + +from supertokens_python.framework.request import BaseRequest +from supertokens_python.framework.response import BaseResponse + +# from supertokens_python.recipe.accountlinking.types import AccountLinkingConfig +# from supertokens_python.recipe.dashboard.utils import DashboardConfig +# from supertokens_python.recipe.emailpassword.utils import EmailPasswordConfig +# from supertokens_python.recipe.emailverification.utils import EmailVerificationConfig +# from supertokens_python.recipe.jwt.utils import JWTConfig +# from supertokens_python.recipe.multifactorauth.types import MultiFactorAuthConfig +# from supertokens_python.recipe.multitenancy.utils import MultitenancyConfig +# from supertokens_python.recipe.oauth2provider.utils import OAuth2ProviderConfig +# from supertokens_python.recipe.openid.utils import OpenIdConfig +# from supertokens_python.recipe.passwordless.utils import PasswordlessConfig +if TYPE_CHECKING: + from supertokens_python.recipe.session.interfaces import ( + SessionClaimValidator, + SessionContainer, + ) + from supertokens_python.supertokens import SupertokensPublicConfig + +# from supertokens_python.recipe.session.utils import SessionConfig +# from supertokens_python.recipe.thirdparty.utils import ThirdPartyConfig +# from supertokens_python.recipe.totp.types import TOTPConfig +# from supertokens_python.recipe.usermetadata.utils import UserMetadataConfig +# from supertokens_python.recipe.userroles.utils import UserRolesConfig +from supertokens_python.types import MaybeAwaitable +from supertokens_python.types.base import UserContext +from supertokens_python.types.response import CamelCaseBaseModel + +T = TypeVar("T") +# T = TypeVar("T", bound=Union[AccountLinkingConfig, DashboardConfig, EmailPasswordConfig, +# EmailVerificationConfig, JWTConfig, MultiFactorAuthConfig, MultitenancyConfig, +# OAuth2ProviderConfig, OpenIdConfig, PasswordlessConfig, SessionConfig, +# ThirdPartyConfig, TOTPConfig, UserMetadataConfig, UserRolesConfig]) + + +# class AllRecipeConfigs: +# # These generally have no Input config type +# accountlinking: AccountLinkingConfig +# dashboard: DashboardConfig +# emailpassword: EmailPasswordConfig +# emailverification: EmailVerificationConfig +# jwt: JWTConfig +# multifactorauth: MultiFactorAuthConfig +# multitenancy: MultitenancyConfig +# oauth2provider: OAuth2ProviderConfig +# openid: OpenIdConfig +# passwordless: PasswordlessConfig +# session: SessionConfig +# thirdparty: ThirdPartyConfig +# totp: TOTPConfig # This is the input config type +# usermetadata: UserMetadataConfig +# userroles: UserRolesConfig +# # webauthn: WebauthnConfig + + +class RecipePluginOverride: + # TODO: Define a base class for the Config/RecipeInterface/ApiInterface classes, and use it here + functions: Optional[Callable[[Any], Any]] + apis: Optional[Callable[[Any], Any]] + config: Optional[Callable[[Any], Any]] + + +# export type AllRecipeConfigs = { +# accountlinking: AccountLinkingTypeInput & { override?: { apis: never } }; +# dashboard: DashboardTypeInput; +# emailpassword: EmailPasswordTypeInput; +# emailverification: EmailVerificationTypeInput; +# jwt: JWTTypeInput; +# multifactorauth: MultifactorAuthTypeInput; +# multitenancy: MultitenancyTypeInput; +# oauth2provider: OAuth2ProviderTypeInput; +# openid: OpenIdTypeInput; +# passwordless: PasswordlessTypeInput; +# session: SessionTypeInput; +# thirdparty: ThirdPartyTypeInput; +# totp: TotpTypeInput; +# usermetadata: UserMetadataTypeInput; +# userroles: UserRolesTypeInput; +# }; + +# export type RecipePluginOverride = { +# functions?: NonNullable["functions"]; +# apis?: NonNullable["apis"]; +# config?: (config: AllRecipeConfigs[T]) => AllRecipeConfigs[T]; +# }; + + +class PluginRouteHandlerResponse(CamelCaseBaseModel): + status: int + body: Any + + +@runtime_checkable +class PluginRouteHandlerHandlerFunction(Protocol): + def __call__( + self, + request: BaseRequest, + response: BaseResponse, + session: Optional["SessionContainer"], + user_context: UserContext, + ) -> BaseResponse: ... + + +@runtime_checkable +class OverrideGlobalClaimValidatorsFunction(Protocol): + def __call__( + self, + global_claim_validators: List["SessionClaimValidator"], + session: "SessionContainer", + user_context: UserContext, + ) -> MaybeAwaitable[List["SessionClaimValidator"]]: ... + + +class VerifySessionOptions(CamelCaseBaseModel): + session_required: Optional[bool] = None + anti_csrf_check: Optional[bool] = None + check_database: Optional[bool] = None + override_global_claim_validators: Optional[ + OverrideGlobalClaimValidatorsFunction + ] = None + + +class PluginRouteHandler: + method: str + path: str + verify_session_options: Optional[VerifySessionOptions] + handler: PluginRouteHandlerHandlerFunction + + +@runtime_checkable +class SuperTokensPluginInit(Protocol): + def __call__( + self, + config: "SupertokensPublicConfig", + all_plugins: List["SuperTokensPublicPlugin"], + sdk_version: str, + ) -> None: ... + + +class PluginDependenciesOkResponse(CamelCaseBaseModel): + status: Literal["OK"] = "OK" + plugins_to_add: List["SuperTokensPlugin"] + + +class PluginDependenciesErrorResponse(CamelCaseBaseModel): + status: Literal["ERROR"] = "ERROR" + message: str + + +@runtime_checkable +class SuperTokensPluginDependencies(Protocol): + def __call__( + self, + config: "SupertokensPublicConfig", + plugins_above: List["SuperTokensPublicPlugin"], + sdk_version: str, + ) -> Union[PluginDependenciesOkResponse, PluginDependenciesErrorResponse]: ... + + +class PluginRouteHandlerFunctionOkResponse(CamelCaseBaseModel): + status: Literal["OK"] = "OK" + plugins_to_add: List["SuperTokensPlugin"] + + +class PluginRouteHandlerFunctionErrorResponse(CamelCaseBaseModel): + status: Literal["ERROR"] = "ERROR" + message: str + + +@runtime_checkable +class PluginRouteHandlerFunction(Protocol): + def __call__( + self, + config: "SupertokensPublicConfig", + all_plugins: List["SuperTokensPublicPlugin"], + sdk_version: str, + ) -> Union[ + PluginRouteHandlerFunctionOkResponse, PluginRouteHandlerFunctionErrorResponse + ]: ... + + +@runtime_checkable +class PluginConfig(Protocol): + def __call__( + self, config: "SupertokensPublicConfig" + ) -> "SupertokensPublicConfig": ... + + +class SuperTokensPluginBase(CamelCaseBaseModel): + id: str + version: Optional[str] = None + compatible_sdk_versions: Union[str, List[str]] + exports: Optional[Dict[str, Any]] = None + + +OverrideMap = Dict[str, Any] + + +class SuperTokensPlugin(SuperTokensPluginBase): + init: Optional[SuperTokensPluginInit] = None + dependencies: Optional[SuperTokensPluginDependencies] = None + # TODO: Add types for recipes + # overrideMap?: { + # [recipeId in keyof AllRecipeConfigs]?: RecipePluginOverride & { + # recipeInitRequired?: boolean | ((sdkVersion: string) => boolean); + # }; + # }; + override_map: Optional[OverrideMap] = None + route_handlers: Optional[ + Union[List[PluginRouteHandler], PluginRouteHandlerFunction] + ] = None + config: Optional[PluginConfig] = None + + +class SuperTokensPublicPlugin(SuperTokensPluginBase): + initialized: bool + + @classmethod + def from_plugin(cls, plugin: SuperTokensPlugin) -> "SuperTokensPublicPlugin": + return cls( + id=plugin.id, + initialized=plugin.init is None, + version=plugin.version, + exports=plugin.exports, + compatible_sdk_versions=plugin.compatible_sdk_versions, + ) + + +class ConfigOverrideBase: + functions: Optional[Callable[[Any], Any]] = None + apis: Optional[Callable[[Any], Any]] = None + + +def apply_plugins(recipe_id: str, config: T, plugins: List[OverrideMap]) -> T: + # print("Startnig apply_plugins") + + def default_fn_override(original_implementation: T) -> T: + return original_implementation + + def default_api_override(original_implementation: T) -> T: + return original_implementation + + if config.override is None: + config.override = ConfigOverrideBase() + config.override.functions = default_fn_override + config.override.apis = default_api_override + + function_overrides = getattr(config.override, "functions", default_fn_override) + api_overrides = getattr(config.override, "apis", default_api_override) + + # print("config overrides", function_overrides, api_overrides) + + function_layers: list[Any] = [] + api_layers: list[Any] = [] + if function_overrides is not None: + function_layers.append(function_overrides) + if api_overrides is not None: + api_layers.append(api_overrides) + + # print("Starting plugin iteration") + for plugin in plugins: + # print(f"{plugin=}") + overrides = plugin[recipe_id] + # print(f"{overrides=}") + if overrides is not None: + if overrides.config is not None: + config = overrides.config(config) + + if overrides.functions is not None: + function_layers.append(overrides.functions) + if overrides.apis is not None: + api_layers.append(overrides.apis) + + # function_layers.reverse() + # api_layers.reverse() + + # Apply the user override first, followed by the plugin overrides + # Example: [user_override, plugin_dep_1, plugin_1] + # final_override(oI) -> plugin_1_oI + # plugin_1(plugin_dep_1_oI) -> plugin_1_oI + # plugin_dep_1(user_override_oI) -> plugin_dep_1_oI + # user_override(oI) -> user_override_oI + + if len(function_layers) > 0: + + def fn_override(original_implementation: Any) -> Any: + for function_layer in reversed(function_layers): + # for function_layer in function_layers: + original_implementation = function_layer(original_implementation) + return original_implementation + + config.override.functions = fn_override + + # config.override.functions = function_layers[0] + # for function_layer in function_layers[1:]: + # config.override.functions = function_layer(config.override.functions) + + if len(api_layers) > 0 and recipe_id != "accountlinking": + # config.override.apis = api_layers[0] + # for api_layer in api_layers[1:]: + # config.override.apis = api_layer(config.override.apis) + def api_override(original_implementation: Any) -> Any: + for api_layer in reversed(api_layers): + # for api_layer in api_layers: + original_implementation = api_layer(original_implementation) + return original_implementation + + config.override.apis = api_override + + return config diff --git a/supertokens_python/recipe/totp/recipe.py b/supertokens_python/recipe/totp/recipe.py index 638aac124..e4bc3f3b5 100644 --- a/supertokens_python/recipe/totp/recipe.py +++ b/supertokens_python/recipe/totp/recipe.py @@ -14,9 +14,10 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.multifactorauth.types import ( GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, @@ -205,12 +206,17 @@ def get_all_cors_headers(self) -> List[str]: def init( config: Union[TOTPConfig, None] = None, ): - def func(app_info: AppInfo): + def func(app_info: AppInfo, plugins: Optional[List[OverrideMap]] = None): + if plugins is None: + plugins = [] + if TOTPRecipe.__instance is None: TOTPRecipe.__instance = TOTPRecipe( TOTPRecipe.recipe_id, app_info, - config, + apply_plugins( + recipe_id=TOTPRecipe.recipe_id, config=config, plugins=plugins + ), ) return TOTPRecipe.__instance raise Exception( diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index abcf3aed0..14f6ab65f 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -14,8 +14,20 @@ from __future__ import annotations +from dataclasses import dataclass from os import environ -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, + cast, +) from typing_extensions import Literal @@ -24,8 +36,15 @@ get_maybe_none_as_str, log_debug_message, ) +from supertokens_python.plugins import ( + OverrideMap, + PluginRouteHandler, + SuperTokensPlugin, + SuperTokensPluginInit, + SuperTokensPublicPlugin, +) -from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT +from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT, VERSION from .exceptions import SuperTokensError from .interfaces import ( CreateUserIdMappingOkResult, @@ -93,6 +112,40 @@ def __init__( self.disable_core_call_cache = disable_core_call_cache +@dataclass +class SupertokensExperimentalConfig: + plugins: Optional[List[SuperTokensPlugin]] = None + + +# TODO: Change to Pydantic? + + +@dataclass +class SupertokensPublicConfig: + app_info: InputAppInfo + framework: Literal["fastapi", "flask", "django"] + supertokens_config: SupertokensConfig + mode: Optional[Literal["asgi", "wsgi"]] + telemetry: Optional[bool] + debug: Optional[bool] + + +@dataclass +class SupertokensInputConfig(SupertokensPublicConfig): + recipe_list: List[Callable[[AppInfo], RecipeModule]] + experimental: Optional[SupertokensExperimentalConfig] = None + + def get_public_config(self) -> SupertokensPublicConfig: + return SupertokensPublicConfig( + app_info=self.app_info, + framework=self.framework, + supertokens_config=self.supertokens_config, + mode=self.mode, + telemetry=self.telemetry, + debug=self.debug, + ) + + class Host: def __init__(self, domain: NormalisedURLDomain, base_path: NormalisedURLPath): self.domain = domain @@ -200,7 +253,21 @@ def manage_session_post_response( class Supertokens: - __instance = None + __instance: Optional[Supertokens] = None + + recipe_modules: List[RecipeModule] + + app_info: AppInfo + + supertokens_config: SupertokensConfig + + _telemetry_status: str + + telemetry: bool + + plugin_route_handlers: List[PluginRouteHandler] + + plugin_list: List[SuperTokensPublicPlugin] def __init__( self, @@ -211,10 +278,118 @@ def __init__( mode: Optional[Literal["asgi", "wsgi"]], telemetry: Optional[bool], debug: Optional[bool], + experimental: Optional[SupertokensExperimentalConfig] = None, ): if not isinstance(app_info, InputAppInfo): # type: ignore raise ValueError("app_info must be an instance of InputAppInfo") + config = SupertokensInputConfig( + app_info=app_info, + framework=framework, + supertokens_config=supertokens_config, + recipe_list=recipe_list, + mode=mode, + telemetry=telemetry, + debug=debug, + experimental=experimental, + ) + # TODO: Probably just want to define this directly and use it + # Can build a input config from the final public config and the additional props + public_config = config.get_public_config() + + self.plugin_route_handlers = [] + + input_plugin_list = [] + final_plugin_list: List[SuperTokensPlugin] = [] + + if experimental is not None and experimental.plugins is not None: + input_plugin_list = experimental.plugins + + print(f"{input_plugin_list=}") + + for plugin in input_plugin_list: + if isinstance(plugin.compatible_sdk_versions, list): + version_constraints = plugin.compatible_sdk_versions + else: + version_constraints = [plugin.compatible_sdk_versions] + + if VERSION not in version_constraints: + # TODO: Better checks + raise Exception("Plugin version mismatch") + + if plugin.dependencies is not None: + dep_result = plugin.dependencies( + config=public_config, + plugins_above=[ + SuperTokensPublicPlugin.from_plugin(plugin) + for plugin in final_plugin_list + ], + sdk_version=VERSION, + ) + + if dep_result.status == "ERROR": + raise Exception(dep_result.message) + + if dep_result.plugins_to_add: + final_plugin_list.extend(dep_result.plugins_to_add) + + final_plugin_list.append(plugin) + + self.plugin_list = [ + SuperTokensPublicPlugin.from_plugin(plugin) for plugin in final_plugin_list + ] + print(f"{self.plugin_list=}") + print() + + for plugin_idx, plugin in enumerate(final_plugin_list): + print() + print(f"{plugin_idx=} {plugin=}") + + if plugin.config is not None: + public_config = plugin.config(public_config) + + if plugin.route_handlers is not None: + handlers: List[PluginRouteHandler] = [] + + if callable(plugin.route_handlers): + handler_result = plugin.route_handlers( + config=public_config, + all_plugins=self.plugin_list, + sdk_version=VERSION, + ) + if handler_result.status == "ERROR": + raise Exception(handler_result.message) + else: + handlers = plugin.route_handlers + + self.plugin_route_handlers.extend(handlers) + + if plugin.init is not None: + print(f"{plugin.init=}") + plugin.init( + config=public_config, + all_plugins=self.plugin_list, + sdk_version=VERSION, + ) + + # TODO: Make this a factory function to avoid weird side-effects? + def callback_factory(): + # This has to be part of the factory to ensure we pick up the correct plugin + init_fn = cast(SuperTokensPluginInit, plugin.init) + idx = plugin_idx + + def callback(): + init_fn( + config=public_config, + all_plugins=self.plugin_list, + sdk_version=VERSION, + ) + self.plugin_list[idx].initialized = True + + return callback + + PostSTInitCallbacks.add_post_init_callback(callback_factory()) + self.app_info = AppInfo( app_info.app_name, app_info.api_domain, @@ -255,6 +430,12 @@ def __init__( "Please provide at least one recipe to the supertokens.init function call" ) + override_maps = [ + plugin.override_map + for plugin in final_plugin_list + if plugin.override_map is not None + ] + multitenancy_found = False totp_found = False user_metadata_found = False @@ -263,7 +444,9 @@ def __init__( openid_found = False jwt_found = False - def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: + def make_recipe( + recipe: Callable[[AppInfo, List[OverrideMap]], RecipeModule], + ) -> RecipeModule: nonlocal \ multitenancy_found, \ totp_found, \ @@ -272,7 +455,7 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: oauth2_found, \ openid_found, \ jwt_found - recipe_module = recipe(self.app_info) + recipe_module = recipe(self.app_info, override_maps) if recipe_module.get_recipe_id() == "multitenancy": multitenancy_found = True elif recipe_module.get_recipe_id() == "usermetadata": @@ -336,6 +519,7 @@ def init( mode: Optional[Literal["asgi", "wsgi"]], telemetry: Optional[bool], debug: Optional[bool], + experimental: Optional[SupertokensExperimentalConfig] = None, ): if Supertokens.__instance is None: Supertokens.__instance = Supertokens( @@ -346,6 +530,7 @@ def init( mode, telemetry, debug, + experimental=experimental, ) PostSTInitCallbacks.run_post_init_callbacks() @@ -543,12 +728,47 @@ async def update_or_delete_user_id_mapping_info( async def middleware( self, request: BaseRequest, response: BaseResponse, user_context: Dict[str, Any] ) -> Union[BaseResponse, None]: + from supertokens_python.recipe.session.recipe import SessionRecipe + log_debug_message("middleware: Started") path = Supertokens.get_instance().app_info.api_gateway_path.append( NormalisedURLPath(request.get_path()) ) method = normalise_http_method(request.method()) + handler_from_apis: Optional[PluginRouteHandler] = None + for handler in self.plugin_route_handlers: + if ( + handler.path == path.get_as_string_dangerous() + and handler.method == method + ): + log_debug_message( + "middleware: Found matching plugin route handler for path: %s and method: %s", + path.get_as_string_dangerous(), + method, + ) + handler_from_apis = handler + break + + if handler_from_apis is not None: + session: Optional[SessionContainer] = None + if handler_from_apis.verify_session_options is not None: + # TODO: Fix verify_session_options type + session = await SessionRecipe.get_instance().verify_session( + request=request, + user_context=user_context, + **handler_from_apis.verify_session_options, + ) + handler_from_apis.handler( + request=request, + response=response, + session=session, + user_context=user_context, + ) + + # TODO: Why do we do this? + return None + if not path.startswith(Supertokens.get_instance().app_info.api_base_path): log_debug_message( "middleware: Not handling because request path did not start with api base path. Request path: %s", From 5f57509800141c331b3d88b5144572c2c3371794 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Fri, 6 Jun 2025 15:09:42 +0530 Subject: [PATCH 02/37] feat: make plugins functional - Matches plugin eval order to Node - Adds recursive dependency resolution - Adds tests for overrides and dependencies --- supertokens_python/__init__.py | 7 +- supertokens_python/plugins.py | 56 +++--- supertokens_python/supertokens.py | 75 +++++--- tests/plugins/__init__.py | 0 tests/plugins/api_implementation.py | 22 +++ tests/plugins/config.py | 71 +++++++ tests/plugins/plugins.py | 112 +++++++++++ tests/plugins/recipe.py | 130 +++++++++++++ tests/plugins/recipe_implementation.py | 34 ++++ tests/plugins/test_plugins.py | 249 +++++++++++++++++++++++++ tests/plugins/types.py | 10 + 11 files changed, 702 insertions(+), 64 deletions(-) create mode 100644 tests/plugins/__init__.py create mode 100644 tests/plugins/api_implementation.py create mode 100644 tests/plugins/config.py create mode 100644 tests/plugins/plugins.py create mode 100644 tests/plugins/recipe.py create mode 100644 tests/plugins/recipe_implementation.py create mode 100644 tests/plugins/test_plugins.py create mode 100644 tests/plugins/types.py diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index 3b5611bdc..8beee8b68 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -12,28 +12,29 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from typing_extensions import Literal from supertokens_python.framework.request import BaseRequest +from supertokens_python.recipe_module import RecipeModule from supertokens_python.types import RecipeUserId from . import supertokens -from .recipe_module import RecipeModule InputAppInfo = supertokens.InputAppInfo Supertokens = supertokens.Supertokens SupertokensConfig = supertokens.SupertokensConfig AppInfo = supertokens.AppInfo SupertokensExperimentalConfig = supertokens.SupertokensExperimentalConfig +RecipeModule = RecipeModule def init( app_info: InputAppInfo, framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, - recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]], + recipe_list: List[supertokens.RecipeInit], mode: Optional[Literal["asgi", "wsgi"]] = None, telemetry: Optional[bool] = None, debug: Optional[bool] = None, diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 18f5c245b..10b5a1b6c 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -1,3 +1,10 @@ +# TODOs: +# - [ ] Define base classes for: +# - Config +# - RecipeInterface +# - APIInterface +# - OverrideConfig + from typing import ( TYPE_CHECKING, Any, @@ -128,9 +135,9 @@ def __call__( class VerifySessionOptions(CamelCaseBaseModel): - session_required: Optional[bool] = None + session_required: bool anti_csrf_check: Optional[bool] = None - check_database: Optional[bool] = None + check_database: bool override_global_claim_validators: Optional[ OverrideGlobalClaimValidatorsFunction ] = None @@ -139,8 +146,8 @@ class VerifySessionOptions(CamelCaseBaseModel): class PluginRouteHandler: method: str path: str - verify_session_options: Optional[VerifySessionOptions] handler: PluginRouteHandlerHandlerFunction + verify_session_options: Optional[VerifySessionOptions] @runtime_checkable @@ -175,7 +182,7 @@ def __call__( class PluginRouteHandlerFunctionOkResponse(CamelCaseBaseModel): status: Literal["OK"] = "OK" - plugins_to_add: List["SuperTokensPlugin"] + route_handlers: List[PluginRouteHandler] class PluginRouteHandlerFunctionErrorResponse(CamelCaseBaseModel): @@ -247,12 +254,13 @@ class ConfigOverrideBase: apis: Optional[Callable[[Any], Any]] = None +# TODO: Pass in the OverrideConfig class as an arg, use it to define a default if None def apply_plugins(recipe_id: str, config: T, plugins: List[OverrideMap]) -> T: - # print("Startnig apply_plugins") - + # TODO: Change to recipe_implementation type def default_fn_override(original_implementation: T) -> T: return original_implementation + # TODO: Change to api_implementation type def default_api_override(original_implementation: T) -> T: return original_implementation @@ -264,8 +272,6 @@ def default_api_override(original_implementation: T) -> T: function_overrides = getattr(config.override, "functions", default_fn_override) api_overrides = getattr(config.override, "apis", default_api_override) - # print("config overrides", function_overrides, api_overrides) - function_layers: list[Any] = [] api_layers: list[Any] = [] if function_overrides is not None: @@ -273,11 +279,8 @@ def default_api_override(original_implementation: T) -> T: if api_overrides is not None: api_layers.append(api_overrides) - # print("Starting plugin iteration") for plugin in plugins: - # print(f"{plugin=}") overrides = plugin[recipe_id] - # print(f"{overrides=}") if overrides is not None: if overrides.config is not None: config = overrides.config(config) @@ -287,37 +290,24 @@ def default_api_override(original_implementation: T) -> T: if overrides.apis is not None: api_layers.append(overrides.apis) - # function_layers.reverse() - # api_layers.reverse() - - # Apply the user override first, followed by the plugin overrides - # Example: [user_override, plugin_dep_1, plugin_1] - # final_override(oI) -> plugin_1_oI - # plugin_1(plugin_dep_1_oI) -> plugin_1_oI - # plugin_dep_1(user_override_oI) -> plugin_dep_1_oI - # user_override(oI) -> user_override_oI - + # Apply overrides in order of definition + # Plugins: [plugin1, plugin2] would be applied as [override, plugin1, plugin2, original] if len(function_layers) > 0: - - def fn_override(original_implementation: Any) -> Any: + # TODO: Change to recipe_implementation type + def fn_override(original_implementation: T) -> T: + # The layers will get called in reversed order + # Iteration is reversed to ensure that the required order is maintained for function_layer in reversed(function_layers): - # for function_layer in function_layers: original_implementation = function_layer(original_implementation) return original_implementation config.override.functions = fn_override - # config.override.functions = function_layers[0] - # for function_layer in function_layers[1:]: - # config.override.functions = function_layer(config.override.functions) - + # AccountLinking recipe does not have an API implementation if len(api_layers) > 0 and recipe_id != "accountlinking": - # config.override.apis = api_layers[0] - # for api_layer in api_layers[1:]: - # config.override.apis = api_layer(config.override.apis) - def api_override(original_implementation: Any) -> Any: + # TODO: Change to api_implementation type + def api_override(original_implementation: T) -> T: for api_layer in reversed(api_layers): - # for api_layer in api_layers: original_implementation = api_layer(original_implementation) return original_implementation diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 14f6ab65f..45d5b2a6f 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -23,6 +23,7 @@ Dict, List, Optional, + Protocol, Set, Tuple, Union, @@ -38,6 +39,7 @@ ) from supertokens_python.plugins import ( OverrideMap, + PluginDependenciesErrorResponse, PluginRouteHandler, SuperTokensPlugin, SuperTokensPluginInit, @@ -132,7 +134,7 @@ class SupertokensPublicConfig: @dataclass class SupertokensInputConfig(SupertokensPublicConfig): - recipe_list: List[Callable[[AppInfo], RecipeModule]] + recipe_list: List[Callable[[AppInfo, List[OverrideMap]], RecipeModule]] experimental: Optional[SupertokensExperimentalConfig] = None def get_public_config(self) -> SupertokensPublicConfig: @@ -252,6 +254,12 @@ def manage_session_post_response( mutator(response, user_context) +class RecipeInit(Protocol): + def __call__( + self, app_info: AppInfo, plugins: List[OverrideMap] + ) -> RecipeModule: ... + + class Supertokens: __instance: Optional[Supertokens] = None @@ -274,12 +282,14 @@ def __init__( app_info: InputAppInfo, framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, - recipe_list: List[Callable[[AppInfo], RecipeModule]], + recipe_list: List[RecipeInit], mode: Optional[Literal["asgi", "wsgi"]], telemetry: Optional[bool], debug: Optional[bool], experimental: Optional[SupertokensExperimentalConfig] = None, ): + from supertokens_python.plugins import PluginRouteHandlerFunctionErrorResponse + if not isinstance(app_info, InputAppInfo): # type: ignore raise ValueError("app_info must be an instance of InputAppInfo") @@ -305,8 +315,6 @@ def __init__( if experimental is not None and experimental.plugins is not None: input_plugin_list = experimental.plugins - print(f"{input_plugin_list=}") - for plugin in input_plugin_list: if isinstance(plugin.compatible_sdk_versions, list): version_constraints = plugin.compatible_sdk_versions @@ -317,34 +325,39 @@ def __init__( # TODO: Better checks raise Exception("Plugin version mismatch") - if plugin.dependencies is not None: - dep_result = plugin.dependencies( - config=public_config, - plugins_above=[ - SuperTokensPublicPlugin.from_plugin(plugin) - for plugin in final_plugin_list - ], - sdk_version=VERSION, - ) + def recurse_deps(plugin: SuperTokensPlugin, deps: List[SuperTokensPlugin]): + if plugin.dependencies is not None: + # Get all dependencies of the plugin + dep_result = plugin.dependencies( + config=public_config, + plugins_above=[ + SuperTokensPublicPlugin.from_plugin(plugin) + for plugin in final_plugin_list + ], + sdk_version=VERSION, + ) - if dep_result.status == "ERROR": - raise Exception(dep_result.message) + # Errors fall through + if isinstance(dep_result, PluginDependenciesErrorResponse): + raise Exception(dep_result.message) - if dep_result.plugins_to_add: - final_plugin_list.extend(dep_result.plugins_to_add) + # Recurse through all dependencies and add the resultant plugins to the list + # Pre-order DFS traversal + for dep_plugin in dep_result.plugins_to_add: + recurse_deps(dep_plugin, deps) - final_plugin_list.append(plugin) + # Add the current plugin + deps.append(plugin) + return deps + + final_plugin_list.extend(recurse_deps(plugin, [])) self.plugin_list = [ SuperTokensPublicPlugin.from_plugin(plugin) for plugin in final_plugin_list ] - print(f"{self.plugin_list=}") - print() for plugin_idx, plugin in enumerate(final_plugin_list): - print() - print(f"{plugin_idx=} {plugin=}") - + # Override the public supertokens config using the config override defined in the plugin if plugin.config is not None: public_config = plugin.config(public_config) @@ -357,15 +370,18 @@ def __init__( all_plugins=self.plugin_list, sdk_version=VERSION, ) - if handler_result.status == "ERROR": + if isinstance( + handler_result, PluginRouteHandlerFunctionErrorResponse + ): raise Exception(handler_result.message) + + handlers = handler_result.route_handlers else: handlers = plugin.route_handlers self.plugin_route_handlers.extend(handlers) if plugin.init is not None: - print(f"{plugin.init=}") plugin.init( config=public_config, all_plugins=self.plugin_list, @@ -515,7 +531,7 @@ def init( app_info: InputAppInfo, framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, - recipe_list: List[Callable[[AppInfo], RecipeModule]], + recipe_list: List[RecipeInit], mode: Optional[Literal["asgi", "wsgi"]], telemetry: Optional[bool], debug: Optional[bool], @@ -753,11 +769,14 @@ async def middleware( if handler_from_apis is not None: session: Optional[SessionContainer] = None if handler_from_apis.verify_session_options is not None: - # TODO: Fix verify_session_options type + verify_session_options = handler_from_apis.verify_session_options session = await SessionRecipe.get_instance().verify_session( request=request, user_context=user_context, - **handler_from_apis.verify_session_options, + anti_csrf_check=verify_session_options.anti_csrf_check, + session_required=verify_session_options.session_required, + check_database=verify_session_options.check_database, + override_global_claim_validators=verify_session_options.override_global_claim_validators, ) handler_from_apis.handler( request=request, diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/plugins/api_implementation.py b/tests/plugins/api_implementation.py new file mode 100644 index 000000000..c38297b94 --- /dev/null +++ b/tests/plugins/api_implementation.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod +from typing import ( + List, +) + +from .types import RecipeReturnType + + +class APIInterface(ABC): + @abstractmethod + def sign_in_post(self, message: str, stack: List[str]) -> RecipeReturnType: ... + + +class APIImplementation(APIInterface): + def sign_in_post(self, message: str, stack: List[str]) -> RecipeReturnType: + stack.append("original") + return RecipeReturnType( + type="API", + function="sign_in_post", + stack=stack, + message=message, + ) diff --git a/tests/plugins/config.py b/tests/plugins/config.py new file mode 100644 index 000000000..b5284a2a9 --- /dev/null +++ b/tests/plugins/config.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Protocol, + TypeVar, + runtime_checkable, +) + +from supertokens_python.supertokens import ( + AppInfo, +) + +if TYPE_CHECKING: + from .api_implementation import APIInterface + from .recipe_implementation import RecipeInterface + +InterfaceType = TypeVar("InterfaceType") +"""Generic Type for use in `InterfaceOverride`""" + + +@runtime_checkable +class InterfaceOverride(Protocol[InterfaceType]): + """ + Callable signature for `WebauthnConfig.override.*`. + """ + + def __call__( + self, + original_implementation: InterfaceType, + ) -> InterfaceType: ... + + +# NOTE: Using dataclasses for these classes since validation is not required +@dataclass +class OverrideConfig: + """ + `WebauthnConfig.override` + """ + + functions: Optional[InterfaceOverride["RecipeInterface"]] = None + apis: Optional[InterfaceOverride["APIInterface"]] = None + config: Optional[InterfaceOverride[Any]] = None + + +@dataclass +class NormalizedPluginTestConfig: + override: OverrideConfig + + +@dataclass +class PluginTestConfig: + override: Optional[OverrideConfig] = None + + +def validate_and_normalise_user_input( + config: Optional[PluginTestConfig], app_info: AppInfo +) -> NormalizedPluginTestConfig: + if config is None: + config = PluginTestConfig() + + if config.override is None: + override = OverrideConfig() + else: + override = OverrideConfig( + functions=config.override.functions, + apis=config.override.apis, + ) + + return NormalizedPluginTestConfig(override=override) diff --git a/tests/plugins/plugins.py b/tests/plugins/plugins.py new file mode 100644 index 000000000..2cd7d5742 --- /dev/null +++ b/tests/plugins/plugins.py @@ -0,0 +1,112 @@ +from typing import Any, List, Optional, Union + +from supertokens_python.plugins import ( + OverrideMap, + PluginDependenciesOkResponse, + SuperTokensPlugin, + SuperTokensPluginDependencies, + SuperTokensPublicPlugin, +) +from supertokens_python.supertokens import SupertokensPublicConfig + +from .api_implementation import APIInterface +from .config import OverrideConfig +from .recipe import PluginTestRecipe +from .recipe_implementation import RecipeInterface + + +def function_override_factory(identifier: str): + def function_override(original_implementation: RecipeInterface) -> RecipeInterface: + og_sign_in = original_implementation.sign_in + + def new_sign_in(message: str, stack: List[str]): + stack.append(identifier) + return og_sign_in(message, stack) + + original_implementation.sign_in = new_sign_in + return original_implementation + + return function_override + + +def api_override_factory(identifier: str): + def function_override(original_implementation: APIInterface) -> APIInterface: + sign_in_post = original_implementation.sign_in_post + + def new_sign_in_post(message: str, stack: List[str]): + stack.append(identifier) + return sign_in_post(message, stack) + + original_implementation.sign_in_post = new_sign_in_post + return original_implementation + + return function_override + + +def init_factory(identifier: str): + def init( + config: SupertokensPublicConfig, + all_plugins: List[SuperTokensPublicPlugin], + sdk_version: str, + ): + # TODO: Test this + print(f"{identifier} init") + + return init + + +def dependency_factory(dependencies: Optional[List[SuperTokensPlugin]]): + if dependencies is None: + dependencies = [] + + def dependency( + config: SupertokensPublicConfig, + plugins_above: List[SuperTokensPublicPlugin], + sdk_version: str, + ): + added_plugin_ids = [plugin.id for plugin in plugins_above] + plugins_to_add = [ + plugin for plugin in dependencies if plugin.id not in added_plugin_ids + ] + return PluginDependenciesOkResponse(plugins_to_add=plugins_to_add) + + return dependency + + +def plugin_factory( + identifier: str, + override_functions: bool = False, + override_apis: bool = False, + deps: Optional[List[SuperTokensPlugin]] = None, +): + override_map_obj: OverrideMap = {PluginTestRecipe.recipe_id: OverrideConfig()} + + if override_functions: + override_map_obj[ + PluginTestRecipe.recipe_id + ].functions = function_override_factory(identifier) + if override_apis: + override_map_obj[PluginTestRecipe.recipe_id].apis = api_override_factory( + identifier + ) + + class Plugin(SuperTokensPlugin): + id: str = identifier + compatible_sdk_versions: Union[str, List[str]] = ["0.30.0"] + override_map: Optional[OverrideMap] = override_map_obj + init: Any = init_factory(identifier) + dependencies: Optional[SuperTokensPluginDependencies] = dependency_factory(deps) + + return Plugin() + + +Plugin1 = plugin_factory("plugin1", override_functions=True) +Plugin2 = plugin_factory("plugin2", override_functions=True) +Plugin3Dep1 = plugin_factory("plugin3dep1", override_functions=True, deps=[Plugin1]) +Plugin3Dep2_1 = plugin_factory( + "plugin3dep2_1", override_functions=True, deps=[Plugin2, Plugin1] +) +Plugin4Dep2 = plugin_factory("plugin4dep2", override_functions=True, deps=[Plugin2]) +Plugin4Dep3__2_1 = plugin_factory( + "plugin4dep3__2_1", override_functions=True, deps=[Plugin3Dep2_1] +) diff --git a/tests/plugins/recipe.py b/tests/plugins/recipe.py new file mode 100644 index 000000000..c7977494a --- /dev/null +++ b/tests/plugins/recipe.py @@ -0,0 +1,130 @@ +from typing import ( + List, + Optional, +) + +from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from supertokens_python.framework.request import BaseRequest +from supertokens_python.framework.response import BaseResponse +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins +from supertokens_python.querier import Querier +from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.supertokens import ( + AppInfo, +) +from supertokens_python.types.base import UserContext + +from .api_implementation import APIImplementation +from .config import ( + NormalizedPluginTestConfig, + PluginTestConfig, + validate_and_normalise_user_input, +) +from .recipe_implementation import RecipeImplementation + + +class PluginTestRecipe(RecipeModule): + __instance: Optional["PluginTestRecipe"] = None + recipe_id = "plugin_test" + + config: NormalizedPluginTestConfig + recipe_implementation: RecipeImplementation + api_implementation: APIImplementation + + def __init__( + self, recipe_id: str, app_info: AppInfo, config: Optional[PluginTestConfig] + ): + super().__init__(recipe_id=recipe_id, app_info=app_info) + self.config = validate_and_normalise_user_input( + app_info=app_info, config=config + ) + + querier = Querier.get_instance(rid_to_core=recipe_id) + recipe_implementation = RecipeImplementation( + querier=querier, + config=self.config, + ) + self.recipe_implementation = ( + recipe_implementation + if self.config.override.functions is None + else self.config.override.functions(recipe_implementation) + ) # type: ignore + + api_implementation = APIImplementation() + self.api_implementation = ( + api_implementation + if self.config.override.apis is None + else self.config.override.apis(api_implementation) + ) # type: ignore + + @staticmethod + def get_instance() -> "PluginTestRecipe": + if PluginTestRecipe.__instance is not None: + return PluginTestRecipe.__instance + raise_general_exception( + "Initialisation not done. Did you forget to call the SuperTokens.init function?" + ) + + @staticmethod + def get_instance_optional() -> Optional["PluginTestRecipe"]: + return PluginTestRecipe.__instance + + @staticmethod + def init(config: Optional[PluginTestConfig]): + def func(app_info: AppInfo, plugins: List[OverrideMap]): + if PluginTestRecipe.__instance is None: + PluginTestRecipe.__instance = PluginTestRecipe( + recipe_id=PluginTestRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=PluginTestRecipe.recipe_id, + config=config, + plugins=plugins, + ), + ) + return PluginTestRecipe.__instance + else: + raise_general_exception( + "PluginTestRecipe has already been initialised. Please check your code for bugs." + ) + + return func + + @staticmethod + def reset(): + PluginTestRecipe.__instance = None + + def get_all_cors_headers(self) -> List[str]: + return [] + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: UserContext, + ): + raise err + + async def handle_api_request( + self, + request_id: str, + tenant_id: str, + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: UserContext, + ): + return None + + def get_apis_handled(self) -> List[APIHandled]: + return [] + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return False + + +def plugin_test_init(config: Optional[PluginTestConfig] = None): + return PluginTestRecipe.init(config=config) diff --git a/tests/plugins/recipe_implementation.py b/tests/plugins/recipe_implementation.py new file mode 100644 index 000000000..03988e859 --- /dev/null +++ b/tests/plugins/recipe_implementation.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from typing import ( + List, +) + +from supertokens_python.querier import Querier + +from .config import NormalizedPluginTestConfig +from .types import RecipeReturnType + + +class RecipeInterface(ABC): + @abstractmethod + def sign_in(self, message: str, stack: List[str]) -> RecipeReturnType: ... + + +class RecipeImplementation(RecipeInterface): + def __init__( + self, + querier: Querier, + config: NormalizedPluginTestConfig, + ): + super().__init__() + self.querier = querier + self.config = config + + def sign_in(self, message: str, stack: List[str]) -> RecipeReturnType: + stack.append("original") + return RecipeReturnType( + type="Recipe", + function="sign_in", + stack=stack, + message=message, + ) diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py new file mode 100644 index 000000000..39720d6af --- /dev/null +++ b/tests/plugins/test_plugins.py @@ -0,0 +1,249 @@ +from functools import partial +from typing import Any, List + +from pytest import fixture, mark, param +from supertokens_python import ( + InputAppInfo, + SupertokensConfig, + SupertokensExperimentalConfig, + init, +) +from supertokens_python.plugins import SuperTokensPlugin + +from tests.utils import outputs, reset + +from .config import OverrideConfig, PluginTestConfig +from .plugins import ( + Plugin1, + Plugin2, + Plugin3Dep1, + Plugin3Dep2_1, + Plugin4Dep2, + Plugin4Dep3__2_1, + api_override_factory, + function_override_factory, + plugin_factory, +) +from .recipe import PluginTestRecipe, plugin_test_init +from .types import RecipeReturnType + + +@fixture(autouse=True) +def setup_and_teardown(): + reset() + PluginTestRecipe.reset() + yield + reset() + PluginTestRecipe.reset() + + +def recipe_factory(override_functions: bool = False, override_apis: bool = False): + override = OverrideConfig() + + if override_functions: + override.functions = function_override_factory("override") + if override_apis: + override.apis = api_override_factory("override") + + return plugin_test_init(config=PluginTestConfig(override=override)) + + +partial_init = partial( + init, + app_info=InputAppInfo( + app_name="plugin_test", + api_domain="api.supertokens.io", + origin="http://localhost:3001", + ), + framework="django", + supertokens_config=SupertokensConfig( + connection_uri="http://localhost:3567", + ), +) + + +@mark.parametrize( + ( + "recipe_fn_override", + "recipe_api_override", + "plugins", + "recipe_expectation", + "api_expectation", + ), + [ + param( + False, + False, + [], + outputs(["original"]), + outputs(["original"]), + id="fn_ovr=False, api_ovr=False, plugins=[]", + ), + param( + True, + False, + [], + outputs(["override", "original"]), + outputs(["original"]), + id="fn_ovr=True, api_ovr=False, plugins=[]", + ), + param( + False, + True, + [], + outputs(["original"]), + outputs(["override", "original"]), + id="fn_ovr=False, api_ovr=True, plugins=[]", + ), + param( + True, + False, + [plugin_factory("plugin1", override_functions=True)], + outputs(["override", "plugin1", "original"]), + outputs(["original"]), + id="fn_ovr=True, api_ovr=False, plugins=[Plugin1], plugin1=[fn]", + ), + param( + True, + False, + [plugin_factory("plugin1", override_apis=True)], + outputs(["override", "original"]), + outputs(["plugin1", "original"]), + id="fn_ovr=True, api_ovr=False, plugins=[Plugin1], plugin1=[api]", + ), + param( + False, + True, + [plugin_factory("plugin1", override_functions=True)], + outputs(["plugin1", "original"]), + outputs(["override", "original"]), + id="fn_ovr=False, api_ovr=True, plugins=[Plugin1], plugin1=[fn]", + ), + param( + False, + True, + [plugin_factory("plugin1", override_apis=True)], + outputs(["original"]), + outputs(["override", "plugin1", "original"]), + id="fn_ovr=False, api_ovr=True, plugins=[Plugin1], plugin1=[api]", + ), + param( + True, + False, + [Plugin1, Plugin2], + outputs(["override", "plugin1", "plugin2", "original"]), + outputs(["original"]), + id="fn_ovr=True, api_ovr=False, plugins=[Plugin1, Plugin2], plugin1=[fn], plugin2=[fn]", + ), + ], +) +def test_overrides( + recipe_fn_override: bool, + recipe_api_override: bool, + plugins: List[SuperTokensPlugin], + recipe_expectation: Any, + api_expectation: Any, +): + partial_init( + recipe_list=[ + recipe_factory( + override_functions=recipe_fn_override, override_apis=recipe_api_override + ), + ], + experimental=SupertokensExperimentalConfig( + plugins=plugins, + ), + ) + + with recipe_expectation as expected_stack: + output = PluginTestRecipe.get_instance().recipe_implementation.sign_in( + "msg", [] + ) + assert output == RecipeReturnType( + type="Recipe", + function="sign_in", + stack=expected_stack, + message="msg", + ) + + with api_expectation as expected_stack: + output = PluginTestRecipe.get_instance().api_implementation.sign_in_post( + "msg", [] + ) + assert output == RecipeReturnType( + type="API", + function="sign_in_post", + stack=expected_stack, + message="msg", + ) + + +@mark.parametrize( + ("plugins", "recipe_expectation", "api_expectation"), + [ + param( + [Plugin3Dep1], + outputs(["plugin1", "plugin3dep1", "original"]), + outputs(["original"]), + id="3->1", + ), + param( + [Plugin3Dep2_1], + outputs(["plugin2", "plugin1", "plugin3dep2_1", "original"]), + outputs(["original"]), + id="3->(2,1)", + ), + param( + [Plugin3Dep1, Plugin4Dep2], + outputs(["plugin1", "plugin3dep1", "plugin2", "plugin4dep2", "original"]), + outputs(["original"]), + id="3->1,4->2", + ), + param( + [Plugin4Dep3__2_1], + outputs( + ["plugin2", "plugin1", "plugin3dep2_1", "plugin4dep3__2_1", "original"] + ), + outputs(["original"]), + id="4->3->(2,1)", + ), + ], +) +def test_depdendencies( + plugins: List[SuperTokensPlugin], + recipe_expectation: Any, + api_expectation: Any, +): + partial_init( + recipe_list=[ + recipe_factory(), + ], + experimental=SupertokensExperimentalConfig( + plugins=plugins, + ), + ) + + with recipe_expectation as expected_stack: + output = PluginTestRecipe.get_instance().recipe_implementation.sign_in( + "msg", [] + ) + assert output == RecipeReturnType( + type="Recipe", + function="sign_in", + stack=expected_stack, + message="msg", + ) + + with api_expectation as expected_stack: + output = PluginTestRecipe.get_instance().api_implementation.sign_in_post( + "msg", [] + ) + assert output == RecipeReturnType( + type="API", + function="sign_in_post", + stack=expected_stack, + message="msg", + ) + + +# TODO: Add tests for init, route handlers, config overrides diff --git a/tests/plugins/types.py b/tests/plugins/types.py new file mode 100644 index 000000000..f14bcfeb8 --- /dev/null +++ b/tests/plugins/types.py @@ -0,0 +1,10 @@ +from typing import List, Literal + +from supertokens_python.types.response import CamelCaseBaseModel + + +class RecipeReturnType(CamelCaseBaseModel): + type: Literal["Recipe", "API"] + function: str + stack: List[str] + message: str From fc9b5299e6d8723177d9cd1aed1b318cd2aee513 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Fri, 6 Jun 2025 18:30:35 +0530 Subject: [PATCH 03/37] update: change plugin evaluation order --- supertokens_python/plugins.py | 80 ++++++++++++++++++++++++++----- supertokens_python/supertokens.py | 42 ++++++---------- tests/plugins/plugins.py | 1 + tests/plugins/test_plugins.py | 32 +++++++++---- 4 files changed, 107 insertions(+), 48 deletions(-) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 10b5a1b6c..83f3cff3b 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -5,6 +5,7 @@ # - APIInterface # - OverrideConfig +from collections import deque from typing import ( TYPE_CHECKING, Any, @@ -13,6 +14,7 @@ List, Literal, Optional, + Set, TypeVar, Union, runtime_checkable, @@ -234,6 +236,58 @@ class SuperTokensPlugin(SuperTokensPluginBase): ] = None config: Optional[PluginConfig] = None + def get_dependencies( + self, + public_config: "SupertokensPublicConfig", + plugins_above: List["SuperTokensPlugin"], + sdk_version: str, + ): + """ + Pre-order DFS traversal to get all dependencies of a plugin. + """ + + def recurse_deps( + plugin: SuperTokensPlugin, + deps: Optional[List[SuperTokensPlugin]] = None, + visited: Optional[Set[str]] = None, + ) -> List[SuperTokensPlugin]: + if deps is None: + deps = [] + + if visited is None: + visited = set() + + if plugin.id in visited: + return deps + visited.add(plugin.id) + + if plugin.dependencies is not None: + # Get all dependencies of the plugin + dep_result = plugin.dependencies( + config=public_config, + plugins_above=[ + SuperTokensPublicPlugin.from_plugin(plugin) + for plugin in plugins_above + ], + sdk_version=sdk_version, + ) + + # Errors fall through + if isinstance(dep_result, PluginDependenciesErrorResponse): + raise Exception(dep_result.message) + + # Recurse through all dependencies and add the resultant plugins to the list + # Pre-order DFS traversal + for dep_plugin in dep_result.plugins_to_add: + recurse_deps(dep_plugin, deps) + + # Add the current plugin and mark it as visited + deps.append(plugin) + + return deps + + return recurse_deps(self) + class SuperTokensPublicPlugin(SuperTokensPluginBase): initialized: bool @@ -272,12 +326,12 @@ def default_api_override(original_implementation: T) -> T: function_overrides = getattr(config.override, "functions", default_fn_override) api_overrides = getattr(config.override, "apis", default_api_override) - function_layers: list[Any] = [] - api_layers: list[Any] = [] - if function_overrides is not None: - function_layers.append(function_overrides) - if api_overrides is not None: - api_layers.append(api_overrides) + function_layers: deque[Any] = deque() + api_layers: deque[Any] = deque() + + # If we have plugins like 4->3->(2, 1) along with a recipe override, + # we want to apply them as: override, 4, 3, 1, 2, original + # Order of 1/2 does not matter since they are independent from each other. for plugin in plugins: overrides = plugin[recipe_id] @@ -290,14 +344,18 @@ def default_api_override(original_implementation: T) -> T: if overrides.apis is not None: api_layers.append(overrides.apis) + if function_overrides is not None: + function_layers.append(function_overrides) + if api_overrides is not None: + api_layers.append(api_overrides) + # Apply overrides in order of definition # Plugins: [plugin1, plugin2] would be applied as [override, plugin1, plugin2, original] if len(function_layers) > 0: - # TODO: Change to recipe_implementation type + # TODO: Change to recipe_interface type def fn_override(original_implementation: T) -> T: # The layers will get called in reversed order - # Iteration is reversed to ensure that the required order is maintained - for function_layer in reversed(function_layers): + for function_layer in function_layers: original_implementation = function_layer(original_implementation) return original_implementation @@ -305,9 +363,9 @@ def fn_override(original_implementation: T) -> T: # AccountLinking recipe does not have an API implementation if len(api_layers) > 0 and recipe_id != "accountlinking": - # TODO: Change to api_implementation type + # TODO: Change to api_interface type def api_override(original_implementation: T) -> T: - for api_layer in reversed(api_layers): + for api_layer in api_layers: original_implementation = api_layer(original_implementation) return original_implementation diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 45d5b2a6f..c10a46018 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -39,7 +39,6 @@ ) from supertokens_python.plugins import ( OverrideMap, - PluginDependenciesErrorResponse, PluginRouteHandler, SuperTokensPlugin, SuperTokensPluginInit, @@ -309,13 +308,18 @@ def __init__( self.plugin_route_handlers = [] - input_plugin_list = [] + input_plugin_list: List[SuperTokensPlugin] = [] + input_plugin_seen_list: Set[str] = set() final_plugin_list: List[SuperTokensPlugin] = [] if experimental is not None and experimental.plugins is not None: input_plugin_list = experimental.plugins for plugin in input_plugin_list: + if plugin.id in input_plugin_seen_list: + log_debug_message(f"Skipping {plugin.id=} as it has already been added") + continue + if isinstance(plugin.compatible_sdk_versions, list): version_constraints = plugin.compatible_sdk_versions else: @@ -325,32 +329,14 @@ def __init__( # TODO: Better checks raise Exception("Plugin version mismatch") - def recurse_deps(plugin: SuperTokensPlugin, deps: List[SuperTokensPlugin]): - if plugin.dependencies is not None: - # Get all dependencies of the plugin - dep_result = plugin.dependencies( - config=public_config, - plugins_above=[ - SuperTokensPublicPlugin.from_plugin(plugin) - for plugin in final_plugin_list - ], - sdk_version=VERSION, - ) - - # Errors fall through - if isinstance(dep_result, PluginDependenciesErrorResponse): - raise Exception(dep_result.message) - - # Recurse through all dependencies and add the resultant plugins to the list - # Pre-order DFS traversal - for dep_plugin in dep_result.plugins_to_add: - recurse_deps(dep_plugin, deps) - - # Add the current plugin - deps.append(plugin) - return deps - - final_plugin_list.extend(recurse_deps(plugin, [])) + # TODO: Overkill, but could topologically sort the plugins based on dependencies + dependencies = plugin.get_dependencies( + public_config=public_config, + plugins_above=final_plugin_list, + sdk_version=VERSION, + ) + final_plugin_list.extend(dependencies) + input_plugin_seen_list.update({dep.id for dep in dependencies}) self.plugin_list = [ SuperTokensPublicPlugin.from_plugin(plugin) for plugin in final_plugin_list diff --git a/tests/plugins/plugins.py b/tests/plugins/plugins.py index 2cd7d5742..42e6b552c 100644 --- a/tests/plugins/plugins.py +++ b/tests/plugins/plugins.py @@ -106,6 +106,7 @@ class Plugin(SuperTokensPlugin): Plugin3Dep2_1 = plugin_factory( "plugin3dep2_1", override_functions=True, deps=[Plugin2, Plugin1] ) +Plugin4Dep1 = plugin_factory("plugin4dep1", override_functions=True, deps=[Plugin1]) Plugin4Dep2 = plugin_factory("plugin4dep2", override_functions=True, deps=[Plugin2]) Plugin4Dep3__2_1 = plugin_factory( "plugin4dep3__2_1", override_functions=True, deps=[Plugin3Dep2_1] diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index 39720d6af..bedbf8f05 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -18,6 +18,7 @@ Plugin2, Plugin3Dep1, Plugin3Dep2_1, + Plugin4Dep1, Plugin4Dep2, Plugin4Dep3__2_1, api_override_factory, @@ -131,7 +132,7 @@ def recipe_factory(override_functions: bool = False, override_apis: bool = False True, False, [Plugin1, Plugin2], - outputs(["override", "plugin1", "plugin2", "original"]), + outputs(["override", "plugin2", "plugin1", "original"]), outputs(["original"]), id="fn_ovr=True, api_ovr=False, plugins=[Plugin1, Plugin2], plugin1=[fn], plugin2=[fn]", ), @@ -178,34 +179,47 @@ def test_overrides( ) +# TODO: Figure out a way to add circular dependencies and test them @mark.parametrize( ("plugins", "recipe_expectation", "api_expectation"), [ + param( + [Plugin1, Plugin1], + outputs(["plugin1", "original"]), + outputs(["original"]), + id="1,1 => 1", + ), param( [Plugin3Dep1], - outputs(["plugin1", "plugin3dep1", "original"]), + outputs(["plugin3dep1", "plugin1", "original"]), outputs(["original"]), - id="3->1", + id="3->1 => 3,1", ), param( [Plugin3Dep2_1], - outputs(["plugin2", "plugin1", "plugin3dep2_1", "original"]), + outputs(["plugin3dep2_1", "plugin1", "plugin2", "original"]), outputs(["original"]), - id="3->(2,1)", + id="3->(2,1) => 3,2,1", ), param( [Plugin3Dep1, Plugin4Dep2], - outputs(["plugin1", "plugin3dep1", "plugin2", "plugin4dep2", "original"]), + outputs(["plugin4dep2", "plugin2", "plugin3dep1", "plugin1", "original"]), outputs(["original"]), - id="3->1,4->2", + id="3->1,4->2 => 4,2,3,1", ), param( [Plugin4Dep3__2_1], outputs( - ["plugin2", "plugin1", "plugin3dep2_1", "plugin4dep3__2_1", "original"] + ["plugin4dep3__2_1", "plugin3dep2_1", "plugin1", "plugin2", "original"] ), outputs(["original"]), - id="4->3->(2,1)", + id="4->3->(2,1) => 4,3,1,2", + ), + param( + [Plugin3Dep1, Plugin4Dep1], + outputs(["plugin4dep1", "plugin3dep1", "plugin1", "original"]), + outputs(["original"]), + id="3->1,4->1 => 4,3,1", ), ], ) From bc6bf7d8fd5533e4b784dbbccfda3dd47bd0d572 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Tue, 10 Jun 2025 10:14:36 +0530 Subject: [PATCH 04/37] update: move plugin logic to function, add config override support --- supertokens_python/plugins.py | 100 ++++++++++++++++++ supertokens_python/supertokens.py | 164 +++++++++++------------------- tests/plugins/test_plugins.py | 20 ++++ 3 files changed, 178 insertions(+), 106 deletions(-) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 83f3cff3b..66b5b5793 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -15,15 +15,19 @@ Literal, Optional, Set, + Tuple, TypeVar, Union, + cast, runtime_checkable, ) from typing_extensions import Protocol +from supertokens_python.constants import VERSION from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse +from supertokens_python.logger import log_debug_message # from supertokens_python.recipe.accountlinking.types import AccountLinkingConfig # from supertokens_python.recipe.dashboard.utils import DashboardConfig @@ -47,6 +51,7 @@ # from supertokens_python.recipe.totp.types import TOTPConfig # from supertokens_python.recipe.usermetadata.utils import UserMetadataConfig # from supertokens_python.recipe.userroles.utils import UserRolesConfig +from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.types import MaybeAwaitable from supertokens_python.types.base import UserContext from supertokens_python.types.response import CamelCaseBaseModel @@ -372,3 +377,98 @@ def api_override(original_implementation: T) -> T: config.override.apis = api_override return config + + +def load_plugins( + plugins: List[SuperTokensPlugin], public_config: "SupertokensPublicConfig" +) -> Tuple[ + "SupertokensPublicConfig", + List[SuperTokensPublicPlugin], + List[PluginRouteHandler], + List[OverrideMap], +]: + input_plugin_seen_list: Set[str] = set() + final_plugin_list: List[SuperTokensPlugin] = [] + plugin_route_handlers: List[PluginRouteHandler] = [] + + for plugin in plugins: + if plugin.id in input_plugin_seen_list: + log_debug_message(f"Skipping {plugin.id=} as it has already been added") + continue + + if isinstance(plugin.compatible_sdk_versions, list): + version_constraints = plugin.compatible_sdk_versions + else: + version_constraints = [plugin.compatible_sdk_versions] + + if VERSION not in version_constraints: + # TODO: Better checks + raise Exception("Plugin version mismatch") + + # TODO: Overkill, but could topologically sort the plugins based on dependencies + dependencies = plugin.get_dependencies( + public_config=public_config, + plugins_above=final_plugin_list, + sdk_version=VERSION, + ) + final_plugin_list.extend(dependencies) + input_plugin_seen_list.update({dep.id for dep in dependencies}) + + processed_plugin_list = [ + SuperTokensPublicPlugin.from_plugin(plugin) for plugin in final_plugin_list + ] + + for plugin_idx, plugin in enumerate(final_plugin_list): + # Override the public supertokens config using the config override defined in the plugin + if plugin.config is not None: + public_config = plugin.config(public_config) + + if plugin.route_handlers is not None: + handlers: List[PluginRouteHandler] = [] + + if callable(plugin.route_handlers): + handler_result = plugin.route_handlers( + config=public_config, + all_plugins=processed_plugin_list, + sdk_version=VERSION, + ) + if isinstance(handler_result, PluginRouteHandlerFunctionErrorResponse): + raise Exception(handler_result.message) + + handlers = handler_result.route_handlers + else: + handlers = plugin.route_handlers + + plugin_route_handlers.extend(handlers) + + if plugin.init is not None: + plugin.init( + config=public_config, + all_plugins=processed_plugin_list, + sdk_version=VERSION, + ) + + def callback_factory(): + # This has to be part of the factory to ensure we pick up the correct plugin + init_fn = cast(SuperTokensPluginInit, plugin.init) + idx = plugin_idx + + def callback(): + init_fn( + config=public_config, + all_plugins=processed_plugin_list, + sdk_version=VERSION, + ) + processed_plugin_list[idx].initialized = True + + return callback + + PostSTInitCallbacks.add_post_init_callback(callback_factory()) + + override_maps = [ + plugin.override_map + for plugin in final_plugin_list + if plugin.override_map is not None + ] + + return public_config, processed_plugin_list, plugin_route_handlers, override_maps diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index c10a46018..b59bad29e 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -27,7 +27,6 @@ Set, Tuple, Union, - cast, ) from typing_extensions import Literal @@ -41,11 +40,11 @@ OverrideMap, PluginRouteHandler, SuperTokensPlugin, - SuperTokensPluginInit, SuperTokensPublicPlugin, + load_plugins, ) -from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT, VERSION +from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT from .exceptions import SuperTokensError from .interfaces import ( CreateUserIdMappingOkResult, @@ -123,6 +122,10 @@ class SupertokensExperimentalConfig: @dataclass class SupertokensPublicConfig: + """ + Public properties received as input to the `Supertokens.init` function. + """ + app_info: InputAppInfo framework: Literal["fastapi", "flask", "django"] supertokens_config: SupertokensConfig @@ -133,6 +136,10 @@ class SupertokensPublicConfig: @dataclass class SupertokensInputConfig(SupertokensPublicConfig): + """ + Various properties received as input to the `Supertokens.init` function. + """ + recipe_list: List[Callable[[AppInfo, List[OverrideMap]], RecipeModule]] experimental: Optional[SupertokensExperimentalConfig] = None @@ -146,6 +153,24 @@ def get_public_config(self) -> SupertokensPublicConfig: debug=self.debug, ) + @classmethod + def from_public_config( + cls, + config: SupertokensPublicConfig, + recipe_list: List[Callable[[AppInfo, List[OverrideMap]], RecipeModule]], + experimental: Optional[SupertokensExperimentalConfig], + ) -> "SupertokensInputConfig": + return cls( + app_info=config.app_info, + framework=config.framework, + supertokens_config=config.supertokens_config, + mode=config.mode, + telemetry=config.telemetry, + debug=config.debug, + recipe_list=recipe_list, + experimental=experimental, + ) + class Host: def __init__(self, domain: NormalisedURLDomain, base_path: NormalisedURLPath): @@ -287,12 +312,10 @@ def __init__( debug: Optional[bool], experimental: Optional[SupertokensExperimentalConfig] = None, ): - from supertokens_python.plugins import PluginRouteHandlerFunctionErrorResponse - if not isinstance(app_info, InputAppInfo): # type: ignore raise ValueError("app_info must be an instance of InputAppInfo") - config = SupertokensInputConfig( + input_config = SupertokensInputConfig( app_info=app_info, framework=framework, supertokens_config=supertokens_config, @@ -304,106 +327,41 @@ def __init__( ) # TODO: Probably just want to define this directly and use it # Can build a input config from the final public config and the additional props - public_config = config.get_public_config() + input_public_config = input_config.get_public_config() + processed_public_config = input_public_config self.plugin_route_handlers = [] - - input_plugin_list: List[SuperTokensPlugin] = [] - input_plugin_seen_list: Set[str] = set() - final_plugin_list: List[SuperTokensPlugin] = [] + override_maps: List[OverrideMap] = [] if experimental is not None and experimental.plugins is not None: - input_plugin_list = experimental.plugins - - for plugin in input_plugin_list: - if plugin.id in input_plugin_seen_list: - log_debug_message(f"Skipping {plugin.id=} as it has already been added") - continue - - if isinstance(plugin.compatible_sdk_versions, list): - version_constraints = plugin.compatible_sdk_versions - else: - version_constraints = [plugin.compatible_sdk_versions] - - if VERSION not in version_constraints: - # TODO: Better checks - raise Exception("Plugin version mismatch") - - # TODO: Overkill, but could topologically sort the plugins based on dependencies - dependencies = plugin.get_dependencies( - public_config=public_config, - plugins_above=final_plugin_list, - sdk_version=VERSION, + ( + processed_public_config, + self.plugin_list, + self.plugin_route_handlers, + override_maps, + ) = load_plugins( + plugins=experimental.plugins, + public_config=input_public_config, ) - final_plugin_list.extend(dependencies) - input_plugin_seen_list.update({dep.id for dep in dependencies}) - - self.plugin_list = [ - SuperTokensPublicPlugin.from_plugin(plugin) for plugin in final_plugin_list - ] - - for plugin_idx, plugin in enumerate(final_plugin_list): - # Override the public supertokens config using the config override defined in the plugin - if plugin.config is not None: - public_config = plugin.config(public_config) - - if plugin.route_handlers is not None: - handlers: List[PluginRouteHandler] = [] - - if callable(plugin.route_handlers): - handler_result = plugin.route_handlers( - config=public_config, - all_plugins=self.plugin_list, - sdk_version=VERSION, - ) - if isinstance( - handler_result, PluginRouteHandlerFunctionErrorResponse - ): - raise Exception(handler_result.message) - - handlers = handler_result.route_handlers - else: - handlers = plugin.route_handlers - - self.plugin_route_handlers.extend(handlers) - - if plugin.init is not None: - plugin.init( - config=public_config, - all_plugins=self.plugin_list, - sdk_version=VERSION, - ) - - # TODO: Make this a factory function to avoid weird side-effects? - def callback_factory(): - # This has to be part of the factory to ensure we pick up the correct plugin - init_fn = cast(SuperTokensPluginInit, plugin.init) - idx = plugin_idx - - def callback(): - init_fn( - config=public_config, - all_plugins=self.plugin_list, - sdk_version=VERSION, - ) - self.plugin_list[idx].initialized = True - - return callback - PostSTInitCallbacks.add_post_init_callback(callback_factory()) + config = SupertokensInputConfig.from_public_config( + config=processed_public_config, + recipe_list=recipe_list, + experimental=experimental, + ) self.app_info = AppInfo( - app_info.app_name, - app_info.api_domain, - app_info.website_domain, - framework, - app_info.api_gateway_path, - app_info.api_base_path, - app_info.website_base_path, - mode, - app_info.origin, + config.app_info.app_name, + config.app_info.api_domain, + config.app_info.website_domain, + config.framework, + config.app_info.api_gateway_path, + config.app_info.api_base_path, + config.app_info.website_base_path, + config.mode, + config.app_info.origin, ) - self.supertokens_config = supertokens_config + self.supertokens_config = config.supertokens_config if debug is True: enable_debug_logging() self._telemetry_status: str = "NONE" @@ -411,7 +369,7 @@ def callback(): "Started SuperTokens with debug logging (supertokens.init called)" ) log_debug_message("app_info: %s", self.app_info.toJSON()) - log_debug_message("framework: %s", framework) + log_debug_message("framework: %s", config.framework) hosts = list( map( lambda h: Host( @@ -432,12 +390,6 @@ def callback(): "Please provide at least one recipe to the supertokens.init function call" ) - override_maps = [ - plugin.override_map - for plugin in final_plugin_list - if plugin.override_map is not None - ] - multitenancy_found = False totp_found = False user_metadata_found = False @@ -507,8 +459,8 @@ def make_recipe( self.recipe_modules.append(OAuth2ProviderRecipe.init()(self.app_info)) self.telemetry = ( - telemetry - if telemetry is not None + config.telemetry + if config.telemetry is not None else (environ.get("TEST_MODE") != "testing") ) diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index bedbf8f05..f57ef9ca2 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -4,6 +4,7 @@ from pytest import fixture, mark, param from supertokens_python import ( InputAppInfo, + Supertokens, SupertokensConfig, SupertokensExperimentalConfig, init, @@ -261,3 +262,22 @@ def test_depdendencies( # TODO: Add tests for init, route handlers, config overrides +def test_config_override(): + plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) + + def config_override(cfg): + cfg.mode = "override" + return cfg + + plugin.config = config_override + + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) + + assert Supertokens.get_instance().app_info.mode == "override" From f06e08f37edff16b8b9063cad377e1ba37dd9110 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 11 Jun 2025 00:30:14 +0530 Subject: [PATCH 05/37] update: cleanup load_plugins --- supertokens_python/plugins.py | 32 +++++++++++++++++-------------- supertokens_python/supertokens.py | 26 ++++++++++++------------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 66b5b5793..f4193e66b 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -6,6 +6,7 @@ # - OverrideConfig from collections import deque +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -15,7 +16,6 @@ Literal, Optional, Set, - Tuple, TypeVar, Union, cast, @@ -150,7 +150,7 @@ class VerifySessionOptions(CamelCaseBaseModel): ] = None -class PluginRouteHandler: +class PluginRouteHandler(CamelCaseBaseModel): method: str path: str handler: PluginRouteHandlerHandlerFunction @@ -379,14 +379,18 @@ def api_override(original_implementation: T) -> T: return config +# TODO: Figure out import cycles and convert to a Pydantic BaseModel +@dataclass +class LoadPluginsResponse: + public_config: "SupertokensPublicConfig" + processed_plugins: List[SuperTokensPublicPlugin] + plugin_route_handlers: List[PluginRouteHandler] + override_maps: List[OverrideMap] + + def load_plugins( plugins: List[SuperTokensPlugin], public_config: "SupertokensPublicConfig" -) -> Tuple[ - "SupertokensPublicConfig", - List[SuperTokensPublicPlugin], - List[PluginRouteHandler], - List[OverrideMap], -]: +) -> LoadPluginsResponse: input_plugin_seen_list: Set[str] = set() final_plugin_list: List[SuperTokensPlugin] = [] plugin_route_handlers: List[PluginRouteHandler] = [] @@ -442,11 +446,6 @@ def load_plugins( plugin_route_handlers.extend(handlers) if plugin.init is not None: - plugin.init( - config=public_config, - all_plugins=processed_plugin_list, - sdk_version=VERSION, - ) def callback_factory(): # This has to be part of the factory to ensure we pick up the correct plugin @@ -471,4 +470,9 @@ def callback(): if plugin.override_map is not None ] - return public_config, processed_plugin_list, plugin_route_handlers, override_maps + return LoadPluginsResponse( + public_config=public_config, + processed_plugins=processed_plugin_list, + plugin_route_handlers=plugin_route_handlers, + override_maps=override_maps, + ) diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index b59bad29e..cd9536a7e 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -334,16 +334,16 @@ def __init__( override_maps: List[OverrideMap] = [] if experimental is not None and experimental.plugins is not None: - ( - processed_public_config, - self.plugin_list, - self.plugin_route_handlers, - override_maps, - ) = load_plugins( + load_plugins_result = load_plugins( plugins=experimental.plugins, public_config=input_public_config, ) + override_maps = load_plugins_result.override_maps + processed_public_config = load_plugins_result.public_config + self.plugin_list = load_plugins_result.processed_plugins + self.plugin_route_handlers = load_plugins_result.plugin_route_handlers + config = SupertokensInputConfig.from_public_config( config=processed_public_config, recipe_list=recipe_list, @@ -716,15 +716,13 @@ async def middleware( check_database=verify_session_options.check_database, override_global_claim_validators=verify_session_options.override_global_claim_validators, ) - handler_from_apis.handler( - request=request, - response=response, - session=session, - user_context=user_context, - ) - # TODO: Why do we do this? - return None + return handler_from_apis.handler( + request=request, + response=response, + session=session, + user_context=user_context, + ) if not path.startswith(Supertokens.get_instance().app_info.api_base_path): log_debug_message( From c317628f6aaeec91f2831962cc19b730f8e22910 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 11 Jun 2025 00:30:45 +0530 Subject: [PATCH 06/37] test: st config overrides, route handlers --- pyproject.toml | 2 + tests/plugins/test_plugins.py | 203 ++++++++++++++++++++++++++++++++-- 2 files changed, 197 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 69efe4a1a..86362d513 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,3 +18,5 @@ include = ["supertokens_python/", "tests/", "examples/"] addopts = " -v -p no:warnings" python_paths = "." xfail_strict = true +# Removes requirement to use `@mark.asyncio` on async tests +asyncio_mode = "auto" diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index f57ef9ca2..cd3bd4dd5 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -1,7 +1,7 @@ from functools import partial -from typing import Any, List +from typing import Any, Dict, List, Optional -from pytest import fixture, mark, param +from pytest import fixture, mark, param, raises from supertokens_python import ( InputAppInfo, Supertokens, @@ -9,7 +9,16 @@ SupertokensExperimentalConfig, init, ) -from supertokens_python.plugins import SuperTokensPlugin +from supertokens_python.framework.request import BaseRequest +from supertokens_python.framework.response import BaseResponse +from supertokens_python.plugins import ( + PluginRouteHandler, + PluginRouteHandlerFunctionErrorResponse, + PluginRouteHandlerFunctionOkResponse, + SuperTokensPlugin, +) +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.supertokens import SupertokensPublicConfig from tests.utils import outputs, reset @@ -64,6 +73,87 @@ def recipe_factory(override_functions: bool = False, override_apis: bool = False ) +class DummyRequest(BaseRequest): + def get_path(self) -> str: + return "/auth/plugin1/hello" + + def get_method(self) -> str: + return "get" + + def get_original_url(self) -> Any: + raise NotImplementedError + + def get_query_param(self, key: str, default: Optional[str] = None) -> Any: + raise NotImplementedError + + def get_query_params(self) -> Any: + raise NotImplementedError + + async def json(self) -> Any: + raise NotImplementedError + + async def form_data(self) -> Any: + raise NotImplementedError + + def method(self) -> Any: + return "get" + + def get_cookie(self, key: str) -> Any: + raise NotImplementedError + + def get_header(self, key: str) -> Any: + return None + + def get_session(self) -> Any: + raise NotImplementedError + + def set_session(self, session: SessionContainer) -> Any: + raise NotImplementedError + + def set_session_as_none(self) -> Any: + raise NotImplementedError + + +class DummyResponse(BaseResponse): + def __init__(self, content: Dict[str, Any], status_code: int = 200): + self.content = content + self.status_code = status_code + + def set_cookie( + self, + key: str, + value: str, + expires: int, + path: str = "/", + domain: Optional[str] = None, + secure: bool = False, + httponly: bool = False, + samesite: str = "lax", + ) -> Any: + raise NotImplementedError + + def set_header(self, key: str, value: str) -> None: + raise NotImplementedError + + def get_header(self, key: str) -> Optional[str]: + raise NotImplementedError + + def remove_header(self, key: str) -> None: + raise NotImplementedError + + def set_status_code(self, status_code: int) -> None: + raise NotImplementedError + + def set_json_content(self, content: Dict[str, Any]) -> Any: + raise NotImplementedError + + def set_html_content(self, content: str) -> Any: + raise NotImplementedError + + def redirect(self, url: str) -> Any: + raise NotImplementedError + + @mark.parametrize( ( "recipe_fn_override", @@ -261,13 +351,13 @@ def test_depdendencies( ) -# TODO: Add tests for init, route handlers, config overrides -def test_config_override(): +# TODO: Add tests for init, recipe config override +def test_st_config_override(): plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) - def config_override(cfg): - cfg.mode = "override" - return cfg + def config_override(config: SupertokensPublicConfig) -> SupertokensPublicConfig: + config.mode = "override" # type: ignore + return config plugin.config = config_override @@ -281,3 +371,100 @@ def config_override(cfg): ) assert Supertokens.get_instance().app_info.mode == "override" + + +def test_st_config_override_non_public_property(): + plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) + + def config_override(config: SupertokensPublicConfig) -> SupertokensPublicConfig: + config.recipe_list = [] # type: ignore + return config + + plugin.config = config_override + + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) + + assert Supertokens.get_instance().recipe_modules != [] + + +plugin_route_handler = PluginRouteHandler( + method="get", + path="/auth/plugin1/hello", + handler=lambda *_, **__: "plugin1", # type: ignore + verify_session_options=None, +) + + +async def test_route_handlers_list(): + plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) + + plugin.route_handlers = [plugin_route_handler] + + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) + + st_instance = Supertokens.get_instance() + + res = await st_instance.middleware( + request=DummyRequest(), + response=DummyResponse(content={}), + user_context={}, + ) + + assert res == "plugin1" + + +@mark.parametrize( + ("handler_response", "expectation"), + [ + param( + PluginRouteHandlerFunctionOkResponse(route_handlers=[plugin_route_handler]), + outputs("plugin1"), + id="OK response with route handler", + ), + param( + PluginRouteHandlerFunctionErrorResponse( + message="error", + ), + raises(Exception, match="error"), + id="Error response", + ), + ], +) +async def test_route_handlers_callable(handler_response: Any, expectation: Any): + plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) + + plugin.route_handlers = lambda *_, **__: handler_response # type: ignore + + with expectation as expected_output: + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) + + st_instance = Supertokens.get_instance() + + res = await st_instance.middleware( + request=DummyRequest(), + response=DummyResponse(content={}), + user_context={}, + ) + + assert res == expected_output From 8c535527966d78e78918fa04164f3004c18f7db8 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 11 Jun 2025 00:33:08 +0530 Subject: [PATCH 07/37] refactor: move classes to new file --- tests/plugins/misc.py | 86 ++++++++++++++++++++++++++++++++++ tests/plugins/test_plugins.py | 87 +---------------------------------- 2 files changed, 88 insertions(+), 85 deletions(-) create mode 100644 tests/plugins/misc.py diff --git a/tests/plugins/misc.py b/tests/plugins/misc.py new file mode 100644 index 000000000..963327472 --- /dev/null +++ b/tests/plugins/misc.py @@ -0,0 +1,86 @@ +from typing import Any, Dict, Optional + +from supertokens_python.framework.request import BaseRequest +from supertokens_python.framework.response import BaseResponse +from supertokens_python.recipe.session import SessionContainer + + +class DummyRequest(BaseRequest): + def get_path(self) -> str: + return "/auth/plugin1/hello" + + def get_method(self) -> str: + return "get" + + def get_original_url(self) -> Any: + raise NotImplementedError + + def get_query_param(self, key: str, default: Optional[str] = None) -> Any: + raise NotImplementedError + + def get_query_params(self) -> Any: + raise NotImplementedError + + async def json(self) -> Any: + raise NotImplementedError + + async def form_data(self) -> Any: + raise NotImplementedError + + def method(self) -> Any: + return "get" + + def get_cookie(self, key: str) -> Any: + raise NotImplementedError + + def get_header(self, key: str) -> Any: + return None + + def get_session(self) -> Any: + raise NotImplementedError + + def set_session(self, session: SessionContainer) -> Any: + raise NotImplementedError + + def set_session_as_none(self) -> Any: + raise NotImplementedError + + +class DummyResponse(BaseResponse): + def __init__(self, content: Dict[str, Any], status_code: int = 200): + self.content = content + self.status_code = status_code + + def set_cookie( + self, + key: str, + value: str, + expires: int, + path: str = "/", + domain: Optional[str] = None, + secure: bool = False, + httponly: bool = False, + samesite: str = "lax", + ) -> Any: + raise NotImplementedError + + def set_header(self, key: str, value: str) -> None: + raise NotImplementedError + + def get_header(self, key: str) -> Optional[str]: + raise NotImplementedError + + def remove_header(self, key: str) -> None: + raise NotImplementedError + + def set_status_code(self, status_code: int) -> None: + raise NotImplementedError + + def set_json_content(self, content: Dict[str, Any]) -> Any: + raise NotImplementedError + + def set_html_content(self, content: str) -> Any: + raise NotImplementedError + + def redirect(self, url: str) -> Any: + raise NotImplementedError diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index cd3bd4dd5..8af20c26c 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, List, Optional +from typing import Any, List from pytest import fixture, mark, param, raises from supertokens_python import ( @@ -9,20 +9,18 @@ SupertokensExperimentalConfig, init, ) -from supertokens_python.framework.request import BaseRequest -from supertokens_python.framework.response import BaseResponse from supertokens_python.plugins import ( PluginRouteHandler, PluginRouteHandlerFunctionErrorResponse, PluginRouteHandlerFunctionOkResponse, SuperTokensPlugin, ) -from supertokens_python.recipe.session import SessionContainer from supertokens_python.supertokens import SupertokensPublicConfig from tests.utils import outputs, reset from .config import OverrideConfig, PluginTestConfig +from .misc import DummyRequest, DummyResponse from .plugins import ( Plugin1, Plugin2, @@ -73,87 +71,6 @@ def recipe_factory(override_functions: bool = False, override_apis: bool = False ) -class DummyRequest(BaseRequest): - def get_path(self) -> str: - return "/auth/plugin1/hello" - - def get_method(self) -> str: - return "get" - - def get_original_url(self) -> Any: - raise NotImplementedError - - def get_query_param(self, key: str, default: Optional[str] = None) -> Any: - raise NotImplementedError - - def get_query_params(self) -> Any: - raise NotImplementedError - - async def json(self) -> Any: - raise NotImplementedError - - async def form_data(self) -> Any: - raise NotImplementedError - - def method(self) -> Any: - return "get" - - def get_cookie(self, key: str) -> Any: - raise NotImplementedError - - def get_header(self, key: str) -> Any: - return None - - def get_session(self) -> Any: - raise NotImplementedError - - def set_session(self, session: SessionContainer) -> Any: - raise NotImplementedError - - def set_session_as_none(self) -> Any: - raise NotImplementedError - - -class DummyResponse(BaseResponse): - def __init__(self, content: Dict[str, Any], status_code: int = 200): - self.content = content - self.status_code = status_code - - def set_cookie( - self, - key: str, - value: str, - expires: int, - path: str = "/", - domain: Optional[str] = None, - secure: bool = False, - httponly: bool = False, - samesite: str = "lax", - ) -> Any: - raise NotImplementedError - - def set_header(self, key: str, value: str) -> None: - raise NotImplementedError - - def get_header(self, key: str) -> Optional[str]: - raise NotImplementedError - - def remove_header(self, key: str) -> None: - raise NotImplementedError - - def set_status_code(self, status_code: int) -> None: - raise NotImplementedError - - def set_json_content(self, content: Dict[str, Any]) -> Any: - raise NotImplementedError - - def set_html_content(self, content: str) -> Any: - raise NotImplementedError - - def redirect(self, url: str) -> Any: - raise NotImplementedError - - @mark.parametrize( ( "recipe_fn_override", From b7a47fe7d5ef1439e0ae70694e8a7a8de9e5a5e6 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 11 Jun 2025 00:36:45 +0530 Subject: [PATCH 08/37] update: import order --- supertokens_python/plugins.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index f4193e66b..d7ddccce3 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -39,13 +39,6 @@ # from supertokens_python.recipe.oauth2provider.utils import OAuth2ProviderConfig # from supertokens_python.recipe.openid.utils import OpenIdConfig # from supertokens_python.recipe.passwordless.utils import PasswordlessConfig -if TYPE_CHECKING: - from supertokens_python.recipe.session.interfaces import ( - SessionClaimValidator, - SessionContainer, - ) - from supertokens_python.supertokens import SupertokensPublicConfig - # from supertokens_python.recipe.session.utils import SessionConfig # from supertokens_python.recipe.thirdparty.utils import ThirdPartyConfig # from supertokens_python.recipe.totp.types import TOTPConfig @@ -56,6 +49,13 @@ from supertokens_python.types.base import UserContext from supertokens_python.types.response import CamelCaseBaseModel +if TYPE_CHECKING: + from supertokens_python.recipe.session.interfaces import ( + SessionClaimValidator, + SessionContainer, + ) + from supertokens_python.supertokens import SupertokensPublicConfig + T = TypeVar("T") # T = TypeVar("T", bound=Union[AccountLinkingConfig, DashboardConfig, EmailPasswordConfig, # EmailVerificationConfig, JWTConfig, MultiFactorAuthConfig, MultitenancyConfig, From b37d23ef77d7834367278997fbc2dc4265ece74c Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Tue, 17 Jun 2025 02:05:48 +0530 Subject: [PATCH 09/37] update: adds tests --- tests/plugins/test_plugins.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index 8af20c26c..81b06496a 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -139,11 +139,36 @@ def recipe_factory(override_functions: bool = False, override_apis: bool = False param( True, False, - [Plugin1, Plugin2], + [ + plugin_factory("plugin1", override_functions=True), + plugin_factory("plugin2", override_functions=True), + ], outputs(["override", "plugin2", "plugin1", "original"]), outputs(["original"]), id="fn_ovr=True, api_ovr=False, plugins=[Plugin1, Plugin2], plugin1=[fn], plugin2=[fn]", ), + param( + False, + True, + [ + plugin_factory("plugin1", override_apis=True), + plugin_factory("plugin2", override_apis=True), + ], + outputs(["original"]), + outputs(["override", "plugin2", "plugin1", "original"]), + id="fn_ovr=True, api_ovr=False, plugins=[Plugin1, Plugin2], plugin1=[api], plugin2=[api]", + ), + param( + True, + True, + [ + plugin_factory("plugin1", override_functions=True, override_apis=True), + plugin_factory("plugin2", override_functions=True, override_apis=True), + ], + outputs(["override", "plugin2", "plugin1", "original"]), + outputs(["override", "plugin2", "plugin1", "original"]), + id="fn_ovr=True, api_ovr=True, plugins=[Plugin1, Plugin2], plugin1=[fn,api], plugin2=[fn,api]", + ), ], ) def test_overrides( @@ -197,6 +222,12 @@ def test_overrides( outputs(["original"]), id="1,1 => 1", ), + param( + [Plugin1, Plugin2], + outputs(["plugin2", "plugin1", "original"]), + outputs(["original"]), + id="1,2 => 2,1", + ), param( [Plugin3Dep1], outputs(["plugin3dep1", "plugin1", "original"]), From 837b0c34b8f93dd4a3ba3f9c6626a340f2d9cc2f Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 18 Jun 2025 12:48:49 +0530 Subject: [PATCH 10/37] update: adds init tests --- tests/plugins/plugins.py | 53 ++++++++++++++++++++++++++++------- tests/plugins/recipe.py | 2 ++ tests/plugins/test_plugins.py | 19 +++++++++++-- 3 files changed, 61 insertions(+), 13 deletions(-) diff --git a/tests/plugins/plugins.py b/tests/plugins/plugins.py index 42e6b552c..0a9fc455e 100644 --- a/tests/plugins/plugins.py +++ b/tests/plugins/plugins.py @@ -49,8 +49,7 @@ def init( all_plugins: List[SuperTokensPublicPlugin], sdk_version: str, ): - # TODO: Test this - print(f"{identifier} init") + PluginTestRecipe.init_calls.append(identifier) return init @@ -78,6 +77,7 @@ def plugin_factory( override_functions: bool = False, override_apis: bool = False, deps: Optional[List[SuperTokensPlugin]] = None, + add_init: bool = False, ): override_map_obj: OverrideMap = {PluginTestRecipe.recipe_id: OverrideConfig()} @@ -90,24 +90,57 @@ def plugin_factory( identifier ) + init_fn = None + if add_init: + init_fn = init_factory(identifier) + class Plugin(SuperTokensPlugin): id: str = identifier compatible_sdk_versions: Union[str, List[str]] = ["0.30.0"] override_map: Optional[OverrideMap] = override_map_obj - init: Any = init_factory(identifier) + init: Any = init_fn dependencies: Optional[SuperTokensPluginDependencies] = dependency_factory(deps) return Plugin() -Plugin1 = plugin_factory("plugin1", override_functions=True) -Plugin2 = plugin_factory("plugin2", override_functions=True) -Plugin3Dep1 = plugin_factory("plugin3dep1", override_functions=True, deps=[Plugin1]) +Plugin1 = plugin_factory( + "plugin1", + override_functions=True, + add_init=True, +) +Plugin2 = plugin_factory( + "plugin2", + override_functions=True, + add_init=True, +) +Plugin3Dep1 = plugin_factory( + "plugin3dep1", + override_functions=True, + deps=[Plugin1], + add_init=True, +) Plugin3Dep2_1 = plugin_factory( - "plugin3dep2_1", override_functions=True, deps=[Plugin2, Plugin1] + "plugin3dep2_1", + override_functions=True, + deps=[Plugin2, Plugin1], + add_init=True, +) +Plugin4Dep1 = plugin_factory( + "plugin4dep1", + override_functions=True, + deps=[Plugin1], + add_init=True, +) +Plugin4Dep2 = plugin_factory( + "plugin4dep2", + override_functions=True, + deps=[Plugin2], + add_init=True, ) -Plugin4Dep1 = plugin_factory("plugin4dep1", override_functions=True, deps=[Plugin1]) -Plugin4Dep2 = plugin_factory("plugin4dep2", override_functions=True, deps=[Plugin2]) Plugin4Dep3__2_1 = plugin_factory( - "plugin4dep3__2_1", override_functions=True, deps=[Plugin3Dep2_1] + "plugin4dep3__2_1", + override_functions=True, + deps=[Plugin3Dep2_1], + add_init=True, ) diff --git a/tests/plugins/recipe.py b/tests/plugins/recipe.py index c7977494a..2a07dcfcb 100644 --- a/tests/plugins/recipe.py +++ b/tests/plugins/recipe.py @@ -26,6 +26,7 @@ class PluginTestRecipe(RecipeModule): __instance: Optional["PluginTestRecipe"] = None + init_calls: List[str] = [] recipe_id = "plugin_test" config: NormalizedPluginTestConfig @@ -94,6 +95,7 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): @staticmethod def reset(): PluginTestRecipe.__instance = None + PluginTestRecipe.init_calls = [] def get_all_cors_headers(self) -> List[str]: return [] diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index 81b06496a..d523da1ef 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -15,6 +15,7 @@ PluginRouteHandlerFunctionOkResponse, SuperTokensPlugin, ) +from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.supertokens import SupertokensPublicConfig from tests.utils import outputs, reset @@ -41,9 +42,11 @@ def setup_and_teardown(): reset() PluginTestRecipe.reset() + PostSTInitCallbacks.reset() yield reset() PluginTestRecipe.reset() + PostSTInitCallbacks.reset() def recipe_factory(override_functions: bool = False, override_apis: bool = False): @@ -214,36 +217,41 @@ def test_overrides( # TODO: Figure out a way to add circular dependencies and test them @mark.parametrize( - ("plugins", "recipe_expectation", "api_expectation"), + ("plugins", "recipe_expectation", "api_expectation", "init_expectation"), [ param( [Plugin1, Plugin1], outputs(["plugin1", "original"]), outputs(["original"]), + outputs(["plugin1"]), id="1,1 => 1", ), param( [Plugin1, Plugin2], outputs(["plugin2", "plugin1", "original"]), outputs(["original"]), + outputs(["plugin1", "plugin2"]), id="1,2 => 2,1", ), param( [Plugin3Dep1], outputs(["plugin3dep1", "plugin1", "original"]), outputs(["original"]), + outputs(["plugin1", "plugin3dep1"]), id="3->1 => 3,1", ), param( [Plugin3Dep2_1], outputs(["plugin3dep2_1", "plugin1", "plugin2", "original"]), outputs(["original"]), + outputs(["plugin2", "plugin1", "plugin3dep2_1"]), id="3->(2,1) => 3,2,1", ), param( [Plugin3Dep1, Plugin4Dep2], outputs(["plugin4dep2", "plugin2", "plugin3dep1", "plugin1", "original"]), outputs(["original"]), + outputs(["plugin1", "plugin3dep1", "plugin2", "plugin4dep2"]), id="3->1,4->2 => 4,2,3,1", ), param( @@ -252,20 +260,23 @@ def test_overrides( ["plugin4dep3__2_1", "plugin3dep2_1", "plugin1", "plugin2", "original"] ), outputs(["original"]), + outputs(["plugin2", "plugin1", "plugin3dep2_1", "plugin4dep3__2_1"]), id="4->3->(2,1) => 4,3,1,2", ), param( [Plugin3Dep1, Plugin4Dep1], outputs(["plugin4dep1", "plugin3dep1", "plugin1", "original"]), outputs(["original"]), + outputs(["plugin1", "plugin3dep1", "plugin4dep1"]), id="3->1,4->1 => 4,3,1", ), ], ) -def test_depdendencies( +def test_depdendencies_and_init( plugins: List[SuperTokensPlugin], recipe_expectation: Any, api_expectation: Any, + init_expectation: Any, ): partial_init( recipe_list=[ @@ -298,8 +309,10 @@ def test_depdendencies( message="msg", ) + with init_expectation as expected_stack: + assert PluginTestRecipe.init_calls == expected_stack + -# TODO: Add tests for init, recipe config override def test_st_config_override(): plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) From fefab28fd0c3735f292a826f96101315f341d441 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Mon, 30 Jun 2025 16:30:53 +0530 Subject: [PATCH 11/37] refactor: use base classes for configs/overrides/interfaces --- supertokens_python/plugins.py | 7 +- .../recipe/accountlinking/interfaces.py | 5 +- .../recipe/accountlinking/recipe.py | 45 ++-- .../recipe/accountlinking/types.py | 55 ++-- .../recipe/accountlinking/utils.py | 56 ++-- .../recipe/dashboard/__init__.py | 6 +- .../recipe/dashboard/interfaces.py | 7 +- supertokens_python/recipe/dashboard/recipe.py | 40 +-- supertokens_python/recipe/dashboard/utils.py | 86 +++--- .../recipe/emailpassword/__init__.py | 8 +- .../recipe/emailpassword/interfaces.py | 7 +- .../recipe/emailpassword/recipe.py | 42 +-- .../recipe/emailpassword/utils.py | 90 +++---- .../recipe/emailverification/__init__.py | 8 +- .../recipe/emailverification/interfaces.py | 7 +- .../recipe/emailverification/recipe.py | 50 ++-- .../recipe/emailverification/utils.py | 89 ++++--- supertokens_python/recipe/jwt/__init__.py | 8 +- supertokens_python/recipe/jwt/interfaces.py | 14 +- supertokens_python/recipe/jwt/recipe.py | 36 +-- supertokens_python/recipe/jwt/utils.py | 62 +++-- .../recipe/multifactorauth/__init__.py | 8 +- .../recipe/multifactorauth/interfaces.py | 8 +- .../recipe/multifactorauth/recipe.py | 48 ++-- .../recipe/multifactorauth/types.py | 38 ++- .../recipe/multifactorauth/utils.py | 30 ++- .../recipe/multitenancy/__init__.py | 7 +- .../recipe/multitenancy/interfaces.py | 7 +- .../recipe/multitenancy/recipe.py | 43 ++- .../recipe/multitenancy/utils.py | 75 +++--- .../recipe/oauth2provider/__init__.py | 8 +- .../recipe/oauth2provider/interfaces.py | 7 +- .../recipe/oauth2provider/recipe.py | 34 +-- .../recipe/oauth2provider/utils.py | 51 ++-- supertokens_python/recipe/openid/__init__.py | 8 +- .../recipe/openid/interfaces.py | 7 +- supertokens_python/recipe/openid/recipe.py | 39 +-- supertokens_python/recipe/openid/utils.py | 85 +++--- .../recipe/passwordless/__init__.py | 8 +- .../recipe/passwordless/interfaces.py | 7 +- .../recipe/passwordless/recipe.py | 68 ++--- .../recipe/passwordless/utils.py | 166 ++++++------ supertokens_python/recipe/session/__init__.py | 5 +- .../recipe/session/interfaces.py | 5 +- supertokens_python/recipe/session/recipe.py | 94 +++---- supertokens_python/recipe/session/utils.py | 246 +++++++++--------- .../recipe/thirdparty/__init__.py | 8 +- .../recipe/thirdparty/interfaces.py | 8 +- .../recipe/thirdparty/recipe.py | 39 +-- supertokens_python/recipe/thirdparty/utils.py | 59 ++--- supertokens_python/recipe/totp/__init__.py | 8 +- supertokens_python/recipe/totp/interfaces.py | 8 +- supertokens_python/recipe/totp/recipe.py | 18 +- supertokens_python/recipe/totp/types.py | 52 ++-- supertokens_python/recipe/totp/utils.py | 22 +- .../recipe/usermetadata/__init__.py | 8 +- .../recipe/usermetadata/interfaces.py | 6 +- .../recipe/usermetadata/recipe.py | 25 +- .../recipe/usermetadata/utils.py | 41 +-- .../recipe/userroles/__init__.py | 8 +- .../recipe/userroles/interfaces.py | 8 +- supertokens_python/recipe/userroles/recipe.py | 39 +-- supertokens_python/recipe/userroles/utils.py | 64 ++--- .../recipe/webauthn/interfaces/api.py | 3 +- .../recipe/webauthn/interfaces/recipe.py | 3 +- supertokens_python/recipe/webauthn/recipe.py | 19 +- .../recipe/webauthn/types/config.py | 30 +-- supertokens_python/recipe/webauthn/utils.py | 18 +- supertokens_python/supertokens.py | 18 +- supertokens_python/test.py | 14 + supertokens_python/types/config.py | 93 +++++++ supertokens_python/types/recipe.py | 7 + tests/auth-react/flask-server/app.py | 14 +- .../django2x/polls/views.py | 14 +- .../drf_async/mysite/settings.py | 14 +- .../drf_async/polls/views.py | 14 +- .../frontendIntegration/fastapi-server/app.py | 14 +- tests/frontendIntegration/flask-server/app.py | 14 +- tests/sessions/claims/utils.py | 10 +- tests/test-server/app.py | 7 +- tests/test_session.py | 8 +- 81 files changed, 1317 insertions(+), 1248 deletions(-) create mode 100644 supertokens_python/test.py create mode 100644 supertokens_python/types/config.py create mode 100644 supertokens_python/types/recipe.py diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index d7ddccce3..77a5706a8 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -335,7 +335,8 @@ def default_api_override(original_implementation: T) -> T: api_layers: deque[Any] = deque() # If we have plugins like 4->3->(2, 1) along with a recipe override, - # we want to apply them as: override, 4, 3, 1, 2, original + # we want to load/init them as: override, 2, 1, 3, 4 + # and call them as: override, 4, 3, 2, 1, original # Order of 1/2 does not matter since they are independent from each other. for plugin in plugins: @@ -354,8 +355,8 @@ def default_api_override(original_implementation: T) -> T: if api_overrides is not None: api_layers.append(api_overrides) - # Apply overrides in order of definition - # Plugins: [plugin1, plugin2] would be applied as [override, plugin1, plugin2, original] + # Apply overrides in reverse order of definition + # Plugins: [plugin1, plugin2] would be applied as [override, plugin2, plugin1, original] if len(function_layers) > 0: # TODO: Change to recipe_interface type def fn_override(original_implementation: T) -> T: diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py index 105b5bda5..04274b6f0 100644 --- a/supertokens_python/recipe/accountlinking/interfaces.py +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -13,12 +13,13 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing_extensions import Literal from supertokens_python.types.base import AccountInfoInput +from supertokens_python.types.recipe import BaseRecipeInterface if TYPE_CHECKING: from supertokens_python.types import ( @@ -27,7 +28,7 @@ ) -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def get_users( self, diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 3b34801ca..2b8bd057a 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -23,6 +23,7 @@ log_debug_message, ) from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.querier import Querier from supertokens_python.recipe_module import APIHandled, RecipeModule @@ -34,6 +35,7 @@ from .types import ( AccountInfoWithRecipeId, AccountInfoWithRecipeIdAndUserId, + AccountLinkingInputConfig, InputOverrideConfig, RecipeLevelUser, ShouldAutomaticallyLink, @@ -77,35 +79,18 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - on_account_linked: Optional[ - Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] - ] = None, - should_do_automatic_account_linking: Optional[ - Callable[ - [ - AccountInfoWithRecipeIdAndUserId, - Optional[User], - Optional[SessionContainer], - str, - Dict[str, Any], - ], - Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], - ] - ] = None, - override: Optional[InputOverrideConfig] = None, + input_config: AccountLinkingInputConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - app_info, on_account_linked, should_do_automatic_account_linking, override + app_info, input_config=input_config ) recipe_implementation: RecipeInterface = RecipeImplementation( Querier.get_instance(recipe_id), self, self.config ) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) self.email_verification_recipe: EmailVerificationRecipe | None = None @@ -164,14 +149,22 @@ def init( ] = None, override: Optional[InputOverrideConfig] = None, ): - def func(app_info: AppInfo): + input_config = AccountLinkingInputConfig( + on_account_linked=on_account_linked, + should_do_automatic_account_linking=should_do_automatic_account_linking, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if AccountLinkingRecipe.__instance is None: AccountLinkingRecipe.__instance = AccountLinkingRecipe( - AccountLinkingRecipe.recipe_id, - app_info, - on_account_linked, - should_do_automatic_account_linking, - override, + recipe_id=AccountLinkingRecipe.recipe_id, + app_info=app_info, + input_config=apply_plugins( + recipe_id=AccountLinkingRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return AccountLinkingRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 22977eabe..ae8908edc 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -22,6 +22,12 @@ RecipeInterface, ) from supertokens_python.types import AccountInfo +from supertokens_python.types.config import ( + BaseConfigWithoutAPIOverride, + BaseInputConfigWithoutAPIOverride, + BaseInputOverrideConfigWithoutAPI, + BaseOverrideConfigWithoutAPI, +) if TYPE_CHECKING: from supertokens_python.recipe.session import SessionContainer @@ -131,29 +137,18 @@ def __init__(self, should_require_verification: bool): self.should_require_verification = should_require_verification -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - ): - self.functions = functions +class InputOverrideConfig(BaseInputOverrideConfigWithoutAPI[RecipeInterface]): ... -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - ): - self.functions = functions +class OverrideConfig(BaseOverrideConfigWithoutAPI[RecipeInterface]): ... -class AccountLinkingConfig: - def __init__( - self, - on_account_linked: Callable[ - [User, RecipeLevelUser, Dict[str, Any]], Awaitable[None] - ], - should_do_automatic_account_linking: Callable[ +class AccountLinkingInputConfig(BaseInputConfigWithoutAPIOverride[RecipeInterface]): + on_account_linked: Optional[ + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + ] = None + should_do_automatic_account_linking: Optional[ + Callable[ [ AccountInfoWithRecipeIdAndUserId, Optional[User], @@ -162,9 +157,21 @@ def __init__( Dict[str, Any], ], Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] + ] = None + + +class AccountLinkingConfig(BaseConfigWithoutAPIOverride[RecipeInterface]): + on_account_linked: Callable[ + [User, RecipeLevelUser, Dict[str, Any]], Awaitable[None] + ] + should_do_automatic_account_linking: Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[User], + Optional[SessionContainer], + str, + Dict[str, Any], ], - override: OverrideConfig, - ): - self.on_account_linked = on_account_linked - self.should_do_automatic_account_linking = should_do_automatic_account_linking - self.override = override + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py index 763bfd48e..d008ad725 100644 --- a/supertokens_python/recipe/accountlinking/utils.py +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -13,13 +13,14 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from supertokens_python.recipe.accountlinking.types import AccountLinkingInputConfig if TYPE_CHECKING: from .types import ( AccountInfoWithRecipeIdAndUserId, AccountLinkingConfig, - InputOverrideConfig, RecipeLevelUser, SessionContainer, ShouldAutomaticallyLink, @@ -58,51 +59,34 @@ def recipe_init_defined_should_do_automatic_account_linking() -> bool: def validate_and_normalise_user_input( _: AppInfo, - on_account_linked: Optional[ - Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] - ] = None, - should_do_automatic_account_linking: Optional[ - Callable[ - [ - AccountInfoWithRecipeIdAndUserId, - Optional[User], - Optional[SessionContainer], - str, - Dict[str, Any], - ], - Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], - ] - ] = None, - override: Union[InputOverrideConfig, None] = None, + input_config: AccountLinkingInputConfig, ) -> AccountLinkingConfig: - from .types import ( - AccountLinkingConfig as ALC, - ) - from .types import ( - InputOverrideConfig as IOC, - ) - from .types import ( - OverrideConfig, - ) + from .types import AccountLinkingConfig, OverrideConfig global _did_use_default_should_do_automatic_account_linking - if override is None: - override = IOC() + + override_config: OverrideConfig = OverrideConfig() + + if ( + input_config.override is not None + and input_config.override.functions is not None + ): + override_config.functions = input_config.override.functions _did_use_default_should_do_automatic_account_linking = ( - should_do_automatic_account_linking is None + input_config.should_do_automatic_account_linking is None ) - return ALC( - override=OverrideConfig(functions=override.functions), + return AccountLinkingConfig( + override=override_config, on_account_linked=( default_on_account_linked - if on_account_linked is None - else on_account_linked + if input_config.on_account_linked is None + else input_config.on_account_linked ), should_do_automatic_account_linking=( default_should_do_automatic_account_linking - if should_do_automatic_account_linking is None - else should_do_automatic_account_linking + if input_config.should_do_automatic_account_linking is None + else input_config.should_do_automatic_account_linking ), ) diff --git a/supertokens_python/recipe/dashboard/__init__.py b/supertokens_python/recipe/dashboard/__init__.py index 46f46f9a2..0d2413522 100644 --- a/supertokens_python/recipe/dashboard/__init__.py +++ b/supertokens_python/recipe/dashboard/__init__.py @@ -14,10 +14,10 @@ from __future__ import annotations -from typing import Callable, List, Optional +from typing import List, Optional -from supertokens_python import AppInfo, RecipeModule from supertokens_python.recipe.dashboard import utils +from supertokens_python.supertokens import RecipeInit from .recipe import DashboardRecipe @@ -28,7 +28,7 @@ def init( api_key: Optional[str] = None, admins: Optional[List[str]] = None, override: Optional[InputOverrideConfig] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return DashboardRecipe.init( api_key, admins, diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index a0e8564a4..6f5e15404 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -13,12 +13,13 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from typing_extensions import Literal from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from ...types.response import APIResponse @@ -41,7 +42,7 @@ def __init__(self, info: SessionInformationResult) -> None: self.tenant_id = info.tenant_id -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -77,7 +78,7 @@ def __init__( self.app_info = app_info -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): # undefined should be allowed self.dashboard_get: Optional[ diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index f789aea66..0bb134b80 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.recipe.dashboard.api.multitenancy.create_or_update_third_party_config import ( handle_create_or_update_third_party_config, ) @@ -149,6 +150,7 @@ VALIDATE_KEY_API, ) from .utils import ( + DashboardInputConfig, InputOverrideConfig, validate_and_normalise_user_input, ) @@ -162,29 +164,19 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - api_key: Optional[str], - admins: Optional[List[str]], - override: Optional[InputOverrideConfig] = None, + input_config: DashboardInputConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - api_key, - admins, - override, + input_config=input_config, ) recipe_implementation = RecipeImplementation() - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, SuperTokensError) and ( @@ -652,14 +644,22 @@ def init( admins: Optional[List[str]] = None, override: Optional[InputOverrideConfig] = None, ): - def func(app_info: AppInfo): + input_config = DashboardInputConfig( + api_key=api_key, + admins=admins, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if DashboardRecipe.__instance is None: DashboardRecipe.__instance = DashboardRecipe( - DashboardRecipe.recipe_id, - app_info, - api_key, - admins, - override, + recipe_id=DashboardRecipe.recipe_id, + app_info=app_info, + input_config=apply_plugins( + recipe_id=DashboardRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return DashboardRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 50324c0fd..fdedddeab 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -13,11 +13,17 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing_extensions import Literal from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -45,9 +51,7 @@ USERS_LIST_GET_API, VALIDATE_KEY_API, ) - -if TYPE_CHECKING: - from .interfaces import APIInterface, RecipeInterface +from .interfaces import APIInterface, RecipeInterface class UserWithMetadata: @@ -73,64 +77,52 @@ def to_json(self) -> Dict[str, Any]: return user_json -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... -class DashboardConfig: - def __init__( - self, - api_key: Optional[str], - admins: Optional[List[str]], - override: OverrideConfig, - auth_mode: str, - ): - self.api_key = api_key - self.admins = admins - self.override = override - self.auth_mode = auth_mode +class DashboardInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + api_key: Optional[str] = None + admins: Optional[List[str]] = None + + +class DashboardConfig(BaseConfig[RecipeInterface, APIInterface]): + api_key: Optional[str] + admins: Optional[List[str]] + auth_mode: str def validate_and_normalise_user_input( - # app_info: AppInfo, - api_key: Union[str, None], - admins: Optional[List[str]], - override: Optional[InputOverrideConfig] = None, + input_config: DashboardInputConfig, ) -> DashboardConfig: - if override is None: - override = InputOverrideConfig() + override_config: OverrideConfig = OverrideConfig() - if api_key is not None and admins is not None: + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions + + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis + + if input_config.api_key is not None and input_config.admins is not None: log_debug_message( "User Dashboard: Providing 'admins' has no effect when using an api key." ) - admins = [normalise_email(a) for a in admins] if admins is not None else None + admins = ( + [normalise_email(a) for a in input_config.admins] + if input_config.admins is not None + else None + ) + auth_mode = "api-key" if input_config.api_key else "email-password" return DashboardConfig( - api_key, - admins, - OverrideConfig( - functions=override.functions, - apis=override.apis, - ), - "api-key" if api_key else "email-password", + api_key=input_config.api_key, + admins=admins, + auth_mode=auth_mode, + override=override_config, ) diff --git a/supertokens_python/recipe/emailpassword/__init__.py b/supertokens_python/recipe/emailpassword/__init__.py index 731009e43..ce5e2b555 100644 --- a/supertokens_python/recipe/emailpassword/__init__.py +++ b/supertokens_python/recipe/emailpassword/__init__.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union from supertokens_python.ingredients.emaildelivery import types as emaildelivery_types from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig @@ -32,16 +32,14 @@ EmailDeliveryInterface = emaildelivery_types.EmailDeliveryInterface if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( sign_up_feature: Union[utils.InputSignUpFeature, None] = None, override: Union[utils.InputOverrideConfig, None] = None, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return EmailPasswordRecipe.init( sign_up_feature, override, diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 96df9426e..fc4927826 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -13,12 +13,13 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Union from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.recipe.emailpassword.types import EmailTemplateVars from supertokens_python.types.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from ...supertokens import AppInfo from ...types import ( @@ -118,7 +119,7 @@ def to_json(self) -> Dict[str, Any]: } -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -315,7 +316,7 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status, "reason": self.reason} -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_email_exists_get = False self.disable_generate_password_reset_token_post = False diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index 9a01031e3..f31c69501 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -20,6 +20,7 @@ from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.recipe.emailpassword.types import ( EmailPasswordIngredients, EmailTemplateVars, @@ -68,6 +69,7 @@ USER_PASSWORD_RESET_TOKEN, ) from .utils import ( + EmailPasswordInputConfig, InputOverrideConfig, InputSignUpFeature, validate_and_normalise_user_input, @@ -84,25 +86,19 @@ def __init__( recipe_id: str, app_info: AppInfo, ingredients: EmailPasswordIngredients, - sign_up_feature: Union[InputSignUpFeature, None] = None, - override: Union[InputOverrideConfig, None] = None, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, + input_config: EmailPasswordInputConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( app_info, - sign_up_feature, - override, - email_delivery, + input_config=input_config, ) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) email_delivery_ingredient = ingredients.email_delivery @@ -114,11 +110,7 @@ def __init__( self.email_delivery = email_delivery_ingredient api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def callback(): mfa_instance = MultiFactorAuthRecipe.get_instance() @@ -374,16 +366,24 @@ def init( override: Union[InputOverrideConfig, None] = None, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, ): - def func(app_info: AppInfo): + input_config = EmailPasswordInputConfig( + sign_up_feature=sign_up_feature, + email_delivery=email_delivery, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if EmailPasswordRecipe.__instance is None: ingredients = EmailPasswordIngredients(None) EmailPasswordRecipe.__instance = EmailPasswordRecipe( - EmailPasswordRecipe.recipe_id, - app_info, - ingredients, - sign_up_feature, - override, - email_delivery=email_delivery, + recipe_id=EmailPasswordRecipe.recipe_id, + app_info=app_info, + ingredients=ingredients, + input_config=apply_plugins( + recipe_id=EmailPasswordRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return EmailPasswordRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/emailpassword/utils.py b/supertokens_python/recipe/emailpassword/utils.py index fa267eae7..3dc72b230 100644 --- a/supertokens_python/recipe/emailpassword/utils.py +++ b/supertokens_python/recipe/emailpassword/utils.py @@ -24,6 +24,12 @@ from supertokens_python.recipe.emailpassword.emaildelivery.services.backward_compatibility import ( BackwardCompatibilityService, ) +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) from .interfaces import APIInterface, RecipeInterface from .types import EmailTemplateVars, InputFormField, NormalisedFormField @@ -213,82 +219,76 @@ def validate_and_normalise_reset_password_using_token_config( ) -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... -class EmailPasswordConfig: - def __init__( - self, - sign_up_feature: SignUpFeature, - sign_in_feature: SignInFeature, - reset_password_using_token_feature: ResetPasswordUsingTokenFeature, - override: OverrideConfig, - get_email_delivery_config: Callable[ - [RecipeInterface], EmailDeliveryConfigWithService[EmailTemplateVars] - ], - ): - self.sign_up_feature = sign_up_feature - self.sign_in_feature = sign_in_feature - self.reset_password_using_token_feature = reset_password_using_token_feature - self.override = override - self.get_email_delivery_config = get_email_delivery_config +class EmailPasswordInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + sign_up_feature: Union[InputSignUpFeature, None] = None + email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None + + +class EmailPasswordConfig(BaseConfig[RecipeInterface, APIInterface]): + sign_up_feature: SignUpFeature + sign_in_feature: SignInFeature + reset_password_using_token_feature: ResetPasswordUsingTokenFeature + get_email_delivery_config: Callable[ + [RecipeInterface], EmailDeliveryConfigWithService[EmailTemplateVars] + ] def validate_and_normalise_user_input( app_info: AppInfo, - sign_up_feature: Union[InputSignUpFeature, None] = None, - override: Union[InputOverrideConfig, None] = None, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, + input_config: EmailPasswordInputConfig, ) -> EmailPasswordConfig: # NOTE: We don't need to check the instance of sign_up_feature and override # as they will always be either None or the specified type. - if override is None: - override = InputOverrideConfig() + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions + + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis + sign_up_feature = input_config.sign_up_feature if sign_up_feature is None: sign_up_feature = InputSignUpFeature() def get_email_delivery_config( ep_recipe: RecipeInterface, ) -> EmailDeliveryConfigWithService[EmailTemplateVars]: - if email_delivery and email_delivery.service: + if input_config.email_delivery and input_config.email_delivery.service: return EmailDeliveryConfigWithService( - service=email_delivery.service, override=email_delivery.override + service=input_config.email_delivery.service, + override=input_config.email_delivery.override, ) email_service = BackwardCompatibilityService( app_info=app_info, recipe_interface_impl=ep_recipe, ) - if email_delivery is not None and email_delivery.override is not None: - override = email_delivery.override + if ( + input_config.email_delivery is not None + and input_config.email_delivery.override is not None + ): + override = input_config.email_delivery.override else: override = None return EmailDeliveryConfigWithService(email_service, override=override) return EmailPasswordConfig( - SignUpFeature(sign_up_feature.form_fields), - SignInFeature(normalise_sign_in_form_fields(sign_up_feature.form_fields)), - validate_and_normalise_reset_password_using_token_config(sign_up_feature), - OverrideConfig(functions=override.functions, apis=override.apis), + sign_up_feature=SignUpFeature(sign_up_feature.form_fields), + sign_in_feature=SignInFeature( + normalise_sign_in_form_fields(sign_up_feature.form_fields) + ), + reset_password_using_token_feature=validate_and_normalise_reset_password_using_token_config( + sign_up_feature + ), + override=override_config, get_email_delivery_config=get_email_delivery_config, ) diff --git a/supertokens_python/recipe/emailverification/__init__.py b/supertokens_python/recipe/emailverification/__init__.py index 4836bb3e0..b88bac6cd 100644 --- a/supertokens_python/recipe/emailverification/__init__.py +++ b/supertokens_python/recipe/emailverification/__init__.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union from ...ingredients.emaildelivery.types import EmailDeliveryConfig from . import exceptions as ex @@ -32,9 +32,7 @@ if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( @@ -42,7 +40,7 @@ def init( email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, override: Union[OverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return EmailVerificationRecipe.init( mode, email_delivery, diff --git a/supertokens_python/recipe/emailverification/interfaces.py b/supertokens_python/recipe/emailverification/interfaces.py index bbf829047..8c00c38fd 100644 --- a/supertokens_python/recipe/emailverification/interfaces.py +++ b/supertokens_python/recipe/emailverification/interfaces.py @@ -13,13 +13,14 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union from typing_extensions import Literal from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.types import RecipeUserId +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse from ...supertokens import AppInfo @@ -84,7 +85,7 @@ class UnverifyEmailOkResult: pass -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -201,7 +202,7 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_email_verify_post = False self.disable_is_email_verified_get = False diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index a4abfb47c..9f5cc7e7f 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -18,6 +18,7 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.recipe.emailverification.exceptions import ( EmailVerificationInvalidTokenError, ) @@ -78,7 +79,12 @@ from .api import handle_email_verify_api, handle_generate_email_verify_token_api from .constants import USER_EMAIL_VERIFY, USER_EMAIL_VERIFY_TOKEN from .exceptions import SuperTokensEmailVerificationError -from .utils import MODE_TYPE, OverrideConfig, validate_and_normalise_user_input +from .utils import ( + MODE_TYPE, + EmailVerificationInputConfig, + InputOverrideConfig, + validate_and_normalise_user_input, +) class EmailVerificationRecipe(RecipeModule): @@ -91,36 +97,24 @@ def __init__( recipe_id: str, app_info: AppInfo, ingredients: EmailVerificationIngredients, - mode: MODE_TYPE, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - override: Union[OverrideConfig, None] = None, + input_config: EmailVerificationInputConfig, ) -> None: super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( app_info, - mode, - email_delivery, - get_email_for_recipe_user_id, - override, + input_config=input_config, ) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.get_email_for_recipe_user_id, ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) email_delivery_ingredient = ingredients.email_delivery if email_delivery_ingredient is None: @@ -207,19 +201,29 @@ def init( mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - override: Union[OverrideConfig, None] = None, + override: Union[InputOverrideConfig, None] = None, ): - def func(app_info: AppInfo) -> EmailVerificationRecipe: + input_config = EmailVerificationInputConfig( + mode=mode, + email_delivery=email_delivery, + get_email_for_recipe_user_id=get_email_for_recipe_user_id, + override=override, + ) + + def func( + app_info: AppInfo, plugins: List[OverrideMap] + ) -> EmailVerificationRecipe: if EmailVerificationRecipe.__instance is None: ingredients = EmailVerificationIngredients(email_delivery=None) EmailVerificationRecipe.__instance = EmailVerificationRecipe( EmailVerificationRecipe.recipe_id, app_info, ingredients, - mode, - email_delivery, - get_email_for_recipe_user_id, - override, + input_config=apply_plugins( + recipe_id=EmailVerificationRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) def callback(): diff --git a/supertokens_python/recipe/emailverification/utils.py b/supertokens_python/recipe/emailverification/utils.py index 8811f073a..251687c67 100644 --- a/supertokens_python/recipe/emailverification/utils.py +++ b/supertokens_python/recipe/emailverification/utils.py @@ -26,53 +26,55 @@ from supertokens_python.recipe.emailverification.emaildelivery.services.backward_compatibility import ( BackwardCompatibilityService, ) +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) + +from .interfaces import APIInterface, RecipeInterface, TypeGetEmailForUserIdFunction if TYPE_CHECKING: from typing import Callable, Union from supertokens_python.supertokens import AppInfo - from .interfaces import APIInterface, RecipeInterface, TypeGetEmailForUserIdFunction from .types import EmailTemplateVars, VerificationEmailTemplateVars -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... + + +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... MODE_TYPE = Literal["REQUIRED", "OPTIONAL"] -class EmailVerificationConfig: - def __init__( - self, - mode: MODE_TYPE, - get_email_delivery_config: Callable[ - [], EmailDeliveryConfigWithService[VerificationEmailTemplateVars] - ], - get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction], - override: OverrideConfig, - ): - self.mode = mode - self.override = override - self.get_email_delivery_config = get_email_delivery_config - self.get_email_for_recipe_user_id = get_email_for_recipe_user_id +class EmailVerificationInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + mode: MODE_TYPE + email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None + + +class EmailVerificationConfig(BaseConfig[RecipeInterface, APIInterface]): + mode: MODE_TYPE + get_email_delivery_config: Callable[ + [], EmailDeliveryConfigWithService[VerificationEmailTemplateVars] + ] + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] def validate_and_normalise_user_input( app_info: AppInfo, - mode: MODE_TYPE, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - override: Union[OverrideConfig, None] = None, + input_config: EmailVerificationInputConfig, + # mode: MODE_TYPE, + # email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, + # get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, + # override: Union[OverrideConfig, None] = None, ) -> EmailVerificationConfig: - if mode not in ["REQUIRED", "OPTIONAL"]: + if input_config.mode not in ["REQUIRED", "OPTIONAL"]: raise ValueError( "Email Verification recipe mode must be one of 'REQUIRED' or 'OPTIONAL'" ) @@ -80,27 +82,36 @@ def validate_and_normalise_user_input( def get_email_delivery_config() -> EmailDeliveryConfigWithService[ VerificationEmailTemplateVars ]: - email_service = email_delivery.service if email_delivery is not None else None + email_service = ( + input_config.email_delivery.service + if input_config.email_delivery is not None + else None + ) if email_service is None: email_service = BackwardCompatibilityService(app_info) - if email_delivery is not None and email_delivery.override is not None: - override = email_delivery.override + if ( + input_config.email_delivery is not None + and input_config.email_delivery.override is not None + ): + override = input_config.email_delivery.override else: override = None return EmailDeliveryConfigWithService(email_service, override=override) - if override is not None and not isinstance(override, OverrideConfig): # type: ignore - raise ValueError("override must be of type OverrideConfig or None") + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions - if override is None: - override = OverrideConfig() + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis return EmailVerificationConfig( - mode, - get_email_delivery_config, - get_email_for_recipe_user_id, - override, + mode=input_config.mode, + get_email_delivery_config=get_email_delivery_config, + get_email_for_recipe_user_id=input_config.get_email_for_recipe_user_id, + override=override_config, ) diff --git a/supertokens_python/recipe/jwt/__init__.py b/supertokens_python/recipe/jwt/__init__.py index e53bb46c1..7d0f1b3f3 100644 --- a/supertokens_python/recipe/jwt/__init__.py +++ b/supertokens_python/recipe/jwt/__init__.py @@ -13,19 +13,17 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union from .recipe import JWTRecipe from .utils import OverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( jwt_validity_seconds: Union[int, None] = None, override: Union[OverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return JWTRecipe.init(jwt_validity_seconds, override) diff --git a/supertokens_python/recipe/jwt/interfaces.py b/supertokens_python/recipe/jwt/interfaces.py index 2c398f8da..5860c09cf 100644 --- a/supertokens_python/recipe/jwt/interfaces.py +++ b/supertokens_python/recipe/jwt/interfaces.py @@ -11,13 +11,15 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.framework import BaseRequest, BaseResponse +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse -from .utils import JWTConfig +if TYPE_CHECKING: + from .utils import JWTConfig class JsonWebKey: @@ -45,7 +47,7 @@ def __init__(self, keys: List[JsonWebKey], validity_in_secs: Optional[int]): self.validity_in_secs = validity_in_secs -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -70,7 +72,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: JWTConfig, + config: "JWTConfig", recipe_implementation: RecipeInterface, ): self.request = request @@ -101,7 +103,7 @@ def to_json(self) -> Dict[str, Any]: return {"keys": keys} -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_jwks_get = False diff --git a/supertokens_python/recipe/jwt/recipe.py b/supertokens_python/recipe/jwt/recipe.py index 4f28c5d88..6135535c1 100644 --- a/supertokens_python/recipe/jwt/recipe.py +++ b/supertokens_python/recipe/jwt/recipe.py @@ -16,6 +16,7 @@ from os import environ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.jwt.api.implementation import APIImplementation from supertokens_python.recipe.jwt.api.jwks_get import jwks_get @@ -24,7 +25,8 @@ from supertokens_python.recipe.jwt.interfaces import APIOptions from supertokens_python.recipe.jwt.recipe_implementation import RecipeImplementation from supertokens_python.recipe.jwt.utils import ( - OverrideConfig, + InputOverrideConfig, + JWTInputConfig, validate_and_normalise_user_input, ) @@ -46,26 +48,19 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - jwt_validity_seconds: Union[int, None] = None, - override: Union[OverrideConfig, None] = None, + input_config: JWTInputConfig, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(jwt_validity_seconds, override) + self.config = validate_and_normalise_user_input(input_config=input_config) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config, app_info ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def get_apis_handled(self) -> List[APIHandled]: return [ @@ -120,12 +115,23 @@ def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: @staticmethod def init( jwt_validity_seconds: Union[int, None] = None, - override: Union[OverrideConfig, None] = None, + override: Union[InputOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + input_config = JWTInputConfig( + jwt_validity_seconds=jwt_validity_seconds, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if JWTRecipe.__instance is None: JWTRecipe.__instance = JWTRecipe( - JWTRecipe.recipe_id, app_info, jwt_validity_seconds, override + JWTRecipe.recipe_id, + app_info, + input_config=apply_plugins( + recipe_id=JWTRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return JWTRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/jwt/utils.py b/supertokens_python/recipe/jwt/utils.py index 35dbef8e1..6ed829407 100644 --- a/supertokens_python/recipe/jwt/utils.py +++ b/supertokens_python/recipe/jwt/utils.py @@ -13,41 +13,49 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import Optional -if TYPE_CHECKING: - from .interfaces import APIInterface, RecipeInterface +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) +from .interfaces import APIInterface, RecipeInterface -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... -class JWTConfig: - def __init__(self, override: OverrideConfig, jwt_validity_seconds: int): - self.override = override - self.jwt_validity_seconds = jwt_validity_seconds +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... + + +class JWTInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + jwt_validity_seconds: Optional[int] = None + + +class JWTConfig(BaseConfig[RecipeInterface, APIInterface]): + jwt_validity_seconds: int -def validate_and_normalise_user_input( - jwt_validity_seconds: Union[int, None] = None, - override: Union[OverrideConfig, None] = None, -): - if jwt_validity_seconds is not None and not isinstance(jwt_validity_seconds, int): # type: ignore - raise ValueError("jwt_validity_seconds must be an integer or None") - if override is not None and not isinstance(override, OverrideConfig): # type: ignore - raise ValueError("override must be an instance of OverrideConfig or None") +def validate_and_normalise_user_input(input_config: JWTInputConfig): + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions - if override is None: - override = OverrideConfig() - if jwt_validity_seconds is None: + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis + + jwt_validity_seconds = input_config.jwt_validity_seconds + + if input_config.jwt_validity_seconds is None: jwt_validity_seconds = 3153600000 - return JWTConfig(override, jwt_validity_seconds) + if not isinstance(jwt_validity_seconds, int): # type: ignore + raise ValueError("jwt_validity_seconds must be an integer or None") + + return JWTConfig( + jwt_validity_seconds=jwt_validity_seconds, override=override_config + ) diff --git a/supertokens_python/recipe/multifactorauth/__init__.py b/supertokens_python/recipe/multifactorauth/__init__.py index fa3aaf6f8..074bc3bc9 100644 --- a/supertokens_python/recipe/multifactorauth/__init__.py +++ b/supertokens_python/recipe/multifactorauth/__init__.py @@ -13,22 +13,20 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from supertokens_python.recipe.multifactorauth.types import OverrideConfig from .recipe import MultiFactorAuthRecipe if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( first_factors: Optional[List[str]] = None, override: Union[OverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return MultiFactorAuthRecipe.init( first_factors, override, diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index b5b0ec928..03db15635 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -14,9 +14,11 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + from ...types.response import APIResponse, GeneralErrorResponse if TYPE_CHECKING: @@ -29,7 +31,7 @@ from .types import MFARequirementList, MultiFactorAuthConfig -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( self, @@ -109,7 +111,7 @@ def __init__( self.recipe_instance = recipe_instance -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_resync_session_and_fetch_mfa_info_put = False diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py index 995137b1a..0dc7e0a54 100644 --- a/supertokens_python/recipe/multifactorauth/recipe.py +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -13,13 +13,13 @@ # under the License. from __future__ import annotations -import importlib from os import environ from typing import Any, Dict, List, Optional, Union from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.querier import Querier from supertokens_python.recipe.multifactorauth.api import ( @@ -37,7 +37,9 @@ from supertokens_python.supertokens import AppInfo from supertokens_python.types import RecipeUserId, User +from .api.implementation import APIImplementation from .interfaces import APIOptions +from .recipe_implementation import RecipeImplementation from .types import ( GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, GetEmailsForFactorFromOtherRecipesFunc, @@ -47,8 +49,10 @@ GetPhoneNumbersForFactorsFromOtherRecipesFunc, GetPhoneNumbersForFactorsOkResult, GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, - OverrideConfig, + InputOverrideConfig, + MultiFactorAuthInputConfig, ) +from .utils import validate_and_normalise_user_input class MultiFactorAuthRecipe(RecipeModule): @@ -59,8 +63,7 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - first_factors: Optional[List[str]] = None, - override: Union[OverrideConfig, None] = None, + input_config: MultiFactorAuthInputConfig, ): super().__init__(recipe_id, app_info) self.get_factors_setup_for_user_from_other_recipes_funcs: List[ @@ -77,32 +80,19 @@ def __init__( ] = [] self.is_get_mfa_requirements_for_auth_overridden: bool = False - module = importlib.import_module( - "supertokens_python.recipe.multifactorauth.utils" - ) - - self.config = module.validate_and_normalise_user_input( - first_factors, - override, + self.config = validate_and_normalise_user_input( + input_config=input_config, ) - from .recipe_implementation import RecipeImplementation recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) - from .api.implementation import APIImplementation api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def callback(): from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe @@ -169,15 +159,23 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( first_factors: Optional[List[str]] = None, - override: Union[OverrideConfig, None] = None, + override: Union[InputOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + input_config = MultiFactorAuthInputConfig( + first_factors=first_factors, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if MultiFactorAuthRecipe.__instance is None: MultiFactorAuthRecipe.__instance = MultiFactorAuthRecipe( MultiFactorAuthRecipe.recipe_id, app_info, - first_factors, - override, + input_config=apply_plugins( + recipe_id=MultiFactorAuthRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return MultiFactorAuthRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index 1f5587189..571561e14 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -13,16 +13,20 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from typing_extensions import Literal from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from supertokens_python.types import RecipeUserId, User +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) -if TYPE_CHECKING: - from .interfaces import APIInterface, RecipeInterface - +from .interfaces import APIInterface, RecipeInterface MFARequirementList = List[ Union[str, Dict[Union[Literal["oneOf"], Literal["allOfInAnyOrder"]], List[str]]] @@ -38,24 +42,18 @@ def __init__(self, c: Dict[str, Any], v: bool): self.v = v -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... -class MultiFactorAuthConfig: - def __init__( - self, - first_factors: Optional[List[str]], - override: OverrideConfig, - ): - self.first_factors = first_factors - self.override = override +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... + + +class MultiFactorAuthInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + first_factors: Optional[List[str]] = None + + +class MultiFactorAuthConfig(BaseConfig[RecipeInterface, APIInterface]): + first_factors: Optional[List[str]] class FactorIds: diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index 38bd635c5..7a0037767 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -15,7 +15,7 @@ import math import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from typing_extensions import Literal @@ -36,28 +36,32 @@ from supertokens_python.types import RecipeUserId from supertokens_python.utils import log_debug_message -if TYPE_CHECKING: - from .types import MultiFactorAuthConfig, OverrideConfig +from .types import ( + MultiFactorAuthConfig, + MultiFactorAuthInputConfig, + OverrideConfig, +) # IMPORTANT: If this function signature is modified, please update all tha places where this function is called. # There will be no type errors cause we use importLib to dynamically import if to prevent cyclic import issues. def validate_and_normalise_user_input( - first_factors: Optional[List[str]], - override: Union[OverrideConfig, None] = None, + input_config: MultiFactorAuthInputConfig, ) -> MultiFactorAuthConfig: - if first_factors is not None and len(first_factors) == 0: + if input_config.first_factors is not None and len(input_config.first_factors) == 0: raise ValueError("'first_factors' can be either None or a non-empty list") - from .types import MultiFactorAuthConfig as MFAC - from .types import OverrideConfig as OC + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions - if override is None: - override = OC() + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis - return MFAC( - first_factors=first_factors, - override=override, + return MultiFactorAuthConfig( + first_factors=input_config.first_factors, + override=override_config, ) diff --git a/supertokens_python/recipe/multitenancy/__init__.py b/supertokens_python/recipe/multitenancy/__init__.py index 5a7180a45..41dc96964 100644 --- a/supertokens_python/recipe/multitenancy/__init__.py +++ b/supertokens_python/recipe/multitenancy/__init__.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union from . import exceptions as ex from . import recipe @@ -22,9 +22,8 @@ exceptions = ex if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo + from supertokens_python.supertokens import RecipeInit - from ...recipe_module import RecipeModule from .interfaces import TypeGetAllowedDomainsForTenantId from .utils import InputOverrideConfig @@ -34,7 +33,7 @@ def init( TypeGetAllowedDomainsForTenantId, None ] = None, override: Union[InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return recipe.MultitenancyRecipe.init( get_allowed_domains_for_tenant_id, override, diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 03bd8d012..a2d6d25c5 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -13,10 +13,11 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from supertokens_python.types import RecipeUserId +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse if TYPE_CHECKING: @@ -192,7 +193,7 @@ def __init__(self, was_associated: bool): self.was_associated = was_associated -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -366,7 +367,7 @@ def to_json(self) -> Dict[str, Any]: } -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_login_methods_get = False diff --git a/supertokens_python/recipe/multitenancy/recipe.py b/supertokens_python/recipe/multitenancy/recipe.py index 38873a4e1..51d7874e8 100644 --- a/supertokens_python/recipe/multitenancy/recipe.py +++ b/supertokens_python/recipe/multitenancy/recipe.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.recipe.session.claim_base_classes.primitive_array_claim import ( PrimitiveArrayClaim, ) @@ -45,6 +46,7 @@ from .exceptions import MultitenancyError from .utils import ( InputOverrideConfig, + MultitenancyInputConfig, validate_and_normalise_user_input, ) @@ -54,35 +56,20 @@ class MultitenancyRecipe(RecipeModule): __instance = None def __init__( - self, - recipe_id: str, - app_info: AppInfo, - get_allowed_domains_for_tenant_id: Optional[ - TypeGetAllowedDomainsForTenantId - ] = None, - override: Union[InputOverrideConfig, None] = None, + self, recipe_id: str, app_info: AppInfo, input_config: MultitenancyInputConfig ) -> None: super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input( - get_allowed_domains_for_tenant_id, - override, - ) + self.config = validate_and_normalise_user_input(input_config=input_config) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) self.static_third_party_providers: List[ProviderInput] = [] self.get_allowed_domains_for_tenant_id = ( @@ -152,13 +139,21 @@ def init( ] = None, override: Union[InputOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + input_config = MultitenancyInputConfig( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if MultitenancyRecipe.__instance is None: MultitenancyRecipe.__instance = MultitenancyRecipe( - MultitenancyRecipe.recipe_id, - app_info, - get_allowed_domains_for_tenant_id, - override, + recipe_id=MultitenancyRecipe.recipe_id, + app_info=app_info, + input_config=apply_plugins( + recipe_id=MultitenancyRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) def callback(): diff --git a/supertokens_python/recipe/multitenancy/utils.py b/supertokens_python/recipe/multitenancy/utils.py index 43bab761b..3448e2e0b 100644 --- a/supertokens_python/recipe/multitenancy/utils.py +++ b/supertokens_python/recipe/multitenancy/utils.py @@ -14,22 +14,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Awaitable, Callable, Optional +from typing import Awaitable, Callable, Optional, Union from supertokens_python.exceptions import SuperTokensError from supertokens_python.framework import BaseRequest, BaseResponse +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) from supertokens_python.utils import ( resolve, ) -if TYPE_CHECKING: - from typing import Union - - from .interfaces import ( - APIInterface, - RecipeInterface, - TypeGetAllowedDomainsForTenantId, - ) +from .interfaces import ( + APIInterface, + RecipeInterface, + TypeGetAllowedDomainsForTenantId, +) class ErrorHandlers: @@ -63,47 +66,37 @@ async def on_recipe_disabled_for_tenant( ) -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... -class MultitenancyConfig: - def __init__( - self, - get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId], - override: OverrideConfig, - ): - self.get_allowed_domains_for_tenant_id = get_allowed_domains_for_tenant_id - self.override = override +class MultitenancyInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId] = None + + +class MultitenancyConfig(BaseConfig[RecipeInterface, APIInterface]): + get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId] def validate_and_normalise_user_input( - get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId], - override: Union[InputOverrideConfig, None] = None, + input_config: MultitenancyInputConfig, ) -> MultitenancyConfig: - if override is not None and not isinstance(override, OverrideConfig): # type: ignore - raise ValueError("override must be of type OverrideConfig or None") + if input_config.override is not None and not isinstance( + input_config.override, InputOverrideConfig + ): # type: ignore + raise ValueError("override must be of type InputOverrideConfig or None") + + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions - if override is None: - override = InputOverrideConfig() + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis return MultitenancyConfig( - get_allowed_domains_for_tenant_id, - OverrideConfig(override.functions, override.apis), + get_allowed_domains_for_tenant_id=input_config.get_allowed_domains_for_tenant_id, + override=override_config, ) diff --git a/supertokens_python/recipe/oauth2provider/__init__.py b/supertokens_python/recipe/oauth2provider/__init__.py index 3397b3430..3c5735030 100644 --- a/supertokens_python/recipe/oauth2provider/__init__.py +++ b/supertokens_python/recipe/oauth2provider/__init__.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union from . import exceptions as ex from . import recipe, utils @@ -22,12 +22,10 @@ InputOverrideConfig = utils.InputOverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( override: Union[InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return recipe.OAuth2ProviderRecipe.init(override) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 3329013a4..f96859747 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from typing_extensions import Literal @@ -23,6 +23,7 @@ RecipeUserId, User, ) +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse from .oauth2_client import OAuth2Client @@ -1016,7 +1017,7 @@ def from_json(json: Dict[str, Any]) -> "UpdateOAuth2ClientInput": ) -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def authorization( self, @@ -1284,7 +1285,7 @@ def __init__( self.recipe_implementation: RecipeInterface = recipe_implementation -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_login_get = False self.disable_auth_get = False diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index 0532091ce..ea54c40bd 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.recipe.oauth2provider.exceptions import OAuth2ProviderError from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.types import User @@ -68,6 +69,7 @@ from .utils import ( InputOverrideConfig, OAuth2ProviderConfig, + OAuth2ProviderInputConfig, validate_and_normalise_user_input, ) @@ -80,11 +82,11 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - override: Union[InputOverrideConfig, None] = None, + input_config: OAuth2ProviderInputConfig, ) -> None: super().__init__(recipe_id, app_info) self.config: OAuth2ProviderConfig = validate_and_normalise_user_input( - override, + input_config=input_config, ) from .recipe_implementation import RecipeImplementation @@ -96,19 +98,13 @@ def __init__( self.get_default_id_token_payload, self.get_default_user_info_payload, ) - self.recipe_implementation: RecipeInterface = ( - self.config.override.functions(recipe_implementation) - if self.config.override is not None - and self.config.override.functions is not None - else recipe_implementation + self.recipe_implementation: RecipeInterface = self.config.override.functions( + recipe_implementation ) api_implementation = APIImplementation() - self.api_implementation: APIInterface = ( - self.config.override.apis(api_implementation) - if self.config.override is not None - and self.config.override.apis is not None - else api_implementation + self.api_implementation: APIInterface = self.config.override.apis( + api_implementation ) self._access_token_builders: List[PayloadBuilderFunction] = [] @@ -270,12 +266,18 @@ def get_all_cors_headers(self) -> List[str]: def init( override: Union[InputOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + input_config = OAuth2ProviderInputConfig(override=override) + + def func(app_info: AppInfo, plugins: List[OverrideMap]) -> OAuth2ProviderRecipe: if OAuth2ProviderRecipe.__instance is None: OAuth2ProviderRecipe.__instance = OAuth2ProviderRecipe( - OAuth2ProviderRecipe.recipe_id, - app_info, - override, + recipe_id=OAuth2ProviderRecipe.recipe_id, + app_info=app_info, + input_config=apply_plugins( + recipe_id=OAuth2ProviderRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return OAuth2ProviderRecipe.__instance diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py index 7e6ff7e84..8fc7f5cc8 100644 --- a/supertokens_python/recipe/oauth2provider/utils.py +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -13,42 +13,35 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) -if TYPE_CHECKING: - from typing import Union +from .interfaces import APIInterface, RecipeInterface - from .interfaces import APIInterface, RecipeInterface +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class OAuth2ProviderInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): ... -class OAuth2ProviderConfig: - def __init__(self, override: Union[OverrideConfig, None] = None): - self.override = override +class OAuth2ProviderConfig(BaseConfig[RecipeInterface, APIInterface]): ... -def validate_and_normalise_user_input( - override: Union[InputOverrideConfig, None] = None, -): - if override is None: - return OAuth2ProviderConfig(OverrideConfig()) - return OAuth2ProviderConfig(OverrideConfig(override.functions, override.apis)) + +def validate_and_normalise_user_input(input_config: OAuth2ProviderInputConfig): + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions + + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis + + return OAuth2ProviderConfig(override=override_config) diff --git a/supertokens_python/recipe/openid/__init__.py b/supertokens_python/recipe/openid/__init__.py index 6ca57f2e2..06a411438 100644 --- a/supertokens_python/recipe/openid/__init__.py +++ b/supertokens_python/recipe/openid/__init__.py @@ -13,19 +13,17 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union from .recipe import OpenIdRecipe from .utils import InputOverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( issuer: Union[str, None] = None, override: Union[InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return OpenIdRecipe.init(issuer, override) diff --git a/supertokens_python/recipe/openid/interfaces.py b/supertokens_python/recipe/openid/interfaces.py index 0c2fd0ae5..691c911db 100644 --- a/supertokens_python/recipe/openid/interfaces.py +++ b/supertokens_python/recipe/openid/interfaces.py @@ -11,7 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Dict, List, Optional, Union from supertokens_python.framework import BaseRequest, BaseResponse @@ -20,6 +20,7 @@ CreateJwtResultUnsupportedAlgorithm, GetJWKSResult, ) +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse from .utils import OpenIdConfig @@ -70,7 +71,7 @@ def to_json(self) -> Dict[str, Any]: } -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -159,7 +160,7 @@ def to_json(self): } -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_open_id_discovery_configuration_get = False diff --git a/supertokens_python/recipe/openid/recipe.py b/supertokens_python/recipe/openid/recipe.py index 6f7abbe59..50a72d927 100644 --- a/supertokens_python/recipe/openid/recipe.py +++ b/supertokens_python/recipe/openid/recipe.py @@ -16,6 +16,7 @@ from os import environ from typing import TYPE_CHECKING, Any, Dict, List, Union +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from .api.implementation import APIImplementation @@ -24,7 +25,11 @@ from .exceptions import SuperTokensOpenIdError from .interfaces import APIOptions from .recipe_implementation import RecipeImplementation -from .utils import InputOverrideConfig, validate_and_normalise_user_input +from .utils import ( + InputOverrideConfig, + OpenIdInputConfig, + validate_and_normalise_user_input, +) if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -44,13 +49,14 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - issuer: Union[str, None] = None, - override: Union[InputOverrideConfig, None] = None, + input_config: OpenIdInputConfig, ): from supertokens_python.recipe.jwt import JWTRecipe super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(app_info, issuer, override) + self.config = validate_and_normalise_user_input( + app_info=app_info, input_config=input_config + ) self.jwt_recipe = JWTRecipe.get_instance() recipe_implementation = RecipeImplementation( @@ -58,17 +64,11 @@ def __init__( self.config, app_info, ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def get_apis_handled(self) -> List[APIHandled]: return [ @@ -131,13 +131,18 @@ def init( issuer: Union[str, None] = None, override: Union[InputOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + input_config = OpenIdInputConfig(issuer=issuer, override=override) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if OpenIdRecipe.__instance is None: OpenIdRecipe.__instance = OpenIdRecipe( - OpenIdRecipe.recipe_id, - app_info, - issuer, - override, + recipe_id=OpenIdRecipe.recipe_id, + app_info=app_info, + input_config=apply_plugins( + recipe_id=OpenIdRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return OpenIdRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/openid/utils.py b/supertokens_python/recipe/openid/utils.py index cb30d0d32..f1c18a64f 100644 --- a/supertokens_python/recipe/openid/utils.py +++ b/supertokens_python/recipe/openid/utils.py @@ -13,77 +13,68 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union + +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) if TYPE_CHECKING: from supertokens_python import AppInfo from supertokens_python.recipe.jwt import OverrideConfig as JWTOverrideConfig - from .interfaces import APIInterface, RecipeInterface from supertokens_python.normalised_url_domain import NormalisedURLDomain from supertokens_python.normalised_url_path import NormalisedURLPath +from .interfaces import APIInterface, RecipeInterface + + +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): + jwt_feature: Union[JWTOverrideConfig, None] = None + + +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... + + +class OpenIdInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + issuer: Union[str, None] = None + -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - jwt_feature: Union[JWTOverrideConfig, None] = None, - ): - self.functions = functions - self.apis = apis - self.jwt_feature = jwt_feature - - -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis - - -class OpenIdConfig: - def __init__( - self, - override: OverrideConfig, - issuer_domain: NormalisedURLDomain, - issuer_path: NormalisedURLPath, - ): - self.override = override - self.issuer_domain = issuer_domain - self.issuer_path = issuer_path +class OpenIdConfig(BaseConfig[RecipeInterface, APIInterface]): + issuer_domain: NormalisedURLDomain + issuer_path: NormalisedURLPath def validate_and_normalise_user_input( app_info: AppInfo, - issuer: Union[str, None] = None, - override: Union[InputOverrideConfig, None] = None, -): - if issuer is None: + input_config: OpenIdInputConfig, +) -> OpenIdConfig: + if input_config.issuer is None: issuer_domain = app_info.api_domain issuer_path = app_info.api_base_path else: - issuer_domain = NormalisedURLDomain(issuer) - issuer_path = NormalisedURLPath(issuer) + issuer_domain = NormalisedURLDomain(input_config.issuer) + issuer_path = NormalisedURLPath(input_config.issuer) if not issuer_path.equals(app_info.api_base_path): raise Exception( "The path of the issuer URL must be equal to the apiBasePath. The default value is /auth" ) - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions - if override is None: - override = InputOverrideConfig() + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis return OpenIdConfig( - OverrideConfig(functions=override.functions, apis=override.apis), - issuer_domain, - issuer_path, + issuer_domain=issuer_domain, + issuer_path=issuer_path, + override=override_config, ) diff --git a/supertokens_python/recipe/passwordless/__init__.py b/supertokens_python/recipe/passwordless/__init__.py index 4d0648384..bf53b9277 100644 --- a/supertokens_python/recipe/passwordless/__init__.py +++ b/supertokens_python/recipe/passwordless/__init__.py @@ -31,11 +31,9 @@ from .smsdelivery import services as smsdelivery_services if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo + from supertokens_python.supertokens import RecipeInit - from ...recipe_module import RecipeModule - -InputOverrideConfig = utils.OverrideConfig +InputOverrideConfig = utils.InputOverrideConfig ContactEmailOnlyConfig = utils.ContactEmailOnlyConfig ContactConfig = utils.ContactConfig PhoneOrEmailInput = utils.PhoneOrEmailInput @@ -63,7 +61,7 @@ def init( ] = None, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, sms_delivery: Union[SMSDeliveryConfig[SMSTemplateVars], None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return PasswordlessRecipe.init( contact_config, flow_type, diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index aae9347cc..32f9e37fa 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Dict, List, Optional, Union from typing_extensions import Literal @@ -26,6 +26,7 @@ User, ) from supertokens_python.types.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse from ...supertokens import AppInfo @@ -214,7 +215,7 @@ def __init__(self, reason: str): self.reason = reason -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -504,7 +505,7 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status, "reason": self.reason} -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_create_code_post = False self.disable_resend_code_post = False diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index a332cb019..fc24aeca0 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -22,6 +22,7 @@ from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.ingredients.smsdelivery import SMSDeliveryIngredient +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.recipe.multifactorauth.types import ( @@ -72,7 +73,8 @@ from .recipe_implementation import RecipeImplementation from .utils import ( ContactConfig, - OverrideConfig, + InputOverrideConfig, + PasswordlessInputConfig, get_enabled_pwless_factors, validate_and_normalise_user_input, ) @@ -98,46 +100,22 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - contact_config: ContactConfig, - flow_type: Literal[ - "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" - ], ingredients: PasswordlessIngredients, - override: Union[OverrideConfig, None] = None, - get_custom_user_input_code: Union[ - Callable[[str, Dict[str, Any]], Awaitable[str]], None - ] = None, - email_delivery: Union[ - EmailDeliveryConfig[PasswordlessLoginEmailTemplateVars], None - ] = None, - sms_delivery: Union[ - SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None - ] = None, + input_config: PasswordlessInputConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - app_info, - contact_config, - flow_type, - override, - get_custom_user_input_code, - email_delivery, - sms_delivery, + app_info=app_info, + input_config=input_config, ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) email_delivery_ingredient = ingredients.email_delivery if email_delivery_ingredient is None: @@ -508,7 +486,7 @@ def init( flow_type: Literal[ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" ], - override: Union[OverrideConfig, None] = None, + override: Union[InputOverrideConfig, None] = None, get_custom_user_input_code: Union[ Callable[[str, Dict[str, Any]], Awaitable[str]], None ] = None, @@ -519,19 +497,27 @@ def init( SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None ] = None, ): - def func(app_info: AppInfo): + input_config = PasswordlessInputConfig( + contact_config=contact_config, + get_custom_user_input_code=get_custom_user_input_code, + email_delivery=email_delivery, + sms_delivery=sms_delivery, + flow_type=flow_type, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if PasswordlessRecipe.__instance is None: ingredients = PasswordlessIngredients(None, None) PasswordlessRecipe.__instance = PasswordlessRecipe( - PasswordlessRecipe.recipe_id, - app_info, - contact_config, - flow_type, - ingredients, - override, - get_custom_user_input_code, - email_delivery, - sms_delivery, + recipe_id=PasswordlessRecipe.recipe_id, + app_info=app_info, + ingredients=ingredients, + input_config=apply_plugins( + recipe_id=PasswordlessRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return PasswordlessRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/passwordless/utils.py b/supertokens_python/recipe/passwordless/utils.py index b61615de5..5f4901ab1 100644 --- a/supertokens_python/recipe/passwordless/utils.py +++ b/supertokens_python/recipe/passwordless/utils.py @@ -15,8 +15,10 @@ from __future__ import annotations from abc import ABC +from re import fullmatch from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union +from phonenumbers import is_valid_number, parse from typing_extensions import Literal from supertokens_python.ingredients.emaildelivery.types import ( @@ -28,29 +30,30 @@ SMSDeliveryConfigWithService, ) from supertokens_python.recipe.multifactorauth.types import FactorIds -from supertokens_python.recipe.passwordless.types import ( - PasswordlessLoginSMSTemplateVars, -) - -if TYPE_CHECKING: - from supertokens_python import AppInfo - - from .interfaces import ( - APIInterface, - PasswordlessLoginEmailTemplateVars, - RecipeInterface, - ) - -from re import fullmatch - -from phonenumbers import is_valid_number, parse # type: ignore - from supertokens_python.recipe.passwordless.emaildelivery.services.backward_compatibility import ( BackwardCompatibilityService, ) from supertokens_python.recipe.passwordless.smsdelivery.services.backward_compatibility import ( BackwardCompatibilityService as SMSBackwardCompatibilityService, ) +from supertokens_python.recipe.passwordless.types import ( + PasswordlessLoginSMSTemplateVars, +) +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) + +from .interfaces import ( + APIInterface, + PasswordlessLoginEmailTemplateVars, + RecipeInterface, +) + +if TYPE_CHECKING: + from supertokens_python import AppInfo async def default_validate_phone_number(value: str, _tenant_id: str): @@ -68,14 +71,10 @@ async def default_validate_email(value: str, _tenant_id: str): return "Email is invalid" -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... + + +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... class ContactConfig(ABC): @@ -142,64 +141,67 @@ def __init__(self, phone_number: Union[str, None], email: Union[str, None]): self.email = email -class PasswordlessConfig: - def __init__( - self, - contact_config: ContactConfig, - override: OverrideConfig, - flow_type: Literal[ - "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" - ], - get_email_delivery_config: Callable[ - [], EmailDeliveryConfigWithService[PasswordlessLoginEmailTemplateVars] - ], - get_sms_delivery_config: Callable[ - [], SMSDeliveryConfigWithService[PasswordlessLoginSMSTemplateVars] - ], - get_custom_user_input_code: Union[ - Callable[[str, Dict[str, Any]], Awaitable[str]], None - ] = None, - ): - self.contact_config = contact_config - self.override = override - self.flow_type: Literal[ - "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" - ] = flow_type - self.get_custom_user_input_code = get_custom_user_input_code - self.get_email_delivery_config = get_email_delivery_config - self.get_sms_delivery_config = get_sms_delivery_config - - -def validate_and_normalise_user_input( - app_info: AppInfo, - contact_config: ContactConfig, +class PasswordlessInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + contact_config: ContactConfig flow_type: Literal[ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" - ], - override: Union[OverrideConfig, None] = None, + ] get_custom_user_input_code: Union[ Callable[[str, Dict[str, Any]], Awaitable[str]], None - ] = None, + ] = None email_delivery: Union[ EmailDeliveryConfig[PasswordlessLoginEmailTemplateVars], None - ] = None, - sms_delivery: Union[ - SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None - ] = None, + ] = None + sms_delivery: Union[SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None] = ( + None + ) + + +class PasswordlessConfig(BaseConfig[RecipeInterface, APIInterface]): + contact_config: ContactConfig + flow_type: Literal[ + "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" + ] + get_email_delivery_config: Callable[ + [], EmailDeliveryConfigWithService[PasswordlessLoginEmailTemplateVars] + ] + get_sms_delivery_config: Callable[ + [], SMSDeliveryConfigWithService[PasswordlessLoginSMSTemplateVars] + ] + get_custom_user_input_code: Union[ + Callable[[str, Dict[str, Any]], Awaitable[str]], None + ] + + +def validate_and_normalise_user_input( + app_info: AppInfo, + input_config: PasswordlessInputConfig, ) -> PasswordlessConfig: - if override is None: - override = OverrideConfig() + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions + + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis def get_email_delivery_config() -> EmailDeliveryConfigWithService[ PasswordlessLoginEmailTemplateVars ]: - email_service = email_delivery.service if email_delivery is not None else None + email_service = ( + input_config.email_delivery.service + if input_config.email_delivery is not None + else None + ) if email_service is None: email_service = BackwardCompatibilityService(app_info) - if email_delivery is not None and email_delivery.override is not None: - override = email_delivery.override + if ( + input_config.email_delivery is not None + and input_config.email_delivery.override is not None + ): + override = input_config.email_delivery.override else: override = None @@ -208,22 +210,29 @@ def get_email_delivery_config() -> EmailDeliveryConfigWithService[ def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ PasswordlessLoginSMSTemplateVars ]: - sms_service = sms_delivery.service if sms_delivery is not None else None + sms_service = ( + input_config.sms_delivery.service + if input_config.sms_delivery is not None + else None + ) if sms_service is None: sms_service = SMSBackwardCompatibilityService(app_info) - if sms_delivery is not None and sms_delivery.override is not None: - override = sms_delivery.override + if ( + input_config.sms_delivery is not None + and input_config.sms_delivery.override is not None + ): + override = input_config.sms_delivery.override else: override = None return SMSDeliveryConfigWithService(sms_service, override=override) - if not isinstance(contact_config, ContactConfig): # type: ignore user might not have linter enabled + if not isinstance(input_config.contact_config, ContactConfig): # type: ignore user might not have linter enabled raise ValueError("contact_config must be of type ContactConfig") - if flow_type not in [ + if input_config.flow_type not in [ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK", @@ -232,16 +241,13 @@ def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ "flow_type must be one of USER_INPUT_CODE, MAGIC_LINK, USER_INPUT_CODE_AND_MAGIC_LINK" ) - if not isinstance(override, OverrideConfig): # type: ignore user might not have linter enabled - raise ValueError("override must be of type OverrideConfig") - return PasswordlessConfig( - contact_config=contact_config, - override=OverrideConfig(functions=override.functions, apis=override.apis), - flow_type=flow_type, + contact_config=input_config.contact_config, + override=override_config, + flow_type=input_config.flow_type, get_email_delivery_config=get_email_delivery_config, get_sms_delivery_config=get_sms_delivery_config, - get_custom_user_input_code=get_custom_user_input_code, + get_custom_user_input_code=input_config.get_custom_user_input_code, ) diff --git a/supertokens_python/recipe/session/__init__.py b/supertokens_python/recipe/session/__init__.py index 5e4808f3e..7692e160b 100644 --- a/supertokens_python/recipe/session/__init__.py +++ b/supertokens_python/recipe/session/__init__.py @@ -18,9 +18,8 @@ from typing_extensions import Literal if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo, BaseRequest + from supertokens_python.supertokens import BaseRequest, RecipeInit - from ...recipe_module import RecipeModule from .utils import TokenTransferMethod from . import exceptions as ex @@ -53,7 +52,7 @@ def init( use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return SessionRecipe.init( cookie_domain, older_cookie_domain, diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 30cf85f54..2fe1ec756 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -33,6 +33,7 @@ MaybeAwaitable, RecipeUserId, ) +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse from ...utils import resolve @@ -187,7 +188,7 @@ class GetSessionTokensDangerouslyDict(TypedDict): antiCsrfToken: Optional[str] -class RecipeInterface(ABC): # pylint: disable=too-many-public-methods +class RecipeInterface(BaseRecipeInterface): # pylint: disable=too-many-public-methods def __init__(self): pass @@ -383,7 +384,7 @@ def __init__( self.recipe_implementation = recipe_implementation -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_refresh_post = False self.disable_signout_post = False diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index 5d025026a..d4476ca69 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -19,6 +19,7 @@ from typing_extensions import Literal from supertokens_python.framework.response import BaseResponse +from supertokens_python.plugins import OverrideMap, apply_plugins from ...types import MaybeAwaitable from .cookie_and_header import ( @@ -59,6 +60,7 @@ from .utils import ( InputErrorHandlers, InputOverrideConfig, + SessionInputConfig, TokenTransferMethod, validate_and_normalise_user_input, ) @@ -72,44 +74,12 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - cookie_domain: Union[str, None] = None, - older_cookie_domain: Union[str, None] = None, - cookie_secure: Union[bool, None] = None, - cookie_same_site: Union[Literal["lax", "none", "strict"], None] = None, - session_expired_status_code: Union[int, None] = None, - anti_csrf: Union[ - Literal["VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE"], None - ] = None, - get_token_transfer_method: Union[ - Callable[ - [BaseRequest, bool, Dict[str, Any]], - Union[TokenTransferMethod, Literal["any"]], - ], - None, - ] = None, - error_handlers: Union[InputErrorHandlers, None] = None, - override: Union[InputOverrideConfig, None] = None, - invalid_claim_status_code: Union[int, None] = None, - use_dynamic_access_token_signing_key: Union[bool, None] = None, - expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, - jwks_refresh_interval_sec: Union[int, None] = None, + input_config: SessionInputConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - app_info, - cookie_domain, - older_cookie_domain, - cookie_secure, - cookie_same_site, - session_expired_status_code, - anti_csrf, - get_token_transfer_method, - error_handlers, - override, - invalid_claim_status_code, - use_dynamic_access_token_signing_key, - expose_access_token_to_frontend_in_cookie_based_auth, - jwks_refresh_interval_sec, + app_info=app_info, + input_config=input_config, ) log_debug_message( "session init: anti_csrf: %s", self.config.anti_csrf_function_or_string @@ -123,8 +93,10 @@ def __init__( # we check the input cookie_same_site because the normalised version is # always a function. - if cookie_same_site is not None: - log_debug_message("session init: cookie_same_site: %s", cookie_same_site) + if input_config.cookie_same_site is not None: + log_debug_message( + "session init: cookie_same_site: %s", input_config.cookie_same_site + ) else: log_debug_message("session init: cookie_same_site: function") @@ -142,19 +114,15 @@ def __init__( recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config, self.app_info ) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) from .api.implementation import APIImplementation api_implementation = APIImplementation() - self.api_implementation: APIInterface = ( + self.api_implementation: APIInterface = self.config.override.apis( api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) ) self.claims_added_by_other_recipes: List[SessionClaim[Any]] = [] @@ -296,24 +264,32 @@ def init( expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, ): - def func(app_info: AppInfo): + input_config = SessionInputConfig( + cookie_domain=cookie_domain, + older_cookie_domain=older_cookie_domain, + cookie_secure=cookie_secure, + cookie_same_site=cookie_same_site, + session_expired_status_code=session_expired_status_code, + anti_csrf=anti_csrf, + get_token_transfer_method=get_token_transfer_method, + error_handlers=error_handlers, + override=override, + invalid_claim_status_code=invalid_claim_status_code, + use_dynamic_access_token_signing_key=use_dynamic_access_token_signing_key, + expose_access_token_to_frontend_in_cookie_based_auth=expose_access_token_to_frontend_in_cookie_based_auth, + jwks_refresh_interval_sec=jwks_refresh_interval_sec, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if SessionRecipe.__instance is None: SessionRecipe.__instance = SessionRecipe( - SessionRecipe.recipe_id, - app_info, - cookie_domain, - older_cookie_domain, - cookie_secure, - cookie_same_site, - session_expired_status_code, - anti_csrf, - get_token_transfer_method, - error_handlers, - override, - invalid_claim_status_code, - use_dynamic_access_token_signing_key, - expose_access_token_to_frontend_in_cookie_based_auth, - jwks_refresh_interval_sec, + recipe_id=SessionRecipe.recipe_id, + app_info=app_info, + input_config=apply_plugins( + recipe_id=SessionRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return SessionRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index bae607c22..f4d137c90 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -22,6 +22,12 @@ from supertokens_python.exceptions import raise_general_exception from supertokens_python.framework import BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) from supertokens_python.utils import ( is_an_ip_address, resolve, @@ -33,6 +39,12 @@ from ...types import MaybeAwaitable, RecipeUserId from .constants import AUTH_MODE_HEADER_KEY, SESSION_REFRESH from .exceptions import ClaimValidationError +from .interfaces import ( + APIInterface, + RecipeInterface, + SessionClaimValidator, + SessionContainer, +) if TYPE_CHECKING: from supertokens_python.framework import BaseRequest @@ -41,12 +53,6 @@ ) from supertokens_python.supertokens import AppInfo - from .interfaces import ( - APIInterface, - RecipeInterface, - SessionClaimValidator, - SessionContainer, - ) from .recipe import SessionRecipe from supertokens_python.logger import log_debug_message @@ -334,141 +340,116 @@ def get_token_transfer_method_default( return "any" -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - openid_feature: Union[OpenIdInputOverrideConfig, None] = None, - ): - self.functions = functions - self.apis = apis - self.openid_feature = openid_feature +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): + openid_feature: Optional[OpenIdInputOverrideConfig] = None -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... TokenType = Literal["access", "refresh"] TokenTransferMethod = Literal["cookie", "header"] -class SessionConfig: - def __init__( - self, - refresh_token_path: NormalisedURLPath, - cookie_domain: Union[None, str], - older_cookie_domain: Union[None, str], - get_cookie_same_site: Callable[ - [Optional[BaseRequest], Dict[str, Any]], - Literal["lax", "strict", "none"], - ], - cookie_secure: bool, - session_expired_status_code: int, - error_handlers: ErrorHandlers, - anti_csrf_function_or_string: Union[ - Callable[ - [Optional[BaseRequest], Dict[str, Any]], - Literal["VIA_CUSTOM_HEADER", "NONE"], - ], - Literal["VIA_CUSTOM_HEADER", "NONE", "VIA_TOKEN"], - ], - get_token_transfer_method: Callable[ +class SessionInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + cookie_domain: Union[str, None] = None + older_cookie_domain: Union[str, None] = None + cookie_secure: Union[bool, None] = None + cookie_same_site: Union[Literal["lax", "strict", "none"], None] = None + session_expired_status_code: Union[int, None] = None + anti_csrf: Union[Literal["VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE"], None] = None + get_token_transfer_method: Union[ + Callable[ [BaseRequest, bool, Dict[str, Any]], Union[TokenTransferMethod, Literal["any"]], ], - override: OverrideConfig, - framework: str, - mode: str, - invalid_claim_status_code: int, - use_dynamic_access_token_signing_key: bool, - expose_access_token_to_frontend_in_cookie_based_auth: bool, - jwks_refresh_interval_sec: int, - ): - self.session_expired_status_code = session_expired_status_code - self.invalid_claim_status_code = invalid_claim_status_code - self.use_dynamic_access_token_signing_key = use_dynamic_access_token_signing_key - self.expose_access_token_to_frontend_in_cookie_based_auth = ( - expose_access_token_to_frontend_in_cookie_based_auth - ) - - self.refresh_token_path = refresh_token_path - self.cookie_domain = cookie_domain - self.older_cookie_domain = older_cookie_domain - self.get_cookie_same_site = get_cookie_same_site - self.cookie_secure = cookie_secure - self.error_handlers = error_handlers - self.anti_csrf_function_or_string = anti_csrf_function_or_string - self.get_token_transfer_method = get_token_transfer_method - self.override = override - self.framework = framework - self.mode = mode - self.jwks_refresh_interval_sec = jwks_refresh_interval_sec + None, + ] = None + error_handlers: Union[ErrorHandlers, None] = None + invalid_claim_status_code: Union[int, None] = None + use_dynamic_access_token_signing_key: Union[bool, None] = None + expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None + jwks_refresh_interval_sec: Union[int, None] = None + + +class SessionConfig(BaseConfig[RecipeInterface, APIInterface]): + refresh_token_path: NormalisedURLPath + cookie_domain: Union[None, str] + older_cookie_domain: Union[None, str] + get_cookie_same_site: Callable[ + [Optional[BaseRequest], Dict[str, Any]], + Literal["lax", "strict", "none"], + ] + cookie_secure: bool + session_expired_status_code: int + error_handlers: ErrorHandlers + anti_csrf_function_or_string: Union[ + Callable[ + [Optional[BaseRequest], Dict[str, Any]], + Literal["VIA_CUSTOM_HEADER", "NONE"], + ], + Literal["VIA_CUSTOM_HEADER", "NONE", "VIA_TOKEN"], + ] + get_token_transfer_method: Callable[ + [BaseRequest, bool, Dict[str, Any]], + Union[TokenTransferMethod, Literal["any"]], + ] + # override: OverrideConfig, + framework: str + mode: str + invalid_claim_status_code: int + use_dynamic_access_token_signing_key: bool + expose_access_token_to_frontend_in_cookie_based_auth: bool + jwks_refresh_interval_sec: int def validate_and_normalise_user_input( app_info: AppInfo, - cookie_domain: Union[str, None] = None, - older_cookie_domain: Union[str, None] = None, - cookie_secure: Union[bool, None] = None, - cookie_same_site: Union[Literal["lax", "strict", "none"], None] = None, - session_expired_status_code: Union[int, None] = None, - anti_csrf: Union[Literal["VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE"], None] = None, - get_token_transfer_method: Union[ - Callable[ - [BaseRequest, bool, Dict[str, Any]], - Union[TokenTransferMethod, Literal["any"]], - ], - None, - ] = None, - error_handlers: Union[ErrorHandlers, None] = None, - override: Union[InputOverrideConfig, None] = None, - invalid_claim_status_code: Union[int, None] = None, - use_dynamic_access_token_signing_key: Union[bool, None] = None, - expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, - jwks_refresh_interval_sec: Union[int, None] = None, + input_config: SessionInputConfig, ): - _ = cookie_same_site # we have this otherwise pylint complains that cookie_same_site is unused, but it is being used in the get_cookie_same_site function. - if anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}: + # _ = cookie_same_site # we have this otherwise pylint complains that cookie_same_site is unused, but it is being used in the get_cookie_same_site function. + if input_config.anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}: raise ValueError( "anti_csrf must be one of VIA_TOKEN, VIA_CUSTOM_HEADER, NONE or None" ) - if error_handlers is not None and not isinstance(error_handlers, ErrorHandlers): # type: ignore + if input_config.error_handlers is not None and not isinstance( + input_config.error_handlers, ErrorHandlers + ): # type: ignore raise ValueError("error_handlers must be an instance of ErrorHandlers or None") - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") + # if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore + # raise ValueError("override must be an instance of InputOverrideConfig or None") cookie_domain = ( - normalise_session_scope(cookie_domain) if cookie_domain is not None else None + normalise_session_scope(input_config.cookie_domain) + if input_config.cookie_domain is not None + else None ) older_cookie_domain = ( - older_cookie_domain - if older_cookie_domain is None or older_cookie_domain == "" - else normalise_session_scope(older_cookie_domain) + input_config.older_cookie_domain + if input_config.older_cookie_domain is None + or input_config.older_cookie_domain == "" + else normalise_session_scope(input_config.older_cookie_domain) ) cookie_secure = ( - cookie_secure - if cookie_secure is not None + input_config.cookie_secure + if input_config.cookie_secure is not None else app_info.api_domain.get_as_string_dangerous().startswith("https") ) session_expired_status_code = ( - session_expired_status_code if session_expired_status_code is not None else 401 + input_config.session_expired_status_code + if input_config.session_expired_status_code is not None + else 401 ) invalid_claim_status_code = ( - invalid_claim_status_code if invalid_claim_status_code is not None else 403 + input_config.invalid_claim_status_code + if input_config.invalid_claim_status_code is not None + else 403 ) if session_expired_status_code == invalid_claim_status_code: @@ -477,21 +458,27 @@ def validate_and_normalise_user_input( f"({invalid_claim_status_code})" ) + get_token_transfer_method = input_config.get_token_transfer_method if get_token_transfer_method is None: get_token_transfer_method = get_token_transfer_method_default + error_handlers = input_config.error_handlers if error_handlers is None: error_handlers = InputErrorHandlers() - if override is None: - override = InputOverrideConfig() - + use_dynamic_access_token_signing_key = ( + input_config.use_dynamic_access_token_signing_key + ) if use_dynamic_access_token_signing_key is None: use_dynamic_access_token_signing_key = True + expose_access_token_to_frontend_in_cookie_based_auth = ( + input_config.expose_access_token_to_frontend_in_cookie_based_auth + ) if expose_access_token_to_frontend_in_cookie_based_auth is None: expose_access_token_to_frontend_in_cookie_based_auth = False + cookie_same_site = input_config.cookie_same_site if cookie_same_site is not None: # this is just so that we check that the user has provided the right # values, since normalise_same_site throws an error if the user @@ -538,29 +525,42 @@ def anti_csrf_function( ], Literal["VIA_CUSTOM_HEADER", "NONE", "VIA_TOKEN"], ] = anti_csrf_function + + anti_csrf = input_config.anti_csrf if anti_csrf is not None: anti_csrf_function_or_string = anti_csrf + jwks_refresh_interval_sec = input_config.jwks_refresh_interval_sec if jwks_refresh_interval_sec is None: jwks_refresh_interval_sec = 4 * 3600 # 4 hours + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions + + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis + return SessionConfig( - app_info.api_base_path.append(NormalisedURLPath(SESSION_REFRESH)), - cookie_domain, - older_cookie_domain, - get_cookie_same_site, - cookie_secure, - session_expired_status_code, - error_handlers, - anti_csrf_function_or_string, - get_token_transfer_method, - OverrideConfig(override.functions, override.apis), - app_info.framework, - app_info.mode, - invalid_claim_status_code, - use_dynamic_access_token_signing_key, - expose_access_token_to_frontend_in_cookie_based_auth, - jwks_refresh_interval_sec, + refresh_token_path=app_info.api_base_path.append( + NormalisedURLPath(SESSION_REFRESH) + ), + cookie_domain=cookie_domain, + older_cookie_domain=older_cookie_domain, + get_cookie_same_site=get_cookie_same_site, + cookie_secure=cookie_secure, + session_expired_status_code=session_expired_status_code, + error_handlers=error_handlers, + anti_csrf_function_or_string=anti_csrf_function_or_string, + get_token_transfer_method=get_token_transfer_method, + override=override_config, + framework=app_info.framework, + mode=app_info.mode, + invalid_claim_status_code=invalid_claim_status_code, + use_dynamic_access_token_signing_key=use_dynamic_access_token_signing_key, + expose_access_token_to_frontend_in_cookie_based_auth=expose_access_token_to_frontend_in_cookie_based_auth, + jwks_refresh_interval_sec=jwks_refresh_interval_sec, ) diff --git a/supertokens_python/recipe/thirdparty/__init__.py b/supertokens_python/recipe/thirdparty/__init__.py index 9e20d78e1..9f60e17c9 100644 --- a/supertokens_python/recipe/thirdparty/__init__.py +++ b/supertokens_python/recipe/thirdparty/__init__.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union from . import exceptions as ex from . import provider, utils @@ -28,15 +28,13 @@ exceptions = ex if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( sign_in_and_up_feature: Optional[SignInAndUpFeature] = None, override: Union[InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: if sign_in_and_up_feature is None: sign_in_and_up_feature = SignInAndUpFeature() return ThirdPartyRecipe.init(sign_in_and_up_feature, override) diff --git a/supertokens_python/recipe/thirdparty/interfaces.py b/supertokens_python/recipe/thirdparty/interfaces.py index db8012b39..181b3669e 100644 --- a/supertokens_python/recipe/thirdparty/interfaces.py +++ b/supertokens_python/recipe/thirdparty/interfaces.py @@ -13,9 +13,11 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + from ...types import RecipeUserId, User from ...types.response import APIResponse, GeneralErrorResponse from .provider import Provider, ProviderInput, RedirectUriInfo @@ -79,7 +81,7 @@ def __init__(self, reason: str): self.reason = reason -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -198,7 +200,7 @@ def to_json(self): } -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_sign_in_up_post = False self.disable_authorisation_url_get = False diff --git a/supertokens_python/recipe/thirdparty/recipe.py b/supertokens_python/recipe/thirdparty/recipe.py index b3ae12be6..d0a40d4ab 100644 --- a/supertokens_python/recipe/thirdparty/recipe.py +++ b/supertokens_python/recipe/thirdparty/recipe.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe_module import APIHandled, RecipeModule @@ -43,7 +44,7 @@ from .constants import APPLE_REDIRECT_HANDLER, AUTHORISATIONURL, SIGNINUP from .exceptions import SuperTokensThirdPartyError from .types import ThirdPartyIngredients -from .utils import validate_and_normalise_user_input +from .utils import ThirdPartyInputConfig, validate_and_normalise_user_input class ThirdPartyRecipe(RecipeModule): @@ -54,29 +55,21 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - sign_in_and_up_feature: SignInAndUpFeature, + input_config: ThirdPartyInputConfig, _ingredients: ThirdPartyIngredients, - override: Union[InputOverrideConfig, None] = None, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input( - sign_in_and_up_feature, - override, - ) + self.config = validate_and_normalise_user_input(input_config=input_config) self.providers = self.config.sign_in_and_up_feature.providers recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.providers ) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation: APIInterface = ( + self.api_implementation: APIInterface = self.config.override.apis( api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) ) def callback(): @@ -167,15 +160,23 @@ def init( sign_in_and_up_feature: SignInAndUpFeature, override: Union[InputOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + input_config = ThirdPartyInputConfig( + sign_in_and_up_feature=sign_in_and_up_feature, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if ThirdPartyRecipe.__instance is None: ingredients = ThirdPartyIngredients() ThirdPartyRecipe.__instance = ThirdPartyRecipe( - ThirdPartyRecipe.recipe_id, - app_info, - sign_in_and_up_feature, - ingredients, - override, + recipe_id=ThirdPartyRecipe.recipe_id, + app_info=app_info, + _ingredients=ingredients, + input_config=apply_plugins( + recipe_id=ThirdPartyRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return ThirdPartyRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/thirdparty/utils.py b/supertokens_python/recipe/thirdparty/utils.py index d556d20ac..0a4479fdd 100644 --- a/supertokens_python/recipe/thirdparty/utils.py +++ b/supertokens_python/recipe/thirdparty/utils.py @@ -13,10 +13,16 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.thirdparty.provider import ProviderInput +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) from .interfaces import APIInterface, RecipeInterface @@ -46,54 +52,39 @@ def __init__(self, providers: Optional[List[ProviderInput]] = None): self.providers = providers -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... -class ThirdPartyConfig: - def __init__( - self, - sign_in_and_up_feature: SignInAndUpFeature, - override: OverrideConfig, - ): - self.sign_in_and_up_feature = sign_in_and_up_feature - self.override = override +class ThirdPartyInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + sign_in_and_up_feature: SignInAndUpFeature + + +class ThirdPartyConfig(BaseConfig[RecipeInterface, APIInterface]): + sign_in_and_up_feature: SignInAndUpFeature def validate_and_normalise_user_input( - sign_in_and_up_feature: SignInAndUpFeature, - override: Union[InputOverrideConfig, None] = None, + input_config: ThirdPartyInputConfig, ) -> ThirdPartyConfig: - if not isinstance(sign_in_and_up_feature, SignInAndUpFeature): # type: ignore + if not isinstance(input_config.sign_in_and_up_feature, SignInAndUpFeature): # type: ignore raise ValueError( "sign_in_and_up_feature must be an instance of SignInAndUpFeature" ) - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions - if override is None: - override = InputOverrideConfig() + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis return ThirdPartyConfig( - sign_in_and_up_feature, - OverrideConfig(functions=override.functions, apis=override.apis), + sign_in_and_up_feature=input_config.sign_in_and_up_feature, + override=override_config, ) diff --git a/supertokens_python/recipe/totp/__init__.py b/supertokens_python/recipe/totp/__init__.py index f89944688..08d313d64 100644 --- a/supertokens_python/recipe/totp/__init__.py +++ b/supertokens_python/recipe/totp/__init__.py @@ -13,21 +13,19 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union from supertokens_python.recipe.totp.types import TOTPConfig from .recipe import TOTPRecipe if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( config: Union[TOTPConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return TOTPRecipe.init( config=config, ) diff --git a/supertokens_python/recipe/totp/interfaces.py b/supertokens_python/recipe/totp/interfaces.py index bd7f36249..9ea482110 100644 --- a/supertokens_python/recipe/totp/interfaces.py +++ b/supertokens_python/recipe/totp/interfaces.py @@ -14,9 +14,11 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + if TYPE_CHECKING: from supertokens_python import AppInfo from supertokens_python.framework import BaseRequest, BaseResponse @@ -42,7 +44,7 @@ ) -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def get_user_identifier_info_for_user_id( self, user_id: str, user_context: Dict[str, Any] @@ -143,7 +145,7 @@ def __init__( self.recipe_instance = recipe_instance -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_create_device_post = False self.disable_list_devices_get = False diff --git a/supertokens_python/recipe/totp/recipe.py b/supertokens_python/recipe/totp/recipe.py index e4bc3f3b5..ea2013862 100644 --- a/supertokens_python/recipe/totp/recipe.py +++ b/supertokens_python/recipe/totp/recipe.py @@ -73,17 +73,13 @@ def __init__( recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config ) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation: APIInterface = ( + self.api_implementation: APIInterface = self.config.override.apis( api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) ) def callback(): @@ -212,10 +208,12 @@ def func(app_info: AppInfo, plugins: Optional[List[OverrideMap]] = None): if TOTPRecipe.__instance is None: TOTPRecipe.__instance = TOTPRecipe( - TOTPRecipe.recipe_id, - app_info, - apply_plugins( - recipe_id=TOTPRecipe.recipe_id, config=config, plugins=plugins + recipe_id=TOTPRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=TOTPRecipe.recipe_id, + config=config, + plugins=plugins, ), ) return TOTPRecipe.__instance diff --git a/supertokens_python/recipe/totp/types.py b/supertokens_python/recipe/totp/types.py index d863696ba..338be8725 100644 --- a/supertokens_python/recipe/totp/types.py +++ b/supertokens_python/recipe/totp/types.py @@ -12,10 +12,16 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from typing_extensions import Literal +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) from supertokens_python.types.response import APIResponse from .interfaces import APIInterface, RecipeInterface @@ -177,39 +183,19 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class OverrideConfig: - def __init__( - self, - functions: Optional[Callable[[RecipeInterface], RecipeInterface]] = None, - apis: Optional[Callable[[APIInterface], APIInterface]] = None, - ): - self.functions = functions - self.apis = apis +class OverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... -class TOTPConfig: - def __init__( - self, - issuer: Optional[str] = None, - default_skew: Optional[int] = None, - default_period: Optional[int] = None, - override: Optional[OverrideConfig] = None, - ): - self.issuer = issuer - self.default_skew = default_skew - self.default_period = default_period - self.override = override +class NormalisedOverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... -class TOTPNormalisedConfig: - def __init__( - self, - issuer: str, - default_skew: int, - default_period: int, - override: OverrideConfig, - ): - self.issuer = issuer - self.default_skew = default_skew - self.default_period = default_period - self.override = override +class TOTPConfig(BaseInputConfig[RecipeInterface, APIInterface]): + issuer: Optional[str] = None + default_skew: Optional[int] = None + default_period: Optional[int] = None + + +class TOTPNormalisedConfig(BaseConfig[RecipeInterface, APIInterface]): + issuer: str + default_skew: int + default_period: int diff --git a/supertokens_python/recipe/totp/utils.py b/supertokens_python/recipe/totp/utils.py index 1e3781335..5d3e04e1b 100644 --- a/supertokens_python/recipe/totp/utils.py +++ b/supertokens_python/recipe/totp/utils.py @@ -16,7 +16,11 @@ from supertokens_python import AppInfo -from .types import OverrideConfig, TOTPConfig, TOTPNormalisedConfig +from .types import ( + NormalisedOverrideConfig, + TOTPConfig, + TOTPNormalisedConfig, +) def validate_and_normalise_user_input( @@ -29,17 +33,17 @@ def validate_and_normalise_user_input( default_skew = config.default_skew if config.default_skew is not None else 1 default_period = config.default_period if config.default_period is not None else 30 - if config.override is None: - override = OverrideConfig() - else: - override = OverrideConfig( - functions=config.override.functions, - apis=config.override.apis, - ) + override_config = NormalisedOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions + + if config.override.apis is not None: + override_config.apis = config.override.apis return TOTPNormalisedConfig( issuer=issuer, default_skew=default_skew, default_period=default_period, - override=override, + override=override_config, ) diff --git a/supertokens_python/recipe/usermetadata/__init__.py b/supertokens_python/recipe/usermetadata/__init__.py index e5bdd43ed..ae632b781 100644 --- a/supertokens_python/recipe/usermetadata/__init__.py +++ b/supertokens_python/recipe/usermetadata/__init__.py @@ -13,18 +13,16 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union from . import utils from .recipe import UserMetadataRecipe if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( override: Union[utils.InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return UserMetadataRecipe.init(override) diff --git a/supertokens_python/recipe/usermetadata/interfaces.py b/supertokens_python/recipe/usermetadata/interfaces.py index 05c042650..601a65278 100644 --- a/supertokens_python/recipe/usermetadata/interfaces.py +++ b/supertokens_python/recipe/usermetadata/interfaces.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + class MetadataResult(ABC): def __init__(self, metadata: Dict[str, Any]): @@ -11,7 +13,7 @@ class ClearUserMetadataResult: pass -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def get_user_metadata( self, user_id: str, user_context: Dict[str, Any] @@ -34,5 +36,5 @@ async def clear_user_metadata( pass -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): pass diff --git a/supertokens_python/recipe/usermetadata/recipe.py b/supertokens_python/recipe/usermetadata/recipe.py index 4eee4b066..c75894ce9 100644 --- a/supertokens_python/recipe/usermetadata/recipe.py +++ b/supertokens_python/recipe/usermetadata/recipe.py @@ -20,6 +20,7 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.usermetadata.exceptions import ( SuperTokensUserMetadataError, @@ -35,7 +36,7 @@ if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo -from .utils import InputOverrideConfig +from .utils import InputOverrideConfig, UserMetadataInputConfig class UserMetadataRecipe(RecipeModule): @@ -46,15 +47,15 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - override: Union[InputOverrideConfig, None] = None, + input_config: UserMetadataInputConfig, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(self, app_info, override) + self.config = validate_and_normalise_user_input( + _recipe=self, _app_info=app_info, input_config=input_config + ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: @@ -91,10 +92,18 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init(override: Union[InputOverrideConfig, None] = None): - def func(app_info: AppInfo): + input_config = UserMetadataInputConfig(override=override) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if UserMetadataRecipe.__instance is None: UserMetadataRecipe.__instance = UserMetadataRecipe( - UserMetadataRecipe.recipe_id, app_info, override + recipe_id=UserMetadataRecipe.recipe_id, + app_info=app_info, + input_config=apply_plugins( + recipe_id=UserMetadataRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return UserMetadataRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/usermetadata/utils.py b/supertokens_python/recipe/usermetadata/utils.py index 7e059cf74..4b226aaed 100644 --- a/supertokens_python/recipe/usermetadata/utils.py +++ b/supertokens_python/recipe/usermetadata/utils.py @@ -14,42 +14,47 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING from supertokens_python.recipe.usermetadata.interfaces import ( APIInterface, RecipeInterface, ) +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) if TYPE_CHECKING: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.supertokens import AppInfo -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... + + +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... + + +class UserMetadataInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): ... -class UserMetadataConfig: - def __init__(self, override: InputOverrideConfig) -> None: - self.override = override +class UserMetadataConfig(BaseConfig[RecipeInterface, APIInterface]): ... def validate_and_normalise_user_input( _recipe: UserMetadataRecipe, _app_info: AppInfo, - override: Union[InputOverrideConfig, None] = None, + input_config: UserMetadataInputConfig, ) -> UserMetadataConfig: - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions - if override is None: - override = InputOverrideConfig() + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis - return UserMetadataConfig(override=override) + return UserMetadataConfig(override=override_config) diff --git a/supertokens_python/recipe/userroles/__init__.py b/supertokens_python/recipe/userroles/__init__.py index f0e082669..52a88d08e 100644 --- a/supertokens_python/recipe/userroles/__init__.py +++ b/supertokens_python/recipe/userroles/__init__.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union from . import recipe, utils from .recipe import UserRolesRecipe @@ -22,16 +22,14 @@ UserRoleClaim = recipe.UserRoleClaim if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( skip_adding_roles_to_access_token: Optional[bool] = None, skip_adding_permissions_to_access_token: Optional[bool] = None, override: Union[utils.InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return UserRolesRecipe.init( skip_adding_roles_to_access_token, skip_adding_permissions_to_access_token, diff --git a/supertokens_python/recipe/userroles/interfaces.py b/supertokens_python/recipe/userroles/interfaces.py index 1ffc97410..e6b7fff81 100644 --- a/supertokens_python/recipe/userroles/interfaces.py +++ b/supertokens_python/recipe/userroles/interfaces.py @@ -1,6 +1,8 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Dict, List, Union +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + class AddRoleToUserOkResult: def __init__(self, did_user_already_have_role: bool): @@ -55,7 +57,7 @@ def __init__(self, roles: List[str]): self.roles = roles -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def add_role_to_user( self, @@ -123,5 +125,5 @@ async def get_all_roles(self, user_context: Dict[str, Any]) -> GetAllRolesOkResu pass -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): pass diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index 7a2d830d5..6b62d744d 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -20,6 +20,7 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.recipe.userroles.recipe_implementation import ( @@ -35,7 +36,7 @@ from ..session.claim_base_classes.primitive_array_claim import PrimitiveArrayClaim from .exceptions import SuperTokensUserRolesError from .interfaces import GetPermissionsForRoleOkResult, UnknownRoleError -from .utils import InputOverrideConfig +from .utils import InputOverrideConfig, UserRolesInputConfig class UserRolesRecipe(RecipeModule): @@ -46,25 +47,19 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - skip_adding_roles_to_access_token: Optional[bool] = None, - skip_adding_permissions_to_access_token: Optional[bool] = None, - override: Union[InputOverrideConfig, None] = None, + input_config: UserRolesInputConfig, ): from ..oauth2provider.recipe import OAuth2ProviderRecipe super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - self, - app_info, - skip_adding_roles_to_access_token, - skip_adding_permissions_to_access_token, - override, + _recipe=self, + _app_info=app_info, + input_config=input_config, ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) def callback(): @@ -218,14 +213,22 @@ def init( skip_adding_permissions_to_access_token: Optional[bool] = None, override: Union[InputOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + input_config = UserRolesInputConfig( + skip_adding_roles_to_access_token=skip_adding_roles_to_access_token, + skip_adding_permissions_to_access_token=skip_adding_permissions_to_access_token, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if UserRolesRecipe.__instance is None: UserRolesRecipe.__instance = UserRolesRecipe( - UserRolesRecipe.recipe_id, - app_info, - skip_adding_roles_to_access_token, - skip_adding_permissions_to_access_token, - override, + recipe_id=UserRolesRecipe.recipe_id, + app_info=app_info, + input_config=apply_plugins( + recipe_id=UserRolesRecipe.recipe_id, + config=input_config, + plugins=plugins, + ), ) return UserRolesRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/userroles/utils.py b/supertokens_python/recipe/userroles/utils.py index c94ab5bed..17a17e9d3 100644 --- a/supertokens_python/recipe/userroles/utils.py +++ b/supertokens_python/recipe/userroles/utils.py @@ -14,59 +14,65 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional from supertokens_python.recipe.userroles.interfaces import APIInterface, RecipeInterface from supertokens_python.supertokens import AppInfo +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) if TYPE_CHECKING: from supertokens_python.recipe.userroles.recipe import UserRolesRecipe -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... -class UserRolesConfig: - def __init__( - self, - skip_adding_roles_to_access_token: bool, - skip_adding_permissions_to_access_token: bool, - override: InputOverrideConfig, - ) -> None: - self.skip_adding_roles_to_access_token = skip_adding_roles_to_access_token - self.skip_adding_permissions_to_access_token = ( - skip_adding_permissions_to_access_token - ) - self.override = override +class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... + + +class UserRolesInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + skip_adding_roles_to_access_token: Optional[bool] = None + skip_adding_permissions_to_access_token: Optional[bool] = None + + +class UserRolesConfig(BaseConfig[RecipeInterface, APIInterface]): + skip_adding_roles_to_access_token: bool + skip_adding_permissions_to_access_token: bool def validate_and_normalise_user_input( _recipe: UserRolesRecipe, _app_info: AppInfo, - skip_adding_roles_to_access_token: Optional[bool] = None, - skip_adding_permissions_to_access_token: Optional[bool] = None, - override: Union[InputOverrideConfig, None] = None, + input_config: UserRolesInputConfig, + # skip_adding_roles_to_access_token: Optional[bool] = None, + # skip_adding_permissions_to_access_token: Optional[bool] = None, + # override: Union[InputOverrideConfig, None] = None, ) -> UserRolesConfig: - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") + override_config = OverrideConfig() + if input_config.override is not None: + if input_config.override.functions is not None: + override_config.functions = input_config.override.functions - if override is None: - override = InputOverrideConfig() + if input_config.override.apis is not None: + override_config.apis = input_config.override.apis + skip_adding_roles_to_access_token = input_config.skip_adding_roles_to_access_token if skip_adding_roles_to_access_token is None: skip_adding_roles_to_access_token = False + + skip_adding_permissions_to_access_token = ( + input_config.skip_adding_permissions_to_access_token + ) if skip_adding_permissions_to_access_token is None: skip_adding_permissions_to_access_token = False return UserRolesConfig( skip_adding_roles_to_access_token=skip_adding_roles_to_access_token, skip_adding_permissions_to_access_token=skip_adding_permissions_to_access_token, - override=override, + override=override_config, ) diff --git a/supertokens_python/recipe/webauthn/interfaces/api.py b/supertokens_python/recipe/webauthn/interfaces/api.py index 230a9086f..819969ccc 100644 --- a/supertokens_python/recipe/webauthn/interfaces/api.py +++ b/supertokens_python/recipe/webauthn/interfaces/api.py @@ -40,6 +40,7 @@ from supertokens_python.supertokens import AppInfo from supertokens_python.types import RecipeUserId, User from supertokens_python.types.base import UserContext +from supertokens_python.types.recipe import BaseAPIInterface from supertokens_python.types.response import ( CamelCaseBaseModel, GeneralErrorResponse, @@ -219,7 +220,7 @@ class RegisterOptionsPOSTKwargsInput(TypedDict): email: NotRequired[str] -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): disable_register_options_post: bool = False disable_sign_in_options_post: bool = False disable_sign_up_post: bool = False diff --git a/supertokens_python/recipe/webauthn/interfaces/recipe.py b/supertokens_python/recipe/webauthn/interfaces/recipe.py index 67dcbe012..1cb49389d 100644 --- a/supertokens_python/recipe/webauthn/interfaces/recipe.py +++ b/supertokens_python/recipe/webauthn/interfaces/recipe.py @@ -30,6 +30,7 @@ from supertokens_python.types import RecipeUserId, User from supertokens_python.types.auth_utils import LinkingToSessionUserFailedError from supertokens_python.types.base import UserContext +from supertokens_python.types.recipe import BaseRecipeInterface from supertokens_python.types.response import ( CamelCaseBaseModel, OkResponseBaseModel, @@ -440,7 +441,7 @@ class RegisterOptionsKwargsInput(TypedDict): email: NotRequired[str] -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def register_options( self, diff --git a/supertokens_python/recipe/webauthn/recipe.py b/supertokens_python/recipe/webauthn/recipe.py index 02465285b..bed22a345 100644 --- a/supertokens_python/recipe/webauthn/recipe.py +++ b/supertokens_python/recipe/webauthn/recipe.py @@ -21,6 +21,7 @@ from supertokens_python.framework.response import BaseResponse from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.querier import Querier from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe @@ -105,18 +106,12 @@ def __init__( querier=querier, config=self.config, ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) if ingredients.email_delivery is None: self.email_delivery = EmailDeliveryIngredient( @@ -301,12 +296,16 @@ def get_instance_optional() -> Optional["WebauthnRecipe"]: @staticmethod def init(config: Optional[WebauthnConfig]): - def func(app_info: AppInfo): + def func(app_info: AppInfo, plugins: List[OverrideMap]): if WebauthnRecipe.__instance is None: WebauthnRecipe.__instance = WebauthnRecipe( recipe_id=WebauthnRecipe.recipe_id, app_info=app_info, - config=config, + config=apply_plugins( + recipe_id=WebauthnRecipe.recipe_id, + config=config, + plugins=plugins, + ), ingredients=WebauthnIngredients(email_delivery=None), ) return WebauthnRecipe.__instance diff --git a/supertokens_python/recipe/webauthn/types/config.py b/supertokens_python/recipe/webauthn/types/config.py index f0c36202e..37e32dd14 100644 --- a/supertokens_python/recipe/webauthn/types/config.py +++ b/supertokens_python/recipe/webauthn/types/config.py @@ -14,7 +14,6 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Protocol, TypeVar, Union, runtime_checkable from supertokens_python.framework import BaseRequest @@ -24,6 +23,13 @@ EmailDeliveryConfigWithService, ) from supertokens_python.types.base import UserContext +from supertokens_python.types.config import ( + BaseConfig, + BaseInputConfig, + BaseInputOverrideConfig, + BaseOverrideConfig, +) +from supertokens_python.types.response import CamelCaseBaseModel if TYPE_CHECKING: from supertokens_python.recipe.webauthn.interfaces.api import ( @@ -179,39 +185,29 @@ def __call__( ) -> InterfaceType: ... -# NOTE: Using dataclasses for these classes since validation is not required -@dataclass -class OverrideConfig: - """ - `WebauthnConfig.override` - """ +class OverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... + - functions: Optional[InterfaceOverride[RecipeInterface]] = None - apis: Optional[InterfaceOverride[APIInterface]] = None +class NormalisedOverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... -@dataclass -class WebauthnConfig: +class WebauthnConfig(BaseInputConfig[RecipeInterface, APIInterface]): get_relying_party_id: Optional[Union[str, GetRelyingPartyId]] = None get_relying_party_name: Optional[Union[str, GetRelyingPartyName]] = None get_origin: Optional[GetOrigin] = None email_delivery: Optional[EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]] = None validate_email_address: Optional[ValidateEmailAddress] = None - override: Optional[OverrideConfig] = None -@dataclass -class NormalisedWebauthnConfig: +class NormalisedWebauthnConfig(BaseConfig[RecipeInterface, APIInterface]): get_relying_party_id: NormalisedGetRelyingPartyId get_relying_party_name: NormalisedGetRelyingPartyName get_origin: NormalisedGetOrigin get_email_delivery_config: NormalisedGetEmailDeliveryConfig validate_email_address: NormalisedValidateEmailAddress - override: OverrideConfig -@dataclass -class WebauthnIngredients: +class WebauthnIngredients(CamelCaseBaseModel): email_delivery: Optional[ EmailDeliveryIngredient[TypeWebauthnEmailDeliveryInput] ] = None diff --git a/supertokens_python/recipe/webauthn/utils.py b/supertokens_python/recipe/webauthn/utils.py index c3a35abfb..dc0499354 100644 --- a/supertokens_python/recipe/webauthn/utils.py +++ b/supertokens_python/recipe/webauthn/utils.py @@ -33,9 +33,9 @@ NormalisedGetOrigin, NormalisedGetRelyingPartyId, NormalisedGetRelyingPartyName, + NormalisedOverrideConfig, NormalisedValidateEmailAddress, NormalisedWebauthnConfig, - OverrideConfig, ValidateEmailAddress, WebauthnConfig, ) @@ -60,13 +60,13 @@ def validate_and_normalise_user_input( config.validate_email_address ) - if config.override is None: - override = OverrideConfig() - else: - override = OverrideConfig( - functions=config.override.functions, - apis=config.override.apis, - ) + override_config = NormalisedOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions + + if config.override.apis is not None: + override_config.apis = config.override.apis def get_email_delivery_config() -> EmailDeliveryConfigWithService[ TypeWebauthnEmailDeliveryInput @@ -93,7 +93,7 @@ def get_email_delivery_config() -> EmailDeliveryConfigWithService[ get_origin=get_origin, get_email_delivery_config=get_email_delivery_config, validate_email_address=validate_email_address, - override=override, + override=override_config, ) diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index cd9536a7e..87deed255 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -431,17 +431,21 @@ def make_recipe( if not jwt_found: from supertokens_python.recipe.jwt.recipe import JWTRecipe - self.recipe_modules.append(JWTRecipe.init()(self.app_info)) + self.recipe_modules.append(JWTRecipe.init()(self.app_info, override_maps)) if not openid_found: from supertokens_python.recipe.openid.recipe import OpenIdRecipe - self.recipe_modules.append(OpenIdRecipe.init()(self.app_info)) + self.recipe_modules.append( + OpenIdRecipe.init()(self.app_info, override_maps) + ) if not multitenancy_found: from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe - self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) + self.recipe_modules.append( + MultitenancyRecipe.init()(self.app_info, override_maps) + ) if totp_found and not multi_factor_auth_found: raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.") @@ -449,14 +453,18 @@ def make_recipe( if not user_metadata_found: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe - self.recipe_modules.append(UserMetadataRecipe.init()(self.app_info)) + self.recipe_modules.append( + UserMetadataRecipe.init()(self.app_info, override_maps) + ) if not oauth2_found: from supertokens_python.recipe.oauth2provider.recipe import ( OAuth2ProviderRecipe, ) - self.recipe_modules.append(OAuth2ProviderRecipe.init()(self.app_info)) + self.recipe_modules.append( + OAuth2ProviderRecipe.init()(self.app_info, override_maps) + ) self.telemetry = ( config.telemetry diff --git a/supertokens_python/test.py b/supertokens_python/test.py new file mode 100644 index 000000000..d8b5e8c71 --- /dev/null +++ b/supertokens_python/test.py @@ -0,0 +1,14 @@ +from typing import cast + +from django.http import HttpRequest + +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.framework.django.asyncio import verify_session + + +# highlight-start +@verify_session() +async def some_api(request: HttpRequest): + session: SessionContainer = cast(SessionContainer, request.supertokens) # type: ignore This will delete the session from the db and from the frontend (cookies) + # highlight-end + await session.revoke_session() diff --git a/supertokens_python/types/config.py b/supertokens_python/types/config.py new file mode 100644 index 000000000..43a20d3cb --- /dev/null +++ b/supertokens_python/types/config.py @@ -0,0 +1,93 @@ +from typing import Callable, Generic, Optional, TypeVar, Union + +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface +from supertokens_python.types.response import CamelCaseBaseModel + +InterfaceType = TypeVar( + "InterfaceType", bound=Union[BaseRecipeInterface, BaseAPIInterface] +) +"""Generic Type for use in `InterfaceOverride`""" +FunctionInterfaceType = TypeVar("FunctionInterfaceType", bound=BaseRecipeInterface) +"""Generic Type for use in `FunctionOverrideConfig`""" +APIInterfaceType = TypeVar("APIInterfaceType", bound=BaseAPIInterface) +"""Generic Type for use in `APIOverrideConfig`""" + + +InterfaceOverride = Callable[[InterfaceType], InterfaceType] + +# @runtime_checkable +# class InterfaceOverride(Protocol[InterfaceType]): +# """ +# Callable signature for `.override.*`. +# """ + +# def __call__( +# self, +# original_implementation: InterfaceType, +# ) -> InterfaceType: ... + + +class BaseInputOverrideConfigWithoutAPI( + CamelCaseBaseModel, Generic[FunctionInterfaceType] +): + """Base class for input override config without API overrides.""" + + functions: Optional[InterfaceOverride[FunctionInterfaceType]] = None + + +class BaseOverrideConfigWithoutAPI(CamelCaseBaseModel, Generic[FunctionInterfaceType]): + """Base class for normalized override config without API overrides.""" + + functions: InterfaceOverride[FunctionInterfaceType] = ( + lambda original_implementation: original_implementation + ) + + +class BaseInputOverrideConfig( + BaseInputOverrideConfigWithoutAPI[FunctionInterfaceType], + Generic[FunctionInterfaceType, APIInterfaceType], +): + """Base class for input override config with API overrides.""" + + apis: Optional[InterfaceOverride[APIInterfaceType]] = None + + +class BaseOverrideConfig( + BaseOverrideConfigWithoutAPI[FunctionInterfaceType], + Generic[FunctionInterfaceType, APIInterfaceType], +): + """Base class for normalized override config with API overrides.""" + + apis: InterfaceOverride[APIInterfaceType] = ( + lambda original_implementation: original_implementation + ) + + +class BaseInputConfigWithoutAPIOverride( + CamelCaseBaseModel, Generic[FunctionInterfaceType] +): + """Base class for input config of a Recipe without API overrides.""" + + override: Optional[BaseInputOverrideConfigWithoutAPI[FunctionInterfaceType]] = None + + +class BaseConfigWithoutAPIOverride(CamelCaseBaseModel, Generic[FunctionInterfaceType]): + """Base class for normalized config of a Recipe without API overrides.""" + + override: BaseOverrideConfigWithoutAPI[FunctionInterfaceType] + + +class BaseInputConfig( + CamelCaseBaseModel, Generic[FunctionInterfaceType, APIInterfaceType] +): + """Base class for input config of a Recipe with API overrides.""" + + override: Optional[ + BaseInputOverrideConfig[FunctionInterfaceType, APIInterfaceType] + ] = None + + +class BaseConfig(CamelCaseBaseModel, Generic[FunctionInterfaceType, APIInterfaceType]): + """Base class for normalized config of a Recipe with API overrides.""" + + override: BaseOverrideConfig[FunctionInterfaceType, APIInterfaceType] diff --git a/supertokens_python/types/recipe.py b/supertokens_python/types/recipe.py new file mode 100644 index 000000000..566695534 --- /dev/null +++ b/supertokens_python/types/recipe.py @@ -0,0 +1,7 @@ +from abc import ABC + + +class BaseRecipeInterface(ABC): ... + + +class BaseAPIInterface(ABC): ... diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 11c8698c5..4ea94a5de 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -486,13 +486,11 @@ def custom_init( WebauthnRecipe.reset() def override_email_verification_apis( - original_implementation_email_verification: EmailVerificationAPIInterface, + original_implementation: EmailVerificationAPIInterface, ): - original_email_verify_post = ( - original_implementation_email_verification.email_verify_post - ) + original_email_verify_post = original_implementation.email_verify_post original_generate_email_verify_token_post = ( - original_implementation_email_verification.generate_email_verify_token_post + original_implementation.generate_email_verify_token_post ) async def email_verify_post( @@ -531,11 +529,11 @@ async def generate_email_verify_token_post( session, api_options, user_context ) - original_implementation_email_verification.email_verify_post = email_verify_post - original_implementation_email_verification.generate_email_verify_token_post = ( + original_implementation.email_verify_post = email_verify_post + original_implementation.generate_email_verify_token_post = ( generate_email_verify_token_post ) - return original_implementation_email_verification + return original_implementation def override_email_password_apis( original_implementation: EmailPasswordAPIInterface, diff --git a/tests/frontendIntegration/django2x/polls/views.py b/tests/frontendIntegration/django2x/polls/views.py index 5d2a587e4..3c3cc5dad 100644 --- a/tests/frontendIntegration/django2x/polls/views.py +++ b/tests/frontendIntegration/django2x/polls/views.py @@ -279,13 +279,13 @@ def unauthorised_f(req: BaseRequest, message: str, res: BaseResponse): return res -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -309,9 +309,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation def get_app_port(): diff --git a/tests/frontendIntegration/drf_async/mysite/settings.py b/tests/frontendIntegration/drf_async/mysite/settings.py index a7f9aba4f..72bf8d82f 100644 --- a/tests/frontendIntegration/drf_async/mysite/settings.py +++ b/tests/frontendIntegration/drf_async/mysite/settings.py @@ -162,13 +162,13 @@ def get_app_port(): return "8080" -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -192,9 +192,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation init( diff --git a/tests/frontendIntegration/drf_async/polls/views.py b/tests/frontendIntegration/drf_async/polls/views.py index f5e53db57..df1c4953d 100644 --- a/tests/frontendIntegration/drf_async/polls/views.py +++ b/tests/frontendIntegration/drf_async/polls/views.py @@ -306,13 +306,13 @@ async def unauthorised_f(req: BaseRequest, message: str, res: BaseResponse): return res -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -336,9 +336,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation def get_app_port(): diff --git a/tests/frontendIntegration/fastapi-server/app.py b/tests/frontendIntegration/fastapi-server/app.py index d6766074f..3f6502a23 100644 --- a/tests/frontendIntegration/fastapi-server/app.py +++ b/tests/frontendIntegration/fastapi-server/app.py @@ -131,13 +131,13 @@ async def unauthorised_f(_: BaseRequest, __: str, res: BaseResponse): return res -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -161,9 +161,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation def get_app_port(): diff --git a/tests/frontendIntegration/flask-server/app.py b/tests/frontendIntegration/flask-server/app.py index 27065f66f..683242e30 100644 --- a/tests/frontendIntegration/flask-server/app.py +++ b/tests/frontendIntegration/flask-server/app.py @@ -162,13 +162,13 @@ async def unauthorised_f(_: BaseRequest, __: str, res: BaseResponse): return res -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -192,9 +192,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation def get_app_port(): diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index 36c63b9ac..91d50d128 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -21,8 +21,10 @@ def session_functions_override_with_claim( if params is None: params = {} - def session_function_override(oi: RecipeInterface) -> RecipeInterface: - oi_create_new_session = oi.create_new_session + def session_function_override( + original_implementation: RecipeInterface, + ) -> RecipeInterface: + oi_create_new_session = original_implementation.create_new_session async def new_create_new_session( user_id: str, @@ -58,8 +60,8 @@ async def new_create_new_session( user_context, ) - oi.create_new_session = new_create_new_session - return oi + original_implementation.create_new_session = new_create_new_session + return original_implementation return session_function_override diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 87647bbb7..820e1011c 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -17,7 +17,6 @@ from passwordless import add_passwordless_routes # pylint: disable=import-error from session import add_session_routes # pylint: disable=import-error from supertokens_python import ( - AppInfo, InputAppInfo, Supertokens, SupertokensConfig, @@ -66,7 +65,7 @@ ) from supertokens_python.recipe.webauthn.recipe import WebauthnRecipe from supertokens_python.recipe.webauthn.types.config import WebauthnConfig -from supertokens_python.recipe_module import RecipeModule +from supertokens_python.supertokens import RecipeInit from supertokens_python.types import RecipeUserId from test_functions_mapper import ( # pylint: disable=import-error get_func, @@ -251,9 +250,7 @@ def init_st(config: Dict[str, Any]): st_reset() override_logging.reset_override_logs() - recipe_list: List[Callable[[AppInfo], RecipeModule]] = [ - dashboard.init(api_key="test") - ] + recipe_list: List[RecipeInit] = [dashboard.init(api_key="test")] for recipe_config in config.get("recipeList", []): recipe_id = recipe_config.get("recipeId") if recipe_id == "emailpassword": diff --git a/tests/test_session.py b/tests/test_session.py index 4b5c3e1c7..7351c4a4b 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -505,15 +505,15 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): async def test_revoking_session_during_refresh_and_throw_unauthorized( driver_config_client: TestClient, ): - def session_api_override(oi: APIInterface) -> APIInterface: - oi_refresh_post = oi.refresh_post + def session_api_override(original_implementation: APIInterface) -> APIInterface: + oi_refresh_post = original_implementation.refresh_post async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): await oi_refresh_post(api_options, user_context) return raise_unauthorised_exception("unauthorized", clear_tokens=True) - oi.refresh_post = refresh_post - return oi + original_implementation.refresh_post = refresh_post + return original_implementation init_args = get_st_init_args( url=get_new_core_app_url(), From ffff56b75d80dfbb0c34522850bb534d68bed81d Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Tue, 1 Jul 2025 10:32:00 +0530 Subject: [PATCH 12/37] fix: cyclic imports --- .../recipe/accountlinking/__init__.py | 6 +++-- .../recipe/emailverification/recipe.py | 26 +++++++++---------- .../recipe/multitenancy/utils.py | 5 ---- .../recipe/openid/interfaces.py | 7 ++--- .../recipe/passwordless/interfaces.py | 6 +++-- supertokens_python/recipe/session/__init__.py | 9 +++---- .../recipe/session/interfaces.py | 6 ++--- supertokens_python/recipe/session/utils.py | 3 +-- .../recipe/webauthn/interfaces/api.py | 6 +++-- .../recipe/webauthn/interfaces/recipe.py | 2 +- .../recipe/webauthn/types/config.py | 14 +++++----- 11 files changed, 42 insertions(+), 48 deletions(-) diff --git a/supertokens_python/recipe/accountlinking/__init__.py b/supertokens_python/recipe/accountlinking/__init__.py index 97d295bd3..16d7c9d91 100644 --- a/supertokens_python/recipe/accountlinking/__init__.py +++ b/supertokens_python/recipe/accountlinking/__init__.py @@ -13,10 +13,9 @@ # under the License. from __future__ import annotations -from typing import Any, Awaitable, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union from ...types import User -from ..session.interfaces import SessionContainer from . import types from .recipe import AccountLinkingRecipe @@ -26,6 +25,9 @@ ShouldAutomaticallyLink = types.ShouldAutomaticallyLink ShouldNotAutomaticallyLink = types.ShouldNotAutomaticallyLink +if TYPE_CHECKING: + from ..session.interfaces import SessionContainer + def init( on_account_linked: Optional[ diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index 9f5cc7e7f..6dea1f3b5 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -18,7 +18,9 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient +from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.plugins import OverrideMap, apply_plugins +from supertokens_python.querier import Querier from supertokens_python.recipe.emailverification.exceptions import ( EmailVerificationInvalidTokenError, ) @@ -28,6 +30,7 @@ VerificationEmailTemplateVars, VerificationEmailTemplateVarsUser, ) +from supertokens_python.recipe.emailverification.utils import get_email_verify_link from supertokens_python.recipe_module import APIHandled, RecipeModule from ...ingredients.emaildelivery.types import EmailDeliveryConfig @@ -47,6 +50,9 @@ SessionClaimValidator, SessionContainer, ) +from .api import handle_email_verify_api, handle_generate_email_verify_token_api +from .constants import USER_EMAIL_VERIFY, USER_EMAIL_VERIFY_TOKEN +from .exceptions import SuperTokensEmailVerificationError from .interfaces import ( APIInterface, APIOptions, @@ -63,6 +69,12 @@ VerifyEmailUsingTokenOkResult, ) from .recipe_implementation import RecipeImplementation +from .utils import ( + MODE_TYPE, + EmailVerificationInputConfig, + InputOverrideConfig, + validate_and_normalise_user_input, +) if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -72,20 +84,6 @@ from ...types import MaybeAwaitable, User -from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.querier import Querier -from supertokens_python.recipe.emailverification.utils import get_email_verify_link - -from .api import handle_email_verify_api, handle_generate_email_verify_token_api -from .constants import USER_EMAIL_VERIFY, USER_EMAIL_VERIFY_TOKEN -from .exceptions import SuperTokensEmailVerificationError -from .utils import ( - MODE_TYPE, - EmailVerificationInputConfig, - InputOverrideConfig, - validate_and_normalise_user_input, -) - class EmailVerificationRecipe(RecipeModule): recipe_id = "emailverification" diff --git a/supertokens_python/recipe/multitenancy/utils.py b/supertokens_python/recipe/multitenancy/utils.py index 3448e2e0b..786b1aff6 100644 --- a/supertokens_python/recipe/multitenancy/utils.py +++ b/supertokens_python/recipe/multitenancy/utils.py @@ -83,11 +83,6 @@ class MultitenancyConfig(BaseConfig[RecipeInterface, APIInterface]): def validate_and_normalise_user_input( input_config: MultitenancyInputConfig, ) -> MultitenancyConfig: - if input_config.override is not None and not isinstance( - input_config.override, InputOverrideConfig - ): # type: ignore - raise ValueError("override must be of type InputOverrideConfig or None") - override_config = OverrideConfig() if input_config.override is not None: if input_config.override.functions is not None: diff --git a/supertokens_python/recipe/openid/interfaces.py b/supertokens_python/recipe/openid/interfaces.py index 691c911db..6719b129c 100644 --- a/supertokens_python/recipe/openid/interfaces.py +++ b/supertokens_python/recipe/openid/interfaces.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from abc import abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.jwt.interfaces import ( @@ -23,7 +23,8 @@ from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse -from .utils import OpenIdConfig +if TYPE_CHECKING: + from .utils import OpenIdConfig class GetOpenIdDiscoveryConfigurationResult: @@ -102,7 +103,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: OpenIdConfig, + config: "OpenIdConfig", recipe_implementation: RecipeInterface, ): self.request = request diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index 32f9e37fa..2807e0695 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -14,7 +14,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing_extensions import Literal @@ -38,7 +38,9 @@ PasswordlessLoginSMSTemplateVars, SMSDeliveryIngredient, ) -from .utils import PasswordlessConfig + +if TYPE_CHECKING: + from .utils import PasswordlessConfig class CreateCodeOkResult: diff --git a/supertokens_python/recipe/session/__init__.py b/supertokens_python/recipe/session/__init__.py index 7692e160b..799221eb7 100644 --- a/supertokens_python/recipe/session/__init__.py +++ b/supertokens_python/recipe/session/__init__.py @@ -13,18 +13,17 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Union from typing_extensions import Literal -if TYPE_CHECKING: - from supertokens_python.supertokens import BaseRequest, RecipeInit - - from .utils import TokenTransferMethod +from supertokens_python.framework import BaseRequest +from supertokens_python.supertokens import RecipeInit from . import exceptions as ex from . import interfaces, utils from .recipe import SessionRecipe +from .utils import TokenTransferMethod InputErrorHandlers = utils.InputErrorHandlers InputOverrideConfig = utils.InputOverrideConfig diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 2fe1ec756..3ce2ca036 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -29,6 +29,7 @@ from typing_extensions import TypedDict from supertokens_python.async_to_sync_wrapper import sync +from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.types import ( MaybeAwaitable, RecipeUserId, @@ -38,12 +39,9 @@ from ...utils import resolve from .exceptions import ClaimValidationError -from .utils import SessionConfig, TokenTransferMethod if TYPE_CHECKING: - from supertokens_python.framework import BaseRequest - -from supertokens_python.framework import BaseResponse + from .utils import SessionConfig, TokenTransferMethod class SessionObj: diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index f4d137c90..75a0db200 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -20,7 +20,7 @@ from typing_extensions import Literal from supertokens_python.exceptions import raise_general_exception -from supertokens_python.framework import BaseResponse +from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.types.config import ( BaseConfig, @@ -47,7 +47,6 @@ ) if TYPE_CHECKING: - from supertokens_python.framework import BaseRequest from supertokens_python.recipe.openid import ( InputOverrideConfig as OpenIdInputOverrideConfig, ) diff --git a/supertokens_python/recipe/webauthn/interfaces/api.py b/supertokens_python/recipe/webauthn/interfaces/api.py index 819969ccc..fed8d64aa 100644 --- a/supertokens_python/recipe/webauthn/interfaces/api.py +++ b/supertokens_python/recipe/webauthn/interfaces/api.py @@ -36,7 +36,6 @@ SignInOptionsErrorResponse, UserVerification, ) -from supertokens_python.recipe.webauthn.types.config import NormalisedWebauthnConfig from supertokens_python.supertokens import AppInfo from supertokens_python.types import RecipeUserId, User from supertokens_python.types.base import UserContext @@ -48,6 +47,9 @@ StatusReasonResponseBaseModel, ) +if TYPE_CHECKING: + from supertokens_python.recipe.webauthn.types.config import NormalisedWebauthnConfig + class SignUpNotAllowedErrorResponse( StatusReasonResponseBaseModel[Literal["SIGN_UP_NOT_ALLOWED"], str] @@ -94,7 +96,7 @@ class TypeWebauthnRecoverAccountEmailDeliveryInput(CamelCaseBaseModel): class APIOptions(CamelCaseBaseModel): recipe_implementation: RecipeInterface app_info: AppInfo - config: NormalisedWebauthnConfig + config: "NormalisedWebauthnConfig" recipe_id: str req: BaseRequest res: BaseResponse diff --git a/supertokens_python/recipe/webauthn/interfaces/recipe.py b/supertokens_python/recipe/webauthn/interfaces/recipe.py index 1cb49389d..8f22da760 100644 --- a/supertokens_python/recipe/webauthn/interfaces/recipe.py +++ b/supertokens_python/recipe/webauthn/interfaces/recipe.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import ( Any, Dict, diff --git a/supertokens_python/recipe/webauthn/types/config.py b/supertokens_python/recipe/webauthn/types/config.py index 37e32dd14..c22254b6f 100644 --- a/supertokens_python/recipe/webauthn/types/config.py +++ b/supertokens_python/recipe/webauthn/types/config.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Protocol, TypeVar, Union, runtime_checkable +from typing import Optional, Protocol, TypeVar, Union, runtime_checkable from supertokens_python.framework import BaseRequest from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient @@ -22,6 +22,11 @@ EmailDeliveryConfig, EmailDeliveryConfigWithService, ) +from supertokens_python.recipe.webauthn.interfaces.api import ( + APIInterface, + TypeWebauthnEmailDeliveryInput, +) +from supertokens_python.recipe.webauthn.interfaces.recipe import RecipeInterface from supertokens_python.types.base import UserContext from supertokens_python.types.config import ( BaseConfig, @@ -31,13 +36,6 @@ ) from supertokens_python.types.response import CamelCaseBaseModel -if TYPE_CHECKING: - from supertokens_python.recipe.webauthn.interfaces.api import ( - APIInterface, - TypeWebauthnEmailDeliveryInput, - ) - from supertokens_python.recipe.webauthn.interfaces.recipe import RecipeInterface - InterfaceType = TypeVar("InterfaceType") """Generic Type for use in `InterfaceOverride`""" From 4f43c373459592ed1f90d9ff134d36dedea4a42a Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Tue, 1 Jul 2025 15:10:01 +0530 Subject: [PATCH 13/37] try to make types work with subclasses --- supertokens_python/plugins.py | 57 +++++++++++++------ .../recipe/accountlinking/types.py | 3 + supertokens_python/recipe/dashboard/utils.py | 3 + .../recipe/emailpassword/utils.py | 3 + .../recipe/emailverification/utils.py | 7 +-- supertokens_python/recipe/jwt/utils.py | 3 + .../recipe/multifactorauth/types.py | 3 + .../recipe/multitenancy/utils.py | 3 + .../recipe/oauth2provider/utils.py | 9 ++- supertokens_python/recipe/openid/utils.py | 7 ++- .../recipe/passwordless/utils.py | 5 +- supertokens_python/recipe/session/utils.py | 10 ++-- supertokens_python/recipe/thirdparty/utils.py | 3 + supertokens_python/recipe/totp/types.py | 3 + .../recipe/usermetadata/utils.py | 9 ++- supertokens_python/recipe/userroles/utils.py | 3 + .../recipe/webauthn/types/config.py | 3 + supertokens_python/types/config.py | 32 +++++------ supertokens_python/types/utils.py | 16 ++++++ 19 files changed, 131 insertions(+), 51 deletions(-) create mode 100644 supertokens_python/types/utils.py diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 77a5706a8..131fb3449 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -47,6 +47,8 @@ from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.types import MaybeAwaitable from supertokens_python.types.base import UserContext +from supertokens_python.types.config import BaseConfig, BaseConfigWithoutAPIOverride +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import CamelCaseBaseModel if TYPE_CHECKING: @@ -56,7 +58,10 @@ ) from supertokens_python.supertokens import SupertokensPublicConfig -T = TypeVar("T") +RecipeInterfaceType = TypeVar("RecipeInterfaceType", bound=BaseRecipeInterface) +APIInterfaceType = TypeVar("APIInterfaceType", bound=BaseAPIInterface) +ConfigType = BaseConfig[RecipeInterfaceType, APIInterfaceType] +# T = TypeVar("T", bound=ConfigType) # T = TypeVar("T", bound=Union[AccountLinkingConfig, DashboardConfig, EmailPasswordConfig, # EmailVerificationConfig, JWTConfig, MultiFactorAuthConfig, MultitenancyConfig, # OAuth2ProviderConfig, OpenIdConfig, PasswordlessConfig, SessionConfig, @@ -313,20 +318,31 @@ class ConfigOverrideBase: apis: Optional[Callable[[Any], Any]] = None -# TODO: Pass in the OverrideConfig class as an arg, use it to define a default if None -def apply_plugins(recipe_id: str, config: T, plugins: List[OverrideMap]) -> T: - # TODO: Change to recipe_implementation type - def default_fn_override(original_implementation: T) -> T: +def apply_plugins( + recipe_id: str, + config: Union[ + BaseConfig[RecipeInterfaceType, APIInterfaceType], + BaseConfigWithoutAPIOverride[RecipeInterfaceType], + ], + plugins: List[OverrideMap], +) -> Union[ + BaseConfig[RecipeInterfaceType, APIInterfaceType], + BaseConfigWithoutAPIOverride[RecipeInterfaceType], +]: + def default_fn_override( + original_implementation: RecipeInterfaceType, + ) -> RecipeInterfaceType: return original_implementation - # TODO: Change to api_implementation type - def default_api_override(original_implementation: T) -> T: + def default_api_override( + original_implementation: APIInterfaceType, + ) -> APIInterfaceType: return original_implementation - if config.override is None: - config.override = ConfigOverrideBase() - config.override.functions = default_fn_override - config.override.apis = default_api_override + if config.override is None: # type: ignore + raise TypeError( + f"Expected config.override to not be `None`. {recipe_id=} {config=}" + ) function_overrides = getattr(config.override, "functions", default_fn_override) api_overrides = getattr(config.override, "apis", default_api_override) @@ -358,8 +374,10 @@ def default_api_override(original_implementation: T) -> T: # Apply overrides in reverse order of definition # Plugins: [plugin1, plugin2] would be applied as [override, plugin2, plugin1, original] if len(function_layers) > 0: - # TODO: Change to recipe_interface type - def fn_override(original_implementation: T) -> T: + + def fn_override( + original_implementation: RecipeInterfaceType, + ) -> RecipeInterfaceType: # The layers will get called in reversed order for function_layer in function_layers: original_implementation = function_layer(original_implementation) @@ -367,10 +385,15 @@ def fn_override(original_implementation: T) -> T: config.override.functions = fn_override - # AccountLinking recipe does not have an API implementation - if len(api_layers) > 0 and recipe_id != "accountlinking": - # TODO: Change to api_interface type - def api_override(original_implementation: T) -> T: + if ( + len(api_layers) > 0 + # AccountLinking recipe does not have an API implementation, uses `BaseConfigWithoutAPIOverride` as base + and recipe_id != "accountlinking" + # `BaseConfig` is the base class for all configs with an API override. + and isinstance(config, BaseConfig) + ): + + def api_override(original_implementation: APIInterfaceType) -> APIInterfaceType: for api_layer in api_layers: original_implementation = api_layer(original_implementation) return original_implementation diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index ae8908edc..169430406 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -28,6 +28,7 @@ BaseInputOverrideConfigWithoutAPI, BaseOverrideConfigWithoutAPI, ) +from supertokens_python.types.utils import UseDefaultIfNone if TYPE_CHECKING: from supertokens_python.recipe.session import SessionContainer @@ -159,6 +160,7 @@ class AccountLinkingInputConfig(BaseInputConfigWithoutAPIOverride[RecipeInterfac Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] ] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class AccountLinkingConfig(BaseConfigWithoutAPIOverride[RecipeInterface]): @@ -175,3 +177,4 @@ class AccountLinkingConfig(BaseConfigWithoutAPIOverride[RecipeInterface]): ], Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index fdedddeab..552368631 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -24,6 +24,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -86,12 +87,14 @@ class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... class DashboardInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): api_key: Optional[str] = None admins: Optional[List[str]] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class DashboardConfig(BaseConfig[RecipeInterface, APIInterface]): api_key: Optional[str] admins: Optional[List[str]] auth_mode: str + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/emailpassword/utils.py b/supertokens_python/recipe/emailpassword/utils.py index 3dc72b230..dd19da841 100644 --- a/supertokens_python/recipe/emailpassword/utils.py +++ b/supertokens_python/recipe/emailpassword/utils.py @@ -30,6 +30,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface from .types import EmailTemplateVars, InputFormField, NormalisedFormField @@ -228,6 +229,7 @@ class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... class EmailPasswordInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): sign_up_feature: Union[InputSignUpFeature, None] = None email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class EmailPasswordConfig(BaseConfig[RecipeInterface, APIInterface]): @@ -237,6 +239,7 @@ class EmailPasswordConfig(BaseConfig[RecipeInterface, APIInterface]): get_email_delivery_config: Callable[ [RecipeInterface], EmailDeliveryConfigWithService[EmailTemplateVars] ] + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/emailverification/utils.py b/supertokens_python/recipe/emailverification/utils.py index 251687c67..45b2720cd 100644 --- a/supertokens_python/recipe/emailverification/utils.py +++ b/supertokens_python/recipe/emailverification/utils.py @@ -32,6 +32,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface, TypeGetEmailForUserIdFunction @@ -56,6 +57,7 @@ class EmailVerificationInputConfig(BaseInputConfig[RecipeInterface, APIInterface mode: MODE_TYPE email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class EmailVerificationConfig(BaseConfig[RecipeInterface, APIInterface]): @@ -64,15 +66,12 @@ class EmailVerificationConfig(BaseConfig[RecipeInterface, APIInterface]): [], EmailDeliveryConfigWithService[VerificationEmailTemplateVars] ] get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( app_info: AppInfo, input_config: EmailVerificationInputConfig, - # mode: MODE_TYPE, - # email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - # get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - # override: Union[OverrideConfig, None] = None, ) -> EmailVerificationConfig: if input_config.mode not in ["REQUIRED", "OPTIONAL"]: raise ValueError( diff --git a/supertokens_python/recipe/jwt/utils.py b/supertokens_python/recipe/jwt/utils.py index 6ed829407..6c3474663 100644 --- a/supertokens_python/recipe/jwt/utils.py +++ b/supertokens_python/recipe/jwt/utils.py @@ -21,6 +21,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface @@ -33,10 +34,12 @@ class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... class JWTInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): jwt_validity_seconds: Optional[int] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class JWTConfig(BaseConfig[RecipeInterface, APIInterface]): jwt_validity_seconds: int + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input(input_config: JWTInputConfig): diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index 571561e14..323337e20 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -25,6 +25,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface @@ -50,10 +51,12 @@ class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... class MultiFactorAuthInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): first_factors: Optional[List[str]] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class MultiFactorAuthConfig(BaseConfig[RecipeInterface, APIInterface]): first_factors: Optional[List[str]] + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 class FactorIds: diff --git a/supertokens_python/recipe/multitenancy/utils.py b/supertokens_python/recipe/multitenancy/utils.py index 786b1aff6..84db1840e 100644 --- a/supertokens_python/recipe/multitenancy/utils.py +++ b/supertokens_python/recipe/multitenancy/utils.py @@ -24,6 +24,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone from supertokens_python.utils import ( resolve, ) @@ -74,10 +75,12 @@ class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... class MultitenancyInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class MultitenancyConfig(BaseConfig[RecipeInterface, APIInterface]): get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId] + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py index 8fc7f5cc8..f33f19782 100644 --- a/supertokens_python/recipe/oauth2provider/utils.py +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -13,12 +13,15 @@ # under the License. from __future__ import annotations +from typing import Optional + from supertokens_python.types.config import ( BaseConfig, BaseInputConfig, BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface @@ -29,10 +32,12 @@ class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface] class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... -class OAuth2ProviderInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): ... +class OAuth2ProviderInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class OAuth2ProviderConfig(BaseConfig[RecipeInterface, APIInterface]): ... +class OAuth2ProviderConfig(BaseConfig[RecipeInterface, APIInterface]): + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input(input_config: OAuth2ProviderInputConfig): diff --git a/supertokens_python/recipe/openid/utils.py b/supertokens_python/recipe/openid/utils.py index f1c18a64f..821c14213 100644 --- a/supertokens_python/recipe/openid/utils.py +++ b/supertokens_python/recipe/openid/utils.py @@ -13,18 +13,19 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union +from supertokens_python.recipe.jwt import OverrideConfig as JWTOverrideConfig from supertokens_python.types.config import ( BaseConfig, BaseInputConfig, BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone if TYPE_CHECKING: from supertokens_python import AppInfo - from supertokens_python.recipe.jwt import OverrideConfig as JWTOverrideConfig from supertokens_python.normalised_url_domain import NormalisedURLDomain @@ -42,11 +43,13 @@ class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... class OpenIdInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): issuer: Union[str, None] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class OpenIdConfig(BaseConfig[RecipeInterface, APIInterface]): issuer_domain: NormalisedURLDomain issuer_path: NormalisedURLPath + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/passwordless/utils.py b/supertokens_python/recipe/passwordless/utils.py index 5f4901ab1..4487be33b 100644 --- a/supertokens_python/recipe/passwordless/utils.py +++ b/supertokens_python/recipe/passwordless/utils.py @@ -16,7 +16,7 @@ from abc import ABC from re import fullmatch -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from phonenumbers import is_valid_number, parse from typing_extensions import Literal @@ -45,6 +45,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import ( APIInterface, @@ -155,6 +156,7 @@ class PasswordlessInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): sms_delivery: Union[SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None] = ( None ) + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class PasswordlessConfig(BaseConfig[RecipeInterface, APIInterface]): @@ -171,6 +173,7 @@ class PasswordlessConfig(BaseConfig[RecipeInterface, APIInterface]): get_custom_user_input_code: Union[ Callable[[str, Dict[str, Any]], Awaitable[str]], None ] + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 75a0db200..65a24e6ee 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -22,12 +22,16 @@ from supertokens_python.exceptions import raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.openid import ( + InputOverrideConfig as OpenIdInputOverrideConfig, +) from supertokens_python.types.config import ( BaseConfig, BaseInputConfig, BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone from supertokens_python.utils import ( is_an_ip_address, resolve, @@ -47,9 +51,6 @@ ) if TYPE_CHECKING: - from supertokens_python.recipe.openid import ( - InputOverrideConfig as OpenIdInputOverrideConfig, - ) from supertokens_python.supertokens import AppInfo from .recipe import SessionRecipe @@ -369,6 +370,7 @@ class SessionInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): use_dynamic_access_token_signing_key: Union[bool, None] = None expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None jwks_refresh_interval_sec: Union[int, None] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class SessionConfig(BaseConfig[RecipeInterface, APIInterface]): @@ -393,13 +395,13 @@ class SessionConfig(BaseConfig[RecipeInterface, APIInterface]): [BaseRequest, bool, Dict[str, Any]], Union[TokenTransferMethod, Literal["any"]], ] - # override: OverrideConfig, framework: str mode: str invalid_claim_status_code: int use_dynamic_access_token_signing_key: bool expose_access_token_to_frontend_in_cookie_based_auth: bool jwks_refresh_interval_sec: int + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/thirdparty/utils.py b/supertokens_python/recipe/thirdparty/utils.py index 0a4479fdd..1b4784ac6 100644 --- a/supertokens_python/recipe/thirdparty/utils.py +++ b/supertokens_python/recipe/thirdparty/utils.py @@ -23,6 +23,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface @@ -60,10 +61,12 @@ class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... class ThirdPartyInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): sign_in_and_up_feature: SignInAndUpFeature + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class ThirdPartyConfig(BaseConfig[RecipeInterface, APIInterface]): sign_in_and_up_feature: SignInAndUpFeature + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/totp/types.py b/supertokens_python/recipe/totp/types.py index 338be8725..cceb22b18 100644 --- a/supertokens_python/recipe/totp/types.py +++ b/supertokens_python/recipe/totp/types.py @@ -23,6 +23,7 @@ BaseOverrideConfig, ) from supertokens_python.types.response import APIResponse +from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface @@ -193,9 +194,11 @@ class TOTPConfig(BaseInputConfig[RecipeInterface, APIInterface]): issuer: Optional[str] = None default_skew: Optional[int] = None default_period: Optional[int] = None + override: UseDefaultIfNone[Optional[OverrideConfig]] = OverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class TOTPNormalisedConfig(BaseConfig[RecipeInterface, APIInterface]): issuer: str default_skew: int default_period: int + override: NormalisedOverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 diff --git a/supertokens_python/recipe/usermetadata/utils.py b/supertokens_python/recipe/usermetadata/utils.py index 4b226aaed..c8bf260bd 100644 --- a/supertokens_python/recipe/usermetadata/utils.py +++ b/supertokens_python/recipe/usermetadata/utils.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from supertokens_python.recipe.usermetadata.interfaces import ( APIInterface, @@ -26,6 +26,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone if TYPE_CHECKING: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe @@ -38,10 +39,12 @@ class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface] class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... -class UserMetadataInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): ... +class UserMetadataInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class UserMetadataConfig(BaseConfig[RecipeInterface, APIInterface]): ... +class UserMetadataConfig(BaseConfig[RecipeInterface, APIInterface]): + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/userroles/utils.py b/supertokens_python/recipe/userroles/utils.py index 17a17e9d3..707640e76 100644 --- a/supertokens_python/recipe/userroles/utils.py +++ b/supertokens_python/recipe/userroles/utils.py @@ -24,6 +24,7 @@ BaseInputOverrideConfig, BaseOverrideConfig, ) +from supertokens_python.types.utils import UseDefaultIfNone if TYPE_CHECKING: from supertokens_python.recipe.userroles.recipe import UserRolesRecipe @@ -38,11 +39,13 @@ class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... class UserRolesInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): skip_adding_roles_to_access_token: Optional[bool] = None skip_adding_permissions_to_access_token: Optional[bool] = None + override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class UserRolesConfig(BaseConfig[RecipeInterface, APIInterface]): skip_adding_roles_to_access_token: bool skip_adding_permissions_to_access_token: bool + override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/webauthn/types/config.py b/supertokens_python/recipe/webauthn/types/config.py index c22254b6f..8e3835b12 100644 --- a/supertokens_python/recipe/webauthn/types/config.py +++ b/supertokens_python/recipe/webauthn/types/config.py @@ -35,6 +35,7 @@ BaseOverrideConfig, ) from supertokens_python.types.response import CamelCaseBaseModel +from supertokens_python.types.utils import UseDefaultIfNone InterfaceType = TypeVar("InterfaceType") """Generic Type for use in `InterfaceOverride`""" @@ -195,6 +196,7 @@ class WebauthnConfig(BaseInputConfig[RecipeInterface, APIInterface]): get_origin: Optional[GetOrigin] = None email_delivery: Optional[EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]] = None validate_email_address: Optional[ValidateEmailAddress] = None + override: UseDefaultIfNone[Optional[OverrideConfig]] = OverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 class NormalisedWebauthnConfig(BaseConfig[RecipeInterface, APIInterface]): @@ -203,6 +205,7 @@ class NormalisedWebauthnConfig(BaseConfig[RecipeInterface, APIInterface]): get_origin: NormalisedGetOrigin get_email_delivery_config: NormalisedGetEmailDeliveryConfig validate_email_address: NormalisedValidateEmailAddress + override: NormalisedOverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 class WebauthnIngredients(CamelCaseBaseModel): diff --git a/supertokens_python/types/config.py b/supertokens_python/types/config.py index 43a20d3cb..8ad5c3d95 100644 --- a/supertokens_python/types/config.py +++ b/supertokens_python/types/config.py @@ -1,11 +1,14 @@ -from typing import Callable, Generic, Optional, TypeVar, Union +from typing import Callable, Generic, Optional, TypeVar from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import CamelCaseBaseModel +from supertokens_python.types.utils import UseDefaultIfNone -InterfaceType = TypeVar( - "InterfaceType", bound=Union[BaseRecipeInterface, BaseAPIInterface] -) +T = TypeVar("T") + +# InterfaceType = TypeVar( +# "InterfaceType", bound=Union[BaseRecipeInterface, BaseAPIInterface], covariant=True +# ) """Generic Type for use in `InterfaceOverride`""" FunctionInterfaceType = TypeVar("FunctionInterfaceType", bound=BaseRecipeInterface) """Generic Type for use in `FunctionOverrideConfig`""" @@ -13,18 +16,7 @@ """Generic Type for use in `APIOverrideConfig`""" -InterfaceOverride = Callable[[InterfaceType], InterfaceType] - -# @runtime_checkable -# class InterfaceOverride(Protocol[InterfaceType]): -# """ -# Callable signature for `.override.*`. -# """ - -# def __call__( -# self, -# original_implementation: InterfaceType, -# ) -> InterfaceType: ... +InterfaceOverride = Callable[[T], T] class BaseInputOverrideConfigWithoutAPI( @@ -32,7 +24,9 @@ class BaseInputOverrideConfigWithoutAPI( ): """Base class for input override config without API overrides.""" - functions: Optional[InterfaceOverride[FunctionInterfaceType]] = None + functions: UseDefaultIfNone[Optional[InterfaceOverride[FunctionInterfaceType]]] = ( + lambda original_implementation: original_implementation + ) class BaseOverrideConfigWithoutAPI(CamelCaseBaseModel, Generic[FunctionInterfaceType]): @@ -49,7 +43,9 @@ class BaseInputOverrideConfig( ): """Base class for input override config with API overrides.""" - apis: Optional[InterfaceOverride[APIInterfaceType]] = None + apis: UseDefaultIfNone[Optional[InterfaceOverride[APIInterfaceType]]] = ( + lambda original_implementation: original_implementation + ) class BaseOverrideConfig( diff --git a/supertokens_python/types/utils.py b/supertokens_python/types/utils.py new file mode 100644 index 000000000..4e50039cb --- /dev/null +++ b/supertokens_python/types/utils.py @@ -0,0 +1,16 @@ +from typing import Any, TypeVar + +from pydantic import BeforeValidator +from pydantic_core import PydanticUseDefault +from typing_extensions import Annotated + +T = TypeVar("T") + + +def default_if_none(value: Any) -> Any: + if value is None: + return PydanticUseDefault() + return value + + +UseDefaultIfNone = Annotated[T, BeforeValidator(default_if_none)] From 20b173a0a91df2b7af86722afbe71f08ee01cf81 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Tue, 1 Jul 2025 20:15:10 +0530 Subject: [PATCH 14/37] refactor: rename classes for consistency, fix plugin types --- supertokens_python/plugins.py | 159 +++++++----------- .../recipe/accountlinking/__init__.py | 4 +- .../recipe/accountlinking/recipe.py | 23 ++- .../accountlinking/recipe_implementation.py | 4 +- .../recipe/accountlinking/types.py | 26 ++- .../recipe/accountlinking/utils.py | 52 +++--- .../recipe/dashboard/__init__.py | 4 +- .../recipe/dashboard/interfaces.py | 8 +- supertokens_python/recipe/dashboard/recipe.py | 16 +- .../recipe/dashboard/recipe_implementation.py | 4 +- supertokens_python/recipe/dashboard/utils.py | 51 +++--- .../recipe/emailpassword/__init__.py | 4 +- .../recipe/emailpassword/interfaces.py | 6 +- .../recipe/emailpassword/recipe.py | 25 +-- .../emailpassword/recipe_implementation.py | 4 +- .../recipe/emailpassword/utils.py | 59 +++---- .../recipe/emailverification/__init__.py | 6 +- .../recipe/emailverification/interfaces.py | 4 +- .../recipe/emailverification/recipe.py | 21 +-- .../recipe/emailverification/utils.py | 59 +++---- supertokens_python/recipe/jwt/__init__.py | 4 +- supertokens_python/recipe/jwt/interfaces.py | 4 +- supertokens_python/recipe/jwt/recipe.py | 19 ++- .../recipe/jwt/recipe_implementation.py | 6 +- supertokens_python/recipe/jwt/utils.py | 40 ++--- .../recipe/multifactorauth/__init__.py | 6 +- .../recipe/multifactorauth/interfaces.py | 4 +- .../recipe/multifactorauth/recipe.py | 19 ++- .../recipe/multifactorauth/types.py | 21 ++- .../recipe/multifactorauth/utils.py | 26 +-- .../recipe/multitenancy/__init__.py | 7 +- .../recipe/multitenancy/interfaces.py | 4 +- .../recipe/multitenancy/recipe.py | 19 ++- .../multitenancy/recipe_implementation.py | 4 +- .../recipe/multitenancy/utils.py | 43 +++-- .../recipe/oauth2provider/__init__.py | 4 +- .../recipe/oauth2provider/interfaces.py | 6 +- .../recipe/oauth2provider/recipe.py | 21 +-- .../recipe/oauth2provider/utils.py | 40 ++--- supertokens_python/recipe/openid/__init__.py | 4 +- .../recipe/openid/interfaces.py | 4 +- supertokens_python/recipe/openid/recipe.py | 22 ++- .../recipe/openid/recipe_implementation.py | 4 +- supertokens_python/recipe/openid/utils.py | 59 +++---- .../recipe/passwordless/__init__.py | 4 +- .../recipe/passwordless/interfaces.py | 4 +- .../recipe/passwordless/recipe.py | 19 ++- .../recipe/passwordless/utils.py | 72 ++++---- supertokens_python/recipe/session/__init__.py | 4 +- .../recipe/session/access_token.py | 4 +- .../recipe/session/cookie_and_header.py | 22 +-- .../recipe/session/interfaces.py | 6 +- supertokens_python/recipe/session/jwks.py | 6 +- supertokens_python/recipe/session/recipe.py | 49 +++--- .../recipe/session/recipe_implementation.py | 6 +- .../session/session_request_functions.py | 8 +- supertokens_python/recipe/session/utils.py | 87 ++++------ .../recipe/thirdparty/__init__.py | 4 +- .../recipe/thirdparty/interfaces.py | 6 +- .../recipe/thirdparty/recipe.py | 19 ++- supertokens_python/recipe/thirdparty/utils.py | 45 +++-- supertokens_python/recipe/totp/interfaces.py | 4 +- supertokens_python/recipe/totp/recipe.py | 11 +- .../recipe/totp/recipe_implementation.py | 4 +- supertokens_python/recipe/totp/types.py | 19 +-- supertokens_python/recipe/totp/utils.py | 10 +- .../recipe/usermetadata/__init__.py | 2 +- .../recipe/usermetadata/recipe.py | 17 +- .../recipe/usermetadata/utils.py | 31 ++-- .../recipe/userroles/__init__.py | 2 +- supertokens_python/recipe/userroles/recipe.py | 17 +- supertokens_python/recipe/userroles/utils.py | 48 +++--- .../recipe/webauthn/__init__.py | 4 +- supertokens_python/recipe/webauthn/recipe.py | 6 +- .../recipe/webauthn/types/config.py | 19 +-- supertokens_python/recipe/webauthn/utils.py | 4 +- supertokens_python/supertokens.py | 14 +- supertokens_python/types/config.py | 47 +++--- tests/Django/test_django.py | 22 +-- tests/Fastapi/test_fastapi.py | 40 +++-- tests/Flask/test_flask.py | 12 +- tests/auth-react/django3x/mysite/utils.py | 20 ++- tests/auth-react/fastapi-server/app.py | 20 ++- tests/auth-react/flask-server/app.py | 20 ++- tests/dashboard/test_dashboard.py | 18 +- tests/emailpassword/test_emailexists.py | 2 +- tests/emailpassword/test_emailverify.py | 14 +- tests/emailpassword/test_signin.py | 2 +- .../django2x/polls/views.py | 6 +- .../django3x/polls/views.py | 6 +- .../drf_async/mysite/settings.py | 2 +- .../drf_async/polls/views.py | 6 +- .../drf_sync/mysite/settings.py | 2 +- .../drf_sync/polls/views.py | 6 +- .../frontendIntegration/fastapi-server/app.py | 6 +- tests/frontendIntegration/flask-server/app.py | 6 +- tests/jwt/test_get_JWKS.py | 4 +- tests/jwt/test_override.py | 6 +- .../claims/test_create_new_session.py | 2 +- tests/sessions/claims/test_verify_session.py | 4 +- tests/sessions/claims/utils.py | 2 +- tests/telemetry/test_telemetry.py | 4 +- tests/test-server/app.py | 22 ++- tests/test_session.py | 12 +- tests/test_user_context.py | 12 +- tests/usermetadata/test_metadata.py | 6 +- 106 files changed, 897 insertions(+), 940 deletions(-) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 131fb3449..b04ed42db 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -1,10 +1,3 @@ -# TODOs: -# - [ ] Define base classes for: -# - Config -# - RecipeInterface -# - APIInterface -# - OverrideConfig - from collections import deque from dataclasses import dataclass from typing import ( @@ -28,64 +21,67 @@ from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse from supertokens_python.logger import log_debug_message - -# from supertokens_python.recipe.accountlinking.types import AccountLinkingConfig -# from supertokens_python.recipe.dashboard.utils import DashboardConfig -# from supertokens_python.recipe.emailpassword.utils import EmailPasswordConfig -# from supertokens_python.recipe.emailverification.utils import EmailVerificationConfig -# from supertokens_python.recipe.jwt.utils import JWTConfig -# from supertokens_python.recipe.multifactorauth.types import MultiFactorAuthConfig -# from supertokens_python.recipe.multitenancy.utils import MultitenancyConfig -# from supertokens_python.recipe.oauth2provider.utils import OAuth2ProviderConfig -# from supertokens_python.recipe.openid.utils import OpenIdConfig -# from supertokens_python.recipe.passwordless.utils import PasswordlessConfig -# from supertokens_python.recipe.session.utils import SessionConfig -# from supertokens_python.recipe.thirdparty.utils import ThirdPartyConfig -# from supertokens_python.recipe.totp.types import TOTPConfig -# from supertokens_python.recipe.usermetadata.utils import UserMetadataConfig -# from supertokens_python.recipe.userroles.utils import UserRolesConfig -from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.types import MaybeAwaitable from supertokens_python.types.base import UserContext -from supertokens_python.types.config import BaseConfig, BaseConfigWithoutAPIOverride +from supertokens_python.types.config import ( + BaseConfig, + BaseConfigWithoutAPIOverride, + BaseOverrideConfig, +) from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import CamelCaseBaseModel if TYPE_CHECKING: - from supertokens_python.recipe.session.interfaces import ( - SessionClaimValidator, - SessionContainer, - ) + from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.supertokens import SupertokensPublicConfig +from supertokens_python.recipe.accountlinking.types import AccountLinkingConfig +from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.emailpassword.utils import EmailPasswordConfig +from supertokens_python.recipe.emailverification.utils import ( + EmailVerificationConfig, +) +from supertokens_python.recipe.jwt.utils import JWTConfig +from supertokens_python.recipe.multifactorauth.types import MultiFactorAuthConfig +from supertokens_python.recipe.multitenancy.utils import MultitenancyConfig +from supertokens_python.recipe.oauth2provider.utils import OAuth2ProviderConfig +from supertokens_python.recipe.openid.utils import OpenIdConfig +from supertokens_python.recipe.passwordless.utils import PasswordlessConfig +from supertokens_python.recipe.session.interfaces import ( + SessionClaimValidator, + SessionContainer, +) +from supertokens_python.recipe.session.utils import SessionConfig +from supertokens_python.recipe.thirdparty.utils import ThirdPartyConfig +from supertokens_python.recipe.totp.types import TOTPConfig +from supertokens_python.recipe.usermetadata.utils import UserMetadataConfig +from supertokens_python.recipe.userroles.utils import UserRolesConfig +from supertokens_python.recipe.webauthn.types.config import WebauthnConfig + +T = TypeVar( + "T", + bound=Union[ + AccountLinkingConfig, + DashboardConfig, + EmailPasswordConfig, + EmailVerificationConfig, + JWTConfig, + MultiFactorAuthConfig, + MultitenancyConfig, + OAuth2ProviderConfig, + OpenIdConfig, + PasswordlessConfig, + SessionConfig, + ThirdPartyConfig, + TOTPConfig, + UserMetadataConfig, + UserRolesConfig, + WebauthnConfig, + ], +) + RecipeInterfaceType = TypeVar("RecipeInterfaceType", bound=BaseRecipeInterface) APIInterfaceType = TypeVar("APIInterfaceType", bound=BaseAPIInterface) -ConfigType = BaseConfig[RecipeInterfaceType, APIInterfaceType] -# T = TypeVar("T", bound=ConfigType) -# T = TypeVar("T", bound=Union[AccountLinkingConfig, DashboardConfig, EmailPasswordConfig, -# EmailVerificationConfig, JWTConfig, MultiFactorAuthConfig, MultitenancyConfig, -# OAuth2ProviderConfig, OpenIdConfig, PasswordlessConfig, SessionConfig, -# ThirdPartyConfig, TOTPConfig, UserMetadataConfig, UserRolesConfig]) - - -# class AllRecipeConfigs: -# # These generally have no Input config type -# accountlinking: AccountLinkingConfig -# dashboard: DashboardConfig -# emailpassword: EmailPasswordConfig -# emailverification: EmailVerificationConfig -# jwt: JWTConfig -# multifactorauth: MultiFactorAuthConfig -# multitenancy: MultitenancyConfig -# oauth2provider: OAuth2ProviderConfig -# openid: OpenIdConfig -# passwordless: PasswordlessConfig -# session: SessionConfig -# thirdparty: ThirdPartyConfig -# totp: TOTPConfig # This is the input config type -# usermetadata: UserMetadataConfig -# userroles: UserRolesConfig -# # webauthn: WebauthnConfig class RecipePluginOverride: @@ -95,31 +91,6 @@ class RecipePluginOverride: config: Optional[Callable[[Any], Any]] -# export type AllRecipeConfigs = { -# accountlinking: AccountLinkingTypeInput & { override?: { apis: never } }; -# dashboard: DashboardTypeInput; -# emailpassword: EmailPasswordTypeInput; -# emailverification: EmailVerificationTypeInput; -# jwt: JWTTypeInput; -# multifactorauth: MultifactorAuthTypeInput; -# multitenancy: MultitenancyTypeInput; -# oauth2provider: OAuth2ProviderTypeInput; -# openid: OpenIdTypeInput; -# passwordless: PasswordlessTypeInput; -# session: SessionTypeInput; -# thirdparty: ThirdPartyTypeInput; -# totp: TotpTypeInput; -# usermetadata: UserMetadataTypeInput; -# userroles: UserRolesTypeInput; -# }; - -# export type RecipePluginOverride = { -# functions?: NonNullable["functions"]; -# apis?: NonNullable["apis"]; -# config?: (config: AllRecipeConfigs[T]) => AllRecipeConfigs[T]; -# }; - - class PluginRouteHandlerResponse(CamelCaseBaseModel): status: int body: Any @@ -320,15 +291,14 @@ class ConfigOverrideBase: def apply_plugins( recipe_id: str, - config: Union[ - BaseConfig[RecipeInterfaceType, APIInterfaceType], - BaseConfigWithoutAPIOverride[RecipeInterfaceType], - ], + config: T, plugins: List[OverrideMap], -) -> Union[ - BaseConfig[RecipeInterfaceType, APIInterfaceType], - BaseConfigWithoutAPIOverride[RecipeInterfaceType], -]: +) -> T: + if not isinstance(config, (BaseConfig, BaseConfigWithoutAPIOverride)): # type: ignore + raise TypeError( + f"Expected config to be an instance of BaseConfig or BaseConfigWithoutAPIOverride. {recipe_id=} {config=}" + ) + def default_fn_override( original_implementation: RecipeInterfaceType, ) -> RecipeInterfaceType: @@ -339,10 +309,11 @@ def default_api_override( ) -> APIInterfaceType: return original_implementation - if config.override is None: # type: ignore - raise TypeError( - f"Expected config.override to not be `None`. {recipe_id=} {config=}" - ) + if config.override is None: + if isinstance(config, BaseConfigWithoutAPIOverride): + config.override = BaseConfigWithoutAPIOverride() # type: ignore + else: + config.override = BaseOverrideConfig() # type: ignore function_overrides = getattr(config.override, "functions", default_fn_override) api_overrides = getattr(config.override, "apis", default_api_override) @@ -383,7 +354,7 @@ def fn_override( original_implementation = function_layer(original_implementation) return original_implementation - config.override.functions = fn_override + config.override.functions = fn_override # type: ignore if ( len(api_layers) > 0 @@ -398,7 +369,7 @@ def api_override(original_implementation: APIInterfaceType) -> APIInterfaceType: original_implementation = api_layer(original_implementation) return original_implementation - config.override.apis = api_override + config.override.apis = api_override # type: ignore return config diff --git a/supertokens_python/recipe/accountlinking/__init__.py b/supertokens_python/recipe/accountlinking/__init__.py index 16d7c9d91..aac4576a3 100644 --- a/supertokens_python/recipe/accountlinking/__init__.py +++ b/supertokens_python/recipe/accountlinking/__init__.py @@ -19,7 +19,7 @@ from . import types from .recipe import AccountLinkingRecipe -InputOverrideConfig = types.InputOverrideConfig +AccountLinkingOverrideConfig = types.AccountLinkingOverrideConfig RecipeLevelUser = types.RecipeLevelUser AccountInfoWithRecipeIdAndUserId = types.AccountInfoWithRecipeIdAndUserId ShouldAutomaticallyLink = types.ShouldAutomaticallyLink @@ -45,7 +45,7 @@ def init( Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] ] = None, - override: Optional[InputOverrideConfig] = None, + override: Optional[AccountLinkingOverrideConfig] = None, ): return AccountLinkingRecipe.init( on_account_linked, should_do_automatic_account_linking, override diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 2b8bd057a..06462739f 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -23,7 +23,6 @@ log_debug_message, ) from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.querier import Querier from supertokens_python.recipe_module import APIHandled, RecipeModule @@ -35,8 +34,8 @@ from .types import ( AccountInfoWithRecipeId, AccountInfoWithRecipeIdAndUserId, - AccountLinkingInputConfig, - InputOverrideConfig, + AccountLinkingConfig, + AccountLinkingOverrideConfig, RecipeLevelUser, ShouldAutomaticallyLink, ShouldNotAutomaticallyLink, @@ -79,12 +78,10 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: AccountLinkingInputConfig, + config: AccountLinkingConfig, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input( - app_info, input_config=input_config - ) + self.config = validate_and_normalise_user_input(app_info, config=config) recipe_implementation: RecipeInterface = RecipeImplementation( Querier.get_instance(recipe_id), self, self.config ) @@ -147,9 +144,11 @@ def init( Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] ] = None, - override: Optional[InputOverrideConfig] = None, - ): - input_config = AccountLinkingInputConfig( + override: Optional[AccountLinkingOverrideConfig] = None, + ) -> Callable[..., AccountLinkingRecipe]: + from supertokens_python.plugins import OverrideMap, apply_plugins + + cofnfig = AccountLinkingConfig( on_account_linked=on_account_linked, should_do_automatic_account_linking=should_do_automatic_account_linking, override=override, @@ -160,9 +159,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): AccountLinkingRecipe.__instance = AccountLinkingRecipe( recipe_id=AccountLinkingRecipe.recipe_id, app_info=app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=AccountLinkingRecipe.recipe_id, - config=input_config, + config=cofnfig, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index 2eb9a89e6..43991ab5c 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -40,7 +40,7 @@ RecipeInterface, UnlinkAccountOkResult, ) -from .types import AccountLinkingConfig, RecipeLevelUser +from .types import NormalisedAccountLinkingConfig, RecipeLevelUser if TYPE_CHECKING: from supertokens_python.querier import Querier @@ -53,7 +53,7 @@ def __init__( self, querier: Querier, recipe_instance: AccountLinkingRecipe, - config: AccountLinkingConfig, + config: NormalisedAccountLinkingConfig, ): super().__init__() self.querier = querier diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 169430406..8de3e89db 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -24,11 +24,10 @@ from supertokens_python.types import AccountInfo from supertokens_python.types.config import ( BaseConfigWithoutAPIOverride, - BaseInputConfigWithoutAPIOverride, - BaseInputOverrideConfigWithoutAPI, - BaseOverrideConfigWithoutAPI, + BaseNormalisedConfigWithoutAPIOverride, + NormalisedOverrideConfigWithoutAPI, + OverrideConfigWithoutAPI, ) -from supertokens_python.types.utils import UseDefaultIfNone if TYPE_CHECKING: from supertokens_python.recipe.session import SessionContainer @@ -40,6 +39,11 @@ User, ) +AccountLinkingOverrideConfig = OverrideConfigWithoutAPI[RecipeInterface] +NormalisedAccountLinkingOverrideConfig = NormalisedOverrideConfigWithoutAPI[ + RecipeInterface +] + class AccountInfoWithRecipeId(AccountInfo): def __init__( @@ -138,13 +142,7 @@ def __init__(self, should_require_verification: bool): self.should_require_verification = should_require_verification -class InputOverrideConfig(BaseInputOverrideConfigWithoutAPI[RecipeInterface]): ... - - -class OverrideConfig(BaseOverrideConfigWithoutAPI[RecipeInterface]): ... - - -class AccountLinkingInputConfig(BaseInputConfigWithoutAPIOverride[RecipeInterface]): +class AccountLinkingConfig(BaseConfigWithoutAPIOverride[RecipeInterface]): on_account_linked: Optional[ Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] ] = None @@ -160,10 +158,11 @@ class AccountLinkingInputConfig(BaseInputConfigWithoutAPIOverride[RecipeInterfac Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] ] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class AccountLinkingConfig(BaseConfigWithoutAPIOverride[RecipeInterface]): +class NormalisedAccountLinkingConfig( + BaseNormalisedConfigWithoutAPIOverride[RecipeInterface] +): on_account_linked: Callable[ [User, RecipeLevelUser, Dict[str, Any]], Awaitable[None] ] @@ -177,4 +176,3 @@ class AccountLinkingConfig(BaseConfigWithoutAPIOverride[RecipeInterface]): ], Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py index d008ad725..98f9beb2b 100644 --- a/supertokens_python/recipe/accountlinking/utils.py +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -15,21 +15,20 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Union -from supertokens_python.recipe.accountlinking.types import AccountLinkingInputConfig - -if TYPE_CHECKING: - from .types import ( - AccountInfoWithRecipeIdAndUserId, - AccountLinkingConfig, - RecipeLevelUser, - SessionContainer, - ShouldAutomaticallyLink, - ShouldNotAutomaticallyLink, - User, - ) +from supertokens_python.recipe.accountlinking.types import ( + AccountInfoWithRecipeIdAndUserId, + AccountLinkingConfig, + NormalisedAccountLinkingConfig, + NormalisedAccountLinkingOverrideConfig, + RecipeLevelUser, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, +) if TYPE_CHECKING: + from supertokens_python.recipe.session.interfaces import SessionContainer from supertokens_python.supertokens import AppInfo + from supertokens_python.types.base import User async def default_on_account_linked(_: User, __: RecipeLevelUser, ___: Dict[str, Any]): @@ -59,34 +58,31 @@ def recipe_init_defined_should_do_automatic_account_linking() -> bool: def validate_and_normalise_user_input( _: AppInfo, - input_config: AccountLinkingInputConfig, -) -> AccountLinkingConfig: - from .types import AccountLinkingConfig, OverrideConfig - + config: AccountLinkingConfig, +) -> NormalisedAccountLinkingConfig: global _did_use_default_should_do_automatic_account_linking - override_config: OverrideConfig = OverrideConfig() + override_config: NormalisedAccountLinkingOverrideConfig = ( + NormalisedAccountLinkingOverrideConfig() + ) - if ( - input_config.override is not None - and input_config.override.functions is not None - ): - override_config.functions = input_config.override.functions + if config.override is not None and config.override.functions is not None: + override_config.functions = config.override.functions _did_use_default_should_do_automatic_account_linking = ( - input_config.should_do_automatic_account_linking is None + config.should_do_automatic_account_linking is None ) - return AccountLinkingConfig( + return NormalisedAccountLinkingConfig( override=override_config, on_account_linked=( default_on_account_linked - if input_config.on_account_linked is None - else input_config.on_account_linked + if config.on_account_linked is None + else config.on_account_linked ), should_do_automatic_account_linking=( default_should_do_automatic_account_linking - if input_config.should_do_automatic_account_linking is None - else input_config.should_do_automatic_account_linking + if config.should_do_automatic_account_linking is None + else config.should_do_automatic_account_linking ), ) diff --git a/supertokens_python/recipe/dashboard/__init__.py b/supertokens_python/recipe/dashboard/__init__.py index 0d2413522..7ccf61d00 100644 --- a/supertokens_python/recipe/dashboard/__init__.py +++ b/supertokens_python/recipe/dashboard/__init__.py @@ -21,13 +21,13 @@ from .recipe import DashboardRecipe -InputOverrideConfig = utils.InputOverrideConfig +DashboardOverrideConfig = utils.DashboardOverrideConfig def init( api_key: Optional[str] = None, admins: Optional[List[str]] = None, - override: Optional[InputOverrideConfig] = None, + override: Optional[DashboardOverrideConfig] = None, ) -> RecipeInit: return DashboardRecipe.init( api_key, diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index 6f5e15404..59c20ff88 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -28,7 +28,7 @@ from supertokens_python.recipe.session.interfaces import SessionInformationResult from ...supertokens import AppInfo - from .utils import DashboardConfig, UserWithMetadata + from .utils import NormalisedDashboardConfig, UserWithMetadata class SessionInfo: @@ -54,7 +54,7 @@ async def get_dashboard_bundle_location(self, user_context: Dict[str, Any]) -> s async def should_allow_access( self, request: BaseRequest, - config: DashboardConfig, + config: NormalisedDashboardConfig, user_context: Dict[str, Any], ) -> bool: pass @@ -66,14 +66,14 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: DashboardConfig, + config: NormalisedDashboardConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, ): self.request: BaseRequest = request self.response: BaseResponse = response self.recipe_id: str = recipe_id - self.config: DashboardConfig = config + self.config: NormalisedDashboardConfig = config self.recipe_implementation: RecipeInterface = recipe_implementation self.app_info = app_info diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 0bb134b80..0de83281e 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -150,8 +150,8 @@ VALIDATE_KEY_API, ) from .utils import ( - DashboardInputConfig, - InputOverrideConfig, + DashboardConfig, + DashboardOverrideConfig, validate_and_normalise_user_input, ) @@ -164,11 +164,11 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: DashboardInputConfig, + config: DashboardConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - input_config=input_config, + config=config, ) recipe_implementation = RecipeImplementation() self.recipe_implementation = self.config.override.functions( @@ -642,9 +642,9 @@ def get_all_cors_headers(self) -> List[str]: def init( api_key: Optional[str], admins: Optional[List[str]] = None, - override: Optional[InputOverrideConfig] = None, + override: Optional[DashboardOverrideConfig] = None, ): - input_config = DashboardInputConfig( + config = DashboardConfig( api_key=api_key, admins=admins, override=override, @@ -655,9 +655,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): DashboardRecipe.__instance = DashboardRecipe( recipe_id=DashboardRecipe.recipe_id, app_info=app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=DashboardRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/dashboard/recipe_implementation.py b/supertokens_python/recipe/dashboard/recipe_implementation.py index 881c14ebd..e46615ebd 100644 --- a/supertokens_python/recipe/dashboard/recipe_implementation.py +++ b/supertokens_python/recipe/dashboard/recipe_implementation.py @@ -27,7 +27,7 @@ from .exceptions import DashboardOperationNotAllowedError from .interfaces import RecipeInterface -from .utils import DashboardConfig, validate_api_key +from .utils import NormalisedDashboardConfig, validate_api_key class RecipeImplementation(RecipeInterface): @@ -37,7 +37,7 @@ async def get_dashboard_bundle_location(self, user_context: Dict[str, Any]) -> s async def should_allow_access( self, request: BaseRequest, - config: DashboardConfig, + config: NormalisedDashboardConfig, user_context: Dict[str, Any], ) -> bool: # For cases where we're not using the API key, the JWT is being used; we allow their access by default diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 552368631..33761851d 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -20,11 +20,10 @@ from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -78,51 +77,51 @@ def to_json(self) -> Dict[str, Any]: return user_json -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... +DashboardOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedDashboardOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... - - -class DashboardInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class DashboardConfig(BaseConfig[RecipeInterface, APIInterface]): api_key: Optional[str] = None admins: Optional[List[str]] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class DashboardConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedDashboardConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): api_key: Optional[str] admins: Optional[List[str]] auth_mode: str - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( - input_config: DashboardInputConfig, -) -> DashboardConfig: - override_config: OverrideConfig = OverrideConfig() + config: DashboardConfig, +) -> NormalisedDashboardConfig: + override_config: NormalisedDashboardOverrideConfig = ( + NormalisedDashboardOverrideConfig() + ) - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis - if input_config.api_key is not None and input_config.admins is not None: + if config.api_key is not None and config.admins is not None: log_debug_message( "User Dashboard: Providing 'admins' has no effect when using an api key." ) admins = ( - [normalise_email(a) for a in input_config.admins] - if input_config.admins is not None + [normalise_email(a) for a in config.admins] + if config.admins is not None else None ) - auth_mode = "api-key" if input_config.api_key else "email-password" + auth_mode = "api-key" if config.api_key else "email-password" - return DashboardConfig( - api_key=input_config.api_key, + return NormalisedDashboardConfig( + api_key=config.api_key, admins=admins, auth_mode=auth_mode, override=override_config, @@ -257,7 +256,7 @@ async def _get_user_for_recipe_id( async def validate_api_key( - req: BaseRequest, config: DashboardConfig, _user_context: Dict[str, Any] + req: BaseRequest, config: NormalisedDashboardConfig, _user_context: Dict[str, Any] ) -> bool: api_key_header_value = req.get_header("authorization") if not api_key_header_value: diff --git a/supertokens_python/recipe/emailpassword/__init__.py b/supertokens_python/recipe/emailpassword/__init__.py index ce5e2b555..8df1a4691 100644 --- a/supertokens_python/recipe/emailpassword/__init__.py +++ b/supertokens_python/recipe/emailpassword/__init__.py @@ -25,7 +25,7 @@ from .recipe import EmailPasswordRecipe exceptions = ex -InputOverrideConfig = utils.InputOverrideConfig +EmailPasswordOverrideConfig = utils.EmailPasswordOverrideConfig InputSignUpFeature = utils.InputSignUpFeature InputFormField = utils.InputFormField SMTPService = emaildelivery_services.SMTPService @@ -37,7 +37,7 @@ def init( sign_up_feature: Union[utils.InputSignUpFeature, None] = None, - override: Union[utils.InputOverrideConfig, None] = None, + override: Union[utils.EmailPasswordOverrideConfig, None] = None, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, ) -> RecipeInit: return EmailPasswordRecipe.init( diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index fc4927826..b3f669d71 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -33,7 +33,7 @@ from ...types import User from .types import FormField - from .utils import EmailPasswordConfig + from .utils import NormalisedEmailPasswordConfig class SignUpOkResult: @@ -219,7 +219,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: EmailPasswordConfig, + config: NormalisedEmailPasswordConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, email_delivery: EmailDeliveryIngredient[EmailTemplateVars], @@ -227,7 +227,7 @@ def __init__( self.request: BaseRequest = request self.response: BaseResponse = response self.recipe_id: str = recipe_id - self.config: EmailPasswordConfig = config + self.config: NormalisedEmailPasswordConfig = config self.recipe_implementation: RecipeInterface = recipe_implementation self.app_info = app_info self.email_delivery = email_delivery diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index f31c69501..5c506986d 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -14,13 +14,12 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.auth_utils import is_fake_email from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.recipe.emailpassword.types import ( EmailPasswordIngredients, EmailTemplateVars, @@ -69,8 +68,8 @@ USER_PASSWORD_RESET_TOKEN, ) from .utils import ( - EmailPasswordInputConfig, - InputOverrideConfig, + EmailPasswordConfig, + EmailPasswordOverrideConfig, InputSignUpFeature, validate_and_normalise_user_input, ) @@ -86,12 +85,12 @@ def __init__( recipe_id: str, app_info: AppInfo, ingredients: EmailPasswordIngredients, - input_config: EmailPasswordInputConfig, + config: EmailPasswordConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( app_info, - input_config=input_config, + config=config, ) recipe_implementation = RecipeImplementation( @@ -362,11 +361,13 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( - sign_up_feature: Union[InputSignUpFeature, None] = None, - override: Union[InputOverrideConfig, None] = None, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, + sign_up_feature: Optional[InputSignUpFeature] = None, + override: Optional[EmailPasswordOverrideConfig] = None, + email_delivery: Optional[EmailDeliveryConfig[EmailTemplateVars]] = None, ): - input_config = EmailPasswordInputConfig( + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = EmailPasswordConfig( sign_up_feature=sign_up_feature, email_delivery=email_delivery, override=override, @@ -379,9 +380,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): recipe_id=EmailPasswordRecipe.recipe_id, app_info=app_info, ingredients=ingredients, - input_config=apply_plugins( + config=apply_plugins( recipe_id=EmailPasswordRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index 453506ef1..d21fb9eed 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -42,7 +42,7 @@ UpdateEmailOrPasswordOkResult, WrongCredentialsError, ) -from .utils import EmailPasswordConfig +from .utils import NormalisedEmailPasswordConfig if TYPE_CHECKING: from supertokens_python.querier import Querier @@ -52,7 +52,7 @@ class RecipeImplementation(RecipeInterface): def __init__( self, querier: Querier, - ep_config: EmailPasswordConfig, + ep_config: NormalisedEmailPasswordConfig, ): super().__init__() self.querier = querier diff --git a/supertokens_python/recipe/emailpassword/utils.py b/supertokens_python/recipe/emailpassword/utils.py index dd19da841..36953b309 100644 --- a/supertokens_python/recipe/emailpassword/utils.py +++ b/supertokens_python/recipe/emailpassword/utils.py @@ -26,22 +26,19 @@ ) from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone +from supertokens_python.utils import get_filtered_list +from .constants import FORM_FIELD_EMAIL_ID, FORM_FIELD_PASSWORD_ID from .interfaces import APIInterface, RecipeInterface from .types import EmailTemplateVars, InputFormField, NormalisedFormField if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo -from supertokens_python.utils import get_filtered_list - -from .constants import FORM_FIELD_EMAIL_ID, FORM_FIELD_PASSWORD_ID - async def default_validator(_: str, __: str) -> Union[str, None]: return None @@ -220,54 +217,54 @@ def validate_and_normalise_reset_password_using_token_config( ) -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... - - -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... +EmailPasswordOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedEmailPasswordOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class EmailPasswordInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class EmailPasswordConfig(BaseConfig[RecipeInterface, APIInterface]): sign_up_feature: Union[InputSignUpFeature, None] = None email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class EmailPasswordConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedEmailPasswordConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): sign_up_feature: SignUpFeature sign_in_feature: SignInFeature reset_password_using_token_feature: ResetPasswordUsingTokenFeature get_email_delivery_config: Callable[ [RecipeInterface], EmailDeliveryConfigWithService[EmailTemplateVars] ] - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( app_info: AppInfo, - input_config: EmailPasswordInputConfig, -) -> EmailPasswordConfig: + config: EmailPasswordConfig, +) -> NormalisedEmailPasswordConfig: # NOTE: We don't need to check the instance of sign_up_feature and override # as they will always be either None or the specified type. - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions + override_config = NormalisedEmailPasswordOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis - sign_up_feature = input_config.sign_up_feature + sign_up_feature = config.sign_up_feature if sign_up_feature is None: sign_up_feature = InputSignUpFeature() def get_email_delivery_config( ep_recipe: RecipeInterface, ) -> EmailDeliveryConfigWithService[EmailTemplateVars]: - if input_config.email_delivery and input_config.email_delivery.service: + if config.email_delivery and config.email_delivery.service: return EmailDeliveryConfigWithService( - service=input_config.email_delivery.service, - override=input_config.email_delivery.override, + service=config.email_delivery.service, + override=config.email_delivery.override, ) email_service = BackwardCompatibilityService( @@ -275,15 +272,15 @@ def get_email_delivery_config( recipe_interface_impl=ep_recipe, ) if ( - input_config.email_delivery is not None - and input_config.email_delivery.override is not None + config.email_delivery is not None + and config.email_delivery.override is not None ): - override = input_config.email_delivery.override + override = config.email_delivery.override else: override = None return EmailDeliveryConfigWithService(email_service, override=override) - return EmailPasswordConfig( + return NormalisedEmailPasswordConfig( sign_up_feature=SignUpFeature(sign_up_feature.form_fields), sign_in_feature=SignInFeature( normalise_sign_in_form_fields(sign_up_feature.form_fields) diff --git a/supertokens_python/recipe/emailverification/__init__.py b/supertokens_python/recipe/emailverification/__init__.py index b88bac6cd..848a16a16 100644 --- a/supertokens_python/recipe/emailverification/__init__.py +++ b/supertokens_python/recipe/emailverification/__init__.py @@ -22,9 +22,9 @@ from .interfaces import TypeGetEmailForUserIdFunction from .recipe import EmailVerificationRecipe from .types import EmailTemplateVars -from .utils import MODE_TYPE, OverrideConfig +from .utils import MODE_TYPE, EmailVerificationOverrideConfig -InputOverrideConfig = utils.OverrideConfig +InputOverrideConfig = utils.EmailVerificationOverrideConfig exception = ex SMTPService = emaildelivery_services.SMTPService EmailVerificationClaim = recipe.EmailVerificationClaim @@ -39,7 +39,7 @@ def init( mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - override: Union[OverrideConfig, None] = None, + override: Union[EmailVerificationOverrideConfig, None] = None, ) -> RecipeInit: return EmailVerificationRecipe.init( mode, diff --git a/supertokens_python/recipe/emailverification/interfaces.py b/supertokens_python/recipe/emailverification/interfaces.py index 8c00c38fd..88f1f4522 100644 --- a/supertokens_python/recipe/emailverification/interfaces.py +++ b/supertokens_python/recipe/emailverification/interfaces.py @@ -30,7 +30,7 @@ from supertokens_python.framework import BaseRequest, BaseResponse from .types import EmailVerificationUser, VerificationEmailTemplateVars - from .utils import EmailVerificationConfig + from .utils import NormalisedEmailVerificationConfig class CreateEmailVerificationTokenOkResult: @@ -141,7 +141,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: EmailVerificationConfig, + config: NormalisedEmailVerificationConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, email_delivery: EmailDeliveryIngredient[VerificationEmailTemplateVars], diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index 6dea1f3b5..2db012b64 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -19,7 +19,6 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.emailverification.exceptions import ( EmailVerificationInvalidTokenError, @@ -71,8 +70,8 @@ from .recipe_implementation import RecipeImplementation from .utils import ( MODE_TYPE, - EmailVerificationInputConfig, - InputOverrideConfig, + EmailVerificationConfig, + EmailVerificationOverrideConfig, validate_and_normalise_user_input, ) @@ -95,12 +94,12 @@ def __init__( recipe_id: str, app_info: AppInfo, ingredients: EmailVerificationIngredients, - input_config: EmailVerificationInputConfig, + config: EmailVerificationConfig, ) -> None: super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - app_info, - input_config=input_config, + app_info=app_info, + config=config, ) recipe_implementation = RecipeImplementation( @@ -199,9 +198,11 @@ def init( mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - override: Union[InputOverrideConfig, None] = None, + override: Optional[EmailVerificationOverrideConfig] = None, ): - input_config = EmailVerificationInputConfig( + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = EmailVerificationConfig( mode=mode, email_delivery=email_delivery, get_email_for_recipe_user_id=get_email_for_recipe_user_id, @@ -217,9 +218,9 @@ def func( EmailVerificationRecipe.recipe_id, app_info, ingredients, - input_config=apply_plugins( + config=apply_plugins( recipe_id=EmailVerificationRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/emailverification/utils.py b/supertokens_python/recipe/emailverification/utils.py index 45b2720cd..f032fe6ca 100644 --- a/supertokens_python/recipe/emailverification/utils.py +++ b/supertokens_python/recipe/emailverification/utils.py @@ -28,11 +28,10 @@ ) from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface, TypeGetEmailForUserIdFunction @@ -43,37 +42,35 @@ from .types import EmailTemplateVars, VerificationEmailTemplateVars - -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... - - -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... - - MODE_TYPE = Literal["REQUIRED", "OPTIONAL"] +EmailVerificationOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedEmailVerificationOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] + -class EmailVerificationInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class EmailVerificationConfig(BaseConfig[RecipeInterface, APIInterface]): mode: MODE_TYPE email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class EmailVerificationConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedEmailVerificationConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): mode: MODE_TYPE get_email_delivery_config: Callable[ [], EmailDeliveryConfigWithService[VerificationEmailTemplateVars] ] get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( app_info: AppInfo, - input_config: EmailVerificationInputConfig, -) -> EmailVerificationConfig: - if input_config.mode not in ["REQUIRED", "OPTIONAL"]: + config: EmailVerificationConfig, +) -> NormalisedEmailVerificationConfig: + if config.mode not in ["REQUIRED", "OPTIONAL"]: raise ValueError( "Email Verification recipe mode must be one of 'REQUIRED' or 'OPTIONAL'" ) @@ -82,34 +79,32 @@ def get_email_delivery_config() -> EmailDeliveryConfigWithService[ VerificationEmailTemplateVars ]: email_service = ( - input_config.email_delivery.service - if input_config.email_delivery is not None - else None + config.email_delivery.service if config.email_delivery is not None else None ) if email_service is None: email_service = BackwardCompatibilityService(app_info) if ( - input_config.email_delivery is not None - and input_config.email_delivery.override is not None + config.email_delivery is not None + and config.email_delivery.override is not None ): - override = input_config.email_delivery.override + override = config.email_delivery.override else: override = None return EmailDeliveryConfigWithService(email_service, override=override) - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions + override_config = NormalisedEmailVerificationOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis - return EmailVerificationConfig( - mode=input_config.mode, + return NormalisedEmailVerificationConfig( + mode=config.mode, get_email_delivery_config=get_email_delivery_config, - get_email_for_recipe_user_id=input_config.get_email_for_recipe_user_id, + get_email_for_recipe_user_id=config.get_email_for_recipe_user_id, override=override_config, ) diff --git a/supertokens_python/recipe/jwt/__init__.py b/supertokens_python/recipe/jwt/__init__.py index 7d0f1b3f3..baf2bfd54 100644 --- a/supertokens_python/recipe/jwt/__init__.py +++ b/supertokens_python/recipe/jwt/__init__.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Union from .recipe import JWTRecipe -from .utils import OverrideConfig +from .utils import JWTOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit @@ -24,6 +24,6 @@ def init( jwt_validity_seconds: Union[int, None] = None, - override: Union[OverrideConfig, None] = None, + override: Union[JWTOverrideConfig, None] = None, ) -> RecipeInit: return JWTRecipe.init(jwt_validity_seconds, override) diff --git a/supertokens_python/recipe/jwt/interfaces.py b/supertokens_python/recipe/jwt/interfaces.py index 5860c09cf..fbedd8ce6 100644 --- a/supertokens_python/recipe/jwt/interfaces.py +++ b/supertokens_python/recipe/jwt/interfaces.py @@ -19,7 +19,7 @@ from supertokens_python.types.response import APIResponse, GeneralErrorResponse if TYPE_CHECKING: - from .utils import JWTConfig + from .utils import NormalisedJWTConfig class JsonWebKey: @@ -72,7 +72,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: "JWTConfig", + config: "NormalisedJWTConfig", recipe_implementation: RecipeInterface, ): self.request = request diff --git a/supertokens_python/recipe/jwt/recipe.py b/supertokens_python/recipe/jwt/recipe.py index 6135535c1..5cfe20ccd 100644 --- a/supertokens_python/recipe/jwt/recipe.py +++ b/supertokens_python/recipe/jwt/recipe.py @@ -16,7 +16,6 @@ from os import environ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.jwt.api.implementation import APIImplementation from supertokens_python.recipe.jwt.api.jwks_get import jwks_get @@ -25,8 +24,8 @@ from supertokens_python.recipe.jwt.interfaces import APIOptions from supertokens_python.recipe.jwt.recipe_implementation import RecipeImplementation from supertokens_python.recipe.jwt.utils import ( - InputOverrideConfig, - JWTInputConfig, + JWTConfig, + JWTOverrideConfig, validate_and_normalise_user_input, ) @@ -48,10 +47,10 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: JWTInputConfig, + config: JWTConfig, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(input_config=input_config) + self.config = validate_and_normalise_user_input(config=config) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config, app_info @@ -115,9 +114,11 @@ def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: @staticmethod def init( jwt_validity_seconds: Union[int, None] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[JWTOverrideConfig, None] = None, ): - input_config = JWTInputConfig( + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = JWTConfig( jwt_validity_seconds=jwt_validity_seconds, override=override, ) @@ -127,9 +128,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): JWTRecipe.__instance = JWTRecipe( JWTRecipe.recipe_id, app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=JWTRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/jwt/recipe_implementation.py b/supertokens_python/recipe/jwt/recipe_implementation.py index 0e7a3f464..9c308a4d1 100644 --- a/supertokens_python/recipe/jwt/recipe_implementation.py +++ b/supertokens_python/recipe/jwt/recipe_implementation.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo - from .utils import JWTConfig + from .utils import NormalisedJWTConfig from supertokens_python.recipe.jwt.interfaces import ( CreateJwtOkResult, @@ -38,7 +38,9 @@ class RecipeImplementation(RecipeInterface): - def __init__(self, querier: Querier, config: JWTConfig, app_info: AppInfo): + def __init__( + self, querier: Querier, config: NormalisedJWTConfig, app_info: AppInfo + ): super().__init__() self.querier = querier self.config = config diff --git a/supertokens_python/recipe/jwt/utils.py b/supertokens_python/recipe/jwt/utils.py index 6c3474663..37d36b8ff 100644 --- a/supertokens_python/recipe/jwt/utils.py +++ b/supertokens_python/recipe/jwt/utils.py @@ -17,48 +17,44 @@ from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface +JWTOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedJWTOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... - -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... - - -class JWTInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class JWTConfig(BaseConfig[RecipeInterface, APIInterface]): jwt_validity_seconds: Optional[int] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class JWTConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedJWTConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): jwt_validity_seconds: int - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 -def validate_and_normalise_user_input(input_config: JWTInputConfig): - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions +def validate_and_normalise_user_input(config: JWTConfig): + override_config = NormalisedJWTOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis - jwt_validity_seconds = input_config.jwt_validity_seconds + jwt_validity_seconds = config.jwt_validity_seconds - if input_config.jwt_validity_seconds is None: + if config.jwt_validity_seconds is None: jwt_validity_seconds = 3153600000 if not isinstance(jwt_validity_seconds, int): # type: ignore raise ValueError("jwt_validity_seconds must be an integer or None") - return JWTConfig( + return NormalisedJWTConfig( jwt_validity_seconds=jwt_validity_seconds, override=override_config ) diff --git a/supertokens_python/recipe/multifactorauth/__init__.py b/supertokens_python/recipe/multifactorauth/__init__.py index 074bc3bc9..861a4cdb8 100644 --- a/supertokens_python/recipe/multifactorauth/__init__.py +++ b/supertokens_python/recipe/multifactorauth/__init__.py @@ -15,7 +15,9 @@ from typing import TYPE_CHECKING, List, Optional, Union -from supertokens_python.recipe.multifactorauth.types import OverrideConfig +from supertokens_python.recipe.multifactorauth.types import ( + MultiFactorAuthOverrideConfig, +) from .recipe import MultiFactorAuthRecipe @@ -25,7 +27,7 @@ def init( first_factors: Optional[List[str]] = None, - override: Union[OverrideConfig, None] = None, + override: Union[MultiFactorAuthOverrideConfig, None] = None, ) -> RecipeInit: return MultiFactorAuthRecipe.init( first_factors, diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index 03db15635..42dc1a4f8 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -28,7 +28,7 @@ from supertokens_python.types import User from ...supertokens import AppInfo - from .types import MFARequirementList, MultiFactorAuthConfig + from .types import MFARequirementList, NormalisedMultiFactorAuthConfig class RecipeInterface(BaseRecipeInterface): @@ -97,7 +97,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: MultiFactorAuthConfig, + config: NormalisedMultiFactorAuthConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, recipe_instance: MultiFactorAuthRecipe, diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py index 0dc7e0a54..be28497ef 100644 --- a/supertokens_python/recipe/multifactorauth/recipe.py +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -19,7 +19,6 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.querier import Querier from supertokens_python.recipe.multifactorauth.api import ( @@ -49,8 +48,8 @@ GetPhoneNumbersForFactorsFromOtherRecipesFunc, GetPhoneNumbersForFactorsOkResult, GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, - InputOverrideConfig, - MultiFactorAuthInputConfig, + MultiFactorAuthConfig, + MultiFactorAuthOverrideConfig, ) from .utils import validate_and_normalise_user_input @@ -63,7 +62,7 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: MultiFactorAuthInputConfig, + config: MultiFactorAuthConfig, ): super().__init__(recipe_id, app_info) self.get_factors_setup_for_user_from_other_recipes_funcs: List[ @@ -81,7 +80,7 @@ def __init__( self.is_get_mfa_requirements_for_auth_overridden: bool = False self.config = validate_and_normalise_user_input( - input_config=input_config, + config=config, ) recipe_implementation = RecipeImplementation( @@ -159,9 +158,11 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( first_factors: Optional[List[str]] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[MultiFactorAuthOverrideConfig, None] = None, ): - input_config = MultiFactorAuthInputConfig( + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = MultiFactorAuthConfig( first_factors=first_factors, override=override, ) @@ -171,9 +172,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): MultiFactorAuthRecipe.__instance = MultiFactorAuthRecipe( MultiFactorAuthRecipe.recipe_id, app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=MultiFactorAuthRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index 323337e20..ab2572530 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -21,11 +21,10 @@ from supertokens_python.types import RecipeUserId, User from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface @@ -43,20 +42,20 @@ def __init__(self, c: Dict[str, Any], v: bool): self.v = v -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... - - -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... +MultiFactorAuthOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedMultiFactorAuthOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class MultiFactorAuthInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class MultiFactorAuthConfig(BaseConfig[RecipeInterface, APIInterface]): first_factors: Optional[List[str]] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class MultiFactorAuthConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedMultiFactorAuthConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): first_factors: Optional[List[str]] - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 class FactorIds: diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index 7a0037767..f7b328004 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -38,29 +38,29 @@ from .types import ( MultiFactorAuthConfig, - MultiFactorAuthInputConfig, - OverrideConfig, + NormalisedMultiFactorAuthConfig, + NormalisedMultiFactorAuthOverrideConfig, ) # IMPORTANT: If this function signature is modified, please update all tha places where this function is called. # There will be no type errors cause we use importLib to dynamically import if to prevent cyclic import issues. def validate_and_normalise_user_input( - input_config: MultiFactorAuthInputConfig, -) -> MultiFactorAuthConfig: - if input_config.first_factors is not None and len(input_config.first_factors) == 0: + config: MultiFactorAuthConfig, +) -> NormalisedMultiFactorAuthConfig: + if config.first_factors is not None and len(config.first_factors) == 0: raise ValueError("'first_factors' can be either None or a non-empty list") - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions + override_config = NormalisedMultiFactorAuthOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis - return MultiFactorAuthConfig( - first_factors=input_config.first_factors, + return NormalisedMultiFactorAuthConfig( + first_factors=config.first_factors, override=override_config, ) diff --git a/supertokens_python/recipe/multitenancy/__init__.py b/supertokens_python/recipe/multitenancy/__init__.py index 41dc96964..e591c794c 100644 --- a/supertokens_python/recipe/multitenancy/__init__.py +++ b/supertokens_python/recipe/multitenancy/__init__.py @@ -17,6 +17,8 @@ from . import exceptions as ex from . import recipe +from .interfaces import TypeGetAllowedDomainsForTenantId +from .utils import MultitenancyOverrideConfig AllowedDomainsClaim = recipe.AllowedDomainsClaim exceptions = ex @@ -24,15 +26,12 @@ if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit - from .interfaces import TypeGetAllowedDomainsForTenantId - from .utils import InputOverrideConfig - def init( get_allowed_domains_for_tenant_id: Union[ TypeGetAllowedDomainsForTenantId, None ] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[MultitenancyOverrideConfig, None] = None, ) -> RecipeInit: return recipe.MultitenancyRecipe.init( get_allowed_domains_for_tenant_id, diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index a2d6d25c5..0af997881 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -27,7 +27,7 @@ ProviderInput, ) - from .utils import MultitenancyConfig + from .utils import NormalisedMultitenancyConfig class TenantConfig: @@ -283,7 +283,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: MultitenancyConfig, + config: NormalisedMultitenancyConfig, recipe_implementation: RecipeInterface, static_third_party_providers: List[ProviderInput], all_available_first_factors: List[str], diff --git a/supertokens_python/recipe/multitenancy/recipe.py b/supertokens_python/recipe/multitenancy/recipe.py index 51d7874e8..e73b32ad3 100644 --- a/supertokens_python/recipe/multitenancy/recipe.py +++ b/supertokens_python/recipe/multitenancy/recipe.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.exceptions import SuperTokensError, raise_general_exception -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.recipe.session.claim_base_classes.primitive_array_claim import ( PrimitiveArrayClaim, ) @@ -45,8 +44,8 @@ from .constants import LOGIN_METHODS from .exceptions import MultitenancyError from .utils import ( - InputOverrideConfig, - MultitenancyInputConfig, + MultitenancyConfig, + MultitenancyOverrideConfig, validate_and_normalise_user_input, ) @@ -56,10 +55,10 @@ class MultitenancyRecipe(RecipeModule): __instance = None def __init__( - self, recipe_id: str, app_info: AppInfo, input_config: MultitenancyInputConfig + self, recipe_id: str, app_info: AppInfo, config: MultitenancyConfig ) -> None: super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(input_config=input_config) + self.config = validate_and_normalise_user_input(config=config) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config @@ -137,9 +136,11 @@ def init( get_allowed_domains_for_tenant_id: Union[ TypeGetAllowedDomainsForTenantId, None ] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[MultitenancyOverrideConfig, None] = None, ): - input_config = MultitenancyInputConfig( + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = MultitenancyConfig( get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id, override=override, ) @@ -149,9 +150,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): MultitenancyRecipe.__instance = MultitenancyRecipe( recipe_id=MultitenancyRecipe.recipe_id, app_info=app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=MultitenancyRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 13e7701e8..091f980f6 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -41,7 +41,7 @@ from supertokens_python.querier import Querier from supertokens_python.recipe.thirdparty.provider import ProviderConfig - from .utils import MultitenancyConfig + from .utils import NormalisedMultitenancyConfig from supertokens_python.querier import NormalisedURLPath @@ -119,7 +119,7 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfig: class RecipeImplementation(RecipeInterface): - def __init__(self, querier: Querier, config: MultitenancyConfig): + def __init__(self, querier: Querier, config: NormalisedMultitenancyConfig): super().__init__() self.querier = querier self.config = config diff --git a/supertokens_python/recipe/multitenancy/utils.py b/supertokens_python/recipe/multitenancy/utils.py index 84db1840e..8aaeca3c6 100644 --- a/supertokens_python/recipe/multitenancy/utils.py +++ b/supertokens_python/recipe/multitenancy/utils.py @@ -20,11 +20,10 @@ from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone from supertokens_python.utils import ( resolve, ) @@ -67,34 +66,32 @@ async def on_recipe_disabled_for_tenant( ) -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... +MultitenancyOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedMultitenancyOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... - - -class MultitenancyInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class MultitenancyConfig(BaseConfig[RecipeInterface, APIInterface]): get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class MultitenancyConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedMultitenancyConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId] - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( - input_config: MultitenancyInputConfig, -) -> MultitenancyConfig: - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions - - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis - - return MultitenancyConfig( - get_allowed_domains_for_tenant_id=input_config.get_allowed_domains_for_tenant_id, + config: MultitenancyConfig, +) -> NormalisedMultitenancyConfig: + override_config = NormalisedMultitenancyOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions + + if config.override.apis is not None: + override_config.apis = config.override.apis + + return NormalisedMultitenancyConfig( + get_allowed_domains_for_tenant_id=config.get_allowed_domains_for_tenant_id, override=override_config, ) diff --git a/supertokens_python/recipe/oauth2provider/__init__.py b/supertokens_python/recipe/oauth2provider/__init__.py index 3c5735030..bf1f02e6c 100644 --- a/supertokens_python/recipe/oauth2provider/__init__.py +++ b/supertokens_python/recipe/oauth2provider/__init__.py @@ -19,13 +19,13 @@ from . import recipe, utils exceptions = ex -InputOverrideConfig = utils.InputOverrideConfig +OAuth2ProviderOverrideConfig = utils.OAuth2ProviderOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit def init( - override: Union[InputOverrideConfig, None] = None, + override: Union[OAuth2ProviderOverrideConfig, None] = None, ) -> RecipeInit: return recipe.OAuth2ProviderRecipe.init(override) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index f96859747..d2f970b18 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -32,7 +32,7 @@ from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.supertokens import AppInfo - from .utils import OAuth2ProviderConfig + from .utils import NormalisedOAuth2ProviderConfig class ErrorOAuth2Response(APIResponse): @@ -1274,14 +1274,14 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: OAuth2ProviderConfig, + config: NormalisedOAuth2ProviderConfig, recipe_implementation: RecipeInterface, ): self.app_info: AppInfo = app_info self.request: BaseRequest = request self.response: BaseResponse = response self.recipe_id: str = recipe_id - self.config: OAuth2ProviderConfig = config + self.config: NormalisedOAuth2ProviderConfig = config self.recipe_implementation: RecipeInterface = recipe_implementation diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index ea54c40bd..a7dcb1ae4 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.exceptions import SuperTokensError, raise_general_exception -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.recipe.oauth2provider.exceptions import OAuth2ProviderError from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.types import User @@ -67,9 +66,9 @@ USER_INFO_PATH, ) from .utils import ( - InputOverrideConfig, + NormalisedOAuth2ProviderConfig, OAuth2ProviderConfig, - OAuth2ProviderInputConfig, + OAuth2ProviderOverrideConfig, validate_and_normalise_user_input, ) @@ -82,11 +81,11 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: OAuth2ProviderInputConfig, + config: OAuth2ProviderConfig, ) -> None: super().__init__(recipe_id, app_info) - self.config: OAuth2ProviderConfig = validate_and_normalise_user_input( - input_config=input_config, + self.config: NormalisedOAuth2ProviderConfig = validate_and_normalise_user_input( + config=config, ) from .recipe_implementation import RecipeImplementation @@ -264,18 +263,20 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( - override: Union[InputOverrideConfig, None] = None, + override: Optional[OAuth2ProviderOverrideConfig] = None, ): - input_config = OAuth2ProviderInputConfig(override=override) + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = OAuth2ProviderConfig(override=override) def func(app_info: AppInfo, plugins: List[OverrideMap]) -> OAuth2ProviderRecipe: if OAuth2ProviderRecipe.__instance is None: OAuth2ProviderRecipe.__instance = OAuth2ProviderRecipe( recipe_id=OAuth2ProviderRecipe.recipe_id, app_info=app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=OAuth2ProviderRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py index f33f19782..2c6107241 100644 --- a/supertokens_python/recipe/oauth2provider/utils.py +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -13,40 +13,36 @@ # under the License. from __future__ import annotations -from typing import Optional - from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface - -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... - - -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... +OAuth2ProviderOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedOAuth2ProviderOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class OAuth2ProviderInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 +class OAuth2ProviderConfig(BaseConfig[RecipeInterface, APIInterface]): ... -class OAuth2ProviderConfig(BaseConfig[RecipeInterface, APIInterface]): - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 +class NormalisedOAuth2ProviderConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): ... -def validate_and_normalise_user_input(input_config: OAuth2ProviderInputConfig): - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions +def validate_and_normalise_user_input(config: OAuth2ProviderConfig): + override_config = NormalisedOAuth2ProviderOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis - return OAuth2ProviderConfig(override=override_config) + return NormalisedOAuth2ProviderConfig(override=override_config) diff --git a/supertokens_python/recipe/openid/__init__.py b/supertokens_python/recipe/openid/__init__.py index 06a411438..7ce1e58a8 100644 --- a/supertokens_python/recipe/openid/__init__.py +++ b/supertokens_python/recipe/openid/__init__.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Union from .recipe import OpenIdRecipe -from .utils import InputOverrideConfig +from .utils import OpenIdOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit @@ -24,6 +24,6 @@ def init( issuer: Union[str, None] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[OpenIdOverrideConfig, None] = None, ) -> RecipeInit: return OpenIdRecipe.init(issuer, override) diff --git a/supertokens_python/recipe/openid/interfaces.py b/supertokens_python/recipe/openid/interfaces.py index 6719b129c..b4dc84a37 100644 --- a/supertokens_python/recipe/openid/interfaces.py +++ b/supertokens_python/recipe/openid/interfaces.py @@ -24,7 +24,7 @@ from supertokens_python.types.response import APIResponse, GeneralErrorResponse if TYPE_CHECKING: - from .utils import OpenIdConfig + from .utils import NormalisedOpenIdConfig class GetOpenIdDiscoveryConfigurationResult: @@ -103,7 +103,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: "OpenIdConfig", + config: "NormalisedOpenIdConfig", recipe_implementation: RecipeInterface, ): self.request = request diff --git a/supertokens_python/recipe/openid/recipe.py b/supertokens_python/recipe/openid/recipe.py index 50a72d927..f1c4325ec 100644 --- a/supertokens_python/recipe/openid/recipe.py +++ b/supertokens_python/recipe/openid/recipe.py @@ -16,7 +16,6 @@ from os import environ from typing import TYPE_CHECKING, Any, Dict, List, Union -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from .api.implementation import APIImplementation @@ -26,8 +25,8 @@ from .interfaces import APIOptions from .recipe_implementation import RecipeImplementation from .utils import ( - InputOverrideConfig, - OpenIdInputConfig, + OpenIdConfig, + OpenIdOverrideConfig, validate_and_normalise_user_input, ) @@ -49,13 +48,13 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: OpenIdInputConfig, + config: OpenIdConfig, ): from supertokens_python.recipe.jwt import JWTRecipe super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - app_info=app_info, input_config=input_config + app_info=app_info, config=config ) self.jwt_recipe = JWTRecipe.get_instance() @@ -129,18 +128,23 @@ def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: @staticmethod def init( issuer: Union[str, None] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[OpenIdOverrideConfig, None] = None, ): - input_config = OpenIdInputConfig(issuer=issuer, override=override) + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = OpenIdConfig( + issuer=issuer, + override=override, + ) def func(app_info: AppInfo, plugins: List[OverrideMap]): if OpenIdRecipe.__instance is None: OpenIdRecipe.__instance = OpenIdRecipe( recipe_id=OpenIdRecipe.recipe_id, app_info=app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=OpenIdRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/openid/recipe_implementation.py b/supertokens_python/recipe/openid/recipe_implementation.py index 72bd9362b..ba97f264d 100644 --- a/supertokens_python/recipe/openid/recipe_implementation.py +++ b/supertokens_python/recipe/openid/recipe_implementation.py @@ -21,7 +21,7 @@ from supertokens_python.supertokens import AppInfo from .interfaces import CreateJwtOkResult, CreateJwtResultUnsupportedAlgorithm - from .utils import OpenIdConfig + from .utils import NormalisedOpenIdConfig from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.jwt.constants import GET_JWKS_API @@ -81,7 +81,7 @@ async def get_open_id_discovery_configuration( def __init__( self, querier: Querier, - config: OpenIdConfig, + config: NormalisedOpenIdConfig, app_info: AppInfo, ): super().__init__() diff --git a/supertokens_python/recipe/openid/utils.py b/supertokens_python/recipe/openid/utils.py index 821c14213..ae0887ad7 100644 --- a/supertokens_python/recipe/openid/utils.py +++ b/supertokens_python/recipe/openid/utils.py @@ -13,70 +13,63 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional -from supertokens_python.recipe.jwt import OverrideConfig as JWTOverrideConfig +from supertokens_python.normalised_url_domain import NormalisedURLDomain +from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone - -if TYPE_CHECKING: - from supertokens_python import AppInfo - - -from supertokens_python.normalised_url_domain import NormalisedURLDomain -from supertokens_python.normalised_url_path import NormalisedURLPath from .interfaces import APIInterface, RecipeInterface - -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): - jwt_feature: Union[JWTOverrideConfig, None] = None +if TYPE_CHECKING: + from supertokens_python import AppInfo -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... +OpenIdOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedOpenIdOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class OpenIdInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): - issuer: Union[str, None] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 +class OpenIdConfig(BaseConfig[RecipeInterface, APIInterface]): + issuer: Optional[str] = None -class OpenIdConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedOpenIdConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): issuer_domain: NormalisedURLDomain issuer_path: NormalisedURLPath - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( app_info: AppInfo, - input_config: OpenIdInputConfig, -) -> OpenIdConfig: - if input_config.issuer is None: + config: OpenIdConfig, +) -> NormalisedOpenIdConfig: + if config.issuer is None: issuer_domain = app_info.api_domain issuer_path = app_info.api_base_path else: - issuer_domain = NormalisedURLDomain(input_config.issuer) - issuer_path = NormalisedURLPath(input_config.issuer) + issuer_domain = NormalisedURLDomain(config.issuer) + issuer_path = NormalisedURLPath(config.issuer) if not issuer_path.equals(app_info.api_base_path): raise Exception( "The path of the issuer URL must be equal to the apiBasePath. The default value is /auth" ) - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions + override_config = NormalisedOpenIdOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis - return OpenIdConfig( + return NormalisedOpenIdConfig( issuer_domain=issuer_domain, issuer_path=issuer_path, override=override_config, diff --git a/supertokens_python/recipe/passwordless/__init__.py b/supertokens_python/recipe/passwordless/__init__.py index bf53b9277..0ab7ef111 100644 --- a/supertokens_python/recipe/passwordless/__init__.py +++ b/supertokens_python/recipe/passwordless/__init__.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit -InputOverrideConfig = utils.InputOverrideConfig +PasswordlessOverrideConfig = utils.PasswordlessOverrideConfig ContactEmailOnlyConfig = utils.ContactEmailOnlyConfig ContactConfig = utils.ContactConfig PhoneOrEmailInput = utils.PhoneOrEmailInput @@ -55,7 +55,7 @@ def init( flow_type: Literal[ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" ], - override: Union[InputOverrideConfig, None] = None, + override: Union[PasswordlessOverrideConfig, None] = None, get_custom_user_input_code: Union[ Callable[[str, Dict[str, Any]], Awaitable[str]], None ] = None, diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index 2807e0695..165097275 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -40,7 +40,7 @@ ) if TYPE_CHECKING: - from .utils import PasswordlessConfig + from .utils import NormalisedPasswordlessConfig class CreateCodeOkResult: @@ -361,7 +361,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: PasswordlessConfig, + config: NormalisedPasswordlessConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, email_delivery: EmailDeliveryIngredient[PasswordlessLoginEmailTemplateVars], diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index fc24aeca0..043b9272d 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -22,7 +22,6 @@ from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.ingredients.smsdelivery import SMSDeliveryIngredient -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.recipe.multifactorauth.types import ( @@ -73,8 +72,8 @@ from .recipe_implementation import RecipeImplementation from .utils import ( ContactConfig, - InputOverrideConfig, - PasswordlessInputConfig, + PasswordlessConfig, + PasswordlessOverrideConfig, get_enabled_pwless_factors, validate_and_normalise_user_input, ) @@ -101,12 +100,12 @@ def __init__( recipe_id: str, app_info: AppInfo, ingredients: PasswordlessIngredients, - input_config: PasswordlessInputConfig, + config: PasswordlessConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( app_info=app_info, - input_config=input_config, + config=config, ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) @@ -486,7 +485,7 @@ def init( flow_type: Literal[ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" ], - override: Union[InputOverrideConfig, None] = None, + override: Optional[PasswordlessOverrideConfig] = None, get_custom_user_input_code: Union[ Callable[[str, Dict[str, Any]], Awaitable[str]], None ] = None, @@ -497,7 +496,9 @@ def init( SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None ] = None, ): - input_config = PasswordlessInputConfig( + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = PasswordlessConfig( contact_config=contact_config, get_custom_user_input_code=get_custom_user_input_code, email_delivery=email_delivery, @@ -513,9 +514,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): recipe_id=PasswordlessRecipe.recipe_id, app_info=app_info, ingredients=ingredients, - input_config=apply_plugins( + config=apply_plugins( recipe_id=PasswordlessRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/passwordless/utils.py b/supertokens_python/recipe/passwordless/utils.py index 4487be33b..ed1b15365 100644 --- a/supertokens_python/recipe/passwordless/utils.py +++ b/supertokens_python/recipe/passwordless/utils.py @@ -16,7 +16,7 @@ from abc import ABC from re import fullmatch -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union from phonenumbers import is_valid_number, parse from typing_extensions import Literal @@ -41,11 +41,10 @@ ) from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import ( APIInterface, @@ -72,10 +71,10 @@ async def default_validate_email(value: str, _tenant_id: str): return "Email is invalid" -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... - - -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... +PasswordlessOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedPasswordlessOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] class ContactConfig(ABC): @@ -142,7 +141,7 @@ def __init__(self, phone_number: Union[str, None], email: Union[str, None]): self.email = email -class PasswordlessInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class PasswordlessConfig(BaseConfig[RecipeInterface, APIInterface]): contact_config: ContactConfig flow_type: Literal[ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" @@ -156,10 +155,9 @@ class PasswordlessInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): sms_delivery: Union[SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None] = ( None ) - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class PasswordlessConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedPasswordlessConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): contact_config: ContactConfig flow_type: Literal[ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" @@ -173,38 +171,35 @@ class PasswordlessConfig(BaseConfig[RecipeInterface, APIInterface]): get_custom_user_input_code: Union[ Callable[[str, Dict[str, Any]], Awaitable[str]], None ] - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( app_info: AppInfo, - input_config: PasswordlessInputConfig, -) -> PasswordlessConfig: - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions + config: PasswordlessConfig, +) -> NormalisedPasswordlessConfig: + override_config = NormalisedPasswordlessOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis def get_email_delivery_config() -> EmailDeliveryConfigWithService[ PasswordlessLoginEmailTemplateVars ]: email_service = ( - input_config.email_delivery.service - if input_config.email_delivery is not None - else None + config.email_delivery.service if config.email_delivery is not None else None ) if email_service is None: email_service = BackwardCompatibilityService(app_info) if ( - input_config.email_delivery is not None - and input_config.email_delivery.override is not None + config.email_delivery is not None + and config.email_delivery.override is not None ): - override = input_config.email_delivery.override + override = config.email_delivery.override else: override = None @@ -214,28 +209,23 @@ def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ PasswordlessLoginSMSTemplateVars ]: sms_service = ( - input_config.sms_delivery.service - if input_config.sms_delivery is not None - else None + config.sms_delivery.service if config.sms_delivery is not None else None ) if sms_service is None: sms_service = SMSBackwardCompatibilityService(app_info) - if ( - input_config.sms_delivery is not None - and input_config.sms_delivery.override is not None - ): - override = input_config.sms_delivery.override + if config.sms_delivery is not None and config.sms_delivery.override is not None: + override = config.sms_delivery.override else: override = None return SMSDeliveryConfigWithService(sms_service, override=override) - if not isinstance(input_config.contact_config, ContactConfig): # type: ignore user might not have linter enabled + if not isinstance(config.contact_config, ContactConfig): # type: ignore user might not have linter enabled raise ValueError("contact_config must be of type ContactConfig") - if input_config.flow_type not in [ + if config.flow_type not in [ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK", @@ -244,18 +234,18 @@ def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ "flow_type must be one of USER_INPUT_CODE, MAGIC_LINK, USER_INPUT_CODE_AND_MAGIC_LINK" ) - return PasswordlessConfig( - contact_config=input_config.contact_config, + return NormalisedPasswordlessConfig( + contact_config=config.contact_config, override=override_config, - flow_type=input_config.flow_type, + flow_type=config.flow_type, get_email_delivery_config=get_email_delivery_config, get_sms_delivery_config=get_sms_delivery_config, - get_custom_user_input_code=input_config.get_custom_user_input_code, + get_custom_user_input_code=config.get_custom_user_input_code, ) def get_enabled_pwless_factors( - config: PasswordlessConfig, + config: NormalisedPasswordlessConfig, ) -> List[str]: all_factors: List[str] = [] diff --git a/supertokens_python/recipe/session/__init__.py b/supertokens_python/recipe/session/__init__.py index 799221eb7..db5fc865b 100644 --- a/supertokens_python/recipe/session/__init__.py +++ b/supertokens_python/recipe/session/__init__.py @@ -26,7 +26,7 @@ from .utils import TokenTransferMethod InputErrorHandlers = utils.InputErrorHandlers -InputOverrideConfig = utils.InputOverrideConfig +SessionOverrideConfig = utils.SessionOverrideConfig SessionContainer = interfaces.SessionContainer exceptions = ex @@ -46,7 +46,7 @@ def init( None, ] = None, error_handlers: Union[InputErrorHandlers, None] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[SessionOverrideConfig, None] = None, invalid_claim_status_code: Union[int, None] = None, use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, diff --git a/supertokens_python/recipe/session/access_token.py b/supertokens_python/recipe/session/access_token.py index 1205b6c7b..cdfd23abe 100644 --- a/supertokens_python/recipe/session/access_token.py +++ b/supertokens_python/recipe/session/access_token.py @@ -21,7 +21,7 @@ from supertokens_python.logger import log_debug_message from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID from supertokens_python.recipe.session.jwks import get_latest_keys -from supertokens_python.recipe.session.utils import SessionConfig +from supertokens_python.recipe.session.utils import NormalisedSessionConfig from supertokens_python.utils import get_timestamp_ms from .exceptions import raise_try_refresh_token_exception @@ -46,7 +46,7 @@ def sanitize_number(n: Any) -> Union[Union[int, float], None]: def get_info_from_access_token( - config: SessionConfig, + config: NormalisedSessionConfig, jwt_info: ParsedJWTInfo, do_anti_csrf_check: bool, ): diff --git a/supertokens_python/recipe/session/cookie_and_header.py b/supertokens_python/recipe/session/cookie_and_header.py index 1a7aa2e50..18e4a40b1 100644 --- a/supertokens_python/recipe/session/cookie_and_header.py +++ b/supertokens_python/recipe/session/cookie_and_header.py @@ -45,7 +45,7 @@ from .recipe import SessionRecipe from .utils import ( - SessionConfig, + NormalisedSessionConfig, TokenTransferMethod, TokenType, ) @@ -111,7 +111,7 @@ def get_cookie(request: BaseRequest, key: str): def _set_cookie( response: BaseResponse, - config: SessionConfig, + config: NormalisedSessionConfig, key: str, value: str, expires: int, @@ -141,7 +141,7 @@ def _set_cookie( def set_cookie_response_mutator( - config: SessionConfig, + config: NormalisedSessionConfig, key: str, value: str, expires: int, @@ -207,7 +207,7 @@ def clear_session_from_all_token_transfer_methods( def clear_session_mutator( - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, ): @@ -222,7 +222,7 @@ def mutator( def _clear_session( response: BaseResponse, - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, user_context: Dict[str, Any], @@ -244,7 +244,7 @@ def _clear_session( def clear_session_response_mutator( - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, ): @@ -293,7 +293,7 @@ def get_token( def _set_token( response: BaseResponse, - config: SessionConfig, + config: NormalisedSessionConfig, token_type: TokenType, value: str, expires: int, @@ -323,7 +323,7 @@ def _set_token( def token_response_mutator( - config: SessionConfig, + config: NormalisedSessionConfig, token_type: TokenType, value: str, expires: int, @@ -356,7 +356,7 @@ def set_token_in_header(response: BaseResponse, name: str, value: str): def access_token_mutator( access_token: str, front_token: str, - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, ): @@ -381,7 +381,7 @@ def _set_access_token_in_response( res: BaseResponse, access_token: str, front_token: str, - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, user_context: Dict[str, Any], @@ -430,7 +430,7 @@ def _set_access_token_in_response( # This function checks for multiple cookies with the same name and clears the cookies for the older domain. def clear_session_cookies_from_older_cookie_domain( - request: BaseRequest, config: SessionConfig, user_context: Dict[str, Any] + request: BaseRequest, config: NormalisedSessionConfig, user_context: Dict[str, Any] ): allowed_transfer_method = config.get_token_transfer_method( request, False, user_context diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 3ce2ca036..fd8d79f10 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -41,7 +41,7 @@ from .exceptions import ClaimValidationError if TYPE_CHECKING: - from .utils import SessionConfig, TokenTransferMethod + from .utils import NormalisedSessionConfig, TokenTransferMethod class SessionObj: @@ -372,7 +372,7 @@ def __init__( request: BaseRequest, response: Optional[BaseResponse], recipe_id: str, - config: SessionConfig, + config: NormalisedSessionConfig, recipe_implementation: RecipeInterface, ): self.request = request @@ -445,7 +445,7 @@ class SessionContainer(ABC): # pylint: disable=too-many-public-methods def __init__( self, recipe_implementation: RecipeInterface, - config: SessionConfig, + config: NormalisedSessionConfig, access_token: str, front_token: str, refresh_token: Optional[TokenInfo], diff --git a/supertokens_python/recipe/session/jwks.py b/supertokens_python/recipe/session/jwks.py index e2eb39c7d..f4f8aa35d 100644 --- a/supertokens_python/recipe/session/jwks.py +++ b/supertokens_python/recipe/session/jwks.py @@ -21,7 +21,7 @@ from supertokens_python.logger import log_debug_message from supertokens_python.querier import Querier -from supertokens_python.recipe.session.utils import SessionConfig +from supertokens_python.recipe.session.utils import NormalisedSessionConfig from supertokens_python.utils import RWLockContext, RWMutex, get_timestamp_ms @@ -88,7 +88,9 @@ def find_matching_keys( return None -def get_latest_keys(config: SessionConfig, kid: Optional[str] = None) -> List[PyJWK]: +def get_latest_keys( + config: NormalisedSessionConfig, kid: Optional[str] = None +) -> List[PyJWK]: global cached_keys if environ.get("SUPERTOKENS_ENV") == "testing": diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index d4476ca69..29764de5a 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -18,11 +18,18 @@ from typing_extensions import Literal +from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework.response import BaseResponse -from supertokens_python.plugins import OverrideMap, apply_plugins +from supertokens_python.logger import log_debug_message +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.querier import Querier +from supertokens_python.recipe_module import APIHandled, RecipeModule from ...types import MaybeAwaitable +from .api import handle_refresh_api, handle_signout_api +from .constants import SESSION_REFRESH, SIGNOUT from .cookie_and_header import ( + clear_session_from_all_token_transfer_methods, get_cors_allowed_headers, ) from .exceptions import ( @@ -32,20 +39,6 @@ TokenTheftError, UnauthorisedError, ) - -if TYPE_CHECKING: - from supertokens_python.framework import BaseRequest - from supertokens_python.supertokens import AppInfo - -from supertokens_python.exceptions import SuperTokensError, raise_general_exception -from supertokens_python.logger import log_debug_message -from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.querier import Querier -from supertokens_python.recipe_module import APIHandled, RecipeModule - -from .api import handle_refresh_api, handle_signout_api -from .constants import SESSION_REFRESH, SIGNOUT -from .cookie_and_header import clear_session_from_all_token_transfer_methods from .interfaces import ( APIInterface, APIOptions, @@ -59,12 +52,16 @@ ) from .utils import ( InputErrorHandlers, - InputOverrideConfig, - SessionInputConfig, + SessionConfig, + SessionOverrideConfig, TokenTransferMethod, validate_and_normalise_user_input, ) +if TYPE_CHECKING: + from supertokens_python.framework import BaseRequest + from supertokens_python.supertokens import AppInfo + class SessionRecipe(RecipeModule): recipe_id = "session" @@ -74,12 +71,12 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: SessionInputConfig, + config: SessionConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( app_info=app_info, - input_config=input_config, + config=config, ) log_debug_message( "session init: anti_csrf: %s", self.config.anti_csrf_function_or_string @@ -93,9 +90,9 @@ def __init__( # we check the input cookie_same_site because the normalised version is # always a function. - if input_config.cookie_same_site is not None: + if config.cookie_same_site is not None: log_debug_message( - "session init: cookie_same_site: %s", input_config.cookie_same_site + "session init: cookie_same_site: %s", config.cookie_same_site ) else: log_debug_message("session init: cookie_same_site: function") @@ -258,13 +255,15 @@ def init( None, ] = None, error_handlers: Union[InputErrorHandlers, None] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[SessionOverrideConfig, None] = None, invalid_claim_status_code: Union[int, None] = None, use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, ): - input_config = SessionInputConfig( + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = SessionConfig( cookie_domain=cookie_domain, older_cookie_domain=older_cookie_domain, cookie_secure=cookie_secure, @@ -285,9 +284,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): SessionRecipe.__instance = SessionRecipe( recipe_id=SessionRecipe.recipe_id, app_info=app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=SessionRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index aa073a872..750d721e4 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -39,7 +39,7 @@ ) from .jwt import ParsedJWTInfo, parse_jwt_without_signature_verification from .session_class import Session -from .utils import SessionConfig, validate_claims_in_payload +from .utils import NormalisedSessionConfig, validate_claims_in_payload if TYPE_CHECKING: from typing import List, Union @@ -54,7 +54,9 @@ class RecipeImplementation(RecipeInterface): # pylint: disable=too-many-public-methods - def __init__(self, querier: Querier, config: SessionConfig, app_info: AppInfo): + def __init__( + self, querier: Querier, config: NormalisedSessionConfig, app_info: AppInfo + ): super().__init__() self.querier = querier self.config = config diff --git a/supertokens_python/recipe/session/session_request_functions.py b/supertokens_python/recipe/session/session_request_functions.py index 40e6c6d58..1e2320817 100644 --- a/supertokens_python/recipe/session/session_request_functions.py +++ b/supertokens_python/recipe/session/session_request_functions.py @@ -47,7 +47,7 @@ parse_jwt_without_signature_verification, ) from supertokens_python.recipe.session.utils import ( - SessionConfig, + NormalisedSessionConfig, TokenTransferMethod, get_auth_mode_from_header, get_required_claim_validators, @@ -75,7 +75,7 @@ async def get_session_from_request( request: Any, - config: SessionConfig, + config: NormalisedSessionConfig, recipe_interface_impl: SessionRecipeInterface, session_required: Optional[bool] = None, anti_csrf_check: Optional[bool] = None, @@ -240,7 +240,7 @@ async def create_new_session_in_request( access_token_payload: Dict[str, Any], user_id: str, recipe_user_id: RecipeUserId, - config: SessionConfig, + config: NormalisedSessionConfig, app_info: AppInfo, session_data_in_database: Dict[str, Any], tenant_id: str, @@ -353,7 +353,7 @@ async def create_new_session_in_request( async def refresh_session_in_request( request: Any, user_context: Dict[str, Any], - config: SessionConfig, + config: NormalisedSessionConfig, recipe_interface_impl: SessionRecipeInterface, ) -> SessionContainer: log_debug_message("refreshSession: Started") diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 65a24e6ee..022cd1e7a 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -22,16 +22,12 @@ from supertokens_python.exceptions import raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.recipe.openid import ( - InputOverrideConfig as OpenIdInputOverrideConfig, -) from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone from supertokens_python.utils import ( is_an_ip_address, resolve, @@ -340,18 +336,17 @@ def get_token_transfer_method_default( return "any" -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): - openid_feature: Optional[OpenIdInputOverrideConfig] = None - - -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... +SessionOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedSessionOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] TokenType = Literal["access", "refresh"] TokenTransferMethod = Literal["cookie", "header"] -class SessionInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class SessionConfig(BaseConfig[RecipeInterface, APIInterface]): cookie_domain: Union[str, None] = None older_cookie_domain: Union[str, None] = None cookie_secure: Union[bool, None] = None @@ -370,10 +365,9 @@ class SessionInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): use_dynamic_access_token_signing_key: Union[bool, None] = None expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None jwks_refresh_interval_sec: Union[int, None] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class SessionConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedSessionConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): refresh_token_path: NormalisedURLPath cookie_domain: Union[None, str] older_cookie_domain: Union[None, str] @@ -401,55 +395,50 @@ class SessionConfig(BaseConfig[RecipeInterface, APIInterface]): use_dynamic_access_token_signing_key: bool expose_access_token_to_frontend_in_cookie_based_auth: bool jwks_refresh_interval_sec: int - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( app_info: AppInfo, - input_config: SessionInputConfig, + config: SessionConfig, ): # _ = cookie_same_site # we have this otherwise pylint complains that cookie_same_site is unused, but it is being used in the get_cookie_same_site function. - if input_config.anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}: + if config.anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}: raise ValueError( "anti_csrf must be one of VIA_TOKEN, VIA_CUSTOM_HEADER, NONE or None" ) - if input_config.error_handlers is not None and not isinstance( - input_config.error_handlers, ErrorHandlers + if config.error_handlers is not None and not isinstance( + config.error_handlers, ErrorHandlers ): # type: ignore raise ValueError("error_handlers must be an instance of ErrorHandlers or None") - # if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - # raise ValueError("override must be an instance of InputOverrideConfig or None") - cookie_domain = ( - normalise_session_scope(input_config.cookie_domain) - if input_config.cookie_domain is not None + normalise_session_scope(config.cookie_domain) + if config.cookie_domain is not None else None ) older_cookie_domain = ( - input_config.older_cookie_domain - if input_config.older_cookie_domain is None - or input_config.older_cookie_domain == "" - else normalise_session_scope(input_config.older_cookie_domain) + config.older_cookie_domain + if config.older_cookie_domain is None or config.older_cookie_domain == "" + else normalise_session_scope(config.older_cookie_domain) ) cookie_secure = ( - input_config.cookie_secure - if input_config.cookie_secure is not None + config.cookie_secure + if config.cookie_secure is not None else app_info.api_domain.get_as_string_dangerous().startswith("https") ) session_expired_status_code = ( - input_config.session_expired_status_code - if input_config.session_expired_status_code is not None + config.session_expired_status_code + if config.session_expired_status_code is not None else 401 ) invalid_claim_status_code = ( - input_config.invalid_claim_status_code - if input_config.invalid_claim_status_code is not None + config.invalid_claim_status_code + if config.invalid_claim_status_code is not None else 403 ) @@ -459,27 +448,25 @@ def validate_and_normalise_user_input( f"({invalid_claim_status_code})" ) - get_token_transfer_method = input_config.get_token_transfer_method + get_token_transfer_method = config.get_token_transfer_method if get_token_transfer_method is None: get_token_transfer_method = get_token_transfer_method_default - error_handlers = input_config.error_handlers + error_handlers = config.error_handlers if error_handlers is None: error_handlers = InputErrorHandlers() - use_dynamic_access_token_signing_key = ( - input_config.use_dynamic_access_token_signing_key - ) + use_dynamic_access_token_signing_key = config.use_dynamic_access_token_signing_key if use_dynamic_access_token_signing_key is None: use_dynamic_access_token_signing_key = True expose_access_token_to_frontend_in_cookie_based_auth = ( - input_config.expose_access_token_to_frontend_in_cookie_based_auth + config.expose_access_token_to_frontend_in_cookie_based_auth ) if expose_access_token_to_frontend_in_cookie_based_auth is None: expose_access_token_to_frontend_in_cookie_based_auth = False - cookie_same_site = input_config.cookie_same_site + cookie_same_site = config.cookie_same_site if cookie_same_site is not None: # this is just so that we check that the user has provided the right # values, since normalise_same_site throws an error if the user @@ -527,23 +514,23 @@ def anti_csrf_function( Literal["VIA_CUSTOM_HEADER", "NONE", "VIA_TOKEN"], ] = anti_csrf_function - anti_csrf = input_config.anti_csrf + anti_csrf = config.anti_csrf if anti_csrf is not None: anti_csrf_function_or_string = anti_csrf - jwks_refresh_interval_sec = input_config.jwks_refresh_interval_sec + jwks_refresh_interval_sec = config.jwks_refresh_interval_sec if jwks_refresh_interval_sec is None: jwks_refresh_interval_sec = 4 * 3600 # 4 hours - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions + override_config = NormalisedSessionOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis - return SessionConfig( + return NormalisedSessionConfig( refresh_token_path=app_info.api_base_path.append( NormalisedURLPath(SESSION_REFRESH) ), diff --git a/supertokens_python/recipe/thirdparty/__init__.py b/supertokens_python/recipe/thirdparty/__init__.py index 9f60e17c9..6d03ffb2f 100644 --- a/supertokens_python/recipe/thirdparty/__init__.py +++ b/supertokens_python/recipe/thirdparty/__init__.py @@ -20,7 +20,7 @@ from . import provider, utils from .recipe import ThirdPartyRecipe -InputOverrideConfig = utils.InputOverrideConfig +ThirdPartyOverrideConfig = utils.ThirdPartyOverrideConfig SignInAndUpFeature = utils.SignInAndUpFeature ProviderInput = provider.ProviderInput ProviderConfig = provider.ProviderConfig @@ -33,7 +33,7 @@ def init( sign_in_and_up_feature: Optional[SignInAndUpFeature] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[ThirdPartyOverrideConfig, None] = None, ) -> RecipeInit: if sign_in_and_up_feature is None: sign_in_and_up_feature = SignInAndUpFeature() diff --git a/supertokens_python/recipe/thirdparty/interfaces.py b/supertokens_python/recipe/thirdparty/interfaces.py index 181b3669e..f7d01ae0c 100644 --- a/supertokens_python/recipe/thirdparty/interfaces.py +++ b/supertokens_python/recipe/thirdparty/interfaces.py @@ -29,7 +29,7 @@ from supertokens_python.types.auth_utils import LinkingToSessionUserFailedError from .types import RawUserInfoFromProvider - from .utils import ThirdPartyConfig + from .utils import NormalisedThirdPartyConfig class SignInUpOkResult: @@ -137,7 +137,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: ThirdPartyConfig, + config: NormalisedThirdPartyConfig, recipe_implementation: RecipeInterface, providers: List[ProviderInput], app_info: AppInfo, @@ -145,7 +145,7 @@ def __init__( self.request: BaseRequest = request self.response: BaseResponse = response self.recipe_id: str = recipe_id - self.config: ThirdPartyConfig = config + self.config: NormalisedThirdPartyConfig = config self.providers: List[ProviderInput] = providers self.recipe_implementation: RecipeInterface = recipe_implementation self.app_info: AppInfo = app_info diff --git a/supertokens_python/recipe/thirdparty/recipe.py b/supertokens_python/recipe/thirdparty/recipe.py index d0a40d4ab..05a26ed4f 100644 --- a/supertokens_python/recipe/thirdparty/recipe.py +++ b/supertokens_python/recipe/thirdparty/recipe.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe_module import APIHandled, RecipeModule @@ -31,7 +30,7 @@ from supertokens_python.framework.response import BaseResponse from supertokens_python.supertokens import AppInfo - from .utils import InputOverrideConfig, SignInAndUpFeature + from .utils import SignInAndUpFeature, ThirdPartyOverrideConfig from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe @@ -44,7 +43,7 @@ from .constants import APPLE_REDIRECT_HANDLER, AUTHORISATIONURL, SIGNINUP from .exceptions import SuperTokensThirdPartyError from .types import ThirdPartyIngredients -from .utils import ThirdPartyInputConfig, validate_and_normalise_user_input +from .utils import ThirdPartyConfig, validate_and_normalise_user_input class ThirdPartyRecipe(RecipeModule): @@ -55,11 +54,11 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: ThirdPartyInputConfig, + config: ThirdPartyConfig, _ingredients: ThirdPartyIngredients, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(input_config=input_config) + self.config = validate_and_normalise_user_input(config=config) self.providers = self.config.sign_in_and_up_feature.providers recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.providers @@ -158,9 +157,11 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( sign_in_and_up_feature: SignInAndUpFeature, - override: Union[InputOverrideConfig, None] = None, + override: Union[ThirdPartyOverrideConfig, None] = None, ): - input_config = ThirdPartyInputConfig( + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = ThirdPartyConfig( sign_in_and_up_feature=sign_in_and_up_feature, override=override, ) @@ -172,9 +173,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): recipe_id=ThirdPartyRecipe.recipe_id, app_info=app_info, _ingredients=ingredients, - input_config=apply_plugins( + config=apply_plugins( recipe_id=ThirdPartyRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/thirdparty/utils.py b/supertokens_python/recipe/thirdparty/utils.py index 1b4784ac6..baef7c4d9 100644 --- a/supertokens_python/recipe/thirdparty/utils.py +++ b/supertokens_python/recipe/thirdparty/utils.py @@ -15,23 +15,22 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set +from jwt import PyJWKClient, decode # type: ignore + from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.thirdparty.provider import ProviderInput from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface if TYPE_CHECKING: from .provider import ProviderInput -from jwt import PyJWKClient, decode # type: ignore - class SignInAndUpFeature: def __init__(self, providers: Optional[List[ProviderInput]] = None): @@ -53,40 +52,38 @@ def __init__(self, providers: Optional[List[ProviderInput]] = None): self.providers = providers -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... +ThirdPartyOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedThirdPartyOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... - - -class ThirdPartyInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class ThirdPartyConfig(BaseConfig[RecipeInterface, APIInterface]): sign_in_and_up_feature: SignInAndUpFeature - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class ThirdPartyConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedThirdPartyConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): sign_in_and_up_feature: SignInAndUpFeature - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( - input_config: ThirdPartyInputConfig, -) -> ThirdPartyConfig: - if not isinstance(input_config.sign_in_and_up_feature, SignInAndUpFeature): # type: ignore + config: ThirdPartyConfig, +) -> NormalisedThirdPartyConfig: + if not isinstance(config.sign_in_and_up_feature, SignInAndUpFeature): # type: ignore raise ValueError( "sign_in_and_up_feature must be an instance of SignInAndUpFeature" ) - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions + override_config = NormalisedThirdPartyOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + if config.override.apis is not None: + override_config.apis = config.override.apis - return ThirdPartyConfig( - sign_in_and_up_feature=input_config.sign_in_and_up_feature, + return NormalisedThirdPartyConfig( + sign_in_and_up_feature=config.sign_in_and_up_feature, override=override_config, ) diff --git a/supertokens_python/recipe/totp/interfaces.py b/supertokens_python/recipe/totp/interfaces.py index 9ea482110..0cafe03a6 100644 --- a/supertokens_python/recipe/totp/interfaces.py +++ b/supertokens_python/recipe/totp/interfaces.py @@ -32,8 +32,8 @@ InvalidTOTPError, LimitReachedError, ListDevicesOkResult, + NormalisedTOTPConfig, RemoveDeviceOkResult, - TOTPNormalisedConfig, UnknownDeviceError, UnknownUserIdError, UpdateDeviceOkResult, @@ -131,7 +131,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: TOTPNormalisedConfig, + config: NormalisedTOTPConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, recipe_instance: TOTPRecipe, diff --git a/supertokens_python/recipe/totp/recipe.py b/supertokens_python/recipe/totp/recipe.py index ea2013862..855e615b9 100644 --- a/supertokens_python/recipe/totp/recipe.py +++ b/supertokens_python/recipe/totp/recipe.py @@ -14,10 +14,9 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.multifactorauth.types import ( GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, @@ -202,10 +201,12 @@ def get_all_cors_headers(self) -> List[str]: def init( config: Union[TOTPConfig, None] = None, ): - def func(app_info: AppInfo, plugins: Optional[List[OverrideMap]] = None): - if plugins is None: - plugins = [] + from supertokens_python.plugins import OverrideMap, apply_plugins + if config is None: + config = TOTPConfig() + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if TOTPRecipe.__instance is None: TOTPRecipe.__instance = TOTPRecipe( recipe_id=TOTPRecipe.recipe_id, diff --git a/supertokens_python/recipe/totp/recipe_implementation.py b/supertokens_python/recipe/totp/recipe_implementation.py index 93ce02a2a..3a16f399b 100644 --- a/supertokens_python/recipe/totp/recipe_implementation.py +++ b/supertokens_python/recipe/totp/recipe_implementation.py @@ -29,8 +29,8 @@ InvalidTOTPError, LimitReachedError, ListDevicesOkResult, + NormalisedTOTPConfig, RemoveDeviceOkResult, - TOTPNormalisedConfig, UnknownDeviceError, UnknownUserIdError, UpdateDeviceOkResult, @@ -48,7 +48,7 @@ class RecipeImplementation(RecipeInterface): def __init__( self, querier: Querier, - config: TOTPNormalisedConfig, + config: NormalisedTOTPConfig, ): super().__init__() self.querier = querier diff --git a/supertokens_python/recipe/totp/types.py b/supertokens_python/recipe/totp/types.py index cceb22b18..e9b521766 100644 --- a/supertokens_python/recipe/totp/types.py +++ b/supertokens_python/recipe/totp/types.py @@ -18,12 +18,11 @@ from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) from supertokens_python.types.response import APIResponse -from supertokens_python.types.utils import UseDefaultIfNone from .interfaces import APIInterface, RecipeInterface @@ -184,21 +183,19 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class OverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... +TOTPOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedTOTPOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class NormalisedOverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... - - -class TOTPConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class TOTPConfig(BaseConfig[RecipeInterface, APIInterface]): issuer: Optional[str] = None default_skew: Optional[int] = None default_period: Optional[int] = None - override: UseDefaultIfNone[Optional[OverrideConfig]] = OverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class TOTPNormalisedConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedTOTPConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): issuer: str default_skew: int default_period: int - override: NormalisedOverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 diff --git a/supertokens_python/recipe/totp/utils.py b/supertokens_python/recipe/totp/utils.py index 5d3e04e1b..12b49ba08 100644 --- a/supertokens_python/recipe/totp/utils.py +++ b/supertokens_python/recipe/totp/utils.py @@ -17,15 +17,15 @@ from supertokens_python import AppInfo from .types import ( - NormalisedOverrideConfig, + NormalisedTOTPConfig, + NormalisedTOTPOverrideConfig, TOTPConfig, - TOTPNormalisedConfig, ) def validate_and_normalise_user_input( app_info: AppInfo, config: Union[TOTPConfig, None] -) -> TOTPNormalisedConfig: +) -> NormalisedTOTPConfig: if config is None: config = TOTPConfig() @@ -33,7 +33,7 @@ def validate_and_normalise_user_input( default_skew = config.default_skew if config.default_skew is not None else 1 default_period = config.default_period if config.default_period is not None else 30 - override_config = NormalisedOverrideConfig() + override_config = NormalisedTOTPOverrideConfig() if config.override is not None: if config.override.functions is not None: override_config.functions = config.override.functions @@ -41,7 +41,7 @@ def validate_and_normalise_user_input( if config.override.apis is not None: override_config.apis = config.override.apis - return TOTPNormalisedConfig( + return NormalisedTOTPConfig( issuer=issuer, default_skew=default_skew, default_period=default_period, diff --git a/supertokens_python/recipe/usermetadata/__init__.py b/supertokens_python/recipe/usermetadata/__init__.py index ae632b781..0f81b5a93 100644 --- a/supertokens_python/recipe/usermetadata/__init__.py +++ b/supertokens_python/recipe/usermetadata/__init__.py @@ -23,6 +23,6 @@ def init( - override: Union[utils.InputOverrideConfig, None] = None, + override: Union[utils.UserMetadataOverrideConfig, None] = None, ) -> RecipeInit: return UserMetadataRecipe.init(override) diff --git a/supertokens_python/recipe/usermetadata/recipe.py b/supertokens_python/recipe/usermetadata/recipe.py index c75894ce9..fc5437b18 100644 --- a/supertokens_python/recipe/usermetadata/recipe.py +++ b/supertokens_python/recipe/usermetadata/recipe.py @@ -20,7 +20,6 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.usermetadata.exceptions import ( SuperTokensUserMetadataError, @@ -36,7 +35,7 @@ if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo -from .utils import InputOverrideConfig, UserMetadataInputConfig +from .utils import UserMetadataConfig, UserMetadataOverrideConfig class UserMetadataRecipe(RecipeModule): @@ -47,11 +46,11 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: UserMetadataInputConfig, + config: UserMetadataConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - _recipe=self, _app_info=app_info, input_config=input_config + _recipe=self, _app_info=app_info, input_config=config ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) self.recipe_implementation = self.config.override.functions( @@ -91,17 +90,19 @@ def get_all_cors_headers(self) -> List[str]: return [] @staticmethod - def init(override: Union[InputOverrideConfig, None] = None): - input_config = UserMetadataInputConfig(override=override) + def init(override: Union[UserMetadataOverrideConfig, None] = None): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = UserMetadataConfig(override=override) def func(app_info: AppInfo, plugins: List[OverrideMap]): if UserMetadataRecipe.__instance is None: UserMetadataRecipe.__instance = UserMetadataRecipe( recipe_id=UserMetadataRecipe.recipe_id, app_info=app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=UserMetadataRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/usermetadata/utils.py b/supertokens_python/recipe/usermetadata/utils.py index c8bf260bd..385cdf782 100644 --- a/supertokens_python/recipe/usermetadata/utils.py +++ b/supertokens_python/recipe/usermetadata/utils.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from supertokens_python.recipe.usermetadata.interfaces import ( APIInterface, @@ -22,37 +22,36 @@ ) from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone if TYPE_CHECKING: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.supertokens import AppInfo -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... +UserMetadataOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedUserMetadataOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... +class UserMetadataConfig(BaseConfig[RecipeInterface, APIInterface]): ... -class UserMetadataInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 - - -class UserMetadataConfig(BaseConfig[RecipeInterface, APIInterface]): - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 +class NormalisedUserMetadataConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): ... def validate_and_normalise_user_input( _recipe: UserMetadataRecipe, _app_info: AppInfo, - input_config: UserMetadataInputConfig, -) -> UserMetadataConfig: - override_config = OverrideConfig() + input_config: UserMetadataConfig, +) -> NormalisedUserMetadataConfig: + override_config = NormalisedUserMetadataOverrideConfig() if input_config.override is not None: if input_config.override.functions is not None: override_config.functions = input_config.override.functions @@ -60,4 +59,4 @@ def validate_and_normalise_user_input( if input_config.override.apis is not None: override_config.apis = input_config.override.apis - return UserMetadataConfig(override=override_config) + return NormalisedUserMetadataConfig(override=override_config) diff --git a/supertokens_python/recipe/userroles/__init__.py b/supertokens_python/recipe/userroles/__init__.py index 52a88d08e..c3f790100 100644 --- a/supertokens_python/recipe/userroles/__init__.py +++ b/supertokens_python/recipe/userroles/__init__.py @@ -28,7 +28,7 @@ def init( skip_adding_roles_to_access_token: Optional[bool] = None, skip_adding_permissions_to_access_token: Optional[bool] = None, - override: Union[utils.InputOverrideConfig, None] = None, + override: Union[utils.UserRolesOverrideConfig, None] = None, ) -> RecipeInit: return UserRolesRecipe.init( skip_adding_roles_to_access_token, diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index 6b62d744d..a8971d747 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -20,7 +20,6 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.querier import Querier from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.recipe.userroles.recipe_implementation import ( @@ -36,7 +35,7 @@ from ..session.claim_base_classes.primitive_array_claim import PrimitiveArrayClaim from .exceptions import SuperTokensUserRolesError from .interfaces import GetPermissionsForRoleOkResult, UnknownRoleError -from .utils import InputOverrideConfig, UserRolesInputConfig +from .utils import UserRolesConfig, UserRolesOverrideConfig class UserRolesRecipe(RecipeModule): @@ -47,7 +46,7 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - input_config: UserRolesInputConfig, + config: UserRolesConfig, ): from ..oauth2provider.recipe import OAuth2ProviderRecipe @@ -55,7 +54,7 @@ def __init__( self.config = validate_and_normalise_user_input( _recipe=self, _app_info=app_info, - input_config=input_config, + config=config, ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) self.recipe_implementation = self.config.override.functions( @@ -211,9 +210,11 @@ def get_all_cors_headers(self) -> List[str]: def init( skip_adding_roles_to_access_token: Optional[bool] = None, skip_adding_permissions_to_access_token: Optional[bool] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[UserRolesOverrideConfig, None] = None, ): - input_config = UserRolesInputConfig( + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = UserRolesConfig( skip_adding_roles_to_access_token=skip_adding_roles_to_access_token, skip_adding_permissions_to_access_token=skip_adding_permissions_to_access_token, override=override, @@ -224,9 +225,9 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): UserRolesRecipe.__instance = UserRolesRecipe( recipe_id=UserRolesRecipe.recipe_id, app_info=app_info, - input_config=apply_plugins( + config=apply_plugins( recipe_id=UserRolesRecipe.recipe_id, - config=input_config, + config=config, plugins=plugins, ), ) diff --git a/supertokens_python/recipe/userroles/utils.py b/supertokens_python/recipe/userroles/utils.py index 707640e76..8680481ed 100644 --- a/supertokens_python/recipe/userroles/utils.py +++ b/supertokens_python/recipe/userroles/utils.py @@ -20,61 +20,55 @@ from supertokens_python.supertokens import AppInfo from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) -from supertokens_python.types.utils import UseDefaultIfNone if TYPE_CHECKING: from supertokens_python.recipe.userroles.recipe import UserRolesRecipe -class InputOverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... +UserRolesOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedUserRolesOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class OverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... - - -class UserRolesInputConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class UserRolesConfig(BaseConfig[RecipeInterface, APIInterface]): skip_adding_roles_to_access_token: Optional[bool] = None skip_adding_permissions_to_access_token: Optional[bool] = None - override: UseDefaultIfNone[Optional[InputOverrideConfig]] = InputOverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class UserRolesConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedUserRolesConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): skip_adding_roles_to_access_token: bool skip_adding_permissions_to_access_token: bool - override: OverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 def validate_and_normalise_user_input( _recipe: UserRolesRecipe, _app_info: AppInfo, - input_config: UserRolesInputConfig, - # skip_adding_roles_to_access_token: Optional[bool] = None, - # skip_adding_permissions_to_access_token: Optional[bool] = None, - # override: Union[InputOverrideConfig, None] = None, -) -> UserRolesConfig: - override_config = OverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions - - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis - - skip_adding_roles_to_access_token = input_config.skip_adding_roles_to_access_token + config: UserRolesConfig, +) -> NormalisedUserRolesConfig: + override_config = NormalisedUserRolesOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions + + if config.override.apis is not None: + override_config.apis = config.override.apis + + skip_adding_roles_to_access_token = config.skip_adding_roles_to_access_token if skip_adding_roles_to_access_token is None: skip_adding_roles_to_access_token = False skip_adding_permissions_to_access_token = ( - input_config.skip_adding_permissions_to_access_token + config.skip_adding_permissions_to_access_token ) if skip_adding_permissions_to_access_token is None: skip_adding_permissions_to_access_token = False - return UserRolesConfig( + return NormalisedUserRolesConfig( skip_adding_roles_to_access_token=skip_adding_roles_to_access_token, skip_adding_permissions_to_access_token=skip_adding_permissions_to_access_token, override=override_config, diff --git a/supertokens_python/recipe/webauthn/__init__.py b/supertokens_python/recipe/webauthn/__init__.py index 8f48fba47..7ad1a28a7 100644 --- a/supertokens_python/recipe/webauthn/__init__.py +++ b/supertokens_python/recipe/webauthn/__init__.py @@ -39,8 +39,8 @@ from supertokens_python.recipe.webauthn.recipe import WebauthnRecipe from supertokens_python.recipe.webauthn.types.config import ( NormalisedWebauthnConfig, - OverrideConfig, WebauthnConfig, + WebauthnOverrideConfig, ) # Some Pydantic models need a rebuild to resolve ForwardRefs @@ -63,7 +63,7 @@ def init(config: Optional[WebauthnConfig] = None): "init", "APIInterface", "RecipeInterface", - "OverrideConfig", + "WebauthnOverrideConfig", "WebauthnConfig", "WebauthnRecipe", "consume_recover_account_token", diff --git a/supertokens_python/recipe/webauthn/recipe.py b/supertokens_python/recipe/webauthn/recipe.py index bed22a345..455319e96 100644 --- a/supertokens_python/recipe/webauthn/recipe.py +++ b/supertokens_python/recipe/webauthn/recipe.py @@ -21,7 +21,6 @@ from supertokens_python.framework.response import BaseResponse from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.querier import Querier from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe @@ -296,6 +295,11 @@ def get_instance_optional() -> Optional["WebauthnRecipe"]: @staticmethod def init(config: Optional[WebauthnConfig]): + from supertokens_python.plugins import OverrideMap, apply_plugins + + if config is None: + config = WebauthnConfig() + def func(app_info: AppInfo, plugins: List[OverrideMap]): if WebauthnRecipe.__instance is None: WebauthnRecipe.__instance = WebauthnRecipe( diff --git a/supertokens_python/recipe/webauthn/types/config.py b/supertokens_python/recipe/webauthn/types/config.py index 8e3835b12..d2132cf1d 100644 --- a/supertokens_python/recipe/webauthn/types/config.py +++ b/supertokens_python/recipe/webauthn/types/config.py @@ -30,12 +30,11 @@ from supertokens_python.types.base import UserContext from supertokens_python.types.config import ( BaseConfig, - BaseInputConfig, - BaseInputOverrideConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, BaseOverrideConfig, ) from supertokens_python.types.response import CamelCaseBaseModel -from supertokens_python.types.utils import UseDefaultIfNone InterfaceType = TypeVar("InterfaceType") """Generic Type for use in `InterfaceOverride`""" @@ -184,28 +183,26 @@ def __call__( ) -> InterfaceType: ... -class OverrideConfig(BaseInputOverrideConfig[RecipeInterface, APIInterface]): ... +WebauthnOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedWebauthnOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -class NormalisedOverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): ... - - -class WebauthnConfig(BaseInputConfig[RecipeInterface, APIInterface]): +class WebauthnConfig(BaseConfig[RecipeInterface, APIInterface]): get_relying_party_id: Optional[Union[str, GetRelyingPartyId]] = None get_relying_party_name: Optional[Union[str, GetRelyingPartyName]] = None get_origin: Optional[GetOrigin] = None email_delivery: Optional[EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]] = None validate_email_address: Optional[ValidateEmailAddress] = None - override: UseDefaultIfNone[Optional[OverrideConfig]] = OverrideConfig() # type: ignore - https://github.com/microsoft/pyright/issues/5933 -class NormalisedWebauthnConfig(BaseConfig[RecipeInterface, APIInterface]): +class NormalisedWebauthnConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): get_relying_party_id: NormalisedGetRelyingPartyId get_relying_party_name: NormalisedGetRelyingPartyName get_origin: NormalisedGetOrigin get_email_delivery_config: NormalisedGetEmailDeliveryConfig validate_email_address: NormalisedValidateEmailAddress - override: NormalisedOverrideConfig # type: ignore - https://github.com/microsoft/pyright/issues/5933 class WebauthnIngredients(CamelCaseBaseModel): diff --git a/supertokens_python/recipe/webauthn/utils.py b/supertokens_python/recipe/webauthn/utils.py index dc0499354..187494203 100644 --- a/supertokens_python/recipe/webauthn/utils.py +++ b/supertokens_python/recipe/webauthn/utils.py @@ -33,9 +33,9 @@ NormalisedGetOrigin, NormalisedGetRelyingPartyId, NormalisedGetRelyingPartyName, - NormalisedOverrideConfig, NormalisedValidateEmailAddress, NormalisedWebauthnConfig, + NormalisedWebauthnOverrideConfig, ValidateEmailAddress, WebauthnConfig, ) @@ -60,7 +60,7 @@ def validate_and_normalise_user_input( config.validate_email_address ) - override_config = NormalisedOverrideConfig() + override_config = NormalisedWebauthnOverrideConfig() if config.override is not None: if config.override.functions is not None: override_config.functions = config.override.functions diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 87deed255..ad6aa9c38 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -36,13 +36,6 @@ get_maybe_none_as_str, log_debug_message, ) -from supertokens_python.plugins import ( - OverrideMap, - PluginRouteHandler, - SuperTokensPlugin, - SuperTokensPublicPlugin, - load_plugins, -) from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT from .exceptions import SuperTokensError @@ -71,6 +64,13 @@ if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse + from supertokens_python.plugins import ( + OverrideMap, + PluginRouteHandler, + SuperTokensPlugin, + SuperTokensPublicPlugin, + load_plugins, + ) from supertokens_python.recipe.session import SessionContainer from .recipe_module import RecipeModule diff --git a/supertokens_python/types/config.py b/supertokens_python/types/config.py index 8ad5c3d95..10752a02d 100644 --- a/supertokens_python/types/config.py +++ b/supertokens_python/types/config.py @@ -6,9 +6,6 @@ T = TypeVar("T") -# InterfaceType = TypeVar( -# "InterfaceType", bound=Union[BaseRecipeInterface, BaseAPIInterface], covariant=True -# ) """Generic Type for use in `InterfaceOverride`""" FunctionInterfaceType = TypeVar("FunctionInterfaceType", bound=BaseRecipeInterface) """Generic Type for use in `FunctionOverrideConfig`""" @@ -19,9 +16,7 @@ InterfaceOverride = Callable[[T], T] -class BaseInputOverrideConfigWithoutAPI( - CamelCaseBaseModel, Generic[FunctionInterfaceType] -): +class OverrideConfigWithoutAPI(CamelCaseBaseModel, Generic[FunctionInterfaceType]): """Base class for input override config without API overrides.""" functions: UseDefaultIfNone[Optional[InterfaceOverride[FunctionInterfaceType]]] = ( @@ -29,7 +24,9 @@ class BaseInputOverrideConfigWithoutAPI( ) -class BaseOverrideConfigWithoutAPI(CamelCaseBaseModel, Generic[FunctionInterfaceType]): +class NormalisedOverrideConfigWithoutAPI( + CamelCaseBaseModel, Generic[FunctionInterfaceType] +): """Base class for normalized override config without API overrides.""" functions: InterfaceOverride[FunctionInterfaceType] = ( @@ -37,8 +34,8 @@ class BaseOverrideConfigWithoutAPI(CamelCaseBaseModel, Generic[FunctionInterface ) -class BaseInputOverrideConfig( - BaseInputOverrideConfigWithoutAPI[FunctionInterfaceType], +class BaseOverrideConfig( + OverrideConfigWithoutAPI[FunctionInterfaceType], Generic[FunctionInterfaceType, APIInterfaceType], ): """Base class for input override config with API overrides.""" @@ -48,8 +45,8 @@ class BaseInputOverrideConfig( ) -class BaseOverrideConfig( - BaseOverrideConfigWithoutAPI[FunctionInterfaceType], +class BaseNormalisedOverrideConfig( + NormalisedOverrideConfigWithoutAPI[FunctionInterfaceType], Generic[FunctionInterfaceType, APIInterfaceType], ): """Base class for normalized override config with API overrides.""" @@ -59,31 +56,31 @@ class BaseOverrideConfig( ) -class BaseInputConfigWithoutAPIOverride( - CamelCaseBaseModel, Generic[FunctionInterfaceType] -): +class BaseConfigWithoutAPIOverride(CamelCaseBaseModel, Generic[FunctionInterfaceType]): """Base class for input config of a Recipe without API overrides.""" - override: Optional[BaseInputOverrideConfigWithoutAPI[FunctionInterfaceType]] = None + override: Optional[OverrideConfigWithoutAPI[FunctionInterfaceType]] = None -class BaseConfigWithoutAPIOverride(CamelCaseBaseModel, Generic[FunctionInterfaceType]): +class BaseNormalisedConfigWithoutAPIOverride( + CamelCaseBaseModel, Generic[FunctionInterfaceType] +): """Base class for normalized config of a Recipe without API overrides.""" - override: BaseOverrideConfigWithoutAPI[FunctionInterfaceType] + override: NormalisedOverrideConfigWithoutAPI[FunctionInterfaceType] -class BaseInputConfig( - CamelCaseBaseModel, Generic[FunctionInterfaceType, APIInterfaceType] -): +class BaseConfig(CamelCaseBaseModel, Generic[FunctionInterfaceType, APIInterfaceType]): """Base class for input config of a Recipe with API overrides.""" - override: Optional[ - BaseInputOverrideConfig[FunctionInterfaceType, APIInterfaceType] - ] = None + override: Optional[BaseOverrideConfig[FunctionInterfaceType, APIInterfaceType]] = ( + None + ) -class BaseConfig(CamelCaseBaseModel, Generic[FunctionInterfaceType, APIInterfaceType]): +class BaseNormalisedConfig( + CamelCaseBaseModel, Generic[FunctionInterfaceType, APIInterfaceType] +): """Base class for normalized config of a Recipe with API overrides.""" - override: BaseOverrideConfig[FunctionInterfaceType, APIInterfaceType] + override: BaseNormalisedOverrideConfig[FunctionInterfaceType, APIInterfaceType] diff --git a/tests/Django/test_django.py b/tests/Django/test_django.py index f294af2b5..c0036c607 100644 --- a/tests/Django/test_django.py +++ b/tests/Django/test_django.py @@ -32,9 +32,9 @@ ) from supertokens_python.querier import Querier from supertokens_python.recipe import emailpassword, session, thirdparty -from supertokens_python.recipe.dashboard import DashboardRecipe, InputOverrideConfig +from supertokens_python.recipe.dashboard import DashboardOverrideConfig, DashboardRecipe from supertokens_python.recipe.dashboard.interfaces import RecipeInterface -from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.dashboard.utils import NormalisedDashboardConfig from supertokens_python.recipe.emailpassword.interfaces import APIInterface, APIOptions from supertokens_python.recipe.passwordless import ContactConfig, PasswordlessRecipe from supertokens_python.recipe.session import SessionContainer @@ -53,7 +53,7 @@ def override_dashboard_functions(original_implementation: RecipeInterface): async def should_allow_access( - request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any] + request: BaseRequest, __: NormalisedDashboardConfig, ___: Dict[str, Any] ): auth_header = request.get_header("authorization") return auth_header == "Bearer testapikey" @@ -370,7 +370,7 @@ async def email_exists_get( mode="asgi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis ) ) @@ -497,7 +497,7 @@ async def test_search_with_multiple_emails(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -548,7 +548,7 @@ async def test_search_with_email_t(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -597,7 +597,7 @@ async def test_search_with_email_iresh(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -648,7 +648,7 @@ async def test_search_with_phone_plus_one(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -702,7 +702,7 @@ async def test_search_with_phone_one_bracket(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -799,7 +799,7 @@ async def test_search_with_provider_google(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -892,7 +892,7 @@ async def test_search_with_provider_google_and_phone_one(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), diff --git a/tests/Fastapi/test_fastapi.py b/tests/Fastapi/test_fastapi.py index 4741713d5..b9c35c01a 100644 --- a/tests/Fastapi/test_fastapi.py +++ b/tests/Fastapi/test_fastapi.py @@ -22,9 +22,9 @@ from supertokens_python.framework.fastapi import get_middleware from supertokens_python.querier import Querier from supertokens_python.recipe import emailpassword, session, thirdparty -from supertokens_python.recipe.dashboard import DashboardRecipe, InputOverrideConfig +from supertokens_python.recipe.dashboard import DashboardOverrideConfig, DashboardRecipe from supertokens_python.recipe.dashboard.interfaces import RecipeInterface -from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.dashboard.utils import NormalisedDashboardConfig from supertokens_python.recipe.emailpassword.interfaces import ( APIInterface as EPAPIInterface, ) @@ -60,7 +60,7 @@ def override_dashboard_functions(original_implementation: RecipeInterface): async def should_allow_access( - request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any] + request: BaseRequest, __: NormalisedDashboardConfig, ___: Dict[str, Any] ): auth_header = request.get_header("authorization") return auth_header == "Bearer testapikey" @@ -157,7 +157,7 @@ async def test_login_refresh(driver_config_client: TestClient): anti_csrf="VIA_TOKEN", cookie_domain="supertokens.io", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], ) @@ -462,7 +462,7 @@ async def email_exists_get( framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis ) ) @@ -570,7 +570,7 @@ async def refresh_post( recipe_list=[ session.init( anti_csrf="VIA_TOKEN", - override=session.InputOverrideConfig(apis=override_session_apis), + override=session.SessionOverrideConfig(apis=override_session_apis), ) ], ) @@ -686,7 +686,9 @@ async def test_search_with_email_t(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), emailpassword.init(), ], @@ -731,7 +733,9 @@ async def test_search_with_email_multiple_email_entry(driver_config_client: Test ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), emailpassword.init(), ], @@ -776,7 +780,9 @@ async def test_search_with_email_iresh(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), emailpassword.init(), ], @@ -821,7 +827,9 @@ async def test_search_with_phone_plus_one(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), PasswordlessRecipe.init( contact_config=ContactConfig(contact_method="EMAIL"), @@ -869,7 +877,9 @@ async def test_search_with_phone_one_bracket(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), PasswordlessRecipe.init( contact_config=ContactConfig(contact_method="EMAIL"), @@ -917,7 +927,9 @@ async def test_search_with_provider_google(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature( @@ -1006,7 +1018,9 @@ async def test_search_with_provider_google_and_phone_1( ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), PasswordlessRecipe.init( contact_config=ContactConfig(contact_method="EMAIL"), diff --git a/tests/Flask/test_flask.py b/tests/Flask/test_flask.py index bdcc50658..40671df58 100644 --- a/tests/Flask/test_flask.py +++ b/tests/Flask/test_flask.py @@ -28,9 +28,9 @@ ) from supertokens_python.querier import Querier from supertokens_python.recipe import emailpassword, session, thirdparty -from supertokens_python.recipe.dashboard import DashboardRecipe, InputOverrideConfig +from supertokens_python.recipe.dashboard import DashboardOverrideConfig, DashboardRecipe from supertokens_python.recipe.dashboard.interfaces import RecipeInterface -from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.dashboard.utils import NormalisedDashboardConfig from supertokens_python.recipe.emailpassword.interfaces import APIInterface, APIOptions from supertokens_python.recipe.passwordless import ContactConfig, PasswordlessRecipe from supertokens_python.recipe.session import SessionContainer @@ -81,7 +81,7 @@ async def email_exists_get( def override_dashboard_functions(original_implementation: RecipeInterface): async def should_allow_access( - request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any] + request: BaseRequest, __: NormalisedDashboardConfig, ___: Dict[str, Any] ): auth_header = request.get_header("authorization") return auth_header == "Bearer testapikey" @@ -110,7 +110,7 @@ async def should_allow_access( get_token_transfer_method=lambda _, __, ___: "cookie", ), emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis ) ), @@ -159,7 +159,9 @@ async def should_allow_access( ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), PasswordlessRecipe.init( contact_config=ContactConfig(contact_method="EMAIL"), diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index 97335f18e..42cfe682d 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -806,7 +806,7 @@ async def resend_code_post( contact_config=ContactPhoneOnlyConfig(), flow_type=passwordlessFlowType, # type: ignore - type expects only certain literals sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -817,7 +817,7 @@ async def resend_code_post( email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -829,7 +829,7 @@ async def resend_code_post( CustomPlessEmailService() ), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -839,7 +839,9 @@ async def resend_code_post( flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), + override=passwordless.PasswordlessOverrideConfig( + apis=override_passwordless_apis + ), ) async def get_allowed_domains_for_tenant_id( @@ -968,7 +970,7 @@ async def resync_session_and_fetch_mfa_info_put( { "id": "session", "init": session.init( - override=session.InputOverrideConfig(apis=override_session_apis) + override=session.SessionOverrideConfig(apis=override_session_apis) ), }, { @@ -988,7 +990,7 @@ async def resync_session_and_fetch_mfa_info_put( email_delivery=emailpassword.EmailDeliveryConfig( CustomEPEmailService() ), - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis, ), ), @@ -1007,7 +1009,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "thirdparty", "init": thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + override=thirdparty.ThirdPartyOverrideConfig( + apis=override_thirdparty_apis + ), ), }, { @@ -1024,7 +1028,7 @@ async def resync_session_and_fetch_mfa_info_put( "id": "multifactorauth", "init": multifactorauth.init( first_factors=mfaInfo.get("firstFactors", None), - override=multifactorauth.OverrideConfig( + override=multifactorauth.MultiFactorAuthOverrideConfig( functions=override_mfa_functions, apis=override_mfa_apis, ), diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index c09a67f23..c22a5e5b2 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -907,7 +907,7 @@ async def resend_code_post( contact_config=ContactPhoneOnlyConfig(), flow_type=passwordlessFlowType, # type: ignore - type expects only certain literals sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -918,7 +918,7 @@ async def resend_code_post( email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -930,7 +930,7 @@ async def resend_code_post( CustomPlessEmailService() ), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -940,7 +940,9 @@ async def resend_code_post( flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), + override=passwordless.PasswordlessOverrideConfig( + apis=override_passwordless_apis + ), ) async def get_allowed_domains_for_tenant_id( @@ -1069,7 +1071,7 @@ async def resync_session_and_fetch_mfa_info_put( { "id": "session", "init": session.init( - override=session.InputOverrideConfig(apis=override_session_apis) + override=session.SessionOverrideConfig(apis=override_session_apis) ), }, { @@ -1089,7 +1091,7 @@ async def resync_session_and_fetch_mfa_info_put( email_delivery=emailpassword.EmailDeliveryConfig( CustomEPEmailService() ), - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis, ), ), @@ -1108,7 +1110,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "thirdparty", "init": thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + override=thirdparty.ThirdPartyOverrideConfig( + apis=override_thirdparty_apis + ), ), }, { @@ -1125,7 +1129,7 @@ async def resync_session_and_fetch_mfa_info_put( "id": "multifactorauth", "init": multifactorauth.init( first_factors=mfaInfo.get("firstFactors", None), - override=multifactorauth.OverrideConfig( + override=multifactorauth.MultiFactorAuthOverrideConfig( functions=override_mfa_functions, apis=override_mfa_apis, ), diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 4ea94a5de..6eff7f65d 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -886,7 +886,7 @@ async def resend_code_post( contact_config=ContactPhoneOnlyConfig(), flow_type=passwordlessFlowType, # type: ignore - type expects only certain literals sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -897,7 +897,7 @@ async def resend_code_post( email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -909,7 +909,7 @@ async def resend_code_post( CustomPlessEmailService() ), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -919,7 +919,9 @@ async def resend_code_post( flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), + override=passwordless.PasswordlessOverrideConfig( + apis=override_passwordless_apis + ), ) async def get_allowed_domains_for_tenant_id( @@ -1048,7 +1050,7 @@ async def resync_session_and_fetch_mfa_info_put( { "id": "session", "init": session.init( - override=session.InputOverrideConfig(apis=override_session_apis) + override=session.SessionOverrideConfig(apis=override_session_apis) ), }, { @@ -1068,7 +1070,7 @@ async def resync_session_and_fetch_mfa_info_put( email_delivery=emailpassword.EmailDeliveryConfig( CustomEPEmailService() ), - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis, ), ), @@ -1087,7 +1089,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "thirdparty", "init": thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + override=thirdparty.ThirdPartyOverrideConfig( + apis=override_thirdparty_apis + ), ), }, { @@ -1104,7 +1108,7 @@ async def resync_session_and_fetch_mfa_info_put( "id": "multifactorauth", "init": multifactorauth.init( first_factors=mfaInfo.get("firstFactors", None), - override=multifactorauth.OverrideConfig( + override=multifactorauth.MultiFactorAuthOverrideConfig( functions=override_mfa_functions, apis=override_mfa_apis, ), diff --git a/tests/dashboard/test_dashboard.py b/tests/dashboard/test_dashboard.py index ac348ed98..4adfd5e48 100644 --- a/tests/dashboard/test_dashboard.py +++ b/tests/dashboard/test_dashboard.py @@ -14,11 +14,11 @@ thirdparty, usermetadata, ) -from supertokens_python.recipe.dashboard import InputOverrideConfig +from supertokens_python.recipe.dashboard import DashboardOverrideConfig from supertokens_python.recipe.dashboard.interfaces import ( RecipeInterface as DashboardRI, ) -from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.dashboard.utils import NormalisedDashboardConfig from supertokens_python.recipe.passwordless import ContactEmailOrPhoneConfig from supertokens_python.recipe.thirdparty.asyncio import manually_create_or_update_user from supertokens_python.recipe.thirdparty.interfaces import ( @@ -49,7 +49,7 @@ async def test_dashboard_recipe(app: TestClient): def override_dashboard_functions(oi: DashboardRI) -> DashboardRI: async def should_allow_access( _request: BaseRequest, - _config: DashboardConfig, + _config: NormalisedDashboardConfig, _user_context: Dict[str, Any], ) -> bool: return True @@ -63,7 +63,9 @@ async def should_allow_access( session.init(get_token_transfer_method=lambda _, __, ___: "cookie"), dashboard.init( api_key="someKey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), ], ) @@ -83,7 +85,7 @@ async def test_dashboard_users_get(app: TestClient): def override_dashboard_functions(oi: DashboardRI) -> DashboardRI: async def should_allow_access( _request: BaseRequest, - _config: DashboardConfig, + _config: NormalisedDashboardConfig, _user_context: Dict[str, Any], ) -> bool: return True @@ -99,7 +101,7 @@ async def should_allow_access( usermetadata.init(), dashboard.init( api_key="someKey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions, ), ), @@ -211,7 +213,7 @@ async def test_that_get_user_works_with_combination_recipes(app: TestClient): def override_dashboard_functions(oi: DashboardRI) -> DashboardRI: async def should_allow_access( _request: BaseRequest, - _config: DashboardConfig, + _config: NormalisedDashboardConfig, _user_context: Dict[str, Any], ) -> bool: return True @@ -231,7 +233,7 @@ async def should_allow_access( usermetadata.init(), dashboard.init( api_key="someKey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions, ), ), diff --git a/tests/emailpassword/test_emailexists.py b/tests/emailpassword/test_emailexists.py index 8536b6d0a..1dfa97f9d 100644 --- a/tests/emailpassword/test_emailexists.py +++ b/tests/emailpassword/test_emailexists.py @@ -188,7 +188,7 @@ async def test_that_if_disabling_api_the_default_email_exists_api_does_not_work( recipe_list=[ session.init(anti_csrf="VIA_TOKEN", cookie_domain="supertokens.io"), emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password ) ), diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index 1b3bcb12d..1ac3a3b97 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -44,7 +44,9 @@ from supertokens_python.recipe.emailverification.types import ( EmailVerificationUser as EVUser, ) -from supertokens_python.recipe.emailverification.utils import OverrideConfig +from supertokens_python.recipe.emailverification.utils import ( + EmailVerificationOverrideConfig, +) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.asyncio import ( create_new_session, @@ -685,7 +687,9 @@ async def email_verify_post( email_delivery=emailverification.EmailDeliveryConfig( CustomEmailService() ), - override=OverrideConfig(apis=apis_override_email_password), + override=EmailVerificationOverrideConfig( + apis=apis_override_email_password + ), ), emailpassword.init(), ], @@ -917,7 +921,9 @@ async def email_verify_post( email_delivery=emailverification.EmailDeliveryConfig( CustomEmailService() ), - override=OverrideConfig(apis=apis_override_email_password), + override=EmailVerificationOverrideConfig( + apis=apis_override_email_password + ), ), emailpassword.init(), ], @@ -1025,7 +1031,7 @@ async def email_verify_post( email_delivery=emailverification.EmailDeliveryConfig( CustomEmailService() ), - override=emailverification.InputOverrideConfig( + override=emailverification.EmailVerificationOverrideConfig( apis=apis_override_email_password ), ), diff --git a/tests/emailpassword/test_signin.py b/tests/emailpassword/test_signin.py index 6d2ce23c4..5a69bcca9 100644 --- a/tests/emailpassword/test_signin.py +++ b/tests/emailpassword/test_signin.py @@ -102,7 +102,7 @@ def apis_override_email_password(param: APIInterface): framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password ) ) diff --git a/tests/frontendIntegration/django2x/polls/views.py b/tests/frontendIntegration/django2x/polls/views.py index 3c3cc5dad..0859a964e 100644 --- a/tests/frontendIntegration/django2x/polls/views.py +++ b/tests/frontendIntegration/django2x/polls/views.py @@ -347,7 +347,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -371,7 +371,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -392,7 +392,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, diff --git a/tests/frontendIntegration/django3x/polls/views.py b/tests/frontendIntegration/django3x/polls/views.py index 0567a0c2f..5c06e1409 100644 --- a/tests/frontendIntegration/django3x/polls/views.py +++ b/tests/frontendIntegration/django3x/polls/views.py @@ -349,7 +349,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -373,7 +373,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -394,7 +394,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, diff --git a/tests/frontendIntegration/drf_async/mysite/settings.py b/tests/frontendIntegration/drf_async/mysite/settings.py index 72bf8d82f..ca43b8b8f 100644 --- a/tests/frontendIntegration/drf_async/mysite/settings.py +++ b/tests/frontendIntegration/drf_async/mysite/settings.py @@ -207,7 +207,7 @@ async def create_new_session_custom( framework="django", recipe_list=[ session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), diff --git a/tests/frontendIntegration/drf_async/polls/views.py b/tests/frontendIntegration/drf_async/polls/views.py index df1c4953d..e69f34107 100644 --- a/tests/frontendIntegration/drf_async/polls/views.py +++ b/tests/frontendIntegration/drf_async/polls/views.py @@ -374,7 +374,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -398,7 +398,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -419,7 +419,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, diff --git a/tests/frontendIntegration/drf_sync/mysite/settings.py b/tests/frontendIntegration/drf_sync/mysite/settings.py index 37c3dc651..70e011e9c 100644 --- a/tests/frontendIntegration/drf_sync/mysite/settings.py +++ b/tests/frontendIntegration/drf_sync/mysite/settings.py @@ -207,7 +207,7 @@ async def create_new_session_custom( framework="django", recipe_list=[ session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), diff --git a/tests/frontendIntegration/drf_sync/polls/views.py b/tests/frontendIntegration/drf_sync/polls/views.py index 0f869b3e2..932fe71bc 100644 --- a/tests/frontendIntegration/drf_sync/polls/views.py +++ b/tests/frontendIntegration/drf_sync/polls/views.py @@ -373,7 +373,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -397,7 +397,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -418,7 +418,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, diff --git a/tests/frontendIntegration/fastapi-server/app.py b/tests/frontendIntegration/fastapi-server/app.py index 3f6502a23..a7f02a383 100644 --- a/tests/frontendIntegration/fastapi-server/app.py +++ b/tests/frontendIntegration/fastapi-server/app.py @@ -199,7 +199,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -223,7 +223,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -244,7 +244,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, diff --git a/tests/frontendIntegration/flask-server/app.py b/tests/frontendIntegration/flask-server/app.py index 683242e30..3059bc3a5 100644 --- a/tests/frontendIntegration/flask-server/app.py +++ b/tests/frontendIntegration/flask-server/app.py @@ -230,7 +230,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -254,7 +254,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -275,7 +275,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, diff --git a/tests/jwt/test_get_JWKS.py b/tests/jwt/test_get_JWKS.py index dd0919d84..f091ab938 100644 --- a/tests/jwt/test_get_JWKS.py +++ b/tests/jwt/test_get_JWKS.py @@ -64,7 +64,7 @@ async def test_that_default_getJWKS_api_does_not_work_when_disabled( ), framework="fastapi", recipe_list=[ - jwt.init(override=jwt.OverrideConfig(apis=apis_override_get_JWKS)) + jwt.init(override=jwt.JWTOverrideConfig(apis=apis_override_get_JWKS)) ], ) @@ -96,7 +96,7 @@ async def get_jwks(user_context: Dict[str, Any]): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[jwt.init(override=jwt.OverrideConfig(functions=func_override))], + recipe_list=[jwt.init(override=jwt.JWTOverrideConfig(functions=func_override))], ) response = driver_config_client.get(url="/auth/jwt/jwks.json") diff --git a/tests/jwt/test_override.py b/tests/jwt/test_override.py index 856477f09..19f64c367 100644 --- a/tests/jwt/test_override.py +++ b/tests/jwt/test_override.py @@ -106,7 +106,9 @@ async def create_jwt_( website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[jwt.init(override=jwt.OverrideConfig(functions=custom_functions))], + recipe_list=[ + jwt.init(override=jwt.JWTOverrideConfig(functions=custom_functions)) + ], ) response = driver_config_client.post( @@ -147,7 +149,7 @@ async def get_jwks_get(api_options: APIOptions, user_context: Dict[str, Any]): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[jwt.init(override=jwt.OverrideConfig(apis=custom_api))], + recipe_list=[jwt.init(override=jwt.JWTOverrideConfig(apis=custom_api))], ) response = driver_config_client.get(url="/auth/jwt/jwks.json") diff --git a/tests/sessions/claims/test_create_new_session.py b/tests/sessions/claims/test_create_new_session.py index 62101fa96..c3efa537b 100644 --- a/tests/sessions/claims/test_create_new_session.py +++ b/tests/sessions/claims/test_create_new_session.py @@ -52,7 +52,7 @@ async def test_should_merge_claims_and_passed_access_token_payload_obj(timestamp url=get_new_core_app_url(), recipe_list=[ session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=session_functions_override_with_claim( TrueClaim, {"user-custom-claim": "foo"} ), diff --git a/tests/sessions/claims/test_verify_session.py b/tests/sessions/claims/test_verify_session.py index cbae9e39f..5bad87a72 100644 --- a/tests/sessions/claims/test_verify_session.py +++ b/tests/sessions/claims/test_verify_session.py @@ -57,7 +57,7 @@ async def new_get_global_claim_validators( session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=session_function_override ), ) @@ -85,7 +85,7 @@ async def new_get_global_claim_validators( session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=session_function_override ), ) diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index 91d50d128..1fba680c5 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -71,7 +71,7 @@ def get_st_init_args(claim: SessionClaim[Any] = TrueClaim): url=get_new_core_app_url(), recipe_list=[ session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=session_functions_override_with_claim(claim), ), ), diff --git a/tests/telemetry/test_telemetry.py b/tests/telemetry/test_telemetry.py index 05d2812f6..ff7472f79 100644 --- a/tests/telemetry/test_telemetry.py +++ b/tests/telemetry/test_telemetry.py @@ -40,7 +40,7 @@ async def test_telemetry(): session.init( anti_csrf="VIA_TOKEN", cookie_domain="supertokens.io", - override=session.InputOverrideConfig(), + override=session.SessionOverrideConfig(), ) ], telemetry=True, @@ -73,7 +73,7 @@ async def test_read_from_env(): session.init( anti_csrf="VIA_TOKEN", cookie_domain="supertokens.io", - override=session.InputOverrideConfig(), + override=session.SessionOverrideConfig(), ) ], ) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 820e1011c..3273a0a49 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -277,7 +277,7 @@ def init_st(config: Dict[str, Any]): ), ) ), - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_builder_with_logging( "EmailPassword.override.apis", recipe_config_json.get("override", {}).get("apis", None), @@ -322,7 +322,7 @@ async def custom_unauthorised_callback( use_dynamic_access_token_signing_key=recipe_config_json.get( "useDynamicAccessTokenSigningKey" ), - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=override_builder_with_logging( "Session.override.apis", recipe_config_json.get("override", {}).get("apis", None), @@ -352,7 +352,7 @@ async def custom_unauthorised_callback( "AccountLinking.onAccountLinked", recipe_config_json.get("onAccountLinked"), ), - override=accountlinking.InputOverrideConfig( + override=accountlinking.AccountLinkingOverrideConfig( functions=override_builder_with_logging( "AccountLinking.override.functions", recipe_config_json.get("override", {}).get( @@ -455,7 +455,7 @@ async def custom_unauthorised_callback( sign_in_and_up_feature=thirdparty.SignInAndUpFeature( providers=providers ), - override=thirdparty.InputOverrideConfig( + override=thirdparty.ThirdPartyOverrideConfig( functions=override_builder_with_logging( "ThirdParty.override.functions", recipe_config_json.get("override", {}).get( @@ -476,7 +476,7 @@ async def custom_unauthorised_callback( UnknownUserIdError, ) from supertokens_python.recipe.emailverification.utils import ( - OverrideConfig as EmailVerificationOverrideConfig, + EmailVerificationOverrideConfig as EmailVerificationOverrideConfig, ) recipe_list.append( @@ -518,7 +518,7 @@ async def custom_unauthorised_callback( recipe_list.append( multifactorauth.init( first_factors=recipe_config_json.get("firstFactors", None), - override=multifactorauth.OverrideConfig( + override=multifactorauth.MultiFactorAuthOverrideConfig( functions=override_builder_with_logging( "MultifactorAuth.override.functions", recipe_config_json.get("override", {}).get( @@ -574,7 +574,7 @@ async def send_sms( ), contact_config=contact_config, flow_type=recipe_config_json.get("flowType"), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_builder_with_logging( "Passwordless.override.apis", recipe_config_json.get("override", {}).get("apis"), @@ -587,9 +587,7 @@ async def send_sms( ) ) elif recipe_id == "totp": - from supertokens_python.recipe.totp.types import ( - OverrideConfig as TOTPOverrideConfig, - ) + from supertokens_python.recipe.totp.types import TOTPOverrideConfig recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( @@ -615,7 +613,7 @@ async def send_sms( recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( oauth2provider.init( - override=oauth2provider.InputOverrideConfig( + override=oauth2provider.OAuth2ProviderOverrideConfig( apis=override_builder_with_logging( "OAuth2Provider.override.apis", recipe_config_json.get("override", {}).get("apis"), @@ -629,7 +627,7 @@ async def send_sms( ) elif recipe_id == "webauthn": from supertokens_python.recipe.webauthn.types.config import ( - OverrideConfig as WebauthnOverrideConfig, + WebauthnOverrideConfig, ) class WebauthnEmailDeliveryConfig( diff --git a/tests/test_session.py b/tests/test_session.py index 7351c4a4b..9a8e91bd7 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -26,7 +26,7 @@ from supertokens_python.framework.fastapi.fastapi_middleware import get_middleware from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.recipe import session -from supertokens_python.recipe.session import InputOverrideConfig, SessionRecipe +from supertokens_python.recipe.session import SessionOverrideConfig, SessionRecipe from supertokens_python.recipe.session.asyncio import ( create_new_session as async_create_new_session, ) @@ -381,7 +381,7 @@ async def get_session_information( recipe_list=[ session.init( anti_csrf="VIA_TOKEN", - override=InputOverrideConfig( + override=SessionOverrideConfig( functions=override_session_functions, ), ) @@ -422,7 +422,7 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=session_api_override), + override=session.SessionOverrideConfig(apis=session_api_override), ) ], ) @@ -472,7 +472,7 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=session_api_override), + override=session.SessionOverrideConfig(apis=session_api_override), ) ], ) @@ -521,7 +521,7 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=session_api_override), + override=session.SessionOverrideConfig(apis=session_api_override), ) ], ) @@ -584,7 +584,7 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=session_api_override), + override=session.SessionOverrideConfig(apis=session_api_override), ) ], ) diff --git a/tests/test_user_context.py b/tests/test_user_context.py index 4ef289756..4c2e87c55 100644 --- a/tests/test_user_context.py +++ b/tests/test_user_context.py @@ -191,13 +191,13 @@ async def create_new_session( framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password, functions=functions_override_email_password, ) ), session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=functions_override_session ) ), @@ -319,13 +319,13 @@ async def create_new_session( framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password, functions=functions_override_email_password, ) ), session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=functions_override_session ) ), @@ -466,13 +466,13 @@ async def create_new_session( framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password, functions=functions_override_email_password, ) ), session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=functions_override_session ) ), diff --git a/tests/usermetadata/test_metadata.py b/tests/usermetadata/test_metadata.py index ded211d6c..a859ee442 100644 --- a/tests/usermetadata/test_metadata.py +++ b/tests/usermetadata/test_metadata.py @@ -27,7 +27,7 @@ ClearUserMetadataResult, RecipeInterface, ) -from supertokens_python.recipe.usermetadata.utils import InputOverrideConfig +from supertokens_python.recipe.usermetadata.utils import UserMetadataOverrideConfig from supertokens_python.utils import is_version_gte from tests.utils import get_new_core_app_url @@ -169,7 +169,9 @@ async def new_get_user_metadata(user_id: str, user_context: Dict[str, Any]): ), framework="fastapi", recipe_list=[ - usermetadata.init(override=InputOverrideConfig(functions=override_func)) + usermetadata.init( + override=UserMetadataOverrideConfig(functions=override_func) + ) ], ) From 6354579313f63d925f319cd272953ea4a87b94e0 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 2 Jul 2025 15:48:30 +0530 Subject: [PATCH 15/37] fix: accountlinking recipe get_instance, tests --- supertokens_python/plugins.py | 32 ++-- .../recipe/accountlinking/recipe.py | 10 +- .../recipe/accountlinking/types.py | 24 +-- .../recipe/emailverification/utils.py | 6 +- supertokens_python/supertokens.py | 21 ++- supertokens_python/types/config.py | 12 +- .../input_validation/test_input_validation.py | 142 +++++++++--------- tests/plugins/api_implementation.py | 6 +- tests/plugins/config.py | 79 ++++------ tests/plugins/plugins.py | 4 +- tests/plugins/recipe.py | 15 +- tests/plugins/recipe_implementation.py | 12 +- tests/plugins/test_plugins.py | 4 +- tests/test_config.py | 5 +- 14 files changed, 187 insertions(+), 185 deletions(-) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index b04ed42db..d2bfe62a4 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -21,20 +21,7 @@ from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse from supertokens_python.logger import log_debug_message -from supertokens_python.types import MaybeAwaitable -from supertokens_python.types.base import UserContext -from supertokens_python.types.config import ( - BaseConfig, - BaseConfigWithoutAPIOverride, - BaseOverrideConfig, -) -from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface -from supertokens_python.types.response import CamelCaseBaseModel - -if TYPE_CHECKING: - from supertokens_python.post_init_callbacks import PostSTInitCallbacks - from supertokens_python.supertokens import SupertokensPublicConfig - +from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.recipe.accountlinking.types import AccountLinkingConfig from supertokens_python.recipe.dashboard.utils import DashboardConfig from supertokens_python.recipe.emailpassword.utils import EmailPasswordConfig @@ -57,6 +44,19 @@ from supertokens_python.recipe.usermetadata.utils import UserMetadataConfig from supertokens_python.recipe.userroles.utils import UserRolesConfig from supertokens_python.recipe.webauthn.types.config import WebauthnConfig +from supertokens_python.types import MaybeAwaitable +from supertokens_python.types.base import UserContext +from supertokens_python.types.config import ( + BaseConfig, + BaseConfigWithoutAPIOverride, + BaseOverrideConfig, + BaseOverrideConfigWithoutAPI, +) +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface +from supertokens_python.types.response import CamelCaseBaseModel + +if TYPE_CHECKING: + from supertokens_python.supertokens import SupertokensPublicConfig T = TypeVar( "T", @@ -311,7 +311,7 @@ def default_api_override( if config.override is None: if isinstance(config, BaseConfigWithoutAPIOverride): - config.override = BaseConfigWithoutAPIOverride() # type: ignore + config.override = BaseOverrideConfigWithoutAPI() else: config.override = BaseOverrideConfig() # type: ignore @@ -327,7 +327,7 @@ def default_api_override( # Order of 1/2 does not matter since they are independent from each other. for plugin in plugins: - overrides = plugin[recipe_id] + overrides = plugin.get(recipe_id) if overrides is not None: if overrides.config is not None: config = overrides.config(config) diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 06462739f..3e0e7c78a 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -26,7 +26,6 @@ from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.querier import Querier from supertokens_python.recipe_module import APIHandled, RecipeModule -from supertokens_python.supertokens import Supertokens from supertokens_python.types.base import AccountInfoInput from .interfaces import RecipeInterface @@ -175,11 +174,12 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): @staticmethod def get_instance() -> AccountLinkingRecipe: - if AccountLinkingRecipe.__instance is None: - AccountLinkingRecipe.init()(Supertokens.get_instance().app_info) + if AccountLinkingRecipe.__instance is not None: + return AccountLinkingRecipe.__instance - assert AccountLinkingRecipe.__instance is not None - return AccountLinkingRecipe.__instance + raise_general_exception( + "Initialisation not done. Did you forget to call the SuperTokens.init function?" + ) @staticmethod def reset(): diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 8de3e89db..83238f8a3 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -21,26 +21,26 @@ from supertokens_python.recipe.accountlinking.interfaces import ( RecipeInterface, ) -from supertokens_python.types import AccountInfo +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import ( + AccountInfo, + LoginMethod, + RecipeUserId, + User, +) from supertokens_python.types.config import ( BaseConfigWithoutAPIOverride, BaseNormalisedConfigWithoutAPIOverride, - NormalisedOverrideConfigWithoutAPI, - OverrideConfigWithoutAPI, + BaseNormalisedOverrideConfigWithoutAPI, + BaseOverrideConfigWithoutAPI, ) if TYPE_CHECKING: - from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo from supertokens_python.recipe.webauthn.types.base import WebauthnInfo - from supertokens_python.types import ( - LoginMethod, - RecipeUserId, - User, - ) - -AccountLinkingOverrideConfig = OverrideConfigWithoutAPI[RecipeInterface] -NormalisedAccountLinkingOverrideConfig = NormalisedOverrideConfigWithoutAPI[ + +AccountLinkingOverrideConfig = BaseOverrideConfigWithoutAPI[RecipeInterface] +NormalisedAccountLinkingOverrideConfig = BaseNormalisedOverrideConfigWithoutAPI[ RecipeInterface ] diff --git a/supertokens_python/recipe/emailverification/utils.py b/supertokens_python/recipe/emailverification/utils.py index f032fe6ca..7875b1f69 100644 --- a/supertokens_python/recipe/emailverification/utils.py +++ b/supertokens_python/recipe/emailverification/utils.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union from typing_extensions import Literal @@ -34,13 +34,11 @@ ) from .interfaces import APIInterface, RecipeInterface, TypeGetEmailForUserIdFunction +from .types import EmailTemplateVars, VerificationEmailTemplateVars if TYPE_CHECKING: - from typing import Callable, Union - from supertokens_python.supertokens import AppInfo - from .types import EmailTemplateVars, VerificationEmailTemplateVars MODE_TYPE = Literal["REQUIRED", "OPTIONAL"] diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index ad6aa9c38..8f1d73fd8 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -69,7 +69,6 @@ PluginRouteHandler, SuperTokensPlugin, SuperTokensPublicPlugin, - load_plugins, ) from supertokens_python.recipe.session import SessionContainer @@ -312,6 +311,8 @@ def __init__( debug: Optional[bool], experimental: Optional[SupertokensExperimentalConfig] = None, ): + print(f"{app_info=}") + print(f"{type(app_info)=}") if not isinstance(app_info, InputAppInfo): # type: ignore raise ValueError("app_info must be an instance of InputAppInfo") @@ -334,6 +335,8 @@ def __init__( override_maps: List[OverrideMap] = [] if experimental is not None and experimental.plugins is not None: + from supertokens_python.plugins import load_plugins + load_plugins_result = load_plugins( plugins=experimental.plugins, public_config=input_public_config, @@ -397,6 +400,7 @@ def __init__( oauth2_found = False openid_found = False jwt_found = False + account_linking_found = False def make_recipe( recipe: Callable[[AppInfo, List[OverrideMap]], RecipeModule], @@ -408,7 +412,9 @@ def make_recipe( multi_factor_auth_found, \ oauth2_found, \ openid_found, \ - jwt_found + jwt_found, \ + account_linking_found + recipe_module = recipe(self.app_info, override_maps) if recipe_module.get_recipe_id() == "multitenancy": multitenancy_found = True @@ -424,10 +430,21 @@ def make_recipe( openid_found = True elif recipe_module.get_recipe_id() == "jwt": jwt_found = True + elif recipe_module.get_recipe_id() == "accountlinking": + account_linking_found = True return recipe_module self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) + if not account_linking_found: + from supertokens_python.recipe.accountlinking.recipe import ( + AccountLinkingRecipe, + ) + + self.recipe_modules.append( + AccountLinkingRecipe.init()(self.app_info, override_maps) + ) + if not jwt_found: from supertokens_python.recipe.jwt.recipe import JWTRecipe diff --git a/supertokens_python/types/config.py b/supertokens_python/types/config.py index 10752a02d..bf1fc6553 100644 --- a/supertokens_python/types/config.py +++ b/supertokens_python/types/config.py @@ -16,7 +16,7 @@ InterfaceOverride = Callable[[T], T] -class OverrideConfigWithoutAPI(CamelCaseBaseModel, Generic[FunctionInterfaceType]): +class BaseOverrideConfigWithoutAPI(CamelCaseBaseModel, Generic[FunctionInterfaceType]): """Base class for input override config without API overrides.""" functions: UseDefaultIfNone[Optional[InterfaceOverride[FunctionInterfaceType]]] = ( @@ -24,7 +24,7 @@ class OverrideConfigWithoutAPI(CamelCaseBaseModel, Generic[FunctionInterfaceType ) -class NormalisedOverrideConfigWithoutAPI( +class BaseNormalisedOverrideConfigWithoutAPI( CamelCaseBaseModel, Generic[FunctionInterfaceType] ): """Base class for normalized override config without API overrides.""" @@ -35,7 +35,7 @@ class NormalisedOverrideConfigWithoutAPI( class BaseOverrideConfig( - OverrideConfigWithoutAPI[FunctionInterfaceType], + BaseOverrideConfigWithoutAPI[FunctionInterfaceType], Generic[FunctionInterfaceType, APIInterfaceType], ): """Base class for input override config with API overrides.""" @@ -46,7 +46,7 @@ class BaseOverrideConfig( class BaseNormalisedOverrideConfig( - NormalisedOverrideConfigWithoutAPI[FunctionInterfaceType], + BaseNormalisedOverrideConfigWithoutAPI[FunctionInterfaceType], Generic[FunctionInterfaceType, APIInterfaceType], ): """Base class for normalized override config with API overrides.""" @@ -59,7 +59,7 @@ class BaseNormalisedOverrideConfig( class BaseConfigWithoutAPIOverride(CamelCaseBaseModel, Generic[FunctionInterfaceType]): """Base class for input config of a Recipe without API overrides.""" - override: Optional[OverrideConfigWithoutAPI[FunctionInterfaceType]] = None + override: Optional[BaseOverrideConfigWithoutAPI[FunctionInterfaceType]] = None class BaseNormalisedConfigWithoutAPIOverride( @@ -67,7 +67,7 @@ class BaseNormalisedConfigWithoutAPIOverride( ): """Base class for normalized config of a Recipe without API overrides.""" - override: NormalisedOverrideConfigWithoutAPI[FunctionInterfaceType] + override: BaseNormalisedOverrideConfigWithoutAPI[FunctionInterfaceType] class BaseConfig(CamelCaseBaseModel, Generic[FunctionInterfaceType, APIInterfaceType]): diff --git a/tests/input_validation/test_input_validation.py b/tests/input_validation/test_input_validation.py index b9ce47f3c..ae6111fa9 100644 --- a/tests/input_validation/test_input_validation.py +++ b/tests/input_validation/test_input_validation.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List import pytest +from pydantic import ValidationError from supertokens_python import InputAppInfo, SupertokensConfig, init from supertokens_python.recipe import ( emailpassword, @@ -23,7 +24,9 @@ @pytest.mark.asyncio async def test_init_validation_emailpassword(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValueError, match="app_info must be an instance of InputAppInfo" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info="AppInfo", # type: ignore @@ -32,9 +35,10 @@ async def test_init_validation_emailpassword(): emailpassword.init(), ], ) - assert "app_info must be an instance of InputAppInfo" == str(ex.value) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be 'REQUIRED' or 'OPTIONAL'" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -49,10 +53,6 @@ async def test_init_validation_emailpassword(): emailpassword.init(), ], ) - assert ( - "Email Verification recipe mode must be one of 'REQUIRED' or 'OPTIONAL'" - == str(ex.value) - ) async def get_email_for_user_id(_: RecipeUserId, __: Dict[str, Any]): @@ -61,7 +61,9 @@ async def get_email_for_user_id(_: RecipeUserId, __: Dict[str, Any]): @pytest.mark.asyncio async def test_init_validation_emailverification(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be 'REQUIRED' or 'OPTIONAL'" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -73,12 +75,11 @@ async def test_init_validation_emailverification(): framework="fastapi", recipe_list=[emailverification.init("config")], # type: ignore ) - assert ( - "Email Verification recipe mode must be one of 'REQUIRED' or 'OPTIONAL'" - == str(ex.value) - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -96,26 +97,29 @@ async def test_init_validation_emailverification(): ) ], ) - assert "override must be of type OverrideConfig or None" == str(ex.value) @pytest.mark.asyncio async def test_init_validation_jwt(): - with pytest.raises(ValueError) as ex: - init( - supertokens_config=SupertokensConfig(get_new_core_app_url()), - app_info=InputAppInfo( - app_name="SuperTokens Demo", - api_domain="http://api.supertokens.io", - website_domain="http://supertokens.io", - api_base_path="/auth", - ), - framework="fastapi", - recipe_list=[jwt.init(jwt_validity_seconds="100")], # type: ignore - ) - assert "jwt_validity_seconds must be an integer or None" == str(ex.value) - - with pytest.raises(ValueError) as ex: + # NOTE: `pydantic` auto-converts strings to integers + # with pytest.raises(ValueError) as ex: + # init( + # supertokens_config=SupertokensConfig(get_new_core_app_url()), + # app_info=InputAppInfo( + # app_name="SuperTokens Demo", + # api_domain="http://api.supertokens.io", + # website_domain="http://supertokens.io", + # api_base_path="/auth", + # ), + # framework="fastapi", + # recipe_list=[jwt.init(jwt_validity_seconds="100")], # type: ignore + # ) + # assert "jwt_validity_seconds must be an integer or None" == str(ex.value) + + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -127,12 +131,14 @@ async def test_init_validation_jwt(): framework="fastapi", recipe_list=[jwt.init(override="override")], # type: ignore ) - assert "override must be an instance of OverrideConfig or None" == str(ex.value) @pytest.mark.asyncio async def test_init_validation_openid(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -144,9 +150,6 @@ async def test_init_validation_openid(): framework="fastapi", recipe_list=[openid.init(override="override")], # type: ignore ) - assert "override must be an instance of InputOverrideConfig or None" == str( - ex.value - ) async def send_text_message( @@ -177,7 +180,9 @@ async def send_email( ) -> None: pass - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValueError, match="app_info must be an instance of InputAppInfo" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info="AppInfo", # type: ignore @@ -195,9 +200,11 @@ async def send_email( ) ], ) - assert "app_info must be an instance of InputAppInfo" == str(ex.value) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be 'USER_INPUT_CODE', 'MAGIC_LINK' or 'USER_INPUT_CODE_AND_MAGIC_LINK'", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -220,12 +227,10 @@ async def send_email( ) ], ) - assert ( - "flow_type must be one of USER_INPUT_CODE, MAGIC_LINK, USER_INPUT_CODE_AND_MAGIC_LINK" - == str(ex.value) - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be an instance of ContactConfig" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -242,9 +247,11 @@ async def send_email( ) ], ) - assert "contact_config must be of type ContactConfig" == str(ex.value) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -268,7 +275,6 @@ async def send_email( ) ], ) - assert "override must be of type OverrideConfig" == str(ex.value) providers_list: List[thirdparty.ProviderInput] = [ @@ -310,7 +316,10 @@ async def send_email( @pytest.mark.asyncio async def test_init_validation_session(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be 'VIA_TOKEN', 'VIA_CUSTOM_HEADER' or 'NONE'", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -322,11 +331,10 @@ async def test_init_validation_session(): framework="fastapi", recipe_list=[session.init(anti_csrf="ABCDE")], # type: ignore ) - assert "anti_csrf must be one of VIA_TOKEN, VIA_CUSTOM_HEADER, NONE or None" == str( - ex.value - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be an instance of ErrorHandlers" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -341,11 +349,11 @@ async def test_init_validation_session(): # on invalid type. recipe_list=[session.init(error_handlers="error handlers")], # type: ignore ) - assert "error_handlers must be an instance of ErrorHandlers or None" == str( - ex.value - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -357,14 +365,13 @@ async def test_init_validation_session(): framework="fastapi", recipe_list=[session.init(override="override")], # type: ignore ) - assert "override must be an instance of InputOverrideConfig or None" == str( - ex.value - ) @pytest.mark.asyncio async def test_init_validation_thirdparty(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be an instance of SignInAndUpFeature" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -381,11 +388,11 @@ async def test_init_validation_thirdparty(): thirdparty.init(sign_in_and_up_feature="sign in up") # type: ignore ], ) - assert "sign_in_and_up_feature must be an instance of SignInAndUpFeature" == str( - ex.value - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -404,14 +411,14 @@ async def test_init_validation_thirdparty(): ) ], ) - assert "override must be an instance of InputOverrideConfig or None" == str( - ex.value - ) @pytest.mark.asyncio async def test_init_validation_usermetadata(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -423,6 +430,3 @@ async def test_init_validation_usermetadata(): framework="fastapi", recipe_list=[usermetadata.init(override="override")], # type: ignore ) - assert "override must be an instance of InputOverrideConfig or None" == str( - ex.value - ) diff --git a/tests/plugins/api_implementation.py b/tests/plugins/api_implementation.py index c38297b94..062f56591 100644 --- a/tests/plugins/api_implementation.py +++ b/tests/plugins/api_implementation.py @@ -1,12 +1,14 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import ( List, ) +from supertokens_python.types.recipe import BaseAPIInterface + from .types import RecipeReturnType -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): @abstractmethod def sign_in_post(self, message: str, stack: List[str]) -> RecipeReturnType: ... diff --git a/tests/plugins/config.py b/tests/plugins/config.py index b5284a2a9..31d6330a9 100644 --- a/tests/plugins/config.py +++ b/tests/plugins/config.py @@ -1,57 +1,36 @@ -from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - Optional, - Protocol, - TypeVar, - runtime_checkable, -) +from typing import Any, Optional from supertokens_python.supertokens import ( AppInfo, ) +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, + InterfaceOverride, +) +from supertokens_python.types.utils import UseDefaultIfNone -if TYPE_CHECKING: - from .api_implementation import APIInterface - from .recipe_implementation import RecipeInterface - -InterfaceType = TypeVar("InterfaceType") -"""Generic Type for use in `InterfaceOverride`""" - - -@runtime_checkable -class InterfaceOverride(Protocol[InterfaceType]): - """ - Callable signature for `WebauthnConfig.override.*`. - """ - - def __call__( - self, - original_implementation: InterfaceType, - ) -> InterfaceType: ... +from .api_implementation import APIInterface +from .recipe_implementation import RecipeInterface +PluginTestOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedPluginTestOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -# NOTE: Using dataclasses for these classes since validation is not required -@dataclass -class OverrideConfig: - """ - `WebauthnConfig.override` - """ - functions: Optional[InterfaceOverride["RecipeInterface"]] = None - apis: Optional[InterfaceOverride["APIInterface"]] = None - config: Optional[InterfaceOverride[Any]] = None +class NormalizedPluginTestConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): ... -@dataclass -class NormalizedPluginTestConfig: - override: OverrideConfig +class PluginTestConfig(BaseConfig[RecipeInterface, APIInterface]): ... -@dataclass -class PluginTestConfig: - override: Optional[OverrideConfig] = None +class PluginOverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): + config: UseDefaultIfNone[Optional[InterfaceOverride[Any]]] = lambda config: config def validate_and_normalise_user_input( @@ -60,12 +39,12 @@ def validate_and_normalise_user_input( if config is None: config = PluginTestConfig() - if config.override is None: - override = OverrideConfig() - else: - override = OverrideConfig( - functions=config.override.functions, - apis=config.override.apis, - ) + override_config = NormalisedPluginTestOverrideConfig() + if config.override is not None: + if config.override.functions is not None: + override_config.functions = config.override.functions + + if config.override.apis is not None: + override_config.apis = config.override.apis - return NormalizedPluginTestConfig(override=override) + return NormalizedPluginTestConfig(override=override_config) diff --git a/tests/plugins/plugins.py b/tests/plugins/plugins.py index 0a9fc455e..60f635a7f 100644 --- a/tests/plugins/plugins.py +++ b/tests/plugins/plugins.py @@ -10,7 +10,7 @@ from supertokens_python.supertokens import SupertokensPublicConfig from .api_implementation import APIInterface -from .config import OverrideConfig +from .config import PluginOverrideConfig from .recipe import PluginTestRecipe from .recipe_implementation import RecipeInterface @@ -79,7 +79,7 @@ def plugin_factory( deps: Optional[List[SuperTokensPlugin]] = None, add_init: bool = False, ): - override_map_obj: OverrideMap = {PluginTestRecipe.recipe_id: OverrideConfig()} + override_map_obj: OverrideMap = {PluginTestRecipe.recipe_id: PluginOverrideConfig()} if override_functions: override_map_obj[ diff --git a/tests/plugins/recipe.py b/tests/plugins/recipe.py index 2a07dcfcb..56f23d287 100644 --- a/tests/plugins/recipe.py +++ b/tests/plugins/recipe.py @@ -46,18 +46,12 @@ def __init__( querier=querier, config=self.config, ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) # type: ignore api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) # type: ignore + self.api_implementation = self.config.override.apis(api_implementation) # type: ignore @staticmethod def get_instance() -> "PluginTestRecipe": @@ -73,6 +67,9 @@ def get_instance_optional() -> Optional["PluginTestRecipe"]: @staticmethod def init(config: Optional[PluginTestConfig]): + if config is None: + config = PluginTestConfig() + def func(app_info: AppInfo, plugins: List[OverrideMap]): if PluginTestRecipe.__instance is None: PluginTestRecipe.__instance = PluginTestRecipe( @@ -80,7 +77,7 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): app_info=app_info, config=apply_plugins( recipe_id=PluginTestRecipe.recipe_id, - config=config, + config=config, # type: ignore plugins=plugins, ), ) diff --git a/tests/plugins/recipe_implementation.py b/tests/plugins/recipe_implementation.py index 03988e859..f43ed1a0f 100644 --- a/tests/plugins/recipe_implementation.py +++ b/tests/plugins/recipe_implementation.py @@ -1,15 +1,19 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import ( + TYPE_CHECKING, List, ) from supertokens_python.querier import Querier +from supertokens_python.types.recipe import BaseRecipeInterface -from .config import NormalizedPluginTestConfig from .types import RecipeReturnType +if TYPE_CHECKING: + from .config import NormalizedPluginTestConfig -class RecipeInterface(ABC): + +class RecipeInterface(BaseRecipeInterface): @abstractmethod def sign_in(self, message: str, stack: List[str]) -> RecipeReturnType: ... @@ -18,7 +22,7 @@ class RecipeImplementation(RecipeInterface): def __init__( self, querier: Querier, - config: NormalizedPluginTestConfig, + config: "NormalizedPluginTestConfig", ): super().__init__() self.querier = querier diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index d523da1ef..12ea8ff52 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -20,7 +20,7 @@ from tests.utils import outputs, reset -from .config import OverrideConfig, PluginTestConfig +from .config import PluginTestConfig, PluginTestOverrideConfig from .misc import DummyRequest, DummyResponse from .plugins import ( Plugin1, @@ -50,7 +50,7 @@ def setup_and_teardown(): def recipe_factory(override_functions: bool = False, override_apis: bool = False): - override = OverrideConfig() + override = PluginTestOverrideConfig() if override_functions: override.functions = function_override_factory("override") diff --git a/tests/test_config.py b/tests/test_config.py index 48486fdcf..daf1bbe45 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,6 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import re from unittest.mock import MagicMock import pytest @@ -280,7 +281,7 @@ async def test_same_site_values(st_config: SupertokensConfig): ) test_passed = False except Exception as e: - assert str(e) == 'cookie same site must be one of "strict", "lax", or "none"' + assert re.search("Input should be 'lax', 'strict' or 'none'", str(e)) assert test_passed reset() @@ -299,7 +300,7 @@ async def test_same_site_values(st_config: SupertokensConfig): ) test_passed = False except Exception as e: - assert str(e) == 'cookie same site must be one of "strict", "lax", or "none"' + assert re.search("Input should be 'lax', 'strict' or 'none'", str(e)) assert test_passed reset() From 28e812174633cfc8251125b67b93f26b86f77500 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 2 Jul 2025 16:28:32 +0530 Subject: [PATCH 16/37] fix: servers --- supertokens_python/recipe/dashboard/recipe.py | 3 ++- supertokens_python/recipe/webauthn/types/config.py | 13 ++++++++++--- supertokens_python/supertokens.py | 2 -- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 0de83281e..aae27c256 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.plugins import OverrideMap, apply_plugins from supertokens_python.recipe.dashboard.api.multitenancy.create_or_update_third_party_config import ( handle_create_or_update_third_party_config, ) @@ -644,6 +643,8 @@ def init( admins: Optional[List[str]] = None, override: Optional[DashboardOverrideConfig] = None, ): + from supertokens_python.plugins import OverrideMap, apply_plugins + config = DashboardConfig( api_key=api_key, admins=admins, diff --git a/supertokens_python/recipe/webauthn/types/config.py b/supertokens_python/recipe/webauthn/types/config.py index d2132cf1d..2fd3724fb 100644 --- a/supertokens_python/recipe/webauthn/types/config.py +++ b/supertokens_python/recipe/webauthn/types/config.py @@ -17,7 +17,6 @@ from typing import Optional, Protocol, TypeVar, Union, runtime_checkable from supertokens_python.framework import BaseRequest -from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import ( EmailDeliveryConfig, EmailDeliveryConfigWithService, @@ -193,7 +192,12 @@ class WebauthnConfig(BaseConfig[RecipeInterface, APIInterface]): get_relying_party_id: Optional[Union[str, GetRelyingPartyId]] = None get_relying_party_name: Optional[Union[str, GetRelyingPartyName]] = None get_origin: Optional[GetOrigin] = None - email_delivery: Optional[EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]] = None + email_delivery: Optional[ + Union[ + EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput], + EmailDeliveryConfigWithService[TypeWebauthnEmailDeliveryInput], + ] + ] = None validate_email_address: Optional[ValidateEmailAddress] = None @@ -207,5 +211,8 @@ class NormalisedWebauthnConfig(BaseNormalisedConfig[RecipeInterface, APIInterfac class WebauthnIngredients(CamelCaseBaseModel): email_delivery: Optional[ - EmailDeliveryIngredient[TypeWebauthnEmailDeliveryInput] + Union[ + EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput], + EmailDeliveryConfigWithService[TypeWebauthnEmailDeliveryInput], + ] ] = None diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 8f1d73fd8..31f2e22e0 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -311,8 +311,6 @@ def __init__( debug: Optional[bool], experimental: Optional[SupertokensExperimentalConfig] = None, ): - print(f"{app_info=}") - print(f"{type(app_info)=}") if not isinstance(app_info, InputAppInfo): # type: ignore raise ValueError("app_info must be an instance of InputAppInfo") From 5ccce692cda7d1ce0f4c810aa21fc089dcbcc0b5 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 2 Jul 2025 17:52:23 +0530 Subject: [PATCH 17/37] fix: servers --- .../recipe/accountlinking/recipe.py | 4 ++-- .../django2x/polls/views.py | 3 +++ .../django3x/polls/views.py | 3 +++ .../drf_async/polls/views.py | 3 +++ .../drf_sync/polls/views.py | 3 +++ .../frontendIntegration/fastapi-server/app.py | 3 +++ tests/frontendIntegration/flask-server/app.py | 22 ++++++++++--------- 7 files changed, 29 insertions(+), 12 deletions(-) diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 3e0e7c78a..cdaf50185 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -147,7 +147,7 @@ def init( ) -> Callable[..., AccountLinkingRecipe]: from supertokens_python.plugins import OverrideMap, apply_plugins - cofnfig = AccountLinkingConfig( + config = AccountLinkingConfig( on_account_linked=on_account_linked, should_do_automatic_account_linking=should_do_automatic_account_linking, override=override, @@ -160,7 +160,7 @@ def func(app_info: AppInfo, plugins: List[OverrideMap]): app_info=app_info, config=apply_plugins( recipe_id=AccountLinkingRecipe.recipe_id, - config=cofnfig, + config=config, plugins=plugins, ), ) diff --git a/tests/frontendIntegration/django2x/polls/views.py b/tests/frontendIntegration/django2x/polls/views.py index 0859a964e..a0128eb48 100644 --- a/tests/frontendIntegration/django2x/polls/views.py +++ b/tests/frontendIntegration/django2x/polls/views.py @@ -36,6 +36,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -607,6 +608,7 @@ def reinitialize(request: HttpRequest): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( data["coreUrl"], last_set_enable_anti_csrf, @@ -627,6 +629,7 @@ def setup_st(request: HttpRequest): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=data["coreUrl"], enable_anti_csrf=data.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/django3x/polls/views.py b/tests/frontendIntegration/django3x/polls/views.py index 5c06e1409..1f74a5ef9 100644 --- a/tests/frontendIntegration/django3x/polls/views.py +++ b/tests/frontendIntegration/django3x/polls/views.py @@ -34,6 +34,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -608,6 +609,7 @@ async def reinitialize(request: HttpRequest): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( data["coreUrl"], last_set_enable_anti_csrf, @@ -628,6 +630,7 @@ async def setup_st(request: HttpRequest): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=data["coreUrl"], enable_anti_csrf=data.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/drf_async/polls/views.py b/tests/frontendIntegration/drf_async/polls/views.py index e69f34107..28513fbfa 100644 --- a/tests/frontendIntegration/drf_async/polls/views.py +++ b/tests/frontendIntegration/drf_async/polls/views.py @@ -40,6 +40,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -671,6 +672,7 @@ async def reinitialize(request: Request): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( data["coreUrl"], # type: ignore last_set_enable_anti_csrf, @@ -693,6 +695,7 @@ async def setup_st(request: HttpRequest): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=data["coreUrl"], enable_anti_csrf=data.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/drf_sync/polls/views.py b/tests/frontendIntegration/drf_sync/polls/views.py index 932fe71bc..4fb1a1001 100644 --- a/tests/frontendIntegration/drf_sync/polls/views.py +++ b/tests/frontendIntegration/drf_sync/polls/views.py @@ -39,6 +39,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -673,6 +674,7 @@ def reinitialize(request: Request): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( data["coreUrl"], # type: ignore last_set_enable_anti_csrf, @@ -695,6 +697,7 @@ def setup_st(request: HttpRequest): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=data["coreUrl"], enable_anti_csrf=data.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/fastapi-server/app.py b/tests/frontendIntegration/fastapi-server/app.py index a7f02a383..ad8a46194 100644 --- a/tests/frontendIntegration/fastapi-server/app.py +++ b/tests/frontendIntegration/fastapi-server/app.py @@ -38,6 +38,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -657,6 +658,7 @@ async def reinitialize(request: Request): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( json["coreUrl"], last_set_enable_anti_csrf, @@ -678,6 +680,7 @@ async def setup_st(request: Request): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=json["coreUrl"], enable_anti_csrf=json.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/flask-server/app.py b/tests/frontendIntegration/flask-server/app.py index 3059bc3a5..c3134208e 100644 --- a/tests/frontendIntegration/flask-server/app.py +++ b/tests/frontendIntegration/flask-server/app.py @@ -38,6 +38,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -282,16 +283,6 @@ def config( ) -core_host = os.environ.get("SUPERTOKENS_CORE_HOST", "localhost") -core_port = os.environ.get("SUPERTOKENS_CORE_PORT", "3567") -config( - core_url=f"http://{core_host}:{core_port}", - enable_anti_csrf=True, - enable_jwt=False, - jwt_property_name=None, -) - - @app.route("/index.html", methods=["GET"]) # type: ignore def send_file(): return render_template("index.html") @@ -674,6 +665,7 @@ def reinitialize(): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( json["coreUrl"], last_set_enable_anti_csrf, # type: ignore @@ -695,6 +687,7 @@ async def setup_st(): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=json["coreUrl"], enable_anti_csrf=json.get("enableAntiCsrf"), # type: ignore @@ -733,5 +726,14 @@ def handle_exception(e): # type: ignore return Response(str(e), status=500) # type: ignore +core_host = os.environ.get("SUPERTOKENS_CORE_HOST", "localhost") +core_port = os.environ.get("SUPERTOKENS_CORE_PORT", "3567") +config( + core_url=f"http://{core_host}:{core_port}", + enable_anti_csrf=True, + enable_jwt=False, + jwt_property_name=None, +) + if __name__ == "__main__": app.run(host="0.0.0.0", port=int(get_app_port()), threaded=True) From b1d66bca899f8f49246a7183284daa85c5c3ecda Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Thu, 3 Jul 2025 12:00:01 +0530 Subject: [PATCH 18/37] fix: webauthn email ingredient types, bump pyright --- dev-requirements.txt | 2 +- .../recipe/emailpassword/api/utils.py | 6 ++++-- .../claim_base_classes/primitive_array_claim.py | 4 ++-- .../recipe/thirdparty/providers/apple.py | 4 ++-- .../recipe/thirdparty/providers/custom.py | 4 ++-- .../recipe/webauthn/interfaces/api.py | 4 ++-- supertokens_python/recipe/webauthn/types/config.py | 13 +++---------- tests/auth-react/django3x/mysite/utils.py | 8 ++++---- tests/auth-react/fastapi-server/app.py | 8 ++++---- tests/auth-react/flask-server/app.py | 8 ++++---- tests/test-server/utils.py | 14 +++++++++----- tests/test_logger.py | 6 +++--- 12 files changed, 40 insertions(+), 41 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index e39220586..2dddb9f97 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -11,7 +11,7 @@ pdoc3==0.11.0 pre-commit==3.5.0 pyfakefs==5.7.4 pylint==3.2.7 -pyright==1.1.393 +pyright==1.1.402 python-dotenv==1.0.1 pytest==8.3.3 pytest-asyncio==0.24.0 diff --git a/supertokens_python/recipe/emailpassword/api/utils.py b/supertokens_python/recipe/emailpassword/api/utils.py index 06af69e9c..d35d96234 100644 --- a/supertokens_python/recipe/emailpassword/api/utils.py +++ b/supertokens_python/recipe/emailpassword/api/utils.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, cast from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.emailpassword.constants import ( @@ -78,7 +78,9 @@ async def validate_form_fields_or_throw_error( form_fields: List[FormField] = [] - form_fields_list_raw: List[Dict[str, Any]] = form_fields_raw + form_fields_list_raw: List[Dict[str, Any]] = cast( + List[Dict[str, Any]], form_fields_raw + ) for current_form_field in form_fields_list_raw: if ( "id" not in current_form_field diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py index 77a2ac061..0845680c1 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, cast from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import get_timestamp_ms @@ -105,7 +105,7 @@ async def _validate( # Doing this to ensure same code in the upcoming steps irrespective of # whether self.val is Primitive or PrimitiveList - vals: List[_T] = val if isinstance(val, list) else [val] + vals: List[_T] = cast(List[_T], val if isinstance(val, list) else [val]) claim_val_set = set(claim_val) if is_include and not is_include_any: diff --git a/supertokens_python/recipe/thirdparty/providers/apple.py b/supertokens_python/recipe/thirdparty/providers/apple.py index 5c344b980..5466335cc 100644 --- a/supertokens_python/recipe/thirdparty/providers/apple.py +++ b/supertokens_python/recipe/thirdparty/providers/apple.py @@ -16,7 +16,7 @@ import json from re import sub from time import time -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast from jwt import encode # type: ignore @@ -106,7 +106,7 @@ async def get_user_info( if isinstance(user, str): user_dict = json.loads(user) elif isinstance(user, dict): - user_dict = user + user_dict = cast(Dict[str, Any], user) else: return response diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index af78ea1e2..c97de1005 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, cast from urllib.parse import parse_qs, urlencode, urlparse import pkce @@ -67,7 +67,7 @@ def access_field(obj: Any, key: str) -> Any: key_parts = key.split(".") for part in key_parts: if isinstance(obj, dict): - obj = obj.get(part) # type: ignore + obj = cast(Dict[str, Any], obj).get(part) else: return None diff --git a/supertokens_python/recipe/webauthn/interfaces/api.py b/supertokens_python/recipe/webauthn/interfaces/api.py index fed8d64aa..8e52c15f2 100644 --- a/supertokens_python/recipe/webauthn/interfaces/api.py +++ b/supertokens_python/recipe/webauthn/interfaces/api.py @@ -12,8 +12,8 @@ # License for the specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod -from typing import List, Literal, Optional, TypedDict, Union +from abc import abstractmethod +from typing import TYPE_CHECKING, List, Literal, Optional, TypedDict, Union from typing_extensions import NotRequired, Unpack diff --git a/supertokens_python/recipe/webauthn/types/config.py b/supertokens_python/recipe/webauthn/types/config.py index 2fd3724fb..d2132cf1d 100644 --- a/supertokens_python/recipe/webauthn/types/config.py +++ b/supertokens_python/recipe/webauthn/types/config.py @@ -17,6 +17,7 @@ from typing import Optional, Protocol, TypeVar, Union, runtime_checkable from supertokens_python.framework import BaseRequest +from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import ( EmailDeliveryConfig, EmailDeliveryConfigWithService, @@ -192,12 +193,7 @@ class WebauthnConfig(BaseConfig[RecipeInterface, APIInterface]): get_relying_party_id: Optional[Union[str, GetRelyingPartyId]] = None get_relying_party_name: Optional[Union[str, GetRelyingPartyName]] = None get_origin: Optional[GetOrigin] = None - email_delivery: Optional[ - Union[ - EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput], - EmailDeliveryConfigWithService[TypeWebauthnEmailDeliveryInput], - ] - ] = None + email_delivery: Optional[EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]] = None validate_email_address: Optional[ValidateEmailAddress] = None @@ -211,8 +207,5 @@ class NormalisedWebauthnConfig(BaseNormalisedConfig[RecipeInterface, APIInterfac class WebauthnIngredients(CamelCaseBaseModel): email_delivery: Optional[ - Union[ - EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput], - EmailDeliveryConfigWithService[TypeWebauthnEmailDeliveryInput], - ] + EmailDeliveryIngredient[TypeWebauthnEmailDeliveryInput] ] = None diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index 42cfe682d..ebb40e7b5 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -8,7 +8,7 @@ from supertokens_python import InputAppInfo, Supertokens, SupertokensConfig, init from supertokens_python.framework.request import BaseRequest from supertokens_python.ingredients.emaildelivery.types import ( - EmailDeliveryConfigWithService, + EmailDeliveryConfig, EmailDeliveryInterface, ) from supertokens_python.recipe import ( @@ -999,9 +999,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "webauthn", "init": webauthn.init( config=WebauthnConfig( - email_delivery=EmailDeliveryConfigWithService[ - TypeWebauthnEmailDeliveryInput - ](service=CustomWebwuthnEmailService()) # type: ignore + email_delivery=EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]( + service=CustomWebwuthnEmailService() + ) ) ), }, diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index c22a5e5b2..84a004f46 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -41,7 +41,7 @@ from supertokens_python.framework.fastapi import get_middleware from supertokens_python.framework.request import BaseRequest from supertokens_python.ingredients.emaildelivery.types import ( - EmailDeliveryConfigWithService, + EmailDeliveryConfig, EmailDeliveryInterface, ) from supertokens_python.recipe import ( @@ -1100,9 +1100,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "webauthn", "init": webauthn.init( config=WebauthnConfig( - email_delivery=EmailDeliveryConfigWithService[ - TypeWebauthnEmailDeliveryInput - ](service=CustomWebwuthnEmailService()) # type: ignore + email_delivery=EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]( + service=CustomWebwuthnEmailService() + ) ) ), }, diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 6eff7f65d..14e9ae5db 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -32,7 +32,7 @@ from supertokens_python.framework.flask.flask_middleware import Middleware from supertokens_python.framework.request import BaseRequest from supertokens_python.ingredients.emaildelivery.types import ( - EmailDeliveryConfigWithService, + EmailDeliveryConfig, EmailDeliveryInterface, ) from supertokens_python.recipe import ( @@ -1079,9 +1079,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "webauthn", "init": webauthn.init( config=WebauthnConfig( - email_delivery=EmailDeliveryConfigWithService[ - TypeWebauthnEmailDeliveryInput - ](service=CustomWebwuthnEmailService()) # type: ignore + email_delivery=EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]( + service=CustomWebwuthnEmailService() + ) ) ), }, diff --git a/tests/test-server/utils.py b/tests/test-server/utils.py index f5c959527..7008fef37 100644 --- a/tests/test-server/utils.py +++ b/tests/test-server/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, cast from override_logging import log_override_event # pylint: disable=import-error from supertokens_python.recipe.emailverification import EmailVerificationClaim @@ -35,10 +35,14 @@ def fetch_value( }, ) - ret_val: Any = user_context.get("st-stub-arr-value") or ( - values[0] - if isinstance(values, list) and isinstance(values[0], list) - else values + ret_val: Any = cast( + Any, + user_context.get("st-stub-arr-value") + or ( + values[0] + if isinstance(values, list) and isinstance(values[0], list) + else values + ), ) log_override_event(f"claim-{key}.fetchValue", "RES", ret_val) diff --git a/tests/test_logger.py b/tests/test_logger.py index 8bf9993fc..78bc3ab32 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -42,11 +42,11 @@ def test_1_json_msg_format(self, datetime_mock: MagicMock): enable_debug_logging() datetime_mock.now.return_value = real_datetime(2000, 1, 1, tzinfo=timezone.utc) - with self.assertLogs(level="DEBUG") as captured: + with self.assertLogs(level="DEBUG") as captured: # type: ignore log_debug_message("API replied with status 200") - record = captured.records[0] - out = json.loads(record.msg) + record = captured.records[0] # type: ignore + out = json.loads(record.msg) # type: ignore assert out == { "t": "2000-01-01T00:00:00+00Z", From 5a23e04d235230e8a8707ca332f52c4a3613eef5 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Mon, 21 Jul 2025 16:35:41 +0530 Subject: [PATCH 19/37] update: address TODOs --- supertokens_python/plugins.py | 23 +++++++++-------------- supertokens_python/supertokens.py | 17 +++++------------ 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index d2bfe62a4..04605c9e7 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -1,5 +1,4 @@ from collections import deque -from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -84,11 +83,15 @@ APIInterfaceType = TypeVar("APIInterfaceType", bound=BaseAPIInterface) +class RecipeInitRequiredFunction(Protocol): + def __call__(self, sdk_version: str) -> bool: ... + + class RecipePluginOverride: - # TODO: Define a base class for the Config/RecipeInterface/ApiInterface classes, and use it here - functions: Optional[Callable[[Any], Any]] - apis: Optional[Callable[[Any], Any]] + functions: Optional[Callable[[BaseRecipeInterface], BaseRecipeInterface]] + apis: Optional[Callable[[BaseAPIInterface], BaseAPIInterface]] config: Optional[Callable[[Any], Any]] + recipe_init_required: Optional[Union[bool, RecipeInitRequiredFunction]] = None class PluginRouteHandlerResponse(CamelCaseBaseModel): @@ -205,13 +208,7 @@ class SuperTokensPluginBase(CamelCaseBaseModel): class SuperTokensPlugin(SuperTokensPluginBase): init: Optional[SuperTokensPluginInit] = None dependencies: Optional[SuperTokensPluginDependencies] = None - # TODO: Add types for recipes - # overrideMap?: { - # [recipeId in keyof AllRecipeConfigs]?: RecipePluginOverride & { - # recipeInitRequired?: boolean | ((sdkVersion: string) => boolean); - # }; - # }; - override_map: Optional[OverrideMap] = None + override_map: Optional[Dict[str, RecipePluginOverride]] = None route_handlers: Optional[ Union[List[PluginRouteHandler], PluginRouteHandlerFunction] ] = None @@ -374,9 +371,7 @@ def api_override(original_implementation: APIInterfaceType) -> APIInterfaceType: return config -# TODO: Figure out import cycles and convert to a Pydantic BaseModel -@dataclass -class LoadPluginsResponse: +class LoadPluginsResponse(CamelCaseBaseModel): public_config: "SupertokensPublicConfig" processed_plugins: List[SuperTokensPublicPlugin] plugin_route_handlers: List[PluginRouteHandler] diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 31f2e22e0..6a649d377 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -14,7 +14,6 @@ from __future__ import annotations -from dataclasses import dataclass from os import environ from typing import ( TYPE_CHECKING, @@ -36,6 +35,7 @@ get_maybe_none_as_str, log_debug_message, ) +from supertokens_python.types.response import CamelCaseBaseModel from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT from .exceptions import SuperTokensError @@ -111,16 +111,11 @@ def __init__( self.disable_core_call_cache = disable_core_call_cache -@dataclass -class SupertokensExperimentalConfig: +class SupertokensExperimentalConfig(CamelCaseBaseModel): plugins: Optional[List[SuperTokensPlugin]] = None -# TODO: Change to Pydantic? - - -@dataclass -class SupertokensPublicConfig: +class SupertokensPublicConfig(CamelCaseBaseModel): """ Public properties received as input to the `Supertokens.init` function. """ @@ -133,7 +128,6 @@ class SupertokensPublicConfig: debug: Optional[bool] -@dataclass class SupertokensInputConfig(SupertokensPublicConfig): """ Various properties received as input to the `Supertokens.init` function. @@ -324,10 +318,9 @@ def __init__( debug=debug, experimental=experimental, ) - # TODO: Probably just want to define this directly and use it - # Can build a input config from the final public config and the additional props input_public_config = input_config.get_public_config() - processed_public_config = input_public_config + # Use the input public config by default if no plugins provided + processed_public_config: SupertokensPublicConfig = input_public_config self.plugin_route_handlers = [] override_maps: List[OverrideMap] = [] From 397e5920360d42aebce9a3b2c8bcdb1a99ad1cef Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Mon, 21 Jul 2025 18:59:12 +0530 Subject: [PATCH 20/37] feat: setup common util for config override normalisation --- .../recipe/accountlinking/utils.py | 7 +--- supertokens_python/recipe/dashboard/utils.py | 11 +----- .../recipe/emailpassword/utils.py | 10 ++--- .../recipe/emailverification/utils.py | 10 ++--- supertokens_python/recipe/jwt/utils.py | 10 ++--- .../recipe/multifactorauth/utils.py | 10 ++--- .../recipe/multitenancy/utils.py | 10 ++--- .../recipe/oauth2provider/utils.py | 10 ++--- supertokens_python/recipe/openid/utils.py | 10 ++--- .../recipe/passwordless/utils.py | 10 ++--- supertokens_python/recipe/session/utils.py | 10 ++--- supertokens_python/recipe/thirdparty/utils.py | 10 ++--- supertokens_python/recipe/totp/utils.py | 10 ++--- .../recipe/usermetadata/utils.py | 10 ++--- supertokens_python/recipe/userroles/utils.py | 10 ++--- supertokens_python/recipe/webauthn/utils.py | 10 ++--- supertokens_python/types/config.py | 37 +++++++++++++++++++ 17 files changed, 83 insertions(+), 112 deletions(-) diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py index 98f9beb2b..8a496de2a 100644 --- a/supertokens_python/recipe/accountlinking/utils.py +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -62,13 +62,10 @@ def validate_and_normalise_user_input( ) -> NormalisedAccountLinkingConfig: global _did_use_default_should_do_automatic_account_linking - override_config: NormalisedAccountLinkingOverrideConfig = ( - NormalisedAccountLinkingOverrideConfig() + override_config = NormalisedAccountLinkingOverrideConfig.from_input_config( + override_config=config.override ) - if config.override is not None and config.override.functions is not None: - override_config.functions = config.override.functions - _did_use_default_should_do_automatic_account_linking = ( config.should_do_automatic_account_linking is None ) diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 33761851d..f06ed695f 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -97,17 +97,10 @@ class NormalisedDashboardConfig(BaseNormalisedConfig[RecipeInterface, APIInterfa def validate_and_normalise_user_input( config: DashboardConfig, ) -> NormalisedDashboardConfig: - override_config: NormalisedDashboardOverrideConfig = ( - NormalisedDashboardOverrideConfig() + override_config = NormalisedDashboardOverrideConfig.from_input_config( + override_config=config.override ) - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis - if config.api_key is not None and config.admins is not None: log_debug_message( "User Dashboard: Providing 'admins' has no effect when using an api key." diff --git a/supertokens_python/recipe/emailpassword/utils.py b/supertokens_python/recipe/emailpassword/utils.py index 36953b309..30ff84257 100644 --- a/supertokens_python/recipe/emailpassword/utils.py +++ b/supertokens_python/recipe/emailpassword/utils.py @@ -246,13 +246,9 @@ def validate_and_normalise_user_input( # NOTE: We don't need to check the instance of sign_up_feature and override # as they will always be either None or the specified type. - override_config = NormalisedEmailPasswordOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedEmailPasswordOverrideConfig.from_input_config( + override_config=config.override + ) sign_up_feature = config.sign_up_feature if sign_up_feature is None: diff --git a/supertokens_python/recipe/emailverification/utils.py b/supertokens_python/recipe/emailverification/utils.py index 7875b1f69..28a7c70a4 100644 --- a/supertokens_python/recipe/emailverification/utils.py +++ b/supertokens_python/recipe/emailverification/utils.py @@ -91,13 +91,9 @@ def get_email_delivery_config() -> EmailDeliveryConfigWithService[ override = None return EmailDeliveryConfigWithService(email_service, override=override) - override_config = NormalisedEmailVerificationOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedEmailVerificationOverrideConfig.from_input_config( + override_config=config.override + ) return NormalisedEmailVerificationConfig( mode=config.mode, diff --git a/supertokens_python/recipe/jwt/utils.py b/supertokens_python/recipe/jwt/utils.py index 37d36b8ff..959427fc5 100644 --- a/supertokens_python/recipe/jwt/utils.py +++ b/supertokens_python/recipe/jwt/utils.py @@ -39,13 +39,9 @@ class NormalisedJWTConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): def validate_and_normalise_user_input(config: JWTConfig): - override_config = NormalisedJWTOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedJWTOverrideConfig.from_input_config( + override_config=config.override + ) jwt_validity_seconds = config.jwt_validity_seconds diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index f7b328004..a5990360d 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -51,13 +51,9 @@ def validate_and_normalise_user_input( if config.first_factors is not None and len(config.first_factors) == 0: raise ValueError("'first_factors' can be either None or a non-empty list") - override_config = NormalisedMultiFactorAuthOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedMultiFactorAuthOverrideConfig.from_input_config( + override_config=config.override + ) return NormalisedMultiFactorAuthConfig( first_factors=config.first_factors, diff --git a/supertokens_python/recipe/multitenancy/utils.py b/supertokens_python/recipe/multitenancy/utils.py index 8aaeca3c6..8bb98ee8f 100644 --- a/supertokens_python/recipe/multitenancy/utils.py +++ b/supertokens_python/recipe/multitenancy/utils.py @@ -83,13 +83,9 @@ class NormalisedMultitenancyConfig(BaseNormalisedConfig[RecipeInterface, APIInte def validate_and_normalise_user_input( config: MultitenancyConfig, ) -> NormalisedMultitenancyConfig: - override_config = NormalisedMultitenancyOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedMultitenancyOverrideConfig.from_input_config( + override_config=config.override + ) return NormalisedMultitenancyConfig( get_allowed_domains_for_tenant_id=config.get_allowed_domains_for_tenant_id, diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py index 2c6107241..a4e3899a3 100644 --- a/supertokens_python/recipe/oauth2provider/utils.py +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -37,12 +37,8 @@ class NormalisedOAuth2ProviderConfig( def validate_and_normalise_user_input(config: OAuth2ProviderConfig): - override_config = NormalisedOAuth2ProviderOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedOAuth2ProviderOverrideConfig.from_input_config( + override_config=config.override + ) return NormalisedOAuth2ProviderConfig(override=override_config) diff --git a/supertokens_python/recipe/openid/utils.py b/supertokens_python/recipe/openid/utils.py index ae0887ad7..e9f593890 100644 --- a/supertokens_python/recipe/openid/utils.py +++ b/supertokens_python/recipe/openid/utils.py @@ -61,13 +61,9 @@ def validate_and_normalise_user_input( "The path of the issuer URL must be equal to the apiBasePath. The default value is /auth" ) - override_config = NormalisedOpenIdOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedOpenIdOverrideConfig.from_input_config( + override_config=config.override + ) return NormalisedOpenIdConfig( issuer_domain=issuer_domain, diff --git a/supertokens_python/recipe/passwordless/utils.py b/supertokens_python/recipe/passwordless/utils.py index ed1b15365..e362f89e5 100644 --- a/supertokens_python/recipe/passwordless/utils.py +++ b/supertokens_python/recipe/passwordless/utils.py @@ -177,13 +177,9 @@ def validate_and_normalise_user_input( app_info: AppInfo, config: PasswordlessConfig, ) -> NormalisedPasswordlessConfig: - override_config = NormalisedPasswordlessOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedPasswordlessOverrideConfig.from_input_config( + override_config=config.override + ) def get_email_delivery_config() -> EmailDeliveryConfigWithService[ PasswordlessLoginEmailTemplateVars diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 022cd1e7a..53d5ed550 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -522,13 +522,9 @@ def anti_csrf_function( if jwks_refresh_interval_sec is None: jwks_refresh_interval_sec = 4 * 3600 # 4 hours - override_config = NormalisedSessionOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedSessionOverrideConfig.from_input_config( + override_config=config.override + ) return NormalisedSessionConfig( refresh_token_path=app_info.api_base_path.append( diff --git a/supertokens_python/recipe/thirdparty/utils.py b/supertokens_python/recipe/thirdparty/utils.py index baef7c4d9..45476eb5c 100644 --- a/supertokens_python/recipe/thirdparty/utils.py +++ b/supertokens_python/recipe/thirdparty/utils.py @@ -74,13 +74,9 @@ def validate_and_normalise_user_input( "sign_in_and_up_feature must be an instance of SignInAndUpFeature" ) - override_config = NormalisedThirdPartyOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedThirdPartyOverrideConfig.from_input_config( + override_config=config.override + ) return NormalisedThirdPartyConfig( sign_in_and_up_feature=config.sign_in_and_up_feature, diff --git a/supertokens_python/recipe/totp/utils.py b/supertokens_python/recipe/totp/utils.py index 12b49ba08..4c49f662c 100644 --- a/supertokens_python/recipe/totp/utils.py +++ b/supertokens_python/recipe/totp/utils.py @@ -33,13 +33,9 @@ def validate_and_normalise_user_input( default_skew = config.default_skew if config.default_skew is not None else 1 default_period = config.default_period if config.default_period is not None else 30 - override_config = NormalisedTOTPOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedTOTPOverrideConfig.from_input_config( + override_config=config.override + ) return NormalisedTOTPConfig( issuer=issuer, diff --git a/supertokens_python/recipe/usermetadata/utils.py b/supertokens_python/recipe/usermetadata/utils.py index 385cdf782..4b611b52e 100644 --- a/supertokens_python/recipe/usermetadata/utils.py +++ b/supertokens_python/recipe/usermetadata/utils.py @@ -51,12 +51,8 @@ def validate_and_normalise_user_input( _app_info: AppInfo, input_config: UserMetadataConfig, ) -> NormalisedUserMetadataConfig: - override_config = NormalisedUserMetadataOverrideConfig() - if input_config.override is not None: - if input_config.override.functions is not None: - override_config.functions = input_config.override.functions - - if input_config.override.apis is not None: - override_config.apis = input_config.override.apis + override_config = NormalisedUserMetadataOverrideConfig.from_input_config( + override_config=input_config.override + ) return NormalisedUserMetadataConfig(override=override_config) diff --git a/supertokens_python/recipe/userroles/utils.py b/supertokens_python/recipe/userroles/utils.py index 8680481ed..e70879f98 100644 --- a/supertokens_python/recipe/userroles/utils.py +++ b/supertokens_python/recipe/userroles/utils.py @@ -50,13 +50,9 @@ def validate_and_normalise_user_input( _app_info: AppInfo, config: UserRolesConfig, ) -> NormalisedUserRolesConfig: - override_config = NormalisedUserRolesOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedUserRolesOverrideConfig.from_input_config( + override_config=config.override + ) skip_adding_roles_to_access_token = config.skip_adding_roles_to_access_token if skip_adding_roles_to_access_token is None: diff --git a/supertokens_python/recipe/webauthn/utils.py b/supertokens_python/recipe/webauthn/utils.py index 187494203..4e6b543b6 100644 --- a/supertokens_python/recipe/webauthn/utils.py +++ b/supertokens_python/recipe/webauthn/utils.py @@ -60,13 +60,9 @@ def validate_and_normalise_user_input( config.validate_email_address ) - override_config = NormalisedWebauthnOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions - - if config.override.apis is not None: - override_config.apis = config.override.apis + override_config = NormalisedWebauthnOverrideConfig.from_input_config( + override_config=config.override + ) def get_email_delivery_config() -> EmailDeliveryConfigWithService[ TypeWebauthnEmailDeliveryInput diff --git a/supertokens_python/types/config.py b/supertokens_python/types/config.py index bf1fc6553..c4793cc48 100644 --- a/supertokens_python/types/config.py +++ b/supertokens_python/types/config.py @@ -33,6 +33,22 @@ class BaseNormalisedOverrideConfigWithoutAPI( lambda original_implementation: original_implementation ) + @classmethod + def from_input_config( + cls, + override_config: Optional[BaseOverrideConfigWithoutAPI[FunctionInterfaceType]], + ) -> "BaseNormalisedOverrideConfigWithoutAPI[FunctionInterfaceType]": + """Create a normalized config from the input config.""" + normalised_config = cls() + + if override_config is None: + return normalised_config + + if override_config.functions is not None: + normalised_config.functions = override_config.functions + + return normalised_config + class BaseOverrideConfig( BaseOverrideConfigWithoutAPI[FunctionInterfaceType], @@ -55,6 +71,27 @@ class BaseNormalisedOverrideConfig( lambda original_implementation: original_implementation ) + @classmethod + def from_input_config( + cls, + override_config: Optional[ + BaseOverrideConfig[FunctionInterfaceType, APIInterfaceType] + ], + ) -> "BaseNormalisedOverrideConfig[FunctionInterfaceType, APIInterfaceType]": # type: ignore + """Create a normalized config from the input config.""" + normalised_config = cls() + + if override_config is None: + return normalised_config + + if override_config.functions is not None: + normalised_config.functions = override_config.functions + + if override_config.apis is not None: + normalised_config.apis = override_config.apis + + return normalised_config + class BaseConfigWithoutAPIOverride(CamelCaseBaseModel, Generic[FunctionInterfaceType]): """Base class for input config of a Recipe without API overrides.""" From 1ff70fdf8f35b1131e92e1c46e5d22f6d18d6e2a Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 23 Jul 2025 11:06:09 +0530 Subject: [PATCH 21/37] fix: circular imports, add tests for config override --- setup.py | 2 +- supertokens_python/__init__.py | 7 +- supertokens_python/constants.py | 2 +- supertokens_python/plugins.py | 99 +++++++-------- supertokens_python/recipe/session/__init__.py | 6 +- .../session/session_request_functions.py | 7 +- supertokens_python/supertokens.py | 120 +++++++++--------- tests/plugins/config.py | 25 ++-- tests/plugins/plugins.py | 28 +++- tests/plugins/test_plugins.py | 47 +++++-- 10 files changed, 194 insertions(+), 149 deletions(-) diff --git a/setup.py b/setup.py index 95fa79288..31aa4d19a 100644 --- a/setup.py +++ b/setup.py @@ -82,7 +82,7 @@ setup( name="supertokens_python", - version="0.30.1", + version="0.31.0", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index 8beee8b68..452dcc286 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -17,17 +17,18 @@ from typing_extensions import Literal from supertokens_python.framework.request import BaseRequest -from supertokens_python.recipe_module import RecipeModule from supertokens_python.types import RecipeUserId -from . import supertokens +from . import plugins, supertokens InputAppInfo = supertokens.InputAppInfo Supertokens = supertokens.Supertokens SupertokensConfig = supertokens.SupertokensConfig AppInfo = supertokens.AppInfo SupertokensExperimentalConfig = supertokens.SupertokensExperimentalConfig -RecipeModule = RecipeModule + +SupertokensPublicConfig = supertokens.SupertokensPublicConfig +plugins.LoadPluginsResponse.model_rebuild() def init( diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index c83ae6a46..34c8a4118 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -15,7 +15,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["5.3"] -VERSION = "0.30.1" +VERSION = "0.31.0" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 04605c9e7..cc8a34a6e 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -21,28 +21,10 @@ from supertokens_python.framework.response import BaseResponse from supertokens_python.logger import log_debug_message from supertokens_python.post_init_callbacks import PostSTInitCallbacks -from supertokens_python.recipe.accountlinking.types import AccountLinkingConfig -from supertokens_python.recipe.dashboard.utils import DashboardConfig -from supertokens_python.recipe.emailpassword.utils import EmailPasswordConfig -from supertokens_python.recipe.emailverification.utils import ( - EmailVerificationConfig, -) -from supertokens_python.recipe.jwt.utils import JWTConfig -from supertokens_python.recipe.multifactorauth.types import MultiFactorAuthConfig -from supertokens_python.recipe.multitenancy.utils import MultitenancyConfig -from supertokens_python.recipe.oauth2provider.utils import OAuth2ProviderConfig -from supertokens_python.recipe.openid.utils import OpenIdConfig -from supertokens_python.recipe.passwordless.utils import PasswordlessConfig from supertokens_python.recipe.session.interfaces import ( SessionClaimValidator, SessionContainer, ) -from supertokens_python.recipe.session.utils import SessionConfig -from supertokens_python.recipe.thirdparty.utils import ThirdPartyConfig -from supertokens_python.recipe.totp.types import TOTPConfig -from supertokens_python.recipe.usermetadata.utils import UserMetadataConfig -from supertokens_python.recipe.userroles.utils import UserRolesConfig -from supertokens_python.recipe.webauthn.types.config import WebauthnConfig from supertokens_python.types import MaybeAwaitable from supertokens_python.types.base import UserContext from supertokens_python.types.config import ( @@ -55,42 +37,60 @@ from supertokens_python.types.response import CamelCaseBaseModel if TYPE_CHECKING: + from supertokens_python.recipe.accountlinking.types import AccountLinkingConfig + from supertokens_python.recipe.dashboard.utils import DashboardConfig + from supertokens_python.recipe.emailpassword.utils import EmailPasswordConfig + from supertokens_python.recipe.emailverification.utils import ( + EmailVerificationConfig, + ) + from supertokens_python.recipe.jwt.utils import JWTConfig + from supertokens_python.recipe.multifactorauth.types import MultiFactorAuthConfig + from supertokens_python.recipe.multitenancy.utils import MultitenancyConfig + from supertokens_python.recipe.oauth2provider.utils import OAuth2ProviderConfig + from supertokens_python.recipe.openid.utils import OpenIdConfig + from supertokens_python.recipe.passwordless.utils import PasswordlessConfig + from supertokens_python.recipe.session.utils import SessionConfig + from supertokens_python.recipe.thirdparty.utils import ThirdPartyConfig + from supertokens_python.recipe.totp.types import TOTPConfig + from supertokens_python.recipe.usermetadata.utils import UserMetadataConfig + from supertokens_python.recipe.userroles.utils import UserRolesConfig + from supertokens_python.recipe.webauthn.types.config import WebauthnConfig from supertokens_python.supertokens import SupertokensPublicConfig -T = TypeVar( - "T", - bound=Union[ - AccountLinkingConfig, - DashboardConfig, - EmailPasswordConfig, - EmailVerificationConfig, - JWTConfig, - MultiFactorAuthConfig, - MultitenancyConfig, - OAuth2ProviderConfig, - OpenIdConfig, - PasswordlessConfig, - SessionConfig, - ThirdPartyConfig, - TOTPConfig, - UserMetadataConfig, - UserRolesConfig, - WebauthnConfig, - ], -) +T = TypeVar("T") + RecipeInterfaceType = TypeVar("RecipeInterfaceType", bound=BaseRecipeInterface) APIInterfaceType = TypeVar("APIInterfaceType", bound=BaseAPIInterface) +RecipeConfigType = Union[ + "AccountLinkingConfig", + "DashboardConfig", + "EmailPasswordConfig", + "EmailVerificationConfig", + "JWTConfig", + "MultiFactorAuthConfig", + "MultitenancyConfig", + "OAuth2ProviderConfig", + "OpenIdConfig", + "PasswordlessConfig", + "SessionConfig", + "ThirdPartyConfig", + "TOTPConfig", + "UserMetadataConfig", + "UserRolesConfig", + "WebauthnConfig", +] + class RecipeInitRequiredFunction(Protocol): def __call__(self, sdk_version: str) -> bool: ... class RecipePluginOverride: - functions: Optional[Callable[[BaseRecipeInterface], BaseRecipeInterface]] - apis: Optional[Callable[[BaseAPIInterface], BaseAPIInterface]] - config: Optional[Callable[[Any], Any]] + functions: Optional[Callable[[BaseRecipeInterface], BaseRecipeInterface]] = None + apis: Optional[Callable[[BaseAPIInterface], BaseAPIInterface]] = None + config: Optional[Callable[[Any], Any]] = None recipe_init_required: Optional[Union[bool, RecipeInitRequiredFunction]] = None @@ -202,13 +202,13 @@ class SuperTokensPluginBase(CamelCaseBaseModel): exports: Optional[Dict[str, Any]] = None -OverrideMap = Dict[str, Any] +OverrideMap = Dict[str, RecipePluginOverride] class SuperTokensPlugin(SuperTokensPluginBase): init: Optional[SuperTokensPluginInit] = None dependencies: Optional[SuperTokensPluginDependencies] = None - override_map: Optional[Dict[str, RecipePluginOverride]] = None + override_map: Optional[OverrideMap] = None route_handlers: Optional[ Union[List[PluginRouteHandler], PluginRouteHandlerFunction] ] = None @@ -281,16 +281,11 @@ def from_plugin(cls, plugin: SuperTokensPlugin) -> "SuperTokensPublicPlugin": ) -class ConfigOverrideBase: - functions: Optional[Callable[[Any], Any]] = None - apis: Optional[Callable[[Any], Any]] = None - - def apply_plugins( recipe_id: str, - config: T, + config: RecipeConfigType, plugins: List[OverrideMap], -) -> T: +) -> RecipeConfigType: if not isinstance(config, (BaseConfig, BaseConfigWithoutAPIOverride)): # type: ignore raise TypeError( f"Expected config to be an instance of BaseConfig or BaseConfigWithoutAPIOverride. {recipe_id=} {config=}" @@ -310,7 +305,7 @@ def default_api_override( if isinstance(config, BaseConfigWithoutAPIOverride): config.override = BaseOverrideConfigWithoutAPI() else: - config.override = BaseOverrideConfig() # type: ignore + config.override = BaseOverrideConfig() # type: ignore - generic type invariance function_overrides = getattr(config.override, "functions", default_fn_override) api_overrides = getattr(config.override, "apis", default_api_override) diff --git a/supertokens_python/recipe/session/__init__.py b/supertokens_python/recipe/session/__init__.py index db5fc865b..077d5bc25 100644 --- a/supertokens_python/recipe/session/__init__.py +++ b/supertokens_python/recipe/session/__init__.py @@ -13,18 +13,20 @@ # under the License. from __future__ import annotations -from typing import Any, Callable, Dict, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Union from typing_extensions import Literal from supertokens_python.framework import BaseRequest -from supertokens_python.supertokens import RecipeInit from . import exceptions as ex from . import interfaces, utils from .recipe import SessionRecipe from .utils import TokenTransferMethod +if TYPE_CHECKING: + from supertokens_python.supertokens import RecipeInit + InputErrorHandlers = utils.InputErrorHandlers SessionOverrideConfig = utils.SessionOverrideConfig SessionContainer = interfaces.SessionContainer diff --git a/supertokens_python/recipe/session/session_request_functions.py b/supertokens_python/recipe/session/session_request_functions.py index 1e2320817..1262ea1f6 100644 --- a/supertokens_python/recipe/session/session_request_functions.py +++ b/supertokens_python/recipe/session/session_request_functions.py @@ -52,7 +52,6 @@ get_auth_mode_from_header, get_required_claim_validators, ) -from supertokens_python.supertokens import Supertokens from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import ( FRAMEWORKS, @@ -88,6 +87,8 @@ async def get_session_from_request( ] = None, user_context: Optional[Dict[str, Any]] = None, ) -> Optional[SessionContainer]: + from supertokens_python.supertokens import Supertokens + log_debug_message("getSession: Started") if not hasattr(request, "wrapper_used") or not request.wrapper_used: @@ -245,6 +246,8 @@ async def create_new_session_in_request( session_data_in_database: Dict[str, Any], tenant_id: str, ) -> SessionContainer: + from supertokens_python.supertokens import Supertokens + log_debug_message("createNewSession: Started") # Handling framework specific request/response wrapping @@ -356,6 +359,8 @@ async def refresh_session_in_request( config: NormalisedSessionConfig, recipe_interface_impl: SessionRecipeInterface, ) -> SessionContainer: + from supertokens_python.supertokens import Supertokens + log_debug_message("refreshSession: Started") response_mutators: List[ResponseMutator] = [] diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 6a649d377..45dba892f 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -35,6 +35,12 @@ get_maybe_none_as_str, log_debug_message, ) +from supertokens_python.plugins import ( + OverrideMap, + PluginRouteHandler, + SuperTokensPlugin, + SuperTokensPublicPlugin, +) from supertokens_python.types.response import CamelCaseBaseModel from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT @@ -64,12 +70,6 @@ if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse - from supertokens_python.plugins import ( - OverrideMap, - PluginRouteHandler, - SuperTokensPlugin, - SuperTokensPublicPlugin, - ) from supertokens_python.recipe.session import SessionContainer from .recipe_module import RecipeModule @@ -111,60 +111,6 @@ def __init__( self.disable_core_call_cache = disable_core_call_cache -class SupertokensExperimentalConfig(CamelCaseBaseModel): - plugins: Optional[List[SuperTokensPlugin]] = None - - -class SupertokensPublicConfig(CamelCaseBaseModel): - """ - Public properties received as input to the `Supertokens.init` function. - """ - - app_info: InputAppInfo - framework: Literal["fastapi", "flask", "django"] - supertokens_config: SupertokensConfig - mode: Optional[Literal["asgi", "wsgi"]] - telemetry: Optional[bool] - debug: Optional[bool] - - -class SupertokensInputConfig(SupertokensPublicConfig): - """ - Various properties received as input to the `Supertokens.init` function. - """ - - recipe_list: List[Callable[[AppInfo, List[OverrideMap]], RecipeModule]] - experimental: Optional[SupertokensExperimentalConfig] = None - - def get_public_config(self) -> SupertokensPublicConfig: - return SupertokensPublicConfig( - app_info=self.app_info, - framework=self.framework, - supertokens_config=self.supertokens_config, - mode=self.mode, - telemetry=self.telemetry, - debug=self.debug, - ) - - @classmethod - def from_public_config( - cls, - config: SupertokensPublicConfig, - recipe_list: List[Callable[[AppInfo, List[OverrideMap]], RecipeModule]], - experimental: Optional[SupertokensExperimentalConfig], - ) -> "SupertokensInputConfig": - return cls( - app_info=config.app_info, - framework=config.framework, - supertokens_config=config.supertokens_config, - mode=config.mode, - telemetry=config.telemetry, - debug=config.debug, - recipe_list=recipe_list, - experimental=experimental, - ) - - class Host: def __init__(self, domain: NormalisedURLDomain, base_path: NormalisedURLPath): self.domain = domain @@ -277,6 +223,60 @@ def __call__( ) -> RecipeModule: ... +class SupertokensExperimentalConfig(CamelCaseBaseModel): + plugins: Optional[List["SuperTokensPlugin"]] = None + + +class SupertokensPublicConfig(CamelCaseBaseModel): + """ + Public properties received as input to the `Supertokens.init` function. + """ + + app_info: InputAppInfo + framework: Literal["fastapi", "flask", "django"] + supertokens_config: SupertokensConfig + mode: Optional[Literal["asgi", "wsgi"]] + telemetry: Optional[bool] + debug: Optional[bool] + + +class SupertokensInputConfig(SupertokensPublicConfig): + """ + Various properties received as input to the `Supertokens.init` function. + """ + + recipe_list: List[Callable[[AppInfo, List["OverrideMap"]], "RecipeModule"]] + experimental: Optional[SupertokensExperimentalConfig] = None + + def get_public_config(self) -> SupertokensPublicConfig: + return SupertokensPublicConfig( + app_info=self.app_info, + framework=self.framework, + supertokens_config=self.supertokens_config, + mode=self.mode, + telemetry=self.telemetry, + debug=self.debug, + ) + + @classmethod + def from_public_config( + cls, + config: SupertokensPublicConfig, + recipe_list: List[Callable[[AppInfo, List["OverrideMap"]], "RecipeModule"]], + experimental: Optional[SupertokensExperimentalConfig], + ) -> "SupertokensInputConfig": + return cls( + app_info=config.app_info, + framework=config.framework, + supertokens_config=config.supertokens_config, + mode=config.mode, + telemetry=config.telemetry, + debug=config.debug, + recipe_list=recipe_list, + experimental=experimental, + ) + + class Supertokens: __instance: Optional[Supertokens] = None diff --git a/tests/plugins/config.py b/tests/plugins/config.py index 31d6330a9..59387085d 100644 --- a/tests/plugins/config.py +++ b/tests/plugins/config.py @@ -1,5 +1,6 @@ -from typing import Any, Optional +from typing import Any, List, Optional +from pydantic import Field from supertokens_python.supertokens import ( AppInfo, ) @@ -21,12 +22,12 @@ ] -class NormalizedPluginTestConfig( - BaseNormalisedConfig[RecipeInterface, APIInterface] -): ... +class NormalizedPluginTestConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + test_property: List[str] -class PluginTestConfig(BaseConfig[RecipeInterface, APIInterface]): ... +class PluginTestConfig(BaseConfig[RecipeInterface, APIInterface]): + test_property: List[str] = Field(default_factory=lambda: ["original"]) class PluginOverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): @@ -39,12 +40,10 @@ def validate_and_normalise_user_input( if config is None: config = PluginTestConfig() - override_config = NormalisedPluginTestOverrideConfig() - if config.override is not None: - if config.override.functions is not None: - override_config.functions = config.override.functions + override_config = NormalisedPluginTestOverrideConfig.from_input_config( + config.override + ) - if config.override.apis is not None: - override_config.apis = config.override.apis - - return NormalizedPluginTestConfig(override=override_config) + return NormalizedPluginTestConfig( + override=override_config, test_property=config.test_property + ) diff --git a/tests/plugins/plugins.py b/tests/plugins/plugins.py index 60f635a7f..cea0d9d75 100644 --- a/tests/plugins/plugins.py +++ b/tests/plugins/plugins.py @@ -3,6 +3,7 @@ from supertokens_python.plugins import ( OverrideMap, PluginDependenciesOkResponse, + RecipePluginOverride, SuperTokensPlugin, SuperTokensPluginDependencies, SuperTokensPublicPlugin, @@ -10,7 +11,7 @@ from supertokens_python.supertokens import SupertokensPublicConfig from .api_implementation import APIInterface -from .config import PluginOverrideConfig +from .config import PluginTestConfig from .recipe import PluginTestRecipe from .recipe_implementation import RecipeInterface @@ -43,6 +44,14 @@ def new_sign_in_post(message: str, stack: List[str]): return function_override +def config_override_factory(identifier: str): + def config_override(original_config: PluginTestConfig) -> PluginTestConfig: + original_config.test_property.append(identifier) + return original_config + + return config_override + + def init_factory(identifier: str): def init( config: SupertokensPublicConfig, @@ -76,18 +85,24 @@ def plugin_factory( identifier: str, override_functions: bool = False, override_apis: bool = False, + override_config: bool = False, deps: Optional[List[SuperTokensPlugin]] = None, add_init: bool = False, ): - override_map_obj: OverrideMap = {PluginTestRecipe.recipe_id: PluginOverrideConfig()} + override_map_obj: OverrideMap = {PluginTestRecipe.recipe_id: RecipePluginOverride()} if override_functions: override_map_obj[ PluginTestRecipe.recipe_id - ].functions = function_override_factory(identifier) + ].functions = function_override_factory(identifier) # type: ignore if override_apis: override_map_obj[PluginTestRecipe.recipe_id].apis = api_override_factory( identifier + ) # type: ignore + + if override_config: + override_map_obj[PluginTestRecipe.recipe_id].config = config_override_factory( + identifier ) init_fn = None @@ -107,40 +122,47 @@ class Plugin(SuperTokensPlugin): Plugin1 = plugin_factory( "plugin1", override_functions=True, + override_config=True, add_init=True, ) Plugin2 = plugin_factory( "plugin2", override_functions=True, + override_config=True, add_init=True, ) Plugin3Dep1 = plugin_factory( "plugin3dep1", override_functions=True, + override_config=True, deps=[Plugin1], add_init=True, ) Plugin3Dep2_1 = plugin_factory( "plugin3dep2_1", override_functions=True, + override_config=True, deps=[Plugin2, Plugin1], add_init=True, ) Plugin4Dep1 = plugin_factory( "plugin4dep1", override_functions=True, + override_config=True, deps=[Plugin1], add_init=True, ) Plugin4Dep2 = plugin_factory( "plugin4dep2", override_functions=True, + override_config=True, deps=[Plugin2], add_init=True, ) Plugin4Dep3__2_1 = plugin_factory( "plugin4dep3__2_1", override_functions=True, + override_config=True, deps=[Plugin3Dep2_1], add_init=True, ) diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index 12ea8ff52..60d684537 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -217,12 +217,19 @@ def test_overrides( # TODO: Figure out a way to add circular dependencies and test them @mark.parametrize( - ("plugins", "recipe_expectation", "api_expectation", "init_expectation"), + ( + "plugins", + "recipe_expectation", + "api_expectation", + "config_expectation", + "init_expectation", + ), [ param( [Plugin1, Plugin1], outputs(["plugin1", "original"]), outputs(["original"]), + outputs(["original", "plugin1"]), outputs(["plugin1"]), id="1,1 => 1", ), @@ -230,6 +237,7 @@ def test_overrides( [Plugin1, Plugin2], outputs(["plugin2", "plugin1", "original"]), outputs(["original"]), + outputs(["original", "plugin1", "plugin2"]), outputs(["plugin1", "plugin2"]), id="1,2 => 2,1", ), @@ -237,6 +245,7 @@ def test_overrides( [Plugin3Dep1], outputs(["plugin3dep1", "plugin1", "original"]), outputs(["original"]), + outputs(["original", "plugin1", "plugin3dep1"]), outputs(["plugin1", "plugin3dep1"]), id="3->1 => 3,1", ), @@ -244,6 +253,7 @@ def test_overrides( [Plugin3Dep2_1], outputs(["plugin3dep2_1", "plugin1", "plugin2", "original"]), outputs(["original"]), + outputs(["original", "plugin2", "plugin1", "plugin3dep2_1"]), outputs(["plugin2", "plugin1", "plugin3dep2_1"]), id="3->(2,1) => 3,2,1", ), @@ -251,6 +261,7 @@ def test_overrides( [Plugin3Dep1, Plugin4Dep2], outputs(["plugin4dep2", "plugin2", "plugin3dep1", "plugin1", "original"]), outputs(["original"]), + outputs(["original", "plugin1", "plugin3dep1", "plugin2", "plugin4dep2"]), outputs(["plugin1", "plugin3dep1", "plugin2", "plugin4dep2"]), id="3->1,4->2 => 4,2,3,1", ), @@ -260,6 +271,9 @@ def test_overrides( ["plugin4dep3__2_1", "plugin3dep2_1", "plugin1", "plugin2", "original"] ), outputs(["original"]), + outputs( + ["original", "plugin2", "plugin1", "plugin3dep2_1", "plugin4dep3__2_1"] + ), outputs(["plugin2", "plugin1", "plugin3dep2_1", "plugin4dep3__2_1"]), id="4->3->(2,1) => 4,3,1,2", ), @@ -267,6 +281,7 @@ def test_overrides( [Plugin3Dep1, Plugin4Dep1], outputs(["plugin4dep1", "plugin3dep1", "plugin1", "original"]), outputs(["original"]), + outputs(["original", "plugin1", "plugin3dep1", "plugin4dep1"]), outputs(["plugin1", "plugin3dep1", "plugin4dep1"]), id="3->1,4->1 => 4,3,1", ), @@ -276,6 +291,7 @@ def test_depdendencies_and_init( plugins: List[SuperTokensPlugin], recipe_expectation: Any, api_expectation: Any, + config_expectation: Any, init_expectation: Any, ): partial_init( @@ -309,6 +325,10 @@ def test_depdendencies_and_init( message="msg", ) + with config_expectation as expected_stack: + output = PluginTestRecipe.get_instance().config.test_property + assert output == expected_stack + with init_expectation as expected_stack: assert PluginTestRecipe.init_calls == expected_stack @@ -317,7 +337,7 @@ def test_st_config_override(): plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) def config_override(config: SupertokensPublicConfig) -> SupertokensPublicConfig: - config.mode = "override" # type: ignore + config.mode = "asgi" return config plugin.config = config_override @@ -331,7 +351,7 @@ def config_override(config: SupertokensPublicConfig) -> SupertokensPublicConfig: ), ) - assert Supertokens.get_instance().app_info.mode == "override" + assert Supertokens.get_instance().app_info.mode == "asgi" def test_st_config_override_non_public_property(): @@ -343,16 +363,17 @@ def config_override(config: SupertokensPublicConfig) -> SupertokensPublicConfig: plugin.config = config_override - partial_init( - recipe_list=[ - recipe_factory(override_functions=False, override_apis=False), - ], - experimental=SupertokensExperimentalConfig( - plugins=[plugin], - ), - ) - - assert Supertokens.get_instance().recipe_modules != [] + with raises( + ValueError, match='"SupertokensPublicConfig" object has no field "recipe_list"' + ): + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) plugin_route_handler = PluginRouteHandler( From 44dc8cc44d558fbd450da6de8d2d2b5a03d1b260 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 23 Jul 2025 17:12:35 +0530 Subject: [PATCH 22/37] update: standardizes `__init__` exports - Defines aliases for old override config classes for backward compatibility - Adds `__all__` to recipe `__init__` files to explicitly declare exports - FIxes `apply_plugins` config input/return type - Adds ruff rule to format `__all__` exports - Updates changelog --- CHANGELOG.md | 43 +++++++++++++++ pyproject.toml | 5 +- supertokens_python/plugins.py | 41 +++++++------- .../recipe/accountlinking/__init__.py | 30 +++++++--- .../recipe/accountlinking/types.py | 2 + .../recipe/dashboard/__init__.py | 12 +++- supertokens_python/recipe/dashboard/utils.py | 2 + .../recipe/emailpassword/__init__.py | 41 +++++++++----- .../recipe/emailpassword/utils.py | 2 + .../recipe/emailverification/__init__.py | 33 ++++++----- .../recipe/emailverification/utils.py | 2 + supertokens_python/recipe/jwt/__init__.py | 10 +++- supertokens_python/recipe/jwt/utils.py | 2 + .../recipe/multifactorauth/__init__.py | 9 +++ .../recipe/multifactorauth/types.py | 2 + .../recipe/multitenancy/__init__.py | 21 ++++--- .../recipe/multitenancy/utils.py | 2 + .../recipe/oauth2provider/__init__.py | 17 ++++-- .../recipe/oauth2provider/utils.py | 2 + supertokens_python/recipe/openid/__init__.py | 10 +++- supertokens_python/recipe/openid/utils.py | 2 + .../recipe/passwordless/__init__.py | 55 ++++++++++++------- .../recipe/passwordless/utils.py | 2 + supertokens_python/recipe/session/__init__.py | 26 ++++++--- supertokens_python/recipe/session/utils.py | 2 + .../recipe/thirdparty/__init__.py | 23 +++++--- supertokens_python/recipe/thirdparty/utils.py | 2 + supertokens_python/recipe/totp/__init__.py | 15 ++++- supertokens_python/recipe/totp/types.py | 2 + .../recipe/usermetadata/__init__.py | 13 ++++- .../recipe/usermetadata/utils.py | 2 + .../recipe/userroles/__init__.py | 19 +++++-- supertokens_python/recipe/userroles/utils.py | 2 + .../recipe/webauthn/__init__.py | 4 +- 34 files changed, 335 insertions(+), 122 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65ed5469f..7b4f0f8da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,49 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.31.0] - 2025-07-18 +### Adds plugins support +- Adds an `experimental` property (`SuperTokensExperimentalConfig`) to the `SuperTokensConfig` + - Plugins can be configured under using the `plugins` property in the `experimental` config +- Refactors the AccountLinking recipe to be automatically initialized on SuperTokens init + +### Breaking Changes +- `AccountLinkingRecipe.get_instance` will now raise an exception if not initialized +- Various config classes renamed for consistency across the codebase, and classes added where they were missing + - Old classes added to the recipe modules as aliases for backward compatibility, but will be removed in future versions. Prefer using the renamed classes. + - `InputOverrideConfig` renamed to `OverrideConfig` + - `OverrideConfig` renamed to `NormalisedOverrideConfig` + - Input config classes like `InputConfig` renamed to `Config` + - Normalised config classes like `Config` renamed to `NormalisedConfig` + - Changed classes: + - AccountLinking `InputOverrideConfig` -> `AccountLinkingOverrideConfig` + - Dashboard `InputOverrideConfig` -> `DashboardOverrideConfig` + - EmailPassword + - `InputOverrideConfig` -> `EmailPasswordOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - EmailVerification + - `InputOverrideConfig` -> `EmailVerificationOverrideConfig` + - `exception` export removed from `__init__`, import the `exceptions` module directly + - JWT `OverrideConfig` -> `JWTOverrideConfig` + - MultiFactorAuth `OverrideConfig` -> `MultiFactorAuthOverrideConfig` + - Multitenancy + - `InputOverrideConfig` -> `MultitenancyOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - OAuth2Provider + - `InputOverrideConfig` -> `OAuth2ProviderOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - OpenId `InputOverrideConfig` -> `OpenIdOverrideConfig` + - Passwordless `InputOverrideConfig` -> `PasswordlessOverrideConfig` + - Session + - `InputOverrideConfig` -> `SessionOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - ThirdParty + - `InputOverrideConfig` -> `ThirdPartyOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - TOTP `OverrideConfig` -> `TOTPOverrideConfig` + - UserMetadata `InputOverrideConfig` -> `UserMetadataOverrideConfig` + - UserRoles `InputOverrideConfig` -> `UserRolesOverrideConfig` + ## [0.30.1] - 2025-07-21 - Adds missing register credential endpoint to the Webauthn recipe diff --git a/pyproject.toml b/pyproject.toml index 86362d513..a7e85ceca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,10 @@ line-length = 88 # Match Black's formatting src = ["supertokens_python"] [tool.ruff.lint] -extend-select = ["I"] # enable import sorting +extend-select = [ + "I", # enable import sorting + "RUF022", # Sort __all__ exports +] [tool.ruff.format] quote-style = "double" # Default diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index cc8a34a6e..412184357 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -57,31 +57,32 @@ from supertokens_python.recipe.webauthn.types.config import WebauthnConfig from supertokens_python.supertokens import SupertokensPublicConfig -T = TypeVar("T") +RecipeConfigType = TypeVar( + "RecipeConfigType", + bound=Union[ + "AccountLinkingConfig", + "DashboardConfig", + "EmailPasswordConfig", + "EmailVerificationConfig", + "JWTConfig", + "MultiFactorAuthConfig", + "MultitenancyConfig", + "OAuth2ProviderConfig", + "OpenIdConfig", + "PasswordlessConfig", + "SessionConfig", + "ThirdPartyConfig", + "TOTPConfig", + "UserMetadataConfig", + "UserRolesConfig", + "WebauthnConfig", + ], +) RecipeInterfaceType = TypeVar("RecipeInterfaceType", bound=BaseRecipeInterface) APIInterfaceType = TypeVar("APIInterfaceType", bound=BaseAPIInterface) -RecipeConfigType = Union[ - "AccountLinkingConfig", - "DashboardConfig", - "EmailPasswordConfig", - "EmailVerificationConfig", - "JWTConfig", - "MultiFactorAuthConfig", - "MultitenancyConfig", - "OAuth2ProviderConfig", - "OpenIdConfig", - "PasswordlessConfig", - "SessionConfig", - "ThirdPartyConfig", - "TOTPConfig", - "UserMetadataConfig", - "UserRolesConfig", - "WebauthnConfig", -] - class RecipeInitRequiredFunction(Protocol): def __call__(self, sdk_version: str) -> bool: ... diff --git a/supertokens_python/recipe/accountlinking/__init__.py b/supertokens_python/recipe/accountlinking/__init__.py index aac4576a3..c64e1051e 100644 --- a/supertokens_python/recipe/accountlinking/__init__.py +++ b/supertokens_python/recipe/accountlinking/__init__.py @@ -15,15 +15,17 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union -from ...types import User -from . import types -from .recipe import AccountLinkingRecipe +from supertokens_python.types import User -AccountLinkingOverrideConfig = types.AccountLinkingOverrideConfig -RecipeLevelUser = types.RecipeLevelUser -AccountInfoWithRecipeIdAndUserId = types.AccountInfoWithRecipeIdAndUserId -ShouldAutomaticallyLink = types.ShouldAutomaticallyLink -ShouldNotAutomaticallyLink = types.ShouldNotAutomaticallyLink +from .recipe import AccountLinkingRecipe +from .types import ( + AccountInfoWithRecipeIdAndUserId, + AccountLinkingOverrideConfig, + InputOverrideConfig, + RecipeLevelUser, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, +) if TYPE_CHECKING: from ..session.interfaces import SessionContainer @@ -50,3 +52,15 @@ def init( return AccountLinkingRecipe.init( on_account_linked, should_do_automatic_account_linking, override ) + + +__all__ = [ + "AccountInfoWithRecipeIdAndUserId", + "AccountLinkingOverrideConfig", + "AccountLinkingRecipe", + "InputOverrideConfig", # deprecated, use AccountLinkingOverrideConfig instead + "RecipeLevelUser", + "ShouldAutomaticallyLink", + "ShouldNotAutomaticallyLink", + "init", +] diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 83238f8a3..74eb752fb 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -43,6 +43,8 @@ NormalisedAccountLinkingOverrideConfig = BaseNormalisedOverrideConfigWithoutAPI[ RecipeInterface ] +InputOverrideConfig = AccountLinkingOverrideConfig +"""Deprecated, use `AccountLinkingOverrideConfig` instead.""" class AccountInfoWithRecipeId(AccountInfo): diff --git a/supertokens_python/recipe/dashboard/__init__.py b/supertokens_python/recipe/dashboard/__init__.py index 7ccf61d00..f09638c68 100644 --- a/supertokens_python/recipe/dashboard/__init__.py +++ b/supertokens_python/recipe/dashboard/__init__.py @@ -16,12 +16,10 @@ from typing import List, Optional -from supertokens_python.recipe.dashboard import utils from supertokens_python.supertokens import RecipeInit from .recipe import DashboardRecipe - -DashboardOverrideConfig = utils.DashboardOverrideConfig +from .utils import DashboardOverrideConfig, InputOverrideConfig def init( @@ -34,3 +32,11 @@ def init( admins, override, ) + + +__all__ = [ + "DashboardOverrideConfig", + "DashboardRecipe", + "InputOverrideConfig", # deprecated, use DashboardOverrideConfig instead + "init", +] diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index f06ed695f..d47e945f0 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -81,6 +81,8 @@ def to_json(self) -> Dict[str, Any]: NormalisedDashboardOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = DashboardOverrideConfig +"""Deprecated, use `DashboardOverrideConfig` instead.""" class DashboardConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/emailpassword/__init__.py b/supertokens_python/recipe/emailpassword/__init__.py index 8df1a4691..01d78d4a9 100644 --- a/supertokens_python/recipe/emailpassword/__init__.py +++ b/supertokens_python/recipe/emailpassword/__init__.py @@ -15,29 +15,28 @@ from typing import TYPE_CHECKING, Union -from supertokens_python.ingredients.emaildelivery import types as emaildelivery_types -from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig +from supertokens_python.ingredients.emaildelivery.types import ( + EmailDeliveryConfig, + EmailDeliveryInterface, +) from supertokens_python.recipe.emailpassword.types import EmailTemplateVars -from . import exceptions as ex -from . import utils -from .emaildelivery import services as emaildelivery_services +from .emaildelivery.services import SMTPService from .recipe import EmailPasswordRecipe - -exceptions = ex -EmailPasswordOverrideConfig = utils.EmailPasswordOverrideConfig -InputSignUpFeature = utils.InputSignUpFeature -InputFormField = utils.InputFormField -SMTPService = emaildelivery_services.SMTPService -EmailDeliveryInterface = emaildelivery_types.EmailDeliveryInterface +from .utils import ( + EmailPasswordOverrideConfig, + InputFormField, + InputOverrideConfig, + InputSignUpFeature, +) if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit def init( - sign_up_feature: Union[utils.InputSignUpFeature, None] = None, - override: Union[utils.EmailPasswordOverrideConfig, None] = None, + sign_up_feature: Union[InputSignUpFeature, None] = None, + override: Union[EmailPasswordOverrideConfig, None] = None, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, ) -> RecipeInit: return EmailPasswordRecipe.init( @@ -45,3 +44,17 @@ def init( override, email_delivery, ) + + +__all__ = [ + "EmailDeliveryConfig", + "EmailDeliveryInterface", + "EmailPasswordOverrideConfig", + "EmailPasswordRecipe", + "EmailTemplateVars", + "InputFormField", + "InputOverrideConfig", # deprecated, use EmailPasswordOverrideConfig instead + "InputSignUpFeature", + "SMTPService", + "init", +] diff --git a/supertokens_python/recipe/emailpassword/utils.py b/supertokens_python/recipe/emailpassword/utils.py index 30ff84257..62d4c72e3 100644 --- a/supertokens_python/recipe/emailpassword/utils.py +++ b/supertokens_python/recipe/emailpassword/utils.py @@ -221,6 +221,8 @@ def validate_and_normalise_reset_password_using_token_config( NormalisedEmailPasswordOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = EmailPasswordOverrideConfig +"""Deprecated, use `EmailPasswordOverrideConfig` instead.""" class EmailPasswordConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/emailverification/__init__.py b/supertokens_python/recipe/emailverification/__init__.py index 848a16a16..6d53c2940 100644 --- a/supertokens_python/recipe/emailverification/__init__.py +++ b/supertokens_python/recipe/emailverification/__init__.py @@ -15,21 +15,13 @@ from typing import TYPE_CHECKING, Optional, Union -from ...ingredients.emaildelivery.types import EmailDeliveryConfig -from . import exceptions as ex -from . import recipe, types, utils -from .emaildelivery import services as emaildelivery_services -from .interfaces import TypeGetEmailForUserIdFunction -from .recipe import EmailVerificationRecipe -from .types import EmailTemplateVars -from .utils import MODE_TYPE, EmailVerificationOverrideConfig - -InputOverrideConfig = utils.EmailVerificationOverrideConfig -exception = ex -SMTPService = emaildelivery_services.SMTPService -EmailVerificationClaim = recipe.EmailVerificationClaim -EmailDeliveryInterface = types.EmailDeliveryInterface +from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig +from .emaildelivery.services import SMTPService +from .interfaces import TypeGetEmailForUserIdFunction +from .recipe import EmailVerificationClaim, EmailVerificationRecipe +from .types import EmailDeliveryInterface, EmailTemplateVars +from .utils import MODE_TYPE, EmailVerificationOverrideConfig, InputOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit @@ -47,3 +39,16 @@ def init( get_email_for_recipe_user_id, override, ) + + +__all__ = [ + "EmailDeliveryInterface", + "EmailTemplateVars", + "EmailVerificationClaim", + "EmailVerificationOverrideConfig", + "EmailVerificationRecipe", + "InputOverrideConfig", # deprecated, use EmailVerificationOverrideConfig instead + "SMTPService", + "TypeGetEmailForUserIdFunction", + "init", +] diff --git a/supertokens_python/recipe/emailverification/utils.py b/supertokens_python/recipe/emailverification/utils.py index 28a7c70a4..5ff3754db 100644 --- a/supertokens_python/recipe/emailverification/utils.py +++ b/supertokens_python/recipe/emailverification/utils.py @@ -46,6 +46,8 @@ NormalisedEmailVerificationOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = EmailVerificationOverrideConfig +"""Deprecated, use `EmailVerificationOverrideConfig` instead.""" class EmailVerificationConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/jwt/__init__.py b/supertokens_python/recipe/jwt/__init__.py index baf2bfd54..0b12165e0 100644 --- a/supertokens_python/recipe/jwt/__init__.py +++ b/supertokens_python/recipe/jwt/__init__.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Union from .recipe import JWTRecipe -from .utils import JWTOverrideConfig +from .utils import JWTOverrideConfig, OverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit @@ -27,3 +27,11 @@ def init( override: Union[JWTOverrideConfig, None] = None, ) -> RecipeInit: return JWTRecipe.init(jwt_validity_seconds, override) + + +__all__ = [ + "JWTOverrideConfig", + "JWTRecipe", + "OverrideConfig", # deprecated, use JWTOverrideConfig instead + "init", +] diff --git a/supertokens_python/recipe/jwt/utils.py b/supertokens_python/recipe/jwt/utils.py index 959427fc5..5fb2e007a 100644 --- a/supertokens_python/recipe/jwt/utils.py +++ b/supertokens_python/recipe/jwt/utils.py @@ -28,6 +28,8 @@ NormalisedJWTOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +OverrideConfig = JWTOverrideConfig +"""Deprecated, use `JWTOverrideConfig` instead.""" class JWTConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/multifactorauth/__init__.py b/supertokens_python/recipe/multifactorauth/__init__.py index 861a4cdb8..614d6bf5e 100644 --- a/supertokens_python/recipe/multifactorauth/__init__.py +++ b/supertokens_python/recipe/multifactorauth/__init__.py @@ -17,6 +17,7 @@ from supertokens_python.recipe.multifactorauth.types import ( MultiFactorAuthOverrideConfig, + OverrideConfig, ) from .recipe import MultiFactorAuthRecipe @@ -33,3 +34,11 @@ def init( first_factors, override, ) + + +__all__ = [ + "MultiFactorAuthOverrideConfig", + "MultiFactorAuthRecipe", + "OverrideConfig", # deprecated, use MultiFactorAuthOverrideConfig instead + "init", +] diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index ab2572530..e53e6362f 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -46,6 +46,8 @@ def __init__(self, c: Dict[str, Any], v: bool): NormalisedMultiFactorAuthOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +OverrideConfig = MultiFactorAuthOverrideConfig +"""Deprecated, use `MultiFactorAuthOverrideConfig` instead.""" class MultiFactorAuthConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/multitenancy/__init__.py b/supertokens_python/recipe/multitenancy/__init__.py index e591c794c..fcfa5d8df 100644 --- a/supertokens_python/recipe/multitenancy/__init__.py +++ b/supertokens_python/recipe/multitenancy/__init__.py @@ -15,13 +15,10 @@ from typing import TYPE_CHECKING, Union -from . import exceptions as ex -from . import recipe -from .interfaces import TypeGetAllowedDomainsForTenantId -from .utils import MultitenancyOverrideConfig +from recipe import AllowedDomainsClaim, MultitenancyRecipe -AllowedDomainsClaim = recipe.AllowedDomainsClaim -exceptions = ex +from .interfaces import TypeGetAllowedDomainsForTenantId +from .utils import InputOverrideConfig, MultitenancyOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit @@ -33,7 +30,17 @@ def init( ] = None, override: Union[MultitenancyOverrideConfig, None] = None, ) -> RecipeInit: - return recipe.MultitenancyRecipe.init( + return MultitenancyRecipe.init( get_allowed_domains_for_tenant_id, override, ) + + +__all__ = [ + "AllowedDomainsClaim", + "InputOverrideConfig", # deprecated, use MultitenancyOverrideConfig instead + "MultitenancyOverrideConfig", + "MultitenancyRecipe", + "TypeGetAllowedDomainsForTenantId", + "init", +] diff --git a/supertokens_python/recipe/multitenancy/utils.py b/supertokens_python/recipe/multitenancy/utils.py index 8bb98ee8f..6f83d5518 100644 --- a/supertokens_python/recipe/multitenancy/utils.py +++ b/supertokens_python/recipe/multitenancy/utils.py @@ -70,6 +70,8 @@ async def on_recipe_disabled_for_tenant( NormalisedMultitenancyOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = MultitenancyOverrideConfig +"""Deprecated, use `MultitenancyOverrideConfig` instead.""" class MultitenancyConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/oauth2provider/__init__.py b/supertokens_python/recipe/oauth2provider/__init__.py index bf1f02e6c..c436084b5 100644 --- a/supertokens_python/recipe/oauth2provider/__init__.py +++ b/supertokens_python/recipe/oauth2provider/__init__.py @@ -15,11 +15,8 @@ from typing import TYPE_CHECKING, Union -from . import exceptions as ex -from . import recipe, utils - -exceptions = ex -OAuth2ProviderOverrideConfig = utils.OAuth2ProviderOverrideConfig +from .recipe import OAuth2ProviderRecipe +from .utils import InputOverrideConfig, OAuth2ProviderOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit @@ -28,4 +25,12 @@ def init( override: Union[OAuth2ProviderOverrideConfig, None] = None, ) -> RecipeInit: - return recipe.OAuth2ProviderRecipe.init(override) + return OAuth2ProviderRecipe.init(override) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use OAuth2ProviderOverrideConfig instead + "OAuth2ProviderOverrideConfig", + "OAuth2ProviderRecipe", + "init", +] diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py index a4e3899a3..cbea697ca 100644 --- a/supertokens_python/recipe/oauth2provider/utils.py +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -26,6 +26,8 @@ NormalisedOAuth2ProviderOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = OAuth2ProviderOverrideConfig +"""Deprecated, use `OAuth2ProviderOverrideConfig` instead.""" class OAuth2ProviderConfig(BaseConfig[RecipeInterface, APIInterface]): ... diff --git a/supertokens_python/recipe/openid/__init__.py b/supertokens_python/recipe/openid/__init__.py index 7ce1e58a8..948527a58 100644 --- a/supertokens_python/recipe/openid/__init__.py +++ b/supertokens_python/recipe/openid/__init__.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Union from .recipe import OpenIdRecipe -from .utils import OpenIdOverrideConfig +from .utils import InputOverrideConfig, OpenIdOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit @@ -27,3 +27,11 @@ def init( override: Union[OpenIdOverrideConfig, None] = None, ) -> RecipeInit: return OpenIdRecipe.init(issuer, override) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use OpenIdOverrideConfig instead + "OpenIdOverrideConfig", + "OpenIdRecipe", + "init", +] diff --git a/supertokens_python/recipe/openid/utils.py b/supertokens_python/recipe/openid/utils.py index e9f593890..43a346c5a 100644 --- a/supertokens_python/recipe/openid/utils.py +++ b/supertokens_python/recipe/openid/utils.py @@ -34,6 +34,8 @@ NormalisedOpenIdOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = OpenIdOverrideConfig +"""Deprecated, use `OpenIdOverrideConfig` instead.""" class OpenIdConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/passwordless/__init__.py b/supertokens_python/recipe/passwordless/__init__.py index 0ab7ef111..336a779c6 100644 --- a/supertokens_python/recipe/passwordless/__init__.py +++ b/supertokens_python/recipe/passwordless/__init__.py @@ -25,30 +25,28 @@ SMSTemplateVars, ) -from . import types, utils -from .emaildelivery import services as emaildelivery_services +from .emaildelivery.services import SMTPService from .recipe import PasswordlessRecipe -from .smsdelivery import services as smsdelivery_services +from .smsdelivery.services import SuperTokensSMSService, TwilioService +from .types import ( + CreateAndSendCustomEmailParameters, + CreateAndSendCustomTextMessageParameters, + EmailDeliveryInterface, + SMSDeliveryInterface, +) +from .utils import ( + ContactConfig, + ContactEmailOnlyConfig, + ContactEmailOrPhoneConfig, + ContactPhoneOnlyConfig, + InputOverrideConfig, + PasswordlessOverrideConfig, + PhoneOrEmailInput, +) if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit -PasswordlessOverrideConfig = utils.PasswordlessOverrideConfig -ContactEmailOnlyConfig = utils.ContactEmailOnlyConfig -ContactConfig = utils.ContactConfig -PhoneOrEmailInput = utils.PhoneOrEmailInput -CreateAndSendCustomTextMessageParameters = ( - types.CreateAndSendCustomTextMessageParameters -) -CreateAndSendCustomEmailParameters = types.CreateAndSendCustomEmailParameters -ContactPhoneOnlyConfig = utils.ContactPhoneOnlyConfig -ContactEmailOrPhoneConfig = utils.ContactEmailOrPhoneConfig -SMTPService = emaildelivery_services.SMTPService -TwilioService = smsdelivery_services.TwilioService -SuperTokensSMSService = smsdelivery_services.SuperTokensSMSService -EmailDeliveryInterface = types.EmailDeliveryInterface -SMSDeliveryInterface = types.SMSDeliveryInterface - def init( contact_config: ContactConfig, @@ -70,3 +68,22 @@ def init( email_delivery, sms_delivery, ) + + +__all__ = [ + "ContactConfig", + "ContactEmailOnlyConfig", + "ContactEmailOrPhoneConfig", + "ContactPhoneOnlyConfig", + "CreateAndSendCustomEmailParameters", + "CreateAndSendCustomTextMessageParameters", + "EmailDeliveryInterface", + "InputOverrideConfig", # deprecated, use PasswordlessOverrideConfig instead + "PasswordlessOverrideConfig", + "PhoneOrEmailInput", + "SMSDeliveryInterface", + "SMTPService", + "SuperTokensSMSService", + "TwilioService", + "init", +] diff --git a/supertokens_python/recipe/passwordless/utils.py b/supertokens_python/recipe/passwordless/utils.py index e362f89e5..6425b6739 100644 --- a/supertokens_python/recipe/passwordless/utils.py +++ b/supertokens_python/recipe/passwordless/utils.py @@ -75,6 +75,8 @@ async def default_validate_email(value: str, _tenant_id: str): NormalisedPasswordlessOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = PasswordlessOverrideConfig +"""Deprecated, use `PasswordlessOverrideConfig` instead.""" class ContactConfig(ABC): diff --git a/supertokens_python/recipe/session/__init__.py b/supertokens_python/recipe/session/__init__.py index 077d5bc25..497348028 100644 --- a/supertokens_python/recipe/session/__init__.py +++ b/supertokens_python/recipe/session/__init__.py @@ -19,19 +19,18 @@ from supertokens_python.framework import BaseRequest -from . import exceptions as ex -from . import interfaces, utils +from .interfaces import SessionContainer from .recipe import SessionRecipe -from .utils import TokenTransferMethod +from .utils import ( + InputErrorHandlers, + InputOverrideConfig, + SessionOverrideConfig, + TokenTransferMethod, +) if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit -InputErrorHandlers = utils.InputErrorHandlers -SessionOverrideConfig = utils.SessionOverrideConfig -SessionContainer = interfaces.SessionContainer -exceptions = ex - def init( cookie_domain: Union[str, None] = None, @@ -69,3 +68,14 @@ def init( expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, ) + + +__all__ = [ + "InputErrorHandlers", + "InputOverrideConfig", # deprecated, use SessionOverrideConfig instead + "SessionContainer", + "SessionOverrideConfig", + "SessionRecipe", + "TokenTransferMethod", + "init", +] diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 53d5ed550..30fe626e9 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -340,6 +340,8 @@ def get_token_transfer_method_default( NormalisedSessionOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = SessionOverrideConfig +"""Deprecated: Use `SessionOverrideConfig` instead.""" TokenType = Literal["access", "refresh"] diff --git a/supertokens_python/recipe/thirdparty/__init__.py b/supertokens_python/recipe/thirdparty/__init__.py index 6d03ffb2f..0e411de29 100644 --- a/supertokens_python/recipe/thirdparty/__init__.py +++ b/supertokens_python/recipe/thirdparty/__init__.py @@ -16,16 +16,9 @@ from typing import TYPE_CHECKING, Optional, Union -from . import exceptions as ex -from . import provider, utils +from .provider import ProviderClientConfig, ProviderConfig, ProviderInput from .recipe import ThirdPartyRecipe - -ThirdPartyOverrideConfig = utils.ThirdPartyOverrideConfig -SignInAndUpFeature = utils.SignInAndUpFeature -ProviderInput = provider.ProviderInput -ProviderConfig = provider.ProviderConfig -ProviderClientConfig = provider.ProviderClientConfig -exceptions = ex +from .utils import InputOverrideConfig, SignInAndUpFeature, ThirdPartyOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit @@ -38,3 +31,15 @@ def init( if sign_in_and_up_feature is None: sign_in_and_up_feature = SignInAndUpFeature() return ThirdPartyRecipe.init(sign_in_and_up_feature, override) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use `ThirdPartyOverrideConfig` instead + "ProviderClientConfig", + "ProviderConfig", + "ProviderInput", + "SignInAndUpFeature", + "ThirdPartyOverrideConfig", + "ThirdPartyRecipe", + "init", +] diff --git a/supertokens_python/recipe/thirdparty/utils.py b/supertokens_python/recipe/thirdparty/utils.py index 45476eb5c..fc1617838 100644 --- a/supertokens_python/recipe/thirdparty/utils.py +++ b/supertokens_python/recipe/thirdparty/utils.py @@ -56,6 +56,8 @@ def __init__(self, providers: Optional[List[ProviderInput]] = None): NormalisedThirdPartyOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = ThirdPartyOverrideConfig +"""Deprecated: Use `ThirdPartyOverrideConfig` instead.""" class ThirdPartyConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/totp/__init__.py b/supertokens_python/recipe/totp/__init__.py index 08d313d64..21baa226e 100644 --- a/supertokens_python/recipe/totp/__init__.py +++ b/supertokens_python/recipe/totp/__init__.py @@ -15,7 +15,11 @@ from typing import TYPE_CHECKING, Union -from supertokens_python.recipe.totp.types import TOTPConfig +from supertokens_python.recipe.totp.types import ( + OverrideConfig, + TOTPConfig, + TOTPOverrideConfig, +) from .recipe import TOTPRecipe @@ -29,3 +33,12 @@ def init( return TOTPRecipe.init( config=config, ) + + +__all__ = [ + "OverrideConfig", # deprecated, use `TOTPOverrideConfig` instead + "TOTPConfig", + "TOTPOverrideConfig", + "TOTPRecipe", + "init", +] diff --git a/supertokens_python/recipe/totp/types.py b/supertokens_python/recipe/totp/types.py index e9b521766..3ace4cb78 100644 --- a/supertokens_python/recipe/totp/types.py +++ b/supertokens_python/recipe/totp/types.py @@ -187,6 +187,8 @@ def to_json(self) -> Dict[str, Any]: NormalisedTOTPOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +OverrideConfig = TOTPOverrideConfig +"""Deprecated: Use `TOTPOverrideConfig` instead.""" class TOTPConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/usermetadata/__init__.py b/supertokens_python/recipe/usermetadata/__init__.py index 0f81b5a93..960fc474b 100644 --- a/supertokens_python/recipe/usermetadata/__init__.py +++ b/supertokens_python/recipe/usermetadata/__init__.py @@ -15,7 +15,8 @@ from typing import TYPE_CHECKING, Union -from . import utils +from utils import InputOverrideConfig, UserMetadataOverrideConfig + from .recipe import UserMetadataRecipe if TYPE_CHECKING: @@ -23,6 +24,14 @@ def init( - override: Union[utils.UserMetadataOverrideConfig, None] = None, + override: Union[UserMetadataOverrideConfig, None] = None, ) -> RecipeInit: return UserMetadataRecipe.init(override) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use `UserMetadataOverrideConfig` instead + "UserMetadataOverrideConfig", + "UserMetadataRecipe", + "init", +] diff --git a/supertokens_python/recipe/usermetadata/utils.py b/supertokens_python/recipe/usermetadata/utils.py index 4b611b52e..70d0a979a 100644 --- a/supertokens_python/recipe/usermetadata/utils.py +++ b/supertokens_python/recipe/usermetadata/utils.py @@ -36,6 +36,8 @@ NormalisedUserMetadataOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = UserMetadataOverrideConfig +"""Deprecated: Use `UserMetadataOverrideConfig` instead.""" class UserMetadataConfig(BaseConfig[RecipeInterface, APIInterface]): ... diff --git a/supertokens_python/recipe/userroles/__init__.py b/supertokens_python/recipe/userroles/__init__.py index c3f790100..7356cc507 100644 --- a/supertokens_python/recipe/userroles/__init__.py +++ b/supertokens_python/recipe/userroles/__init__.py @@ -15,11 +15,8 @@ from typing import TYPE_CHECKING, Optional, Union -from . import recipe, utils -from .recipe import UserRolesRecipe - -PermissionClaim = recipe.PermissionClaim -UserRoleClaim = recipe.UserRoleClaim +from .recipe import PermissionClaim, UserRoleClaim, UserRolesRecipe +from .utils import InputOverrideConfig, UserRolesOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit @@ -28,10 +25,20 @@ def init( skip_adding_roles_to_access_token: Optional[bool] = None, skip_adding_permissions_to_access_token: Optional[bool] = None, - override: Union[utils.UserRolesOverrideConfig, None] = None, + override: Union[UserRolesOverrideConfig, None] = None, ) -> RecipeInit: return UserRolesRecipe.init( skip_adding_roles_to_access_token, skip_adding_permissions_to_access_token, override, ) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use `UserRolesOverrideConfig` instead + "PermissionClaim", + "UserRoleClaim", + "UserRolesOverrideConfig", + "UserRolesRecipe", + "init", +] diff --git a/supertokens_python/recipe/userroles/utils.py b/supertokens_python/recipe/userroles/utils.py index e70879f98..81eb6bdbe 100644 --- a/supertokens_python/recipe/userroles/utils.py +++ b/supertokens_python/recipe/userroles/utils.py @@ -33,6 +33,8 @@ NormalisedUserRolesOverrideConfig = BaseNormalisedOverrideConfig[ RecipeInterface, APIInterface ] +InputOverrideConfig = UserRolesOverrideConfig +"""Deprecated: Use `UserRolesOverrideConfig` instead.""" class UserRolesConfig(BaseConfig[RecipeInterface, APIInterface]): diff --git a/supertokens_python/recipe/webauthn/__init__.py b/supertokens_python/recipe/webauthn/__init__.py index 7ad1a28a7..f19c9098b 100644 --- a/supertokens_python/recipe/webauthn/__init__.py +++ b/supertokens_python/recipe/webauthn/__init__.py @@ -60,11 +60,10 @@ def init(config: Optional[WebauthnConfig] = None): __all__ = [ - "init", "APIInterface", "RecipeInterface", - "WebauthnOverrideConfig", "WebauthnConfig", + "WebauthnOverrideConfig", "WebauthnRecipe", "consume_recover_account_token", "create_recover_account_link", @@ -72,6 +71,7 @@ def init(config: Optional[WebauthnConfig] = None): "get_credential", "get_generated_options", "get_user_from_recover_account_token", + "init", "list_credentials", "recover_account", "register_credential", From 75fe3b4d8c44da67cd0f3642cc9266163a609ab2 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 23 Jul 2025 17:17:21 +0530 Subject: [PATCH 23/37] lint: fix lint errors --- .../recipe/dashboard/api/__init__.py | 28 +++++++++---------- supertokens_python/types/__init__.py | 2 +- supertokens_python/types/config.py | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/supertokens_python/recipe/dashboard/api/__init__.py b/supertokens_python/recipe/dashboard/api/__init__.py index 3ae3a869a..527f08fe2 100644 --- a/supertokens_python/recipe/dashboard/api/__init__.py +++ b/supertokens_python/recipe/dashboard/api/__init__.py @@ -33,24 +33,24 @@ from .validate_key import handle_validate_key_api __all__ = [ - "handle_dashboard_api", "api_key_protector", - "handle_users_count_get_api", - "handle_users_get_api", - "handle_validate_key_api", - "handle_user_email_verify_get", - "handle_user_get", + "handle_analytics_post", + "handle_dashboard_api", + "handle_email_verify_token_post", + "handle_emailpassword_signin_api", + "handle_emailpassword_signout_api", + "handle_get_tags", "handle_metadata_get", + "handle_metadata_put", "handle_sessions_get", "handle_user_delete", - "handle_user_put", + "handle_user_email_verify_get", "handle_user_email_verify_put", - "handle_metadata_put", - "handle_user_sessions_post", + "handle_user_get", "handle_user_password_put", - "handle_email_verify_token_post", - "handle_emailpassword_signin_api", - "handle_emailpassword_signout_api", - "handle_get_tags", - "handle_analytics_post", + "handle_user_put", + "handle_user_sessions_post", + "handle_users_count_get_api", + "handle_users_get_api", + "handle_validate_key_api", ] diff --git a/supertokens_python/types/__init__.py b/supertokens_python/types/__init__.py index 72f7cda5d..7359f2069 100644 --- a/supertokens_python/types/__init__.py +++ b/supertokens_python/types/__init__.py @@ -27,8 +27,8 @@ __all__ = ( "APIResponse", - "GeneralErrorResponse", "AccountInfo", + "GeneralErrorResponse", "LoginMethod", "MaybeAwaitable", "RecipeUserId", diff --git a/supertokens_python/types/config.py b/supertokens_python/types/config.py index c4793cc48..d6d12808f 100644 --- a/supertokens_python/types/config.py +++ b/supertokens_python/types/config.py @@ -72,7 +72,7 @@ class BaseNormalisedOverrideConfig( ) @classmethod - def from_input_config( + def from_input_config( # type: ignore - invalid override due to subclassing cls, override_config: Optional[ BaseOverrideConfig[FunctionInterfaceType, APIInterfaceType] From 4a5e8f296446c8bb9ddde12d26f045dbdd815716 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 23 Jul 2025 17:21:11 +0530 Subject: [PATCH 24/37] fix: broken imports --- supertokens_python/recipe/multitenancy/__init__.py | 3 +-- supertokens_python/recipe/usermetadata/__init__.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/supertokens_python/recipe/multitenancy/__init__.py b/supertokens_python/recipe/multitenancy/__init__.py index fcfa5d8df..a91f5fa83 100644 --- a/supertokens_python/recipe/multitenancy/__init__.py +++ b/supertokens_python/recipe/multitenancy/__init__.py @@ -15,9 +15,8 @@ from typing import TYPE_CHECKING, Union -from recipe import AllowedDomainsClaim, MultitenancyRecipe - from .interfaces import TypeGetAllowedDomainsForTenantId +from .recipe import AllowedDomainsClaim, MultitenancyRecipe from .utils import InputOverrideConfig, MultitenancyOverrideConfig if TYPE_CHECKING: diff --git a/supertokens_python/recipe/usermetadata/__init__.py b/supertokens_python/recipe/usermetadata/__init__.py index 960fc474b..c760e225b 100644 --- a/supertokens_python/recipe/usermetadata/__init__.py +++ b/supertokens_python/recipe/usermetadata/__init__.py @@ -15,9 +15,8 @@ from typing import TYPE_CHECKING, Union -from utils import InputOverrideConfig, UserMetadataOverrideConfig - from .recipe import UserMetadataRecipe +from .utils import InputOverrideConfig, UserMetadataOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import RecipeInit From 3983522df79f6f7eee22e38ed738ceff711421fe Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 23 Jul 2025 17:23:04 +0530 Subject: [PATCH 25/37] fix: changes test plugin versions to use current SDK version --- tests/plugins/plugins.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/plugins/plugins.py b/tests/plugins/plugins.py index cea0d9d75..463f96a7f 100644 --- a/tests/plugins/plugins.py +++ b/tests/plugins/plugins.py @@ -1,5 +1,6 @@ from typing import Any, List, Optional, Union +from supertokens_python.constants import VERSION from supertokens_python.plugins import ( OverrideMap, PluginDependenciesOkResponse, @@ -111,7 +112,7 @@ def plugin_factory( class Plugin(SuperTokensPlugin): id: str = identifier - compatible_sdk_versions: Union[str, List[str]] = ["0.30.0"] + compatible_sdk_versions: Union[str, List[str]] = [VERSION] override_map: Optional[OverrideMap] = override_map_obj init: Any = init_fn dependencies: Optional[SuperTokensPluginDependencies] = dependency_factory(deps) From c677c94e8911cf2217d2de4451f34c04480cd082 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 23 Jul 2025 17:39:46 +0530 Subject: [PATCH 26/37] fix: SuperTokensInputConfig model rebuild, base `__init__` imports/exports --- supertokens_python/__init__.py | 48 +++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index 452dcc286..ee6407ff8 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -17,25 +17,37 @@ from typing_extensions import Literal from supertokens_python.framework.request import BaseRequest +from supertokens_python.recipe_module import RecipeModule from supertokens_python.types import RecipeUserId -from . import plugins, supertokens +from .plugins import LoadPluginsResponse +from .supertokens import ( + AppInfo, + InputAppInfo, + RecipeInit, + Supertokens, + SupertokensConfig, + SupertokensExperimentalConfig, + SupertokensInputConfig, + SupertokensPublicConfig, +) -InputAppInfo = supertokens.InputAppInfo -Supertokens = supertokens.Supertokens -SupertokensConfig = supertokens.SupertokensConfig -AppInfo = supertokens.AppInfo -SupertokensExperimentalConfig = supertokens.SupertokensExperimentalConfig +# Some Pydantic models need a rebuild to resolve ForwardRefs +# Referencing imports here to prevent lint errors. +# Caveat: These will be available for import from this module directly. +RecipeModule # type: ignore -SupertokensPublicConfig = supertokens.SupertokensPublicConfig -plugins.LoadPluginsResponse.model_rebuild() +# LoadPluginsResponse -> SupertokensPublicConfig +LoadPluginsResponse.model_rebuild() +# SupertokensInputConfig -> RecipeModule +SupertokensInputConfig.model_rebuild() def init( app_info: InputAppInfo, framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, - recipe_list: List[supertokens.RecipeInit], + recipe_list: List[RecipeInit], mode: Optional[Literal["asgi", "wsgi"]] = None, telemetry: Optional[bool] = None, debug: Optional[bool] = None, @@ -54,7 +66,7 @@ def init( def get_all_cors_headers() -> List[str]: - return supertokens.Supertokens.get_instance().get_all_cors_headers() + return Supertokens.get_instance().get_all_cors_headers() def get_request_from_user_context( @@ -65,3 +77,19 @@ def get_request_from_user_context( def convert_to_recipe_user_id(user_id: str) -> RecipeUserId: return RecipeUserId(user_id) + + +__all__ = [ + "AppInfo", + "InputAppInfo", + "RecipeInit", + "RecipeUserId", + "Supertokens", + "SupertokensConfig", + "SupertokensExperimentalConfig", + "SupertokensPublicConfig", + "convert_to_recipe_user_id", + "get_all_cors_headers", + "get_request_from_user_context", + "init", +] From decc4cd10a016bec8631c99b5fbbd7b50edfef13 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Thu, 24 Jul 2025 10:52:01 +0530 Subject: [PATCH 27/37] feat: adds way to check if recipe is initialized --- CHANGELOG.md | 1 + supertokens_python/__init__.py | 4 ++++ supertokens_python/supertokens.py | 12 ++++++++++++ 3 files changed, 17 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b4f0f8da..535a5dc40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Adds an `experimental` property (`SuperTokensExperimentalConfig`) to the `SuperTokensConfig` - Plugins can be configured under using the `plugins` property in the `experimental` config - Refactors the AccountLinking recipe to be automatically initialized on SuperTokens init +- Adds `is_recipe_initialized` method to check if a recipe has been initialized ### Breaking Changes - `AccountLinkingRecipe.get_instance` will now raise an exception if not initialized diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index ee6407ff8..7923bb72d 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -79,6 +79,9 @@ def convert_to_recipe_user_id(user_id: str) -> RecipeUserId: return RecipeUserId(user_id) +is_recipe_initialized = Supertokens.is_recipe_initialized + + __all__ = [ "AppInfo", "InputAppInfo", @@ -92,4 +95,5 @@ def convert_to_recipe_user_id(user_id: str) -> RecipeUserId: "get_all_cors_headers", "get_request_from_user_context", "init", + "is_recipe_initialized", ] diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 45dba892f..39c90bc85 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -900,6 +900,18 @@ def get_request_from_user_context( return user_context.get("_default", {}).get("request") + @staticmethod + def is_recipe_initialized(recipe_id: str) -> bool: + """ + Check if a recipe is initialized. + :param recipe_id: The ID of the recipe to check. + :return: Whether the recipe is initialized. + """ + return any( + recipe.get_recipe_id() == recipe_id + for recipe in Supertokens.get_instance().recipe_modules + ) + def get_request_from_user_context( user_context: Optional[Dict[str, Any]], From e7869a9045ff632de9bb7bbccac44baf36f8eb19 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Fri, 25 Jul 2025 12:59:38 +0530 Subject: [PATCH 28/37] feat: use normalised app info in public config --- supertokens_python/plugins.py | 4 ++- supertokens_python/supertokens.py | 47 +++++++++++++++++++++++++------ 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 412184357..b11c16c13 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -393,7 +393,9 @@ def load_plugins( if VERSION not in version_constraints: # TODO: Better checks - raise Exception("Plugin version mismatch") + raise Exception( + f"Plugin version mismatch. Version {VERSION} not in {version_constraints=} for plugin {plugin.id}" + ) # TODO: Overkill, but could topologically sort the plugins based on dependencies dependencies = plugin.get_dependencies( diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 39c90bc85..51612a04c 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -227,12 +227,11 @@ class SupertokensExperimentalConfig(CamelCaseBaseModel): plugins: Optional[List["SuperTokensPlugin"]] = None -class SupertokensPublicConfig(CamelCaseBaseModel): +class _BaseSupertokensPublicConfig(CamelCaseBaseModel): """ Public properties received as input to the `Supertokens.init` function. """ - app_info: InputAppInfo framework: Literal["fastapi", "flask", "django"] supertokens_config: SupertokensConfig mode: Optional[Literal["asgi", "wsgi"]] @@ -240,17 +239,33 @@ class SupertokensPublicConfig(CamelCaseBaseModel): debug: Optional[bool] -class SupertokensInputConfig(SupertokensPublicConfig): +class SupertokensPublicConfig(_BaseSupertokensPublicConfig): """ - Various properties received as input to the `Supertokens.init` function. + Public properties received as input to the `Supertokens.init` function. """ + app_info: AppInfo # Uses the Normalised AppInfo class + + +class _BaseSupertokensInputConfig(_BaseSupertokensPublicConfig): recipe_list: List[Callable[[AppInfo, List["OverrideMap"]], "RecipeModule"]] experimental: Optional[SupertokensExperimentalConfig] = None - def get_public_config(self) -> SupertokensPublicConfig: + +class SupertokensInputConfigWithNormalisedAppInfo(_BaseSupertokensInputConfig): + app_info: AppInfo + + +class SupertokensInputConfig(_BaseSupertokensInputConfig): + """ + Various properties received as input to the `Supertokens.init` function. + """ + + app_info: InputAppInfo + + def to_public_config(self, normalised_app_info: AppInfo) -> SupertokensPublicConfig: return SupertokensPublicConfig( - app_info=self.app_info, + app_info=normalised_app_info, framework=self.framework, supertokens_config=self.supertokens_config, mode=self.mode, @@ -262,11 +277,12 @@ def get_public_config(self) -> SupertokensPublicConfig: def from_public_config( cls, config: SupertokensPublicConfig, + app_info: InputAppInfo, recipe_list: List[Callable[[AppInfo, List["OverrideMap"]], "RecipeModule"]], experimental: Optional[SupertokensExperimentalConfig], ) -> "SupertokensInputConfig": return cls( - app_info=config.app_info, + app_info=app_info, framework=config.framework, supertokens_config=config.supertokens_config, mode=config.mode, @@ -308,6 +324,18 @@ def __init__( if not isinstance(app_info, InputAppInfo): # type: ignore raise ValueError("app_info must be an instance of InputAppInfo") + self.app_info = AppInfo( + app_info.app_name, + app_info.api_domain, + app_info.website_domain, + framework, + app_info.api_gateway_path, + app_info.api_base_path, + app_info.website_base_path, + mode, + app_info.origin, + ) + input_config = SupertokensInputConfig( app_info=app_info, framework=framework, @@ -318,7 +346,9 @@ def __init__( debug=debug, experimental=experimental, ) - input_public_config = input_config.get_public_config() + input_public_config = input_config.to_public_config( + normalised_app_info=self.app_info + ) # Use the input public config by default if no plugins provided processed_public_config: SupertokensPublicConfig = input_public_config @@ -340,6 +370,7 @@ def __init__( config = SupertokensInputConfig.from_public_config( config=processed_public_config, + app_info=input_config.app_info, recipe_list=recipe_list, experimental=experimental, ) From 5b1f37e1f508fcf1846167f5109d572cbeff0274 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Fri, 25 Jul 2025 14:28:53 +0530 Subject: [PATCH 29/37] feat: adds additional check for duplicate plugins --- supertokens_python/plugins.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index b11c16c13..0147b41df 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -406,6 +406,18 @@ def load_plugins( final_plugin_list.extend(dependencies) input_plugin_seen_list.update({dep.id for dep in dependencies}) + # Secondary check to ensure no duplicate plugins + # Should ideally be handled in the dependency resolution above. + unique_plugins: Set[str] = set() + duplicate_plugins: List[str] = [] + for plugin in final_plugin_list: + if plugin.id in unique_plugins: + duplicate_plugins.append(plugin.id) + unique_plugins.add(plugin.id) + + if len(duplicate_plugins) > 0: + raise Exception(f"Duplicate plugins found: {', '.join(duplicate_plugins)}") + processed_plugin_list = [ SuperTokensPublicPlugin.from_plugin(plugin) for plugin in final_plugin_list ] From 11a28ca86432fb267e4a2739edbe1d2f2531929c Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 6 Aug 2025 18:06:26 +0530 Subject: [PATCH 30/37] feat: improves plugin version checks - Adds dependency on `packaging` to compare versions - Versions expected to follow PEP 440 style specifiers - Sorts items in `setup.py` for clarity ref: supertokens/supertokens-node#1021 --- dev-requirements.txt | 1 + setup.py | 39 ++++++++++++++------------- supertokens_python/plugins.py | 12 +++++---- tests/plugins/plugins.py | 8 +++++- tests/plugins/test_plugins.py | 50 ++++++++++++++++++++++++++++++++++- 5 files changed, 84 insertions(+), 26 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 2dddb9f97..99ffb6fe8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,6 +7,7 @@ fastapi==0.115.5 Flask==3.0.3 flask-cors==5.0.0 nest-asyncio==1.6.0 +packaging==25.0 pdoc3==0.11.0 pre-commit==3.5.0 pyfakefs==5.7.4 diff --git a/setup.py b/setup.py index 31aa4d19a..fe85d89d7 100644 --- a/setup.py +++ b/setup.py @@ -61,23 +61,23 @@ } exclude_list = [ - "tests", - "examples", - "hooks", - ".gitignore", + ".circleci", ".git", + ".github", + ".gitignore", + ".pylintrc", + "Makefile", "addDevTag", "addReleaseTag", - "frontendDriverInterfaceSupported.json", "coreDriverInterfaceSupported.json", - ".github", - ".circleci", - "html", - "pyrightconfig.json", - "Makefile", - ".pylintrc", "dev-requirements.txt", "docs-templates", + "examples", + "frontendDriverInterfaceSupported.json", + "hooks", + "html", + "pyrightconfig.json", + "tests", ] setup( @@ -112,22 +112,23 @@ ], keywords="", install_requires=[ + "Deprecated<1.3.0", # [crypto] ensures that it installs the `cryptography` library as well # based on constraints specified in https://github.com/jpadilla/pyjwt/blob/master/setup.cfg#L50 "PyJWT[crypto]>=2.5.0,<3.0.0", - "httpx>=0.15.0,<1.0.0", - "pycryptodome<3.21.0", - "tldextract<6.0.0", + "aiosmtplib>=1.1.6,<4.0.0", "asgiref>=3.4.1,<4", - "typing_extensions>=4.1.1,<5.0.0", - "Deprecated<1.3.0", + "httpx>=0.15.0,<1.0.0", + "packaging>=25.0,<26.0", "phonenumbers<9", - "twilio<10", - "aiosmtplib>=1.1.6,<4.0.0", "pkce<1.1.0", + "pycryptodome<3.21.0", + "pydantic>=2.10.6,<3.0.0", "pyotp<3", "python-dateutil<3", - "pydantic>=2.10.6,<3.0.0", + "tldextract<6.0.0", + "twilio<10", + "typing_extensions>=4.1.1,<5.0.0", ], python_requires=">=3.8", include_package_data=True, diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 0147b41df..902acfe32 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -14,6 +14,8 @@ runtime_checkable, ) +from packaging.specifiers import SpecifierSet +from packaging.version import Version from typing_extensions import Protocol from supertokens_python.constants import VERSION @@ -387,14 +389,14 @@ def load_plugins( continue if isinstance(plugin.compatible_sdk_versions, list): - version_constraints = plugin.compatible_sdk_versions + version_constraints = ",".join(plugin.compatible_sdk_versions) else: - version_constraints = [plugin.compatible_sdk_versions] + version_constraints = plugin.compatible_sdk_versions - if VERSION not in version_constraints: - # TODO: Better checks + if not SpecifierSet(version_constraints).contains(Version(VERSION)): raise Exception( - f"Plugin version mismatch. Version {VERSION} not in {version_constraints=} for plugin {plugin.id}" + f"Incompatible SDK version for plugin {plugin.id}. " + f"Version {VERSION} not found in compatible versions {version_constraints}" ) # TODO: Overkill, but could topologically sort the plugins based on dependencies diff --git a/tests/plugins/plugins.py b/tests/plugins/plugins.py index 463f96a7f..eaee4c828 100644 --- a/tests/plugins/plugins.py +++ b/tests/plugins/plugins.py @@ -89,6 +89,7 @@ def plugin_factory( override_config: bool = False, deps: Optional[List[SuperTokensPlugin]] = None, add_init: bool = False, + compatible_sdk_versions: Optional[Union[str, List[str]]] = None, ): override_map_obj: OverrideMap = {PluginTestRecipe.recipe_id: RecipePluginOverride()} @@ -110,9 +111,14 @@ def plugin_factory( if add_init: init_fn = init_factory(identifier) + if compatible_sdk_versions is None: + sdk_versions = f"=={VERSION}" + else: + sdk_versions = compatible_sdk_versions + class Plugin(SuperTokensPlugin): id: str = identifier - compatible_sdk_versions: Union[str, List[str]] = [VERSION] + compatible_sdk_versions: Union[str, List[str]] = sdk_versions override_map: Optional[OverrideMap] = override_map_obj init: Any = init_fn dependencies: Optional[SuperTokensPluginDependencies] = dependency_factory(deps) diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index 60d684537..efb70e504 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -1,5 +1,6 @@ from functools import partial -from typing import Any, List +from typing import Any, List, Union +from unittest.mock import patch from pytest import fixture, mark, param, raises from supertokens_python import ( @@ -450,3 +451,50 @@ async def test_route_handlers_callable(handler_response: Any, expectation: Any): ) assert res == expected_output + + +@mark.parametrize( + ("sdk_version", "compatible_versions", "expectation"), + [ + param( + "1.5.0", + ">=1.0.0,<2.0.0", + outputs(None), + id="[Valid][1.5.0][>=1.0.0,<2.0.0] as string", + ), + param( + "1.5.0", + [">=1.0.0", "<2.0.0"], + outputs(None), + id="[Valid][1.5.0][>=1.0.0,<2.0.0] as list of strings", + ), + param( + "2.0.0", + [">=1.0.0,<2.0.0"], + raises(Exception, match="Incompatible SDK version for plugin plugin1."), + id="[Invalid][2.0.0][>=1.0.0,<2.0.0]", + ), + ], +) +def test_versions( + sdk_version: str, + compatible_versions: Union[str, List[str]], + expectation: Any, +): + plugin = plugin_factory( + "plugin1", + override_functions=False, + override_apis=False, + compatible_sdk_versions=compatible_versions, + ) + + with patch("supertokens_python.plugins.VERSION", sdk_version): + with expectation as _: + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) From 09b4216a5677ce5b2fcf7bc3c1ac00505c19596a Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 6 Aug 2025 19:30:41 +0530 Subject: [PATCH 31/37] feat: add error handling for plugin route handlers - Uses a new `PluginRouteHandlerWithPluginId` class internally - Adds a new `PluginError` and handles similar to `BadInputError` - Makes `handler` functions async to be consistent with Node ref: supertokens/supertokens-node#1021 --- supertokens_python/exceptions.py | 4 ++++ supertokens_python/plugins.py | 36 ++++++++++++++++++++++++++----- supertokens_python/supertokens.py | 27 ++++++++++++++++------- tests/plugins/test_plugins.py | 9 ++++++-- 4 files changed, 61 insertions(+), 15 deletions(-) diff --git a/supertokens_python/exceptions.py b/supertokens_python/exceptions.py index 30f4f8514..d7ccae5f4 100644 --- a/supertokens_python/exceptions.py +++ b/supertokens_python/exceptions.py @@ -40,3 +40,7 @@ class GeneralError(SuperTokensError): class BadInputError(SuperTokensError): pass + + +class PluginError(SuperTokensError): + pass diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py index 902acfe32..5e615cf39 100644 --- a/supertokens_python/plugins.py +++ b/supertokens_python/plugins.py @@ -104,7 +104,7 @@ class PluginRouteHandlerResponse(CamelCaseBaseModel): @runtime_checkable class PluginRouteHandlerHandlerFunction(Protocol): - def __call__( + async def __call__( self, request: BaseRequest, response: BaseResponse, @@ -139,6 +139,24 @@ class PluginRouteHandler(CamelCaseBaseModel): verify_session_options: Optional[VerifySessionOptions] +class PluginRouteHandlerWithPluginId(PluginRouteHandler): + plugin_id: str + """ + This is useful when multiple plugins handle the same route. + """ + + @classmethod + def from_route_handler( + cls, + route_handler: PluginRouteHandler, + plugin_id: str, + ): + return cls( + **route_handler.model_dump(), + plugin_id=plugin_id, + ) + + @runtime_checkable class SuperTokensPluginInit(Protocol): def __call__( @@ -372,16 +390,17 @@ def api_override(original_implementation: APIInterfaceType) -> APIInterfaceType: class LoadPluginsResponse(CamelCaseBaseModel): public_config: "SupertokensPublicConfig" processed_plugins: List[SuperTokensPublicPlugin] - plugin_route_handlers: List[PluginRouteHandler] + plugin_route_handlers: List[PluginRouteHandlerWithPluginId] override_maps: List[OverrideMap] def load_plugins( - plugins: List[SuperTokensPlugin], public_config: "SupertokensPublicConfig" + plugins: List[SuperTokensPlugin], + public_config: "SupertokensPublicConfig", ) -> LoadPluginsResponse: input_plugin_seen_list: Set[str] = set() final_plugin_list: List[SuperTokensPlugin] = [] - plugin_route_handlers: List[PluginRouteHandler] = [] + plugin_route_handlers: List[PluginRouteHandlerWithPluginId] = [] for plugin in plugins: if plugin.id in input_plugin_seen_list: @@ -445,7 +464,14 @@ def load_plugins( else: handlers = plugin.route_handlers - plugin_route_handlers.extend(handlers) + plugin_route_handlers.extend( + [ + PluginRouteHandlerWithPluginId.from_route_handler( + handler, plugin.id + ) + for handler in handlers + ] + ) if plugin.init is not None: diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 51612a04c..b34f3c8cf 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -38,13 +38,14 @@ from supertokens_python.plugins import ( OverrideMap, PluginRouteHandler, + PluginRouteHandlerWithPluginId, SuperTokensPlugin, SuperTokensPublicPlugin, ) from supertokens_python.types.response import CamelCaseBaseModel from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT -from .exceptions import SuperTokensError +from .exceptions import PluginError, SuperTokensError from .interfaces import ( CreateUserIdMappingOkResult, DeleteUserIdMappingOkResult, @@ -306,7 +307,7 @@ class Supertokens: telemetry: bool - plugin_route_handlers: List[PluginRouteHandler] + plugin_route_handlers: List[PluginRouteHandlerWithPluginId] plugin_list: List[SuperTokensPublicPlugin] @@ -764,12 +765,22 @@ async def middleware( override_global_claim_validators=verify_session_options.override_global_claim_validators, ) - return handler_from_apis.handler( - request=request, - response=response, - session=session, - user_context=user_context, + log_debug_message( + f"middleware: Request being handled by plugin `{handler_from_apis.plugin_id}`" ) + try: + return await handler_from_apis.handler( + request=request, + response=response, + session=session, + user_context=user_context, + ) + except PluginError as err: + log_debug_message( + f"middleware: Error from plugin `{handler_from_apis.plugin_id}`: {str(err)}. " + "Transforming to SuperTokensError." + ) + raise err if not path.startswith(Supertokens.get_instance().app_info.api_base_path): log_debug_message( @@ -899,7 +910,7 @@ async def handle_supertokens_error( if isinstance(err, GeneralError): raise err - if isinstance(err, BadInputError): + if isinstance(err, (BadInputError, PluginError)): log_debug_message("errorHandler: Sending 400 status code response") return send_non_200_response_with_message(str(err), 400, response) diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index efb70e504..f65efcdb5 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, List, Union +from typing import Any, Dict, List, Union from unittest.mock import patch from pytest import fixture, mark, param, raises @@ -377,10 +377,15 @@ def config_override(config: SupertokensPublicConfig) -> SupertokensPublicConfig: ) +# NOTE: Returning a string here to make it easier to write/test the handler +async def handler_fn(*_, **__: Dict[str, Any]) -> Any: + return "plugin1" + + plugin_route_handler = PluginRouteHandler( method="get", path="/auth/plugin1/hello", - handler=lambda *_, **__: "plugin1", # type: ignore + handler=handler_fn, # type: ignore - returns string for simplicity verify_session_options=None, ) From fb6e57a1592266a350bcaa211953ddee8ed1e90b Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 13 Aug 2025 10:40:59 +0530 Subject: [PATCH 32/37] test: adds workflow to test `get-versions-from-repo` action --- .github/workflows/test-version-action.yml | 51 +++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 .github/workflows/test-version-action.yml diff --git a/.github/workflows/test-version-action.yml b/.github/workflows/test-version-action.yml new file mode 100644 index 000000000..03883c91c --- /dev/null +++ b/.github/workflows/test-version-action.yml @@ -0,0 +1,51 @@ +name: Test Versions Action + +on: + pull_request: + types: + - opened + - reopened + - synchronize + + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + versions: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: supertokens/get-supported-versions-action@main + id: versions + with: + has-fdi: true + has-cdi: true + has-web-js: false + + - uses: supertokens/actions/get-versions-from-repo@main + with: + repo: supertokens-node + cdi-versions: ${{ steps.versions.outputs.cdiVersions }} + fdi-versions: ${{ steps.versions.outputs.fdiVersions }} + + - uses: supertokens/actions/get-versions-from-repo@main + with: + repo: supertokens-auth-react + cdi-versions: ${{ steps.versions.outputs.cdiVersions }} + fdi-versions: ${{ steps.versions.outputs.fdiVersions }} + + - uses: supertokens/actions/get-versions-from-repo@main + with: + repo: supertokens-website + cdi-versions: ${{ steps.versions.outputs.cdiVersions }} + fdi-versions: ${{ steps.versions.outputs.fdiVersions }} + + - uses: supertokens/actions/get-versions-from-repo@main + with: + repo: supertokens-core + cdi-versions: ${{ steps.versions.outputs.cdiVersions }} + fdi-versions: ${{ steps.versions.outputs.fdiVersions }} From b44c44151772c2efa258670fd7e1d8b8c407a49c Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 13 Aug 2025 10:50:06 +0530 Subject: [PATCH 33/37] update: adds github token to workflows --- .github/workflows/test-version-action.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test-version-action.yml b/.github/workflows/test-version-action.yml index 03883c91c..a4b01ee31 100644 --- a/.github/workflows/test-version-action.yml +++ b/.github/workflows/test-version-action.yml @@ -29,23 +29,27 @@ jobs: - uses: supertokens/actions/get-versions-from-repo@main with: repo: supertokens-node + github-token: ${{ secrets.GITHUB_TOKEN }} cdi-versions: ${{ steps.versions.outputs.cdiVersions }} fdi-versions: ${{ steps.versions.outputs.fdiVersions }} - uses: supertokens/actions/get-versions-from-repo@main with: repo: supertokens-auth-react + github-token: ${{ secrets.GITHUB_TOKEN }} cdi-versions: ${{ steps.versions.outputs.cdiVersions }} fdi-versions: ${{ steps.versions.outputs.fdiVersions }} - uses: supertokens/actions/get-versions-from-repo@main with: repo: supertokens-website + github-token: ${{ secrets.GITHUB_TOKEN }} cdi-versions: ${{ steps.versions.outputs.cdiVersions }} fdi-versions: ${{ steps.versions.outputs.fdiVersions }} - uses: supertokens/actions/get-versions-from-repo@main with: repo: supertokens-core + github-token: ${{ secrets.GITHUB_TOKEN }} cdi-versions: ${{ steps.versions.outputs.cdiVersions }} fdi-versions: ${{ steps.versions.outputs.fdiVersions }} From bcd9c723aa2597674c9eb88e6170cfd977747925 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 13 Aug 2025 10:54:07 +0530 Subject: [PATCH 34/37] update: only use versions available in repo for inputs --- .github/workflows/test-version-action.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-version-action.yml b/.github/workflows/test-version-action.yml index a4b01ee31..2659aedc7 100644 --- a/.github/workflows/test-version-action.yml +++ b/.github/workflows/test-version-action.yml @@ -37,14 +37,14 @@ jobs: with: repo: supertokens-auth-react github-token: ${{ secrets.GITHUB_TOKEN }} - cdi-versions: ${{ steps.versions.outputs.cdiVersions }} + # cdi-versions: ${{ steps.versions.outputs.cdiVersions }} fdi-versions: ${{ steps.versions.outputs.fdiVersions }} - uses: supertokens/actions/get-versions-from-repo@main with: repo: supertokens-website github-token: ${{ secrets.GITHUB_TOKEN }} - cdi-versions: ${{ steps.versions.outputs.cdiVersions }} + # cdi-versions: ${{ steps.versions.outputs.cdiVersions }} fdi-versions: ${{ steps.versions.outputs.fdiVersions }} - uses: supertokens/actions/get-versions-from-repo@main @@ -52,4 +52,4 @@ jobs: repo: supertokens-core github-token: ${{ secrets.GITHUB_TOKEN }} cdi-versions: ${{ steps.versions.outputs.cdiVersions }} - fdi-versions: ${{ steps.versions.outputs.fdiVersions }} + # fdi-versions: ${{ steps.versions.outputs.fdiVersions }} From cb599e90468de49979d360a53038bb55f14e2f08 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 13 Aug 2025 11:03:33 +0530 Subject: [PATCH 35/37] update: check for if outputs can be accessed in workflows --- .github/workflows/test-version-action.yml | 50 +++++++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test-version-action.yml b/.github/workflows/test-version-action.yml index 2659aedc7..523f7dedb 100644 --- a/.github/workflows/test-version-action.yml +++ b/.github/workflows/test-version-action.yml @@ -13,8 +13,12 @@ concurrency: cancel-in-progress: true jobs: - versions: + supported-versions: runs-on: ubuntu-latest + outputs: + fdiVersions: ${{ steps.versions.outputs.fdiVersions }} + cdiVersions: ${{ steps.versions.outputs.cdiVersions }} + pyVersions: '["3.8", "3.13"]' steps: - uses: actions/checkout@v4 @@ -26,30 +30,58 @@ jobs: has-cdi: true has-web-js: false + get-versions: + runs-on: ubuntu-latest + needs: supported-versions + + strategy: + fail-fast: false + matrix: + fdi-version: ${{ fromJSON(needs.supported-versions.outputs.fdiVersions) }} + cdi-version: ${{ fromJSON(needs.supported-versions.outputs.cdiVersions) }} + + + steps: - uses: supertokens/actions/get-versions-from-repo@main + id: node with: repo: supertokens-node github-token: ${{ secrets.GITHUB_TOKEN }} - cdi-versions: ${{ steps.versions.outputs.cdiVersions }} - fdi-versions: ${{ steps.versions.outputs.fdiVersions }} + cdi-versions: ${{ needs.supported-versions.outputs.cdiVersions }} + fdi-versions: ${{ needs.supported-versions.outputs.fdiVersions }} + + - run: | + echo "Node version from CDI: ${{ steps.node.outputs.cdiVersions[matrix.cdi-version] }}" + echo "Node version from FDI: ${{ steps.node.outputs.fdiVersions[matrix.fdi-version] }}" - uses: supertokens/actions/get-versions-from-repo@main + id: auth-react with: repo: supertokens-auth-react github-token: ${{ secrets.GITHUB_TOKEN }} - # cdi-versions: ${{ steps.versions.outputs.cdiVersions }} - fdi-versions: ${{ steps.versions.outputs.fdiVersions }} + # cdi-versions: ${{ needs.supported-versions.outputs.cdiVersions }} + fdi-versions: ${{ needs.supported-versions.outputs.fdiVersions }} + + - run: | + echo "Auth react version from FDI: ${{ steps.auth-react.outputs.fdiVersions[matrix.fdi-version] }}" - uses: supertokens/actions/get-versions-from-repo@main + id: website with: repo: supertokens-website github-token: ${{ secrets.GITHUB_TOKEN }} - # cdi-versions: ${{ steps.versions.outputs.cdiVersions }} - fdi-versions: ${{ steps.versions.outputs.fdiVersions }} + # cdi-versions: ${{ needs.supported-versions.outputs.cdiVersions }} + fdi-versions: ${{ needs.supported-versions.outputs.fdiVersions }} + + - run: | + echo "Website version from FDI: ${{ steps.website.outputs.fdiVersions[matrix.fdi-version] }}" - uses: supertokens/actions/get-versions-from-repo@main + id: core with: repo: supertokens-core github-token: ${{ secrets.GITHUB_TOKEN }} - cdi-versions: ${{ steps.versions.outputs.cdiVersions }} - # fdi-versions: ${{ steps.versions.outputs.fdiVersions }} + cdi-versions: ${{ needs.supported-versions.outputs.cdiVersions }} + # fdi-versions: ${{ needs.supported-versions.outputs.fdiVersions }} + - run: | + echo "Core version from CDI: ${{ steps.core.outputs.cdiVersions[matrix.cdi-version] }}" From 3e1bfbf17eb476913da5b994d1098daf728377d7 Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 13 Aug 2025 11:06:55 +0530 Subject: [PATCH 36/37] update: try to use jq --- .github/workflows/test-version-action.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-version-action.yml b/.github/workflows/test-version-action.yml index 523f7dedb..9e4f04ede 100644 --- a/.github/workflows/test-version-action.yml +++ b/.github/workflows/test-version-action.yml @@ -51,7 +51,8 @@ jobs: fdi-versions: ${{ needs.supported-versions.outputs.fdiVersions }} - run: | - echo "Node version from CDI: ${{ steps.node.outputs.cdiVersions[matrix.cdi-version] }}" + nodeCdi=$( echo '${{ steps.node.outputs.cdiVersions }}' | jq -r '.["${{ matrix.cdi-version }}"]' ) + echo "Node version from CDI: $nodeCdi" echo "Node version from FDI: ${{ steps.node.outputs.fdiVersions[matrix.fdi-version] }}" - uses: supertokens/actions/get-versions-from-repo@main From d1ab29103187612d13c854807f575a5ed8e6bcdb Mon Sep 17 00:00:00 2001 From: Namit Nathwani Date: Wed, 13 Aug 2025 11:09:28 +0530 Subject: [PATCH 37/37] update: set all steps to use jq --- .github/workflows/test-version-action.yml | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test-version-action.yml b/.github/workflows/test-version-action.yml index 9e4f04ede..b3279b7a9 100644 --- a/.github/workflows/test-version-action.yml +++ b/.github/workflows/test-version-action.yml @@ -51,9 +51,10 @@ jobs: fdi-versions: ${{ needs.supported-versions.outputs.fdiVersions }} - run: | - nodeCdi=$( echo '${{ steps.node.outputs.cdiVersions }}' | jq -r '.["${{ matrix.cdi-version }}"]' ) - echo "Node version from CDI: $nodeCdi" - echo "Node version from FDI: ${{ steps.node.outputs.fdiVersions[matrix.fdi-version] }}" + cdiVersion=$( echo '${{ steps.node.outputs.cdiVersions }}' | jq -r '.["${{ matrix.cdi-version }}"]' ) + echo "Node version from CDI: $cdiVersion" + fdiVersion=$( echo '${{ steps.node.outputs.fdiVersions }}' | jq -r '.["${{ matrix.fdi-version }}"]' ) + echo "Node version from FDI: $fdiVersion" - uses: supertokens/actions/get-versions-from-repo@main id: auth-react @@ -64,7 +65,8 @@ jobs: fdi-versions: ${{ needs.supported-versions.outputs.fdiVersions }} - run: | - echo "Auth react version from FDI: ${{ steps.auth-react.outputs.fdiVersions[matrix.fdi-version] }}" + fdiVersion=$( echo '${{ steps.auth-react.outputs.fdiVersions }}' | jq -r '.["${{ matrix.fdi-version }}"]' ) + echo "Auth React version from FDI: $fdiVersion" - uses: supertokens/actions/get-versions-from-repo@main id: website @@ -75,7 +77,8 @@ jobs: fdi-versions: ${{ needs.supported-versions.outputs.fdiVersions }} - run: | - echo "Website version from FDI: ${{ steps.website.outputs.fdiVersions[matrix.fdi-version] }}" + fdiVersion=$( echo '${{ steps.website.outputs.fdiVersions }}' | jq -r '.["${{ matrix.fdi-version }}"]' ) + echo "website version from FDI: $fdiVersion" - uses: supertokens/actions/get-versions-from-repo@main id: core @@ -85,4 +88,5 @@ jobs: cdi-versions: ${{ needs.supported-versions.outputs.cdiVersions }} # fdi-versions: ${{ needs.supported-versions.outputs.fdiVersions }} - run: | - echo "Core version from CDI: ${{ steps.core.outputs.cdiVersions[matrix.cdi-version] }}" + cdiVersion=$( echo '${{ steps.core.outputs.cdiVersions }}' | jq -r '.["${{ matrix.cdi-version }}"]' ) + echo "core version from CDI: $cdiVersion"