From 7086a11ce64e692ff2161ac102f3517db4b584b9 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 10 Aug 2025 21:50:59 +1200 Subject: [PATCH 01/23] chore: update project.md --- project.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/project.md b/project.md index 36f2bea..8c132a0 100644 --- a/project.md +++ b/project.md @@ -23,16 +23,16 @@ - backend - device - service + - mcp ### IDP On-Going -- Add MCP app support - -### IDP Todo - - Add OAuth Dynamic Registration for: - accounts - apps + +### IDP Todo + - Account key generation - Dynamic OIDC configs - User authentication for each app type: From 6ee98b63fdb09bbafa2e373e4806ae3102bf802a Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Wed, 13 Aug 2025 19:27:02 +1200 Subject: [PATCH 02/23] feat(idp): add full dynamic registration options to account credentials --- idp/initial_schema.dbml | 100 +++--- .../controllers/account_credentials.go | 20 +- .../controllers/bodies/account_credentials.go | 31 +- idp/internal/controllers/bodies/accounts.go | 4 +- idp/internal/controllers/bodies/apps.go | 4 +- idp/internal/controllers/bodies/auth.go | 6 +- idp/internal/controllers/bodies/users.go | 8 +- .../database/account_credentials.sql.go | 208 ++++++++++--- idp/internal/providers/database/apps.sql.go | 9 +- ...0241213231542_create_initial_schema.up.sql | 76 +++-- idp/internal/providers/database/models.go | 104 ++++--- .../database/queries/account_credentials.sql | 43 ++- .../providers/database/queries/apps.sql | 3 +- idp/internal/services/account_credentials.go | 286 +++++++++++++++--- idp/internal/services/apps.go | 31 +- .../services/dtos/account_credentials.go | 89 +++++- idp/internal/services/helpers.go | 24 +- idp/internal/services/oauth.go | 3 +- idp/tests/account_credentials_test.go | 36 +-- idp/tests/oauth_test.go | 4 +- 20 files changed, 760 insertions(+), 329 deletions(-) diff --git a/idp/initial_schema.dbml b/idp/initial_schema.dbml index 9dbd92f..1571a53 100644 --- a/idp/initial_schema.dbml +++ b/idp/initial_schema.dbml @@ -118,9 +118,9 @@ Table accounts as A { id serial [pk] public_id uuid [not null] - given_name varchar(50) [not null] - family_name varchar(50) [not null] - username varchar(63) [not null] + given_name varchar(100) [not null] + family_name varchar(100) [not null] + username varchar(63) [not null] // maximum length of a DNS label email varchar(250) [not null] organization varchar(50) password text @@ -307,37 +307,11 @@ Enum account_credentials_scope { } Enum account_credentials_type { - "client" + "native" + "service" "mcp" } -Table account_credentials as AC { - id serial [pk] - - account_id integer [not null] - account_public_id uuid [not null] - - credentials_type account_credentials_type [not null] - scopes "account_credentials_scope[]" [not null] - token_endpoint_auth_method "auth_method" [not null] - issuers "varchar(255)[]" [not null] - - alias varchar(100) [not null] - client_id varchar(22) [not null] - - created_at timestamptz [not null, default: `now()`] - updated_at timestamptz [not null, default: `now()`] - - Indexes { - (client_id) [unique, name: 'account_credentials_client_id_uidx'] - (account_id) [name: 'account_credentials_account_id_idx'] - (account_public_id) [name: 'account_credentials_account_public_id_idx'] - (account_public_id, client_id) [name: 'account_credentials_account_public_id_client_id_idx'] - (alias, account_id) [unique, name: 'account_credentials_alias_account_id_uidx'] - } -} -Ref: AC.account_id > A.id [delete: cascade] - Enum transport { "http" "https" @@ -350,40 +324,45 @@ Enum creation_method { "dynamic_registration" } -Table account_credentials_mcps as ACM { +Table account_credentials as AC { id serial [pk] account_id integer [not null] account_public_id uuid [not null] - account_credentials_id integer [not null] - account_credentials_client_id varchar(22) [not null] - creation_method creation_method [not null] + client_id varchar(22) [not null] + name varchar(255) [not null] + domain "varchar(250)" [not null] + credentials_type account_credentials_type [not null] + scopes "account_credentials_scope[]" [not null] + token_endpoint_auth_method "auth_method" [not null] + grant_types "grant_type[]" [not null] + version integer [not null, default: 1] + transport transport [not null] + creation_method creation_method [not null] - response_types "response_type[]" [not null] - callback_uris "varchar(2048)[]" [not null] client_uri varchar(512) [not null] + redirect_uris "varchar(2048)[]" [not null] logo_uri varchar(512) [null] policy_uri varchar(512) [null] tos_uri varchar(512) [null] software_id varchar(512) [not null] software_version varchar(512) [null] - contacts "varchar(512)[]" [not null, default: '{}'] + contacts "varchar(250)[]" [not null] created_at timestamptz [not null, default: `now()`] updated_at timestamptz [not null, default: `now()`] Indexes { - (account_id) [name: 'account_credentials_mcp_account_id_idx'] - (account_public_id) [name: 'account_credentials_mcp_account_public_id_idx'] - (account_credentials_id) [unique, name: 'account_credentials_mcp_account_credentials_id_uidx'] - (account_credentials_client_id) [name: 'account_credentials_mcp_account_credentials_client_id_idx'] - (account_credentials_id, software_id) [unique, name: 'account_credentials_mcp_account_credentials_id_software_id_uidx'] + (client_id) [unique, name: 'account_credentials_client_id_uidx'] + (account_id) [name: 'account_credentials_account_id_idx'] + (account_public_id) [name: 'account_credentials_account_public_id_idx'] + (account_public_id, client_id) [name: 'account_credentials_account_public_id_client_id_idx'] + (name, account_id) [unique, name: 'account_credentials_name_account_id_uidx'] } } -Ref: ACM.account_id > A.id [delete: cascade] -Ref: ACM.account_credentials_id > AC.id [delete: cascade] +Ref: AC.account_id > A.id [delete: cascade] Table account_credentials_secrets as ACS { account_id integer [not null] @@ -532,7 +511,7 @@ Table users as U { account_id integer [not null] email varchar(250) [not null] - username varchar(250) [not null] + username varchar(63) [not null] // maximum length of a DNS label password text version integer [not null, default: 1] email_verified boolean [not null, default: false] @@ -721,7 +700,7 @@ Table apps as APP { account_public_id uuid [not null] app_type app_type [not null] - name varchar(100) [not null] + name varchar(255) [not null] client_id varchar(22) [not null] version integer [not null, default: 1] creation_method creation_method [not null] @@ -888,7 +867,28 @@ Enum software_statement_verification_method { "jwks_uri" } -Table dynamic_registration_configs as DRC { +Table account_dynamic_registration_configs as ADRC { + id serial [pk] + + account_id integer [not null] + + whitelisted_domains "varchar(250)[]" [not null] + require_software_statement boolean [not null] + software_statement_verification_methods "software_statement_verification_method[]" [not null] + + require_initial_access_token boolean [not null] + initial_access_token_generation_methods "initial_access_token_generation_method[]" [not null] + + created_at timestamptz [not null, default: `now()`] + updated_at timestamptz [not null, default: `now()`] + + Indexes { + (account_id) [name: 'account_dynamic_registration_configs_account_id_idx'] + } +} +Ref: ADRC.account_id > A.id [delete: cascade] + +Table app_dynamic_registration_configs as APDRC { id serial [pk] account_id integer [not null] @@ -918,10 +918,10 @@ Table dynamic_registration_configs as DRC { updated_at timestamptz [not null, default: `now()`] Indexes { - (account_id) [name: 'dynamic_registrations_configs_account_id_idx'] + (account_id) [name: 'app_dynamic_registration_configs_account_id_idx'] } } -Ref: DRC.account_id > A.id [delete: cascade] +Ref: APDRC.account_id > A.id [delete: cascade] Enum app_profile_type { "human" diff --git a/idp/internal/controllers/account_credentials.go b/idp/internal/controllers/account_credentials.go index 48e1da5..18a7411 100644 --- a/idp/internal/controllers/account_credentials.go +++ b/idp/internal/controllers/account_credentials.go @@ -49,10 +49,15 @@ func (c *Controllers) CreateAccountCredentials(ctx *fiber.Ctx) error { RequestID: requestID, AccountPublicID: accountClaims.AccountID, AccountVersion: accountClaims.AccountVersion, - Alias: body.Alias, + Name: body.Name, Scopes: body.Scopes, - AuthMethod: body.AuthMethod, - Issuers: body.Issuers, + AuthMethod: body.TokenEndpointAuthMethod, + Domain: body.Domain, + ClientURI: body.ClientURI, + RedirectURIs: body.RedirectURIs, + LogoURI: body.LogoURI, + TOSURI: body.TOSURI, + PolicyURI: body.PolicyURI, }, ) if serviceErr != nil { @@ -170,8 +175,13 @@ func (c *Controllers) UpdateAccountCredentials(ctx *fiber.Ctx) error { AccountVersion: accountClaims.AccountVersion, ClientID: urlParams.ClientID, Scopes: body.Scopes, - Alias: body.Alias, - Issuers: body.Issuers, + Transport: body.Transport, + Domain: body.Domain, + ClientURI: body.ClientURI, + RedirectURIs: body.RedirectURIs, + LogoURI: body.LogoURI, + TOSURI: body.TOSURI, + PolicyURI: body.PolicyURI, }, ) if serviceErr != nil { diff --git a/idp/internal/controllers/bodies/account_credentials.go b/idp/internal/controllers/bodies/account_credentials.go index e30d344..2609e6b 100644 --- a/idp/internal/controllers/bodies/account_credentials.go +++ b/idp/internal/controllers/bodies/account_credentials.go @@ -7,15 +7,30 @@ package bodies type CreateAccountCredentialsBody struct { - Scopes []string `json:"scopes" validate:"required,unique,dive,oneof=account:admin account:users:read account:users:write account:apps:read account:apps:write account:credentials:read account:credentials:write account:auth_providers:read"` - Alias string `json:"alias" validate:"required,min=1,max=50,slug"` - AuthMethod string `json:"auth_method" validate:"required,oneof=client_secret_basic client_secret_post client_secret_jwt private_key_jwt"` - Issuers []string `json:"issuers,omitempty" validate:"required_if=AuthMethod private_key_jwt,unique,dive,url"` - Algorithm string `json:"algorithm,omitempty" validate:"omitempty,oneof=ES256 EdDSA"` + Type string `json:"type" validate:"required,oneof=native service mcp"` + Name string `json:"name" validate:"required,min=1,max=255"` + Scopes []string `json:"scopes" validate:"required,unique,dive,oneof=email profile account:admin account:users:read account:users:write account:apps:read account:apps:write account:credentials:read account:credentials:write account:auth_providers:read"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method" validate:"required,oneof=client_secret_basic client_secret_post client_secret_jwt private_key_jwt"` + Domain string `json:"domain,omitempty" validate:"omitempty,fqdn,max=250"` + ClientURI string `json:"client_uri" validate:"required,uri"` + RedirectURIs []string `json:"redirect_uris,omitempty" validate:"omitempty,unique,dive,uri"` + LogoURI string `json:"logo_uri,omitempty" validate:"omitempty,uri"` + TOSURI string `json:"tos_uri,omitempty" validate:"omitempty,uri"` + PolicyURI string `json:"policy_uri,omitempty" validate:"omitempty,uri"` + SoftwareID string `json:"software_id" validate:"required,min=1,max=100"` + SoftwareVersion string `json:"software_version" validate:"required,min=1,max=100"` + Algorithm string `json:"algorithm,omitempty" validate:"omitempty,oneof=ES256 EdDSA"` } type UpdateAccountCredentialsBody struct { - Scopes []string `json:"scopes" validate:"required,unique,dive,oneof=account:admin account:users:read account:users:write account:apps:read account:apps:write account:credentials:read account:credentials:write account:auth_providers:read"` - Alias string `json:"alias" validate:"required,min=1,max=50,slug"` - Issuers []string `json:"issuers" validate:"required,unique,dive,url"` + Name string `json:"name" validate:"required,min=1,max=255"` + Scopes []string `json:"scopes" validate:"required,unique,dive,oneof=account:admin account:users:read account:users:write account:apps:read account:apps:write account:credentials:read account:credentials:write account:auth_providers:read"` + Transport string `json:"transport" validate:"required,oneof=http https stdio streamable_http"` + Domain string `json:"domain,omitempty" validate:"omitempty,fqdn,max=250"` + ClientURI string `json:"client_uri" validate:"required,uri"` + RedirectURIs []string `json:"redirect_uris,omitempty" validate:"omitempty,unique,dive,uri"` + LogoURI string `json:"logo_uri,omitempty" validate:"omitempty,uri"` + TOSURI string `json:"tos_uri,omitempty" validate:"omitempty,uri"` + PolicyURI string `json:"policy_uri,omitempty" validate:"omitempty,uri"` + SoftwareVersion string `json:"software_version,omitempty" validate:"omitempty,min=1,max=100"` } diff --git a/idp/internal/controllers/bodies/accounts.go b/idp/internal/controllers/bodies/accounts.go index fb2d17f..5b41cf4 100644 --- a/idp/internal/controllers/bodies/accounts.go +++ b/idp/internal/controllers/bodies/accounts.go @@ -7,6 +7,6 @@ package bodies type UpdateAccountBody struct { - GivenName string `json:"given_name" validate:"required,min=2,max=50"` - FamilyName string `json:"family_name" validate:"required,min=2,max=50"` + GivenName string `json:"given_name" validate:"required,min=2,max=100"` + FamilyName string `json:"family_name" validate:"required,min=2,max=100"` } diff --git a/idp/internal/controllers/bodies/apps.go b/idp/internal/controllers/bodies/apps.go index 169864a..ad13063 100644 --- a/idp/internal/controllers/bodies/apps.go +++ b/idp/internal/controllers/bodies/apps.go @@ -8,7 +8,7 @@ package bodies type CreateAppBodyBase struct { Type string `json:"type" validate:"required,oneof=web spa native backend device service mcp"` - Name string `json:"name" validate:"required,min=3,max=50"` + Name string `json:"name" validate:"required,min=1,max=255"` Domain string `json:"domain" validate:"omitempty,fqdn,max=250"` ClientURI string `json:"client_uri" validate:"required,url"` LogoURI string `json:"logo_uri,omitempty" validate:"omitempty,url"` @@ -25,7 +25,7 @@ type CreateAppBodyBase struct { } type UpdateAppBodyBase struct { - Name string `json:"name" validate:"required,max=50,min=3"` + Name string `json:"name" validate:"required,max=255,min=1"` Domain string `json:"domain" validate:"omitempty,fqdn,max=250"` ClientURI string `json:"client_uri" validate:"required,url"` LogoURI string `json:"logo_uri,omitempty" validate:"omitempty,url"` diff --git a/idp/internal/controllers/bodies/auth.go b/idp/internal/controllers/bodies/auth.go index 87fd7f0..8131f6b 100644 --- a/idp/internal/controllers/bodies/auth.go +++ b/idp/internal/controllers/bodies/auth.go @@ -7,9 +7,9 @@ package bodies type RegisterAccountBody struct { - Email string `json:"email" validate:"required,email"` - GivenName string `json:"given_name" validate:"required,min=2,max=50"` - FamilyName string `json:"family_name" validate:"required,min=2,max=50"` + Email string `json:"email" validate:"required,email,max=250"` + GivenName string `json:"given_name" validate:"required,min=2,max=100"` + FamilyName string `json:"family_name" validate:"required,min=2,max=100"` Username string `json:"username,omitempty" validate:"omitempty,min=3,max=63,slug"` Password string `json:"password" validate:"required,min=8,max=100,password"` Password2 string `json:"password2" validate:"required,eqfield=Password"` diff --git a/idp/internal/controllers/bodies/users.go b/idp/internal/controllers/bodies/users.go index 6876773..7275f41 100644 --- a/idp/internal/controllers/bodies/users.go +++ b/idp/internal/controllers/bodies/users.go @@ -3,15 +3,15 @@ package bodies type UserData = map[string]any type CreateUserBody struct { - Email string `json:"email" validate:"required,email"` - Username string `json:"username,omitempty" validate:"omitempty,min=3,max=100,slug"` + Email string `json:"email" validate:"required,email,max=250"` + Username string `json:"username,omitempty" validate:"omitempty,min=3,max=63,slug"` Password string `json:"password" validate:"required,min=8,max=100,password"` UserData } type UpdateUserBody struct { - Email string `json:"email" validate:"omitempty,email"` - Username string `json:"username,omitempty" validate:"omitempty,min=3,max=100,slug"` + Email string `json:"email" validate:"omitempty,email,max=250"` + Username string `json:"username,omitempty" validate:"omitempty,min=3,max=63,slug"` IsActive bool `json:"is_active"` UserData } diff --git a/idp/internal/providers/database/account_credentials.sql.go b/idp/internal/providers/database/account_credentials.sql.go index caac08d..691b54d 100644 --- a/idp/internal/providers/database/account_credentials.sql.go +++ b/idp/internal/providers/database/account_credentials.sql.go @@ -9,6 +9,7 @@ import ( "context" "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" ) const countAccountCredentialsByAccountPublicID = `-- name: CountAccountCredentialsByAccountPublicID :one @@ -24,18 +25,18 @@ func (q *Queries) CountAccountCredentialsByAccountPublicID(ctx context.Context, return count, err } -const countAccountCredentialsByAliasAndAccountID = `-- name: CountAccountCredentialsByAliasAndAccountID :one +const countAccountCredentialsByNameAndAccountID = `-- name: CountAccountCredentialsByNameAndAccountID :one SELECT COUNT(*) FROM "account_credentials" -WHERE "account_id" = $1 AND "alias" = $2 +WHERE "account_id" = $1 AND "name" = $2 ` -type CountAccountCredentialsByAliasAndAccountIDParams struct { +type CountAccountCredentialsByNameAndAccountIDParams struct { AccountID int32 - Alias string + Name string } -func (q *Queries) CountAccountCredentialsByAliasAndAccountID(ctx context.Context, arg CountAccountCredentialsByAliasAndAccountIDParams) (int64, error) { - row := q.db.QueryRow(ctx, countAccountCredentialsByAliasAndAccountID, arg.AccountID, arg.Alias) +func (q *Queries) CountAccountCredentialsByNameAndAccountID(ctx context.Context, arg CountAccountCredentialsByNameAndAccountIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countAccountCredentialsByNameAndAccountID, arg.AccountID, arg.Name) var count int64 err := row.Scan(&count) return count, err @@ -47,10 +48,20 @@ INSERT INTO "account_credentials" ( "account_id", "account_public_id", "credentials_type", - "alias", + "name", "scopes", "token_endpoint_auth_method", - "issuers" + "domain", + "client_uri", + "redirect_uris", + "logo_uri", + "policy_uri", + "tos_uri", + "software_id", + "software_version", + "contacts", + "creation_method", + "transport" ) VALUES ( $1, $2, @@ -59,8 +70,18 @@ INSERT INTO "account_credentials" ( $5, $6, $7, - $8 -) RETURNING id, account_id, account_public_id, credentials_type, scopes, token_endpoint_auth_method, issuers, alias, client_id, created_at, updated_at + $8, + $9, + $10, + $11, + $12, + $13, + $14, + $15, + $16, + $17, + $18 +) RETURNING id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at ` type CreateAccountCredentialsParams struct { @@ -68,10 +89,20 @@ type CreateAccountCredentialsParams struct { AccountID int32 AccountPublicID uuid.UUID CredentialsType AccountCredentialsType - Alias string + Name string Scopes []AccountCredentialsScope TokenEndpointAuthMethod AuthMethod - Issuers []string + Domain string + ClientUri string + RedirectUris []string + LogoUri pgtype.Text + PolicyUri pgtype.Text + TosUri pgtype.Text + SoftwareID string + SoftwareVersion pgtype.Text + Contacts []string + CreationMethod CreationMethod + Transport Transport } func (q *Queries) CreateAccountCredentials(ctx context.Context, arg CreateAccountCredentialsParams) (AccountCredential, error) { @@ -80,22 +111,44 @@ func (q *Queries) CreateAccountCredentials(ctx context.Context, arg CreateAccoun arg.AccountID, arg.AccountPublicID, arg.CredentialsType, - arg.Alias, + arg.Name, arg.Scopes, arg.TokenEndpointAuthMethod, - arg.Issuers, + arg.Domain, + arg.ClientUri, + arg.RedirectUris, + arg.LogoUri, + arg.PolicyUri, + arg.TosUri, + arg.SoftwareID, + arg.SoftwareVersion, + arg.Contacts, + arg.CreationMethod, + arg.Transport, ) var i AccountCredential err := row.Scan( &i.ID, &i.AccountID, &i.AccountPublicID, + &i.ClientID, + &i.Name, + &i.Domain, &i.CredentialsType, &i.Scopes, &i.TokenEndpointAuthMethod, - &i.Issuers, - &i.Alias, - &i.ClientID, + &i.GrantTypes, + &i.Version, + &i.Transport, + &i.CreationMethod, + &i.ClientUri, + &i.RedirectUris, + &i.LogoUri, + &i.PolicyUri, + &i.TosUri, + &i.SoftwareID, + &i.SoftwareVersion, + &i.Contacts, &i.CreatedAt, &i.UpdatedAt, ) @@ -122,7 +175,7 @@ func (q *Queries) DeleteAllAccountCredentials(ctx context.Context) error { } const findAccountCredentialsByAccountPublicIDAndClientID = `-- name: FindAccountCredentialsByAccountPublicIDAndClientID :one -SELECT id, account_id, account_public_id, credentials_type, scopes, token_endpoint_auth_method, issuers, alias, client_id, created_at, updated_at FROM "account_credentials" +SELECT id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at FROM "account_credentials" WHERE "account_public_id" = $1 AND "client_id" = $2 LIMIT 1 ` @@ -139,12 +192,24 @@ func (q *Queries) FindAccountCredentialsByAccountPublicIDAndClientID(ctx context &i.ID, &i.AccountID, &i.AccountPublicID, + &i.ClientID, + &i.Name, + &i.Domain, &i.CredentialsType, &i.Scopes, &i.TokenEndpointAuthMethod, - &i.Issuers, - &i.Alias, - &i.ClientID, + &i.GrantTypes, + &i.Version, + &i.Transport, + &i.CreationMethod, + &i.ClientUri, + &i.RedirectUris, + &i.LogoUri, + &i.PolicyUri, + &i.TosUri, + &i.SoftwareID, + &i.SoftwareVersion, + &i.Contacts, &i.CreatedAt, &i.UpdatedAt, ) @@ -153,7 +218,7 @@ func (q *Queries) FindAccountCredentialsByAccountPublicIDAndClientID(ctx context const findAccountCredentialsByClientID = `-- name: FindAccountCredentialsByClientID :one -SELECT id, account_id, account_public_id, credentials_type, scopes, token_endpoint_auth_method, issuers, alias, client_id, created_at, updated_at FROM "account_credentials" +SELECT id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at FROM "account_credentials" WHERE "client_id" = $1 LIMIT 1 ` @@ -170,12 +235,24 @@ func (q *Queries) FindAccountCredentialsByClientID(ctx context.Context, clientID &i.ID, &i.AccountID, &i.AccountPublicID, + &i.ClientID, + &i.Name, + &i.Domain, &i.CredentialsType, &i.Scopes, &i.TokenEndpointAuthMethod, - &i.Issuers, - &i.Alias, - &i.ClientID, + &i.GrantTypes, + &i.Version, + &i.Transport, + &i.CreationMethod, + &i.ClientUri, + &i.RedirectUris, + &i.LogoUri, + &i.PolicyUri, + &i.TosUri, + &i.SoftwareID, + &i.SoftwareVersion, + &i.Contacts, &i.CreatedAt, &i.UpdatedAt, ) @@ -183,7 +260,7 @@ func (q *Queries) FindAccountCredentialsByClientID(ctx context.Context, clientID } const findPaginatedAccountCredentialsByAccountPublicID = `-- name: FindPaginatedAccountCredentialsByAccountPublicID :many -SELECT id, account_id, account_public_id, credentials_type, scopes, token_endpoint_auth_method, issuers, alias, client_id, created_at, updated_at FROM "account_credentials" +SELECT id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at FROM "account_credentials" WHERE "account_public_id" = $1 ORDER BY "id" DESC OFFSET $2 LIMIT $3 @@ -208,12 +285,24 @@ func (q *Queries) FindPaginatedAccountCredentialsByAccountPublicID(ctx context.C &i.ID, &i.AccountID, &i.AccountPublicID, + &i.ClientID, + &i.Name, + &i.Domain, &i.CredentialsType, &i.Scopes, &i.TokenEndpointAuthMethod, - &i.Issuers, - &i.Alias, - &i.ClientID, + &i.GrantTypes, + &i.Version, + &i.Transport, + &i.CreationMethod, + &i.ClientUri, + &i.RedirectUris, + &i.LogoUri, + &i.PolicyUri, + &i.TosUri, + &i.SoftwareID, + &i.SoftwareVersion, + &i.Contacts, &i.CreatedAt, &i.UpdatedAt, ); err != nil { @@ -230,38 +319,75 @@ func (q *Queries) FindPaginatedAccountCredentialsByAccountPublicID(ctx context.C const updateAccountCredentials = `-- name: UpdateAccountCredentials :one UPDATE "account_credentials" SET "scopes" = $2, - "alias" = $3, - "issuers" = $4, + "name" = $3, + "domain" = $4, + "client_uri" = $5, + "redirect_uris" = $6, + "logo_uri" = $7, + "policy_uri" = $8, + "tos_uri" = $9, + "software_version" = $10, + "contacts" = $11, + "transport" = $12, + "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 -RETURNING id, account_id, account_public_id, credentials_type, scopes, token_endpoint_auth_method, issuers, alias, client_id, created_at, updated_at +RETURNING id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at ` type UpdateAccountCredentialsParams struct { - ID int32 - Scopes []AccountCredentialsScope - Alias string - Issuers []string + ID int32 + Scopes []AccountCredentialsScope + Name string + Domain string + ClientUri string + RedirectUris []string + LogoUri pgtype.Text + PolicyUri pgtype.Text + TosUri pgtype.Text + SoftwareVersion pgtype.Text + Contacts []string + Transport Transport } func (q *Queries) UpdateAccountCredentials(ctx context.Context, arg UpdateAccountCredentialsParams) (AccountCredential, error) { row := q.db.QueryRow(ctx, updateAccountCredentials, arg.ID, arg.Scopes, - arg.Alias, - arg.Issuers, + arg.Name, + arg.Domain, + arg.ClientUri, + arg.RedirectUris, + arg.LogoUri, + arg.PolicyUri, + arg.TosUri, + arg.SoftwareVersion, + arg.Contacts, + arg.Transport, ) var i AccountCredential err := row.Scan( &i.ID, &i.AccountID, &i.AccountPublicID, + &i.ClientID, + &i.Name, + &i.Domain, &i.CredentialsType, &i.Scopes, &i.TokenEndpointAuthMethod, - &i.Issuers, - &i.Alias, - &i.ClientID, + &i.GrantTypes, + &i.Version, + &i.Transport, + &i.CreationMethod, + &i.ClientUri, + &i.RedirectUris, + &i.LogoUri, + &i.PolicyUri, + &i.TosUri, + &i.SoftwareID, + &i.SoftwareVersion, + &i.Contacts, &i.CreatedAt, &i.UpdatedAt, ) diff --git a/idp/internal/providers/database/apps.sql.go b/idp/internal/providers/database/apps.sql.go index 5146727..cfe6f64 100644 --- a/idp/internal/providers/database/apps.sql.go +++ b/idp/internal/providers/database/apps.sql.go @@ -1129,7 +1129,7 @@ SET "name" = $2, "logo_uri" = $5, "tos_uri" = $6, "policy_uri" = $7, - "software_id" = $8, + "auth_providers" = $8, "software_version" = $9, "contacts" = $10, "domain" = $11, @@ -1137,7 +1137,6 @@ SET "name" = $2, "redirect_uris" = $13, "allow_user_registration" = $14, "response_types" = $15, - "auth_providers" = $16, "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 @@ -1152,7 +1151,7 @@ type UpdateAppParams struct { LogoUri pgtype.Text TosUri pgtype.Text PolicyUri pgtype.Text - SoftwareID string + AuthProviders []AuthProvider SoftwareVersion pgtype.Text Contacts []string Domain string @@ -1160,7 +1159,6 @@ type UpdateAppParams struct { RedirectUris []string AllowUserRegistration bool ResponseTypes []ResponseType - AuthProviders []AuthProvider } func (q *Queries) UpdateApp(ctx context.Context, arg UpdateAppParams) (App, error) { @@ -1172,7 +1170,7 @@ func (q *Queries) UpdateApp(ctx context.Context, arg UpdateAppParams) (App, erro arg.LogoUri, arg.TosUri, arg.PolicyUri, - arg.SoftwareID, + arg.AuthProviders, arg.SoftwareVersion, arg.Contacts, arg.Domain, @@ -1180,7 +1178,6 @@ func (q *Queries) UpdateApp(ctx context.Context, arg UpdateAppParams) (App, erro arg.RedirectUris, arg.AllowUserRegistration, arg.ResponseTypes, - arg.AuthProviders, ) var i App err := row.Scan( diff --git a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql index f705b51..69a2d91 100644 --- a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql +++ b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql @@ -1,6 +1,6 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-08-10T08:33:39.963Z +-- Generated at: 2025-08-12T22:36:20.011Z CREATE TYPE "kek_usage" AS ENUM ( 'global', @@ -83,7 +83,8 @@ CREATE TYPE "account_credentials_scope" AS ENUM ( ); CREATE TYPE "account_credentials_type" AS ENUM ( - 'client', + 'native', + 'service', 'mcp' ); @@ -226,8 +227,8 @@ CREATE TABLE "token_signing_keys" ( CREATE TABLE "accounts" ( "id" serial PRIMARY KEY, "public_id" uuid NOT NULL, - "given_name" varchar(50) NOT NULL, - "family_name" varchar(50) NOT NULL, + "given_name" varchar(100) NOT NULL, + "family_name" varchar(100) NOT NULL, "username" varchar(63) NOT NULL, "email" varchar(250) NOT NULL, "organization" varchar(50), @@ -304,33 +305,24 @@ CREATE TABLE "account_credentials" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, "account_public_id" uuid NOT NULL, + "client_id" varchar(22) NOT NULL, + "name" varchar(255) NOT NULL, + "domain" varchar(250) NOT NULL, "credentials_type" account_credentials_type NOT NULL, "scopes" account_credentials_scope[] NOT NULL, "token_endpoint_auth_method" auth_method NOT NULL, - "issuers" varchar(255)[] NOT NULL, - "alias" varchar(100) NOT NULL, - "client_id" varchar(22) NOT NULL, - "created_at" timestamptz NOT NULL DEFAULT (now()), - "updated_at" timestamptz NOT NULL DEFAULT (now()) -); - -CREATE TABLE "account_credentials_mcps" ( - "id" serial PRIMARY KEY, - "account_id" integer NOT NULL, - "account_public_id" uuid NOT NULL, - "account_credentials_id" integer NOT NULL, - "account_credentials_client_id" varchar(22) NOT NULL, - "creation_method" creation_method NOT NULL, + "grant_types" grant_type[] NOT NULL, + "version" integer NOT NULL DEFAULT 1, "transport" transport NOT NULL, - "response_types" response_type[] NOT NULL, - "callback_uris" varchar(2048)[] NOT NULL, + "creation_method" creation_method NOT NULL, "client_uri" varchar(512) NOT NULL, + "redirect_uris" varchar(2048)[] NOT NULL, "logo_uri" varchar(512), "policy_uri" varchar(512), "tos_uri" varchar(512), "software_id" varchar(512) NOT NULL, "software_version" varchar(512), - "contacts" varchar(512)[] NOT NULL DEFAULT '{}', + "contacts" varchar(250)[] NOT NULL, "created_at" timestamptz NOT NULL DEFAULT (now()), "updated_at" timestamptz NOT NULL DEFAULT (now()) ); @@ -387,7 +379,7 @@ CREATE TABLE "users" ( "public_id" uuid NOT NULL, "account_id" integer NOT NULL, "email" varchar(250) NOT NULL, - "username" varchar(250) NOT NULL, + "username" varchar(63) NOT NULL, "password" text, "version" integer NOT NULL DEFAULT 1, "email_verified" boolean NOT NULL DEFAULT false, @@ -460,7 +452,7 @@ CREATE TABLE "apps" ( "account_id" integer NOT NULL, "account_public_id" uuid NOT NULL, "app_type" app_type NOT NULL, - "name" varchar(100) NOT NULL, + "name" varchar(255) NOT NULL, "client_id" varchar(22) NOT NULL, "version" integer NOT NULL DEFAULT 1, "creation_method" creation_method NOT NULL, @@ -539,7 +531,19 @@ CREATE TABLE "app_designs" ( "updated_at" timestamptz NOT NULL DEFAULT (now()) ); -CREATE TABLE "dynamic_registration_configs" ( +CREATE TABLE "account_dynamic_registration_configs" ( + "id" serial PRIMARY KEY, + "account_id" integer NOT NULL, + "whitelisted_domains" varchar(250)[] NOT NULL, + "require_software_statement" boolean NOT NULL, + "software_statement_verification_methods" software_statement_verification_method[] NOT NULL, + "require_initial_access_token" boolean NOT NULL, + "initial_access_token_generation_methods" initial_access_token_generation_method[] NOT NULL, + "created_at" timestamptz NOT NULL DEFAULT (now()), + "updated_at" timestamptz NOT NULL DEFAULT (now()) +); + +CREATE TABLE "app_dynamic_registration_configs" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, "allowed_app_types" app_type[] NOT NULL, @@ -671,17 +675,7 @@ CREATE INDEX "account_credentials_account_public_id_idx" ON "account_credentials CREATE INDEX "account_credentials_account_public_id_client_id_idx" ON "account_credentials" ("account_public_id", "client_id"); -CREATE UNIQUE INDEX "account_credentials_alias_account_id_uidx" ON "account_credentials" ("alias", "account_id"); - -CREATE INDEX "account_credentials_mcp_account_id_idx" ON "account_credentials_mcps" ("account_id"); - -CREATE INDEX "account_credentials_mcp_account_public_id_idx" ON "account_credentials_mcps" ("account_public_id"); - -CREATE UNIQUE INDEX "account_credentials_mcp_account_credentials_id_uidx" ON "account_credentials_mcps" ("account_credentials_id"); - -CREATE INDEX "account_credentials_mcp_account_credentials_client_id_idx" ON "account_credentials_mcps" ("account_credentials_client_id"); - -CREATE UNIQUE INDEX "account_credentials_mcp_account_credentials_id_software_id_uidx" ON "account_credentials_mcps" ("account_credentials_id", "software_id"); +CREATE UNIQUE INDEX "account_credentials_name_account_id_uidx" ON "account_credentials" ("name", "account_id"); CREATE INDEX "account_credential_secrets_account_id_idx" ON "account_credentials_secrets" ("account_id"); @@ -833,7 +827,9 @@ CREATE INDEX "app_designs_account_id_idx" ON "app_designs" ("account_id"); CREATE UNIQUE INDEX "app_designs_app_id_uidx" ON "app_designs" ("app_id"); -CREATE INDEX "dynamic_registrations_configs_account_id_idx" ON "dynamic_registration_configs" ("account_id"); +CREATE INDEX "account_dynamic_registration_configs_account_id_idx" ON "account_dynamic_registration_configs" ("account_id"); + +CREATE INDEX "app_dynamic_registration_configs_account_id_idx" ON "app_dynamic_registration_configs" ("account_id"); CREATE INDEX "user_profiles_app_id_idx" ON "app_profiles" ("app_id"); @@ -877,10 +873,6 @@ ALTER TABLE "account_totps" ADD FOREIGN KEY ("totp_id") REFERENCES "totps" ("id" ALTER TABLE "account_credentials" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; -ALTER TABLE "account_credentials_mcps" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; - -ALTER TABLE "account_credentials_mcps" ADD FOREIGN KEY ("account_credentials_id") REFERENCES "account_credentials" ("id") ON DELETE CASCADE; - ALTER TABLE "account_credentials_secrets" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; ALTER TABLE "account_credentials_secrets" ADD FOREIGN KEY ("credentials_secret_id") REFERENCES "credentials_secrets" ("id") ON DELETE CASCADE; @@ -967,7 +959,9 @@ ALTER TABLE "app_designs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ( ALTER TABLE "app_designs" ADD FOREIGN KEY ("app_id") REFERENCES "apps" ("id") ON DELETE CASCADE; -ALTER TABLE "dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; +ALTER TABLE "account_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; + +ALTER TABLE "app_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; ALTER TABLE "app_profiles" ADD FOREIGN KEY ("app_id") REFERENCES "apps" ("id") ON DELETE CASCADE; diff --git a/idp/internal/providers/database/models.go b/idp/internal/providers/database/models.go index 8e2103a..bfbb4ff 100644 --- a/idp/internal/providers/database/models.go +++ b/idp/internal/providers/database/models.go @@ -66,8 +66,9 @@ func (ns NullAccountCredentialsScope) Value() (driver.Value, error) { type AccountCredentialsType string const ( - AccountCredentialsTypeClient AccountCredentialsType = "client" - AccountCredentialsTypeMcp AccountCredentialsType = "mcp" + AccountCredentialsTypeNative AccountCredentialsType = "native" + AccountCredentialsTypeService AccountCredentialsType = "service" + AccountCredentialsTypeMcp AccountCredentialsType = "mcp" ) func (e *AccountCredentialsType) Scan(src interface{}) error { @@ -1150,12 +1151,24 @@ type AccountCredential struct { ID int32 AccountID int32 AccountPublicID uuid.UUID + ClientID string + Name string + Domain string CredentialsType AccountCredentialsType Scopes []AccountCredentialsScope TokenEndpointAuthMethod AuthMethod - Issuers []string - Alias string - ClientID string + GrantTypes []GrantType + Version int32 + Transport Transport + CreationMethod CreationMethod + ClientUri string + RedirectUris []string + LogoUri pgtype.Text + PolicyUri pgtype.Text + TosUri pgtype.Text + SoftwareID string + SoftwareVersion pgtype.Text + Contacts []string CreatedAt time.Time UpdatedAt time.Time } @@ -1169,27 +1182,6 @@ type AccountCredentialsKey struct { CreatedAt time.Time } -type AccountCredentialsMcp struct { - ID int32 - AccountID int32 - AccountPublicID uuid.UUID - AccountCredentialsID int32 - AccountCredentialsClientID string - CreationMethod CreationMethod - Transport Transport - ResponseTypes []ResponseType - CallbackUris []string - ClientUri string - LogoUri pgtype.Text - PolicyUri pgtype.Text - TosUri pgtype.Text - SoftwareID string - SoftwareVersion pgtype.Text - Contacts []string - CreatedAt time.Time - UpdatedAt time.Time -} - type AccountCredentialsSecret struct { AccountID int32 CredentialsSecretID int32 @@ -1205,6 +1197,18 @@ type AccountDataEncryptionKey struct { CreatedAt time.Time } +type AccountDynamicRegistrationConfig struct { + ID int32 + AccountID int32 + WhitelistedDomains []string + RequireSoftwareStatement bool + SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod + RequireInitialAccessToken bool + InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod + CreatedAt time.Time + UpdatedAt time.Time +} + type AccountKeyEncryptionKey struct { AccountID int32 KeyEncryptionKeyID int32 @@ -1271,6 +1275,30 @@ type AppDesign struct { UpdatedAt time.Time } +type AppDynamicRegistrationConfig struct { + ID int32 + AccountID int32 + AllowedAppTypes []AppType + WhitelistedDomains []string + DefaultAllowUserRegistration bool + DefaultAuthProviders []AuthProvider + DefaultUsernameColumn AppUsernameColumn + DefaultAllowedScopes []Scopes + DefaultScopes []Scopes + RequireSoftwareStatementAppTypes []AppType + SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod + RequireInitialAccessTokenAppTypes []AppType + InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod + InitialAccessTokenTtl int32 + InitialAccessTokenMaxUses int32 + AllowedGrantTypes []GrantType + AllowedResponseTypes []ResponseType + AllowedTokenEndpointAuthMethods []AuthMethod + MaxRedirectUris int32 + CreatedAt time.Time + UpdatedAt time.Time +} + type AppKey struct { AppID int32 CredentialsKeyID int32 @@ -1351,30 +1379,6 @@ type DataEncryptionKey struct { UpdatedAt time.Time } -type DynamicRegistrationConfig struct { - ID int32 - AccountID int32 - AllowedAppTypes []AppType - WhitelistedDomains []string - DefaultAllowUserRegistration bool - DefaultAuthProviders []AuthProvider - DefaultUsernameColumn AppUsernameColumn - DefaultAllowedScopes []Scopes - DefaultScopes []Scopes - RequireSoftwareStatementAppTypes []AppType - SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod - RequireInitialAccessTokenAppTypes []AppType - InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod - InitialAccessTokenTtl int32 - InitialAccessTokenMaxUses int32 - AllowedGrantTypes []GrantType - AllowedResponseTypes []ResponseType - AllowedTokenEndpointAuthMethods []AuthMethod - MaxRedirectUris int32 - CreatedAt time.Time - UpdatedAt time.Time -} - type KeyEncryptionKey struct { ID int32 Kid uuid.UUID diff --git a/idp/internal/providers/database/queries/account_credentials.sql b/idp/internal/providers/database/queries/account_credentials.sql index 93411b8..3e1bfa7 100644 --- a/idp/internal/providers/database/queries/account_credentials.sql +++ b/idp/internal/providers/database/queries/account_credentials.sql @@ -20,10 +20,20 @@ INSERT INTO "account_credentials" ( "account_id", "account_public_id", "credentials_type", - "alias", + "name", "scopes", "token_endpoint_auth_method", - "issuers" + "domain", + "client_uri", + "redirect_uris", + "logo_uri", + "policy_uri", + "tos_uri", + "software_id", + "software_version", + "contacts", + "creation_method", + "transport" ) VALUES ( $1, $2, @@ -32,21 +42,40 @@ INSERT INTO "account_credentials" ( $5, $6, $7, - $8 + $8, + $9, + $10, + $11, + $12, + $13, + $14, + $15, + $16, + $17, + $18 ) RETURNING *; -- name: UpdateAccountCredentials :one UPDATE "account_credentials" SET "scopes" = $2, - "alias" = $3, - "issuers" = $4, + "name" = $3, + "domain" = $4, + "client_uri" = $5, + "redirect_uris" = $6, + "logo_uri" = $7, + "policy_uri" = $8, + "tos_uri" = $9, + "software_version" = $10, + "contacts" = $11, + "transport" = $12, + "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 RETURNING *; --- name: CountAccountCredentialsByAliasAndAccountID :one +-- name: CountAccountCredentialsByNameAndAccountID :one SELECT COUNT(*) FROM "account_credentials" -WHERE "account_id" = $1 AND "alias" = $2; +WHERE "account_id" = $1 AND "name" = $2; -- name: DeleteAccountCredentials :exec DELETE FROM "account_credentials" diff --git a/idp/internal/providers/database/queries/apps.sql b/idp/internal/providers/database/queries/apps.sql index 6cc314d..f748ba4 100644 --- a/idp/internal/providers/database/queries/apps.sql +++ b/idp/internal/providers/database/queries/apps.sql @@ -92,7 +92,7 @@ SET "name" = $2, "logo_uri" = $5, "tos_uri" = $6, "policy_uri" = $7, - "software_id" = $8, + "auth_providers" = $8, "software_version" = $9, "contacts" = $10, "domain" = $11, @@ -100,7 +100,6 @@ SET "name" = $2, "redirect_uris" = $13, "allow_user_registration" = $14, "response_types" = $15, - "auth_providers" = $16, "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 diff --git a/idp/internal/services/account_credentials.go b/idp/internal/services/account_credentials.go index afcc61c..baa3c2c 100644 --- a/idp/internal/services/account_credentials.go +++ b/idp/internal/services/account_credentials.go @@ -9,6 +9,7 @@ package services import ( "context" "fmt" + "strings" "time" "github.com/google/uuid" @@ -16,7 +17,6 @@ import ( "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/cache" "github.com/tugascript/devlogs/idp/internal/providers/database" - "github.com/tugascript/devlogs/idp/internal/providers/tokens" "github.com/tugascript/devlogs/idp/internal/services/dtos" "github.com/tugascript/devlogs/idp/internal/utils" ) @@ -28,10 +28,89 @@ const ( accountCredentialsKeysCacheKeyPrefix string = "account_credentials_keys" ) +func mapAccountCredentialsTransport( + transport string, + credentialType database.AccountCredentialsType, +) (database.Transport, *exceptions.ServiceError) { + if credentialType == database.AccountCredentialsTypeMcp { + switch transport { + case transportSTDIO: + return database.TransportStdio, nil + case transportStreamableHTTP: + return database.TransportStreamableHttp, nil + default: + return "", exceptions.NewValidationError("invalid transport: " + transport) + } + } + if credentialType == database.AccountCredentialsTypeService || credentialType == database.AccountCredentialsTypeNative { + switch transport { + case transportHTTP: + return database.TransportHttp, nil + case transportHTTPS: + return database.TransportHttps, nil + default: + return "", exceptions.NewValidationError("invalid transport: " + transport) + } + } + + return "", exceptions.NewValidationError("invalid credentials type: " + string(credentialType)) +} + +func mapAccountCredentialsType(credentialsType string) (database.AccountCredentialsType, *exceptions.ServiceError) { + acType := database.AccountCredentialsType(credentialsType) + switch acType { + case database.AccountCredentialsTypeService: + return acType, nil + case database.AccountCredentialsTypeMcp: + return acType, nil + case database.AccountCredentialsTypeNative: + return "", exceptions.NewValidationError("Native credentials are not supported") + default: + return "", exceptions.NewValidationError("invalid credentials type: " + credentialsType) + } +} + +func mapAccountCredentialsTokenEndpointAuthMethod( + authMethod string, + credentialType database.AccountCredentialsType, + transport database.Transport, +) (database.AuthMethod, *exceptions.ServiceError) { + switch credentialType { + case database.AccountCredentialsTypeNative: + if authMethod != "" && authMethod != AuthMethodNone { + return "", exceptions.NewValidationError("auth method is not supported for native credentials") + } + + return database.AuthMethodNone, nil + case database.AccountCredentialsTypeService: + if authMethod == "" || authMethod == AuthMethodNone { + return "", exceptions.NewValidationError("auth method is required for service credentials") + } + + return mapAuthMethod(authMethod) + case database.AccountCredentialsTypeMcp: + if transport == database.TransportStdio { + if authMethod != "" && authMethod != AuthMethodNone { + return "", exceptions.NewValidationError("auth method is not supported for stdio mcp credentials") + } + + return database.AuthMethodNone, nil + } + if transport == database.TransportStreamableHttp { + return mapAuthMethod(authMethod) + } + + return "", exceptions.NewValidationError("invalid transport: " + string(transport)) + default: + return "", exceptions.NewValidationError("invalid credentials type: " + string(credentialType)) + } +} + func mapAccountCredentialsScope(scope string) (database.AccountCredentialsScope, *exceptions.ServiceError) { acScope := database.AccountCredentialsScope(scope) switch acScope { - case database.AccountCredentialsScopeAccountAdmin, database.AccountCredentialsScopeAccountAuthProvidersRead, + case database.AccountCredentialsScopeEmail, database.AccountCredentialsScopeProfile, + database.AccountCredentialsScopeAccountAdmin, database.AccountCredentialsScopeAccountAuthProvidersRead, database.AccountCredentialsScopeAccountUsersRead, database.AccountCredentialsScopeAccountUsersWrite, database.AccountCredentialsScopeAccountAppsRead, database.AccountCredentialsScopeAccountAppsWrite, database.AccountCredentialsScopeAccountCredentialsRead, database.AccountCredentialsScopeAccountCredentialsWrite: @@ -41,14 +120,12 @@ func mapAccountCredentialsScope(scope string) (database.AccountCredentialsScope, return "", exceptions.NewValidationError("invalid scope: " + scope) } -// NOTE: using a map will lead to a null pointer dereference even if the slice is not empty func mapAccountCredentialsScopes(scopes []string) ([]database.AccountCredentialsScope, *exceptions.ServiceError) { scopesSet := utils.SliceToHashSet(scopes) if scopesSet.IsEmpty() { return nil, exceptions.NewValidationError("scopes cannot be empty") } - // return utils.MapSliceWithErr(scopesSet.Items(), mapAccountCredentialsScope) mappedScopes := make([]database.AccountCredentialsScope, 0, scopesSet.Size()) for _, scope := range scopesSet.Items() { mappedScope, serviceErr := mapAccountCredentialsScope(scope) @@ -64,10 +141,21 @@ type CreateAccountCredentialsOptions struct { RequestID string AccountPublicID uuid.UUID AccountVersion int32 - Alias string + CredentialsType string + Name string + Domain string + ClientURI string + RedirectURIs []string + LogoURI string + TOSURI string + PolicyURI string + SoftwareID string + SoftwareVersion string + Contacts []string + CreationMethod database.CreationMethod + Transport string Scopes []string AuthMethod string - Issuers []string Algorithm string } @@ -78,12 +166,28 @@ func (s *Services) CreateAccountCredentials( logger := s.buildLogger(opts.RequestID, accountCredentialsLocation, "CreateAccountCredentials").With( "accountPublicID", opts.AccountPublicID, "scopes", opts.Scopes, - "alias", opts.Alias, + "name", opts.Name, "authMethod", opts.AuthMethod, ) logger.InfoContext(ctx, "Creating account keys...") - authMethod, serviceErr := mapAuthMethod(opts.AuthMethod) + credentialsType, serviceErr := mapAccountCredentialsType(opts.CredentialsType) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map credentials type", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + transport, serviceErr := mapAccountCredentialsTransport(opts.Transport, credentialsType) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map transport", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + authMethod, serviceErr := mapAccountCredentialsTokenEndpointAuthMethod( + opts.AuthMethod, + credentialsType, + transport, + ) if serviceErr != nil { logger.WarnContext(ctx, "Failed to map auth method", "serviceError", serviceErr) return dtos.AccountCredentialsDTO{}, serviceErr @@ -95,6 +199,12 @@ func (s *Services) CreateAccountCredentials( return dtos.AccountCredentialsDTO{}, serviceErr } + domain, serviceErr := mapDomain(opts.ClientURI, opts.Domain) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map domain", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + accountID, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ RequestID: opts.RequestID, PublicID: opts.AccountPublicID, @@ -105,21 +215,55 @@ func (s *Services) CreateAccountCredentials( return dtos.AccountCredentialsDTO{}, serviceErr } - alias := utils.Lowered(opts.Alias) - count, err := s.database.CountAccountCredentialsByAliasAndAccountID( + name := strings.TrimSpace(opts.Name) + count, err := s.database.CountAccountCredentialsByNameAndAccountID( ctx, - database.CountAccountCredentialsByAliasAndAccountIDParams{ + database.CountAccountCredentialsByNameAndAccountIDParams{ AccountID: accountID, - Alias: alias, + Name: name, }, ) if err != nil { - logger.ErrorContext(ctx, "Failed to count account credentials by alias", "error", err) + logger.ErrorContext(ctx, "Failed to count account credentials by name", "error", err) return dtos.AccountCredentialsDTO{}, exceptions.NewInternalServerError() } if count > 0 { - logger.WarnContext(ctx, "Account credentials alias already exists", "alias", alias) - return dtos.AccountCredentialsDTO{}, exceptions.NewConflictError("Account credentials alias already exists") + logger.WarnContext(ctx, "Account credentials name already exists", "name", name) + return dtos.AccountCredentialsDTO{}, exceptions.NewConflictError("Account credentials name already exists") + } + + if authMethod == database.AuthMethodNone { + accountCredentials, err := s.database.CreateAccountCredentials( + ctx, + database.CreateAccountCredentialsParams{ + ClientID: utils.Base62UUID(), + AccountID: accountID, + AccountPublicID: opts.AccountPublicID, + CredentialsType: credentialsType, + Name: name, + Scopes: scopes, + TokenEndpointAuthMethod: authMethod, + Domain: domain, + ClientUri: utils.ProcessURL(opts.ClientURI), + RedirectUris: utils.MapSlice(opts.RedirectURIs, func(uri *string) string { + return utils.ProcessURL(*uri) + }), + LogoUri: mapEmptyURL(opts.LogoURI), + PolicyUri: mapEmptyURL(opts.PolicyURI), + TosUri: mapEmptyURL(opts.TOSURI), + SoftwareID: opts.SoftwareID, + SoftwareVersion: mapEmptyString(opts.SoftwareVersion), + Contacts: opts.Contacts, + CreationMethod: opts.CreationMethod, + Transport: transport, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account credentials", "error", err) + return dtos.AccountCredentialsDTO{}, exceptions.FromDBError(err) + } + + return dtos.MapAccountCredentialsToDTO(&accountCredentials), nil } qrs, txn, err := s.database.BeginTx(ctx) @@ -132,18 +276,31 @@ func (s *Services) CreateAccountCredentials( s.database.FinalizeTx(ctx, txn, err, serviceErr) }() - accountCredentials, err := qrs.CreateAccountCredentials(ctx, database.CreateAccountCredentialsParams{ - ClientID: utils.Base62UUID(), - AccountID: accountID, - AccountPublicID: opts.AccountPublicID, - CredentialsType: database.AccountCredentialsTypeClient, - Scopes: scopes, - TokenEndpointAuthMethod: authMethod, - Alias: alias, - Issuers: utils.MapSlice(opts.Issuers, func(url *string) string { - return utils.ProcessURL(*url) - }), - }) + accountCredentials, err := qrs.CreateAccountCredentials( + ctx, + database.CreateAccountCredentialsParams{ + ClientID: utils.Base62UUID(), + AccountID: accountID, + AccountPublicID: opts.AccountPublicID, + CredentialsType: credentialsType, + Name: name, + Scopes: scopes, + TokenEndpointAuthMethod: authMethod, + Domain: domain, + ClientUri: utils.ProcessURL(opts.ClientURI), + RedirectUris: utils.MapSlice(opts.RedirectURIs, func(uri *string) string { + return utils.ProcessURL(*uri) + }), + LogoUri: mapEmptyURL(opts.LogoURI), + PolicyUri: mapEmptyURL(opts.PolicyURI), + TosUri: mapEmptyURL(opts.TOSURI), + SoftwareID: opts.SoftwareID, + SoftwareVersion: mapEmptyString(opts.SoftwareVersion), + Contacts: opts.Contacts, + CreationMethod: opts.CreationMethod, + Transport: transport, + }, + ) if err != nil { logger.ErrorContext(ctx, "Failed to create account credentials", "error", err) serviceErr = exceptions.FromDBError(err) @@ -367,14 +524,38 @@ func (s *Services) ListAccountCredentialsByAccountPublicID( return utils.MapSlice(accountCredentials, dtos.MapAccountCredentialsToDTO), count, nil } +func mapAccountCredentialsUpdateTransport( + transport string, + currentTransport database.Transport, + credentialsType database.AccountCredentialsType, +) (database.Transport, *exceptions.ServiceError) { + if credentialsType == database.AccountCredentialsTypeMcp { + if transport != "" { + return "", exceptions.NewValidationError("Transport update is not allowed for MCP credentials") + } + + return currentTransport, nil + } + + return mapAccountCredentialsTransport(transport, credentialsType) +} + type UpdateAccountCredentialsScopesOptions struct { RequestID string AccountPublicID uuid.UUID AccountVersion int32 ClientID string - Alias string - Scopes []tokens.AccountScope - Issuers []string + Name string + Domain string + Scopes []string + ClientURI string + RedirectURIs []string + LogoURI string + TOSURI string + PolicyURI string + SoftwareVersion string + Contacts []string + Transport string } func (s *Services) UpdateAccountCredentials( @@ -406,13 +587,13 @@ func (s *Services) UpdateAccountCredentials( return dtos.AccountCredentialsDTO{}, serviceErr } - alias := utils.Lowered(opts.Alias) - if alias != accountCredentialsDTO.Alias { - count, err := s.database.CountAccountCredentialsByAliasAndAccountID( + name := strings.TrimSpace(opts.Name) + if name != accountCredentialsDTO.Name { + count, err := s.database.CountAccountCredentialsByNameAndAccountID( ctx, - database.CountAccountCredentialsByAliasAndAccountIDParams{ + database.CountAccountCredentialsByNameAndAccountIDParams{ AccountID: accountCredentialsDTO.AccountID(), - Alias: alias, + Name: name, }, ) if err != nil { @@ -420,18 +601,39 @@ func (s *Services) UpdateAccountCredentials( return dtos.AccountCredentialsDTO{}, exceptions.NewInternalServerError() } if count > 0 { - logger.WarnContext(ctx, "Account credentials alias already exists", "alias", alias) + logger.WarnContext(ctx, "Account credentials alias already exists", "name", name) return dtos.AccountCredentialsDTO{}, exceptions.NewConflictError("Account credentials alias already exists") } } + transport, serviceErr := mapAccountCredentialsUpdateTransport( + opts.Transport, + accountCredentialsDTO.Transport, + accountCredentialsDTO.Type, + ) + if serviceErr != nil { + return dtos.AccountCredentialsDTO{}, serviceErr + } + + domain, serviceErr := mapDomain(opts.ClientURI, opts.Domain) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map domain", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + accountCredentials, err := s.database.UpdateAccountCredentials(ctx, database.UpdateAccountCredentialsParams{ - ID: accountCredentialsDTO.ID(), - Scopes: scopes, - Alias: alias, - Issuers: utils.MapSlice(opts.Issuers, func(url *string) string { - return utils.ProcessURL(*url) - }), + ID: accountCredentialsDTO.ID(), + Scopes: scopes, + Name: name, + Domain: domain, + ClientUri: opts.ClientURI, + RedirectUris: opts.RedirectURIs, + LogoUri: mapEmptyURL(opts.LogoURI), + TosUri: mapEmptyURL(opts.TOSURI), + PolicyUri: mapEmptyURL(opts.PolicyURI), + SoftwareVersion: mapEmptyString(opts.SoftwareVersion), + Contacts: opts.Contacts, + Transport: transport, }) if err != nil { logger.ErrorContext(ctx, "Failed to update account keys scopes", "error", err) diff --git a/idp/internal/services/apps.go b/idp/internal/services/apps.go index bcfa925..5841c37 100644 --- a/idp/internal/services/apps.go +++ b/idp/internal/services/apps.go @@ -8,7 +8,6 @@ package services import ( "context" - "net/url" "strings" "time" @@ -30,6 +29,7 @@ const ( transportSTDIO string = "stdio" transportStreamableHTTP string = "streamable_http" transportHTTP string = "http" + transportHTTPS string = "https" ) var authCodeAppGrantTypes = []database.GrantType{database.GrantTypeAuthorizationCode, database.GrantTypeRefreshToken} @@ -621,27 +621,6 @@ func (s *Services) checkForDuplicateApps( return nil } -func mapDomain(clientURI string, domain string) (string, *exceptions.ServiceError) { - trimmed := strings.TrimSpace(domain) - if trimmed != "" { - return trimmed, nil - } - - parsed, err := url.Parse(strings.TrimSpace(clientURI)) - if err != nil || parsed == nil { - return "", exceptions.NewValidationError("Invalid client URI") - } - if parsed.Scheme != "http" && parsed.Scheme != "https" { - return "", exceptions.NewValidationError("Invalid client URI") - } - - host := parsed.Hostname() - if strings.TrimSpace(host) == "" { - return "", exceptions.NewValidationError("Invalid client URI") - } - return host, nil -} - type createAppOptions struct { requestID string accountID int32 @@ -2032,7 +2011,6 @@ type updateAppOptions struct { logoURI string tosURI string policyURI string - softwareID string softwareVersion string contacts []string redirectURIs []string @@ -2086,7 +2064,6 @@ func (s *Services) updateApp( LogoUri: mapEmptyURL(opts.logoURI), TosUri: mapEmptyURL(opts.tosURI), PolicyUri: mapEmptyURL(opts.policyURI), - SoftwareID: opts.softwareID, SoftwareVersion: softwareVersion, Domain: derivedDomain, Transport: opts.transport, @@ -2154,7 +2131,6 @@ func (s *Services) updateSingleApp( LogoUri: mapEmptyURL(opts.logoURI), TosUri: mapEmptyURL(opts.tosURI), PolicyUri: mapEmptyURL(opts.policyURI), - SoftwareID: opts.softwareID, SoftwareVersion: softwareVersion, Domain: derivedDomain, Transport: opts.transport, @@ -2280,7 +2256,6 @@ func (s *Services) UpdateWebSPANativeApp( logoURI: opts.LogoURI, tosURI: opts.TOSURI, policyURI: opts.PolicyURI, - softwareID: opts.SoftwareID, softwareVersion: opts.SoftwareVersion, contacts: opts.Contacts, redirectURIs: opts.RedirectURIs, @@ -2359,7 +2334,6 @@ func (s *Services) UpdateBackendApp( logoURI: opts.LogoURI, tosURI: opts.TOSURI, policyURI: opts.PolicyURI, - softwareID: opts.SoftwareID, softwareVersion: opts.SoftwareVersion, contacts: opts.Contacts, redirectURIs: make([]string, 0), @@ -2492,7 +2466,6 @@ func (s *Services) UpdateDeviceApp( logoURI: opts.LogoURI, tosURI: opts.TOSURI, policyURI: opts.PolicyURI, - softwareID: opts.SoftwareID, softwareVersion: opts.SoftwareVersion, contacts: opts.Contacts, redirectURIs: make([]string, 0), @@ -2610,7 +2583,6 @@ func (s *Services) UpdateServiceApp( logoURI: opts.LogoURI, tosURI: opts.TOSURI, policyURI: opts.PolicyURI, - softwareID: opts.SoftwareID, softwareVersion: opts.SoftwareVersion, contacts: opts.Contacts, redirectURIs: make([]string, 0), @@ -2710,7 +2682,6 @@ func (s *Services) UpdateMCPApp( logoURI: opts.LogoURI, tosURI: opts.TOSURI, policyURI: opts.PolicyURI, - softwareID: opts.SoftwareID, softwareVersion: opts.SoftwareVersion, contacts: opts.Contacts, redirectURIs: utils.ToEmptySlice(opts.RedirectURIs), diff --git a/idp/internal/services/dtos/account_credentials.go b/idp/internal/services/dtos/account_credentials.go index e2a1ce1..20d5ede 100644 --- a/idp/internal/services/dtos/account_credentials.go +++ b/idp/internal/services/dtos/account_credentials.go @@ -17,10 +17,21 @@ import ( type AccountCredentialsDTO struct { ClientID string `json:"client_id"` - Alias string `json:"alias"` + Type database.AccountCredentialsType `json:"type"` + Name string `json:"name"` + Domain string `json:"domain"` Scopes []database.AccountCredentialsScope `json:"scopes"` TokenEndpointAuthMethod database.AuthMethod `json:"token_endpoint_auth_method"` - Issuers []string `json:"issuers,omitempty"` + Transport database.Transport `json:"transport"` + CreationMethod database.CreationMethod `json:"creation_method"` + ClientURI string `json:"client_uri"` + RedirectURIs []string `json:"redirect_uris"` + LogoURI string `json:"logo_uri,omitempty"` + TOSURI string `json:"tos_uri,omitempty"` + PolicyURI string `json:"policy_uri,omitempty"` + SoftwareID string `json:"software_id"` + SoftwareVersion string `json:"software_version,omitempty"` + Contacts []string `json:"contacts,omitempty"` ClientSecretID string `json:"client_secret_id,omitempty"` ClientSecret string `json:"client_secret,omitempty"` ClientSecretJWK utils.JWK `json:"client_secret_jwk,omitempty"` @@ -65,12 +76,32 @@ func (ak *AccountCredentialsDTO) UnmarshalJSON(data []byte) error { func MapAccountCredentialsToDTO( accountCredential *database.AccountCredential, ) AccountCredentialsDTO { + var redirectURIs []string + if len(accountCredential.RedirectUris) > 0 { + redirectURIs = accountCredential.RedirectUris + } + + var contacts []string + if len(accountCredential.Contacts) > 0 { + contacts = accountCredential.Contacts + } + return AccountCredentialsDTO{ id: accountCredential.ID, ClientID: accountCredential.ClientID, - Alias: accountCredential.Alias, - Scopes: accountCredential.Scopes, - Issuers: accountCredential.Issuers, + Type: accountCredential.CredentialsType, + Name: accountCredential.Name, + Domain: accountCredential.Domain, + ClientURI: accountCredential.ClientUri, + RedirectURIs: redirectURIs, + LogoURI: accountCredential.LogoUri.String, + TOSURI: accountCredential.TosUri.String, + PolicyURI: accountCredential.PolicyUri.String, + SoftwareID: accountCredential.SoftwareID, + SoftwareVersion: accountCredential.SoftwareVersion.String, + Contacts: contacts, + CreationMethod: accountCredential.CreationMethod, + Transport: accountCredential.Transport, TokenEndpointAuthMethod: accountCredential.TokenEndpointAuthMethod, accountId: accountCredential.AccountID, } @@ -81,17 +112,33 @@ func MapAccountCredentialsToDTOWithJWK( jwk utils.JWK, exp time.Time, ) AccountCredentialsDTO { + var contacts []string + if len(accountKeys.Contacts) > 0 { + contacts = accountKeys.Contacts + } + return AccountCredentialsDTO{ id: accountKeys.ID, - Alias: accountKeys.Alias, + Type: accountKeys.CredentialsType, + Name: accountKeys.Name, + Domain: accountKeys.Domain, + ClientURI: accountKeys.ClientUri, + RedirectURIs: accountKeys.RedirectUris, + LogoURI: accountKeys.LogoUri.String, + TOSURI: accountKeys.TosUri.String, + PolicyURI: accountKeys.PolicyUri.String, + SoftwareID: accountKeys.SoftwareID, + SoftwareVersion: accountKeys.SoftwareVersion.String, + Contacts: contacts, + CreationMethod: accountKeys.CreationMethod, + Transport: accountKeys.Transport, + TokenEndpointAuthMethod: accountKeys.TokenEndpointAuthMethod, + accountId: accountKeys.AccountID, ClientID: accountKeys.ClientID, ClientSecretID: jwk.GetKeyID(), ClientSecretJWK: jwk, ClientSecretExp: exp.Unix(), Scopes: accountKeys.Scopes, - Issuers: accountKeys.Issuers, - TokenEndpointAuthMethod: accountKeys.TokenEndpointAuthMethod, - accountId: accountKeys.AccountID, } } @@ -101,16 +148,32 @@ func MapAccountCredentialsToDTOWithSecret( secret string, exp time.Time, ) AccountCredentialsDTO { + var contacts []string + if len(accountKeys.Contacts) > 0 { + contacts = accountKeys.Contacts + } + return AccountCredentialsDTO{ id: accountKeys.ID, - Alias: accountKeys.Alias, + Type: accountKeys.CredentialsType, + Name: accountKeys.Name, + Domain: accountKeys.Domain, + ClientURI: accountKeys.ClientUri, + RedirectURIs: accountKeys.RedirectUris, + LogoURI: accountKeys.LogoUri.String, + TOSURI: accountKeys.TosUri.String, + PolicyURI: accountKeys.PolicyUri.String, + SoftwareID: accountKeys.SoftwareID, + SoftwareVersion: accountKeys.SoftwareVersion.String, + Contacts: contacts, + CreationMethod: accountKeys.CreationMethod, + Transport: accountKeys.Transport, + TokenEndpointAuthMethod: accountKeys.TokenEndpointAuthMethod, + accountId: accountKeys.AccountID, ClientID: accountKeys.ClientID, ClientSecretID: secretID, ClientSecret: fmt.Sprintf("%s.%s", secretID, secret), ClientSecretExp: exp.Unix(), Scopes: accountKeys.Scopes, - Issuers: accountKeys.Issuers, - TokenEndpointAuthMethod: accountKeys.TokenEndpointAuthMethod, - accountId: accountKeys.AccountID, } } diff --git a/idp/internal/services/helpers.go b/idp/internal/services/helpers.go index ed8b61f..f5f1b7b 100644 --- a/idp/internal/services/helpers.go +++ b/idp/internal/services/helpers.go @@ -8,6 +8,7 @@ package services import ( "log/slog" + "net/url" "strings" "github.com/jackc/pgx/v5/pgtype" @@ -81,7 +82,7 @@ func mapAuthMethod(authMethod string) (database.AuthMethod, *exceptions.ServiceE return database.AuthMethodClientSecretPost, nil case AuthMethodClientSecretJWT: return database.AuthMethodClientSecretJwt, nil - case AuthMethodNone: + case AuthMethodNone, "": return database.AuthMethodNone, nil default: return "", exceptions.NewValidationError("invalid auth method") @@ -164,6 +165,27 @@ func mapScope(scope string) (database.Scopes, *exceptions.ServiceError) { } } +func mapDomain(baseURI string, domain string) (string, *exceptions.ServiceError) { + trimmed := strings.TrimSpace(domain) + if trimmed != "" { + return trimmed, nil + } + + parsed, err := url.Parse(strings.TrimSpace(baseURI)) + if err != nil || parsed == nil { + return "", exceptions.NewValidationError("Invalid client URI") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", exceptions.NewValidationError("Invalid client URI") + } + + host := parsed.Hostname() + if strings.TrimSpace(host) == "" { + return "", exceptions.NewValidationError("Invalid client URI") + } + return host, nil +} + func mapTwoFactorType(twoFactorType string) (database.TwoFactorType, *exceptions.ServiceError) { if len(twoFactorType) < 4 { return "", exceptions.NewValidationError("invalid two factor type") diff --git a/idp/internal/services/oauth.go b/idp/internal/services/oauth.go index 7193e0e..fadf035 100644 --- a/idp/internal/services/oauth.go +++ b/idp/internal/services/oauth.go @@ -11,7 +11,6 @@ import ( "fmt" "log/slog" "net/url" - "slices" "strings" "time" @@ -624,7 +623,7 @@ func (s *Services) validateAccountJWTClaims( } if opts.claims.Issuer == "" || !utils.IsValidURL(opts.claims.Issuer) || - !slices.Contains(accountClientsDTO.Issuers, utils.ProcessURL(opts.claims.Issuer)) { + fmt.Sprintf("https://%s", accountClientsDTO.Domain) != opts.claims.Issuer { logger.WarnContext(ctx, "JWT Bearer token issuer is not allowed", "issuer", opts.claims.Issuer) return dtos.AccountCredentialsDTO{}, exceptions.NewForbiddenError() } diff --git a/idp/tests/account_credentials_test.go b/idp/tests/account_credentials_test.go index 0e50956..52de15e 100644 --- a/idp/tests/account_credentials_test.go +++ b/idp/tests/account_credentials_test.go @@ -237,7 +237,7 @@ func TestCreateAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "existing-alias", + Name: "existing-alias", Scopes: []string{"account:users:read", "account:users:write"}, AuthMethod: "private_key_jwt", }); err != nil { @@ -331,7 +331,7 @@ func TestListAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: alias, + Name: alias, Scopes: scopes, AuthMethod: authMethods, Issuers: issuers, @@ -470,7 +470,7 @@ func TestGetAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "get-cred", + Name: "get-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", Issuers: []string{"https://issuer.example.com"}, @@ -536,7 +536,7 @@ func TestGetAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "forbidden-cred", + Name: "forbidden-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", }) @@ -575,7 +575,7 @@ func TestUpdateAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "update-cred", + Name: "update-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", Issuers: []string{"https://issuer.example.com"}, @@ -641,7 +641,7 @@ func TestUpdateAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "existing-alias", + Name: "existing-alias", Scopes: []string{"account:users:read"}, AuthMethod: "client_secret_basic", Issuers: []string{"updated.example.com"}, @@ -653,7 +653,7 @@ func TestUpdateAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "other-alias", + Name: "other-alias", Scopes: []string{"account:users:read"}, Issuers: []string{"https://updated.example.com"}, AuthMethod: "client_secret_basic", @@ -719,7 +719,7 @@ func TestUpdateAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "forbidden-cred", + Name: "forbidden-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", Issuers: []string{"https://issuer.example.com"}, @@ -763,7 +763,7 @@ func TestDeleteAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "delete-cred", + Name: "delete-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", Issuers: []string{"https://issuer.example.com"}, @@ -824,7 +824,7 @@ func TestDeleteAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "forbidden-cred", + Name: "forbidden-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", }) @@ -864,7 +864,7 @@ func TestRevokeAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "revoke-cred", + Name: "revoke-cred", Scopes: []string{"account:admin"}, Issuers: []string{"https://issuer.example.com"}, AuthMethod: authMethods, @@ -979,7 +979,7 @@ func TestRevokeAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "forbidden-cred", + Name: "forbidden-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", }) @@ -1017,7 +1017,7 @@ func TestListAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "list-cred", + Name: "list-cred", Scopes: []string{"account:admin"}, Issuers: []string{"https://issuer.example.com"}, AuthMethod: authMethods, @@ -1094,7 +1094,7 @@ func TestListAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "forbidden-cred", + Name: "forbidden-cred", Scopes: []string{"account:admin"}, Issuers: []string{"https://issuer.example.com"}, AuthMethod: "client_secret_basic", @@ -1135,7 +1135,7 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "create-secret-cred", + Name: "create-secret-cred", Scopes: []string{"account:admin"}, Issuers: []string{"https://issuer.example.com"}, AuthMethod: authMethods, @@ -1416,7 +1416,7 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "forbidden-cred", + Name: "forbidden-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", }) @@ -1453,7 +1453,7 @@ func TestGetAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "get-secret-cred", + Name: "get-secret-cred", Scopes: []string{"account:admin"}, Issuers: []string{"https://issuer.example.com"}, AuthMethod: authMethods, @@ -1542,7 +1542,7 @@ func TestGetAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "forbidden-cred", + Name: "forbidden-cred", Scopes: []string{"account:admin"}, Issuers: []string{"https://issuer.example.com"}, AuthMethod: "client_secret_basic", diff --git a/idp/tests/oauth_test.go b/idp/tests/oauth_test.go index d931516..b0c2909 100644 --- a/idp/tests/oauth_test.go +++ b/idp/tests/oauth_test.go @@ -750,7 +750,7 @@ func TestOAuthToken(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "update-cred", + Name: "update-cred", Scopes: []string{"account:admin"}, AuthMethod: "private_key_jwt", Issuers: []string{"https://issuer.example.com"}, @@ -826,7 +826,7 @@ func TestOAuthToken(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Alias: "update-cred", + Name: "update-cred", Scopes: []string{"account:admin"}, AuthMethod: am, Issuers: []string{"https://issuer.example.com"}, From 3cb4adf52937be4869507e06d4ebede4d70955c1 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sat, 16 Aug 2025 13:41:39 +1200 Subject: [PATCH 03/23] feat(idp): add account dynamic registration config --- idp/go.mod | 48 +- idp/go.sum | 48 + idp/initial_schema.dbml | 9 +- .../controllers/account_credentials.go | 7 + .../account_dynamic_registration_configs.go | 96 ++ idp/internal/controllers/apps.go | 1 + .../controllers/bodies/account_credentials.go | 3 +- .../account_dynamics_registration_configs.go | 16 + idp/internal/controllers/middleware.go | 11 +- idp/internal/controllers/paths/common.go | 1 + .../controllers/paths/dynamic_registration.go | 13 + idp/internal/controllers/paths/oauth.go | 1 + ...ccount_dynamic_registration_configs.sql.go | 185 +++ ...0241213231542_create_initial_schema.up.sql | 12 +- idp/internal/providers/database/models.go | 20 +- .../account_dynamic_registration_configs.sql | 48 + idp/internal/server/routes.go | 1 + .../routes/account_dynamic_registration.go | 36 + .../account_credentials_registration.go | 7 + .../account_dynamic_registration_configs.go | 233 +++ .../account_dynamic_registration_config.go | 38 + idp/tests/account_credentials_test.go | 1269 ++++++++--------- idp/tests/oauth_test.go | 18 +- 23 files changed, 1407 insertions(+), 714 deletions(-) create mode 100644 idp/internal/controllers/account_dynamic_registration_configs.go create mode 100644 idp/internal/controllers/bodies/account_dynamics_registration_configs.go create mode 100644 idp/internal/controllers/paths/dynamic_registration.go create mode 100644 idp/internal/providers/database/account_dynamic_registration_configs.sql.go create mode 100644 idp/internal/providers/database/queries/account_dynamic_registration_configs.sql create mode 100644 idp/internal/server/routes/account_dynamic_registration.go create mode 100644 idp/internal/services/account_credentials_registration.go create mode 100644 idp/internal/services/account_dynamic_registration_configs.go create mode 100644 idp/internal/services/dtos/account_dynamic_registration_config.go diff --git a/idp/go.mod b/idp/go.mod index 0c020e9..c2176eb 100644 --- a/idp/go.mod +++ b/idp/go.mod @@ -1,13 +1,13 @@ module github.com/tugascript/devlogs/idp -go 1.24.0 +go 1.25.0 require ( github.com/biter777/countries v1.7.5 github.com/go-faker/faker/v4 v4.6.1 github.com/go-playground/validator/v10 v10.27.0 github.com/gofiber/fiber/v2 v2.52.9 - github.com/gofiber/storage/redis/v3 v3.2.0 + github.com/gofiber/storage/redis/v3 v3.4.0 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/h2non/gock v1.2.0 @@ -16,19 +16,19 @@ require ( github.com/openbao/openbao/api/auth/approle/v2 v2.3.1 github.com/openbao/openbao/api/v2 v2.3.1 github.com/pquerna/otp v1.5.0 - github.com/redis/go-redis/v9 v9.11.0 - golang.org/x/crypto v0.40.0 + github.com/redis/go-redis/v9 v9.12.1 + golang.org/x/crypto v0.41.0 golang.org/x/oauth2 v0.30.0 - golang.org/x/text v0.27.0 - google.golang.org/api v0.244.0 + golang.org/x/text v0.28.0 + google.golang.org/api v0.247.0 ) require ( - cloud.google.com/go/auth v0.16.3 // indirect + cloud.google.com/go/auth v0.16.4 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect - cloud.google.com/go/compute/metadata v0.7.0 // indirect - github.com/andybalholm/brotli v1.1.1 // indirect - github.com/boombuler/barcode v1.0.2 // indirect + cloud.google.com/go/compute/metadata v0.8.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/boombuler/barcode v1.1.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect @@ -39,7 +39,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-viper/mapstructure/v2 v2.3.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect @@ -47,11 +47,11 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect - github.com/hashicorp/go-retryablehttp v0.7.7 // indirect + github.com/hashicorp/go-retryablehttp v0.7.8 // indirect github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect github.com/hashicorp/go-sockaddr v1.0.7 // indirect - github.com/hashicorp/hcl v1.0.1-vault-5 // indirect + github.com/hashicorp/hcl v1.0.1-vault-7 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect @@ -61,22 +61,24 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect + github.com/philhofer/fwd v1.2.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect github.com/tinylib/msgp v1.3.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasthttp v1.62.0 // indirect + github.com/valyala/fasthttp v1.64.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect - go.opentelemetry.io/otel v1.36.0 // indirect - go.opentelemetry.io/otel/metric v1.36.0 // indirect - go.opentelemetry.io/otel/trace v1.36.0 // indirect - golang.org/x/net v0.42.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + golang.org/x/mod v0.26.0 // indirect + golang.org/x/net v0.43.0 // indirect golang.org/x/sync v0.16.0 // indirect - golang.org/x/sys v0.34.0 // indirect + golang.org/x/sys v0.35.0 // indirect golang.org/x/time v0.12.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250728155136-f173205681a0 // indirect + golang.org/x/tools v0.35.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect google.golang.org/grpc v1.74.2 // indirect - google.golang.org/protobuf v1.36.6 // indirect + google.golang.org/protobuf v1.36.7 // indirect ) diff --git a/idp/go.sum b/idp/go.sum index be6b0ee..f5c3c2b 100644 --- a/idp/go.sum +++ b/idp/go.sum @@ -1,9 +1,13 @@ cloud.google.com/go/auth v0.16.3 h1:kabzoQ9/bobUmnseYnBO6qQG7q4a/CffFRlJSxv2wCc= cloud.google.com/go/auth v0.16.3/go.mod h1:NucRGjaXfzP1ltpcQ7On/VTZ0H4kWB5Jy+Y9Dnm76fA= +cloud.google.com/go/auth v0.16.4 h1:fXOAIQmkApVvcIn7Pc2+5J8QTMVbUGLscnSVNl11su8= +cloud.google.com/go/auth v0.16.4/go.mod h1:j10ncYwjX/g3cdX7GpEzsdM+d+ZNsXAbb6qXA7p1Y5M= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU= cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo= +cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= +cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= @@ -12,11 +16,15 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/biter777/countries v1.7.5 h1:MJ+n3+rSxWQdqVJU8eBy9RqcdH6ePPn4PJHocVWUa+Q= github.com/biter777/countries v1.7.5/go.mod h1:1HSpZ526mYqKJcpT5Ti1kcGQ0L0SrXWIaptUWjFfv2E= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.2 h1:79yrbttoZrLGkL/oOI8hBrUKucwOL0oOjUgEguGMcJ4= github.com/boombuler/barcode v1.0.2/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= +github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -76,10 +84,14 @@ github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg= github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/go-viper/mapstructure/v2 v2.3.0 h1:27XbWsHIqhbdR5TIC911OfYvgSaW93HM+dX7970Q7jk= github.com/go-viper/mapstructure/v2 v2.3.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/storage/redis/v3 v3.2.0 h1:1cmxmH6ZniZcWHvMpp6LzfcSK5o7CgqiouRqrVCNY9A= github.com/gofiber/storage/redis/v3 v3.2.0/go.mod h1:fffHK3QnjOxOUZGtq08YVNU1lqKvE+pAKJ5roSnM7FE= +github.com/gofiber/storage/redis/v3 v3.4.0 h1:FbtVgHsWkHFaogObFyNbBkNkZL9/zYxQkS1PV0rA5Ss= +github.com/gofiber/storage/redis/v3 v3.4.0/go.mod h1:5efv+XbKwSQju9j7tokMgFWZ1JwlZvSsIL4RNJSDyf0= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= @@ -112,6 +124,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM= github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0= github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= @@ -120,6 +134,8 @@ github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9 github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw= github.com/hashicorp/hcl v1.0.1-vault-5 h1:kI3hhbbyzr4dldA8UdTb7ZlVVlI2DACdCfz31RPDgJM= github.com/hashicorp/hcl v1.0.1-vault-5/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= +github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I= +github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -174,6 +190,8 @@ github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJw github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY= github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= +github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= +github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -185,6 +203,8 @@ github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs= github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= +github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -213,6 +233,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0= github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg= +github.com/valyala/fasthttp v1.64.0 h1:QBygLLQmiAyiXuRhthf0tuRkqAFcrC42dckN2S+N3og= +github.com/valyala/fasthttp v1.64.0/go.mod h1:dGmFxwkWXSK0NbOSJuF7AMVzU+lkHz0wQVvVITv2UQA= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -222,23 +244,35 @@ go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJyS go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY= go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= go.opentelemetry.io/otel/metric v1.36.0 h1:MoWPKVhQvJ+eeXWHFBOPoBOi20jh6Iq2CcCREuTYufE= go.opentelemetry.io/otel/metric v1.36.0/go.mod h1:zC7Ks+yeyJt4xig9DEw9kuUFe5C3zLbVjV2PzT6qzbs= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/sdk v1.36.0 h1:b6SYIuLRs88ztox4EyrvRti80uXIFy+Sqzoh9kFULbs= go.opentelemetry.io/otel/sdk v1.36.0/go.mod h1:+lC+mTgD+MUWfjJubi2vvXWcVxyr9rmlshZni72pXeY= go.opentelemetry.io/otel/sdk/metric v1.36.0 h1:r0ntwwGosWGaa0CrSt8cuNuTcccMXERFwHX4dThiPis= go.opentelemetry.io/otel/sdk/metric v1.36.0/go.mod h1:qTNOhFDfKRwX0yXOqJYegL5WRaW376QbB7P4Pb0qva4= go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w= go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -246,6 +280,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -264,6 +300,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -277,25 +315,35 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.244.0 h1:lpkP8wVibSKr++NCD36XzTk/IzeKJ3klj7vbj+XU5pE= google.golang.org/api v0.244.0/go.mod h1:dMVhVcylamkirHdzEBAIQWUCgqY885ivNeZYd7VAVr8= +google.golang.org/api v0.247.0 h1:tSd/e0QrUlLsrwMKmkbQhYVa109qIintOls2Wh6bngc= +google.golang.org/api v0.247.0/go.mod h1:r1qZOPmxXffXg6xS5uhx16Fa/UFY8QU/K4bfKrnvovM= google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= google.golang.org/genproto/googleapis/rpc v0.0.0-20250728155136-f173205681a0 h1:MAKi5q709QWfnkkpNQ0M12hYJ1+e8qYVDyowc4U1XZM= google.golang.org/genproto/googleapis/rpc v0.0.0-20250728155136-f173205681a0/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/idp/initial_schema.dbml b/idp/initial_schema.dbml index 1571a53..651c4a7 100644 --- a/idp/initial_schema.dbml +++ b/idp/initial_schema.dbml @@ -871,19 +871,22 @@ Table account_dynamic_registration_configs as ADRC { id serial [pk] account_id integer [not null] + account_public_id uuid [not null] + account_credentials_types "account_credentials_type[]" [not null] whitelisted_domains "varchar(250)[]" [not null] - require_software_statement boolean [not null] + require_software_statement_credential_types "account_credentials_type[]" [not null] software_statement_verification_methods "software_statement_verification_method[]" [not null] - require_initial_access_token boolean [not null] + require_initial_access_token_credential_types "account_credentials_type[]" [not null] initial_access_token_generation_methods "initial_access_token_generation_method[]" [not null] created_at timestamptz [not null, default: `now()`] updated_at timestamptz [not null, default: `now()`] Indexes { - (account_id) [name: 'account_dynamic_registration_configs_account_id_idx'] + (account_id) [unique, name: 'account_dynamic_registration_configs_account_id_uidx'] + (account_public_id) [name: 'account_dynamic_registration_configs_account_public_id_idx'] } } Ref: ADRC.account_id > A.id [delete: cascade] diff --git a/idp/internal/controllers/account_credentials.go b/idp/internal/controllers/account_credentials.go index 18a7411..3ee4a26 100644 --- a/idp/internal/controllers/account_credentials.go +++ b/idp/internal/controllers/account_credentials.go @@ -49,6 +49,7 @@ func (c *Controllers) CreateAccountCredentials(ctx *fiber.Ctx) error { RequestID: requestID, AccountPublicID: accountClaims.AccountID, AccountVersion: accountClaims.AccountVersion, + CredentialsType: body.Type, Name: body.Name, Scopes: body.Scopes, AuthMethod: body.TokenEndpointAuthMethod, @@ -58,6 +59,10 @@ func (c *Controllers) CreateAccountCredentials(ctx *fiber.Ctx) error { LogoURI: body.LogoURI, TOSURI: body.TOSURI, PolicyURI: body.PolicyURI, + SoftwareID: body.SoftwareID, + SoftwareVersion: body.SoftwareVersion, + Algorithm: body.Algorithm, + Transport: body.Transport, }, ) if serviceErr != nil { @@ -174,6 +179,7 @@ func (c *Controllers) UpdateAccountCredentials(ctx *fiber.Ctx) error { AccountPublicID: accountClaims.AccountID, AccountVersion: accountClaims.AccountVersion, ClientID: urlParams.ClientID, + Name: body.Name, Scopes: body.Scopes, Transport: body.Transport, Domain: body.Domain, @@ -182,6 +188,7 @@ func (c *Controllers) UpdateAccountCredentials(ctx *fiber.Ctx) error { LogoURI: body.LogoURI, TOSURI: body.TOSURI, PolicyURI: body.PolicyURI, + SoftwareVersion: body.SoftwareVersion, }, ) if serviceErr != nil { diff --git a/idp/internal/controllers/account_dynamic_registration_configs.go b/idp/internal/controllers/account_dynamic_registration_configs.go new file mode 100644 index 0000000..1a9f916 --- /dev/null +++ b/idp/internal/controllers/account_dynamic_registration_configs.go @@ -0,0 +1,96 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package controllers + +import ( + "github.com/gofiber/fiber/v2" + + "github.com/tugascript/devlogs/idp/internal/controllers/bodies" + "github.com/tugascript/devlogs/idp/internal/services" +) + +const ( + accountDynamicRegistrationConfigsLocation string = "account_dynamic_registration_configs" +) + +func (c *Controllers) UpsertAccountDynamicRegistrationConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger( + requestID, + accountDynamicRegistrationConfigsLocation, + "UpsertAccountDynamicRegistrationConfig", + ) + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + body := new(bodies.AccountDynamicRegistrationConfigBody) + if err := ctx.BodyParser(body); err != nil { + return parseRequestErrorResponse(logger, ctx, err) + } + if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { + return validateBodyErrorResponse(logger, ctx, err) + } + + dto, created, serviceErr := c.services.SaveAccountDynamicRegistrationConfig( + ctx.UserContext(), + services.SaveAccountDynamicRegistrationConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + AccountCredentialsTypes: body.AccountCredentialsTypes, + WhitelistedDomains: body.WhitelistedDomains, + RequireSoftwareStatementCredentialTypes: body.RequireSoftwareStatementCredentialTypes, + SoftwareStatementVerificationMethods: body.SoftwareStatementVerificationMethods, + RequireInitialAccessTokenCredentialTypes: body.RequireInitialAccessTokenCredentialTypes, + InitialAccessTokenGenerationMethods: body.InitialAccessTokenGenerationMethods, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + if created { + logResponse(logger, ctx, fiber.StatusCreated) + return ctx.Status(fiber.StatusCreated).JSON(&dto) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&dto) +} + +func (c *Controllers) GetAccountDynamicRegistrationConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger( + requestID, + accountDynamicRegistrationConfigsLocation, + "GetAccountDynamicRegistrationConfig", + ) + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + dto, serviceErr := c.services.GetAccountDynamicRegistrationConfig( + ctx.UserContext(), + services.GetAccountDynamicRegistrationConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&dto) +} diff --git a/idp/internal/controllers/apps.go b/idp/internal/controllers/apps.go index ecea21d..7b1c0cf 100644 --- a/idp/internal/controllers/apps.go +++ b/idp/internal/controllers/apps.go @@ -41,6 +41,7 @@ func (c *Controllers) createWebApp( baseBody *bodies.CreateAppBodyBase, ) error { logger := c.buildLogger(requestID, appsLocation, "createWebApp") + logRequest(logger, ctx) body := new(bodies.CreateAppBodyWeb) if err := ctx.BodyParser(body); err != nil { diff --git a/idp/internal/controllers/bodies/account_credentials.go b/idp/internal/controllers/bodies/account_credentials.go index 2609e6b..e8b6a0a 100644 --- a/idp/internal/controllers/bodies/account_credentials.go +++ b/idp/internal/controllers/bodies/account_credentials.go @@ -10,6 +10,7 @@ type CreateAccountCredentialsBody struct { Type string `json:"type" validate:"required,oneof=native service mcp"` Name string `json:"name" validate:"required,min=1,max=255"` Scopes []string `json:"scopes" validate:"required,unique,dive,oneof=email profile account:admin account:users:read account:users:write account:apps:read account:apps:write account:credentials:read account:credentials:write account:auth_providers:read"` + Transport string `json:"transport,omitempty" validate:"required_if=Type mcp,oneof=http https stdio streamable_http"` TokenEndpointAuthMethod string `json:"token_endpoint_auth_method" validate:"required,oneof=client_secret_basic client_secret_post client_secret_jwt private_key_jwt"` Domain string `json:"domain,omitempty" validate:"omitempty,fqdn,max=250"` ClientURI string `json:"client_uri" validate:"required,uri"` @@ -25,7 +26,7 @@ type CreateAccountCredentialsBody struct { type UpdateAccountCredentialsBody struct { Name string `json:"name" validate:"required,min=1,max=255"` Scopes []string `json:"scopes" validate:"required,unique,dive,oneof=account:admin account:users:read account:users:write account:apps:read account:apps:write account:credentials:read account:credentials:write account:auth_providers:read"` - Transport string `json:"transport" validate:"required,oneof=http https stdio streamable_http"` + Transport string `json:"transport,omitempty" validate:"omitempty,oneof=http https"` Domain string `json:"domain,omitempty" validate:"omitempty,fqdn,max=250"` ClientURI string `json:"client_uri" validate:"required,uri"` RedirectURIs []string `json:"redirect_uris,omitempty" validate:"omitempty,unique,dive,uri"` diff --git a/idp/internal/controllers/bodies/account_dynamics_registration_configs.go b/idp/internal/controllers/bodies/account_dynamics_registration_configs.go new file mode 100644 index 0000000..409f3a2 --- /dev/null +++ b/idp/internal/controllers/bodies/account_dynamics_registration_configs.go @@ -0,0 +1,16 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package bodies + +type AccountDynamicRegistrationConfigBody struct { + AccountCredentialsTypes []string `json:"account_credentials_types" validate:"required,unique,min=1,max=3,oneof=native service mcp"` + WhitelistedDomains []string `json:"whitelisted_domains" validate:"omitempty,unique,min=1,max=250,dive,fqdn"` + RequireSoftwareStatementCredentialTypes []string `json:"require_software_statement_credential_types" validate:"omitempty,unique,min=1,max=3,oneof=native service mcp"` + SoftwareStatementVerificationMethods []string `json:"software_statement_verification_methods" validate:"omitempty,unique,min=1,max=2,oneof=manual jwks_uri"` + RequireInitialAccessTokenCredentialTypes []string `json:"require_initial_access_token_credential_types" validate:"omitempty,unique,min=1,max=3,oneof=native service mcp"` + InitialAccessTokenGenerationMethods []string `json:"initial_access_token_generation_methods" validate:"omitempty,unique,min=1,max=2,oneof=manual authorization_code"` +} diff --git a/idp/internal/controllers/middleware.go b/idp/internal/controllers/middleware.go index ef93a3c..0cafc4e 100644 --- a/idp/internal/controllers/middleware.go +++ b/idp/internal/controllers/middleware.go @@ -195,7 +195,11 @@ func (c *Controllers) AdminScopeMiddleware(ctx *fiber.Ctx) error { return ctx.Next() } -func processHost(host string) (string, error) { +func processHost(backendDomain string, host string) (string, error) { + if !strings.HasSuffix(host, "."+backendDomain) { + return "", errors.New("invalid host") + } + hostArr := strings.Split(host, ".") if len(hostArr) < 2 { return "", errors.New("host must contain at least two parts") @@ -212,12 +216,13 @@ func processHost(host string) (string, error) { func (c *Controllers) AccountHostMiddleware(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) logger := c.buildLogger(requestID, middlewareLocation, "AccountHostMiddleware") - host := ctx.Get("Host") + host := ctx.Hostname() if host == "" { + logger.DebugContext(ctx.UserContext(), "no host found") return serviceErrorResponse(logger, ctx, exceptions.NewNotFoundError()) } - username, err := processHost(host) + username, err := processHost(c.backendDomain, host) if err != nil { logger.DebugContext(ctx.UserContext(), "invalid host", "error", err) return serviceErrorResponse(logger, ctx, exceptions.NewNotFoundError()) diff --git a/idp/internal/controllers/paths/common.go b/idp/internal/controllers/paths/common.go index 3029f3d..bb5e9c8 100644 --- a/idp/internal/controllers/paths/common.go +++ b/idp/internal/controllers/paths/common.go @@ -11,4 +11,5 @@ const ( Keys string = "/keys" Confirm string = "/confirm" Recover string = "/recover" + Config string = "/config" ) diff --git a/idp/internal/controllers/paths/dynamic_registration.go b/idp/internal/controllers/paths/dynamic_registration.go new file mode 100644 index 0000000..b4cffa6 --- /dev/null +++ b/idp/internal/controllers/paths/dynamic_registration.go @@ -0,0 +1,13 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package paths + +const ( + DynamicRegistrationBase string = "/dynamic-registration" + + Domains string = "/domains" +) diff --git a/idp/internal/controllers/paths/oauth.go b/idp/internal/controllers/paths/oauth.go index 6703810..61ddfeb 100644 --- a/idp/internal/controllers/paths/oauth.go +++ b/idp/internal/controllers/paths/oauth.go @@ -14,6 +14,7 @@ const ( OAuthUserInfo string = "/userinfo" OAuthToken string = "/token" OAuthRevoke string = "/revoke" + OAuthRegister string = "/register" OAuthIntrospect string = "/introspect" OAuthDeviceAuth string = "/auth/device" diff --git a/idp/internal/providers/database/account_dynamic_registration_configs.sql.go b/idp/internal/providers/database/account_dynamic_registration_configs.sql.go new file mode 100644 index 0000000..8a0e69a --- /dev/null +++ b/idp/internal/providers/database/account_dynamic_registration_configs.sql.go @@ -0,0 +1,185 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: account_dynamic_registration_configs.sql + +package database + +import ( + "context" + + "github.com/google/uuid" +) + +const createAccountDynamicRegistrationConfig = `-- name: CreateAccountDynamicRegistrationConfig :one + +INSERT INTO "account_dynamic_registration_configs" ( + "account_id", + "account_public_id", + "account_credentials_types", + "whitelisted_domains", + "require_software_statement_credential_types", + "software_statement_verification_methods", + "require_initial_access_token_credential_types", + "initial_access_token_generation_methods" +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8 +) RETURNING id, account_id, account_public_id, account_credentials_types, whitelisted_domains, require_software_statement_credential_types, software_statement_verification_methods, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at +` + +type CreateAccountDynamicRegistrationConfigParams struct { + AccountID int32 + AccountPublicID uuid.UUID + AccountCredentialsTypes []AccountCredentialsType + WhitelistedDomains []string + RequireSoftwareStatementCredentialTypes []AccountCredentialsType + SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod + RequireInitialAccessTokenCredentialTypes []AccountCredentialsType + InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod +} + +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +func (q *Queries) CreateAccountDynamicRegistrationConfig(ctx context.Context, arg CreateAccountDynamicRegistrationConfigParams) (AccountDynamicRegistrationConfig, error) { + row := q.db.QueryRow(ctx, createAccountDynamicRegistrationConfig, + arg.AccountID, + arg.AccountPublicID, + arg.AccountCredentialsTypes, + arg.WhitelistedDomains, + arg.RequireSoftwareStatementCredentialTypes, + arg.SoftwareStatementVerificationMethods, + arg.RequireInitialAccessTokenCredentialTypes, + arg.InitialAccessTokenGenerationMethods, + ) + var i AccountDynamicRegistrationConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.AccountCredentialsTypes, + &i.WhitelistedDomains, + &i.RequireSoftwareStatementCredentialTypes, + &i.SoftwareStatementVerificationMethods, + &i.RequireInitialAccessTokenCredentialTypes, + &i.InitialAccessTokenGenerationMethods, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteAccountDynamicRegistrationConfig = `-- name: DeleteAccountDynamicRegistrationConfig :exec +DELETE FROM "account_dynamic_registration_configs" WHERE "id" = $1 +` + +func (q *Queries) DeleteAccountDynamicRegistrationConfig(ctx context.Context, id int32) error { + _, err := q.db.Exec(ctx, deleteAccountDynamicRegistrationConfig, id) + return err +} + +const findAccountDynamicRegistrationConfigByAccountID = `-- name: FindAccountDynamicRegistrationConfigByAccountID :one +SELECT id, account_id, account_public_id, account_credentials_types, whitelisted_domains, require_software_statement_credential_types, software_statement_verification_methods, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at FROM "account_dynamic_registration_configs" +WHERE "account_id" = $1 LIMIT 1 +` + +func (q *Queries) FindAccountDynamicRegistrationConfigByAccountID(ctx context.Context, accountID int32) (AccountDynamicRegistrationConfig, error) { + row := q.db.QueryRow(ctx, findAccountDynamicRegistrationConfigByAccountID, accountID) + var i AccountDynamicRegistrationConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.AccountCredentialsTypes, + &i.WhitelistedDomains, + &i.RequireSoftwareStatementCredentialTypes, + &i.SoftwareStatementVerificationMethods, + &i.RequireInitialAccessTokenCredentialTypes, + &i.InitialAccessTokenGenerationMethods, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const findAccountDynamicRegistrationConfigByAccountPublicID = `-- name: FindAccountDynamicRegistrationConfigByAccountPublicID :one +SELECT id, account_id, account_public_id, account_credentials_types, whitelisted_domains, require_software_statement_credential_types, software_statement_verification_methods, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at FROM "account_dynamic_registration_configs" +WHERE "account_public_id" = $1 LIMIT 1 +` + +func (q *Queries) FindAccountDynamicRegistrationConfigByAccountPublicID(ctx context.Context, accountPublicID uuid.UUID) (AccountDynamicRegistrationConfig, error) { + row := q.db.QueryRow(ctx, findAccountDynamicRegistrationConfigByAccountPublicID, accountPublicID) + var i AccountDynamicRegistrationConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.AccountCredentialsTypes, + &i.WhitelistedDomains, + &i.RequireSoftwareStatementCredentialTypes, + &i.SoftwareStatementVerificationMethods, + &i.RequireInitialAccessTokenCredentialTypes, + &i.InitialAccessTokenGenerationMethods, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const updateAccountDynamicRegistrationConfig = `-- name: UpdateAccountDynamicRegistrationConfig :one +UPDATE "account_dynamic_registration_configs" SET + "account_credentials_types" = $2, + "whitelisted_domains" = $3, + "require_software_statement_credential_types" = $4, + "software_statement_verification_methods" = $5, + "require_initial_access_token_credential_types" = $6, + "initial_access_token_generation_methods" = $7 +WHERE "id" = $1 +RETURNING id, account_id, account_public_id, account_credentials_types, whitelisted_domains, require_software_statement_credential_types, software_statement_verification_methods, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at +` + +type UpdateAccountDynamicRegistrationConfigParams struct { + ID int32 + AccountCredentialsTypes []AccountCredentialsType + WhitelistedDomains []string + RequireSoftwareStatementCredentialTypes []AccountCredentialsType + SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod + RequireInitialAccessTokenCredentialTypes []AccountCredentialsType + InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod +} + +func (q *Queries) UpdateAccountDynamicRegistrationConfig(ctx context.Context, arg UpdateAccountDynamicRegistrationConfigParams) (AccountDynamicRegistrationConfig, error) { + row := q.db.QueryRow(ctx, updateAccountDynamicRegistrationConfig, + arg.ID, + arg.AccountCredentialsTypes, + arg.WhitelistedDomains, + arg.RequireSoftwareStatementCredentialTypes, + arg.SoftwareStatementVerificationMethods, + arg.RequireInitialAccessTokenCredentialTypes, + arg.InitialAccessTokenGenerationMethods, + ) + var i AccountDynamicRegistrationConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.AccountCredentialsTypes, + &i.WhitelistedDomains, + &i.RequireSoftwareStatementCredentialTypes, + &i.SoftwareStatementVerificationMethods, + &i.RequireInitialAccessTokenCredentialTypes, + &i.InitialAccessTokenGenerationMethods, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql index 69a2d91..d96b79d 100644 --- a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql +++ b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql @@ -1,6 +1,6 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-08-12T22:36:20.011Z +-- Generated at: 2025-08-13T20:40:02.036Z CREATE TYPE "kek_usage" AS ENUM ( 'global', @@ -534,10 +534,12 @@ CREATE TABLE "app_designs" ( CREATE TABLE "account_dynamic_registration_configs" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, + "account_public_id" uuid NOT NULL, + "account_credentials_types" account_credentials_type[] NOT NULL, "whitelisted_domains" varchar(250)[] NOT NULL, - "require_software_statement" boolean NOT NULL, + "require_software_statement_credential_types" account_credentials_type[] NOT NULL, "software_statement_verification_methods" software_statement_verification_method[] NOT NULL, - "require_initial_access_token" boolean NOT NULL, + "require_initial_access_token_credential_types" account_credentials_type[] NOT NULL, "initial_access_token_generation_methods" initial_access_token_generation_method[] NOT NULL, "created_at" timestamptz NOT NULL DEFAULT (now()), "updated_at" timestamptz NOT NULL DEFAULT (now()) @@ -827,7 +829,9 @@ CREATE INDEX "app_designs_account_id_idx" ON "app_designs" ("account_id"); CREATE UNIQUE INDEX "app_designs_app_id_uidx" ON "app_designs" ("app_id"); -CREATE INDEX "account_dynamic_registration_configs_account_id_idx" ON "account_dynamic_registration_configs" ("account_id"); +CREATE UNIQUE INDEX "account_dynamic_registration_configs_account_id_uidx" ON "account_dynamic_registration_configs" ("account_id"); + +CREATE INDEX "account_dynamic_registration_configs_account_public_id_idx" ON "account_dynamic_registration_configs" ("account_public_id"); CREATE INDEX "app_dynamic_registration_configs_account_id_idx" ON "app_dynamic_registration_configs" ("account_id"); diff --git a/idp/internal/providers/database/models.go b/idp/internal/providers/database/models.go index bfbb4ff..f52c300 100644 --- a/idp/internal/providers/database/models.go +++ b/idp/internal/providers/database/models.go @@ -1198,15 +1198,17 @@ type AccountDataEncryptionKey struct { } type AccountDynamicRegistrationConfig struct { - ID int32 - AccountID int32 - WhitelistedDomains []string - RequireSoftwareStatement bool - SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod - RequireInitialAccessToken bool - InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod - CreatedAt time.Time - UpdatedAt time.Time + ID int32 + AccountID int32 + AccountPublicID uuid.UUID + AccountCredentialsTypes []AccountCredentialsType + WhitelistedDomains []string + RequireSoftwareStatementCredentialTypes []AccountCredentialsType + SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod + RequireInitialAccessTokenCredentialTypes []AccountCredentialsType + InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod + CreatedAt time.Time + UpdatedAt time.Time } type AccountKeyEncryptionKey struct { diff --git a/idp/internal/providers/database/queries/account_dynamic_registration_configs.sql b/idp/internal/providers/database/queries/account_dynamic_registration_configs.sql new file mode 100644 index 0000000..697d1a3 --- /dev/null +++ b/idp/internal/providers/database/queries/account_dynamic_registration_configs.sql @@ -0,0 +1,48 @@ +-- Copyright (c) 2025 Afonso Barracha +-- +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +-- name: CreateAccountDynamicRegistrationConfig :one +INSERT INTO "account_dynamic_registration_configs" ( + "account_id", + "account_public_id", + "account_credentials_types", + "whitelisted_domains", + "require_software_statement_credential_types", + "software_statement_verification_methods", + "require_initial_access_token_credential_types", + "initial_access_token_generation_methods" +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8 +) RETURNING *; + +-- name: UpdateAccountDynamicRegistrationConfig :one +UPDATE "account_dynamic_registration_configs" SET + "account_credentials_types" = $2, + "whitelisted_domains" = $3, + "require_software_statement_credential_types" = $4, + "software_statement_verification_methods" = $5, + "require_initial_access_token_credential_types" = $6, + "initial_access_token_generation_methods" = $7 +WHERE "id" = $1 +RETURNING *; + +-- name: FindAccountDynamicRegistrationConfigByAccountID :one +SELECT * FROM "account_dynamic_registration_configs" +WHERE "account_id" = $1 LIMIT 1; + +-- name: FindAccountDynamicRegistrationConfigByAccountPublicID :one +SELECT * FROM "account_dynamic_registration_configs" +WHERE "account_public_id" = $1 LIMIT 1; + +-- name: DeleteAccountDynamicRegistrationConfig :exec +DELETE FROM "account_dynamic_registration_configs" WHERE "id" = $1; \ No newline at end of file diff --git a/idp/internal/server/routes.go b/idp/internal/server/routes.go index 4cde724..5b1729d 100644 --- a/idp/internal/server/routes.go +++ b/idp/internal/server/routes.go @@ -8,6 +8,7 @@ package server func (s *FiberServer) RegisterFiberRoutes() { s.routes.HealthRoutes(s.App) + s.routes.AccountDynamicRegistrationRoutes(s.App) s.routes.OAuthRoutes(s.App) s.routes.AuthRoutes(s.App) s.routes.AccountCredentialsRoutes(s.App) diff --git a/idp/internal/server/routes/account_dynamic_registration.go b/idp/internal/server/routes/account_dynamic_registration.go new file mode 100644 index 0000000..3d0c647 --- /dev/null +++ b/idp/internal/server/routes/account_dynamic_registration.go @@ -0,0 +1,36 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package routes + +import ( + "github.com/gofiber/fiber/v2" + + "github.com/tugascript/devlogs/idp/internal/controllers/paths" + "github.com/tugascript/devlogs/idp/internal/providers/tokens" +) + +func (r *Routes) AccountDynamicRegistrationRoutes(app *fiber.App) { + router := v1PathRouter(app).Group( + paths.AccountsBase+paths.CredentialsBase+paths.DynamicRegistrationBase, + r.controllers.AccountAccessClaimsMiddleware, + r.controllers.AdminScopeMiddleware, + ) + + credentialsWriteScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsWrite) + credentialsReadScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsRead) + + router.Get( + paths.Config, + credentialsReadScopeMiddleware, + r.controllers.GetAccountDynamicRegistrationConfig, + ) + router.Put( + paths.Config, + credentialsWriteScopeMiddleware, + r.controllers.UpsertAccountDynamicRegistrationConfig, + ) +} diff --git a/idp/internal/services/account_credentials_registration.go b/idp/internal/services/account_credentials_registration.go new file mode 100644 index 0000000..1df7cfd --- /dev/null +++ b/idp/internal/services/account_credentials_registration.go @@ -0,0 +1,7 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services diff --git a/idp/internal/services/account_dynamic_registration_configs.go b/idp/internal/services/account_dynamic_registration_configs.go new file mode 100644 index 0000000..a5d353f --- /dev/null +++ b/idp/internal/services/account_dynamic_registration_configs.go @@ -0,0 +1,233 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + + "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/services/dtos" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const ( + accountDynamicRegistrationConfigsLocation string = "account_dynamic_registration_configs" + + softwareStatementVerificationMethodJwksUri string = "jwks_uri" + softwareStatementVerificationMethodManual string = "manual" + + initialAccessTokenGenerationMethodAuthorizationCode string = "authorization_code" + initialAccessTokenGenerationMethodManual string = "manual" +) + +func mapAccountCredentialsTypes(credentialsTypes []string) ([]database.AccountCredentialsType, *exceptions.ServiceError) { + accountCredentialsTypes := make([]database.AccountCredentialsType, 0, len(credentialsTypes)) + for _, credentialsType := range credentialsTypes { + accountCredentialsType, serviceErr := mapAccountCredentialsType(credentialsType) + if serviceErr != nil { + return nil, serviceErr + } + accountCredentialsTypes = append(accountCredentialsTypes, accountCredentialsType) + } + return accountCredentialsTypes, nil +} + +func mapSoftwareStatementVerificationMethod( + softwareStatementVerificationMethod string, +) (database.SoftwareStatementVerificationMethod, *exceptions.ServiceError) { + switch softwareStatementVerificationMethod { + case softwareStatementVerificationMethodJwksUri: + return database.SoftwareStatementVerificationMethodJwksUri, nil + case softwareStatementVerificationMethodManual: + return database.SoftwareStatementVerificationMethodManual, nil + default: + return "", exceptions.NewValidationError("Invalid software statement verification method: " + softwareStatementVerificationMethod) + } +} + +func mapSoftwareStatementVerificationMethods( + ssvms []string, +) ([]database.SoftwareStatementVerificationMethod, *exceptions.ServiceError) { + softwareStatementVerificationMethods := make([]database.SoftwareStatementVerificationMethod, 0, len(ssvms)) + for _, ssvm := range ssvms { + softwareStatementVerificationMethod, serviceErr := mapSoftwareStatementVerificationMethod(ssvm) + if serviceErr != nil { + return nil, serviceErr + } + softwareStatementVerificationMethods = append(softwareStatementVerificationMethods, softwareStatementVerificationMethod) + } + return softwareStatementVerificationMethods, nil +} + +func mapInitialAccessTokenGenerationMethod( + initialAccessTokenGenerationMethod string, +) (database.InitialAccessTokenGenerationMethod, *exceptions.ServiceError) { + switch initialAccessTokenGenerationMethod { + case initialAccessTokenGenerationMethodAuthorizationCode: + return database.InitialAccessTokenGenerationMethodAuthorizationCode, nil + case initialAccessTokenGenerationMethodManual: + return database.InitialAccessTokenGenerationMethodManual, nil + default: + return "", exceptions.NewValidationError("Invalid initial access token generation method: " + initialAccessTokenGenerationMethod) + } +} + +func mapInitialAccessTokenGenerationMethods( + iatgms []string, +) ([]database.InitialAccessTokenGenerationMethod, *exceptions.ServiceError) { + initialAccessTokenGenerationMethods := make([]database.InitialAccessTokenGenerationMethod, 0, len(iatgms)) + for _, iatgm := range iatgms { + initialAccessTokenGenerationMethod, serviceErr := mapInitialAccessTokenGenerationMethod(iatgm) + if serviceErr != nil { + return nil, serviceErr + } + initialAccessTokenGenerationMethods = append(initialAccessTokenGenerationMethods, initialAccessTokenGenerationMethod) + } + return initialAccessTokenGenerationMethods, nil +} + +type SaveAccountDynamicRegistrationConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + AccountCredentialsTypes []string + WhitelistedDomains []string + RequireSoftwareStatementCredentialTypes []string + SoftwareStatementVerificationMethods []string + RequireInitialAccessTokenCredentialTypes []string + InitialAccessTokenGenerationMethods []string +} + +func (s *Services) SaveAccountDynamicRegistrationConfig( + ctx context.Context, + opts SaveAccountDynamicRegistrationConfigOptions, +) (dtos.AccountDynamicRegistrationConfigDTO, bool, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountDynamicRegistrationConfigsLocation, "CreateAccountDynamicRegistrationConfig").With( + "accountPublicID", opts.AccountPublicID, + "accountVersion", opts.AccountVersion, + ) + logger.InfoContext(ctx, "Creating account dynamic registration config...") + + credentialsTypes, serviceErr := mapAccountCredentialsTypes(opts.AccountCredentialsTypes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map credentials types", "serviceError", serviceErr) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, serviceErr + } + + requireSoftwareStatementCredentialTypes, serviceErr := mapAccountCredentialsTypes(opts.RequireSoftwareStatementCredentialTypes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map require software statement credential types", "serviceError", serviceErr) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, serviceErr + } + + requireInitialAccessTokenCredentialTypes, serviceErr := mapAccountCredentialsTypes(opts.RequireInitialAccessTokenCredentialTypes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map require initial access token credential types", "serviceError", serviceErr) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, serviceErr + } + + softwareStatementVerificationMethods, serviceErr := mapSoftwareStatementVerificationMethods(opts.SoftwareStatementVerificationMethods) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map software statement verification methods", "serviceError", serviceErr) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, serviceErr + } + + initialAccessTokenGenerationMethods, serviceErr := mapInitialAccessTokenGenerationMethods(opts.InitialAccessTokenGenerationMethods) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map initial access token generation methods", "serviceError", serviceErr) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, serviceErr + } + + accountID, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account", "serviceError", serviceErr) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, serviceErr + } + + accountDynamicRegistrationConfig, err := s.database.FindAccountDynamicRegistrationConfigByAccountID(ctx, accountID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find account dynamic registration config", "error", err) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, serviceErr + } + + logger.InfoContext(ctx, "Account dynamic registration config not found, creating new one...") + accountDynamicRegistrationConfig, err = s.database.CreateAccountDynamicRegistrationConfig( + ctx, + database.CreateAccountDynamicRegistrationConfigParams{ + AccountID: accountID, + AccountPublicID: opts.AccountPublicID, + AccountCredentialsTypes: credentialsTypes, + WhitelistedDomains: utils.ToEmptySlice(opts.WhitelistedDomains), + RequireSoftwareStatementCredentialTypes: requireSoftwareStatementCredentialTypes, + SoftwareStatementVerificationMethods: softwareStatementVerificationMethods, + RequireInitialAccessTokenCredentialTypes: requireInitialAccessTokenCredentialTypes, + InitialAccessTokenGenerationMethods: initialAccessTokenGenerationMethods, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account dynamic registration config", "error", err) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, exceptions.FromDBError(err) + } + + return dtos.MapAccountDynamicRegistrationConfigToDTO(&accountDynamicRegistrationConfig), true, nil + + } + + accountDynamicRegistrationConfig, err = s.database.UpdateAccountDynamicRegistrationConfig(ctx, database.UpdateAccountDynamicRegistrationConfigParams{ + ID: accountDynamicRegistrationConfig.ID, + AccountCredentialsTypes: credentialsTypes, + WhitelistedDomains: utils.ToEmptySlice(opts.WhitelistedDomains), + RequireSoftwareStatementCredentialTypes: requireSoftwareStatementCredentialTypes, + SoftwareStatementVerificationMethods: softwareStatementVerificationMethods, + RequireInitialAccessTokenCredentialTypes: requireInitialAccessTokenCredentialTypes, + InitialAccessTokenGenerationMethods: initialAccessTokenGenerationMethods, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to update account dynamic registration config", "error", err) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, exceptions.FromDBError(err) + } + + return dtos.MapAccountDynamicRegistrationConfigToDTO(&accountDynamicRegistrationConfig), false, nil +} + +type GetAccountDynamicRegistrationConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID +} + +func (s *Services) GetAccountDynamicRegistrationConfig( + ctx context.Context, + opts GetAccountDynamicRegistrationConfigOptions, +) (dtos.AccountDynamicRegistrationConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountDynamicRegistrationConfigsLocation, "GetAccountDynamicRegistrationConfig").With( + "accountPublicID", opts.AccountPublicID, + ) + logger.InfoContext(ctx, "Retrieving account dynamic registration config...") + + accountDynamicRegistrationConfig, err := s.database.FindAccountDynamicRegistrationConfigByAccountPublicID(ctx, opts.AccountPublicID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find account dynamic registration config", "error", err) + return dtos.AccountDynamicRegistrationConfigDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Account dynamic registration config not found", "error", err) + return dtos.AccountDynamicRegistrationConfigDTO{}, nil + } + + return dtos.MapAccountDynamicRegistrationConfigToDTO(&accountDynamicRegistrationConfig), nil +} diff --git a/idp/internal/services/dtos/account_dynamic_registration_config.go b/idp/internal/services/dtos/account_dynamic_registration_config.go new file mode 100644 index 0000000..964ece5 --- /dev/null +++ b/idp/internal/services/dtos/account_dynamic_registration_config.go @@ -0,0 +1,38 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package dtos + +import "github.com/tugascript/devlogs/idp/internal/providers/database" + +type AccountDynamicRegistrationConfigDTO struct { + id int32 + + CredentialsTypes []database.AccountCredentialsType `json:"credentials_types"` + WhitelistedDomains []string `json:"whitelisted_domains"` + RequireSoftwareStatementCredentialTypes []database.AccountCredentialsType `json:"require_software_statement_credential_types"` + SoftwareStatementVerificationMethods []database.SoftwareStatementVerificationMethod `json:"software_statement_verification_methods"` + RequireInitialAccessTokenCredentialTypes []database.AccountCredentialsType `json:"require_initial_access_token_credential_types"` + InitialAccessTokenGenerationMethods []database.InitialAccessTokenGenerationMethod `json:"initial_access_token_generation_methods"` +} + +func (a *AccountDynamicRegistrationConfigDTO) ID() int32 { + return a.id +} + +func MapAccountDynamicRegistrationConfigToDTO( + config *database.AccountDynamicRegistrationConfig, +) AccountDynamicRegistrationConfigDTO { + return AccountDynamicRegistrationConfigDTO{ + id: config.ID, + CredentialsTypes: config.AccountCredentialsTypes, + WhitelistedDomains: config.WhitelistedDomains, + RequireSoftwareStatementCredentialTypes: config.RequireSoftwareStatementCredentialTypes, + SoftwareStatementVerificationMethods: config.SoftwareStatementVerificationMethods, + RequireInitialAccessTokenCredentialTypes: config.RequireInitialAccessTokenCredentialTypes, + InitialAccessTokenGenerationMethods: config.InitialAccessTokenGenerationMethods, + } +} diff --git a/idp/tests/account_credentials_test.go b/idp/tests/account_credentials_test.go index 52de15e..654ab98 100644 --- a/idp/tests/account_credentials_test.go +++ b/idp/tests/account_credentials_test.go @@ -8,11 +8,8 @@ package tests import ( "context" - rand2 "math/rand/v2" "net/http" - "strings" "testing" - "time" "github.com/google/uuid" @@ -31,7 +28,7 @@ func accountCredentialsCleanUp(t *testing.T) func() { db := GetTestDatabase(t) if err := db.DeleteAllAccountCredentials(context.Background()); err != nil { - t.Fatal("Failed to delete all accounts", err) + t.Fatal("Failed to delete all account credentials", err) } if err := db.DeleteAllCredentialsKeys(context.Background()); err != nil { t.Fatal("Failed to delete all credentials keys", err) @@ -50,14 +47,19 @@ func TestCreateAccountCredentials(t *testing.T) { testCases := []TestRequestCase[bodies.CreateAccountCredentialsBody]{ { - Name: "Should return 201 CREATED with secret and client_secret_jwt", + Name: "Should create service credentials with client_secret_jwt", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken, _ := GenerateTestAccountAuthTokens(t, &account) return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:admin"}, - Alias: "admin", - AuthMethod: "client_secret_jwt", + Type: "service", + Name: "admin-service", + Scopes: []string{"account:admin"}, + TokenEndpointAuthMethod: "client_secret_jwt", + Transport: "https", + ClientURI: "https://admin.example.com", + SoftwareID: "admin-service", + SoftwareVersion: "1.0.0", }, accessToken }, ExpStatus: http.StatusCreated, @@ -69,20 +71,25 @@ func TestCreateAccountCredentials(t *testing.T) { AssertNotEmpty(t, resBody.ClientSecretExp) AssertEmpty(t, resBody.ClientSecretJWK) AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodClientSecretJwt) - AssertEqual(t, len(resBody.Issuers), 0) + AssertEqual(t, resBody.Type, database.AccountCredentialsTypeService) + AssertEqual(t, resBody.Transport, database.TransportHttps) }, }, { - Name: "Should return 201 CREATED with secret and private key JWT with ES256 algorithm", + Name: "Should create service credentials with private_key_jwt and ES256", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken, _ := GenerateTestAccountAuthTokens(t, &account) return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:credentials:read", "account:credentials:write"}, - Alias: "super-key", - AuthMethod: "private_key_jwt", - Issuers: []string{"https://issuer.example.com"}, - Algorithm: "ES256", + Type: "service", + Name: "super-service", + Scopes: []string{"account:credentials:read", "account:credentials:write"}, + TokenEndpointAuthMethod: "private_key_jwt", + Transport: "https", + ClientURI: "https://super.example.com", + SoftwareID: "super-service", + SoftwareVersion: "2.0.0", + Algorithm: "ES256", }, accessToken }, ExpStatus: http.StatusCreated, @@ -94,19 +101,24 @@ func TestCreateAccountCredentials(t *testing.T) { AssertNotEmpty(t, resBody.ClientSecretExp) AssertNotEmpty(t, resBody.ClientSecretJWK) AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodPrivateKeyJwt) + AssertEqual(t, resBody.Type, database.AccountCredentialsTypeService) }, }, { - Name: "Should return 201 CREATED with secret and private key JWT with EdDSA algorithm", + Name: "Should create service credentials with private_key_jwt and EdDSA", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken, _ := GenerateTestAccountAuthTokens(t, &account) return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:credentials:read", "account:credentials:write"}, - Alias: "super-key", - AuthMethod: "private_key_jwt", - Issuers: []string{"https://issuer.example.com"}, - Algorithm: "EdDSA", + Type: "service", + Name: "eddsa-service", + Scopes: []string{"account:credentials:read", "account:credentials:write"}, + TokenEndpointAuthMethod: "private_key_jwt", + Transport: "https", + ClientURI: "https://eddsa.example.com", + SoftwareID: "eddsa-service", + SoftwareVersion: "1.0.0", + Algorithm: "EdDSA", }, accessToken }, ExpStatus: http.StatusCreated, @@ -118,18 +130,23 @@ func TestCreateAccountCredentials(t *testing.T) { AssertNotEmpty(t, resBody.ClientSecretExp) AssertNotEmpty(t, resBody.ClientSecretJWK) AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodPrivateKeyJwt) + AssertEqual(t, resBody.Type, database.AccountCredentialsTypeService) }, }, { - Name: "Should return 201 CREATED with secret and private key JWT with default algorithm", + Name: "Should create service credentials with client_secret_post", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken, _ := GenerateTestAccountAuthTokens(t, &account) return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:credentials:read", "account:credentials:write"}, - Alias: "super-key", - AuthMethod: "private_key_jwt", - Issuers: []string{"https://issuer.example.com"}, + Type: "service", + Name: "app-service", + Scopes: []string{"account:apps:read", "account:apps:write"}, + TokenEndpointAuthMethod: "client_secret_post", + Transport: "https", + ClientURI: "https://app.example.com", + SoftwareID: "app-service", + SoftwareVersion: "1.0.0", }, accessToken }, ExpStatus: http.StatusCreated, @@ -137,21 +154,27 @@ func TestCreateAccountCredentials(t *testing.T) { resBody := AssertTestResponseBody(t, res, dtos.AccountCredentialsDTO{}) AssertNotEmpty(t, resBody.ClientID) AssertNotEmpty(t, resBody.ClientSecretID) - AssertEmpty(t, resBody.ClientSecret) + AssertNotEmpty(t, resBody.ClientSecret) AssertNotEmpty(t, resBody.ClientSecretExp) - AssertNotEmpty(t, resBody.ClientSecretJWK) - AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodPrivateKeyJwt) + AssertEmpty(t, resBody.ClientSecretJWK) + AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodClientSecretPost) + AssertEqual(t, resBody.Type, database.AccountCredentialsTypeService) }, }, { - Name: "Should return 201 CREATED with secret and client secret post", + Name: "Should create service credentials with client_secret_basic", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken, _ := GenerateTestAccountAuthTokens(t, &account) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:apps:read", "account:apps:write"}, - Alias: "app-keys", - AuthMethod: "client_secret_post", + Type: "service", + Name: "user-service", + Scopes: []string{"account:users:read", "account:users:write"}, + TokenEndpointAuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://user.example.com", + SoftwareID: "user-service", + SoftwareVersion: "1.0.0", }, accessToken }, ExpStatus: http.StatusCreated, @@ -162,18 +185,24 @@ func TestCreateAccountCredentials(t *testing.T) { AssertNotEmpty(t, resBody.ClientSecret) AssertNotEmpty(t, resBody.ClientSecretExp) AssertEmpty(t, resBody.ClientSecretJWK) - AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodClientSecretPost) + AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodClientSecretBasic) + AssertEqual(t, resBody.Type, database.AccountCredentialsTypeService) }, }, { - Name: "Should return 201 CREATED with secret and client secret basic", + Name: "Should create MCP credentials with streamable_http transport", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) + accessToken, _ := GenerateTestAccountAuthTokens(t, &account) return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:users:read", "account:users:write"}, - Alias: "user-keys", - AuthMethod: "client_secret_basic", + Type: "mcp", + Name: "mcp-client", + Scopes: []string{"account:admin"}, + TokenEndpointAuthMethod: "client_secret_basic", + Transport: "streamable_http", + ClientURI: "https://mcp.example.com", + SoftwareID: "mcp-client", + SoftwareVersion: "1.0.0", }, accessToken }, ExpStatus: http.StatusCreated, @@ -185,86 +214,136 @@ func TestCreateAccountCredentials(t *testing.T) { AssertNotEmpty(t, resBody.ClientSecretExp) AssertEmpty(t, resBody.ClientSecretJWK) AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodClientSecretBasic) + AssertEqual(t, resBody.Type, database.AccountCredentialsTypeMcp) + AssertEqual(t, resBody.Transport, database.TransportStreamableHttp) }, }, { - Name: "Should return 400 BAD REQUEST with auth method of private_key_jwt but no issuers", + Name: "Should create MCP credentials with stdio transport", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken, _ := GenerateTestAccountAuthTokens(t, &account) return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:credentials:read", "account:credentials:write"}, - Alias: "super-key", - AuthMethod: "private_key_jwt", + Type: "mcp", + Name: "mcp-stdio", + Scopes: []string{"account:admin"}, + TokenEndpointAuthMethod: "private_key_jwt", + Transport: "stdio", + ClientURI: "https://mcp-stdio.example.com", + SoftwareID: "mcp-stdio", + SoftwareVersion: "1.0.0", + Algorithm: "ES256", + }, accessToken + }, + ExpStatus: http.StatusCreated, + AssertFn: func(t *testing.T, _ bodies.CreateAccountCredentialsBody, res *http.Response) { + resBody := AssertTestResponseBody(t, res, dtos.AccountCredentialsDTO{}) + AssertNotEmpty(t, resBody.ClientID) + AssertNotEmpty(t, resBody.ClientSecretID) + AssertEmpty(t, resBody.ClientSecret) + AssertNotEmpty(t, resBody.ClientSecretExp) + AssertNotEmpty(t, resBody.ClientSecretJWK) + AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodPrivateKeyJwt) + AssertEqual(t, resBody.Type, database.AccountCredentialsTypeMcp) + AssertEqual(t, resBody.Transport, database.TransportStdio) + }, + }, + { + Name: "Should reject native credentials creation", + ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { + account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) + accessToken, _ := GenerateTestAccountAuthTokens(t, &account) + return bodies.CreateAccountCredentialsBody{ + Type: "native", + Name: "native-client", + Scopes: []string{"account:admin"}, + TokenEndpointAuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://native.example.com", + SoftwareID: "native-client", + SoftwareVersion: "1.0.0", }, accessToken }, ExpStatus: http.StatusBadRequest, AssertFn: func(t *testing.T, _ bodies.CreateAccountCredentialsBody, res *http.Response) { - resBody := AssertTestResponseBody(t, res, exceptions.ValidationErrorResponse{}) - AssertEqual(t, len(resBody.Fields), 1) - AssertEqual(t, resBody.Fields[0].Param, "issuers") + resBody := AssertTestResponseBody(t, res, exceptions.ErrorResponse{}) + AssertEqual(t, resBody.Message, "Native credentials are not supported") }, }, { - Name: "Should return 400 BAD REQUEST with bad values", + Name: "Should return 400 BAD REQUEST with invalid data", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken, _ := GenerateTestAccountAuthTokens(t, &account) return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"invalid:scope", "account:users:readsd"}, - Alias: "invalid asdfasd ### scope", - AuthMethod: "client_secret_not_valid", - Issuers: []string{"https://issuer.example.com"}, + Type: "service", + Name: "", + Scopes: []string{"invalid:scope", "account:users:readsd"}, + TokenEndpointAuthMethod: "invalid_auth_method", + Transport: "invalid_transport", + ClientURI: "not-a-uri", + SoftwareID: "", + SoftwareVersion: "", }, accessToken }, ExpStatus: http.StatusBadRequest, AssertFn: func(t *testing.T, _ bodies.CreateAccountCredentialsBody, res *http.Response) { resBody := AssertTestResponseBody(t, res, exceptions.ValidationErrorResponse{}) - AssertEqual(t, len(resBody.Fields), 4) - AssertEqual(t, resBody.Fields[0].Param, "scopes[0]") - AssertEqual(t, resBody.Fields[1].Param, "scopes[1]") - AssertEqual(t, resBody.Fields[2].Param, "alias") - AssertEqual(t, resBody.Fields[3].Param, "auth_method") + AssertEqual(t, len(resBody.Fields) >= 5, true) }, }, { - Name: "Should return 409 CONFLICT with existing alias", + Name: "Should return 409 CONFLICT with existing name", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken, _ := GenerateTestAccountAuthTokens(t, &account) + // Create initial credentials if _, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "existing-alias", - Scopes: []string{"account:users:read", "account:users:write"}, - AuthMethod: "private_key_jwt", + CredentialsType: "service", + Name: "existing-name", + Scopes: []string{"account:admin"}, + AuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://existing.example.com", + SoftwareID: "existing-service", + SoftwareVersion: "1.0.0", }); err != nil { t.Fatal("Failed to create initial account credentials", err) } return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:admin"}, - Alias: "existing-alias", - AuthMethod: "client_secret_basic", - Issuers: []string{"https://issuer.example.com"}, + Type: "service", + Name: "existing-name", + Scopes: []string{"account:admin"}, + TokenEndpointAuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://new.example.com", + SoftwareID: "new-service", + SoftwareVersion: "1.0.0", }, accessToken }, ExpStatus: http.StatusConflict, AssertFn: func(t *testing.T, _ bodies.CreateAccountCredentialsBody, res *http.Response) { resBody := AssertTestResponseBody(t, res, exceptions.ErrorResponse{}) - AssertEqual(t, resBody.Message, "Account credentials alias already exists") + AssertEqual(t, resBody.Message, "Account credentials name already exists") }, }, { Name: "Should return 401 UNAUTHORIZED without access token", ReqFn: func(t *testing.T) (bodies.CreateAccountCredentialsBody, string) { return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:credentials:write", "account:auth_providers:read"}, - Alias: "user-keys", - AuthMethod: "client_secret_basic", - Issuers: []string{"https://issuer.example.com"}, + Type: "service", + Name: "unauthorized-service", + Scopes: []string{"account:credentials:write", "account:auth_providers:read"}, + TokenEndpointAuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://unauthorized.example.com", + SoftwareID: "unauthorized-service", + SoftwareVersion: "1.0.0", }, "" }, ExpStatus: http.StatusUnauthorized, @@ -276,10 +355,14 @@ func TestCreateAccountCredentials(t *testing.T) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead}) return bodies.CreateAccountCredentialsBody{ - Scopes: []string{"account:apps:read", "account:apps:write"}, - Alias: "app-keys", - AuthMethod: "client_secret_post", - Issuers: []string{"https://issuer.example.com"}, + Type: "service", + Name: "forbidden-service", + Scopes: []string{"account:apps:read", "account:apps:write"}, + TokenEndpointAuthMethod: "client_secret_post", + Transport: "https", + ClientURI: "https://forbidden.example.com", + SoftwareID: "forbidden-service", + SoftwareVersion: "1.0.0", }, accessToken }, ExpStatus: http.StatusForbidden, @@ -296,173 +379,11 @@ func TestCreateAccountCredentials(t *testing.T) { t.Cleanup(accountCredentialsCleanUp(t)) } -func TestListAccountCredentials(t *testing.T) { - const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase - - listAccountBeforeEach := func(t *testing.T, n int) string { - account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead, tokens.AccountScopeCredentialsWrite}) - - authMethodsList := []string{ - "client_secret_basic", - "client_secret_post", - "client_secret_jwt", - "private_key_jwt", - } - scopesList := [][]string{ - {"account:admin"}, - {"account:credentials:read", "account:credentials:write"}, - {"account:apps:read", "account:apps:write"}, - {"account:users:read", "account:users:write"}, - } - issuersList := [][]string{ - {"https://issuer1.example.com"}, - {"https://issuer2.example.com"}, - {"https://issuer3.example.com"}, - } - - for i := 0; i < n; i++ { - authMethods := authMethodsList[rand2.IntN(len(authMethodsList))] - scopes := scopesList[rand2.IntN(len(scopesList))] - issuers := issuersList[rand2.IntN(len(issuersList))] - alias := "cred-" + uuid.NewString() - - _, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ - RequestID: uuid.NewString(), - AccountPublicID: account.PublicID, - AccountVersion: account.Version(), - Name: alias, - Scopes: scopes, - AuthMethod: authMethods, - Issuers: issuers, - }) - if err != nil { - t.Fatalf("Failed to create account credentials: %v", err) - } - } - - return accessToken - } - - testCases := []TestRequestCase[any]{ - { - Name: "Should return 200 OK without any account credentials", - ReqFn: func(t *testing.T) (any, string) { - account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken, _ := GenerateTestAccountAuthTokens(t, &account) - return nil, accessToken - }, - ExpStatus: http.StatusOK, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.PaginationDTO[dtos.AccountCredentialsDTO]{}) - AssertEqual(t, len(resBody.Items), 0) - AssertEmpty(t, resBody.Next) - AssertEmpty(t, resBody.Previous) - AssertEqual(t, resBody.Total, 0) - }, - Path: accountCredentialsPath, - }, - { - Name: "Should return 200 OK with paginated account credentials", - ReqFn: func(t *testing.T) (any, string) { - accessToken := listAccountBeforeEach(t, 30) - return nil, accessToken - }, - ExpStatus: http.StatusOK, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.PaginationDTO[dtos.AccountCredentialsDTO]{}) - AssertEqual(t, len(resBody.Items), 20) - AssertEqual(t, resBody.Total, 30) - AssertEqual( - t, - strings.Split(resBody.Next, GetTestConfig(t).BackendDomain())[1], - "/v1/accounts/credentials?offset=20&limit=20", - ) - AssertEmpty(t, resBody.Previous) - }, - Path: accountCredentialsPath, - }, - { - Name: "Should return 200 OK with paginated account credentials and previous link", - ReqFn: func(t *testing.T) (any, string) { - accessToken := listAccountBeforeEach(t, 12) - return nil, accessToken - }, - ExpStatus: http.StatusOK, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.PaginationDTO[dtos.AccountCredentialsDTO]{}) - AssertEqual(t, len(resBody.Items), 2) - AssertEqual(t, resBody.Total, 12) - AssertEmpty(t, resBody.Next) - AssertEqual( - t, - strings.Split(resBody.Previous, GetTestConfig(t).BackendDomain())[1], - "/v1/accounts/credentials?offset=0&limit=20", - ) - }, - Path: accountCredentialsPath + "?offset=10&limit=20", - }, - { - Name: "Should return 200 OK with paginated account with next and previous link", - ReqFn: func(t *testing.T) (any, string) { - accessToken := listAccountBeforeEach(t, 20) - return nil, accessToken - }, - ExpStatus: http.StatusOK, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.PaginationDTO[dtos.AccountCredentialsDTO]{}) - backendDomain := GetTestConfig(t).BackendDomain() - AssertEqual(t, len(resBody.Items), 5) - AssertEqual(t, resBody.Total, 20) - AssertEqual( - t, - strings.Split(resBody.Next, backendDomain)[1], - "/v1/accounts/credentials?offset=15&limit=5", - ) - AssertEqual( - t, - strings.Split(resBody.Previous, backendDomain)[1], - "/v1/accounts/credentials?offset=5&limit=5", - ) - }, - Path: accountCredentialsPath + "?offset=10&limit=5", - }, - { - Name: "Should return 401 UNAUTHORIZED without access token", - ReqFn: func(t *testing.T) (any, string) { - return nil, "" - }, - ExpStatus: http.StatusUnauthorized, - AssertFn: AssertUnauthorizedError[any], - Path: accountCredentialsPath, - }, - { - Name: "Should return 403 FORBIDDEN without account:credentials:read scope", - ReqFn: func(t *testing.T) (any, string) { - account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) - return nil, accessToken - }, - ExpStatus: http.StatusForbidden, - AssertFn: AssertForbiddenError[any], - Path: accountCredentialsPath, - }, - } - - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - PerformTestRequestCase(t, http.MethodGet, tc.Path, tc) - }) - } - - t.Cleanup(accountCredentialsCleanUp(t)) -} - -func TestGetAccountCredentials(t *testing.T) { +func TestUpdateAccountCredentials(t *testing.T) { const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase var clientID string - getAccountCredentialBeforeEach := func(t *testing.T) string { + updateAccountCredentialBeforeEach := func(t *testing.T) string { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead, tokens.AccountScopeCredentialsWrite}) @@ -470,141 +391,87 @@ func TestGetAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "get-cred", + CredentialsType: "service", + Name: "update-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", - Issuers: []string{"https://issuer.example.com"}, + Transport: "https", + ClientURI: "https://update.example.com", + SoftwareID: "update-service", + SoftwareVersion: "1.0.0", }) if err != nil { - t.Fatalf("Failed to create account credentials: %v", err) + t.Fatalf("Failed to create initial account credentials: %v", err) } clientID = cred.ClientID return accessToken } - testCases := []TestRequestCase[any]{ + testCases := []TestRequestCase[bodies.UpdateAccountCredentialsBody]{ { - Name: "Should return 200 OK with account credential", - ReqFn: func(t *testing.T) (any, string) { - accessToken := getAccountCredentialBeforeEach(t) - return nil, accessToken + Name: "Should update service credentials name and scopes", + ReqFn: func(t *testing.T) (bodies.UpdateAccountCredentialsBody, string) { + accessToken := updateAccountCredentialBeforeEach(t) + return bodies.UpdateAccountCredentialsBody{ + Name: "updated-service-name", + Scopes: []string{"account:users:read"}, + Transport: "https", + ClientURI: "https://updated.example.com", + SoftwareVersion: "2.0.0", + }, accessToken }, ExpStatus: http.StatusOK, - AssertFn: func(t *testing.T, _ any, res *http.Response) { + AssertFn: func(t *testing.T, _ bodies.UpdateAccountCredentialsBody, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.AccountCredentialsDTO{}) - AssertNotEmpty(t, resBody.ClientID) - AssertNotEmpty(t, resBody.Alias) - AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodClientSecretBasic) - AssertEmpty(t, resBody.ClientSecret) - AssertEmpty(t, resBody.ClientSecretJWK) - }, - PathFn: func() string { - return accountCredentialsPath + "/" + clientID - }, - }, - { - Name: "Should return 404 NOT FOUND for non-existent credential", - ReqFn: func(t *testing.T) (any, string) { - account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead}) - return nil, accessToken - }, - ExpStatus: http.StatusNotFound, - AssertFn: AssertNotFoundError[any], - PathFn: func() string { - return accountCredentialsPath + "/" + utils.Base62UUID() - }, - }, - { - Name: "Should return 401 UNAUTHORIZED without access token", - ReqFn: func(t *testing.T) (any, string) { - getAccountCredentialBeforeEach(t) - return nil, "" + AssertEqual(t, resBody.Name, "updated-service-name") + AssertEqual(t, len(resBody.Scopes), 1) + AssertEqual(t, resBody.Scopes[0], "account:users:read") + AssertEqual(t, resBody.SoftwareVersion, "2.0.0") }, - ExpStatus: http.StatusUnauthorized, - AssertFn: AssertUnauthorizedError[any], PathFn: func() string { return accountCredentialsPath + "/" + clientID }, }, { - Name: "Should return 403 FORBIDDEN without account:credentials:read scope", - ReqFn: func(t *testing.T) (any, string) { + Name: "Should update MCP credentials scopes and software version", + ReqFn: func(t *testing.T) (bodies.UpdateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead, tokens.AccountScopeCredentialsWrite}) + + // Create MCP credentials first cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "forbidden-cred", + CredentialsType: "mcp", + Name: "mcp-update", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", + Transport: "streamable_http", + ClientURI: "https://mcp-update.example.com", + SoftwareID: "mcp-update", + SoftwareVersion: "1.0.0", }) if err != nil { - t.Fatalf("Failed to create account credentials: %v", err) + t.Fatalf("Failed to create MCP credentials: %v", err) } clientID = cred.ClientID - return nil, accessToken - }, - ExpStatus: http.StatusForbidden, - AssertFn: AssertForbiddenError[any], - PathFn: func() string { - return accountCredentialsPath + "/" + clientID - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - PerformTestRequestCaseWithPathFn(t, http.MethodGet, tc) - }) - } - - t.Cleanup(accountCredentialsCleanUp(t)) -} - -func TestUpdateAccountCredentials(t *testing.T) { - const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase - - var clientID string - updateAccountCredentialBeforeEach := func(t *testing.T) string { - account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead, tokens.AccountScopeCredentialsWrite}) - - cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ - RequestID: uuid.NewString(), - AccountPublicID: account.PublicID, - AccountVersion: account.Version(), - Name: "update-cred", - Scopes: []string{"account:admin"}, - AuthMethod: "client_secret_basic", - Issuers: []string{"https://issuer.example.com"}, - }) - if err != nil { - t.Fatalf("Failed to create account credentials: %v", err) - } - clientID = cred.ClientID - return accessToken - } - testCases := []TestRequestCase[bodies.UpdateAccountCredentialsBody]{ - { - Name: "Should return 200 OK and update alias and scopes", - ReqFn: func(t *testing.T) (bodies.UpdateAccountCredentialsBody, string) { - accessToken := updateAccountCredentialBeforeEach(t) return bodies.UpdateAccountCredentialsBody{ - Alias: "updated-alias", - Scopes: []string{"account:users:read"}, - Issuers: []string{"https://issuer-updated.example.com"}, + Name: "updated-mcp-name", + Scopes: []string{"account:users:read", "account:apps:read"}, + ClientURI: "https://updated-mcp.example.com", + SoftwareVersion: "2.0.0", }, accessToken }, ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, _ bodies.UpdateAccountCredentialsBody, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.AccountCredentialsDTO{}) - AssertEqual(t, resBody.Alias, "updated-alias") - AssertEqual(t, len(resBody.Scopes), 1) + AssertEqual(t, resBody.Name, "updated-mcp-name") + AssertEqual(t, len(resBody.Scopes), 2) AssertEqual(t, resBody.Scopes[0], "account:users:read") - AssertEqual(t, resBody.Issuers[0], "https://issuer-updated.example.com") + AssertEqual(t, resBody.Scopes[1], "account:apps:read") + AssertEqual(t, resBody.SoftwareVersion, "2.0.0") }, PathFn: func() string { return accountCredentialsPath + "/" + clientID @@ -615,63 +482,76 @@ func TestUpdateAccountCredentials(t *testing.T) { ReqFn: func(t *testing.T) (bodies.UpdateAccountCredentialsBody, string) { accessToken := updateAccountCredentialBeforeEach(t) return bodies.UpdateAccountCredentialsBody{ - Alias: "invalid alias ###", - Scopes: []string{"account:users:read", "invalid:scope"}, - Issuers: []string{"https://issuer-updated.example.com"}, + Name: "", + Scopes: []string{"account:users:read", "invalid:scope"}, + Transport: "invalid_transport", + ClientURI: "not-a-uri", + SoftwareVersion: "", }, accessToken }, ExpStatus: http.StatusBadRequest, AssertFn: func(t *testing.T, _ bodies.UpdateAccountCredentialsBody, res *http.Response) { resBody := AssertTestResponseBody(t, res, exceptions.ValidationErrorResponse{}) - AssertEqual(t, len(resBody.Fields), 2) - AssertEqual(t, resBody.Fields[0].Param, "scopes[1]") - AssertEqual(t, resBody.Fields[1].Param, "alias") + AssertEqual(t, len(resBody.Fields) >= 3, true) }, PathFn: func() string { return accountCredentialsPath + "/" + clientID }, }, { - Name: "Should return 409 conflict and update alias and scopes", + Name: "Should return 409 CONFLICT with existing name", ReqFn: func(t *testing.T) (bodies.UpdateAccountCredentialsBody, string) { - account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderFacebook)) + account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) - testS := GetTestServices(t) - if _, err := testS.CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ + + // Create first credentials + if _, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "existing-alias", - Scopes: []string{"account:users:read"}, + CredentialsType: "service", + Name: "existing-name", + Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", - Issuers: []string{"updated.example.com"}, + Transport: "https", + ClientURI: "https://existing.example.com", + SoftwareID: "existing-service", + SoftwareVersion: "1.0.0", }); err != nil { - t.Fatalf("Failed to create initial account credentials: %v", err) + t.Fatalf("Failed to create first credentials: %v", err) } - clientCreds, err := testS.CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ + // Create second credentials to update + cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "other-alias", - Scopes: []string{"account:users:read"}, - Issuers: []string{"https://updated.example.com"}, + CredentialsType: "service", + Name: "other-name", + Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://other.example.com", + SoftwareID: "other-service", + SoftwareVersion: "1.0.0", }) if err != nil { - t.Fatalf("Failed to create initial account credentials: %v", err) + t.Fatalf("Failed to create second credentials: %v", err) } - clientID = clientCreds.ClientID + clientID = cred.ClientID + return bodies.UpdateAccountCredentialsBody{ - Alias: "existing-alias", - Scopes: []string{"account:users:read"}, - Issuers: []string{"https://issuer-updated.example.com"}, + Name: "existing-name", + Scopes: []string{"account:users:read"}, + Transport: "https", + ClientURI: "https://updated.example.com", + SoftwareVersion: "2.0.0", }, accessToken }, ExpStatus: http.StatusConflict, AssertFn: func(t *testing.T, _ bodies.UpdateAccountCredentialsBody, res *http.Response) { resBody := AssertTestResponseBody(t, res, exceptions.ErrorResponse{}) - AssertEqual(t, resBody.Message, "Account credentials alias already exists") + AssertEqual(t, resBody.Message, "Account credentials name already exists") }, PathFn: func() string { return accountCredentialsPath + "/" + clientID @@ -683,9 +563,11 @@ func TestUpdateAccountCredentials(t *testing.T) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) return bodies.UpdateAccountCredentialsBody{ - Alias: "new-alias", - Scopes: []string{"account:users:read"}, - Issuers: []string{"https://issuer-updated.example.com"}, + Name: "new-name", + Scopes: []string{"account:users:read"}, + Transport: "https", + ClientURI: "https://new.example.com", + SoftwareVersion: "1.0.0", }, accessToken }, ExpStatus: http.StatusNotFound, @@ -699,9 +581,11 @@ func TestUpdateAccountCredentials(t *testing.T) { ReqFn: func(t *testing.T) (bodies.UpdateAccountCredentialsBody, string) { updateAccountCredentialBeforeEach(t) return bodies.UpdateAccountCredentialsBody{ - Alias: "updated-alias", - Scopes: []string{"account:users:read"}, - Issuers: []string{"https://issuer-updated.example.com"}, + Name: "updated-name", + Scopes: []string{"account:users:read"}, + Transport: "https", + ClientURI: "https://updated.example.com", + SoftwareVersion: "2.0.0", }, "" }, ExpStatus: http.StatusUnauthorized, @@ -715,23 +599,31 @@ func TestUpdateAccountCredentials(t *testing.T) { ReqFn: func(t *testing.T) (bodies.UpdateAccountCredentialsBody, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead}) + cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "forbidden-cred", + CredentialsType: "service", + Name: "forbidden-update", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", - Issuers: []string{"https://issuer.example.com"}, + Transport: "https", + ClientURI: "https://forbidden.example.com", + SoftwareID: "forbidden-service", + SoftwareVersion: "1.0.0", }) if err != nil { - t.Fatalf("Failed to create account credentials: %v", err) + t.Fatalf("Failed to create credentials: %v", err) } clientID = cred.ClientID + return bodies.UpdateAccountCredentialsBody{ - Alias: "updated-alias", - Scopes: []string{"account:users:read"}, - Issuers: []string{"https://issuer-updated.example.com"}, + Name: "updated-name", + Scopes: []string{"account:users:read"}, + Transport: "https", + ClientURI: "https://updated.example.com", + SoftwareVersion: "2.0.0", }, accessToken }, ExpStatus: http.StatusForbidden, @@ -751,11 +643,114 @@ func TestUpdateAccountCredentials(t *testing.T) { t.Cleanup(accountCredentialsCleanUp(t)) } -func TestDeleteAccountCredentials(t *testing.T) { +func TestListAccountCredentials(t *testing.T) { + const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase + + listAccountBeforeEach := func(t *testing.T, n int) string { + account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead, tokens.AccountScopeCredentialsWrite}) + + types := []string{"service", "mcp"} + authMethods := []string{"client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt"} + transports := []string{"https", "streamable_http", "stdio"} + + for i := 0; i < n; i++ { + credType := types[i%len(types)] + authMethod := authMethods[i%len(authMethods)] + transport := transports[i%len(transports)] + name := "cred-" + uuid.NewString() + + _, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ + RequestID: uuid.NewString(), + AccountPublicID: account.PublicID, + AccountVersion: account.Version(), + CredentialsType: credType, + Name: name, + Scopes: []string{"account:admin"}, + AuthMethod: authMethod, + Transport: transport, + ClientURI: "https://" + name + ".example.com", + SoftwareID: name + "-service", + SoftwareVersion: "1.0.0", + }) + if err != nil { + t.Fatalf("Failed to create account credentials: %v", err) + } + } + + return accessToken + } + + testCases := []TestRequestCase[any]{ + { + Name: "Should return 200 OK without any account credentials", + ReqFn: func(t *testing.T) (any, string) { + account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) + accessToken, _ := GenerateTestAccountAuthTokens(t, &account) + return nil, accessToken + }, + ExpStatus: http.StatusOK, + AssertFn: func(t *testing.T, _ any, res *http.Response) { + resBody := AssertTestResponseBody(t, res, dtos.PaginationDTO[dtos.AccountCredentialsDTO]{}) + AssertEqual(t, len(resBody.Items), 0) + AssertEmpty(t, resBody.Next) + AssertEmpty(t, resBody.Previous) + AssertEqual(t, resBody.Total, 0) + }, + Path: accountCredentialsPath, + }, + { + Name: "Should return 200 OK with paginated account credentials", + ReqFn: func(t *testing.T) (any, string) { + accessToken := listAccountBeforeEach(t, 30) + return nil, accessToken + }, + ExpStatus: http.StatusOK, + AssertFn: func(t *testing.T, _ any, res *http.Response) { + resBody := AssertTestResponseBody(t, res, dtos.PaginationDTO[dtos.AccountCredentialsDTO]{}) + AssertEqual(t, len(resBody.Items), 20) + AssertEqual(t, resBody.Total, 30) + AssertNotEmpty(t, resBody.Next) + AssertEmpty(t, resBody.Previous) + }, + Path: accountCredentialsPath, + }, + { + Name: "Should return 401 UNAUTHORIZED without access token", + ReqFn: func(t *testing.T) (any, string) { + return nil, "" + }, + ExpStatus: http.StatusUnauthorized, + AssertFn: AssertUnauthorizedError[any], + Path: accountCredentialsPath, + }, + { + Name: "Should return 403 FORBIDDEN without account:credentials:read scope", + ReqFn: func(t *testing.T) (any, string) { + account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) + return nil, accessToken + }, + ExpStatus: http.StatusForbidden, + AssertFn: AssertForbiddenError[any], + Path: accountCredentialsPath, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + PerformTestRequestCase(t, http.MethodGet, tc.Path, tc) + }) + } + + t.Cleanup(accountCredentialsCleanUp(t)) +} + +func TestGetSingleAccountCredentials(t *testing.T) { const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase var clientID string - deleteAccountCredentialBeforeEach := func(t *testing.T) string { + getAccountCredentialBeforeEach := func(t *testing.T) string { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead, tokens.AccountScopeCredentialsWrite}) @@ -763,10 +758,14 @@ func TestDeleteAccountCredentials(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "delete-cred", + CredentialsType: "service", + Name: "get-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", - Issuers: []string{"https://issuer.example.com"}, + Transport: "https", + ClientURI: "https://get.example.com", + SoftwareID: "get-service", + SoftwareVersion: "1.0.0", }) if err != nil { t.Fatalf("Failed to create account credentials: %v", err) @@ -777,14 +776,19 @@ func TestDeleteAccountCredentials(t *testing.T) { testCases := []TestRequestCase[any]{ { - Name: "Should return 204 NO CONTENT on successful delete", + Name: "Should return 200 OK with account credential", ReqFn: func(t *testing.T) (any, string) { - accessToken := deleteAccountCredentialBeforeEach(t) + accessToken := getAccountCredentialBeforeEach(t) return nil, accessToken }, - ExpStatus: http.StatusNoContent, + ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, _ any, res *http.Response) { - AssertEqual(t, res.StatusCode, http.StatusNoContent) + resBody := AssertTestResponseBody(t, res, dtos.AccountCredentialsDTO{}) + AssertNotEmpty(t, resBody.ClientID) + AssertNotEmpty(t, resBody.Name) + AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodClientSecretBasic) + AssertEmpty(t, resBody.ClientSecret) + AssertEmpty(t, resBody.ClientSecretJWK) }, PathFn: func() string { return accountCredentialsPath + "/" + clientID @@ -794,7 +798,7 @@ func TestDeleteAccountCredentials(t *testing.T) { Name: "Should return 404 NOT FOUND for non-existent credential", ReqFn: func(t *testing.T) (any, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead}) return nil, accessToken }, ExpStatus: http.StatusNotFound, @@ -806,7 +810,7 @@ func TestDeleteAccountCredentials(t *testing.T) { { Name: "Should return 401 UNAUTHORIZED without access token", ReqFn: func(t *testing.T) (any, string) { - deleteAccountCredentialBeforeEach(t) + getAccountCredentialBeforeEach(t) return nil, "" }, ExpStatus: http.StatusUnauthorized, @@ -816,17 +820,23 @@ func TestDeleteAccountCredentials(t *testing.T) { }, }, { - Name: "Should return 403 FORBIDDEN without account:credentials:write scope", + Name: "Should return 403 FORBIDDEN without account:credentials:read scope", ReqFn: func(t *testing.T) (any, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead}) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) + cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), + CredentialsType: "service", Name: "forbidden-cred", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://forbidden.example.com", + SoftwareID: "forbidden-service", + SoftwareVersion: "1.0.0", }) if err != nil { t.Fatalf("Failed to create account credentials: %v", err) @@ -844,155 +854,111 @@ func TestDeleteAccountCredentials(t *testing.T) { for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { - PerformTestRequestCaseWithPathFn(t, http.MethodDelete, tc) + PerformTestRequestCaseWithPathFn(t, http.MethodGet, tc) }) } t.Cleanup(accountCredentialsCleanUp(t)) } -func TestRevokeAccountCredentialsSecret(t *testing.T) { +func TestDeleteAccountCredentials(t *testing.T) { const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase var clientID string - var secretID string - revokeAccountCredentialBeforeEach := func(t *testing.T, authMethods string) string { + deleteAccountCredentialBeforeEach := func(t *testing.T) string { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead, tokens.AccountScopeCredentialsWrite}) cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "revoke-cred", + CredentialsType: "service", + Name: "delete-cred", Scopes: []string{"account:admin"}, - Issuers: []string{"https://issuer.example.com"}, - AuthMethod: authMethods, + AuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://delete.example.com", + SoftwareID: "delete-service", + SoftwareVersion: "1.0.0", }) if err != nil { t.Fatalf("Failed to create account credentials: %v", err) - } - clientID = cred.ClientID - secretID = cred.ClientSecretID - return accessToken - } - - generateFakeKeyID := func(t *testing.T) string { - key, err := utils.GenerateBase64Secret(16) - if err != nil { - t.Fatalf("Failed to generate fake key: %v", err) - } - return key - } - - pathFN := func() string { - return accountCredentialsPath + "/" + clientID + "/secrets/" + secretID - } - - testCases := []TestRequestCase[any]{ - { - Name: "Should return 200 OK on successful secret revoke for client_secret_post", - ReqFn: func(t *testing.T) (any, string) { - accessToken := revokeAccountCredentialBeforeEach(t, "client_secret_post") - return nil, accessToken - }, - ExpStatus: http.StatusOK, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertEqual(t, resBody.PublicID, secretID) - AssertEqual(t, resBody.Status, "revoked") - AssertEmpty(t, resBody.ClientSecretJWK) - AssertEmpty(t, resBody.ClientSecret) - }, - PathFn: pathFN, - }, - { - Name: "Should return 200 OK on successful secret revoke for client_secret_jwt", - ReqFn: func(t *testing.T) (any, string) { - accessToken := revokeAccountCredentialBeforeEach(t, "client_secret_jwt") - return nil, accessToken - }, - ExpStatus: http.StatusOK, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertEqual(t, resBody.PublicID, secretID) - AssertEqual(t, resBody.Status, "revoked") - AssertEmpty(t, resBody.ClientSecretJWK) - AssertEmpty(t, resBody.ClientSecret) - }, - PathFn: pathFN, - }, + } + clientID = cred.ClientID + return accessToken + } + + testCases := []TestRequestCase[any]{ { - Name: "Should return 200 OK on successful secret revoke for private_key_jwt", + Name: "Should return 204 NO CONTENT on successful delete", ReqFn: func(t *testing.T) (any, string) { - accessToken := revokeAccountCredentialBeforeEach(t, "private_key_jwt") + accessToken := deleteAccountCredentialBeforeEach(t) return nil, accessToken }, - ExpStatus: http.StatusOK, + ExpStatus: http.StatusNoContent, AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertEqual(t, resBody.PublicID, secretID) - AssertEqual(t, resBody.Status, "revoked") - AssertEmpty(t, resBody.ClientSecretJWK) - AssertEmpty(t, resBody.ClientSecret) + AssertEqual(t, res.StatusCode, http.StatusNoContent) + }, + PathFn: func() string { + return accountCredentialsPath + "/" + clientID }, - PathFn: pathFN, }, { Name: "Should return 404 NOT FOUND for non-existent credential", ReqFn: func(t *testing.T) (any, string) { - accessToken := revokeAccountCredentialBeforeEach(t, "client_secret_post") - clientID = utils.Base62UUID() + account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) return nil, accessToken }, ExpStatus: http.StatusNotFound, AssertFn: AssertNotFoundError[any], - PathFn: pathFN, - }, - { - Name: "Should return 404 NOT FOUND for non-existent secret", - ReqFn: func(t *testing.T) (any, string) { - accessToken := revokeAccountCredentialBeforeEach(t, "client_secret_basic") - secretID = generateFakeKeyID(t) - return nil, accessToken + PathFn: func() string { + return accountCredentialsPath + "/" + utils.Base62UUID() }, - ExpStatus: http.StatusNotFound, - AssertFn: AssertNotFoundError[any], - PathFn: pathFN, }, { Name: "Should return 401 UNAUTHORIZED without access token", ReqFn: func(t *testing.T) (any, string) { - revokeAccountCredentialBeforeEach(t, "client_secret_post") + deleteAccountCredentialBeforeEach(t) return nil, "" }, ExpStatus: http.StatusUnauthorized, AssertFn: AssertUnauthorizedError[any], - PathFn: pathFN, + PathFn: func() string { + return accountCredentialsPath + "/" + clientID + }, }, { Name: "Should return 403 FORBIDDEN without account:credentials:write scope", ReqFn: func(t *testing.T) (any, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead}) + cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "forbidden-cred", + CredentialsType: "service", + Name: "forbidden-delete", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://forbidden-delete.example.com", + SoftwareID: "forbidden-delete-service", + SoftwareVersion: "1.0.0", }) if err != nil { t.Fatalf("Failed to create account credentials: %v", err) } clientID = cred.ClientID - secretID = cred.ClientSecretID return nil, accessToken }, ExpStatus: http.StatusForbidden, AssertFn: AssertForbiddenError[any], - PathFn: pathFN, + PathFn: func() string { + return accountCredentialsPath + "/" + clientID + }, }, } @@ -1005,7 +971,7 @@ func TestRevokeAccountCredentialsSecret(t *testing.T) { t.Cleanup(accountCredentialsCleanUp(t)) } -func TestListAccountCredentialsSecret(t *testing.T) { +func TestListAccountCredentialsSecrets(t *testing.T) { const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase var clientID string @@ -1017,10 +983,14 @@ func TestListAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), + CredentialsType: "service", Name: "list-cred", Scopes: []string{"account:admin"}, - Issuers: []string{"https://issuer.example.com"}, AuthMethod: authMethods, + Transport: "https", + ClientURI: "https://list.example.com", + SoftwareID: "list-service", + SoftwareVersion: "1.0.0", }) if err != nil { t.Fatalf("Failed to create account credentials: %v", err) @@ -1050,7 +1020,7 @@ func TestListAccountCredentialsSecret(t *testing.T) { PathFn: pathFN, }, { - Name: "Should return 200 OK with secrets for private_key_jwt", + Name: "Should return 200 OK with keys for private_key_jwt", ReqFn: func(t *testing.T) (any, string) { accessToken := listAccountCredentialBeforeEach(t, "private_key_jwt") return nil, accessToken @@ -1090,14 +1060,19 @@ func TestListAccountCredentialsSecret(t *testing.T) { ReqFn: func(t *testing.T) (any, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) + cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "forbidden-cred", + CredentialsType: "service", + Name: "forbidden-list", Scopes: []string{"account:admin"}, - Issuers: []string{"https://issuer.example.com"}, AuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://forbidden-list.example.com", + SoftwareID: "forbidden-list-service", + SoftwareVersion: "1.0.0", }) if err != nil { t.Fatalf("Failed to create account credentials: %v", err) @@ -1123,10 +1098,9 @@ func TestListAccountCredentialsSecret(t *testing.T) { func TestCreateAccountCredentialsSecret(t *testing.T) { const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase - var account dtos.AccountDTO - var clientID, secretID string + var clientID string createAccountCredentialBeforeEach := func(t *testing.T, authMethods string) string { - account = CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) + account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) cred, err := GetTestServices(t).CreateAccountCredentials( @@ -1135,10 +1109,14 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), + CredentialsType: "service", Name: "create-secret-cred", Scopes: []string{"account:admin"}, - Issuers: []string{"https://issuer.example.com"}, AuthMethod: authMethods, + Transport: "https", + ClientURI: "https://create-secret.example.com", + SoftwareID: "create-secret-service", + SoftwareVersion: "1.0.0", }, ) if err != nil { @@ -1146,7 +1124,6 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { } clientID = cred.ClientID - secretID = cred.ClientSecretID return accessToken } @@ -1156,48 +1133,9 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { testCases := []TestRequestCase[any]{ { - Name: "Should return 201 CREATED and create new secret for client_secret_post when the main one is revoked", + Name: "Should return 201 CREATED and create new secret for client_secret_post", ReqFn: func(t *testing.T) (any, string) { accessToken := createAccountCredentialBeforeEach(t, "client_secret_post") - if _, err := GetTestServices(t).RevokeAccountCredentialsSecretOrKey( - context.Background(), - services.RevokeAccountCredentialsSecretOrKeyOptions{ - RequestID: uuid.NewString(), - AccountPublicID: account.PublicID, - AccountVersion: account.Version(), - ClientID: clientID, - SecretID: secretID, - }, - ); err != nil { - t.Fatalf("Failed to revoke account credentials secret: %v", err) - } - return nil, accessToken - }, - ExpStatus: http.StatusCreated, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertNotEmpty(t, resBody.PublicID) - AssertEqual(t, resBody.Status, "active") - AssertNotEmpty(t, resBody.ClientSecret) - }, - PathFn: pathFN, - }, - { - Name: "Should return 201 CREATED and create new secret for client_secret_jwt when the main one is revoked", - ReqFn: func(t *testing.T) (any, string) { - accessToken := createAccountCredentialBeforeEach(t, "client_secret_jwt") - if _, err := GetTestServices(t).RevokeAccountCredentialsSecretOrKey( - context.Background(), - services.RevokeAccountCredentialsSecretOrKeyOptions{ - RequestID: uuid.NewString(), - AccountPublicID: account.PublicID, - AccountVersion: account.Version(), - ClientID: clientID, - SecretID: secretID, - }, - ); err != nil { - t.Fatalf("Failed to revoke account credentials secret: %v", err) - } return nil, accessToken }, ExpStatus: http.StatusCreated, @@ -1210,21 +1148,9 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { PathFn: pathFN, }, { - Name: "Should return 201 CREATED and create new secret for private_key_jwt when the main one is revoked", + Name: "Should return 201 CREATED and create new key for private_key_jwt", ReqFn: func(t *testing.T) (any, string) { accessToken := createAccountCredentialBeforeEach(t, "private_key_jwt") - if _, err := GetTestServices(t).RevokeAccountCredentialsSecretOrKey( - context.Background(), - services.RevokeAccountCredentialsSecretOrKeyOptions{ - RequestID: uuid.NewString(), - AccountPublicID: account.PublicID, - AccountVersion: account.Version(), - ClientID: clientID, - SecretID: secretID, - }, - ); err != nil { - t.Fatalf("Failed to revoke account credentials secret: %v", err) - } return nil, accessToken }, ExpStatus: http.StatusCreated, @@ -1237,160 +1163,146 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { PathFn: pathFN, }, { - Name: "Should return 201 CREATED and create new secret for client_secret_post when the main one is almost expired", + Name: "Should return 404 NOT FOUND for non-existent credential", ReqFn: func(t *testing.T) (any, string) { accessToken := createAccountCredentialBeforeEach(t, "client_secret_post") - if err := GetTestDatabase(t).UpdateCredentialsSecretExpiresAtAndCreatedAt( - context.Background(), - database.UpdateCredentialsSecretExpiresAtAndCreatedAtParams{ - SecretID: secretID, - ExpiresAt: time.Now().Add(24 * time.Hour), // Set to 24 hours from now - CreatedAt: time.Now().Add(-24 * 365 * time.Hour), // Set to 1 year ago - }, - ); err != nil { - t.Fatalf("Failed to revoke account credentials secret: %v", err) - } + clientID = utils.Base62UUID() return nil, accessToken }, - ExpStatus: http.StatusCreated, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertNotEmpty(t, resBody.PublicID) - AssertEqual(t, resBody.Status, "active") - AssertNotEmpty(t, resBody.ClientSecret) - }, - PathFn: pathFN, + ExpStatus: http.StatusNotFound, + AssertFn: AssertNotFoundError[any], + PathFn: pathFN, }, { - Name: "Should return 201 CREATED and create new secret for client_secret_jwt when the main one is almost expired", + Name: "Should return 401 UNAUTHORIZED without access token", ReqFn: func(t *testing.T) (any, string) { - accessToken := createAccountCredentialBeforeEach(t, "client_secret_jwt") - if err := GetTestDatabase(t).UpdateCredentialsSecretExpiresAtAndCreatedAt( - context.Background(), - database.UpdateCredentialsSecretExpiresAtAndCreatedAtParams{ - SecretID: secretID, - ExpiresAt: time.Now().Add(24 * time.Hour), // Set to 24 hours from now - CreatedAt: time.Now().Add(-24 * 365 * time.Hour), // Set to 1 year ago - }, - ); err != nil { - t.Fatalf("Failed to revoke account credentials secret: %v", err) - } - return nil, accessToken - }, - ExpStatus: http.StatusCreated, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertNotEmpty(t, resBody.PublicID) - AssertEqual(t, resBody.Status, "active") - AssertNotEmpty(t, resBody.ClientSecret) + createAccountCredentialBeforeEach(t, "client_secret_post") + return nil, "" }, - PathFn: pathFN, + ExpStatus: http.StatusUnauthorized, + AssertFn: AssertUnauthorizedError[any], + PathFn: pathFN, }, { - Name: "Should return 201 CREATED and create new secret for private_key_jwt when the main one is almost expired", + Name: "Should return 403 FORBIDDEN without account:credentials:write scope", ReqFn: func(t *testing.T) (any, string) { - accessToken := createAccountCredentialBeforeEach(t, "private_key_jwt") - if err := GetTestDatabase(t).UpdateCredentialsKeyExpiresAtAndCreatedAt( - context.Background(), - database.UpdateCredentialsKeyExpiresAtAndCreatedAtParams{ - PublicKid: secretID, - ExpiresAt: time.Now().Add(24 * time.Hour), // Set to 24 hours from now - CreatedAt: time.Now().Add(-24 * 365 * time.Hour), // Set to 1 year ago - }, - ); err != nil { - t.Fatalf("Failed to revoke account credentials secret: %v", err) + account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead}) + + cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ + RequestID: uuid.NewString(), + AccountPublicID: account.PublicID, + AccountVersion: account.Version(), + CredentialsType: "service", + Name: "forbidden-create-secret", + Scopes: []string{"account:admin"}, + AuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://forbidden-create-secret.example.com", + SoftwareID: "forbidden-create-secret-service", + SoftwareVersion: "1.0.0", + }) + if err != nil { + t.Fatalf("Failed to create account credentials: %v", err) } + clientID = cred.ClientID return nil, accessToken }, - ExpStatus: http.StatusCreated, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertNotEmpty(t, resBody.PublicID) - AssertEqual(t, resBody.Status, "active") - AssertNotEmpty(t, resBody.ClientSecretJWK) - }, - PathFn: pathFN, + ExpStatus: http.StatusForbidden, + AssertFn: AssertForbiddenError[any], + PathFn: pathFN, }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + PerformTestRequestCaseWithPathFn(t, http.MethodPost, tc) + }) + } + + t.Cleanup(accountCredentialsCleanUp(t)) +} + +func TestGetAccountCredentialsSecret(t *testing.T) { + const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase + + var clientID, secretID string + getAccountCredentialSecretBeforeEach := func(t *testing.T, authMethods string) string { + account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead, tokens.AccountScopeCredentialsWrite}) + + cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ + RequestID: uuid.NewString(), + AccountPublicID: account.PublicID, + AccountVersion: account.Version(), + CredentialsType: "service", + Name: "get-secret-cred", + Scopes: []string{"account:admin"}, + AuthMethod: authMethods, + Transport: "https", + ClientURI: "https://get-secret.example.com", + SoftwareID: "get-secret-service", + SoftwareVersion: "1.0.0", + }) + if err != nil { + t.Fatalf("Failed to create account credentials: %v", err) + } + clientID = cred.ClientID + secretID = cred.ClientSecretID + return accessToken + } + + pathFN := func() string { + return accountCredentialsPath + "/" + clientID + "/secrets/" + secretID + } + + testCases := []TestRequestCase[any]{ { - Name: "Should return 201 CREATED and create new secret for client_secret_post when the main one is expired", + Name: "Should return 200 OK with secret for client_secret_post", ReqFn: func(t *testing.T) (any, string) { - accessToken := createAccountCredentialBeforeEach(t, "client_secret_post") - if err := GetTestDatabase(t).UpdateCredentialsSecretExpiresAtAndCreatedAt( - context.Background(), - database.UpdateCredentialsSecretExpiresAtAndCreatedAtParams{ - SecretID: secretID, - ExpiresAt: time.Now().Add(-24 * time.Hour), // Set to 24 hours from now - CreatedAt: time.Now().Add(-24 * 366 * time.Hour), // Set to 1 year ago - }, - ); err != nil { - t.Fatalf("Failed to revoke account credentials secret: %v", err) - } + accessToken := getAccountCredentialSecretBeforeEach(t, "client_secret_post") return nil, accessToken }, - ExpStatus: http.StatusCreated, + ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, _ any, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertNotEmpty(t, resBody.PublicID) + AssertEqual(t, resBody.PublicID, secretID) AssertEqual(t, resBody.Status, "active") - AssertNotEmpty(t, resBody.ClientSecret) + AssertEmpty(t, resBody.ClientSecret) }, PathFn: pathFN, }, { - Name: "Should return 201 CREATED and create new secret for client_secret_jwt when the main one is expired", + Name: "Should return 200 OK with key for private_key_jwt", ReqFn: func(t *testing.T) (any, string) { - accessToken := createAccountCredentialBeforeEach(t, "client_secret_jwt") - if err := GetTestDatabase(t).UpdateCredentialsSecretExpiresAtAndCreatedAt( - context.Background(), - database.UpdateCredentialsSecretExpiresAtAndCreatedAtParams{ - SecretID: secretID, - ExpiresAt: time.Now().Add(-24 * time.Hour), // Set to 24 hours from now - CreatedAt: time.Now().Add(-24 * 366 * time.Hour), // Set to 1 year ago - }, - ); err != nil { - t.Fatalf("Failed to revoke account credentials secret: %v", err) - } + accessToken := getAccountCredentialSecretBeforeEach(t, "private_key_jwt") return nil, accessToken }, - ExpStatus: http.StatusCreated, + ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, _ any, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertNotEmpty(t, resBody.PublicID) + AssertEqual(t, resBody.PublicID, secretID) AssertEqual(t, resBody.Status, "active") - AssertNotEmpty(t, resBody.ClientSecret) + AssertNotEmpty(t, resBody.ClientSecretJWK) }, PathFn: pathFN, }, { - Name: "Should return 201 CREATED and create new secret for private_key_jwt when the main one is expired", + Name: "Should return 404 NOT FOUND for non-existent credential", ReqFn: func(t *testing.T) (any, string) { - accessToken := createAccountCredentialBeforeEach(t, "private_key_jwt") - if err := GetTestDatabase(t).UpdateCredentialsKeyExpiresAtAndCreatedAt( - context.Background(), - database.UpdateCredentialsKeyExpiresAtAndCreatedAtParams{ - PublicKid: secretID, - ExpiresAt: time.Now().Add(-24 * time.Hour), // Set to 24 hours from now - CreatedAt: time.Now().Add(-24 * 366 * time.Hour), // Set to 1 year ago - }, - ); err != nil { - t.Fatalf("Failed to revoke account credentials secret: %v", err) - } + accessToken := getAccountCredentialSecretBeforeEach(t, "client_secret_post") + clientID = utils.Base62UUID() return nil, accessToken }, - ExpStatus: http.StatusCreated, - AssertFn: func(t *testing.T, _ any, res *http.Response) { - resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) - AssertNotEmpty(t, resBody.PublicID) - AssertEqual(t, resBody.Status, "active") - AssertNotEmpty(t, resBody.ClientSecretJWK) - }, - PathFn: pathFN, + ExpStatus: http.StatusNotFound, + AssertFn: AssertNotFoundError[any], + PathFn: pathFN, }, { - Name: "Should return 404 NOT FOUND for non-existent credential", + Name: "Should return 404 NOT FOUND for non-existent secret", ReqFn: func(t *testing.T) (any, string) { - accessToken := createAccountCredentialBeforeEach(t, "client_secret_post") - clientID = utils.Base62UUID() + accessToken := getAccountCredentialSecretBeforeEach(t, "client_secret_basic") + secretID = utils.Base62UUID() return nil, accessToken }, ExpStatus: http.StatusNotFound, @@ -1400,7 +1312,7 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { { Name: "Should return 401 UNAUTHORIZED without access token", ReqFn: func(t *testing.T) (any, string) { - createAccountCredentialBeforeEach(t, "client_secret_post") + getAccountCredentialSecretBeforeEach(t, "client_secret_post") return nil, "" }, ExpStatus: http.StatusUnauthorized, @@ -1408,22 +1320,29 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { PathFn: pathFN, }, { - Name: "Should return 403 FORBIDDEN without account:credentials:write scope", + Name: "Should return 403 FORBIDDEN without account:credentials:read scope", ReqFn: func(t *testing.T) (any, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead}) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) + cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "forbidden-cred", + CredentialsType: "service", + Name: "forbidden-get-secret", Scopes: []string{"account:admin"}, AuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://forbidden-get-secret.example.com", + SoftwareID: "forbidden-get-secret-service", + SoftwareVersion: "1.0.0", }) if err != nil { t.Fatalf("Failed to create account credentials: %v", err) } clientID = cred.ClientID + secretID = cred.ClientSecretID return nil, accessToken }, ExpStatus: http.StatusForbidden, @@ -1434,29 +1353,34 @@ func TestCreateAccountCredentialsSecret(t *testing.T) { for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { - PerformTestRequestCaseWithPathFn(t, http.MethodPost, tc) + PerformTestRequestCaseWithPathFn(t, http.MethodGet, tc) }) } t.Cleanup(accountCredentialsCleanUp(t)) } -func TestGetAccountCredentialsSecret(t *testing.T) { +func TestRevokeAccountCredentialsSecret(t *testing.T) { const accountCredentialsPath = v1Path + paths.AccountsBase + paths.CredentialsBase - var clientID, secretID string - getAccountCredentialSecretBeforeEach := func(t *testing.T, authMethods string) string { + var clientID string + var secretID string + revokeAccountCredentialBeforeEach := func(t *testing.T, authMethods string) string { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead, tokens.AccountScopeCredentialsWrite}) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "get-secret-cred", + CredentialsType: "service", + Name: "revoke-cred", Scopes: []string{"account:admin"}, - Issuers: []string{"https://issuer.example.com"}, AuthMethod: authMethods, + Transport: "https", + ClientURI: "https://revoke.example.com", + SoftwareID: "revoke-service", + SoftwareVersion: "1.0.0", }) if err != nil { t.Fatalf("Failed to create account credentials: %v", err) @@ -1472,39 +1396,41 @@ func TestGetAccountCredentialsSecret(t *testing.T) { testCases := []TestRequestCase[any]{ { - Name: "Should return 200 OK with secret for client_secret_post", + Name: "Should return 200 OK on successful secret revoke for client_secret_post", ReqFn: func(t *testing.T) (any, string) { - accessToken := getAccountCredentialSecretBeforeEach(t, "client_secret_post") + accessToken := revokeAccountCredentialBeforeEach(t, "client_secret_post") return nil, accessToken }, ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, _ any, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) AssertEqual(t, resBody.PublicID, secretID) - AssertEqual(t, resBody.Status, "active") + AssertEqual(t, resBody.Status, "revoked") + AssertEmpty(t, resBody.ClientSecretJWK) AssertEmpty(t, resBody.ClientSecret) }, PathFn: pathFN, }, { - Name: "Should return 200 OK with secret for private_key_jwt", + Name: "Should return 200 OK on successful key revoke for private_key_jwt", ReqFn: func(t *testing.T) (any, string) { - accessToken := getAccountCredentialSecretBeforeEach(t, "private_key_jwt") + accessToken := revokeAccountCredentialBeforeEach(t, "private_key_jwt") return nil, accessToken }, ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, _ any, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.ClientCredentialsSecretDTO{}) AssertEqual(t, resBody.PublicID, secretID) - AssertEqual(t, resBody.Status, "active") + AssertEqual(t, resBody.Status, "revoked") AssertEmpty(t, resBody.ClientSecretJWK) + AssertEmpty(t, resBody.ClientSecret) }, PathFn: pathFN, }, { Name: "Should return 404 NOT FOUND for non-existent credential", ReqFn: func(t *testing.T) (any, string) { - accessToken := getAccountCredentialSecretBeforeEach(t, "client_secret_post") + accessToken := revokeAccountCredentialBeforeEach(t, "client_secret_post") clientID = utils.Base62UUID() return nil, accessToken }, @@ -1515,7 +1441,7 @@ func TestGetAccountCredentialsSecret(t *testing.T) { { Name: "Should return 404 NOT FOUND for non-existent secret", ReqFn: func(t *testing.T) (any, string) { - accessToken := getAccountCredentialSecretBeforeEach(t, "client_secret_basic") + accessToken := revokeAccountCredentialBeforeEach(t, "client_secret_basic") secretID = utils.Base62UUID() return nil, accessToken }, @@ -1526,7 +1452,7 @@ func TestGetAccountCredentialsSecret(t *testing.T) { { Name: "Should return 401 UNAUTHORIZED without access token", ReqFn: func(t *testing.T) (any, string) { - getAccountCredentialSecretBeforeEach(t, "client_secret_post") + revokeAccountCredentialBeforeEach(t, "client_secret_post") return nil, "" }, ExpStatus: http.StatusUnauthorized, @@ -1534,18 +1460,23 @@ func TestGetAccountCredentialsSecret(t *testing.T) { PathFn: pathFN, }, { - Name: "Should return 403 FORBIDDEN without account:credentials:read scope", + Name: "Should return 403 FORBIDDEN without account:credentials:write scope", ReqFn: func(t *testing.T) (any, string) { account := CreateTestAccount(t, GenerateFakeAccountData(t, services.AuthProviderGoogle)) - accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsWrite}) + accessToken := GenerateScopedAccountAccessToken(t, &account, []string{tokens.AccountScopeCredentialsRead}) + cred, err := GetTestServices(t).CreateAccountCredentials(context.Background(), services.CreateAccountCredentialsOptions{ RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), - Name: "forbidden-cred", + CredentialsType: "service", + Name: "forbidden-revoke", Scopes: []string{"account:admin"}, - Issuers: []string{"https://issuer.example.com"}, AuthMethod: "client_secret_basic", + Transport: "https", + ClientURI: "https://forbidden-revoke.example.com", + SoftwareID: "forbidden-revoke-service", + SoftwareVersion: "1.0.0", }) if err != nil { t.Fatalf("Failed to create account credentials: %v", err) @@ -1562,7 +1493,7 @@ func TestGetAccountCredentialsSecret(t *testing.T) { for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { - PerformTestRequestCaseWithPathFn(t, http.MethodGet, tc) + PerformTestRequestCaseWithPathFn(t, http.MethodDelete, tc) }) } diff --git a/idp/tests/oauth_test.go b/idp/tests/oauth_test.go index b0c2909..f3aa753 100644 --- a/idp/tests/oauth_test.go +++ b/idp/tests/oauth_test.go @@ -750,10 +750,17 @@ func TestOAuthToken(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), + CredentialsType: string(database.AccountCredentialsTypeService), Name: "update-cred", Scopes: []string{"account:admin"}, AuthMethod: "private_key_jwt", - Issuers: []string{"https://issuer.example.com"}, + Domain: "issuer.example.com", + ClientURI: "https://issuer.example.com", + Transport: "https", + SoftwareID: "test-software", + SoftwareVersion: "1.0.0", + Contacts: []string{"test@example.com"}, + CreationMethod: database.CreationMethodManual, Algorithm: string(algorithm), }) if serviceErr != nil { @@ -826,10 +833,17 @@ func TestOAuthToken(t *testing.T) { RequestID: uuid.NewString(), AccountPublicID: account.PublicID, AccountVersion: account.Version(), + CredentialsType: string(database.AccountCredentialsTypeService), Name: "update-cred", Scopes: []string{"account:admin"}, AuthMethod: am, - Issuers: []string{"https://issuer.example.com"}, + Domain: "issuer.example.com", + ClientURI: "https://issuer.example.com", + Transport: "https", + SoftwareID: "test-software", + SoftwareVersion: "1.0.0", + Contacts: []string{"test@example.com"}, + CreationMethod: database.CreationMethodManual, }) if serviceErr != nil { t.Fatalf("Failed to create account credentials: %v", serviceErr) From 4af5f8a8c5f859e391aad789679c13abbd923b7f Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sat, 16 Aug 2025 14:27:54 +1200 Subject: [PATCH 04/23] feat(idp): add dynamic configuration domain configs --- idp/dbml-error.log | 3 + idp/initial_schema.dbml | 21 +++++ idp/internal/config/config.go | 92 ++++++++++++------- ...0241213231542_create_initial_schema.up.sql | 25 ++++- .../account_dynamic_registration_tokens.go | 7 ++ 5 files changed, 116 insertions(+), 32 deletions(-) create mode 100644 idp/internal/services/account_dynamic_registration_tokens.go diff --git a/idp/dbml-error.log b/idp/dbml-error.log index 22438cb..fca0d60 100644 --- a/idp/dbml-error.log +++ b/idp/dbml-error.log @@ -31,3 +31,6 @@ undefined 2025-08-03T02:07:21.986Z undefined +2025-08-16T01:54:09.751Z +undefined + diff --git a/idp/initial_schema.dbml b/idp/initial_schema.dbml index 651c4a7..2eb28d8 100644 --- a/idp/initial_schema.dbml +++ b/idp/initial_schema.dbml @@ -891,6 +891,27 @@ Table account_dynamic_registration_configs as ADRC { } Ref: ADRC.account_id > A.id [delete: cascade] +Table account_dynamic_registration_domains as ADRD { + id serial [pk] + + account_id integer [not null] + + domain varchar(250) [not null] + verification_host varchar(50) [not null] + verification_code text [not null] + expires_at timestamptz [not null] + + created_at timestamptz [not null, default: `now()`] + updated_at timestamptz [not null, default: `now()`] + + Indexes { + (account_id) [name: 'accounts_totps_account_id_idx'] + (domain) [name: 'account_dynamic_registration_domains_domain_idx'] + (account_id, domain) [unique, name: 'account_dynamic_registration_domains_account_id_domain_uidx'] + } +} +Ref: ADRD.account_id > A.id [delete: cascade] + Table app_dynamic_registration_configs as APDRC { id serial [pk] diff --git a/idp/internal/config/config.go b/idp/internal/config/config.go index c8f9f94..fde9102 100644 --- a/idp/internal/config/config.go +++ b/idp/internal/config/config.go @@ -17,32 +17,36 @@ import ( ) type Config struct { - port int64 - env string - maxProcs int64 - databaseURL string - valkeyURL string - frontendDomain string - backendDomain string - cookieSecret string - cookieName string - emailPubChannel string - encryptionSecret string - serviceID uuid.UUID - serviceName string - loggerConfig LoggerConfig - tokensConfig TokensConfig - oAuthProvidersConfig OAuthProvidersConfig - rateLimiterConfig RateLimiterConfig - openBaoConfig OpenBaoConfig - cryptoConfig CryptoConfig - distributedCache DistributedCache - kekExpirationDays int64 - dekExpirationDays int64 - jwkExpirationDays int64 - accountCCExpDays int64 - userCCExpDays int64 - appCCExpDays int64 + port int64 + env string + maxProcs int64 + databaseURL string + valkeyURL string + frontendDomain string + backendDomain string + cookieSecret string + cookieName string + emailPubChannel string + encryptionSecret string + serviceID uuid.UUID + serviceName string + loggerConfig LoggerConfig + tokensConfig TokensConfig + oAuthProvidersConfig OAuthProvidersConfig + rateLimiterConfig RateLimiterConfig + openBaoConfig OpenBaoConfig + cryptoConfig CryptoConfig + distributedCache DistributedCache + kekExpirationDays int64 + dekExpirationDays int64 + jwkExpirationDays int64 + accountCCExpDays int64 + userCCExpDays int64 + appCCExpDays int64 + accountDomainVerificationHost string + appsDomainVerificationHost string + accountDomainVerificationTTL int64 + appsDomainVerificationTTL int64 } func (c *Config) Port() int64 { @@ -149,7 +153,23 @@ func (c *Config) AppCCExpDays() int64 { return c.appCCExpDays } -var variables = [45]string{ +func (c *Config) AccountDomainVerificationHost() string { + return c.accountDomainVerificationHost +} + +func (c *Config) AppsDomainVerificationHost() string { + return c.appsDomainVerificationHost +} + +func (c *Config) AccountDomainVerificationTTL() int64 { + return c.accountDomainVerificationTTL +} + +func (c *Config) AppsDomainVerificationTTL() int64 { + return c.appsDomainVerificationTTL +} + +var variables = [49]string{ "PORT", "ENV", "DEBUG", @@ -195,6 +215,10 @@ var variables = [45]string{ "APP_CLIENT_CREDENTIALS_EXPIRATION_DAYS", "OAUTH_STATE_TTL_SEC", "OAUTH_CODE_TTL_SEC", + "ACCOUNT_CREDENTIALS_DOMAIN_VERIFICATION_HOST", + "ACCOUNT_CREDENTIALS_DOMAIN_VERIFICATION_TTL_SEC", + "APPS_DOMAIN_VERIFICATION_HOST", + "APPS_DOMAIN_VERIFICATION_TTL_SEC", } var optionalVariables = [10]string{ @@ -210,7 +234,7 @@ var optionalVariables = [10]string{ "MICROSOFT_CLIENT_SECRET", } -var numerics = [29]string{ +var numerics = [31]string{ "PORT", "MAX_PROCS", "JWT_ACCESS_TTL_SEC", @@ -240,6 +264,8 @@ var numerics = [29]string{ "APP_CLIENT_CREDENTIALS_EXPIRATION_DAYS", "OAUTH_STATE_TTL_SEC", "OAUTH_CODE_TTL_SEC", + "ACCOUNT_CREDENTIALS_DOMAIN_VERIFICATION_TTL_SEC", + "APPS_DOMAIN_VERIFICATION_TTL_SEC", } func NewConfig(logger *slog.Logger, envPath string) Config { @@ -337,8 +363,12 @@ func NewConfig(logger *slog.Logger, envPath string) Config { intMap["OAUTH_STATE_TTL_SEC"], intMap["OAUTH_CODE_TTL_SEC"], ), - accountCCExpDays: intMap["ACCOUNT_CLIENT_CREDENTIALS_EXPIRATION_DAYS"], - userCCExpDays: intMap["USER_CLIENT_CREDENTIALS_EXPIRATION_DAYS"], - appCCExpDays: intMap["APP_CLIENT_CREDENTIALS_EXPIRATION_DAYS"], + accountCCExpDays: intMap["ACCOUNT_CLIENT_CREDENTIALS_EXPIRATION_DAYS"], + userCCExpDays: intMap["USER_CLIENT_CREDENTIALS_EXPIRATION_DAYS"], + appCCExpDays: intMap["APP_CLIENT_CREDENTIALS_EXPIRATION_DAYS"], + accountDomainVerificationHost: variablesMap["ACCOUNT_CREDENTIALS_DOMAIN_VERIFICATION_HOST"], + appsDomainVerificationHost: variablesMap["APPS_DOMAIN_VERIFICATION_HOST"], + accountDomainVerificationTTL: intMap["ACCOUNT_CREDENTIALS_DOMAIN_VERIFICATION_TTL_SEC"], + appsDomainVerificationTTL: intMap["APPS_DOMAIN_VERIFICATION_TTL_SEC"], } } diff --git a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql index d96b79d..019dbe4 100644 --- a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql +++ b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql @@ -1,6 +1,6 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-08-13T20:40:02.036Z +-- Generated at: 2025-08-16T01:54:22.889Z CREATE TYPE "kek_usage" AS ENUM ( 'global', @@ -545,6 +545,17 @@ CREATE TABLE "account_dynamic_registration_configs" ( "updated_at" timestamptz NOT NULL DEFAULT (now()) ); +CREATE TABLE "account_dynamic_registration_domains" ( + "id" serial PRIMARY KEY, + "account_id" integer NOT NULL, + "domain" varchar(250) NOT NULL, + "verification_host" varchar(50) NOT NULL, + "verification_code" text NOT NULL, + "dek_kid" varchar(22) NOT NULL, + "created_at" timestamptz NOT NULL DEFAULT (now()), + "updated_at" timestamptz NOT NULL DEFAULT (now()) +); + CREATE TABLE "app_dynamic_registration_configs" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, @@ -833,6 +844,14 @@ CREATE UNIQUE INDEX "account_dynamic_registration_configs_account_id_uidx" ON "a CREATE INDEX "account_dynamic_registration_configs_account_public_id_idx" ON "account_dynamic_registration_configs" ("account_public_id"); +CREATE INDEX "accounts_totps_dek_kid_idx" ON "account_dynamic_registration_domains" ("dek_kid"); + +CREATE INDEX "accounts_totps_account_id_idx" ON "account_dynamic_registration_domains" ("account_id"); + +CREATE INDEX "account_dynamic_registration_domains_domain_idx" ON "account_dynamic_registration_domains" ("domain"); + +CREATE UNIQUE INDEX "account_dynamic_registration_domains_account_id_domain_uidx" ON "account_dynamic_registration_domains" ("account_id", "domain"); + CREATE INDEX "app_dynamic_registration_configs_account_id_idx" ON "app_dynamic_registration_configs" ("account_id"); CREATE INDEX "user_profiles_app_id_idx" ON "app_profiles" ("app_id"); @@ -965,6 +984,10 @@ ALTER TABLE "app_designs" ADD FOREIGN KEY ("app_id") REFERENCES "apps" ("id") ON ALTER TABLE "account_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; +ALTER TABLE "account_dynamic_registration_domains" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; + +ALTER TABLE "account_dynamic_registration_domains" ADD FOREIGN KEY ("dek_kid") REFERENCES "data_encryption_keys" ("kid") ON DELETE CASCADE ON UPDATE CASCADE; + ALTER TABLE "app_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; ALTER TABLE "app_profiles" ADD FOREIGN KEY ("app_id") REFERENCES "apps" ("id") ON DELETE CASCADE; diff --git a/idp/internal/services/account_dynamic_registration_tokens.go b/idp/internal/services/account_dynamic_registration_tokens.go new file mode 100644 index 0000000..1df7cfd --- /dev/null +++ b/idp/internal/services/account_dynamic_registration_tokens.go @@ -0,0 +1,7 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services From f53231a1f3bbeda90355c8aa8ad0d210cf32182e Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sat, 16 Aug 2025 22:28:33 +1200 Subject: [PATCH 05/23] feat(idp): add account registration domain creation --- idp/initial_schema.dbml | 63 ++++++- idp/internal/config/config.go | 24 ++- idp/internal/config/encryption.go | 14 +- idp/internal/exceptions/services.go | 8 + idp/internal/providers/crypto/encryption.go | 5 - idp/internal/providers/crypto/hmac.go | 158 ++++++++++++++++++ ...t_dynamic_registration_domain_codes.sql.go | 60 +++++++ ...ccount_dynamic_registration_domains.sql.go | 85 ++++++++++ .../database/account_hmac_secrets.sql.go | 125 ++++++++++++++ ...0241213231542_create_initial_schema.up.sql | 64 ++++++- idp/internal/providers/database/models.go | 78 +++++++++ ...ount_dynamic_registration_domain_codes.sql | 24 +++ .../account_dynamic_registration_domains.sql | 21 +++ .../database/queries/account_hmac_secrets.sql | 40 +++++ idp/internal/server/server.go | 3 + .../account_credentials_registration.go | 156 +++++++++++++++++ idp/internal/services/account_hmac_secrets.go | 157 +++++++++++++++++ idp/internal/services/deks.go | 32 ++-- ...account_credentials_registration_domain.go | 50 ++++++ idp/internal/services/helpers.go | 7 + idp/internal/services/keks.go | 34 ++-- idp/internal/services/services.go | 61 ++++--- idp/tests/common_test.go | 3 + 23 files changed, 1184 insertions(+), 88 deletions(-) create mode 100644 idp/internal/providers/crypto/hmac.go create mode 100644 idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go create mode 100644 idp/internal/providers/database/account_dynamic_registration_domains.sql.go create mode 100644 idp/internal/providers/database/account_hmac_secrets.sql.go create mode 100644 idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql create mode 100644 idp/internal/providers/database/queries/account_dynamic_registration_domains.sql create mode 100644 idp/internal/providers/database/queries/account_hmac_secrets.sql create mode 100644 idp/internal/services/account_hmac_secrets.go create mode 100644 idp/internal/services/dtos/account_credentials_registration_domain.go diff --git a/idp/initial_schema.dbml b/idp/initial_schema.dbml index 2eb28d8..2aab25f 100644 --- a/idp/initial_schema.dbml +++ b/idp/initial_schema.dbml @@ -263,6 +263,30 @@ Table account_data_encryption_keys as ADEK { Ref: ADEK.account_id > A.id [delete: cascade] Ref: ADEK.data_encryption_key_id > DEK.id [delete: cascade] +Table account_hmac_secrets as AHS { + id serial [pk] + + account_id integer [not null] + + secret_id varchar(22) [not null] + secret text [not null] + dek_kid varchar(22) [not null] + is_revoked boolean [not null, default: false] + expires_at timestamptz [not null] + + created_at timestamptz [not null, default: `now()`] + + Indexes { + (account_id) [name: 'account_hmac_secrets_account_id_idx'] + (secret_id) [unique, name: 'account_hmac_secrets_secret_id_uidx'] + (dek_kid) [name: 'account_hmac_secrets_dek_kid_idx'] + (account_id, secret_id) [name: 'account_hmac_secrets_account_id_secret_id_idx'] + (account_id, is_revoked, expires_at) [name: 'account_hmac_secrets_account_id_is_revoked_expires_at_idx'] + } +} +Ref: AHS.account_id > A.id [delete: cascade] +Ref: AHS.dek_kid > DEK.kid [delete: cascade, update: cascade] + Table account_totps as AT { account_id integer [not null] totp_id integer [not null] @@ -891,27 +915,58 @@ Table account_dynamic_registration_configs as ADRC { } Ref: ADRC.account_id > A.id [delete: cascade] +Enum domain_verification_method { + "authorization_code" + "software_statement" + "dns_txt_record" +} + Table account_dynamic_registration_domains as ADRD { id serial [pk] account_id integer [not null] + account_public_id uuid [not null] domain varchar(250) [not null] - verification_host varchar(50) [not null] - verification_code text [not null] - expires_at timestamptz [not null] + verified_at timestamptz [null] + verification_method domain_verification_method [not null] created_at timestamptz [not null, default: `now()`] updated_at timestamptz [not null, default: `now()`] Indexes { (account_id) [name: 'accounts_totps_account_id_idx'] + (account_public_id) [name: 'account_dynamic_registration_domains_account_public_id_idx'] (domain) [name: 'account_dynamic_registration_domains_domain_idx'] - (account_id, domain) [unique, name: 'account_dynamic_registration_domains_account_id_domain_uidx'] + (account_public_id, domain) [unique, name: 'account_dynamic_registration_domains_account_public_id_domain_uidx'] } } Ref: ADRD.account_id > A.id [delete: cascade] +Table account_dynamic_registration_domain_codes as ADRDC { + id serial [pk] + + account_id integer [not null] + account_dynamic_registration_domain_id integer [not null] + + verification_host varchar(50) [not null] + verification_code text [not null] + hmac_secret_id varchar(22) [not null] + verification_prefix varchar(70) [not null] + expires_at timestamptz [not null] + + created_at timestamptz [not null, default: `now()`] + updated_at timestamptz [not null, default: `now()`] + + Indexes { + (account_id) [name: 'account_dynamic_registration_domain_codes_account_id_idx'] + (account_dynamic_registration_domain_id) [name: 'account_dynamic_registration_domain_codes_account_dynamic_registration_domain_id_idx'] + } +} +Ref: ADRDC.account_id > A.id [delete: cascade] +Ref: ADRDC.account_dynamic_registration_domain_id > ADRD.id [delete: cascade] +Ref: ADRDC.hmac_secret_id > AHS.secret_id [delete: cascade, update: cascade] + Table app_dynamic_registration_configs as APDRC { id serial [pk] diff --git a/idp/internal/config/config.go b/idp/internal/config/config.go index fde9102..423746d 100644 --- a/idp/internal/config/config.go +++ b/idp/internal/config/config.go @@ -40,6 +40,7 @@ type Config struct { kekExpirationDays int64 dekExpirationDays int64 jwkExpirationDays int64 + hmacSecretExpDays int64 accountCCExpDays int64 userCCExpDays int64 appCCExpDays int64 @@ -141,6 +142,10 @@ func (c *Config) JWKExpirationDays() int64 { return c.jwkExpirationDays } +func (c *Config) HMACSecretExpDays() int64 { + return c.hmacSecretExpDays +} + func (c *Config) AccountCCExpDays() int64 { return c.accountCCExpDays } @@ -169,7 +174,7 @@ func (c *Config) AppsDomainVerificationTTL() int64 { return c.appsDomainVerificationTTL } -var variables = [49]string{ +var variables = [45]string{ "PORT", "ENV", "DEBUG", @@ -197,14 +202,10 @@ var variables = [49]string{ "OPENBAO_ROLE_ID", "OPENBAO_SECRET_ID", "KEK_PATH", - "DEK_TTL_SEC", - "JWK_TTL_SEC", "KEK_EXPIRATION_DAYS", "DEK_EXPIRATION_DAYS", "JWK_EXPIRATION_DAYS", - "KEK_CACHE_TTL_SEC", - "DECRYPT_DEK_CACHE_TTL_SEC", - "ENCRYPT_DEK_CACHE_TTL_SEC", + "HMAC_SECRET_EXPIRATION_DAYS", "PUBLIC_JWK_CACHE_TTL_SEC", "PRIVATE_JWK_CACHE_TTL_SEC", "PUBLIC_JWKS_CACHE_TTL_SEC", @@ -234,7 +235,7 @@ var optionalVariables = [10]string{ "MICROSOFT_CLIENT_SECRET", } -var numerics = [31]string{ +var numerics = [27]string{ "PORT", "MAX_PROCS", "JWT_ACCESS_TTL_SEC", @@ -246,14 +247,10 @@ var numerics = [31]string{ "JWT_APPS_TTL_SEC", "RATE_LIMITER_MAX", "RATE_LIMITER_EXP_SEC", - "DEK_TTL_SEC", - "JWK_TTL_SEC", "KEK_EXPIRATION_DAYS", "DEK_EXPIRATION_DAYS", "JWK_EXPIRATION_DAYS", - "KEK_CACHE_TTL_SEC", - "DECRYPT_DEK_CACHE_TTL_SEC", - "ENCRYPT_DEK_CACHE_TTL_SEC", + "HMAC_SECRET_EXPIRATION_DAYS", "PUBLIC_JWK_CACHE_TTL_SEC", "PRIVATE_JWK_CACHE_TTL_SEC", "PUBLIC_JWKS_CACHE_TTL_SEC", @@ -345,12 +342,11 @@ func NewConfig(logger *slog.Logger, envPath string) Config { ), cryptoConfig: NewEncryptionConfig( variablesMap["KEK_PATH"], - intMap["DEK_TTL_SEC"], - intMap["JWK_TTL_SEC"], ), kekExpirationDays: intMap["KEK_EXPIRATION_DAYS"], dekExpirationDays: intMap["DEK_EXPIRATION_DAYS"], jwkExpirationDays: intMap["JWK_EXPIRATION_DAYS"], + hmacSecretExpDays: intMap["HMAC_SECRET_EXPIRATION_DAYS"], distributedCache: NewDistributedCache( intMap["KEK_CACHE_TTL_SEC"], intMap["DECRYPT_DEK_CACHE_TTL_SEC"], diff --git a/idp/internal/config/encryption.go b/idp/internal/config/encryption.go index 75dd3ab..54e8d9a 100644 --- a/idp/internal/config/encryption.go +++ b/idp/internal/config/encryption.go @@ -8,26 +8,14 @@ package config type CryptoConfig struct { kekPath string - dekTTL int64 - jwkTTL int64 } func (cc *CryptoConfig) KEKPath() string { return cc.kekPath } -func (cc *CryptoConfig) DEKTTL() int64 { - return cc.dekTTL -} - -func (cc *CryptoConfig) JWKTTL() int64 { - return cc.jwkTTL -} - -func NewEncryptionConfig(kekPath string, dekTTL, jwkTTL int64) CryptoConfig { +func NewEncryptionConfig(kekPath string) CryptoConfig { return CryptoConfig{ kekPath: kekPath, - dekTTL: dekTTL, - jwkTTL: jwkTTL, } } diff --git a/idp/internal/exceptions/services.go b/idp/internal/exceptions/services.go index a66ecc5..c2191b6 100644 --- a/idp/internal/exceptions/services.go +++ b/idp/internal/exceptions/services.go @@ -67,6 +67,10 @@ func NewValidationError(message string) *ServiceError { return NewError(CodeValidation, message) } +func NewNotFoundValidationError(message string) *ServiceError { + return NewError(CodeNotFound, message) +} + func NewInternalServerError() *ServiceError { return NewError(CodeInternalServerError, MessageUnknown) } @@ -87,6 +91,10 @@ func NewForbiddenError() *ServiceError { return NewError(CodeForbidden, MessageForbidden) } +func NewForbiddenValidationError(message string) *ServiceError { + return NewError(CodeForbidden, message) +} + func (e *ServiceError) Error() string { return e.Message } diff --git a/idp/internal/providers/crypto/encryption.go b/idp/internal/providers/crypto/encryption.go index 8691bda..5326841 100644 --- a/idp/internal/providers/crypto/encryption.go +++ b/idp/internal/providers/crypto/encryption.go @@ -8,7 +8,6 @@ package crypto import ( "log/slog" - "time" openbao "github.com/openbao/openbao/api/v2" @@ -23,8 +22,6 @@ type Crypto struct { opLogical *openbao.Logical serviceName string kekPath string - dekTTL time.Duration - jwkTTL time.Duration } func NewCrypto( @@ -38,7 +35,5 @@ func NewCrypto( opLogical: op.Logical(), kekPath: encCfg.KEKPath(), serviceName: utils.Capitalized(serviceName), - dekTTL: time.Duration(encCfg.DEKTTL()) * time.Second, - jwkTTL: time.Duration(encCfg.JWKTTL()) * time.Second, } } diff --git a/idp/internal/providers/crypto/hmac.go b/idp/internal/providers/crypto/hmac.go new file mode 100644 index 0000000..e969b08 --- /dev/null +++ b/idp/internal/providers/crypto/hmac.go @@ -0,0 +1,158 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package crypto + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const ( + hmacLocation string = "hmac" + + hmacSecretByteLength int = 32 +) + +func encodeHMACSecret(secret []byte) string { + return base64.StdEncoding.EncodeToString(secret) +} + +func decodeHMACSecret(secret string) ([]byte, error) { + return base64.StdEncoding.DecodeString(secret) +} + +type StoreHMACSecret = func(dekID string, secretID string, encryptedSecret string) (int32, *exceptions.ServiceError) + +type GenerateHMACSecretOptions struct { + RequestID string + StoreFN StoreHMACSecret + GetDEKfn GetDEKtoEncrypt +} + +func (e *Crypto) GenerateHMACSecret(ctx context.Context, opts GenerateHMACSecretOptions) (string, *exceptions.ServiceError) { + logger := utils.BuildLogger(e.logger, utils.LoggerOptions{ + Location: hmacLocation, + Method: "GenerateHMACSecret", + RequestID: opts.RequestID, + }) + logger.DebugContext(ctx, "Generating HMAC secret...") + + secretBytes, err := utils.GenerateRandomBytes(hmacSecretByteLength) + if err != nil { + logger.ErrorContext(ctx, "Failed to generate HMAC secret", "error", err) + return "", exceptions.NewInternalServerError() + } + secretID := utils.ExtractSecretID(secretBytes) + + dekID, encryptedSecret, serviceErr := e.EncryptWithDEK(ctx, EncryptWithDEKOptions{ + RequestID: opts.RequestID, + GetDEKfn: opts.GetDEKfn, + PlainText: encodeHMACSecret(secretBytes), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to encrypt HMAC secret", "serviceError", serviceErr) + return "", exceptions.NewInternalServerError() + } + + dbID, serviceErr := opts.StoreFN(dekID, secretID, encryptedSecret) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to store HMAC secret", "serviceError", serviceErr) + return "", exceptions.NewInternalServerError() + } + + logger.InfoContext(ctx, "HMAC secret generated and stored successfully", "dekID", dekID, "dbID", dbID) + return secretID, nil +} + +func encodeHMACData(data []byte) string { + return hex.EncodeToString(data) +} + +// func decodeHMACData(data string) ([]byte, error) { +// return hex.DecodeString(data) +// } + +type GetHMACSecretFN = func() (string, DEKCiphertext, *exceptions.ServiceError) + +type StoreHashedData = func(secretID string, hashedData string) *exceptions.ServiceError + +type HMACSha256HashOptions struct { + RequestID string + PlainText string + GetHMACSecretFN GetHMACSecretFN + StoreHashedDataFN StoreHashedData + GetDecryptDEKfn GetDEKtoDecrypt + GetEncryptDEKfn GetDEKtoEncrypt + StoreReEncryptedHMACSecretFN StoreReEncryptedData +} + +func (e *Crypto) HMACSha256Hash(ctx context.Context, opts HMACSha256HashOptions) *exceptions.ServiceError { + logger := utils.BuildLogger(e.logger, utils.LoggerOptions{ + Location: hmacLocation, + Method: "HMACSha256Hash", + RequestID: opts.RequestID, + }) + logger.DebugContext(ctx, "Calculating HMAC SHA256...") + + secretID, dekCiphertext, serviceErr := opts.GetHMACSecretFN() + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get HMAC secret", "serviceError", serviceErr) + return exceptions.NewInternalServerError() + } + + encodedSecret, serviceErr := e.DecryptWithDEK(ctx, DecryptWithDEKOptions{ + RequestID: opts.RequestID, + GetDecryptDEKfn: opts.GetDecryptDEKfn, + GetEncryptDEKfn: opts.GetEncryptDEKfn, + StoreReEncryptedDataFn: opts.StoreReEncryptedHMACSecretFN, + EntityID: secretID, + Ciphertext: dekCiphertext, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to decrypt HMAC secret", "serviceError", serviceErr) + return exceptions.NewInternalServerError() + } + + secret, err := decodeHMACSecret(encodedSecret) + if err != nil { + logger.ErrorContext(ctx, "Failed to decode HMAC secret", "error", err) + return exceptions.NewInternalServerError() + } + + mac := hmac.New(sha256.New, secret) + mac.Write([]byte(opts.PlainText)) + if serviceErr := opts.StoreHashedDataFN(secretID, encodeHMACData(mac.Sum(nil))); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to store hashed data", "serviceError", serviceErr) + return exceptions.NewInternalServerError() + } + + return nil +} + +// type HMACSha256CompareHashOptions struct { +// RequestID string +// PlainText string +// GetHMACSecretFN GetHMACSecretFN +// GetDecryptDEKfn GetDEKtoDecrypt +// GetEncryptDEKfn GetDEKtoEncrypt +// StoreReEncryptedHMACSecretFN StoreReEncryptedData +// } + +// func (e *Crypto) HMACSha256CompareHash(ctx context.Context, opts HMACSha256CompareHashOptions) *exceptions.ServiceError { +// logger := utils.BuildLogger(e.logger, utils.LoggerOptions{ +// Location: hmacLocation, +// Method: "HMACSha256CompareHash", +// RequestID: opts.RequestID, +// }) +// logger.DebugContext(ctx, "Comparing HMAC SHA256...") +// } diff --git a/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go b/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go new file mode 100644 index 0000000..448d0b8 --- /dev/null +++ b/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go @@ -0,0 +1,60 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: account_dynamic_registration_domain_codes.sql + +package database + +import ( + "context" + "time" +) + +const createAccountDynamicRegistrationDomainCode = `-- name: CreateAccountDynamicRegistrationDomainCode :exec + +INSERT INTO "account_dynamic_registration_domain_codes" ( + "account_id", + "account_dynamic_registration_domain_id", + "verification_host", + "verification_code", + "verification_prefix", + "hmac_secret_id", + "expires_at" +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) +` + +type CreateAccountDynamicRegistrationDomainCodeParams struct { + AccountID int32 + AccountDynamicRegistrationDomainID int32 + VerificationHost string + VerificationCode string + VerificationPrefix string + HmacSecretID string + ExpiresAt time.Time +} + +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +func (q *Queries) CreateAccountDynamicRegistrationDomainCode(ctx context.Context, arg CreateAccountDynamicRegistrationDomainCodeParams) error { + _, err := q.db.Exec(ctx, createAccountDynamicRegistrationDomainCode, + arg.AccountID, + arg.AccountDynamicRegistrationDomainID, + arg.VerificationHost, + arg.VerificationCode, + arg.VerificationPrefix, + arg.HmacSecretID, + arg.ExpiresAt, + ) + return err +} diff --git a/idp/internal/providers/database/account_dynamic_registration_domains.sql.go b/idp/internal/providers/database/account_dynamic_registration_domains.sql.go new file mode 100644 index 0000000..47db56a --- /dev/null +++ b/idp/internal/providers/database/account_dynamic_registration_domains.sql.go @@ -0,0 +1,85 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: account_dynamic_registration_domains.sql + +package database + +import ( + "context" + + "github.com/google/uuid" +) + +const createAccountDynamicRegistrationDomain = `-- name: CreateAccountDynamicRegistrationDomain :one + +INSERT INTO "account_dynamic_registration_domains" ( + "account_id", + "account_public_id", + "domain", + "verification_method" +) VALUES ( + $1, + $2, + $3, + $4 +) RETURNING id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at +` + +type CreateAccountDynamicRegistrationDomainParams struct { + AccountID int32 + AccountPublicID uuid.UUID + Domain string + VerificationMethod DomainVerificationMethod +} + +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +func (q *Queries) CreateAccountDynamicRegistrationDomain(ctx context.Context, arg CreateAccountDynamicRegistrationDomainParams) (AccountDynamicRegistrationDomain, error) { + row := q.db.QueryRow(ctx, createAccountDynamicRegistrationDomain, + arg.AccountID, + arg.AccountPublicID, + arg.Domain, + arg.VerificationMethod, + ) + var i AccountDynamicRegistrationDomain + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const findAccountDynamicRegistrationDomainByAccountPublicIDAndDomain = `-- name: FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain :one +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 LIMIT 1 +` + +type FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams struct { + AccountPublicID uuid.UUID + Domain string +} + +func (q *Queries) FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx context.Context, arg FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams) (AccountDynamicRegistrationDomain, error) { + row := q.db.QueryRow(ctx, findAccountDynamicRegistrationDomainByAccountPublicIDAndDomain, arg.AccountPublicID, arg.Domain) + var i AccountDynamicRegistrationDomain + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/idp/internal/providers/database/account_hmac_secrets.sql.go b/idp/internal/providers/database/account_hmac_secrets.sql.go new file mode 100644 index 0000000..6a63a35 --- /dev/null +++ b/idp/internal/providers/database/account_hmac_secrets.sql.go @@ -0,0 +1,125 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: account_hmac_secrets.sql + +package database + +import ( + "context" + "time" +) + +const createAccountHMACSecret = `-- name: CreateAccountHMACSecret :one + +INSERT INTO "account_hmac_secrets" ( + "account_id", + "secret_id", + "secret", + "dek_kid", + "expires_at" +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING "id" +` + +type CreateAccountHMACSecretParams struct { + AccountID int32 + SecretID string + Secret string + DekKid string + ExpiresAt time.Time +} + +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +func (q *Queries) CreateAccountHMACSecret(ctx context.Context, arg CreateAccountHMACSecretParams) (int32, error) { + row := q.db.QueryRow(ctx, createAccountHMACSecret, + arg.AccountID, + arg.SecretID, + arg.Secret, + arg.DekKid, + arg.ExpiresAt, + ) + var id int32 + err := row.Scan(&id) + return id, err +} + +const findAccountHMACSecretByAccountIDAndSecretID = `-- name: FindAccountHMACSecretByAccountIDAndSecretID :one +SELECT id, account_id, secret_id, secret, dek_kid, is_revoked, expires_at, created_at FROM "account_hmac_secrets" +WHERE "account_id" = $1 AND "secret_id" = $2 +LIMIT 1 +` + +type FindAccountHMACSecretByAccountIDAndSecretIDParams struct { + AccountID int32 + SecretID string +} + +func (q *Queries) FindAccountHMACSecretByAccountIDAndSecretID(ctx context.Context, arg FindAccountHMACSecretByAccountIDAndSecretIDParams) (AccountHmacSecret, error) { + row := q.db.QueryRow(ctx, findAccountHMACSecretByAccountIDAndSecretID, arg.AccountID, arg.SecretID) + var i AccountHmacSecret + err := row.Scan( + &i.ID, + &i.AccountID, + &i.SecretID, + &i.Secret, + &i.DekKid, + &i.IsRevoked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} + +const findValidHMACSecretByAccountID = `-- name: FindValidHMACSecretByAccountID :one +SELECT id, account_id, secret_id, secret, dek_kid, is_revoked, expires_at, created_at FROM "account_hmac_secrets" +WHERE + "account_id" = $1 AND + "is_revoked" = false AND + "expires_at" > now() +LIMIT 1 +` + +func (q *Queries) FindValidHMACSecretByAccountID(ctx context.Context, accountID int32) (AccountHmacSecret, error) { + row := q.db.QueryRow(ctx, findValidHMACSecretByAccountID, accountID) + var i AccountHmacSecret + err := row.Scan( + &i.ID, + &i.AccountID, + &i.SecretID, + &i.Secret, + &i.DekKid, + &i.IsRevoked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} + +const updateAccountHMACSecret = `-- name: UpdateAccountHMACSecret :exec +UPDATE "account_hmac_secrets" SET + "secret" = $2, + "dek_kid" = $3, + "updated_at" = now() +WHERE "id" = $1 +` + +type UpdateAccountHMACSecretParams struct { + ID int32 + Secret string + DekKid string +} + +func (q *Queries) UpdateAccountHMACSecret(ctx context.Context, arg UpdateAccountHMACSecretParams) error { + _, err := q.db.Exec(ctx, updateAccountHMACSecret, arg.ID, arg.Secret, arg.DekKid) + return err +} diff --git a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql index 019dbe4..f648f51 100644 --- a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql +++ b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql @@ -1,6 +1,6 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-08-16T01:54:22.889Z +-- Generated at: 2025-08-16T10:00:02.079Z CREATE TYPE "kek_usage" AS ENUM ( 'global', @@ -174,6 +174,12 @@ CREATE TYPE "software_statement_verification_method" AS ENUM ( 'jwks_uri' ); +CREATE TYPE "domain_verification_method" AS ENUM ( + 'authorization_code', + 'software_statement', + 'dns_txt_record' +); + CREATE TYPE "app_profile_type" AS ENUM ( 'human', 'machine', @@ -294,6 +300,17 @@ CREATE TABLE "account_data_encryption_keys" ( PRIMARY KEY ("account_id", "data_encryption_key_id") ); +CREATE TABLE "account_hmac_secrets" ( + "id" serial PRIMARY KEY, + "account_id" integer NOT NULL, + "secret_id" varchar(22) NOT NULL, + "secret" text NOT NULL, + "dek_kid" varchar(22) NOT NULL, + "is_revoked" boolean NOT NULL DEFAULT false, + "expires_at" timestamptz NOT NULL, + "created_at" timestamptz NOT NULL DEFAULT (now()) +); + CREATE TABLE "account_totps" ( "account_id" integer NOT NULL, "totp_id" integer NOT NULL, @@ -548,10 +565,23 @@ CREATE TABLE "account_dynamic_registration_configs" ( CREATE TABLE "account_dynamic_registration_domains" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, + "account_public_id" uuid NOT NULL, "domain" varchar(250) NOT NULL, + "verified_at" timestamptz, + "verification_method" domain_verification_method NOT NULL, + "created_at" timestamptz NOT NULL DEFAULT (now()), + "updated_at" timestamptz NOT NULL DEFAULT (now()) +); + +CREATE TABLE "account_dynamic_registration_domain_codes" ( + "id" serial PRIMARY KEY, + "account_id" integer NOT NULL, + "account_dynamic_registration_domain_id" integer NOT NULL, "verification_host" varchar(50) NOT NULL, "verification_code" text NOT NULL, - "dek_kid" varchar(22) NOT NULL, + "hmac_secret_id" varchar(22) NOT NULL, + "verification_prefix" varchar(70) NOT NULL, + "expires_at" timestamptz NOT NULL, "created_at" timestamptz NOT NULL DEFAULT (now()), "updated_at" timestamptz NOT NULL DEFAULT (now()) ); @@ -674,6 +704,16 @@ CREATE UNIQUE INDEX "account_data_encryption_keys_data_encryption_key_id_uidx" O CREATE UNIQUE INDEX "account_data_encryption_keys_account_id_data_encryption_key_id_uidx" ON "account_data_encryption_keys" ("account_id", "data_encryption_key_id"); +CREATE INDEX "account_hmac_secrets_account_id_idx" ON "account_hmac_secrets" ("account_id"); + +CREATE UNIQUE INDEX "account_hmac_secrets_secret_id_uidx" ON "account_hmac_secrets" ("secret_id"); + +CREATE INDEX "account_hmac_secrets_dek_kid_idx" ON "account_hmac_secrets" ("dek_kid"); + +CREATE INDEX "account_hmac_secrets_account_id_secret_id_idx" ON "account_hmac_secrets" ("account_id", "secret_id"); + +CREATE INDEX "account_hmac_secrets_account_id_is_revoked_expires_at_idx" ON "account_hmac_secrets" ("account_id", "is_revoked", "expires_at"); + CREATE UNIQUE INDEX "accounts_totps_account_id_uidx" ON "account_totps" ("account_id"); CREATE UNIQUE INDEX "accounts_totps_totp_id_uidx" ON "account_totps" ("totp_id"); @@ -844,13 +884,17 @@ CREATE UNIQUE INDEX "account_dynamic_registration_configs_account_id_uidx" ON "a CREATE INDEX "account_dynamic_registration_configs_account_public_id_idx" ON "account_dynamic_registration_configs" ("account_public_id"); -CREATE INDEX "accounts_totps_dek_kid_idx" ON "account_dynamic_registration_domains" ("dek_kid"); - CREATE INDEX "accounts_totps_account_id_idx" ON "account_dynamic_registration_domains" ("account_id"); +CREATE INDEX "account_dynamic_registration_domains_account_public_id_idx" ON "account_dynamic_registration_domains" ("account_public_id"); + CREATE INDEX "account_dynamic_registration_domains_domain_idx" ON "account_dynamic_registration_domains" ("domain"); -CREATE UNIQUE INDEX "account_dynamic_registration_domains_account_id_domain_uidx" ON "account_dynamic_registration_domains" ("account_id", "domain"); +CREATE UNIQUE INDEX "account_dynamic_registration_domains_account_public_id_domain_uidx" ON "account_dynamic_registration_domains" ("account_public_id", "domain"); + +CREATE INDEX "account_dynamic_registration_domain_codes_account_id_idx" ON "account_dynamic_registration_domain_codes" ("account_id"); + +CREATE INDEX "account_dynamic_registration_domain_codes_account_dynamic_registration_domain_id_idx" ON "account_dynamic_registration_domain_codes" ("account_dynamic_registration_domain_id"); CREATE INDEX "app_dynamic_registration_configs_account_id_idx" ON "app_dynamic_registration_configs" ("account_id"); @@ -890,6 +934,10 @@ ALTER TABLE "account_data_encryption_keys" ADD FOREIGN KEY ("account_id") REFERE ALTER TABLE "account_data_encryption_keys" ADD FOREIGN KEY ("data_encryption_key_id") REFERENCES "data_encryption_keys" ("id") ON DELETE CASCADE; +ALTER TABLE "account_hmac_secrets" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; + +ALTER TABLE "account_hmac_secrets" ADD FOREIGN KEY ("dek_kid") REFERENCES "data_encryption_keys" ("kid") ON DELETE CASCADE ON UPDATE CASCADE; + ALTER TABLE "account_totps" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; ALTER TABLE "account_totps" ADD FOREIGN KEY ("totp_id") REFERENCES "totps" ("id") ON DELETE CASCADE; @@ -986,7 +1034,11 @@ ALTER TABLE "account_dynamic_registration_configs" ADD FOREIGN KEY ("account_id" ALTER TABLE "account_dynamic_registration_domains" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; -ALTER TABLE "account_dynamic_registration_domains" ADD FOREIGN KEY ("dek_kid") REFERENCES "data_encryption_keys" ("kid") ON DELETE CASCADE ON UPDATE CASCADE; +ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; + +ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("account_dynamic_registration_domain_id") REFERENCES "account_dynamic_registration_domains" ("id") ON DELETE CASCADE; + +ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("hmac_secret_id") REFERENCES "account_hmac_secrets" ("secret_id") ON DELETE CASCADE ON UPDATE CASCADE; ALTER TABLE "app_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; diff --git a/idp/internal/providers/database/models.go b/idp/internal/providers/database/models.go index f52c300..398c09a 100644 --- a/idp/internal/providers/database/models.go +++ b/idp/internal/providers/database/models.go @@ -518,6 +518,49 @@ func (ns NullDekUsage) Value() (driver.Value, error) { return string(ns.DekUsage), nil } +type DomainVerificationMethod string + +const ( + DomainVerificationMethodAuthorizationCode DomainVerificationMethod = "authorization_code" + DomainVerificationMethodSoftwareStatement DomainVerificationMethod = "software_statement" + DomainVerificationMethodDnsTxtRecord DomainVerificationMethod = "dns_txt_record" +) + +func (e *DomainVerificationMethod) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = DomainVerificationMethod(s) + case string: + *e = DomainVerificationMethod(s) + default: + return fmt.Errorf("unsupported scan type for DomainVerificationMethod: %T", src) + } + return nil +} + +type NullDomainVerificationMethod struct { + DomainVerificationMethod DomainVerificationMethod + Valid bool // Valid is true if DomainVerificationMethod is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullDomainVerificationMethod) Scan(value interface{}) error { + if value == nil { + ns.DomainVerificationMethod, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.DomainVerificationMethod.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullDomainVerificationMethod) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.DomainVerificationMethod), nil +} + type GrantType string const ( @@ -1211,6 +1254,41 @@ type AccountDynamicRegistrationConfig struct { UpdatedAt time.Time } +type AccountDynamicRegistrationDomain struct { + ID int32 + AccountID int32 + AccountPublicID uuid.UUID + Domain string + VerifiedAt pgtype.Timestamptz + VerificationMethod DomainVerificationMethod + CreatedAt time.Time + UpdatedAt time.Time +} + +type AccountDynamicRegistrationDomainCode struct { + ID int32 + AccountID int32 + AccountDynamicRegistrationDomainID int32 + VerificationHost string + VerificationCode string + HmacSecretID string + VerificationPrefix string + ExpiresAt time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +type AccountHmacSecret struct { + ID int32 + AccountID int32 + SecretID string + Secret string + DekKid string + IsRevoked bool + ExpiresAt time.Time + CreatedAt time.Time +} + type AccountKeyEncryptionKey struct { AccountID int32 KeyEncryptionKeyID int32 diff --git a/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql b/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql new file mode 100644 index 0000000..9c99bdf --- /dev/null +++ b/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql @@ -0,0 +1,24 @@ +-- Copyright (c) 2025 Afonso Barracha +-- +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +-- name: CreateAccountDynamicRegistrationDomainCode :exec +INSERT INTO "account_dynamic_registration_domain_codes" ( + "account_id", + "account_dynamic_registration_domain_id", + "verification_host", + "verification_code", + "verification_prefix", + "hmac_secret_id", + "expires_at" +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +); \ No newline at end of file diff --git a/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql b/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql new file mode 100644 index 0000000..65895af --- /dev/null +++ b/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql @@ -0,0 +1,21 @@ +-- Copyright (c) 2025 Afonso Barracha +-- +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +-- name: CreateAccountDynamicRegistrationDomain :one +INSERT INTO "account_dynamic_registration_domains" ( + "account_id", + "account_public_id", + "domain", + "verification_method" +) VALUES ( + $1, + $2, + $3, + $4 +) RETURNING *; + +-- name: FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain :one +SELECT * FROM "account_dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 LIMIT 1; \ No newline at end of file diff --git a/idp/internal/providers/database/queries/account_hmac_secrets.sql b/idp/internal/providers/database/queries/account_hmac_secrets.sql new file mode 100644 index 0000000..192e8ad --- /dev/null +++ b/idp/internal/providers/database/queries/account_hmac_secrets.sql @@ -0,0 +1,40 @@ +-- Copyright (c) 2025 Afonso Barracha +-- +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +-- name: CreateAccountHMACSecret :one +INSERT INTO "account_hmac_secrets" ( + "account_id", + "secret_id", + "secret", + "dek_kid", + "expires_at" +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING "id"; + +-- name: UpdateAccountHMACSecret :exec +UPDATE "account_hmac_secrets" SET + "secret" = $2, + "dek_kid" = $3, + "updated_at" = now() +WHERE "id" = $1; + +-- name: FindAccountHMACSecretByAccountIDAndSecretID :one +SELECT * FROM "account_hmac_secrets" +WHERE "account_id" = $1 AND "secret_id" = $2 +LIMIT 1; + +-- name: FindValidHMACSecretByAccountID :one +SELECT * FROM "account_hmac_secrets" +WHERE + "account_id" = $1 AND + "is_revoked" = false AND + "expires_at" > now() +LIMIT 1; \ No newline at end of file diff --git a/idp/internal/server/server.go b/idp/internal/server/server.go index c2f8fd2..50022a0 100644 --- a/idp/internal/server/server.go +++ b/idp/internal/server/server.go @@ -283,6 +283,9 @@ func New( cfg.AccountCCExpDays(), cfg.UserCCExpDays(), cfg.AppCCExpDays(), + cfg.HMACSecretExpDays(), + cfg.AccountDomainVerificationHost(), + cfg.AccountDomainVerificationTTL(), ) logger.InfoContext(ctx, "Finished building services") diff --git a/idp/internal/services/account_credentials_registration.go b/idp/internal/services/account_credentials_registration.go index 1df7cfd..3d57cd8 100644 --- a/idp/internal/services/account_credentials_registration.go +++ b/idp/internal/services/account_credentials_registration.go @@ -5,3 +5,159 @@ // file, You can obtain one at https://mozilla.org/MPL/2.0/. package services + +import ( + "context" + "fmt" + "slices" + "time" + + "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/crypto" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/services/dtos" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const ( + accountCredentialsRegistrationDomainLocation string = "account_credentials_registration_domain" + + domainCodeByteLength int = 32 +) + +type CreateAccountCredentialsRegistrationDomainOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string +} + +func (s *Services) CreateAccountCredentialsRegistrationDomain( + ctx context.Context, + opts CreateAccountCredentialsRegistrationDomainOptions, +) (dtos.AccountCredentialsRegistrationDomainDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainLocation, "CreateAccountCredentialsRegistrationDomain").With( + "accountPublicID", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Creating account credentials registration domain...") + + dynamicRegistrationConfig, serviceErr := s.GetAccountDynamicRegistrationConfig(ctx, GetAccountDynamicRegistrationConfigOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + }) + if serviceErr != nil { + if serviceErr.Code != exceptions.CodeNotFound { + logger.WarnContext(ctx, "Account dynamic registration config not found", "serviceError", serviceErr) + return dtos.AccountCredentialsRegistrationDomainDTO{}, exceptions.NewNotFoundValidationError("Dynamic registration config not found") + } + return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr + } + if len(dynamicRegistrationConfig.WhitelistedDomains) > 0 && !slices.Contains(dynamicRegistrationConfig.WhitelistedDomains, opts.Domain) { + logger.WarnContext(ctx, "Domain is not whitelisted", "domain", opts.Domain) + return dtos.AccountCredentialsRegistrationDomainDTO{}, exceptions.NewForbiddenValidationError("Domain is not whitelisted") + } + + if _, err := s.database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx, database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams{ + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }); err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.WarnContext(ctx, "Failed to find account dynamic registration domain", "error", err) + return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr + } + } else { + logger.InfoContext(ctx, "Account dynamic registration domain already exists", "domain", opts.Domain) + return dtos.AccountCredentialsRegistrationDomainDTO{}, exceptions.NewConflictError("Account credentials registration domain already exists") + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account ID", "serviceError", serviceErr) + return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr + } + + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return dtos.AccountCredentialsRegistrationDomainDTO{}, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + domain, err := qrs.CreateAccountDynamicRegistrationDomain(ctx, database.CreateAccountDynamicRegistrationDomainParams{ + AccountID: accountDTO.ID(), + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + VerificationMethod: database.DomainVerificationMethodDnsTxtRecord, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr + } + + code, err := utils.GenerateBase64Secret(domainCodeByteLength) + if err != nil { + logger.ErrorContext(ctx, "Failed to generate domain code", "error", err) + serviceErr = exceptions.NewInternalServerError() + return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr + } + + verificationPrefix := fmt.Sprintf("%s-verification", accountDTO.Username) + exp := time.Now().Add(s.accountDomainVerificationTTL) + if serviceErr = s.crypto.HMACSha256Hash(ctx, crypto.HMACSha256HashOptions{ + RequestID: opts.RequestID, + PlainText: code, + GetDecryptDEKfn: s.BuildGetDecAccountDEKFn(ctx, BuildGetDecAccountDEKFnOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + Queries: qrs, + }), + GetEncryptDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + Queries: qrs, + }), + GetHMACSecretFN: s.BuildGetHMACSecretFN(ctx, BuildGetHMACSecretFNOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + Queries: qrs, + }), + StoreReEncryptedHMACSecretFN: s.BuildUpdateHMACSecretFN(ctx, BuildUpdateHMACSecretFNOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + Queries: qrs, + }), + StoreHashedDataFN: func(secretID string, hashedData string) *exceptions.ServiceError { + if err := qrs.CreateAccountDynamicRegistrationDomainCode(ctx, database.CreateAccountDynamicRegistrationDomainCodeParams{ + AccountID: accountDTO.ID(), + AccountDynamicRegistrationDomainID: domain.ID, + VerificationCode: hashedData, + VerificationPrefix: verificationPrefix, + VerificationHost: s.accountDomainVerificationHost, + HmacSecretID: secretID, + ExpiresAt: exp, + }); err != nil { + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code", "error", err) + return exceptions.FromDBError(err) + } + return nil + }, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to hash code", "serviceError", serviceErr) + return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Created account dynamic registration domain successfully") + return dtos.MapAccountCredentialsRegistrationDomainToDTOWithCode(&domain, s.accountDomainVerificationHost, verificationPrefix, code, exp), nil +} diff --git a/idp/internal/services/account_hmac_secrets.go b/idp/internal/services/account_hmac_secrets.go new file mode 100644 index 0000000..5b1be4a --- /dev/null +++ b/idp/internal/services/account_hmac_secrets.go @@ -0,0 +1,157 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + "time" + + "github.com/jackc/pgx/v5" + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/crypto" + "github.com/tugascript/devlogs/idp/internal/providers/database" +) + +const accountHMACSecretsLocation = "account_hmac_secrets" + +type buildStoreAccountHMACSecretOptions struct { + requestID string + accountID int32 + data map[string]string + queries *database.Queries +} + +func (s *Services) buildStoreAccountHMACSecretFn( + ctx context.Context, + opts buildStoreAccountHMACSecretOptions, +) crypto.StoreHMACSecret { + logger := s.buildLogger(opts.requestID, accountHMACSecretsLocation, "buildStoreAccountHMACSecretFn") + logger.InfoContext(ctx, "Building store function for account HMAC secret...") + + return func(dekID string, secretID string, encryptedSecret string) (int32, *exceptions.ServiceError) { + var qrs *database.Queries + var txn pgx.Tx + var err error + var serviceErr *exceptions.ServiceError + if opts.queries != nil { + qrs = opts.queries + } else { + qrs, txn, err = s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return 0, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + } + + id, err := qrs.CreateAccountHMACSecret(ctx, database.CreateAccountHMACSecretParams{ + AccountID: opts.accountID, + SecretID: secretID, + Secret: encryptedSecret, + DekKid: dekID, + ExpiresAt: time.Now().Add(s.hmacSecretExpDays), + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account HMAC secret", "error", err) + serviceErr = exceptions.FromDBError(err) + return 0, serviceErr + } + + opts.data["secretID"] = secretID + opts.data["encryptedSecret"] = encryptedSecret + logger.InfoContext(ctx, "Created account HMAC secret", "id", id) + return id, nil + } +} + +type BuildGetHMACSecretFNOptions struct { + RequestID string + AccountID int32 + Queries *database.Queries +} + +func (s *Services) BuildGetHMACSecretFN( + ctx context.Context, + opts BuildGetHMACSecretFNOptions, +) crypto.GetHMACSecretFN { + logger := s.buildLogger(opts.RequestID, accountHMACSecretsLocation, "BuildGetHMACSecretFN") + logger.InfoContext(ctx, "Building get HMAC secret function...") + + return func() (string, crypto.DEKCiphertext, *exceptions.ServiceError) { + logger.InfoContext(ctx, "Getting HMAC secret...") + + secret, err := s.mapQueries(opts.Queries).FindValidHMACSecretByAccountID(ctx, opts.AccountID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find account HMAC secret", "error", err) + return "", "", serviceErr + } + + data := make(map[string]string) + if _, serviceErr := s.crypto.GenerateHMACSecret(ctx, crypto.GenerateHMACSecretOptions{ + RequestID: opts.RequestID, + StoreFN: s.buildStoreAccountHMACSecretFn(ctx, buildStoreAccountHMACSecretOptions{ + requestID: opts.RequestID, + accountID: opts.AccountID, + queries: opts.Queries, + data: data, + }), + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to generate account HMAC secret", "serviceError", serviceErr) + return "", "", serviceErr + } + + return data["secretID"], data["encryptedSecret"], nil + } + + return secret.SecretID, secret.Secret, nil + } +} + +type BuildUpdateHMACSecretFNOptions struct { + RequestID string + AccountID int32 + Queries *database.Queries +} + +func (s *Services) BuildUpdateHMACSecretFN( + ctx context.Context, + opts BuildUpdateHMACSecretFNOptions, +) crypto.StoreReEncryptedData { + logger := s.buildLogger(opts.RequestID, accountHMACSecretsLocation, "BuildUpdateHMACSecretFN") + logger.InfoContext(ctx, "Building update HMAC secret function...") + + return func(secretID crypto.EntityID, dekID crypto.DEKID, encPrivKey crypto.DEKCiphertext) *exceptions.ServiceError { + logger.InfoContext(ctx, "Updating HMAC secret...") + + qrs := s.mapQueries(opts.Queries) + secret, err := qrs.FindAccountHMACSecretByAccountIDAndSecretID(ctx, database.FindAccountHMACSecretByAccountIDAndSecretIDParams{ + AccountID: opts.AccountID, + SecretID: secretID, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to find account HMAC secret", "error", err) + return exceptions.FromDBError(err) + } + + if err := qrs.UpdateAccountHMACSecret(ctx, database.UpdateAccountHMACSecretParams{ + ID: secret.ID, + Secret: encPrivKey, + DekKid: dekID, + }); err != nil { + logger.ErrorContext(ctx, "Failed to update account HMAC secret", "error", err) + return exceptions.FromDBError(err) + } + + logger.InfoContext(ctx, "Updated HMAC secret successfully") + return nil + } +} diff --git a/idp/internal/services/deks.go b/idp/internal/services/deks.go index 44ac62a..93574d6 100644 --- a/idp/internal/services/deks.go +++ b/idp/internal/services/deks.go @@ -12,6 +12,7 @@ import ( "time" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/cache" @@ -201,6 +202,7 @@ type buildStoreAccountDEKOptions struct { requestID string accountID int32 data map[string]string + queries *database.Queries } func (s *Services) buildStoreAccountDEKfn( @@ -211,16 +213,23 @@ func (s *Services) buildStoreAccountDEKfn( logger.InfoContext(ctx, "Building store function for account DEK...") return func(dekID string, encryptedDEK string, kekID uuid.UUID) (int32, *exceptions.ServiceError) { + var qrs *database.Queries + var txn pgx.Tx + var err error var serviceErr *exceptions.ServiceError - qrs, txn, err := s.database.BeginTx(ctx) - if err != nil { - logger.ErrorContext(ctx, "Failed to start transaction", "error", err) - return 0, exceptions.FromDBError(err) + if opts.queries != nil { + qrs = opts.queries + } else { + qrs, txn, err = s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return 0, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() } - defer func() { - logger.DebugContext(ctx, "Finalizing transaction") - s.database.FinalizeTx(ctx, txn, err, serviceErr) - }() dekEnt, err := qrs.CreateDataEncryptionKey(ctx, database.CreateDataEncryptionKeyParams{ Kid: dekID, @@ -266,6 +275,7 @@ func (s *Services) buildStoreAccountDEKfn( type BuildGetEncAccountDEKOptions struct { RequestID string AccountID int32 + Queries *database.Queries } func (s *Services) BuildGetEncAccountDEKfn( @@ -294,7 +304,8 @@ func (s *Services) BuildGetEncAccountDEKfn( } logger.InfoContext(ctx, "DEK not found in cache, checking database...") - dekEnt, err := s.database.FindAccountDataEncryptionKeyByAccountID( + qrs := s.mapQueries(opts.Queries) + dekEnt, err := qrs.FindAccountDataEncryptionKeyByAccountID( ctx, database.FindAccountDataEncryptionKeyByAccountIDParams{ AccountID: opts.AccountID, @@ -366,6 +377,7 @@ func (s *Services) BuildGetEncAccountDEKfn( type BuildGetDecAccountDEKFnOptions struct { RequestID string AccountID int32 + Queries *database.Queries } func (s *Services) BuildGetDecAccountDEKFn( @@ -397,7 +409,7 @@ func (s *Services) BuildGetDecAccountDEKFn( } logger.InfoContext(ctx, "DEK not found in cache, checking database...") - dekEnt, err := s.database.FindAccountDataEncryptionKeyByAccountIDAndKID( + dekEnt, err := s.mapQueries(opts.Queries).FindAccountDataEncryptionKeyByAccountIDAndKID( ctx, database.FindAccountDataEncryptionKeyByAccountIDAndKIDParams{ AccountID: opts.AccountID, diff --git a/idp/internal/services/dtos/account_credentials_registration_domain.go b/idp/internal/services/dtos/account_credentials_registration_domain.go new file mode 100644 index 0000000..3f27857 --- /dev/null +++ b/idp/internal/services/dtos/account_credentials_registration_domain.go @@ -0,0 +1,50 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package dtos + +import ( + "fmt" + "time" + + "github.com/tugascript/devlogs/idp/internal/providers/database" +) + +type AccountCredentialsRegistrationDomainDTO struct { + id int32 + + Domain string `json:"domain"` + Verified bool `json:"verified"` + + VerificationHost string `json:"verification_host,omitempty"` + VerificationPrefix string `json:"verification_prefix,omitempty"` + VerificationCode string `json:"verification_code,omitempty"` + VerificationValue string `json:"verification_value,omitempty"` + VerificationCodeExpiresAt int64 `json:"verification_code_expires_at,omitempty"` +} + +func (a *AccountCredentialsRegistrationDomainDTO) ID() int32 { + return a.id +} + +func MapAccountCredentialsRegistrationDomainToDTOWithCode( + domain *database.AccountDynamicRegistrationDomain, + verificationHost string, + verificationPrefix string, + verificationCode string, + expiresAt time.Time, +) AccountCredentialsRegistrationDomainDTO { + return AccountCredentialsRegistrationDomainDTO{ + id: domain.ID, + Domain: domain.Domain, + VerificationHost: verificationHost, + VerificationPrefix: verificationPrefix, + VerificationCode: verificationCode, + VerificationValue: fmt.Sprintf("%s=%s", verificationPrefix, verificationCode), + VerificationCodeExpiresAt: expiresAt.Unix(), + Verified: false, + } +} diff --git a/idp/internal/services/helpers.go b/idp/internal/services/helpers.go index f5f1b7b..1fa8be5 100644 --- a/idp/internal/services/helpers.go +++ b/idp/internal/services/helpers.go @@ -56,6 +56,13 @@ func (s *Services) buildLogger(requestID, location, function string) *slog.Logge }) } +func (s *Services) mapQueries(qrs *database.Queries) *database.Queries { + if qrs != nil { + return qrs + } + return s.database.Queries +} + func extractAuthHeaderToken(ah string) (string, *exceptions.ServiceError) { if ah == "" { return "", exceptions.NewUnauthorizedError() diff --git a/idp/internal/services/keks.go b/idp/internal/services/keks.go index 87665ef..c8a6470 100644 --- a/idp/internal/services/keks.go +++ b/idp/internal/services/keks.go @@ -12,6 +12,7 @@ import ( "time" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/cache" @@ -152,6 +153,7 @@ func (s *Services) GetOrCreateGlobalKEK( type createAndCacheAccountKEKOptions struct { requestID string accountID int32 + queries *database.Queries } func (s *Services) createAndCacheAccountKEK( @@ -161,16 +163,23 @@ func (s *Services) createAndCacheAccountKEK( logger := s.buildLogger(opts.requestID, keksLocation, "createAndCacheAccountKEK") logger.InfoContext(ctx, "Creating and caching account KEK...") + var qrs *database.Queries + var txn pgx.Tx + var err error var serviceErr *exceptions.ServiceError - qrs, txn, err := s.database.BeginTx(ctx) - if err != nil { - logger.ErrorContext(ctx, "Failed to start transaction", "error", err) - return uuid.Nil, exceptions.FromDBError(err) + if opts.queries != nil { + qrs = opts.queries + } else { + qrs, txn, err = s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return uuid.Nil, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() } - defer func() { - logger.DebugContext(ctx, "Finalizing transaction") - s.database.FinalizeTx(ctx, txn, err, serviceErr) - }() dbID, keyID, err := s.crypto.GenerateKEK(ctx, crypto.GenerateKEKOptions{ RequestID: opts.requestID, @@ -213,6 +222,7 @@ func (s *Services) createAndCacheAccountKEK( type getAndCacheAccountKEKOptions struct { requestID string accountID int32 + queries *database.Queries } func (s *Services) getAndCacheAccountKEK( @@ -236,7 +246,8 @@ func (s *Services) getAndCacheAccountKEK( return kek, nil } - kekEntity, err := s.database.FindAccountKeyEncryptionKeyByAccountID(ctx, opts.accountID) + qrs := s.mapQueries(opts.queries) + kekEntity, err := qrs.FindAccountKeyEncryptionKeyByAccountID(ctx, opts.accountID) if err != nil { serviceErr := exceptions.FromDBError(err) if serviceErr.Code != exceptions.CodeNotFound { @@ -267,7 +278,7 @@ func (s *Services) getAndCacheAccountKEK( RequestID: opts.requestID, KEKid: kekEntity.Kid, StoreFN: func(_ uuid.UUID) (int32, error) { - return s.database.RotateKeyEncryptionKey(ctx, database.RotateKeyEncryptionKeyParams{ + return qrs.RotateKeyEncryptionKey(ctx, database.RotateKeyEncryptionKeyParams{ ID: kekEntity.ID, NextRotationAt: time.Now().Add(s.kekExpDays), }) @@ -292,6 +303,7 @@ func (s *Services) getAndCacheAccountKEK( type GetOrCreateAccountKEKOptions struct { RequestID string AccountID int32 + Queries *database.Queries } func (s *Services) GetOrCreateAccountKEK( @@ -304,6 +316,7 @@ func (s *Services) GetOrCreateAccountKEK( kek, serviceErr := s.getAndCacheAccountKEK(ctx, getAndCacheAccountKEKOptions{ requestID: opts.RequestID, accountID: opts.AccountID, + queries: opts.Queries, }) if serviceErr != nil { if serviceErr.Code == exceptions.CodeNotFound { @@ -311,6 +324,7 @@ func (s *Services) GetOrCreateAccountKEK( return s.createAndCacheAccountKEK(ctx, createAndCacheAccountKEKOptions{ requestID: opts.RequestID, accountID: opts.AccountID, + queries: opts.Queries, }) } diff --git a/idp/internal/services/services.go b/idp/internal/services/services.go index a70ac1b..d955264 100644 --- a/idp/internal/services/services.go +++ b/idp/internal/services/services.go @@ -20,19 +20,22 @@ import ( ) type Services struct { - logger *slog.Logger - database *database.Database - cache *cache.Cache - mail *mailer.EmailPublisher - jwt *tokens.Tokens - crypto *crypto.Crypto - oauthProviders *oauth.Providers - kekExpDays time.Duration - dekExpDays time.Duration - jwkExpDays time.Duration - accountCCExpDays time.Duration - appCCExpDays time.Duration - userCCExpDays time.Duration + logger *slog.Logger + database *database.Database + cache *cache.Cache + mail *mailer.EmailPublisher + jwt *tokens.Tokens + crypto *crypto.Crypto + oauthProviders *oauth.Providers + kekExpDays time.Duration + dekExpDays time.Duration + jwkExpDays time.Duration + accountCCExpDays time.Duration + appCCExpDays time.Duration + userCCExpDays time.Duration + hmacSecretExpDays time.Duration + accountDomainVerificationHost string + accountDomainVerificationTTL time.Duration } func NewServices( @@ -49,20 +52,26 @@ func NewServices( accountCCExpDays int64, appCCExpDays int64, userCCExpDays int64, + hmacSecretExpDays int64, + accountDomainVerificationHost string, + accountDomainVerificationTTL int64, ) *Services { return &Services{ - logger: logger.With(utils.BaseLayer, utils.ServicesLogLayer), - database: database, - cache: cache, - mail: mail, - jwt: jwt, - crypto: encrypt, - oauthProviders: oauthProv, - kekExpDays: utils.ToDaysDuration(kekExpDays), - dekExpDays: utils.ToDaysDuration(dekExpDays), - jwkExpDays: utils.ToDaysDuration(jwkExpDays), - accountCCExpDays: utils.ToDaysDuration(accountCCExpDays), - appCCExpDays: utils.ToDaysDuration(appCCExpDays), - userCCExpDays: utils.ToDaysDuration(userCCExpDays), + logger: logger.With(utils.BaseLayer, utils.ServicesLogLayer), + database: database, + cache: cache, + mail: mail, + jwt: jwt, + crypto: encrypt, + oauthProviders: oauthProv, + kekExpDays: utils.ToDaysDuration(kekExpDays), + dekExpDays: utils.ToDaysDuration(dekExpDays), + jwkExpDays: utils.ToDaysDuration(jwkExpDays), + accountCCExpDays: utils.ToDaysDuration(accountCCExpDays), + appCCExpDays: utils.ToDaysDuration(appCCExpDays), + userCCExpDays: utils.ToDaysDuration(userCCExpDays), + hmacSecretExpDays: utils.ToDaysDuration(hmacSecretExpDays), + accountDomainVerificationHost: accountDomainVerificationHost, + accountDomainVerificationTTL: utils.ToSecondsDuration(accountDomainVerificationTTL), } } diff --git a/idp/tests/common_test.go b/idp/tests/common_test.go index d8e736f..8edc2d8 100644 --- a/idp/tests/common_test.go +++ b/idp/tests/common_test.go @@ -205,6 +205,9 @@ func initTestServicesAndApp(t *testing.T) { cfg.AccountCCExpDays(), cfg.AppCCExpDays(), cfg.UserCCExpDays(), + cfg.HMACSecretExpDays(), + cfg.AccountDomainVerificationHost(), + cfg.AccountDomainVerificationTTL(), ) _testServer = server.New(ctx, logger, *_testConfig) From 93b7abe4ed0a225f9aff47db9f66f185c9e8214a Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 17 Aug 2025 21:34:33 +1200 Subject: [PATCH 06/23] feat(idp): add dynamic registration credentials domain configuration --- idp/go.mod | 15 +- idp/go.sum | 134 ++- idp/initial_schema.dbml | 29 +- ...ccount_credentials_registration_domains.go | 303 +++++++ .../account_dynamic_registration_configs.go | 30 + .../bodies/dynamic_registration_domains.go | 11 + idp/internal/controllers/middleware.go | 4 +- .../params/dynamic_registration_domains.go | 18 + idp/internal/controllers/paths/domains.go | 15 + .../controllers/paths/dynamic_registration.go | 2 - idp/internal/providers/crypto/hmac.go | 99 ++- ...t_dynamic_registration_domain_codes.sql.go | 57 +- ...ccount_dynamic_registration_domains.sql.go | 255 ++++++ .../dynamic_registration_domain_codes.sql.go | 99 +++ ...0241213231542_create_initial_schema.up.sql | 32 +- idp/internal/providers/database/models.go | 61 +- ...ount_dynamic_registration_domain_codes.sql | 22 +- .../account_dynamic_registration_domains.sql | 52 +- .../dynamic_registration_domain_codes.sql | 35 + idp/internal/providers/tokens/accounts.go | 24 +- idp/internal/server/routes.go | 2 +- .../routes/account_dynamic_registration.go | 60 +- idp/internal/server/validations/scope.go | 5 +- .../account_credentials_registration.go | 163 ---- ...ccount_credentials_registration_domains.go | 819 ++++++++++++++++++ .../account_dynamic_registration_configs.go | 45 + idp/internal/services/account_hmac_secrets.go | 56 +- ...main.go => dynamic_registration_domain.go} | 32 +- .../dtos/dynamic_registration_domain_code.go | 55 ++ idp/internal/services/helpers.go | 41 + 30 files changed, 2174 insertions(+), 401 deletions(-) create mode 100644 idp/internal/controllers/account_credentials_registration_domains.go create mode 100644 idp/internal/controllers/bodies/dynamic_registration_domains.go create mode 100644 idp/internal/controllers/params/dynamic_registration_domains.go create mode 100644 idp/internal/controllers/paths/domains.go create mode 100644 idp/internal/providers/database/dynamic_registration_domain_codes.sql.go create mode 100644 idp/internal/providers/database/queries/dynamic_registration_domain_codes.sql delete mode 100644 idp/internal/services/account_credentials_registration.go create mode 100644 idp/internal/services/account_credentials_registration_domains.go rename idp/internal/services/dtos/{account_credentials_registration_domain.go => dynamic_registration_domain.go} (57%) create mode 100644 idp/internal/services/dtos/dynamic_registration_domain_code.go diff --git a/idp/go.mod b/idp/go.mod index c2176eb..1c83a41 100644 --- a/idp/go.mod +++ b/idp/go.mod @@ -24,22 +24,27 @@ require ( ) require ( - cloud.google.com/go/auth v0.16.4 // indirect + cloud.google.com/go/auth v0.16.5 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.8.0 // indirect + dario.cat/mergo v1.0.2 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/boombuler/barcode v1.1.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/docker/docker v28.3.3+incompatible // indirect + github.com/docker/go-connections v0.6.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gabriel-vasile/mimetype v1.4.9 // indirect github.com/go-jose/go-jose/v3 v3.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/gofiber/storage/testhelpers/redis v0.0.0-20250815074620-1386290f7fd5 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect @@ -57,27 +62,29 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/moby/term v0.5.2 // indirect github.com/philhofer/fwd v1.2.0 // indirect + github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect + github.com/shirou/gopsutil/v4 v4.25.7 // indirect github.com/tinylib/msgp v1.3.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasthttp v1.64.0 // indirect + github.com/valyala/fasthttp v1.65.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect - golang.org/x/mod v0.26.0 // indirect golang.org/x/net v0.43.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/time v0.12.0 // indirect - golang.org/x/tools v0.35.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect google.golang.org/grpc v1.74.2 // indirect google.golang.org/protobuf v1.36.7 // indirect diff --git a/idp/go.sum b/idp/go.sum index f5c3c2b..90b40f4 100644 --- a/idp/go.sum +++ b/idp/go.sum @@ -1,28 +1,22 @@ -cloud.google.com/go/auth v0.16.3 h1:kabzoQ9/bobUmnseYnBO6qQG7q4a/CffFRlJSxv2wCc= -cloud.google.com/go/auth v0.16.3/go.mod h1:NucRGjaXfzP1ltpcQ7On/VTZ0H4kWB5Jy+Y9Dnm76fA= cloud.google.com/go/auth v0.16.4 h1:fXOAIQmkApVvcIn7Pc2+5J8QTMVbUGLscnSVNl11su8= cloud.google.com/go/auth v0.16.4/go.mod h1:j10ncYwjX/g3cdX7GpEzsdM+d+ZNsXAbb6qXA7p1Y5M= +cloud.google.com/go/auth v0.16.5 h1:mFWNQ2FEVWAliEQWpAdH80omXFokmrnbDhUS9cBywsI= +cloud.google.com/go/auth v0.16.5/go.mod h1:utzRfHMP+Vv0mpOkTRQoWD2q3BatTOoWbA7gCc2dUhQ= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= -cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU= -cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo= cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= -dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= -dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= -github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= -github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= +dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= +github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= -github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/biter777/countries v1.7.5 h1:MJ+n3+rSxWQdqVJU8eBy9RqcdH6ePPn4PJHocVWUa+Q= github.com/biter777/countries v1.7.5/go.mod h1:1HSpZ526mYqKJcpT5Ti1kcGQ0L0SrXWIaptUWjFfv2E= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= -github.com/boombuler/barcode v1.0.2 h1:79yrbttoZrLGkL/oOI8hBrUKucwOL0oOjUgEguGMcJ4= -github.com/boombuler/barcode v1.0.2/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -33,6 +27,10 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3 github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= +github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= +github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= +github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= @@ -47,14 +45,14 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/docker/docker v28.0.1+incompatible h1:FCHjSRdXhNRFjlHMTv4jUNlIBbTeRjrWfeFuJp7jpo0= -github.com/docker/docker v28.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= -github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= -github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/docker v28.3.3+incompatible h1:Dypm25kh4rmk49v1eiVbsAtpAsYURjYkaKubwuBdxEI= +github.com/docker/docker v28.3.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= +github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I= -github.com/ebitengine/purego v0.8.2/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= +github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -70,8 +68,8 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= -github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= +github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -82,16 +80,14 @@ github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHO github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg= github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= -github.com/go-viper/mapstructure/v2 v2.3.0 h1:27XbWsHIqhbdR5TIC911OfYvgSaW93HM+dX7970Q7jk= -github.com/go-viper/mapstructure/v2 v2.3.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= -github.com/gofiber/storage/redis/v3 v3.2.0 h1:1cmxmH6ZniZcWHvMpp6LzfcSK5o7CgqiouRqrVCNY9A= -github.com/gofiber/storage/redis/v3 v3.2.0/go.mod h1:fffHK3QnjOxOUZGtq08YVNU1lqKvE+pAKJ5roSnM7FE= github.com/gofiber/storage/redis/v3 v3.4.0 h1:FbtVgHsWkHFaogObFyNbBkNkZL9/zYxQkS1PV0rA5Ss= github.com/gofiber/storage/redis/v3 v3.4.0/go.mod h1:5efv+XbKwSQju9j7tokMgFWZ1JwlZvSsIL4RNJSDyf0= +github.com/gofiber/storage/testhelpers/redis v0.0.0-20250815074620-1386290f7fd5 h1:vC79Z8gkydKoxsq+7+IhnTd3z2J7qs1Zi5wXTP29/C4= +github.com/gofiber/storage/testhelpers/redis v0.0.0-20250815074620-1386290f7fd5/go.mod h1:PU9dj9E5K6+TLw7pF87y4yOf5HUH6S9uxTlhuRAVMEY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= @@ -122,8 +118,6 @@ github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB1 github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= -github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM= @@ -132,8 +126,6 @@ github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9 github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw= github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw= -github.com/hashicorp/hcl v1.0.1-vault-5 h1:kI3hhbbyzr4dldA8UdTb7ZlVVlI2DACdCfz31RPDgJM= -github.com/hashicorp/hcl v1.0.1-vault-5/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I= github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -150,8 +142,8 @@ github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zt github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= -github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= +github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 h1:PpXWgLPs+Fqr325bN2FD2ISlRRztXibcX6e8f5FR5Dc= +github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -166,16 +158,18 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ= +github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo= github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc= -github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5lXtc= -github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo= -github.com/moby/sys/user v0.1.0 h1:WmZ93f5Ux6het5iituh9x2zAG7NFY9Aqi49jjE1PaQg= -github.com/moby/sys/user v0.1.0/go.mod h1:fKJhFOnsCN6xZ5gSfbM6zaHGgDJMrqt9/reuj4T7MmU= +github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU= +github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko= +github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs= +github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs= github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g= github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28= -github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= -github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= +github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= +github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= @@ -188,8 +182,6 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= -github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY= -github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -197,12 +189,10 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= -github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= +github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= -github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs= -github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -210,8 +200,8 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= -github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs= -github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI= +github.com/shirou/gopsutil/v4 v4.25.7 h1:bNb2JuqKuAu3tRlPv5piSmBZyMfecwQ+t/ILq+1JqVM= +github.com/shirou/gopsutil/v4 v4.25.7/go.mod h1:XV/egmwJtd3ZQjBpJVY5kndsiOO4IRqy9TQnmm6VP7U= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -219,22 +209,22 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/testcontainers/testcontainers-go v0.37.0 h1:L2Qc0vkTw2EHWQ08djon0D2uw7Z/PtHS/QzZZ5Ra/hg= -github.com/testcontainers/testcontainers-go v0.37.0/go.mod h1:QPzbxZhQ6Bclip9igjLFj6z0hs01bU8lrl2dHQmgFGM= -github.com/testcontainers/testcontainers-go/modules/redis v0.37.0 h1:9HIY28I9ME/Zmb+zey1p/I1mto5+5ch0wLX+nJdOsQ4= -github.com/testcontainers/testcontainers-go/modules/redis v0.37.0/go.mod h1:Abu9g/25Qv+FkYVx3U4Voaynou1c+7D0HIhaQJXvk6E= +github.com/testcontainers/testcontainers-go v0.38.0 h1:d7uEapLcv2P8AvH8ahLqDMMxda2W9gQN1nRbHS28HBw= +github.com/testcontainers/testcontainers-go v0.38.0/go.mod h1:C52c9MoHpWO+C4aqmgSU+hxlR5jlEayWtgYrb8Pzz1w= +github.com/testcontainers/testcontainers-go/modules/redis v0.38.0 h1:289pn0BFmGqDrd6BrImZAprFef9aaPZacx07YOQaPV4= +github.com/testcontainers/testcontainers-go/modules/redis v0.38.0/go.mod h1:EcKPWRzOglnQfYe+ekA8RPEIWSNJTGwaC5oE5bQV+D0= github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww= github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= -github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= -github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= -github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= -github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4= +github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4= +github.com/tklauser/numcpus v0.10.0 h1:18njr6LDBk1zuna922MgdjQuJFjrdppsZG60sHGfjso= +github.com/tklauser/numcpus v0.10.0/go.mod h1:BiTKazU708GQTYF4mB+cmlpT2Is1gLk7XVuEeem8LsQ= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0= -github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg= github.com/valyala/fasthttp v1.64.0 h1:QBygLLQmiAyiXuRhthf0tuRkqAFcrC42dckN2S+N3og= github.com/valyala/fasthttp v1.64.0/go.mod h1:dGmFxwkWXSK0NbOSJuF7AMVzU+lkHz0wQVvVITv2UQA= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -242,44 +232,30 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY= -go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= -go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E= go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= -go.opentelemetry.io/otel/metric v1.36.0 h1:MoWPKVhQvJ+eeXWHFBOPoBOi20jh6Iq2CcCREuTYufE= -go.opentelemetry.io/otel/metric v1.36.0/go.mod h1:zC7Ks+yeyJt4xig9DEw9kuUFe5C3zLbVjV2PzT6qzbs= go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= -go.opentelemetry.io/otel/sdk v1.36.0 h1:b6SYIuLRs88ztox4EyrvRti80uXIFy+Sqzoh9kFULbs= -go.opentelemetry.io/otel/sdk v1.36.0/go.mod h1:+lC+mTgD+MUWfjJubi2vvXWcVxyr9rmlshZni72pXeY= -go.opentelemetry.io/otel/sdk/metric v1.36.0 h1:r0ntwwGosWGaa0CrSt8cuNuTcccMXERFwHX4dThiPis= -go.opentelemetry.io/otel/sdk/metric v1.36.0/go.mod h1:qTNOhFDfKRwX0yXOqJYegL5WRaW376QbB7P4Pb0qva4= -go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w= -go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= @@ -291,15 +267,15 @@ golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -313,8 +289,6 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= @@ -323,25 +297,17 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.244.0 h1:lpkP8wVibSKr++NCD36XzTk/IzeKJ3klj7vbj+XU5pE= -google.golang.org/api v0.244.0/go.mod h1:dMVhVcylamkirHdzEBAIQWUCgqY885ivNeZYd7VAVr8= google.golang.org/api v0.247.0 h1:tSd/e0QrUlLsrwMKmkbQhYVa109qIintOls2Wh6bngc= google.golang.org/api v0.247.0/go.mod h1:r1qZOPmxXffXg6xS5uhx16Fa/UFY8QU/K4bfKrnvovM= google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250728155136-f173205681a0 h1:MAKi5q709QWfnkkpNQ0M12hYJ1+e8qYVDyowc4U1XZM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250728155136-f173205681a0/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/idp/initial_schema.dbml b/idp/initial_schema.dbml index 2aab25f..f5aebe0 100644 --- a/idp/initial_schema.dbml +++ b/idp/initial_schema.dbml @@ -76,6 +76,7 @@ Enum token_key_type { "email_verification" "password_reset" "2fa_authentication" + "dynamic_registration" } Table token_signing_keys as TS { @@ -325,8 +326,12 @@ Enum account_credentials_scope { "account:users:write" "account:apps:read" "account:apps:write" + "account:apps:configs:read" + "account:apps:configs:write" "account:credentials:read" "account:credentials:write" + "account:credentials:configs:read" + "account:credentials:configs:write" "account:auth_providers:read" } @@ -943,12 +948,10 @@ Table account_dynamic_registration_domains as ADRD { } Ref: ADRD.account_id > A.id [delete: cascade] -Table account_dynamic_registration_domain_codes as ADRDC { +Table dynamic_registration_domain_codes as DRDC { id serial [pk] account_id integer [not null] - account_dynamic_registration_domain_id integer [not null] - verification_host varchar(50) [not null] verification_code text [not null] hmac_secret_id varchar(22) [not null] @@ -960,12 +963,28 @@ Table account_dynamic_registration_domain_codes as ADRDC { Indexes { (account_id) [name: 'account_dynamic_registration_domain_codes_account_id_idx'] - (account_dynamic_registration_domain_id) [name: 'account_dynamic_registration_domain_codes_account_dynamic_registration_domain_id_idx'] + } +} +Ref: DRDC.account_id > A.id [delete: cascade] +Ref: DRDC.hmac_secret_id > AHS.secret_id [delete: cascade, update: cascade] + +Table account_dynamic_registration_domain_codes as ADRDC { + account_dynamic_registration_domain_id integer [not null] + dynamic_registration_domain_code_id integer [not null] + + account_id integer [not null] + created_at timestamptz [not null, default: `now()`] + + Indexes { + (account_dynamic_registration_domain_id, dynamic_registration_domain_code_id) [pk] + (account_id) [name: 'account_dynamic_registration_domain_codes_account_id_idx'] + (account_dynamic_registration_domain_id) [unique, name: 'account_dynamic_registration_domain_codes_account_dynamic_registration_domain_id_uidx'] + (dynamic_registration_domain_code_id) [unique, name: 'account_dynamic_registration_domain_codes_dynamic_registration_domain_code_id_uidx'] } } Ref: ADRDC.account_id > A.id [delete: cascade] Ref: ADRDC.account_dynamic_registration_domain_id > ADRD.id [delete: cascade] -Ref: ADRDC.hmac_secret_id > AHS.secret_id [delete: cascade, update: cascade] +Ref: ADRDC.dynamic_registration_domain_code_id > DRDC.id [delete: cascade] Table app_dynamic_registration_configs as APDRC { id serial [pk] diff --git a/idp/internal/controllers/account_credentials_registration_domains.go b/idp/internal/controllers/account_credentials_registration_domains.go new file mode 100644 index 0000000..6d29ef5 --- /dev/null +++ b/idp/internal/controllers/account_credentials_registration_domains.go @@ -0,0 +1,303 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package controllers + +import ( + "github.com/gofiber/fiber/v2" + + "github.com/tugascript/devlogs/idp/internal/controllers/bodies" + "github.com/tugascript/devlogs/idp/internal/controllers/params" + "github.com/tugascript/devlogs/idp/internal/controllers/paths" + "github.com/tugascript/devlogs/idp/internal/services" + "github.com/tugascript/devlogs/idp/internal/services/dtos" +) + +const ( + accountCredentialsRegistrationDomainsLocation string = "account_credentials_registration_domains" +) + +func (c *Controllers) CreateAccountCredentialsRegistrationDomain(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "CreateAccountDynamicRegistrationDomain") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + body := new(bodies.CreateDynamicRegistrationDomainBody) + if err := ctx.BodyParser(body); err != nil { + return parseRequestErrorResponse(logger, ctx, err) + } + if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { + return validateBodyErrorResponse(logger, ctx, err) + } + + domainDTO, serviceErr := c.services.CreateAccountCredentialsRegistrationDomain( + ctx.UserContext(), + services.CreateAccountCredentialsRegistrationDomainOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + Domain: body.Domain, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusCreated) + return ctx.Status(fiber.StatusCreated).JSON(domainDTO) +} + +func (c *Controllers) ListAccountCredentialsRegistrationDomains(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "ListAccountCredentialsRegistrationDomains") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + queryParams := params.DynamicRegistrationDomainQueryParams{ + Limit: ctx.QueryInt("limit", 10), + Offset: ctx.QueryInt("offset", 0), + Order: ctx.Query("order", "date"), + Search: ctx.Query("search"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &queryParams); err != nil { + return validateQueryParamsErrorResponse(logger, ctx, err) + } + + var domains []dtos.DynamicRegistrationDomainDTO + var count int64 + if queryParams.Search != "" { + domains, count, serviceErr = c.services.FilterAccountCredentialsRegistrationDomains( + ctx.UserContext(), + services.FilterAccountCredentialsRegistrationDomainsOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + Search: queryParams.Search, + Limit: int32(queryParams.Limit), + Offset: int32(queryParams.Offset), + Order: queryParams.Order, + }, + ) + } else { + domains, count, serviceErr = c.services.ListAccountCredentialsRegistrationDomains( + ctx.UserContext(), + services.ListAccountCredentialsRegistrationDomainsOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + Limit: int32(queryParams.Limit), + Offset: int32(queryParams.Offset), + Order: queryParams.Order, + }, + ) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(dtos.NewPaginationDTO( + domains, + count, + c.backendDomain, + paths.AccountsBase+paths.CredentialsBase+paths.DynamicRegistrationBase+paths.Domains, + queryParams.Limit, + queryParams.Offset, + "order", queryParams.Order, + )) +} + +func (c *Controllers) GetAccountCredentialsRegistrationDomain(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "GetAccountCredentialsRegistrationDomain") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + urlParams := params.DynamicRegistrationDomainURLParams{Domain: ctx.Params("domain")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return validateURLParamsErrorResponse(logger, ctx, err) + } + + domainDTO, serviceErr := c.services.GetAccountCredentialsRegistrationDomain( + ctx.UserContext(), + services.GetAccountCredentialsRegistrationDomainOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + Domain: urlParams.Domain, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(domainDTO) +} + +func (c *Controllers) DeleteAccountCredentialsRegistrationDomain(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "DeleteAccountCredentialsRegistrationDomain") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + urlParams := params.DynamicRegistrationDomainURLParams{Domain: ctx.Params("domain")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return validateURLParamsErrorResponse(logger, ctx, err) + } + + if serviceErr := c.services.DeleteAccountCredentialsRegistrationDomain( + ctx.UserContext(), + services.DeleteAccountCredentialsRegistrationDomainOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + Domain: urlParams.Domain, + }, + ); serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusNoContent) + return ctx.SendStatus(fiber.StatusNoContent) +} + +func (c *Controllers) VerifyAccountCredentialsRegistrationDomain(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "VerifyAccountCredentialsRegistrationDomain") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + urlParams := params.DynamicRegistrationDomainURLParams{Domain: ctx.Params("domain")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return validateURLParamsErrorResponse(logger, ctx, err) + } + + domainDTO, serviceErr := c.services.VerifyAccountCredentialsRegistrationDomain( + ctx.UserContext(), + services.VerifyAccountCredentialsRegistrationDomainOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + Domain: urlParams.Domain, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(domainDTO) +} + +func (c *Controllers) UpsertAccountCredentialsRegistrationDomainCode(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "UpsertAccountCredentialsRegistrationDomain") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + body := new(bodies.CreateDynamicRegistrationDomainBody) + if err := ctx.BodyParser(body); err != nil { + return parseRequestErrorResponse(logger, ctx, err) + } + if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { + return validateBodyErrorResponse(logger, ctx, err) + } + + domainDTO, serviceErr := c.services.SaveAccountCredentialsRegistrationDomainCode( + ctx.UserContext(), + services.SaveAccountCredentialsRegistrationDomainCodeOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + Domain: body.Domain, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(domainDTO) +} + +func (c *Controllers) GetAccountCredentialsRegistrationDomainCode(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "GetAccountCredentialsRegistrationDomainCode") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + urlParams := params.DynamicRegistrationDomainURLParams{Domain: ctx.Params("domain")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return validateURLParamsErrorResponse(logger, ctx, err) + } + + codeDTO, serviceErr := c.services.GetAccountCredentialsRegistrationDomainCode( + ctx.UserContext(), + services.GetAccountCredentialsRegistrationDomainCodeOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + Domain: urlParams.Domain, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(codeDTO) +} + +func (c *Controllers) DeleteAccountCredentialsRegistrationDomainCode(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "DeleteAccountCredentialsRegistrationDomainCode") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + urlParams := params.DynamicRegistrationDomainURLParams{Domain: ctx.Params("domain")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return validateURLParamsErrorResponse(logger, ctx, err) + } + + if serviceErr := c.services.DeleteAccountCredentialsRegistrationDomainCode( + ctx.UserContext(), + services.DeleteAccountCredentialsRegistrationDomainCodeOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + Domain: urlParams.Domain, + }, + ); serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusNoContent) + return ctx.SendStatus(fiber.StatusNoContent) +} diff --git a/idp/internal/controllers/account_dynamic_registration_configs.go b/idp/internal/controllers/account_dynamic_registration_configs.go index 1a9f916..cc1cd10 100644 --- a/idp/internal/controllers/account_dynamic_registration_configs.go +++ b/idp/internal/controllers/account_dynamic_registration_configs.go @@ -94,3 +94,33 @@ func (c *Controllers) GetAccountDynamicRegistrationConfig(ctx *fiber.Ctx) error logResponse(logger, ctx, fiber.StatusOK) return ctx.Status(fiber.StatusOK).JSON(&dto) } + +func (c *Controllers) DeleteAccountDynamicRegistrationConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger( + requestID, + accountDynamicRegistrationConfigsLocation, + "DeleteAccountDynamicRegistrationConfig", + ) + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + serviceErr = c.services.DeleteAccountDynamicRegistrationConfig( + ctx.UserContext(), + services.DeleteAccountDynamicRegistrationConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusNoContent) + return ctx.SendStatus(fiber.StatusNoContent) +} diff --git a/idp/internal/controllers/bodies/dynamic_registration_domains.go b/idp/internal/controllers/bodies/dynamic_registration_domains.go new file mode 100644 index 0000000..8966c1b --- /dev/null +++ b/idp/internal/controllers/bodies/dynamic_registration_domains.go @@ -0,0 +1,11 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package bodies + +type CreateDynamicRegistrationDomainBody struct { + Domain string `json:"domain" validate:"required,fqdn,max=250"` +} diff --git a/idp/internal/controllers/middleware.go b/idp/internal/controllers/middleware.go index 0cafc4e..c2d1b90 100644 --- a/idp/internal/controllers/middleware.go +++ b/idp/internal/controllers/middleware.go @@ -201,8 +201,8 @@ func processHost(backendDomain string, host string) (string, error) { } hostArr := strings.Split(host, ".") - if len(hostArr) < 2 { - return "", errors.New("host must contain at least two parts") + if len(hostArr) < 3 { + return "", errors.New("host must contain at least three parts") } username := hostArr[0] diff --git a/idp/internal/controllers/params/dynamic_registration_domains.go b/idp/internal/controllers/params/dynamic_registration_domains.go new file mode 100644 index 0000000..e6a88b3 --- /dev/null +++ b/idp/internal/controllers/params/dynamic_registration_domains.go @@ -0,0 +1,18 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package params + +type DynamicRegistrationDomainURLParams struct { + Domain string `validate:"required,fqdn,max=250"` +} + +type DynamicRegistrationDomainQueryParams struct { + Limit int `validate:"min=1,max=100"` + Offset int `validate:"min=0"` + Order string `validate:"oneof=date domain"` + Search string `validate:"omitempty,min=1,max=250"` +} diff --git a/idp/internal/controllers/paths/domains.go b/idp/internal/controllers/paths/domains.go new file mode 100644 index 0000000..4112b24 --- /dev/null +++ b/idp/internal/controllers/paths/domains.go @@ -0,0 +1,15 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package paths + +const ( + Domains string = "/domains" + + SingleDomain string = "/:domain" + VerifyDomain string = "/:domain/verify" + DomainCode string = "/:domain/code" +) diff --git a/idp/internal/controllers/paths/dynamic_registration.go b/idp/internal/controllers/paths/dynamic_registration.go index b4cffa6..689f090 100644 --- a/idp/internal/controllers/paths/dynamic_registration.go +++ b/idp/internal/controllers/paths/dynamic_registration.go @@ -8,6 +8,4 @@ package paths const ( DynamicRegistrationBase string = "/dynamic-registration" - - Domains string = "/domains" ) diff --git a/idp/internal/providers/crypto/hmac.go b/idp/internal/providers/crypto/hmac.go index e969b08..3e30718 100644 --- a/idp/internal/providers/crypto/hmac.go +++ b/idp/internal/providers/crypto/hmac.go @@ -31,7 +31,9 @@ func decodeHMACSecret(secret string) ([]byte, error) { return base64.StdEncoding.DecodeString(secret) } -type StoreHMACSecret = func(dekID string, secretID string, encryptedSecret string) (int32, *exceptions.ServiceError) +type SecretID = string + +type StoreHMACSecret = func(dekID string, secretID SecretID, encryptedSecret string) (int32, *exceptions.ServiceError) type GenerateHMACSecretOptions struct { RequestID string @@ -78,9 +80,9 @@ func encodeHMACData(data []byte) string { return hex.EncodeToString(data) } -// func decodeHMACData(data string) ([]byte, error) { -// return hex.DecodeString(data) -// } +func decodeHMACData(data string) ([]byte, error) { + return hex.DecodeString(data) +} type GetHMACSecretFN = func() (string, DEKCiphertext, *exceptions.ServiceError) @@ -107,7 +109,7 @@ func (e *Crypto) HMACSha256Hash(ctx context.Context, opts HMACSha256HashOptions) secretID, dekCiphertext, serviceErr := opts.GetHMACSecretFN() if serviceErr != nil { logger.ErrorContext(ctx, "Failed to get HMAC secret", "serviceError", serviceErr) - return exceptions.NewInternalServerError() + return serviceErr } encodedSecret, serviceErr := e.DecryptWithDEK(ctx, DecryptWithDEKOptions{ @@ -120,7 +122,7 @@ func (e *Crypto) HMACSha256Hash(ctx context.Context, opts HMACSha256HashOptions) }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to decrypt HMAC secret", "serviceError", serviceErr) - return exceptions.NewInternalServerError() + return serviceErr } secret, err := decodeHMACSecret(encodedSecret) @@ -139,20 +141,71 @@ func (e *Crypto) HMACSha256Hash(ctx context.Context, opts HMACSha256HashOptions) return nil } -// type HMACSha256CompareHashOptions struct { -// RequestID string -// PlainText string -// GetHMACSecretFN GetHMACSecretFN -// GetDecryptDEKfn GetDEKtoDecrypt -// GetEncryptDEKfn GetDEKtoEncrypt -// StoreReEncryptedHMACSecretFN StoreReEncryptedData -// } - -// func (e *Crypto) HMACSha256CompareHash(ctx context.Context, opts HMACSha256CompareHashOptions) *exceptions.ServiceError { -// logger := utils.BuildLogger(e.logger, utils.LoggerOptions{ -// Location: hmacLocation, -// Method: "HMACSha256CompareHash", -// RequestID: opts.RequestID, -// }) -// logger.DebugContext(ctx, "Comparing HMAC SHA256...") -// } +type GetHMACSecretByIDfn = func(secretID SecretID) (DEKCiphertext, *exceptions.ServiceError) + +type GetHashedSecretFN = func() (SecretID, string, *exceptions.ServiceError) + +type HMACSha256CompareHashOptions struct { + RequestID string + PlainText string + HashedSecretFN GetHashedSecretFN + GetHMACSecretByIDFN GetHMACSecretByIDfn + GetDecryptDEKfn GetDEKtoDecrypt + GetEncryptDEKfn GetDEKtoEncrypt + StoreReEncryptedHMACSecretFN StoreReEncryptedData +} + +func (e *Crypto) HMACSha256CompareHash(ctx context.Context, opts HMACSha256CompareHashOptions) *exceptions.ServiceError { + logger := utils.BuildLogger(e.logger, utils.LoggerOptions{ + Location: hmacLocation, + Method: "HMACSha256CompareHash", + RequestID: opts.RequestID, + }) + logger.DebugContext(ctx, "Comparing HMAC SHA256...") + + secretID, hashedSecret, serviceErr := opts.HashedSecretFN() + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get hashed secret", "serviceError", serviceErr) + return serviceErr + } + + dekCiphertext, serviceErr := opts.GetHMACSecretByIDFN(secretID) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get HMAC secret by ID", "serviceError", serviceErr) + return serviceErr + } + + encodedSecret, serviceErr := e.DecryptWithDEK(ctx, DecryptWithDEKOptions{ + RequestID: opts.RequestID, + GetDecryptDEKfn: opts.GetDecryptDEKfn, + GetEncryptDEKfn: opts.GetEncryptDEKfn, + StoreReEncryptedDataFn: opts.StoreReEncryptedHMACSecretFN, + EntityID: secretID, + Ciphertext: dekCiphertext, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to decrypt HMAC secret", "serviceError", serviceErr) + return serviceErr + } + + secret, err := decodeHMACSecret(encodedSecret) + if err != nil { + logger.ErrorContext(ctx, "Failed to decode HMAC secret", "error", err) + return exceptions.NewInternalServerError() + } + + hashedSecretBytes, err := decodeHMACData(hashedSecret) + if err != nil { + logger.ErrorContext(ctx, "Failed to decode hashed secret", "error", err) + return exceptions.NewInternalServerError() + } + + mac := hmac.New(sha256.New, secret) + mac.Write([]byte(opts.PlainText)) + if !utils.CompareSha256(mac.Sum(nil), hashedSecretBytes) { + logger.WarnContext(ctx, "HMAC SHA256 hash mismatch", "plainText", opts.PlainText, "hashedSecret", hashedSecret) + return exceptions.NewUnauthorizedError() + } + + return nil +} diff --git a/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go b/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go index 448d0b8..1abe6eb 100644 --- a/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go +++ b/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go @@ -7,38 +7,25 @@ package database import ( "context" - "time" ) const createAccountDynamicRegistrationDomainCode = `-- name: CreateAccountDynamicRegistrationDomainCode :exec INSERT INTO "account_dynamic_registration_domain_codes" ( - "account_id", "account_dynamic_registration_domain_id", - "verification_host", - "verification_code", - "verification_prefix", - "hmac_secret_id", - "expires_at" + "dynamic_registration_domain_code_id", + "account_id" ) VALUES ( $1, $2, - $3, - $4, - $5, - $6, - $7 + $3 ) ` type CreateAccountDynamicRegistrationDomainCodeParams struct { - AccountID int32 AccountDynamicRegistrationDomainID int32 - VerificationHost string - VerificationCode string - VerificationPrefix string - HmacSecretID string - ExpiresAt time.Time + DynamicRegistrationDomainCodeID int32 + AccountID int32 } // Copyright (c) 2025 Afonso Barracha @@ -47,14 +34,30 @@ type CreateAccountDynamicRegistrationDomainCodeParams struct { // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. func (q *Queries) CreateAccountDynamicRegistrationDomainCode(ctx context.Context, arg CreateAccountDynamicRegistrationDomainCodeParams) error { - _, err := q.db.Exec(ctx, createAccountDynamicRegistrationDomainCode, - arg.AccountID, - arg.AccountDynamicRegistrationDomainID, - arg.VerificationHost, - arg.VerificationCode, - arg.VerificationPrefix, - arg.HmacSecretID, - arg.ExpiresAt, - ) + _, err := q.db.Exec(ctx, createAccountDynamicRegistrationDomainCode, arg.AccountDynamicRegistrationDomainID, arg.DynamicRegistrationDomainCodeID, arg.AccountID) return err } + +const findDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID = `-- name: FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID :one +SELECT d.id, d.account_id, d.verification_host, d.verification_code, d.hmac_secret_id, d.verification_prefix, d.expires_at, d.created_at, d.updated_at FROM "dynamic_registration_domain_codes" "d" +LEFT JOIN "account_dynamic_registration_domain_codes" "a" ON "d"."id" = "a"."dynamic_registration_domain_code_id" +WHERE "a"."account_dynamic_registration_domain_id" = $1 +LIMIT 1 +` + +func (q *Queries) FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID(ctx context.Context, accountDynamicRegistrationDomainID int32) (DynamicRegistrationDomainCode, error) { + row := q.db.QueryRow(ctx, findDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID, accountDynamicRegistrationDomainID) + var i DynamicRegistrationDomainCode + err := row.Scan( + &i.ID, + &i.AccountID, + &i.VerificationHost, + &i.VerificationCode, + &i.HmacSecretID, + &i.VerificationPrefix, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/idp/internal/providers/database/account_dynamic_registration_domains.sql.go b/idp/internal/providers/database/account_dynamic_registration_domains.sql.go index 47db56a..352dd5a 100644 --- a/idp/internal/providers/database/account_dynamic_registration_domains.sql.go +++ b/idp/internal/providers/database/account_dynamic_registration_domains.sql.go @@ -11,6 +11,38 @@ import ( "github.com/google/uuid" ) +const countAccountDynamicRegistrationDomainsByAccountPublicID = `-- name: CountAccountDynamicRegistrationDomainsByAccountPublicID :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE "account_public_id" = $1 +` + +func (q *Queries) CountAccountDynamicRegistrationDomainsByAccountPublicID(ctx context.Context, accountPublicID uuid.UUID) (int64, error) { + row := q.db.QueryRow(ctx, countAccountDynamicRegistrationDomainsByAccountPublicID, accountPublicID) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countFilteredAccountDynamicRegistrationDomainsByAccountPublicID = `-- name: CountFilteredAccountDynamicRegistrationDomainsByAccountPublicID :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +LIMIT 1 +` + +type CountFilteredAccountDynamicRegistrationDomainsByAccountPublicIDParams struct { + AccountPublicID uuid.UUID + Domain string +} + +func (q *Queries) CountFilteredAccountDynamicRegistrationDomainsByAccountPublicID(ctx context.Context, arg CountFilteredAccountDynamicRegistrationDomainsByAccountPublicIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countFilteredAccountDynamicRegistrationDomainsByAccountPublicID, arg.AccountPublicID, arg.Domain) + var count int64 + err := row.Scan(&count) + return count, err +} + const createAccountDynamicRegistrationDomain = `-- name: CreateAccountDynamicRegistrationDomain :one INSERT INTO "account_dynamic_registration_domains" ( @@ -59,6 +91,116 @@ func (q *Queries) CreateAccountDynamicRegistrationDomain(ctx context.Context, ar return i, err } +const deleteAccountDynamicRegistrationDomain = `-- name: DeleteAccountDynamicRegistrationDomain :exec +DELETE FROM "account_dynamic_registration_domains" +WHERE "id" = $1 +` + +func (q *Queries) DeleteAccountDynamicRegistrationDomain(ctx context.Context, id int32) error { + _, err := q.db.Exec(ctx, deleteAccountDynamicRegistrationDomain, id) + return err +} + +const filterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain = `-- name: FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +ORDER BY "domain" ASC +LIMIT $3 OFFSET $4 +` + +type FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams struct { + AccountPublicID uuid.UUID + Domain string + Limit int32 + Offset int32 +} + +func (q *Queries) FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain(ctx context.Context, arg FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams) ([]AccountDynamicRegistrationDomain, error) { + rows, err := q.db.Query(ctx, filterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain, + arg.AccountPublicID, + arg.Domain, + arg.Limit, + arg.Offset, + ) + if err != nil { + return nil, err + } + defer rows.Close() + items := []AccountDynamicRegistrationDomain{} + for rows.Next() { + var i AccountDynamicRegistrationDomain + if err := rows.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const filterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID = `-- name: FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +ORDER BY "id" DESC +LIMIT $3 OFFSET $4 +` + +type FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams struct { + AccountPublicID uuid.UUID + Domain string + Limit int32 + Offset int32 +} + +func (q *Queries) FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID(ctx context.Context, arg FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams) ([]AccountDynamicRegistrationDomain, error) { + rows, err := q.db.Query(ctx, filterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID, + arg.AccountPublicID, + arg.Domain, + arg.Limit, + arg.Offset, + ) + if err != nil { + return nil, err + } + defer rows.Close() + items := []AccountDynamicRegistrationDomain{} + for rows.Next() { + var i AccountDynamicRegistrationDomain + if err := rows.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const findAccountDynamicRegistrationDomainByAccountPublicIDAndDomain = `-- name: FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain :one SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 LIMIT 1 ` @@ -83,3 +225,116 @@ func (q *Queries) FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain ) return i, err } + +const findPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain = `-- name: FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" +WHERE "account_public_id" = $1 +ORDER BY "domain" ASC +LIMIT $2 OFFSET $3 +` + +type FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams struct { + AccountPublicID uuid.UUID + Limit int32 + Offset int32 +} + +func (q *Queries) FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain(ctx context.Context, arg FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams) ([]AccountDynamicRegistrationDomain, error) { + rows, err := q.db.Query(ctx, findPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain, arg.AccountPublicID, arg.Limit, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + items := []AccountDynamicRegistrationDomain{} + for rows.Next() { + var i AccountDynamicRegistrationDomain + if err := rows.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const findPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID = `-- name: FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" +WHERE "account_public_id" = $1 +ORDER BY "id" DESC +LIMIT $2 OFFSET $3 +` + +type FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams struct { + AccountPublicID uuid.UUID + Limit int32 + Offset int32 +} + +func (q *Queries) FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID(ctx context.Context, arg FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams) ([]AccountDynamicRegistrationDomain, error) { + rows, err := q.db.Query(ctx, findPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID, arg.AccountPublicID, arg.Limit, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + items := []AccountDynamicRegistrationDomain{} + for rows.Next() { + var i AccountDynamicRegistrationDomain + if err := rows.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const verifyAccountDynamicRegistrationDomain = `-- name: VerifyAccountDynamicRegistrationDomain :one +UPDATE "account_dynamic_registration_domains" +SET + "verified_at" = NOW(), + "verification_method" = $2 +WHERE "id" = $1 RETURNING id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at +` + +type VerifyAccountDynamicRegistrationDomainParams struct { + ID int32 + VerificationMethod DomainVerificationMethod +} + +func (q *Queries) VerifyAccountDynamicRegistrationDomain(ctx context.Context, arg VerifyAccountDynamicRegistrationDomainParams) (AccountDynamicRegistrationDomain, error) { + row := q.db.QueryRow(ctx, verifyAccountDynamicRegistrationDomain, arg.ID, arg.VerificationMethod) + var i AccountDynamicRegistrationDomain + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/idp/internal/providers/database/dynamic_registration_domain_codes.sql.go b/idp/internal/providers/database/dynamic_registration_domain_codes.sql.go new file mode 100644 index 0000000..0c7de03 --- /dev/null +++ b/idp/internal/providers/database/dynamic_registration_domain_codes.sql.go @@ -0,0 +1,99 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: dynamic_registration_domain_codes.sql + +package database + +import ( + "context" + "time" +) + +const createDynamicRegistrationDomainCode = `-- name: CreateDynamicRegistrationDomainCode :one + +INSERT INTO "dynamic_registration_domain_codes" ( + "account_id", + "verification_host", + "verification_code", + "verification_prefix", + "hmac_secret_id", + "expires_at" +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6 +) RETURNING "id" +` + +type CreateDynamicRegistrationDomainCodeParams struct { + AccountID int32 + VerificationHost string + VerificationCode string + VerificationPrefix string + HmacSecretID string + ExpiresAt time.Time +} + +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +func (q *Queries) CreateDynamicRegistrationDomainCode(ctx context.Context, arg CreateDynamicRegistrationDomainCodeParams) (int32, error) { + row := q.db.QueryRow(ctx, createDynamicRegistrationDomainCode, + arg.AccountID, + arg.VerificationHost, + arg.VerificationCode, + arg.VerificationPrefix, + arg.HmacSecretID, + arg.ExpiresAt, + ) + var id int32 + err := row.Scan(&id) + return id, err +} + +const deleteDynamicRegistrationDomainCode = `-- name: DeleteDynamicRegistrationDomainCode :exec +DELETE FROM "dynamic_registration_domain_codes" +WHERE "id" = $1 +` + +func (q *Queries) DeleteDynamicRegistrationDomainCode(ctx context.Context, id int32) error { + _, err := q.db.Exec(ctx, deleteDynamicRegistrationDomainCode, id) + return err +} + +const updateDynamicRegistrationDomainCode = `-- name: UpdateDynamicRegistrationDomainCode :exec +UPDATE "dynamic_registration_domain_codes" SET + "verification_host" = $2, + "verification_code" = $3, + "verification_prefix" = $4, + "hmac_secret_id" = $5, + "expires_at" = $6 +WHERE "id" = $1 +` + +type UpdateDynamicRegistrationDomainCodeParams struct { + ID int32 + VerificationHost string + VerificationCode string + VerificationPrefix string + HmacSecretID string + ExpiresAt time.Time +} + +func (q *Queries) UpdateDynamicRegistrationDomainCode(ctx context.Context, arg UpdateDynamicRegistrationDomainCodeParams) error { + _, err := q.db.Exec(ctx, updateDynamicRegistrationDomainCode, + arg.ID, + arg.VerificationHost, + arg.VerificationCode, + arg.VerificationPrefix, + arg.HmacSecretID, + arg.ExpiresAt, + ) + return err +} diff --git a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql index f648f51..7eadfe2 100644 --- a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql +++ b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql @@ -1,6 +1,6 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-08-16T10:00:02.079Z +-- Generated at: 2025-08-17T08:47:57.041Z CREATE TYPE "kek_usage" AS ENUM ( 'global', @@ -30,7 +30,8 @@ CREATE TYPE "token_key_type" AS ENUM ( 'client_credentials', 'email_verification', 'password_reset', - '2fa_authentication' + '2fa_authentication', + 'dynamic_registration' ); CREATE TYPE "two_factor_type" AS ENUM ( @@ -77,8 +78,12 @@ CREATE TYPE "account_credentials_scope" AS ENUM ( 'account:users:write', 'account:apps:read', 'account:apps:write', + 'account:apps:configs:read', + 'account:apps:configs:write', 'account:credentials:read', 'account:credentials:write', + 'account:credentials:configs:read', + 'account:credentials:configs:write', 'account:auth_providers:read' ); @@ -573,10 +578,9 @@ CREATE TABLE "account_dynamic_registration_domains" ( "updated_at" timestamptz NOT NULL DEFAULT (now()) ); -CREATE TABLE "account_dynamic_registration_domain_codes" ( +CREATE TABLE "dynamic_registration_domain_codes" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, - "account_dynamic_registration_domain_id" integer NOT NULL, "verification_host" varchar(50) NOT NULL, "verification_code" text NOT NULL, "hmac_secret_id" varchar(22) NOT NULL, @@ -586,6 +590,14 @@ CREATE TABLE "account_dynamic_registration_domain_codes" ( "updated_at" timestamptz NOT NULL DEFAULT (now()) ); +CREATE TABLE "account_dynamic_registration_domain_codes" ( + "account_dynamic_registration_domain_id" integer NOT NULL, + "dynamic_registration_domain_code_id" integer NOT NULL, + "account_id" integer NOT NULL, + "created_at" timestamptz NOT NULL DEFAULT (now()), + PRIMARY KEY ("account_dynamic_registration_domain_id", "dynamic_registration_domain_code_id") +); + CREATE TABLE "app_dynamic_registration_configs" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, @@ -892,9 +904,13 @@ CREATE INDEX "account_dynamic_registration_domains_domain_idx" ON "account_dynam CREATE UNIQUE INDEX "account_dynamic_registration_domains_account_public_id_domain_uidx" ON "account_dynamic_registration_domains" ("account_public_id", "domain"); +CREATE INDEX "account_dynamic_registration_domain_codes_account_id_idx" ON "dynamic_registration_domain_codes" ("account_id"); + CREATE INDEX "account_dynamic_registration_domain_codes_account_id_idx" ON "account_dynamic_registration_domain_codes" ("account_id"); -CREATE INDEX "account_dynamic_registration_domain_codes_account_dynamic_registration_domain_id_idx" ON "account_dynamic_registration_domain_codes" ("account_dynamic_registration_domain_id"); +CREATE UNIQUE INDEX "account_dynamic_registration_domain_codes_account_dynamic_registration_domain_id_uidx" ON "account_dynamic_registration_domain_codes" ("account_dynamic_registration_domain_id"); + +CREATE UNIQUE INDEX "account_dynamic_registration_domain_codes_dynamic_registration_domain_code_id_uidx" ON "account_dynamic_registration_domain_codes" ("dynamic_registration_domain_code_id"); CREATE INDEX "app_dynamic_registration_configs_account_id_idx" ON "app_dynamic_registration_configs" ("account_id"); @@ -1034,11 +1050,15 @@ ALTER TABLE "account_dynamic_registration_configs" ADD FOREIGN KEY ("account_id" ALTER TABLE "account_dynamic_registration_domains" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; +ALTER TABLE "dynamic_registration_domain_codes" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; + +ALTER TABLE "dynamic_registration_domain_codes" ADD FOREIGN KEY ("hmac_secret_id") REFERENCES "account_hmac_secrets" ("secret_id") ON DELETE CASCADE ON UPDATE CASCADE; + ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("account_dynamic_registration_domain_id") REFERENCES "account_dynamic_registration_domains" ("id") ON DELETE CASCADE; -ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("hmac_secret_id") REFERENCES "account_hmac_secrets" ("secret_id") ON DELETE CASCADE ON UPDATE CASCADE; +ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("dynamic_registration_domain_code_id") REFERENCES "dynamic_registration_domain_codes" ("id") ON DELETE CASCADE; ALTER TABLE "app_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; diff --git a/idp/internal/providers/database/models.go b/idp/internal/providers/database/models.go index 398c09a..0ca084a 100644 --- a/idp/internal/providers/database/models.go +++ b/idp/internal/providers/database/models.go @@ -16,16 +16,20 @@ import ( type AccountCredentialsScope string const ( - AccountCredentialsScopeEmail AccountCredentialsScope = "email" - AccountCredentialsScopeProfile AccountCredentialsScope = "profile" - AccountCredentialsScopeAccountAdmin AccountCredentialsScope = "account:admin" - AccountCredentialsScopeAccountUsersRead AccountCredentialsScope = "account:users:read" - AccountCredentialsScopeAccountUsersWrite AccountCredentialsScope = "account:users:write" - AccountCredentialsScopeAccountAppsRead AccountCredentialsScope = "account:apps:read" - AccountCredentialsScopeAccountAppsWrite AccountCredentialsScope = "account:apps:write" - AccountCredentialsScopeAccountCredentialsRead AccountCredentialsScope = "account:credentials:read" - AccountCredentialsScopeAccountCredentialsWrite AccountCredentialsScope = "account:credentials:write" - AccountCredentialsScopeAccountAuthProvidersRead AccountCredentialsScope = "account:auth_providers:read" + AccountCredentialsScopeEmail AccountCredentialsScope = "email" + AccountCredentialsScopeProfile AccountCredentialsScope = "profile" + AccountCredentialsScopeAccountAdmin AccountCredentialsScope = "account:admin" + AccountCredentialsScopeAccountUsersRead AccountCredentialsScope = "account:users:read" + AccountCredentialsScopeAccountUsersWrite AccountCredentialsScope = "account:users:write" + AccountCredentialsScopeAccountAppsRead AccountCredentialsScope = "account:apps:read" + AccountCredentialsScopeAccountAppsWrite AccountCredentialsScope = "account:apps:write" + AccountCredentialsScopeAccountAppsConfigsRead AccountCredentialsScope = "account:apps:configs:read" + AccountCredentialsScopeAccountAppsConfigsWrite AccountCredentialsScope = "account:apps:configs:write" + AccountCredentialsScopeAccountCredentialsRead AccountCredentialsScope = "account:credentials:read" + AccountCredentialsScopeAccountCredentialsWrite AccountCredentialsScope = "account:credentials:write" + AccountCredentialsScopeAccountCredentialsConfigsRead AccountCredentialsScope = "account:credentials:configs:read" + AccountCredentialsScopeAccountCredentialsConfigsWrite AccountCredentialsScope = "account:credentials:configs:write" + AccountCredentialsScopeAccountAuthProvidersRead AccountCredentialsScope = "account:auth_providers:read" ) func (e *AccountCredentialsScope) Scan(src interface{}) error { @@ -907,13 +911,14 @@ func (ns NullTokenCryptoSuite) Value() (driver.Value, error) { type TokenKeyType string const ( - TokenKeyTypeAccess TokenKeyType = "access" - TokenKeyTypeRefresh TokenKeyType = "refresh" - TokenKeyTypeIDToken TokenKeyType = "id_token" - TokenKeyTypeClientCredentials TokenKeyType = "client_credentials" - TokenKeyTypeEmailVerification TokenKeyType = "email_verification" - TokenKeyTypePasswordReset TokenKeyType = "password_reset" - TokenKeyType2faAuthentication TokenKeyType = "2fa_authentication" + TokenKeyTypeAccess TokenKeyType = "access" + TokenKeyTypeRefresh TokenKeyType = "refresh" + TokenKeyTypeIDToken TokenKeyType = "id_token" + TokenKeyTypeClientCredentials TokenKeyType = "client_credentials" + TokenKeyTypeEmailVerification TokenKeyType = "email_verification" + TokenKeyTypePasswordReset TokenKeyType = "password_reset" + TokenKeyType2faAuthentication TokenKeyType = "2fa_authentication" + TokenKeyTypeDynamicRegistration TokenKeyType = "dynamic_registration" ) func (e *TokenKeyType) Scan(src interface{}) error { @@ -1266,16 +1271,10 @@ type AccountDynamicRegistrationDomain struct { } type AccountDynamicRegistrationDomainCode struct { - ID int32 - AccountID int32 AccountDynamicRegistrationDomainID int32 - VerificationHost string - VerificationCode string - HmacSecretID string - VerificationPrefix string - ExpiresAt time.Time + DynamicRegistrationDomainCodeID int32 + AccountID int32 CreatedAt time.Time - UpdatedAt time.Time } type AccountHmacSecret struct { @@ -1459,6 +1458,18 @@ type DataEncryptionKey struct { UpdatedAt time.Time } +type DynamicRegistrationDomainCode struct { + ID int32 + AccountID int32 + VerificationHost string + VerificationCode string + HmacSecretID string + VerificationPrefix string + ExpiresAt time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + type KeyEncryptionKey struct { ID int32 Kid uuid.UUID diff --git a/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql b/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql index 9c99bdf..a92cd0f 100644 --- a/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql +++ b/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql @@ -6,19 +6,17 @@ -- name: CreateAccountDynamicRegistrationDomainCode :exec INSERT INTO "account_dynamic_registration_domain_codes" ( - "account_id", "account_dynamic_registration_domain_id", - "verification_host", - "verification_code", - "verification_prefix", - "hmac_secret_id", - "expires_at" + "dynamic_registration_domain_code_id", + "account_id" ) VALUES ( $1, $2, - $3, - $4, - $5, - $6, - $7 -); \ No newline at end of file + $3 +); + +-- name: FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID :one +SELECT "d".* FROM "dynamic_registration_domain_codes" "d" +LEFT JOIN "account_dynamic_registration_domain_codes" "a" ON "d"."id" = "a"."dynamic_registration_domain_code_id" +WHERE "a"."account_dynamic_registration_domain_id" = $1 +LIMIT 1; diff --git a/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql b/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql index 65895af..f892c6d 100644 --- a/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql +++ b/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql @@ -18,4 +18,54 @@ INSERT INTO "account_dynamic_registration_domains" ( ) RETURNING *; -- name: FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain :one -SELECT * FROM "account_dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 LIMIT 1; \ No newline at end of file +SELECT * FROM "account_dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 LIMIT 1; + +-- name: VerifyAccountDynamicRegistrationDomain :one +UPDATE "account_dynamic_registration_domains" +SET + "verified_at" = NOW(), + "verification_method" = $2 +WHERE "id" = $1 RETURNING *; + +-- name: FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many +SELECT * FROM "account_dynamic_registration_domains" +WHERE "account_public_id" = $1 +ORDER BY "id" DESC +LIMIT $2 OFFSET $3; + +-- name: FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many +SELECT * FROM "account_dynamic_registration_domains" +WHERE "account_public_id" = $1 +ORDER BY "domain" ASC +LIMIT $2 OFFSET $3; + +-- name: CountAccountDynamicRegistrationDomainsByAccountPublicID :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE "account_public_id" = $1; + +-- name: FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many +SELECT * FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +ORDER BY "id" DESC +LIMIT $3 OFFSET $4; + +-- name: FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many +SELECT * FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +ORDER BY "domain" ASC +LIMIT $3 OFFSET $4; + +-- name: CountFilteredAccountDynamicRegistrationDomainsByAccountPublicID :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +LIMIT 1; + +-- name: DeleteAccountDynamicRegistrationDomain :exec +DELETE FROM "account_dynamic_registration_domains" +WHERE "id" = $1; diff --git a/idp/internal/providers/database/queries/dynamic_registration_domain_codes.sql b/idp/internal/providers/database/queries/dynamic_registration_domain_codes.sql new file mode 100644 index 0000000..9b73f9e --- /dev/null +++ b/idp/internal/providers/database/queries/dynamic_registration_domain_codes.sql @@ -0,0 +1,35 @@ +-- Copyright (c) 2025 Afonso Barracha +-- +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +-- name: CreateDynamicRegistrationDomainCode :one +INSERT INTO "dynamic_registration_domain_codes" ( + "account_id", + "verification_host", + "verification_code", + "verification_prefix", + "hmac_secret_id", + "expires_at" +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6 +) RETURNING "id"; + +-- name: UpdateDynamicRegistrationDomainCode :exec +UPDATE "dynamic_registration_domain_codes" SET + "verification_host" = $2, + "verification_code" = $3, + "verification_prefix" = $4, + "hmac_secret_id" = $5, + "expires_at" = $6 +WHERE "id" = $1; + +-- name: DeleteDynamicRegistrationDomainCode :exec +DELETE FROM "dynamic_registration_domain_codes" +WHERE "id" = $1; diff --git a/idp/internal/providers/tokens/accounts.go b/idp/internal/providers/tokens/accounts.go index 0412b73..050e1e3 100644 --- a/idp/internal/providers/tokens/accounts.go +++ b/idp/internal/providers/tokens/accounts.go @@ -22,16 +22,20 @@ import ( type AccountScope = string const ( - AccountScopeEmail AccountScope = "email" - AccountScopeProfile AccountScope = "profile" - AccountScopeAdmin AccountScope = "account:admin" - AccountScopeUsersRead AccountScope = "account:users:read" - AccountScopeUsersWrite AccountScope = "account:users:write" - AccountScopeAppsRead AccountScope = "account:apps:read" - AccountScopeAppsWrite AccountScope = "account:apps:write" - AccountScopeCredentialsRead AccountScope = "account:credentials:read" - AccountScopeCredentialsWrite AccountScope = "account:credentials:write" - AccountScopeAuthProvidersRead AccountScope = "account:auth_providers:read" + AccountScopeEmail AccountScope = "email" + AccountScopeProfile AccountScope = "profile" + AccountScopeAdmin AccountScope = "account:admin" + AccountScopeUsersRead AccountScope = "account:users:read" + AccountScopeUsersWrite AccountScope = "account:users:write" + AccountScopeAppsRead AccountScope = "account:apps:read" + AccountScopeAppsWrite AccountScope = "account:apps:write" + AccountScopeAppsConfigsRead AccountScope = "account:apps:configs:read" + AccountScopeAppsConfigsWrite AccountScope = "account:apps:configs:write" + AccountScopeCredentialsRead AccountScope = "account:credentials:read" + AccountScopeCredentialsWrite AccountScope = "account:credentials:write" + AccountScopeCredentialsConfigsRead AccountScope = "account:credentials:configs:read" + AccountScopeCredentialsConfigsWrite AccountScope = "account:credentials:configs:write" + AccountScopeAuthProvidersRead AccountScope = "account:auth_providers:read" ) var baseAuthScopes = []AccountScope{AccountScopeEmail, AccountScopeProfile} diff --git a/idp/internal/server/routes.go b/idp/internal/server/routes.go index 5b1729d..ddc942d 100644 --- a/idp/internal/server/routes.go +++ b/idp/internal/server/routes.go @@ -8,7 +8,7 @@ package server func (s *FiberServer) RegisterFiberRoutes() { s.routes.HealthRoutes(s.App) - s.routes.AccountDynamicRegistrationRoutes(s.App) + s.routes.AccountDynamicRegistrationConfigurationRoutes(s.App) s.routes.OAuthRoutes(s.App) s.routes.AuthRoutes(s.App) s.routes.AccountCredentialsRoutes(s.App) diff --git a/idp/internal/server/routes/account_dynamic_registration.go b/idp/internal/server/routes/account_dynamic_registration.go index 3d0c647..915f68c 100644 --- a/idp/internal/server/routes/account_dynamic_registration.go +++ b/idp/internal/server/routes/account_dynamic_registration.go @@ -13,24 +13,72 @@ import ( "github.com/tugascript/devlogs/idp/internal/providers/tokens" ) -func (r *Routes) AccountDynamicRegistrationRoutes(app *fiber.App) { +func (r *Routes) AccountDynamicRegistrationConfigurationRoutes(app *fiber.App) { router := v1PathRouter(app).Group( paths.AccountsBase+paths.CredentialsBase+paths.DynamicRegistrationBase, r.controllers.AccountAccessClaimsMiddleware, - r.controllers.AdminScopeMiddleware, ) - credentialsWriteScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsWrite) - credentialsReadScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsRead) + credentialsConfigsWriteScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsConfigsWrite) + credentialsConfigsReadScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsConfigsRead) + // Dynamic Registration Config router.Get( paths.Config, - credentialsReadScopeMiddleware, + credentialsConfigsReadScopeMiddleware, r.controllers.GetAccountDynamicRegistrationConfig, ) router.Put( paths.Config, - credentialsWriteScopeMiddleware, + credentialsConfigsWriteScopeMiddleware, r.controllers.UpsertAccountDynamicRegistrationConfig, ) + router.Delete( + paths.Config, + credentialsConfigsWriteScopeMiddleware, + r.controllers.DeleteAccountDynamicRegistrationConfig, + ) + + // Dynamic Registration Domains + router.Post( + paths.Domains, + credentialsConfigsWriteScopeMiddleware, + r.controllers.CreateAccountCredentialsRegistrationDomain, + ) + router.Get( + paths.Domains, + credentialsConfigsReadScopeMiddleware, + r.controllers.ListAccountCredentialsRegistrationDomains, + ) + router.Get( + paths.Domains+paths.SingleDomain, + credentialsConfigsReadScopeMiddleware, + r.controllers.GetAccountCredentialsRegistrationDomain, + ) + router.Delete( + paths.Domains+paths.SingleDomain, + credentialsConfigsWriteScopeMiddleware, + r.controllers.DeleteAccountCredentialsRegistrationDomain, + ) + router.Post( + paths.Domains+paths.VerifyDomain, + credentialsConfigsWriteScopeMiddleware, + r.controllers.VerifyAccountCredentialsRegistrationDomain, + ) + // Dynamic Registration Domains Code + router.Get( + paths.Domains+paths.DomainCode, + credentialsConfigsReadScopeMiddleware, + r.controllers.GetAccountCredentialsRegistrationDomainCode, + ) + router.Put( + paths.Domains+paths.DomainCode, + credentialsConfigsWriteScopeMiddleware, + r.controllers.UpsertAccountCredentialsRegistrationDomainCode, + ) + router.Delete( + paths.Domains+paths.DomainCode, + credentialsConfigsWriteScopeMiddleware, + r.controllers.DeleteAccountCredentialsRegistrationDomainCode, + ) } diff --git a/idp/internal/server/validations/scope.go b/idp/internal/server/validations/scope.go index 6627e9d..a67e663 100644 --- a/idp/internal/server/validations/scope.go +++ b/idp/internal/server/validations/scope.go @@ -7,15 +7,16 @@ package validations import ( - "github.com/go-playground/validator/v10" "regexp" + + "github.com/go-playground/validator/v10" ) const singleScopeValidatorTag string = "single_scope" const multipleScopeValidatorTag string = "multiple_scope" -var singleScopeRegex = regexp.MustCompile(`^[a-z\d]+(?:([-_:])[a-z\d]+)*$`) +var singleScopeRegex = regexp.MustCompile(`^[a-z\d]+(?:([-_:\.])[a-z\d]+)*$`) var spacesRegex = regexp.MustCompile(`\s+`) func singleScopeValidator(fl validator.FieldLevel) bool { diff --git a/idp/internal/services/account_credentials_registration.go b/idp/internal/services/account_credentials_registration.go deleted file mode 100644 index 3d57cd8..0000000 --- a/idp/internal/services/account_credentials_registration.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) 2025 Afonso Barracha -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -package services - -import ( - "context" - "fmt" - "slices" - "time" - - "github.com/google/uuid" - - "github.com/tugascript/devlogs/idp/internal/exceptions" - "github.com/tugascript/devlogs/idp/internal/providers/crypto" - "github.com/tugascript/devlogs/idp/internal/providers/database" - "github.com/tugascript/devlogs/idp/internal/services/dtos" - "github.com/tugascript/devlogs/idp/internal/utils" -) - -const ( - accountCredentialsRegistrationDomainLocation string = "account_credentials_registration_domain" - - domainCodeByteLength int = 32 -) - -type CreateAccountCredentialsRegistrationDomainOptions struct { - RequestID string - AccountPublicID uuid.UUID - AccountVersion int32 - Domain string -} - -func (s *Services) CreateAccountCredentialsRegistrationDomain( - ctx context.Context, - opts CreateAccountCredentialsRegistrationDomainOptions, -) (dtos.AccountCredentialsRegistrationDomainDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainLocation, "CreateAccountCredentialsRegistrationDomain").With( - "accountPublicID", opts.AccountPublicID, - "domain", opts.Domain, - ) - logger.InfoContext(ctx, "Creating account credentials registration domain...") - - dynamicRegistrationConfig, serviceErr := s.GetAccountDynamicRegistrationConfig(ctx, GetAccountDynamicRegistrationConfigOptions{ - RequestID: opts.RequestID, - AccountPublicID: opts.AccountPublicID, - }) - if serviceErr != nil { - if serviceErr.Code != exceptions.CodeNotFound { - logger.WarnContext(ctx, "Account dynamic registration config not found", "serviceError", serviceErr) - return dtos.AccountCredentialsRegistrationDomainDTO{}, exceptions.NewNotFoundValidationError("Dynamic registration config not found") - } - return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr - } - if len(dynamicRegistrationConfig.WhitelistedDomains) > 0 && !slices.Contains(dynamicRegistrationConfig.WhitelistedDomains, opts.Domain) { - logger.WarnContext(ctx, "Domain is not whitelisted", "domain", opts.Domain) - return dtos.AccountCredentialsRegistrationDomainDTO{}, exceptions.NewForbiddenValidationError("Domain is not whitelisted") - } - - if _, err := s.database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx, database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams{ - AccountPublicID: opts.AccountPublicID, - Domain: opts.Domain, - }); err != nil { - serviceErr := exceptions.FromDBError(err) - if serviceErr.Code != exceptions.CodeNotFound { - logger.WarnContext(ctx, "Failed to find account dynamic registration domain", "error", err) - return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr - } - } else { - logger.InfoContext(ctx, "Account dynamic registration domain already exists", "domain", opts.Domain) - return dtos.AccountCredentialsRegistrationDomainDTO{}, exceptions.NewConflictError("Account credentials registration domain already exists") - } - - accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ - RequestID: opts.RequestID, - PublicID: opts.AccountPublicID, - Version: opts.AccountVersion, - }) - if serviceErr != nil { - logger.WarnContext(ctx, "Failed to get account ID", "serviceError", serviceErr) - return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr - } - - qrs, txn, err := s.database.BeginTx(ctx) - if err != nil { - logger.ErrorContext(ctx, "Failed to start transaction", "error", err) - return dtos.AccountCredentialsRegistrationDomainDTO{}, exceptions.FromDBError(err) - } - defer func() { - logger.DebugContext(ctx, "Finalizing transaction") - s.database.FinalizeTx(ctx, txn, err, serviceErr) - }() - - domain, err := qrs.CreateAccountDynamicRegistrationDomain(ctx, database.CreateAccountDynamicRegistrationDomainParams{ - AccountID: accountDTO.ID(), - AccountPublicID: opts.AccountPublicID, - Domain: opts.Domain, - VerificationMethod: database.DomainVerificationMethodDnsTxtRecord, - }) - if err != nil { - logger.ErrorContext(ctx, "Failed to create account dynamic registration domain", "error", err) - serviceErr = exceptions.FromDBError(err) - return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr - } - - code, err := utils.GenerateBase64Secret(domainCodeByteLength) - if err != nil { - logger.ErrorContext(ctx, "Failed to generate domain code", "error", err) - serviceErr = exceptions.NewInternalServerError() - return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr - } - - verificationPrefix := fmt.Sprintf("%s-verification", accountDTO.Username) - exp := time.Now().Add(s.accountDomainVerificationTTL) - if serviceErr = s.crypto.HMACSha256Hash(ctx, crypto.HMACSha256HashOptions{ - RequestID: opts.RequestID, - PlainText: code, - GetDecryptDEKfn: s.BuildGetDecAccountDEKFn(ctx, BuildGetDecAccountDEKFnOptions{ - RequestID: opts.RequestID, - AccountID: accountDTO.ID(), - Queries: qrs, - }), - GetEncryptDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ - RequestID: opts.RequestID, - AccountID: accountDTO.ID(), - Queries: qrs, - }), - GetHMACSecretFN: s.BuildGetHMACSecretFN(ctx, BuildGetHMACSecretFNOptions{ - RequestID: opts.RequestID, - AccountID: accountDTO.ID(), - Queries: qrs, - }), - StoreReEncryptedHMACSecretFN: s.BuildUpdateHMACSecretFN(ctx, BuildUpdateHMACSecretFNOptions{ - RequestID: opts.RequestID, - AccountID: accountDTO.ID(), - Queries: qrs, - }), - StoreHashedDataFN: func(secretID string, hashedData string) *exceptions.ServiceError { - if err := qrs.CreateAccountDynamicRegistrationDomainCode(ctx, database.CreateAccountDynamicRegistrationDomainCodeParams{ - AccountID: accountDTO.ID(), - AccountDynamicRegistrationDomainID: domain.ID, - VerificationCode: hashedData, - VerificationPrefix: verificationPrefix, - VerificationHost: s.accountDomainVerificationHost, - HmacSecretID: secretID, - ExpiresAt: exp, - }); err != nil { - logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code", "error", err) - return exceptions.FromDBError(err) - } - return nil - }, - }); serviceErr != nil { - logger.ErrorContext(ctx, "Failed to hash code", "serviceError", serviceErr) - return dtos.AccountCredentialsRegistrationDomainDTO{}, serviceErr - } - - logger.InfoContext(ctx, "Created account dynamic registration domain successfully") - return dtos.MapAccountCredentialsRegistrationDomainToDTOWithCode(&domain, s.accountDomainVerificationHost, verificationPrefix, code, exp), nil -} diff --git a/idp/internal/services/account_credentials_registration_domains.go b/idp/internal/services/account_credentials_registration_domains.go new file mode 100644 index 0000000..ea32e06 --- /dev/null +++ b/idp/internal/services/account_credentials_registration_domains.go @@ -0,0 +1,819 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + "fmt" + "slices" + "time" + + "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/crypto" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/services/dtos" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const ( + accountCredentialsRegistrationDomainsLocation string = "account_credentials_registration_domains" + + domainCodeByteLength int = 32 +) + +type CreateAccountCredentialsRegistrationDomainOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string +} + +func (s *Services) CreateAccountCredentialsRegistrationDomain( + ctx context.Context, + opts CreateAccountCredentialsRegistrationDomainOptions, +) (dtos.DynamicRegistrationDomainDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "CreateAccountCredentialsRegistrationDomain").With( + "accountPublicID", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Creating account credentials registration domain...") + + dynamicRegistrationConfig, serviceErr := s.GetAccountDynamicRegistrationConfig(ctx, GetAccountDynamicRegistrationConfigOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + }) + if serviceErr != nil { + if serviceErr.Code != exceptions.CodeNotFound { + logger.WarnContext(ctx, "Account dynamic registration config not found", "serviceError", serviceErr) + return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewNotFoundValidationError("Dynamic registration config not found") + } + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + if len(dynamicRegistrationConfig.WhitelistedDomains) > 0 && !slices.Contains(dynamicRegistrationConfig.WhitelistedDomains, opts.Domain) { + logger.WarnContext(ctx, "Domain is not whitelisted", "domain", opts.Domain) + return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewForbiddenValidationError("Domain is not whitelisted") + } + + if _, err := s.database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx, database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams{ + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }); err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.WarnContext(ctx, "Failed to find account dynamic registration domain", "error", err) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + } else { + logger.InfoContext(ctx, "Account dynamic registration domain already exists", "domain", opts.Domain) + return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewConflictError("Account credentials registration domain already exists") + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account ID", "serviceError", serviceErr) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return dtos.DynamicRegistrationDomainDTO{}, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + domain, err := qrs.CreateAccountDynamicRegistrationDomain(ctx, database.CreateAccountDynamicRegistrationDomainParams{ + AccountID: accountDTO.ID(), + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + VerificationMethod: database.DomainVerificationMethodDnsTxtRecord, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + code, err := utils.GenerateBase64Secret(domainCodeByteLength) + if err != nil { + logger.ErrorContext(ctx, "Failed to generate domain code", "error", err) + serviceErr = exceptions.NewInternalServerError() + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + verificationPrefix := fmt.Sprintf("%s-verification", accountDTO.Username) + exp := time.Now().Add(s.accountDomainVerificationTTL) + if serviceErr = s.crypto.HMACSha256Hash(ctx, crypto.HMACSha256HashOptions{ + RequestID: opts.RequestID, + PlainText: code, + GetDecryptDEKfn: s.BuildGetDecAccountDEKFn(ctx, BuildGetDecAccountDEKFnOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + Queries: qrs, + }), + GetEncryptDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + Queries: qrs, + }), + GetHMACSecretFN: s.BuildGetHMACSecretFN(ctx, BuildGetHMACSecretFNOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + Queries: qrs, + }), + StoreReEncryptedHMACSecretFN: s.BuildUpdateHMACSecretFN(ctx, BuildUpdateHMACSecretFNOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + Queries: qrs, + }), + StoreHashedDataFN: func(secretID string, hashedData string) *exceptions.ServiceError { + codeID, err := qrs.CreateDynamicRegistrationDomainCode( + ctx, + database.CreateDynamicRegistrationDomainCodeParams{ + AccountID: accountDTO.ID(), + VerificationCode: hashedData, + VerificationPrefix: verificationPrefix, + VerificationHost: s.accountDomainVerificationHost, + HmacSecretID: secretID, + ExpiresAt: exp, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code", "error", err) + return exceptions.FromDBError(err) + } + if err := qrs.CreateAccountDynamicRegistrationDomainCode( + ctx, + database.CreateAccountDynamicRegistrationDomainCodeParams{ + AccountDynamicRegistrationDomainID: domain.ID, + DynamicRegistrationDomainCodeID: codeID, + AccountID: accountDTO.ID(), + }, + ); err != nil { + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code association", "error", err) + return exceptions.FromDBError(err) + } + return nil + }, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to hash code", "serviceError", serviceErr) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Created account dynamic registration domain successfully") + return dtos.MapAccountCredentialsRegistrationDomainToDTOWithCode(&domain, s.accountDomainVerificationHost, verificationPrefix, code, exp), nil +} + +type GetAccountCredentialsRegistrationDomainOptions struct { + RequestID string + AccountPublicID uuid.UUID + Domain string +} + +func (s *Services) GetAccountCredentialsRegistrationDomain( + ctx context.Context, + opts GetAccountCredentialsRegistrationDomainOptions, +) (dtos.DynamicRegistrationDomainDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "GetAccountCredentialsRegistrationDomain").With( + "accountPublicID", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Getting account credentials registration domain...") + + domainDTO, err := s.database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx, database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams{ + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find account dynamic registration domain", "error", err) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + logger.WarnContext(ctx, "Account dynamic registration domain not found", "domain", opts.Domain) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Found account dynamic registration domain", "domain", opts.Domain) + return dtos.MapAccountCredentialsRegistrationDomainToDTO(&domainDTO), nil +} + +type ListAccountCredentialsRegistrationDomainsOptions struct { + RequestID string + AccountPublicID uuid.UUID + Offset int32 + Limit int32 + Order string +} + +func (s *Services) ListAccountCredentialsRegistrationDomains( + ctx context.Context, + opts ListAccountCredentialsRegistrationDomainsOptions, +) ([]dtos.DynamicRegistrationDomainDTO, int64, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "ListAccountCredentialsRegistrationDomains").With( + "accountPublicID", opts.AccountPublicID, + "offset", opts.Offset, + "limit", opts.Limit, + "order", opts.Order, + ) + logger.InfoContext(ctx, "Listing account credentials registration domains...") + + order := utils.Lowered(opts.Order) + var domains []database.AccountDynamicRegistrationDomain + var err error + switch order { + case "date": + domains, err = s.database.FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID( + ctx, + database.FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams{ + AccountPublicID: opts.AccountPublicID, + Limit: opts.Limit, + Offset: opts.Offset, + }, + ) + case "domain": + domains, err = s.database.FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain( + ctx, + database.FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams{ + AccountPublicID: opts.AccountPublicID, + Limit: opts.Limit, + Offset: opts.Offset, + }, + ) + default: + logger.WarnContext(ctx, "Invalid order parameter", "order", opts.Order) + return nil, 0, exceptions.NewValidationError("Invalid order parameter") + } + if err != nil { + logger.ErrorContext(ctx, "Failed to find account dynamic registration domains", "error", err) + return nil, 0, exceptions.FromDBError(err) + } + + count, err := s.database.CountAccountDynamicRegistrationDomainsByAccountPublicID(ctx, opts.AccountPublicID) + if err != nil { + logger.ErrorContext(ctx, "Failed to count account dynamic registration domains", "error", err) + return nil, 0, exceptions.FromDBError(err) + } + + logger.InfoContext(ctx, "Listed account dynamic registration domains successfully") + return utils.MapSlice(domains, dtos.MapAccountCredentialsRegistrationDomainToDTO), count, nil +} + +type FilterAccountCredentialsRegistrationDomainsOptions struct { + RequestID string + AccountPublicID uuid.UUID + Search string + Offset int32 + Limit int32 + Order string +} + +func (s *Services) FilterAccountCredentialsRegistrationDomains( + ctx context.Context, + opts FilterAccountCredentialsRegistrationDomainsOptions, +) ([]dtos.DynamicRegistrationDomainDTO, int64, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "FilterAccountCredentialsRegistrationDomains").With( + "accountPublicID", opts.AccountPublicID, + "search", opts.Search, + "offset", opts.Offset, + "limit", opts.Limit, + "order", opts.Order, + ) + logger.InfoContext(ctx, "Filtering account credentials registration domains...") + + domainSearch := utils.DbSearch(opts.Search) + order := utils.Lowered(opts.Order) + var domains []database.AccountDynamicRegistrationDomain + var err error + + switch order { + case "date": + domains, err = s.database.FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID( + ctx, + database.FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams{ + AccountPublicID: opts.AccountPublicID, + Domain: domainSearch, + Limit: opts.Limit, + Offset: opts.Offset, + }, + ) + case "domain": + domains, err = s.database.FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain( + ctx, + database.FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams{ + AccountPublicID: opts.AccountPublicID, + Domain: domainSearch, + Limit: opts.Limit, + Offset: opts.Offset, + }, + ) + default: + logger.WarnContext(ctx, "Invalid order parameter", "order", opts.Order) + return nil, 0, exceptions.NewValidationError("Invalid order parameter") + } + if err != nil { + logger.ErrorContext(ctx, "Failed to filter account dynamic registration domains", "error", err) + return nil, 0, exceptions.FromDBError(err) + } + + count, err := s.database.CountFilteredAccountDynamicRegistrationDomainsByAccountPublicID( + ctx, + database.CountFilteredAccountDynamicRegistrationDomainsByAccountPublicIDParams{ + AccountPublicID: opts.AccountPublicID, + Domain: domainSearch, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to count filtered account dynamic registration domains", "error", err) + return nil, 0, exceptions.FromDBError(err) + } + + logger.InfoContext(ctx, "Filtered account dynamic registration domains successfully") + return utils.MapSlice(domains, dtos.MapAccountCredentialsRegistrationDomainToDTO), count, nil +} + +type DeleteAccountCredentialsRegistrationDomainOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string +} + +func (s *Services) DeleteAccountCredentialsRegistrationDomain( + ctx context.Context, + opts DeleteAccountCredentialsRegistrationDomainOptions, +) *exceptions.ServiceError { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "DeleteAccountCredentialsRegistrationDomain").With( + "accountPublicID", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Deleting account credentials registration domain...") + + if _, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ + RequestID: "", + PublicID: uuid.UUID{}, + Version: 0, + }); serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account ID", "serviceError", serviceErr) + return serviceErr + } + + domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account credentials registration domain", "error", serviceErr) + return serviceErr + } + if err := s.database.DeleteAccountDynamicRegistrationDomain(ctx, domainDTO.ID()); err != nil { + logger.ErrorContext(ctx, "Failed to delete account dynamic registration domain", "error", err) + return exceptions.FromDBError(err) + } + + logger.InfoContext(ctx, "Deleted account credentials registration domain") + return nil +} + +type VerifyAccountCredentialsRegistrationDomainOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string + VerificationCode string +} + +func (s *Services) VerifyAccountCredentialsRegistrationDomain( + ctx context.Context, + opts VerifyAccountCredentialsRegistrationDomainOptions, +) (dtos.DynamicRegistrationDomainDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "VerifyAccountCredentialsRegistrationDomain").With( + "accountPublicID", opts.AccountPublicID, + "domain", opts.Domain, + "verificationCode", opts.VerificationCode, + ) + logger.InfoContext(ctx, "Verifying account credentials registration domain...") + + domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }) + if serviceErr != nil { + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + if domainDTO.Verified { + logger.InfoContext(ctx, "Account credentials registration domain already verified", "domain", opts.Domain) + return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewConflictError("Account credentials registration domain already verified") + } + + if domainDTO.VerificationMethod != database.DomainVerificationMethodDnsTxtRecord { + logger.WarnContext(ctx, "Invalid verification method", "verificationMethod", domainDTO.VerificationMethod) + return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewValidationError("Invalid verification method") + } + + accountID, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + code, err := s.database.FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID(ctx, domainDTO.ID()) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find account dynamic registration domain code", "error", err) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + logger.WarnContext(ctx, "Account dynamic registration domain code not found", "error", err) + return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewNotFoundValidationError("Account dynamic registration domain code not found") + } + + if code.ExpiresAt.Before(time.Now()) { + logger.WarnContext(ctx, "Account dynamic registration domain code expired", "expiresAt", code.ExpiresAt.Unix()) + if err := s.database.DeleteDynamicRegistrationDomainCode(ctx, code.ID); err != nil { + logger.ErrorContext(ctx, "Failed to delete account dynamic registration domain code", "error", err) + return dtos.DynamicRegistrationDomainDTO{}, exceptions.FromDBError(err) + } + + return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewValidationError("Registration domain code expired, generate a new one") + } + + if serviceErr := s.crypto.HMACSha256CompareHash(ctx, crypto.HMACSha256CompareHashOptions{ + RequestID: opts.RequestID, + PlainText: code.VerificationCode, + HashedSecretFN: func() (string, string, *exceptions.ServiceError) { + return code.HmacSecretID, code.VerificationCode, nil + }, + GetHMACSecretByIDFN: s.BuildGetHMACSecretByIDFN(ctx, BuildGetHMACSecretByIDFNOptions{ + RequestID: opts.RequestID, + AccountID: accountID, + }), + GetDecryptDEKfn: s.BuildGetDecAccountDEKFn(ctx, BuildGetDecAccountDEKFnOptions{ + RequestID: opts.RequestID, + AccountID: accountID, + }), + GetEncryptDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ + RequestID: opts.RequestID, + AccountID: accountID, + }), + StoreReEncryptedHMACSecretFN: s.BuildUpdateHMACSecretFN(ctx, BuildUpdateHMACSecretFNOptions{ + RequestID: opts.RequestID, + AccountID: accountID, + }), + }); serviceErr != nil { + logger.WarnContext(ctx, "Failed to verify account credentials registration domain", "serviceError", serviceErr) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + if serviceErr := s.verifyTXTRecord(ctx, verifyTXTRecordOptions{ + requestID: opts.RequestID, + host: code.VerificationHost, + domain: opts.Domain, + prefix: code.VerificationPrefix, + code: opts.VerificationCode, + }); serviceErr != nil { + logger.WarnContext(ctx, "Failed to verify TXT record", "serviceError", serviceErr) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + domain, err := qrs.VerifyAccountDynamicRegistrationDomain( + ctx, + database.VerifyAccountDynamicRegistrationDomainParams{ + ID: domainDTO.ID(), + VerificationMethod: database.DomainVerificationMethodDnsTxtRecord, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to verify account dynamic registration domain", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + if err = qrs.DeleteDynamicRegistrationDomainCode(ctx, code.ID); err != nil { + logger.ErrorContext(ctx, "Failed to delete account dynamic registration domain code", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Verified account credentials registration domain successfully", "domain", opts.Domain) + return dtos.MapAccountCredentialsRegistrationDomainToDTO(&domain), nil +} + +type GetAccountCredentialsRegistrationDomainCodeOptions struct { + RequestID string + AccountPublicID uuid.UUID + Domain string +} + +func (s *Services) GetAccountCredentialsRegistrationDomainCode( + ctx context.Context, + opts GetAccountCredentialsRegistrationDomainCodeOptions, +) (dtos.DynamicRegistrationDomainCodeDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "GetAccountCredentialsRegistrationDomainCode").With( + "accountPublicID", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Getting account credentials registration domain code...") + + domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }) + if serviceErr != nil { + return dtos.DynamicRegistrationDomainCodeDTO{}, serviceErr + } + + if domainDTO.VerificationMethod != database.DomainVerificationMethodDnsTxtRecord { + logger.WarnContext(ctx, "Invalid verification method", "verificationMethod", domainDTO.VerificationMethod) + return dtos.DynamicRegistrationDomainCodeDTO{}, exceptions.NewValidationError("Invalid verification method") + } + if domainDTO.Verified { + logger.InfoContext(ctx, "Verification code not available for verified domain") + return dtos.DynamicRegistrationDomainCodeDTO{}, exceptions.NewConflictError("Verification code not available for verified domain") + } + + code, err := s.database.FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID(ctx, domainDTO.ID()) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find account dynamic registration domain code", "error", err) + return dtos.DynamicRegistrationDomainCodeDTO{}, serviceErr + } + + logger.WarnContext(ctx, "Account dynamic registration domain code not found", "error", err) + return dtos.DynamicRegistrationDomainCodeDTO{}, exceptions.NewNotFoundValidationError("Account dynamic registration domain code not found") + } + + logger.InfoContext(ctx, "Found account dynamic registration domain code") + return dtos.MapDynamicRegistrationDomainCodeToDTO(&code), nil +} + +type SaveAccountCredentialsRegistrationDomainCodeOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string +} + +func (s *Services) SaveAccountCredentialsRegistrationDomainCode( + ctx context.Context, + opts SaveAccountCredentialsRegistrationDomainCodeOptions, +) (dtos.DynamicRegistrationDomainCodeDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "SaveAccountCredentialsRegistrationDomainCode").With( + "accountPublicID", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Saving account credentials registration domain code...") + + domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }) + if serviceErr != nil { + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find account dynamic registration domain code", "serviceError", serviceErr) + return dtos.DynamicRegistrationDomainCodeDTO{}, serviceErr + } + + logger.WarnContext(ctx, "Account dynamic registration domain not found") + return dtos.DynamicRegistrationDomainCodeDTO{}, exceptions.NewNotFoundValidationError("Account dynamic registration domain code not found") + } + + if domainDTO.VerificationMethod != database.DomainVerificationMethodDnsTxtRecord { + logger.WarnContext(ctx, "Invalid verification method", "verificationMethod", domainDTO.VerificationMethod) + return dtos.DynamicRegistrationDomainCodeDTO{}, exceptions.NewValidationError("Invalid verification method") + } + if domainDTO.Verified { + logger.InfoContext(ctx, "Verification code not available for verified domain") + return dtos.DynamicRegistrationDomainCodeDTO{}, exceptions.NewConflictError("Verification code not available for verified domain") + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + return dtos.DynamicRegistrationDomainCodeDTO{}, serviceErr + } + + verificationCode, err := utils.GenerateBase64Secret(domainCodeByteLength) + if err != nil { + logger.ErrorContext(ctx, "Failed to generate domain code", "error", err) + return dtos.DynamicRegistrationDomainCodeDTO{}, exceptions.NewInternalServerError() + } + + verificationPrefix := fmt.Sprintf("%s-verification", accountDTO.Username) + exp := time.Now().Add(s.accountDomainVerificationTTL) + code, err := s.database.FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID(ctx, domainDTO.ID()) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find account dynamic registration domain code", "error", err) + return dtos.DynamicRegistrationDomainCodeDTO{}, serviceErr + } + + if serviceErr := s.crypto.HMACSha256Hash(ctx, crypto.HMACSha256HashOptions{ + RequestID: opts.RequestID, + PlainText: verificationCode, + GetDecryptDEKfn: s.BuildGetDecAccountDEKFn(ctx, BuildGetDecAccountDEKFnOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + }), + GetEncryptDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + }), + GetHMACSecretFN: s.BuildGetHMACSecretFN(ctx, BuildGetHMACSecretFNOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + }), + StoreReEncryptedHMACSecretFN: s.BuildUpdateHMACSecretFN(ctx, BuildUpdateHMACSecretFNOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + }), + StoreHashedDataFN: func(secretID string, hashedData string) *exceptions.ServiceError { + var serviceErr *exceptions.ServiceError + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + serviceErr = exceptions.FromDBError(err) + return serviceErr + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + codeID, err := qrs.CreateDynamicRegistrationDomainCode( + ctx, + database.CreateDynamicRegistrationDomainCodeParams{ + AccountID: accountDTO.ID(), + VerificationCode: hashedData, + VerificationPrefix: verificationPrefix, + VerificationHost: s.accountDomainVerificationHost, + HmacSecretID: secretID, + ExpiresAt: exp, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code", "error", err) + serviceErr = exceptions.FromDBError(err) + return serviceErr + } + if err := qrs.CreateAccountDynamicRegistrationDomainCode( + ctx, + database.CreateAccountDynamicRegistrationDomainCodeParams{ + AccountDynamicRegistrationDomainID: domainDTO.ID(), + DynamicRegistrationDomainCodeID: codeID, + AccountID: accountDTO.ID(), + }, + ); err != nil { + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code association", "error", err) + serviceErr = exceptions.FromDBError(err) + return serviceErr + } + return nil + }, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to hash code", "serviceError", serviceErr) + return dtos.DynamicRegistrationDomainCodeDTO{}, serviceErr + } + + return dtos.CreateDynamicRegistrationDomainCodeDTO( + s.accountDomainVerificationHost, + verificationPrefix, + verificationCode, + exp, + ), nil + } + + if serviceErr := s.crypto.HMACSha256Hash(ctx, crypto.HMACSha256HashOptions{ + RequestID: opts.RequestID, + PlainText: verificationCode, + GetDecryptDEKfn: s.BuildGetDecAccountDEKFn(ctx, BuildGetDecAccountDEKFnOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + }), + GetEncryptDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + }), + GetHMACSecretFN: s.BuildGetHMACSecretFN(ctx, BuildGetHMACSecretFNOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + }), + StoreReEncryptedHMACSecretFN: s.BuildUpdateHMACSecretFN(ctx, BuildUpdateHMACSecretFNOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + }), + StoreHashedDataFN: func(secretID string, hashedData string) *exceptions.ServiceError { + if err := s.database.UpdateDynamicRegistrationDomainCode( + ctx, + database.UpdateDynamicRegistrationDomainCodeParams{ + ID: code.ID, + VerificationCode: hashedData, + VerificationPrefix: verificationPrefix, + VerificationHost: s.accountDomainVerificationHost, + HmacSecretID: secretID, + ExpiresAt: exp, + }, + ); err != nil { + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code", "error", err) + serviceErr = exceptions.FromDBError(err) + return serviceErr + } + return nil + }, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to hash code", "serviceError", serviceErr) + return dtos.DynamicRegistrationDomainCodeDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Saved account dynamic registration domain code successfully") + return dtos.CreateDynamicRegistrationDomainCodeDTO( + s.accountDomainVerificationHost, + verificationPrefix, + verificationCode, + exp, + ), nil +} + +type DeleteAccountCredentialsRegistrationDomainCodeOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string +} + +func (s *Services) DeleteAccountCredentialsRegistrationDomainCode( + ctx context.Context, + opts DeleteAccountCredentialsRegistrationDomainCodeOptions, +) *exceptions.ServiceError { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "DeleteAccountCredentialsRegistrationDomainCode").With( + "accountPublicID", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Deleting account credentials registration domain...") + + domainCodeDTO, serviceErr := s.GetAccountCredentialsRegistrationDomainCode( + ctx, + GetAccountCredentialsRegistrationDomainCodeOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }, + ) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account credentials registration domain code", "serviceError", serviceErr) + return serviceErr + } + if _, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }); serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account ID", "serviceError", serviceErr) + return serviceErr + } + + if err := s.database.DeleteDynamicRegistrationDomainCode(ctx, domainCodeDTO.ID()); err != nil { + logger.ErrorContext(ctx, "Failed to delete account dynamic registration domain", "error", err) + return exceptions.FromDBError(err) + } + + logger.InfoContext(ctx, "Deleted account credentials registration domain successfully") + return nil +} diff --git a/idp/internal/services/account_dynamic_registration_configs.go b/idp/internal/services/account_dynamic_registration_configs.go index a5d353f..3724bd7 100644 --- a/idp/internal/services/account_dynamic_registration_configs.go +++ b/idp/internal/services/account_dynamic_registration_configs.go @@ -231,3 +231,48 @@ func (s *Services) GetAccountDynamicRegistrationConfig( return dtos.MapAccountDynamicRegistrationConfigToDTO(&accountDynamicRegistrationConfig), nil } + +type DeleteAccountDynamicRegistrationConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 +} + +func (s *Services) DeleteAccountDynamicRegistrationConfig( + ctx context.Context, + opts DeleteAccountDynamicRegistrationConfigOptions, +) *exceptions.ServiceError { + logger := s.buildLogger(opts.RequestID, accountDynamicRegistrationConfigsLocation, "DeleteAccountDynamicRegistrationConfig").With( + "accountPublicID", opts.AccountPublicID, + "accountVersion", opts.AccountVersion, + ) + logger.InfoContext(ctx, "Deleting account dynamic registration config...") + + dynamicRegistratioDTO, serviceErr := s.GetAccountDynamicRegistrationConfig( + ctx, + GetAccountDynamicRegistrationConfigOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + }, + ) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account dynamic registration config", "serviceError", serviceErr) + return serviceErr + } + + if _, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account ID by public ID and version", "serviceError", serviceErr) + return serviceErr + } + + if err := s.database.DeleteAccountDynamicRegistrationConfig(ctx, dynamicRegistratioDTO.ID()); err != nil { + logger.ErrorContext(ctx, "Failed to delete account dynamic registration config", "error", err) + return exceptions.FromDBError(err) + } + + return nil +} diff --git a/idp/internal/services/account_hmac_secrets.go b/idp/internal/services/account_hmac_secrets.go index 5b1be4a..50e9f24 100644 --- a/idp/internal/services/account_hmac_secrets.go +++ b/idp/internal/services/account_hmac_secrets.go @@ -10,7 +10,6 @@ import ( "context" "time" - "github.com/jackc/pgx/v5" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/crypto" "github.com/tugascript/devlogs/idp/internal/providers/database" @@ -33,25 +32,7 @@ func (s *Services) buildStoreAccountHMACSecretFn( logger.InfoContext(ctx, "Building store function for account HMAC secret...") return func(dekID string, secretID string, encryptedSecret string) (int32, *exceptions.ServiceError) { - var qrs *database.Queries - var txn pgx.Tx - var err error - var serviceErr *exceptions.ServiceError - if opts.queries != nil { - qrs = opts.queries - } else { - qrs, txn, err = s.database.BeginTx(ctx) - if err != nil { - logger.ErrorContext(ctx, "Failed to start transaction", "error", err) - return 0, exceptions.FromDBError(err) - } - defer func() { - logger.DebugContext(ctx, "Finalizing transaction") - s.database.FinalizeTx(ctx, txn, err, serviceErr) - }() - } - - id, err := qrs.CreateAccountHMACSecret(ctx, database.CreateAccountHMACSecretParams{ + id, err := s.mapQueries(opts.queries).CreateAccountHMACSecret(ctx, database.CreateAccountHMACSecretParams{ AccountID: opts.accountID, SecretID: secretID, Secret: encryptedSecret, @@ -60,8 +41,7 @@ func (s *Services) buildStoreAccountHMACSecretFn( }) if err != nil { logger.ErrorContext(ctx, "Failed to create account HMAC secret", "error", err) - serviceErr = exceptions.FromDBError(err) - return 0, serviceErr + return 0, exceptions.FromDBError(err) } opts.data["secretID"] = secretID @@ -155,3 +135,35 @@ func (s *Services) BuildUpdateHMACSecretFN( return nil } } + +type BuildGetHMACSecretByIDFNOptions struct { + RequestID string + AccountID int32 + Queries *database.Queries +} + +func (s *Services) BuildGetHMACSecretByIDFN( + ctx context.Context, + opts BuildGetHMACSecretByIDFNOptions, +) crypto.GetHMACSecretByIDfn { + logger := s.buildLogger(opts.RequestID, accountHMACSecretsLocation, "BuildGetHMACSecretByIDFN") + logger.InfoContext(ctx, "Building get HMAC secret by ID function...") + + return func(secretID crypto.SecretID) (crypto.DEKCiphertext, *exceptions.ServiceError) { + logger.InfoContext(ctx, "Getting HMAC secret by ID...", "secretID", secretID) + + secret, err := s.mapQueries(opts.Queries).FindAccountHMACSecretByAccountIDAndSecretID( + ctx, + database.FindAccountHMACSecretByAccountIDAndSecretIDParams{ + AccountID: opts.AccountID, + SecretID: secretID, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to find account HMAC secret", "error", err) + return "", exceptions.FromDBError(err) + } + + return secret.Secret, nil + } +} diff --git a/idp/internal/services/dtos/account_credentials_registration_domain.go b/idp/internal/services/dtos/dynamic_registration_domain.go similarity index 57% rename from idp/internal/services/dtos/account_credentials_registration_domain.go rename to idp/internal/services/dtos/dynamic_registration_domain.go index 3f27857..d15f07b 100644 --- a/idp/internal/services/dtos/account_credentials_registration_domain.go +++ b/idp/internal/services/dtos/dynamic_registration_domain.go @@ -13,11 +13,13 @@ import ( "github.com/tugascript/devlogs/idp/internal/providers/database" ) -type AccountCredentialsRegistrationDomainDTO struct { +type DynamicRegistrationDomainDTO struct { id int32 - Domain string `json:"domain"` - Verified bool `json:"verified"` + Domain string `json:"domain"` + Verified bool `json:"verified"` + VerificationMethod database.DomainVerificationMethod `json:"verification_method"` + VerifiedAt int64 `json:"verified_at,omitempty"` VerificationHost string `json:"verification_host,omitempty"` VerificationPrefix string `json:"verification_prefix,omitempty"` @@ -26,7 +28,7 @@ type AccountCredentialsRegistrationDomainDTO struct { VerificationCodeExpiresAt int64 `json:"verification_code_expires_at,omitempty"` } -func (a *AccountCredentialsRegistrationDomainDTO) ID() int32 { +func (a *DynamicRegistrationDomainDTO) ID() int32 { return a.id } @@ -36,10 +38,11 @@ func MapAccountCredentialsRegistrationDomainToDTOWithCode( verificationPrefix string, verificationCode string, expiresAt time.Time, -) AccountCredentialsRegistrationDomainDTO { - return AccountCredentialsRegistrationDomainDTO{ +) DynamicRegistrationDomainDTO { + return DynamicRegistrationDomainDTO{ id: domain.ID, Domain: domain.Domain, + VerificationMethod: domain.VerificationMethod, VerificationHost: verificationHost, VerificationPrefix: verificationPrefix, VerificationCode: verificationCode, @@ -48,3 +51,20 @@ func MapAccountCredentialsRegistrationDomainToDTOWithCode( Verified: false, } } + +func MapAccountCredentialsRegistrationDomainToDTO( + domain *database.AccountDynamicRegistrationDomain, +) DynamicRegistrationDomainDTO { + verifiedAt := int64(0) + if domain.VerifiedAt.Valid { + verifiedAt = domain.VerifiedAt.Time.Unix() + } + + return DynamicRegistrationDomainDTO{ + id: domain.ID, + Domain: domain.Domain, + Verified: verifiedAt > 0, + VerifiedAt: verifiedAt, + VerificationMethod: domain.VerificationMethod, + } +} diff --git a/idp/internal/services/dtos/dynamic_registration_domain_code.go b/idp/internal/services/dtos/dynamic_registration_domain_code.go new file mode 100644 index 0000000..7fd4ce6 --- /dev/null +++ b/idp/internal/services/dtos/dynamic_registration_domain_code.go @@ -0,0 +1,55 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package dtos + +import ( + "fmt" + "time" + + "github.com/tugascript/devlogs/idp/internal/providers/database" +) + +type DynamicRegistrationDomainCodeDTO struct { + id int32 + + VerificationHost string `json:"verification_host"` + VerificationPrefix string `json:"verification_prefix"` + VerificationCode string `json:"verification_code,omitempty"` + VerificationValue string `json:"verification_value,omitempty"` + VerificationCodeExpiresAt int64 `json:"verification_code_expires_at"` +} + +func (a *DynamicRegistrationDomainCodeDTO) ID() int32 { + return a.id +} + +func MapDynamicRegistrationDomainCodeToDTO( + domainCode *database.DynamicRegistrationDomainCode, +) DynamicRegistrationDomainCodeDTO { + return DynamicRegistrationDomainCodeDTO{ + id: domainCode.ID, + VerificationHost: domainCode.VerificationHost, + VerificationPrefix: domainCode.VerificationPrefix, + VerificationCode: domainCode.VerificationCode, + VerificationCodeExpiresAt: domainCode.ExpiresAt.Unix(), + } +} + +func CreateDynamicRegistrationDomainCodeDTO( + verificationHost string, + verificationPrefix string, + verificationCode string, + expiresAt time.Time, +) DynamicRegistrationDomainCodeDTO { + return DynamicRegistrationDomainCodeDTO{ + VerificationHost: verificationHost, + VerificationPrefix: verificationPrefix, + VerificationCode: verificationCode, + VerificationValue: fmt.Sprintf("%s=%s", verificationPrefix, verificationCode), + VerificationCodeExpiresAt: expiresAt.Unix(), + } +} diff --git a/idp/internal/services/helpers.go b/idp/internal/services/helpers.go index 1fa8be5..6d35bbd 100644 --- a/idp/internal/services/helpers.go +++ b/idp/internal/services/helpers.go @@ -7,7 +7,10 @@ package services import ( + "context" + "fmt" "log/slog" + "net" "net/url" "strings" @@ -19,6 +22,8 @@ import ( ) const ( + helpersLocation string = "helpers" + AuthMethodPrivateKeyJwt string = "private_key_jwt" AuthMethodClientSecretBasic string = "client_secret_basic" AuthMethodClientSecretPost string = "client_secret_post" @@ -244,3 +249,39 @@ func mapEmptyString(str string) pgtype.Text { return pgtype.Text{String: strings.TrimSpace(str), Valid: true} } + +type verifyTXTRecordOptions struct { + requestID string + host string + domain string + prefix string + code string +} + +func (s *Services) verifyTXTRecord( + ctx context.Context, + opts verifyTXTRecordOptions, +) *exceptions.ServiceError { + logger := s.buildLogger(opts.requestID, helpersLocation, "verifyTXTRecord").With( + "host", opts.host, + "domain", opts.domain, + "prefix", opts.prefix, + ) + logger.InfoContext(ctx, "Verifying TXT record...") + + records, err := net.LookupTXT(fmt.Sprintf("%s.%s", opts.host, opts.domain)) + if err != nil { + logger.ErrorContext(ctx, "Failed to lookup TXT record: %s", err) + return exceptions.NewValidationError("TXT record not found") + } + + hashSet := utils.SliceToHashSet(records) + value := fmt.Sprintf("%s=%s", opts.prefix, opts.code) + if !hashSet.Contains(value) { + logger.InfoContext(ctx, "TXT code not found in records") + return exceptions.NewValidationError("TXT code not found in records") + } + + logger.InfoContext(ctx, "TXT code found in records") + return nil +} From e2495364d38ebef58afeb59d71f6a29a53d5f891 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 24 Aug 2025 14:04:59 +1200 Subject: [PATCH 07/23] feat(idp): start adding auth templates --- idp/initial_schema.dbml | 22 + idp/internal/config/config.go | 7 +- idp/internal/config/tokens.go | 36 +- .../bodies/oauth_dynamic_registration.go | 12 + .../templates/account_dynamic_registration.go | 537 ++++++++++++++++++ idp/internal/controllers/templates/login.html | 464 +++++++++++++++ ...ccount_credentials_dynamic_registration.go | 7 + idp/internal/providers/cache/oauth_code.go | 16 +- ...0241213231542_create_initial_schema.up.sql | 26 +- idp/internal/providers/database/models.go | 10 + .../providers/tokens/dynamic_registration.go | 60 ++ idp/internal/providers/tokens/tokens.go | 39 +- idp/internal/server/server.go | 1 + ...ccount_credentials_registration_domains.go | 6 +- .../account_credentials_registration_iat.go | 85 +++ idp/internal/services/auth.go | 6 +- idp/internal/services/dtos/auth_provider.go | 6 +- idp/internal/services/helpers.go | 2 +- idp/internal/utils/secrets.go | 9 + idp/tests/common_test.go | 1 + 20 files changed, 1301 insertions(+), 51 deletions(-) create mode 100644 idp/internal/controllers/bodies/oauth_dynamic_registration.go create mode 100644 idp/internal/controllers/templates/account_dynamic_registration.go create mode 100644 idp/internal/controllers/templates/login.html create mode 100644 idp/internal/providers/cache/account_credentials_dynamic_registration.go create mode 100644 idp/internal/providers/tokens/dynamic_registration.go create mode 100644 idp/internal/services/account_credentials_registration_iat.go diff --git a/idp/initial_schema.dbml b/idp/initial_schema.dbml index f5aebe0..4892a90 100644 --- a/idp/initial_schema.dbml +++ b/idp/initial_schema.dbml @@ -59,6 +59,7 @@ Table data_encryption_keys as DEK { Ref: DEK.kek_kid > KEK.kid [delete: cascade, update: cascade] Enum token_crypto_suite { + "RS256" "ES256" "EdDSA" } @@ -986,6 +987,27 @@ Ref: ADRDC.account_id > A.id [delete: cascade] Ref: ADRDC.account_dynamic_registration_domain_id > ADRD.id [delete: cascade] Ref: ADRDC.dynamic_registration_domain_code_id > DRDC.id [delete: cascade] +Table account_dynamic_registration_software_statement_keys as ADRSK { + id serial [pk] + + account_id integer [not null] + account_public_id uuid [not null] + credentials_key_id integer [not null] + account_dynamic_registration_domain_id integer [not null] + + created_at timestamptz [not null, default: `now()`] + + Indexes { + (account_id) [name: 'account_dynamic_registration_software_statement_keys_account_id_idx'] + (account_public_id) [name: 'account_dynamic_registration_software_statement_keys_account_public_id_idx'] + (credentials_key_id) [unique, name: 'account_dynamic_registration_software_statement_keys_credentials_key_id_uidx'] + (account_dynamic_registration_domain_id) [unique, name: 'account_dynamic_registration_software_statement_keys_account_dynamic_registration_domain_id_uidx'] + } +} +Ref: ADRSK.account_id > A.id [delete: cascade] +Ref: ADRSK.credentials_key_id > CK.id [delete: cascade] +Ref: ADRSK.account_dynamic_registration_domain_id > ADRD.id [delete: cascade] + Table app_dynamic_registration_configs as APDRC { id serial [pk] diff --git a/idp/internal/config/config.go b/idp/internal/config/config.go index 423746d..a0b2d4a 100644 --- a/idp/internal/config/config.go +++ b/idp/internal/config/config.go @@ -174,7 +174,7 @@ func (c *Config) AppsDomainVerificationTTL() int64 { return c.appsDomainVerificationTTL } -var variables = [45]string{ +var variables = [46]string{ "PORT", "ENV", "DEBUG", @@ -197,6 +197,7 @@ var variables = [45]string{ "JWT_RESET_TTL_SEC", "JWT_2FA_TTL_SEC", "JWT_APPS_TTL_SEC", + "JWT_DYNAMIC_REGISTRATION_TTL_SEC", "OPENBAO_URL", "OPENBAO_DEV_TOKEN", "OPENBAO_ROLE_ID", @@ -235,7 +236,7 @@ var optionalVariables = [10]string{ "MICROSOFT_CLIENT_SECRET", } -var numerics = [27]string{ +var numerics = [28]string{ "PORT", "MAX_PROCS", "JWT_ACCESS_TTL_SEC", @@ -245,6 +246,7 @@ var numerics = [27]string{ "JWT_RESET_TTL_SEC", "JWT_2FA_TTL_SEC", "JWT_APPS_TTL_SEC", + "JWT_DYNAMIC_REGISTRATION_TTL_SEC", "RATE_LIMITER_MAX", "RATE_LIMITER_EXP_SEC", "KEK_EXPIRATION_DAYS", @@ -322,6 +324,7 @@ func NewConfig(logger *slog.Logger, envPath string) Config { intMap["JWT_RESET_TTL_SEC"], intMap["JWT_2FA_TTL_SEC"], intMap["JWT_APPS_TTL_SEC"], + intMap["JWT_DYNAMIC_REGISTRATION_TTL_SEC"], ), oAuthProvidersConfig: NewOAuthProviders( NewOAuthProvider(variablesMap["GITHUB_CLIENT_ID"], variablesMap["GITHUB_CLIENT_SECRET"]), diff --git a/idp/internal/config/tokens.go b/idp/internal/config/tokens.go index 96478bf..260ce6a 100644 --- a/idp/internal/config/tokens.go +++ b/idp/internal/config/tokens.go @@ -7,24 +7,26 @@ package config type TokensConfig struct { - accessTTL int64 - accountCredentialsTTL int64 - refreshTTL int64 - confirmTTL int64 - resetTTL int64 - twoFATTL int64 - appsTTL int64 + accessTTL int64 + accountCredentialsTTL int64 + refreshTTL int64 + confirmTTL int64 + resetTTL int64 + twoFATTL int64 + appsTTL int64 + dynamicRegistrationTTL int64 } -func NewTokensConfig(access, accountCredentials, refresh, confirm, reset, twoFA, apps int64) TokensConfig { +func NewTokensConfig(access, accountCredentials, refresh, confirm, reset, twoFA, apps, dynamicRegistration int64) TokensConfig { return TokensConfig{ - accessTTL: access, - accountCredentialsTTL: accountCredentials, - refreshTTL: refresh, - confirmTTL: confirm, - resetTTL: reset, - twoFATTL: twoFA, - appsTTL: apps, + accessTTL: access, + accountCredentialsTTL: accountCredentials, + refreshTTL: refresh, + confirmTTL: confirm, + resetTTL: reset, + twoFATTL: twoFA, + appsTTL: apps, + dynamicRegistrationTTL: dynamicRegistration, } } @@ -57,3 +59,7 @@ func (t TokensConfig) TwoFATTL() int64 { func (t TokensConfig) AppsTTL() int64 { return t.appsTTL } + +func (t TokensConfig) DynamicRegistrationTTL() int64 { + return t.dynamicRegistrationTTL +} diff --git a/idp/internal/controllers/bodies/oauth_dynamic_registration.go b/idp/internal/controllers/bodies/oauth_dynamic_registration.go new file mode 100644 index 0000000..72fc2fb --- /dev/null +++ b/idp/internal/controllers/bodies/oauth_dynamic_registration.go @@ -0,0 +1,12 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package bodies + +type OAuthDynamicClientRegistrationBody struct { + RedirectURIs []string `json:"redirect_uris" validate:"required,min=1,dive,uri"` + ResponseTypes []string `json:"response_types" validate:"required,min=1,dive,oneof=code id_token 'code id_token'"` +} diff --git a/idp/internal/controllers/templates/account_dynamic_registration.go b/idp/internal/controllers/templates/account_dynamic_registration.go new file mode 100644 index 0000000..e5b261d --- /dev/null +++ b/idp/internal/controllers/templates/account_dynamic_registration.go @@ -0,0 +1,537 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package templates + +import ( + "bytes" + "errors" + "fmt" + "html/template" + + "github.com/tugascript/devlogs/idp/internal/controllers/paths" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/services/dtos" +) + +const accountDynamicRegistrationBaseTemplate = ` + + + + + + + + OAuth2.0 Dynamic Registration Authorization + + + +
+
+
+ +
+

{{.Header}}

+
+ %s +
+ + + + +` + +func buildEntryAccountDynamicRegistrationTemplate(body string) string { + return fmt.Sprintf(accountDynamicRegistrationBaseTemplate, body) +} + +const baseAccountDynamicRegistrationLoginTitle = "Account Credentials Dynamic Registration" + +const loginForm = ` +
+ + + +
+` + +const divider = ` +
+ OR +
+` + +const appleLoginButton = ` + +` + +const facebookLoginButton = ` + +` + +const githubLoginButton = ` + +` + +const googleLoginButton = ` + +` + +const microsoftLoginButton = ` + +` + +const accountDynamicRegistrationLoginTemplateName = "login" + +type accountDynamicRegistrationLoginTemplateData struct { + Title string + Header string + LoginURL string + AppleLoginURL string + FacebookLoginURL string + GithubLoginURL string + GoogleLoginURL string + MicrosoftLoginURL string +} + +func BuildAccountDynamicRegistrationLoginTemplate( + account *dtos.AccountDTO, + authProviders []dtos.AuthProviderDTO, +) (string, error) { + if len(authProviders) == 0 { + return "", errors.New("no auth providers found") + } + + baseURL := paths.AccountsBase + "/" + account.PublicID.String() + paths.CredentialsBase + paths.OAuthBase + data := accountDynamicRegistrationLoginTemplateData{ + Title: fmt.Sprintf("%s %s", baseAccountDynamicRegistrationLoginTitle, account.GivenName), + Header: fmt.Sprintf("Confirm Account Credentials Client Registration %s", account.GivenName), + } + baseTemplateBody := "" + for _, provider := range authProviders { + switch provider.Provider { + case database.AuthProviderLocal: + data.LoginURL = baseURL + paths.AuthLogin + if len(baseTemplateBody) == 0 { + baseTemplateBody += loginForm + continue + } + + baseTemplateBody = loginForm + divider + baseTemplateBody + case database.AuthProviderApple: + data.AppleLoginURL = baseURL + paths.OAuthAuth + "?client_id=apple&response_type=code" + baseTemplateBody += appleLoginButton + case database.AuthProviderFacebook: + baseTemplateBody += facebookLoginButton + "?client_id=facebook&response_type=code" + case database.AuthProviderGithub: + baseTemplateBody += githubLoginButton + "?client_id=github&response_type=code" + case database.AuthProviderGoogle: + baseTemplateBody += googleLoginButton + "?client_id=google&response_type=code" + case database.AuthProviderMicrosoft: + baseTemplateBody += microsoftLoginButton + "?client_id=microsoft&response_type=code" + default: + return "", fmt.Errorf("unsupported auth provider: %s", provider.Provider) + } + } + + loginTemplate := buildEntryAccountDynamicRegistrationTemplate(baseTemplateBody) + t, err := template.New(accountDynamicRegistrationLoginTemplateName).Parse(loginTemplate) + if err != nil { + return "", nil + } + var loginTemplateContent bytes.Buffer + if err := t.Execute(&loginTemplateContent, data); err != nil { + return "", err + } + + return loginTemplateContent.String(), nil +} diff --git a/idp/internal/controllers/templates/login.html b/idp/internal/controllers/templates/login.html new file mode 100644 index 0000000..0333dc6 --- /dev/null +++ b/idp/internal/controllers/templates/login.html @@ -0,0 +1,464 @@ + + + + + + + + Login - DevLogs + + + +
+
+
+ +
+

Welcome back {{.Name}}

+
+ +
+ + + + +
+ +
+ OR +
+ +
+ + + + + + + + + +
+
+ + + + + \ No newline at end of file diff --git a/idp/internal/providers/cache/account_credentials_dynamic_registration.go b/idp/internal/providers/cache/account_credentials_dynamic_registration.go new file mode 100644 index 0000000..057a849 --- /dev/null +++ b/idp/internal/providers/cache/account_credentials_dynamic_registration.go @@ -0,0 +1,7 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package cache diff --git a/idp/internal/providers/cache/oauth_code.go b/idp/internal/providers/cache/oauth_code.go index b47cc7b..f4279ee 100644 --- a/idp/internal/providers/cache/oauth_code.go +++ b/idp/internal/providers/cache/oauth_code.go @@ -20,8 +20,14 @@ import ( const ( oauthCodePrefix string = "oauth_code" oauthCodeLocation string = "oauth_code" + + codeByteLen int = 16 ) +func buildOAuthCodeKey(codeID string) string { + return fmt.Sprintf("%s:%s", oauthCodePrefix, utils.Sha256HashHex([]byte(codeID))) +} + type OAuthCodeData struct { Email string `json:"email"` GivenName string `json:"given_name"` @@ -49,8 +55,12 @@ func (c *Cache) GenerateOAuthCode(ctx context.Context, opts GenerateOAuthCodeOpt logger.DebugContext(ctx, "Generating OAuth code...") codeID := utils.Base62UUID() - code := utils.Base62UUID() - key := fmt.Sprintf("%s:%s", oauthCodePrefix, codeID) + code, err := utils.GenerateBase62Secret(codeByteLen) + if err != nil { + logger.ErrorContext(ctx, "Error generating OAuth code", "error", err) + return "", err + } + key := buildOAuthCodeKey(codeID) data := OAuthCodeData{ Email: opts.Email, @@ -98,7 +108,7 @@ func (c *Cache) VerifyOAuthCode(ctx context.Context, opts VerifyOAuthCodeOptions return OAuthCodeData{}, false, nil } - key := fmt.Sprintf("%s:%s", oauthCodePrefix, parts[0]) + key := buildOAuthCodeKey(parts[0]) valByte, err := c.storage.Get(key) if err != nil { logger.ErrorContext(ctx, "Error getting OAuth code", "error", err) diff --git a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql index 7eadfe2..e9f1420 100644 --- a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql +++ b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql @@ -1,6 +1,6 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-08-17T08:47:57.041Z +-- Generated at: 2025-08-18T07:42:22.764Z CREATE TYPE "kek_usage" AS ENUM ( 'global', @@ -14,6 +14,7 @@ CREATE TYPE "dek_usage" AS ENUM ( ); CREATE TYPE "token_crypto_suite" AS ENUM ( + 'RS256', 'ES256', 'EdDSA' ); @@ -598,6 +599,15 @@ CREATE TABLE "account_dynamic_registration_domain_codes" ( PRIMARY KEY ("account_dynamic_registration_domain_id", "dynamic_registration_domain_code_id") ); +CREATE TABLE "account_dynamic_registration_software_statement_keys" ( + "id" serial PRIMARY KEY, + "account_id" integer NOT NULL, + "account_public_id" uuid NOT NULL, + "credentials_key_id" integer NOT NULL, + "account_dynamic_registration_domain_id" integer NOT NULL, + "created_at" timestamptz NOT NULL DEFAULT (now()) +); + CREATE TABLE "app_dynamic_registration_configs" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, @@ -912,6 +922,14 @@ CREATE UNIQUE INDEX "account_dynamic_registration_domain_codes_account_dynamic_r CREATE UNIQUE INDEX "account_dynamic_registration_domain_codes_dynamic_registration_domain_code_id_uidx" ON "account_dynamic_registration_domain_codes" ("dynamic_registration_domain_code_id"); +CREATE INDEX "account_dynamic_registration_software_statement_keys_account_id_idx" ON "account_dynamic_registration_software_statement_keys" ("account_id"); + +CREATE INDEX "account_dynamic_registration_software_statement_keys_account_public_id_idx" ON "account_dynamic_registration_software_statement_keys" ("account_public_id"); + +CREATE UNIQUE INDEX "account_dynamic_registration_software_statement_keys_credentials_key_id_uidx" ON "account_dynamic_registration_software_statement_keys" ("credentials_key_id"); + +CREATE UNIQUE INDEX "account_dynamic_registration_software_statement_keys_account_dynamic_registration_domain_id_uidx" ON "account_dynamic_registration_software_statement_keys" ("account_dynamic_registration_domain_id"); + CREATE INDEX "app_dynamic_registration_configs_account_id_idx" ON "app_dynamic_registration_configs" ("account_id"); CREATE INDEX "user_profiles_app_id_idx" ON "app_profiles" ("app_id"); @@ -1060,6 +1078,12 @@ ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("accoun ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("dynamic_registration_domain_code_id") REFERENCES "dynamic_registration_domain_codes" ("id") ON DELETE CASCADE; +ALTER TABLE "account_dynamic_registration_software_statement_keys" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; + +ALTER TABLE "account_dynamic_registration_software_statement_keys" ADD FOREIGN KEY ("credentials_key_id") REFERENCES "credentials_keys" ("id") ON DELETE CASCADE; + +ALTER TABLE "account_dynamic_registration_software_statement_keys" ADD FOREIGN KEY ("account_dynamic_registration_domain_id") REFERENCES "account_dynamic_registration_domains" ("id") ON DELETE CASCADE; + ALTER TABLE "app_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; ALTER TABLE "app_profiles" ADD FOREIGN KEY ("app_id") REFERENCES "apps" ("id") ON DELETE CASCADE; diff --git a/idp/internal/providers/database/models.go b/idp/internal/providers/database/models.go index 0ca084a..51ff16c 100644 --- a/idp/internal/providers/database/models.go +++ b/idp/internal/providers/database/models.go @@ -869,6 +869,7 @@ func (ns NullSoftwareStatementVerificationMethod) Value() (driver.Value, error) type TokenCryptoSuite string const ( + TokenCryptoSuiteRS256 TokenCryptoSuite = "RS256" TokenCryptoSuiteES256 TokenCryptoSuite = "ES256" TokenCryptoSuiteEdDSA TokenCryptoSuite = "EdDSA" ) @@ -1277,6 +1278,15 @@ type AccountDynamicRegistrationDomainCode struct { CreatedAt time.Time } +type AccountDynamicRegistrationSoftwareStatementKey struct { + ID int32 + AccountID int32 + AccountPublicID uuid.UUID + CredentialsKeyID int32 + AccountDynamicRegistrationDomainID int32 + CreatedAt time.Time +} + type AccountHmacSecret struct { ID int32 AccountID int32 diff --git a/idp/internal/providers/tokens/dynamic_registration.go b/idp/internal/providers/tokens/dynamic_registration.go new file mode 100644 index 0000000..f884816 --- /dev/null +++ b/idp/internal/providers/tokens/dynamic_registration.go @@ -0,0 +1,60 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package tokens + +import ( + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +type accountCredentialsDynamicRegistrationClaims struct { + AccountClaims + Domain string `json:"domain"` + jwt.RegisteredClaims +} + +type AccountCredentialsDynamicRegistrationTokenOptions struct { + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string +} + +func (t *Tokens) CreateAccountCredentialsDynamicRegistrationToken( + opts AccountCredentialsDynamicRegistrationTokenOptions, +) *jwt.Token { + now := time.Now() + iat := jwt.NewNumericDate(now) + exp := jwt.NewNumericDate(now.Add(time.Second * time.Duration(t.dynamicRegistrationTTL))) + return jwt.NewWithClaims( + jwt.SigningMethodEdDSA, + accountCredentialsDynamicRegistrationClaims{ + AccountClaims: AccountClaims{ + AccountID: opts.AccountPublicID, + AccountVersion: opts.AccountVersion, + }, + Domain: opts.Domain, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: fmt.Sprintf("https://%s", t.backendDomain), + Audience: []string{ + fmt.Sprintf("https://%s", opts.Domain), + }, + Subject: opts.Domain, + IssuedAt: iat, + NotBefore: iat, + ExpiresAt: exp, + ID: uuid.NewString(), + }, + }, + ) +} + +func (t *Tokens) GetDynamicRegistrationTTL() int64 { + return t.dynamicRegistrationTTL +} diff --git a/idp/internal/providers/tokens/tokens.go b/idp/internal/providers/tokens/tokens.go index a12d179..cb24f3f 100644 --- a/idp/internal/providers/tokens/tokens.go +++ b/idp/internal/providers/tokens/tokens.go @@ -44,15 +44,16 @@ const ( ) type Tokens struct { - logger *slog.Logger - backendDomain string - accessTTL int64 - accountCredentialsTTL int64 - appsTTL int64 - refreshTTL int64 - confirmationTTL int64 - resetTTL int64 - twoFATTL int64 + logger *slog.Logger + backendDomain string + accessTTL int64 + accountCredentialsTTL int64 + appsTTL int64 + refreshTTL int64 + confirmationTTL int64 + resetTTL int64 + twoFATTL int64 + dynamicRegistrationTTL int64 } func NewTokens( @@ -65,16 +66,18 @@ func NewTokens( confirmationTTL int64, resetTTL int64, twoFATTL int64, + dynamicRegistrationTTL int64, ) *Tokens { return &Tokens{ - logger: logger.With(utils.BaseLayer, logLayer), - accessTTL: accessTTL, - accountCredentialsTTL: accountCredentialsTTL, - appsTTL: appsTTL, - refreshTTL: refreshTTL, - confirmationTTL: confirmationTTL, - resetTTL: resetTTL, - twoFATTL: twoFATTL, - backendDomain: backendDomain, + logger: logger.With(utils.BaseLayer, logLayer), + accessTTL: accessTTL, + accountCredentialsTTL: accountCredentialsTTL, + appsTTL: appsTTL, + refreshTTL: refreshTTL, + confirmationTTL: confirmationTTL, + resetTTL: resetTTL, + twoFATTL: twoFATTL, + backendDomain: backendDomain, + dynamicRegistrationTTL: dynamicRegistrationTTL, } } diff --git a/idp/internal/server/server.go b/idp/internal/server/server.go index 50022a0..b061a7f 100644 --- a/idp/internal/server/server.go +++ b/idp/internal/server/server.go @@ -191,6 +191,7 @@ func New( tokensCfg.ConfirmTTL(), tokensCfg.ResetTTL(), tokensCfg.TwoFATTL(), + tokensCfg.DynamicRegistrationTTL(), ) logger.InfoContext(ctx, "Finished building JWT tokens keys") diff --git a/idp/internal/services/account_credentials_registration_domains.go b/idp/internal/services/account_credentials_registration_domains.go index ea32e06..1ea7811 100644 --- a/idp/internal/services/account_credentials_registration_domains.go +++ b/idp/internal/services/account_credentials_registration_domains.go @@ -545,11 +545,7 @@ func (s *Services) GetAccountCredentialsRegistrationDomainCode( ) logger.InfoContext(ctx, "Getting account credentials registration domain code...") - domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ - RequestID: opts.RequestID, - AccountPublicID: opts.AccountPublicID, - Domain: opts.Domain, - }) + domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions(opts)) if serviceErr != nil { return dtos.DynamicRegistrationDomainCodeDTO{}, serviceErr } diff --git a/idp/internal/services/account_credentials_registration_iat.go b/idp/internal/services/account_credentials_registration_iat.go new file mode 100644 index 0000000..6f2b522 --- /dev/null +++ b/idp/internal/services/account_credentials_registration_iat.go @@ -0,0 +1,85 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + + "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/crypto" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/providers/tokens" +) + +const accountCredentialsRegistrationIATLocation = "account_credentials_registration_iat" + +type CreateAccountCredentialsRegistrationIATOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string +} + +func (s *Services) CreateAccountCredentialsRegistrationIAT( + ctx context.Context, + opts CreateAccountCredentialsRegistrationIATOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationIATLocation, "CreateAccountCredentialsRegistrationIAT").With( + "accountPublicId", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Creating account credentials registration IAT...") + + domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) + return "", serviceErr + } + if !domainDTO.Verified { + logger.ErrorContext(ctx, "Account credentials registration domain is not verified") + return "", exceptions.NewValidationError("account credentials registration domain is not verified") + } + + if _, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account", "serviceError", serviceErr) + return "", serviceErr + } + + signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ + RequestID: opts.RequestID, + Token: s.jwt.CreateAccountCredentialsDynamicRegistrationToken(tokens.AccountCredentialsDynamicRegistrationTokenOptions{ + AccountPublicID: opts.AccountPublicID, + AccountVersion: opts.AccountVersion, + Domain: opts.Domain, + }), + GetJWKfn: s.BuildGetGlobalEncryptedJWKFn(ctx, BuildEncryptedJWKFnOptions{ + RequestID: opts.RequestID, + KeyType: database.TokenKeyTypeDynamicRegistration, + TTL: s.jwt.GetDynamicRegistrationTTL(), + }), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.RequestID), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.RequestID), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.RequestID), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to sign account credentials registration IAT", "serviceError", serviceErr) + return "", serviceErr + } + + logger.InfoContext(ctx, "Created account credentials registration IAT successfully") + return signedToken, nil +} diff --git a/idp/internal/services/auth.go b/idp/internal/services/auth.go index fd113d3..63c99a5 100644 --- a/idp/internal/services/auth.go +++ b/idp/internal/services/auth.go @@ -127,7 +127,7 @@ func (s *Services) sendConfirmationEmail( accountDTO *dtos.AccountDTO, ) *exceptions.ServiceError { logger.InfoContext(ctx, "Sending confirmation email...") - signedToken, err := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ + signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ RequestID: requestID, Token: s.jwt.CreateConfirmationToken(tokens.AccountConfirmationTokenOptions{ PublicID: accountDTO.PublicID, @@ -142,8 +142,8 @@ func (s *Services) sendConfirmationEmail( GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, requestID), StoreFN: s.BuildUpdateJWKDEKFn(ctx, requestID), }) - if err != nil { - logger.ErrorContext(ctx, "Failed to sign confirmation token", "error", err) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to sign confirmation token", "serviceError", serviceErr) return exceptions.NewInternalServerError() } diff --git a/idp/internal/services/dtos/auth_provider.go b/idp/internal/services/dtos/auth_provider.go index fd768d7..64b34b3 100644 --- a/idp/internal/services/dtos/auth_provider.go +++ b/idp/internal/services/dtos/auth_provider.go @@ -13,8 +13,8 @@ import ( ) type AuthProviderDTO struct { - Provider string `json:"provider"` - RegisteredAt string `json:"registered_at"` + Provider database.AuthProvider `json:"provider"` + RegisteredAt string `json:"registered_at"` id int32 } @@ -26,7 +26,7 @@ func (a *AuthProviderDTO) ID() int32 { func MapAccountAuthProviderToDTO(provider *database.AccountAuthProvider) AuthProviderDTO { return AuthProviderDTO{ id: provider.ID, - Provider: string(provider.Provider), + Provider: provider.Provider, RegisteredAt: provider.CreatedAt.Format(time.RFC3339), } } diff --git a/idp/internal/services/helpers.go b/idp/internal/services/helpers.go index 6d35bbd..d67bb09 100644 --- a/idp/internal/services/helpers.go +++ b/idp/internal/services/helpers.go @@ -271,7 +271,7 @@ func (s *Services) verifyTXTRecord( records, err := net.LookupTXT(fmt.Sprintf("%s.%s", opts.host, opts.domain)) if err != nil { - logger.ErrorContext(ctx, "Failed to lookup TXT record: %s", err) + logger.ErrorContext(ctx, "Failed to lookup TXT record", "error", err) return exceptions.NewValidationError("TXT record not found") } diff --git a/idp/internal/utils/secrets.go b/idp/internal/utils/secrets.go index 89a5d60..558b959 100644 --- a/idp/internal/utils/secrets.go +++ b/idp/internal/utils/secrets.go @@ -32,6 +32,15 @@ func GenerateBase64Secret(byteLen int) (string, error) { return base64.RawURLEncoding.EncodeToString(randomBytes), nil } +func GenerateBase62Secret(byteLen int) (string, error) { + randomBytes, err := GenerateRandomBytes(byteLen) + if err != nil { + return "", err + } + + return Base62Encode(randomBytes), nil +} + func DecodeBase64Secret(secret string) ([]byte, error) { decoded, err := base64.RawURLEncoding.DecodeString(secret) if err != nil { diff --git a/idp/tests/common_test.go b/idp/tests/common_test.go index 8edc2d8..1fc2e99 100644 --- a/idp/tests/common_test.go +++ b/idp/tests/common_test.go @@ -142,6 +142,7 @@ func initTestServicesAndApp(t *testing.T) { tokensCfg.ConfirmTTL(), tokensCfg.ResetTTL(), tokensCfg.TwoFATTL(), + tokensCfg.DynamicRegistrationTTL(), ) logger.InfoContext(ctx, "Finished building JWT tokens keys") From 4ca38e933a4ea33bb54252213f702ad4cd05eff1 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 24 Aug 2025 23:50:19 +1200 Subject: [PATCH 08/23] feat(idp): add init registration IAT for code exchange --- .../bodies/oauth_dynamic_registration.go | 18 +++- ...ccount_credentials_dynamic_registration.go | 102 ++++++++++++++++++ .../providers/tokens/dynamic_registration.go | 8 +- .../account_credentials_registration_iat.go | 71 ++++++++++++ .../services/oauth_dynamic_registration.go | 26 +++++ .../templates/account_dynamic_registration.go | 66 ++++++------ .../templates/login.html | 55 +++------- 7 files changed, 267 insertions(+), 79 deletions(-) create mode 100644 idp/internal/services/oauth_dynamic_registration.go rename idp/internal/{controllers => services}/templates/account_dynamic_registration.go (91%) rename idp/internal/{controllers => services}/templates/login.html (90%) diff --git a/idp/internal/controllers/bodies/oauth_dynamic_registration.go b/idp/internal/controllers/bodies/oauth_dynamic_registration.go index 72fc2fb..23a8fd7 100644 --- a/idp/internal/controllers/bodies/oauth_dynamic_registration.go +++ b/idp/internal/controllers/bodies/oauth_dynamic_registration.go @@ -7,6 +7,20 @@ package bodies type OAuthDynamicClientRegistrationBody struct { - RedirectURIs []string `json:"redirect_uris" validate:"required,min=1,dive,uri"` - ResponseTypes []string `json:"response_types" validate:"required,min=1,dive,oneof=code id_token 'code id_token'"` + RedirectURIs []string `json:"redirect_uris" validate:"required,min=1,dive,uri"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty" validate:"omitempty,oneof=none client_secret_basic client_secret_post client_secret_jwt private_key_jwt"` + ResponseTypes []string `json:"response_types,omitempty" validate:"omitempty,dive,oneof=none code"` + GrantTypes []string `json:"grant_types,omitempty" validate:"omitempty,min=1,dive,oneof=authorization_code refresh_token client_credentials urn:ietf:params:oauth:grant-type:jwt-bearer"` + ApplicationType string `json:"application_type" validate:"required,oneof=native service mcp"` + ClientName string `json:"client_name" validate:"required,min=1,max=255"` + ClientURI string `json:"client_uri" validate:"required,url"` + LogoURI string `json:"logo_uri,omitempty" validate:"omitempty,url"` + Scope string `json:"scope" validate:"required,multiple_scope"` + Contacts []string `json:"contacts,omitempty" validate:"omitempty,unique,dive,email"` + TOSURI string `json:"tos_uri,omitempty" validate:"omitempty,url"` + PolicyURI string `json:"policy_uri,omitempty" validate:"omitempty,url"` + JWKsURI string `json:"jwks_uri,omitempty" validate:"omitempty,url"` + JWKs []string `json:"jwks,omitempty" validate:"omitempty,json"` + SoftwareID string `json:"software_id,omitempty" validate:"omitempty,max=250"` + SoftwareVersion string `json:"software_version,omitempty" validate:"omitempty,max=250"` } diff --git a/idp/internal/providers/cache/account_credentials_dynamic_registration.go b/idp/internal/providers/cache/account_credentials_dynamic_registration.go index 057a849..d1e61a2 100644 --- a/idp/internal/providers/cache/account_credentials_dynamic_registration.go +++ b/idp/internal/providers/cache/account_credentials_dynamic_registration.go @@ -5,3 +5,105 @@ // file, You can obtain one at https://mozilla.org/MPL/2.0/. package cache + +import ( + "context" + "fmt" + "strings" + + "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const ( + accountCredentialsDynamicRegistrationLocation string = "account_credentials_dynamic_registration" + + accountCredentialsDynamicRegistrationIATPrefix string = "account_credentials_dynamic_registration_iat" +) + +func buildAccountCredentialsDynamicRegistrationIATCacheKey(clientID string) string { + return fmt.Sprintf("%s:%s", accountCredentialsDynamicRegistrationIATPrefix, clientID) +} + +func buildAccountCredentialsDynamicRegistrationIATData(accountPublicID uuid.UUID, domain string) []byte { + return []byte(fmt.Sprintf("%s|%s", accountPublicID.String(), domain)) +} + +type SaveAccountCredentialsDynamicRegistrationIATOptions struct { + RequestID string + AccountPublicID uuid.UUID + Domain string +} + +func (c *Cache) SaveAccountCredentialsDynamicRegistrationIAT( + ctx context.Context, + opts SaveAccountCredentialsDynamicRegistrationIATOptions, +) (string, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "SaveAccountCredentialsDynamicRegistrationIAT", + RequestID: opts.RequestID, + }).With( + "accountPublicId", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.DebugContext(ctx, "Saving account credentials dynamic registration IAT sessions...") + clientID := utils.Base62UUID() + return clientID, c.storage.Set( + buildAccountCredentialsDynamicRegistrationIATCacheKey(clientID), + buildAccountCredentialsDynamicRegistrationIATData(opts.AccountPublicID, opts.Domain), + c.oauthCodeTTL, + ) +} + +func parseAccountCredentialsDynamicRegistrationIATData(data []byte) (uuid.UUID, string, error) { + parsedData := strings.Split(string(data), "|") + if len(parsedData) != 2 { + return uuid.Nil, "", fmt.Errorf("invalid account credentials dynamic registration IAT data") + } + + accountPublicID, err := uuid.Parse(parsedData[0]) + if err != nil { + return uuid.Nil, "", fmt.Errorf("invalid account public ID in account credentials dynamic registration IAT data: %w", err) + } + + return accountPublicID, parsedData[1], nil +} + +type GetAccountCredentialsDynamicRegistrationIATOptions struct { + RequestID string + ClientID string +} + +func (c *Cache) GetAccountCredentialsDynamicRegistrationIAT( + ctx context.Context, + opts GetAccountCredentialsDynamicRegistrationIATOptions, +) (bool, uuid.UUID, string, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "GetAccountCredentialsDynamicRegistrationIAT", + RequestID: opts.RequestID, + }).With( + "clientId", opts.ClientID, + ) + logger.DebugContext(ctx, "Getting account credentials dynamic registration IAT...") + + data, err := c.storage.Get(buildAccountCredentialsDynamicRegistrationIATCacheKey(opts.ClientID)) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT", "error", err) + return false, uuid.Nil, "", err + } + if data == nil { + logger.DebugContext(ctx, "Account credentials dynamic registration IAT not found") + return false, uuid.Nil, "", nil + } + + accountPublicID, domain, err := parseAccountCredentialsDynamicRegistrationIATData(data) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse account credentials dynamic registration IAT data", "error", err) + return false, uuid.Nil, "", err + } + + return true, accountPublicID, domain, nil +} diff --git a/idp/internal/providers/tokens/dynamic_registration.go b/idp/internal/providers/tokens/dynamic_registration.go index f884816..dd05f44 100644 --- a/idp/internal/providers/tokens/dynamic_registration.go +++ b/idp/internal/providers/tokens/dynamic_registration.go @@ -12,11 +12,14 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/utils" ) type accountCredentialsDynamicRegistrationClaims struct { AccountClaims - Domain string `json:"domain"` + Domain string `json:"domain"` + ClientID string `json:"client_id"` jwt.RegisteredClaims } @@ -39,7 +42,8 @@ func (t *Tokens) CreateAccountCredentialsDynamicRegistrationToken( AccountID: opts.AccountPublicID, AccountVersion: opts.AccountVersion, }, - Domain: opts.Domain, + Domain: opts.Domain, + ClientID: utils.Base62UUID(), RegisteredClaims: jwt.RegisteredClaims{ Issuer: fmt.Sprintf("https://%s", t.backendDomain), Audience: []string{ diff --git a/idp/internal/services/account_credentials_registration_iat.go b/idp/internal/services/account_credentials_registration_iat.go index 6f2b522..87b1097 100644 --- a/idp/internal/services/account_credentials_registration_iat.go +++ b/idp/internal/services/account_credentials_registration_iat.go @@ -12,9 +12,11 @@ import ( "github.com/google/uuid" "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/cache" "github.com/tugascript/devlogs/idp/internal/providers/crypto" "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/providers/tokens" + "github.com/tugascript/devlogs/idp/internal/services/templates" ) const accountCredentialsRegistrationIATLocation = "account_credentials_registration_iat" @@ -83,3 +85,72 @@ func (s *Services) CreateAccountCredentialsRegistrationIAT( logger.InfoContext(ctx, "Created account credentials registration IAT successfully") return signedToken, nil } + +type InitiateAccountCredentialsRegistrationIATOptions struct { + RequestID string + AccountPublicID uuid.UUID + Domain string +} + +func (s *Services) InitiateAccountCredentialsRegistrationIAT( + ctx context.Context, + opts InitiateAccountCredentialsRegistrationIATOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationIATLocation, "InitiateAccountCredentialsRegistrationIAT").With( + "accountPublicId", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Initiating account credentials registration IAT generation...") + + domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) + return "", serviceErr + } + if !domainDTO.Verified { + logger.ErrorContext(ctx, "Account credentials registration domain is not verified") + return "", exceptions.NewValidationError("account credentials registration domain is not verified") + } + + accountDTO, serviceErr := s.GetAccountByPublicID(ctx, GetAccountByPublicIDOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account", "serviceError", serviceErr) + return "", serviceErr + } + + authProviders, serviceErr := s.ListAccountAuthProviders(ctx, ListAccountAuthProvidersOptions{ + RequestID: opts.RequestID, + PublicID: accountDTO.PublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to list account auth providers", "serviceError", serviceErr) + return "", serviceErr + } + + clientID, err := s.cache.SaveAccountCredentialsDynamicRegistrationIAT( + ctx, + cache.SaveAccountCredentialsDynamicRegistrationIATOptions{ + RequestID: opts.RequestID, + AccountPublicID: accountDTO.PublicID, + Domain: opts.Domain, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to save account credentials dynamic registration IAT", "error", err) + return "", exceptions.NewInternalServerError() + } + + loginHTML, err := templates.BuildAccountDynamicRegistrationLoginTemplate(clientID, &accountDTO, authProviders) + if err != nil { + logger.ErrorContext(ctx, "Failed to build account dynamic registration login template", "error", err) + return "", exceptions.NewInternalServerError() + } + return loginHTML, nil +} diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go new file mode 100644 index 0000000..e19b7c6 --- /dev/null +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -0,0 +1,26 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +type OauthDynamicRegistrationOptions struct { + RedirectURIs []string + TokenEndpointAuthMethod string + ResponseTypes []string + GrantTypes []string + ApplicationType string + ClientName string + ClientURI string + LogoURI string + Scope string + Contacts []string + TOSURI string + PolicyURI string + JWKsURI string + JWKs []string + SoftwareID string + SoftwareVersion string +} diff --git a/idp/internal/controllers/templates/account_dynamic_registration.go b/idp/internal/services/templates/account_dynamic_registration.go similarity index 91% rename from idp/internal/controllers/templates/account_dynamic_registration.go rename to idp/internal/services/templates/account_dynamic_registration.go index e5b261d..f4e66c0 100644 --- a/idp/internal/controllers/templates/account_dynamic_registration.go +++ b/idp/internal/services/templates/account_dynamic_registration.go @@ -281,16 +281,18 @@ const accountDynamicRegistrationBaseTemplate = ` %s - @@ -391,49 +387,49 @@ const divider = ` ` const appleLoginButton = ` - ` const facebookLoginButton = ` - ` const githubLoginButton = ` - ` const googleLoginButton = ` - ` const microsoftLoginButton = ` - @@ -484,6 +479,7 @@ type accountDynamicRegistrationLoginTemplateData struct { } func BuildAccountDynamicRegistrationLoginTemplate( + clientID string, account *dtos.AccountDTO, authProviders []dtos.AuthProviderDTO, ) (string, error) { @@ -491,7 +487,7 @@ func BuildAccountDynamicRegistrationLoginTemplate( return "", errors.New("no auth providers found") } - baseURL := paths.AccountsBase + "/" + account.PublicID.String() + paths.CredentialsBase + paths.OAuthBase + baseURL := paths.AccountsBase + paths.CredentialsBase + "/" + clientID + paths.OAuthBase data := accountDynamicRegistrationLoginTemplateData{ Title: fmt.Sprintf("%s %s", baseAccountDynamicRegistrationLoginTitle, account.GivenName), Header: fmt.Sprintf("Confirm Account Credentials Client Registration %s", account.GivenName), @@ -511,13 +507,17 @@ func BuildAccountDynamicRegistrationLoginTemplate( data.AppleLoginURL = baseURL + paths.OAuthAuth + "?client_id=apple&response_type=code" baseTemplateBody += appleLoginButton case database.AuthProviderFacebook: - baseTemplateBody += facebookLoginButton + "?client_id=facebook&response_type=code" + data.FacebookLoginURL = baseURL + paths.OAuthAuth + "?client_id=facebook&response_type=code" + baseTemplateBody += facebookLoginButton case database.AuthProviderGithub: - baseTemplateBody += githubLoginButton + "?client_id=github&response_type=code" + data.GithubLoginURL = baseURL + paths.OAuthAuth + "?client_id=github&response_type=code" + baseTemplateBody += githubLoginButton case database.AuthProviderGoogle: - baseTemplateBody += googleLoginButton + "?client_id=google&response_type=code" + data.GoogleLoginURL = baseURL + paths.OAuthAuth + "?client_id=google&response_type=code" + baseTemplateBody += googleLoginButton case database.AuthProviderMicrosoft: - baseTemplateBody += microsoftLoginButton + "?client_id=microsoft&response_type=code" + data.MicrosoftLoginURL = baseURL + paths.OAuthAuth + "?client_id=microsoft" + baseTemplateBody += microsoftLoginButton default: return "", fmt.Errorf("unsupported auth provider: %s", provider.Provider) } diff --git a/idp/internal/controllers/templates/login.html b/idp/internal/services/templates/login.html similarity index 90% rename from idp/internal/controllers/templates/login.html rename to idp/internal/services/templates/login.html index 0333dc6..c339213 100644 --- a/idp/internal/controllers/templates/login.html +++ b/idp/internal/services/templates/login.html @@ -275,7 +275,7 @@

Welcome back {{.Name}}

- - - - -
- @@ -370,12 +288,18 @@ func buildEntryAccountDynamicRegistrationTemplate(body string) string { return fmt.Sprintf(accountDynamicRegistrationBaseTemplate, body) } -const baseAccountDynamicRegistrationLoginTitle = "Account Credentials Dynamic Registration" +const baseAccountLoginTitle = "Account Login" const loginForm = `
- - + + + + + + + +
` @@ -387,7 +311,7 @@ const divider = ` ` const appleLoginButton = ` - + + + + +` + +type accountDynamicRegistrationIAT2FAData struct { + TwoFAURL string + ClientID string + CSRFToken string + State string + CodeChallenge string + CodeChallengeMethod string + RedirectURI string +} + +type AccountDynamicRegistrationIAT2FAOptions struct { + ClientID string + CSRFToken string + State string + CodeChallenge string + CodeChallengeMethod string + RedirectURI string +} + +func BuildAccountDynamicRegistrationIAT2FATemplate(opts AccountDynamicRegistrationIATAuthOptions) (string, error) { + baseURL := paths.AccountsBase + paths.CredentialsBase + opts.ClientID + paths.InitialAccessToken + data := accountDynamicRegistrationIAT2FAData{ + TwoFAURL: baseURL + paths.AuthLogin + paths.Auth2FA, + RedirectURI: opts.RedirectURI, + CodeChallenge: opts.CodeChallenge, + CodeChallengeMethod: opts.CodeChallengeMethod, + State: opts.State, + CSRFToken: opts.CSRFToken, + } + + t, err := template.New("two_fa").Parse(twoFaTemplate) + if err != nil { + return "", nil + } + var twoFATemplateContent bytes.Buffer + if err := t.Execute(&twoFATemplateContent, data); err != nil { + return "", err + } + + return twoFATemplateContent.String(), nil +} diff --git a/idp/internal/services/templates/login.html b/idp/internal/services/templates/login.html index c339213..d50764a 100644 --- a/idp/internal/services/templates/login.html +++ b/idp/internal/services/templates/login.html @@ -264,9 +264,14 @@

Welcome back {{.Name}}

- - - + + + + + + + +
@@ -287,7 +292,7 @@

Welcome back {{.Name}}

- + + + + \ No newline at end of file diff --git a/idp/internal/utils/hasher.go b/idp/internal/utils/hasher.go index f33d5f5..3b32b90 100644 --- a/idp/internal/utils/hasher.go +++ b/idp/internal/utils/hasher.go @@ -93,8 +93,8 @@ func Sha256HashHex(str string) string { return hex.EncodeToString(hash[:]) } -func Sha256HashBase64(bytes []byte) string { - hash := sha256.Sum256(bytes) +func Sha256HashBase64(str string) string { + hash := sha256.Sum256([]byte(str)) return base64.RawURLEncoding.EncodeToString(hash[:]) } From 32d0bc3d111523bc617e7b9c6248cd8908923983 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Thu, 28 Aug 2025 22:33:33 +1200 Subject: [PATCH 11/23] feat(idp): start adding support for 2fa --- ...ccount_credentials_dynamic_registration.go | 66 +++++++++++ idp/internal/providers/cache/two_factor.go | 14 ++- .../account_credentials_registration_iat.go | 2 +- .../services/oauth_dynamic_registration.go | 105 ++++++++++++++---- .../templates/account_dynamic_registration.go | 10 +- .../services/templates/two_factor.html | 1 + 6 files changed, 166 insertions(+), 32 deletions(-) diff --git a/idp/internal/providers/cache/account_credentials_dynamic_registration.go b/idp/internal/providers/cache/account_credentials_dynamic_registration.go index 31bc778..15ab010 100644 --- a/idp/internal/providers/cache/account_credentials_dynamic_registration.go +++ b/idp/internal/providers/cache/account_credentials_dynamic_registration.go @@ -12,6 +12,7 @@ import ( "fmt" "net/url" "strings" + "time" "github.com/google/uuid" @@ -149,6 +150,71 @@ func (c *Cache) DeleteAccountCredentialsDynamicRegistrationIATAuth( return c.storage.DeleteWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATAuthCacheKey(opts.ClientID)) } +type AccountCredentialsDynamicRegistrationIAT2FAData struct { + AccountPublicID uuid.UUID `json:"account_public_id"` + AccountVersion int32 `json:"account_version"` + RedirectURI string `json:"redirect_uri"` + Challenge string `json:"challenge"` + ClientID string `json:"clientId"` + Domain string `json:"domain"` + State string `json:"state"` +} + +func buildAccountCredentialsDynamicRegistrationIAT2FACacheKey(sessionID string) string { + return fmt.Sprintf("%s:2fa:%s", accountCredentialsDynamicRegistrationIATPrefix, utils.Sha256HashHex(sessionID)) +} + +type SaveAccountCredentialsDynamicRegistrationIAT2FAOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + RedirectURI string + Domain string + ClientID string + Challenge string + State string + TwoFATTL int64 +} + +func (c *Cache) SaveAccountCredentialsDynamicRegistrationIAT2FA( + ctx context.Context, + opts SaveAccountCredentialsDynamicRegistrationIAT2FAOptions, +) (string, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "SaveAccountCredentialsDynamicRegistrationIAT2FA", + RequestID: opts.RequestID, + }).With( + "accountPublicId", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.DebugContext(ctx, "Saving account credentials dynamic registration IAT2FA...") + + sessionId := utils.Base64UUID() + data := AccountCredentialsDynamicRegistrationIAT2FAData{ + AccountPublicID: opts.AccountPublicID, + AccountVersion: opts.AccountVersion, + RedirectURI: opts.RedirectURI, + Domain: opts.Domain, + ClientID: opts.ClientID, + Challenge: opts.Challenge, + State: opts.State, + } + dataBytes, err := json.Marshal(data) + if err != nil { + logger.ErrorContext(ctx, "Failed to marshal account credentials dynamic registration IAT2FA data", "error", err) + return "", err + } + + return sessionId, c.storage.SetWithContext( + ctx, + buildAccountCredentialsDynamicRegistrationIAT2FACacheKey(sessionId), + dataBytes, + time.Duration(opts.TwoFATTL)*time.Second, + ) + +} + func buildAccountCredentialsDynamicRegistrationIATCodeCacheKey(codeID string) string { return fmt.Sprintf("%s:code:%s", accountCredentialsDynamicRegistrationIATPrefix, codeID) } diff --git a/idp/internal/providers/cache/two_factor.go b/idp/internal/providers/cache/two_factor.go index 7a6f66c..14a47da 100644 --- a/idp/internal/providers/cache/two_factor.go +++ b/idp/internal/providers/cache/two_factor.go @@ -23,13 +23,17 @@ const ( twoFactorUserPrefix string = "user" ) -func generateCode() (string, error) { - const codeLength = 6 - const digits = "0123456789" +const ( + codeLength int = 6 + digits string = "0123456789" + digitsLen int64 = 10 +) + +func generate2FACode() (string, error) { code := make([]byte, codeLength) for i := 0; i < codeLength; i++ { - num, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits)))) + num, err := rand.Int(rand.Reader, big.NewInt(digitsLen)) if err != nil { return "", err } @@ -65,7 +69,7 @@ func (c *Cache) AddTwoFactorCode(ctx context.Context, opts AddTwoFactorCodeOptio ) logger.DebugContext(ctx, "Adding two factor code...") - code, err := generateCode() + code, err := generate2FACode() if err != nil { logger.ErrorContext(ctx, "Error generating two factor code", "error", err) return "", err diff --git a/idp/internal/services/account_credentials_registration_iat.go b/idp/internal/services/account_credentials_registration_iat.go index 978aec2..05fd925 100644 --- a/idp/internal/services/account_credentials_registration_iat.go +++ b/idp/internal/services/account_credentials_registration_iat.go @@ -10,13 +10,13 @@ import ( "context" "github.com/google/uuid" - "github.com/tugascript/devlogs/idp/internal/utils" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/cache" "github.com/tugascript/devlogs/idp/internal/providers/crypto" "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/providers/tokens" + "github.com/tugascript/devlogs/idp/internal/utils" ) const accountCredentialsRegistrationIATLocation = "account_credentials_registration_iat" diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index b7f1bad..42c0648 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -11,10 +11,12 @@ import ( "net/url" "github.com/google/uuid" + "github.com/tugascript/devlogs/idp/internal/controllers/paths" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/cache" "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/providers/mailer" "github.com/tugascript/devlogs/idp/internal/services/templates" "github.com/tugascript/devlogs/idp/internal/utils" ) @@ -143,6 +145,7 @@ func (s *Services) createAccountCredentialsRegistrationIATCode( type OAuthDynamicRegistrationIATLoginOptions struct { RequestID string ClientID string + CSRFToken string CodeChallenge string CodeChallengeMethod string State string @@ -155,7 +158,7 @@ type OAuthDynamicRegistrationIATLoginOptions struct { func (s *Services) OAuthDynamicRegistrationIATLogin( ctx context.Context, opts OAuthDynamicRegistrationIATLoginOptions, -) (string, *exceptions.ServiceError) { +) (string, string, *exceptions.ServiceError) { logger := s.buildLogger(opts.RequestID, oauthDynamicRegistrationLocation, "OAuthDynamicRegistrationIATLogin").With( "clientId", opts.ClientID, "email", opts.Email, @@ -168,30 +171,30 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( }) if err != nil { logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT", "error", err) - return "", exceptions.NewInternalServerError() + return "", "", exceptions.NewInternalServerError() } if !found { logger.ErrorContext(ctx, "Account credentials dynamic registration IAT not found") - return "", exceptions.NewNotFoundValidationError("invalid client ID") + return "", "", exceptions.NewNotFoundValidationError("invalid client ID") } hashedChallenge, err := hashChallenge(opts.CodeChallenge, opts.CodeChallengeMethod) if err != nil { logger.ErrorContext(ctx, "Invalid code challenge", "error", err) - return "", exceptions.NewInternalServerError() + return "", "", exceptions.NewInternalServerError() } // Note this is not the verifier so standard comparison is ok if hashedChallenge != data.Challenge { logger.WarnContext(ctx, "OAuth Code challenge verification failed") - return "", exceptions.NewUnauthorizedError() + return "", "", exceptions.NewUnauthorizedError() } if data.State != opts.State { logger.WarnContext(ctx, "OAuth State does not match") - return "", exceptions.NewUnauthorizedError() + return "", "", exceptions.NewUnauthorizedError() } if data.RedirectURI != opts.RedirectURI { logger.WarnContext(ctx, "OAuth Redirect URI does not match") - return "", exceptions.NewUnauthorizedError() + return "", "", exceptions.NewUnauthorizedError() } accountDTO, serviceErr := s.GetAccountByEmail(ctx, GetAccountByEmailOptions{ @@ -200,11 +203,11 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( }) if serviceErr != nil { if serviceErr.Code != exceptions.CodeNotFound { - return "", serviceErr + return "", "", serviceErr } logger.WarnContext(ctx, "Account was not found", "error", serviceErr) - return "", exceptions.NewUnauthorizedError() + return "", "", exceptions.NewUnauthorizedError() } if _, err := s.database.FindAccountAuthProviderByAccountPublicIdAndProvider( ctx, @@ -216,28 +219,85 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( serviceErr := exceptions.FromDBError(err) if serviceErr.Code != exceptions.CodeNotFound { logger.ErrorContext(ctx, "Failed to find account auth provider", "error", err) - return "", serviceErr + return "", "", serviceErr } logger.WarnContext(ctx, "Account auth provider not found", "error", err) - return "", exceptions.NewUnauthorizedError() + return "", "", exceptions.NewUnauthorizedError() } passwordVerified, err := utils.Argon2CompareHash(opts.Password, accountDTO.Password()) if err != nil { logger.ErrorContext(ctx, "Failed to verify password", "error", err) - return "", exceptions.NewInternalServerError() + return "", "", exceptions.NewInternalServerError() } if !passwordVerified { logger.WarnContext(ctx, "Passwords do not match") - return "", exceptions.NewUnauthorizedError() + return "", "", exceptions.NewUnauthorizedError() } if !accountDTO.EmailVerified() { logger.InfoContext(ctx, "Account is not confirmed") - return "", exceptions.NewForbiddenError() + return "", "", exceptions.NewForbiddenError() } - // TODO: add 2FA redirect here + if accountDTO.TwoFactorType != database.TwoFactorTypeNone { + logger.InfoContext(ctx, "Two-Factor is enabled, proceeding to 2FA step") + sessionID, err := s.cache.SaveAccountCredentialsDynamicRegistrationIAT2FA( + ctx, + cache.SaveAccountCredentialsDynamicRegistrationIAT2FAOptions{ + RequestID: opts.RequestID, + AccountPublicID: accountDTO.PublicID, + AccountVersion: accountDTO.Version(), + RedirectURI: opts.RedirectURI, + Domain: data.Domain, + ClientID: opts.ClientID, + Challenge: data.Challenge, + State: data.State, + TwoFATTL: s.jwt.Get2FATTL(), + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to save account credentials dynamic registration IAT 2FA", "error", err) + return "", "", exceptions.NewInternalServerError() + } + + if accountDTO.TwoFactorType == database.TwoFactorTypeEmail { + code, err := s.cache.AddTwoFactorCode(ctx, cache.AddTwoFactorCodeOptions{ + RequestID: opts.RequestID, + AccountID: accountDTO.ID(), + TTL: s.jwt.Get2FATTL(), + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to add two factor code", "error", err) + return "", "", exceptions.NewInternalServerError() + } + + if err := s.mail.Publish2FAEmail(ctx, mailer.TwoFactorEmailOptions{ + RequestID: opts.RequestID, + Email: accountDTO.Email, + Name: accountDTO.GivenName, + Code: code, + }); err != nil { + logger.ErrorContext(ctx, "Failed to send two factor code email", "error", err) + return "", "", exceptions.NewInternalServerError() + } + + logger.InfoContext(ctx, "Sent two factor code email successfully") + } + + if err := s.cache.DeleteAccountCredentialsDynamicRegistrationIATAuth( + ctx, + cache.DeleteAccountCredentialsDynamicRegistrationIATAuthOptions{ + RequestID: opts.RequestID, + ClientID: opts.ClientID, + }, + ); err != nil { + logger.ErrorContext(ctx, "Failed to delete account credentials dynamic registration IAT auth", "error", err) + return "", "", exceptions.NewInternalServerError() + } + + return paths.AccountsBase + paths.CredentialsBase + "/" + opts.ClientID + paths.InitialAccessToken + paths.AuthLogin + paths.Auth2FA, sessionID, nil + } domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ RequestID: opts.RequestID, @@ -247,7 +307,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( if serviceErr != nil { if serviceErr.Code != exceptions.CodeNotFound { logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) - return "", serviceErr + return "", "", serviceErr } if err := s.cache.DeleteAccountCredentialsDynamicRegistrationIATAuth( @@ -258,15 +318,15 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( }, ); err != nil { logger.ErrorContext(ctx, "Failed to delete account credentials dynamic registration IAT auth", "error", err) - return "", exceptions.NewInternalServerError() + return "", "", exceptions.NewInternalServerError() } logger.WarnContext(ctx, "Account credentials registration domain not found") - return "", exceptions.NewForbiddenError() + return "", "", exceptions.NewForbiddenError() } if !domainDTO.Verified { logger.ErrorContext(ctx, "Account credentials registration domain is not verified") - return "", exceptions.NewForbiddenError() + return "", "", exceptions.NewForbiddenError() } code, serviceErr := s.createAccountCredentialsRegistrationIATCode( @@ -282,7 +342,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( ) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to create account credentials registration IAT code", "serviceError", serviceErr) - return "", serviceErr + return "", "", serviceErr } if err := s.cache.DeleteAccountCredentialsDynamicRegistrationIATAuth( ctx, @@ -292,19 +352,20 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( }, ); err != nil { logger.ErrorContext(ctx, "Failed to delete account credentials dynamic registration IAT auth", "error", err) - return "", exceptions.NewInternalServerError() + return "", "", exceptions.NewInternalServerError() } queryParams := make(url.Values) queryParams.Add("code", code) queryParams.Add("state", data.State) queryParams.Add("iss", "https://"+opts.BackendDomain) - return data.RedirectURI + "?" + queryParams.Encode(), nil + return data.RedirectURI + "?" + queryParams.Encode(), "", nil } type OAuthDynamicRegistrationIAT2FAOptions struct { RequestID string ClientID string + SessionID string Code string } diff --git a/idp/internal/services/templates/account_dynamic_registration.go b/idp/internal/services/templates/account_dynamic_registration.go index 76170b3..879d47c 100644 --- a/idp/internal/services/templates/account_dynamic_registration.go +++ b/idp/internal/services/templates/account_dynamic_registration.go @@ -422,7 +422,7 @@ type AccountDynamicRegistrationIATAuthOptions struct { } func BuildAccountDynamicRegistrationIATAuthTemplate(opts AccountDynamicRegistrationIATAuthOptions) (string, error) { - baseURL := paths.AccountsBase + paths.CredentialsBase + opts.ClientID + paths.InitialAccessToken + baseURL := paths.AccountsBase + paths.CredentialsBase + "/" + opts.ClientID + paths.InitialAccessToken data := accountDynamicRegistrationLoginTemplateData{ Title: baseAccountLoginTitle, Header: "OAuth Dynamic Client Registration Initial Access Token Login", @@ -651,7 +651,7 @@ const twoFaTemplate = ` type accountDynamicRegistrationIAT2FAData struct { TwoFAURL string - ClientID string + SessionID string CSRFToken string State string CodeChallenge string @@ -661,6 +661,7 @@ type accountDynamicRegistrationIAT2FAData struct { type AccountDynamicRegistrationIAT2FAOptions struct { ClientID string + SessionID string CSRFToken string State string CodeChallenge string @@ -668,8 +669,8 @@ type AccountDynamicRegistrationIAT2FAOptions struct { RedirectURI string } -func BuildAccountDynamicRegistrationIAT2FATemplate(opts AccountDynamicRegistrationIATAuthOptions) (string, error) { - baseURL := paths.AccountsBase + paths.CredentialsBase + opts.ClientID + paths.InitialAccessToken +func BuildAccountDynamicRegistrationIAT2FATemplate(opts AccountDynamicRegistrationIAT2FAOptions) (string, error) { + baseURL := paths.AccountsBase + paths.CredentialsBase + "/" + opts.ClientID + paths.InitialAccessToken data := accountDynamicRegistrationIAT2FAData{ TwoFAURL: baseURL + paths.AuthLogin + paths.Auth2FA, RedirectURI: opts.RedirectURI, @@ -677,6 +678,7 @@ func BuildAccountDynamicRegistrationIAT2FATemplate(opts AccountDynamicRegistrati CodeChallengeMethod: opts.CodeChallengeMethod, State: opts.State, CSRFToken: opts.CSRFToken, + SessionID: opts.SessionID, } t, err := template.New("two_fa").Parse(twoFaTemplate) diff --git a/idp/internal/services/templates/two_factor.html b/idp/internal/services/templates/two_factor.html index b80cf5f..b0dd761 100644 --- a/idp/internal/services/templates/two_factor.html +++ b/idp/internal/services/templates/two_factor.html @@ -136,6 +136,7 @@

Confirm {{.Name}}

+ From 785a1b4d063962c118c823622296bb484073e2c3 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 31 Aug 2025 02:45:33 +1200 Subject: [PATCH 12/23] feat(idp): add session key to oauth 2.0 flow for iat --- idp/internal/config/config.go | 1 + idp/internal/controllers/apps.go | 1 + idp/internal/controllers/auth.go | 22 +- idp/internal/controllers/controllers.go | 21 +- idp/internal/controllers/helpers.go | 71 +- idp/internal/controllers/oauth.go | 1 - .../controllers/oauth_dynamic_registration.go | 144 ++++ idp/internal/controllers/params/oauth.go | 6 + .../params/oauth_dynamic_registration.go | 23 + idp/internal/controllers/paths/common.go | 1 + idp/internal/exceptions/controllers.go | 4 +- ...ccount_credentials_dynamic_registration.go | 346 +++++++- idp/internal/providers/cache/cache.go | 6 + .../database/account_credentials.sql.go | 18 + .../database/queries/account_credentials.sql | 5 + .../routes/account_dynamic_registration.go | 51 +- idp/internal/server/routes/common.go | 8 +- idp/internal/services/accounts.go | 32 +- idp/internal/services/auth.go | 50 +- .../services/oauth_dynamic_registration.go | 739 ++++++++++++++---- .../templates/account_dynamic_registration.go | 11 +- idp/internal/services/templates/error.go | 308 ++++++++ idp/internal/services/templates/error.html | 123 +++ 23 files changed, 1697 insertions(+), 295 deletions(-) create mode 100644 idp/internal/controllers/oauth_dynamic_registration.go create mode 100644 idp/internal/controllers/params/oauth_dynamic_registration.go create mode 100644 idp/internal/services/templates/error.go create mode 100644 idp/internal/services/templates/error.html diff --git a/idp/internal/config/config.go b/idp/internal/config/config.go index a0b2d4a..502c33c 100644 --- a/idp/internal/config/config.go +++ b/idp/internal/config/config.go @@ -26,6 +26,7 @@ type Config struct { backendDomain string cookieSecret string cookieName string + sessionCookieName string emailPubChannel string encryptionSecret string serviceID uuid.UUID diff --git a/idp/internal/controllers/apps.go b/idp/internal/controllers/apps.go index 7b1c0cf..b067e64 100644 --- a/idp/internal/controllers/apps.go +++ b/idp/internal/controllers/apps.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/gofiber/fiber/v2" + "github.com/tugascript/devlogs/idp/internal/controllers/bodies" "github.com/tugascript/devlogs/idp/internal/controllers/params" "github.com/tugascript/devlogs/idp/internal/controllers/paths" diff --git a/idp/internal/controllers/auth.go b/idp/internal/controllers/auth.go index d53e8c6..075a725 100644 --- a/idp/internal/controllers/auth.go +++ b/idp/internal/controllers/auth.go @@ -8,6 +8,7 @@ package controllers import ( "github.com/gofiber/fiber/v2" + "github.com/tugascript/devlogs/idp/internal/controllers/paths" "github.com/tugascript/devlogs/idp/internal/controllers/bodies" "github.com/tugascript/devlogs/idp/internal/controllers/params" @@ -20,11 +21,11 @@ const authLocation string = "auth" func (c *Controllers) saveAccountRefreshCookie(ctx *fiber.Ctx, token string) { ctx.Cookie(&fiber.Cookie{ - Name: c.refreshCookieName, + Name: c.cookieName + refreshCookieSuffix, Value: token, - Path: "/auth", + Path: paths.V1 + paths.AuthBase, HTTPOnly: true, - SameSite: "None", + SameSite: fiber.CookieSameSiteNoneMode, Secure: true, MaxAge: int(c.services.GetRefreshTTL()), }) @@ -32,11 +33,12 @@ func (c *Controllers) saveAccountRefreshCookie(ctx *fiber.Ctx, token string) { func (c *Controllers) clearAccountRefreshCookie(ctx *fiber.Ctx) { ctx.Cookie(&fiber.Cookie{ - Name: c.refreshCookieName, + Name: c.cookieName + refreshCookieSuffix, Value: "", + Path: paths.V1 + paths.AuthBase, HTTPOnly: true, Secure: true, - SameSite: "None", + SameSite: fiber.CookieSameSiteNoneMode, MaxAge: -1, }) } @@ -195,7 +197,7 @@ func (c *Controllers) LogoutAccount(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) logger := c.buildLogger(requestID, authLocation, "LogoutAccount") - refreshToken := ctx.Cookies(c.refreshCookieName) + refreshToken := ctx.Cookies(c.cookieName + refreshCookieSuffix) if refreshToken == "" { body := new(bodies.RefreshTokenBody) if err := ctx.BodyParser(body); err != nil { @@ -215,6 +217,7 @@ func (c *Controllers) LogoutAccount(ctx *fiber.Ctx) error { return serviceErrorResponse(logger, ctx, serviceErr) } + c.clearAccountRefreshCookie(ctx) logResponse(logger, ctx, fiber.StatusNoContent) return ctx.SendStatus(fiber.StatusNoContent) } @@ -224,7 +227,8 @@ func (c *Controllers) RefreshAccount(ctx *fiber.Ctx) error { logger := c.buildLogger(requestID, authLocation, "RefreshAccount") logRequest(logger, ctx) - refreshToken := ctx.Cookies(c.refreshCookieName) + refreshToken := ctx.Cookies(c.cookieName + refreshCookieSuffix) + isCookie := true if refreshToken == "" { body := new(bodies.RefreshTokenBody) if err := ctx.BodyParser(body); err != nil { @@ -234,6 +238,7 @@ func (c *Controllers) RefreshAccount(ctx *fiber.Ctx) error { return validateBodyErrorResponse(logger, ctx, err) } + isCookie = false refreshToken = body.RefreshToken } @@ -242,6 +247,9 @@ func (c *Controllers) RefreshAccount(ctx *fiber.Ctx) error { RefreshToken: refreshToken, }) if serviceErr != nil { + if isCookie { + c.clearAccountRefreshCookie(ctx) + } return serviceErrorResponse(logger, ctx, serviceErr) } diff --git a/idp/internal/controllers/controllers.go b/idp/internal/controllers/controllers.go index 667cc6e..83e59f8 100644 --- a/idp/internal/controllers/controllers.go +++ b/idp/internal/controllers/controllers.go @@ -21,23 +21,24 @@ type Controllers struct { validate *validator.Validate frontendDomain string backendDomain string - refreshCookieName string + cookieName string + sessionCookieName string } func NewControllers( logger *slog.Logger, services *services.Services, validate *validator.Validate, - frontendDomain, - backendDomain, - refreshCookieName string, + frontendDomain string, + backendDomain string, + cookieName string, ) *Controllers { return &Controllers{ - logger: logger.With(utils.BaseLayer, utils.ControllersLogLayer), - services: services, - validate: validate, - frontendDomain: frontendDomain, - backendDomain: backendDomain, - refreshCookieName: refreshCookieName, + logger: logger.With(utils.BaseLayer, utils.ControllersLogLayer), + services: services, + validate: validate, + frontendDomain: frontendDomain, + backendDomain: backendDomain, + cookieName: cookieName, } } diff --git a/idp/internal/controllers/helpers.go b/idp/internal/controllers/helpers.go index b56a8ae..77e2f1d 100644 --- a/idp/internal/controllers/helpers.go +++ b/idp/internal/controllers/helpers.go @@ -14,6 +14,7 @@ import ( "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" "github.com/google/uuid" + "github.com/tugascript/devlogs/idp/internal/services/templates" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/utils" @@ -22,6 +23,8 @@ import ( const ( cacheControlNoStore string = "no-store, no-cache, must-revalidate, private" + refreshCookieSuffix = "_rt" + grantTypeRefresh string = "refresh_token" grantTypeAuthorization string = "authorization_code" grantTypeClientCredentials string = "client_credentials" @@ -59,33 +62,58 @@ func logResponse(logger *slog.Logger, ctx *fiber.Ctx, status int) { ) } -func validateErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, location string, err error) error { - logger.WarnContext(ctx.UserContext(), "Failed to validate request", "error", err) - logResponse(logger, ctx, fiber.StatusBadRequest) - +func validationErrorException(location string, err error) *exceptions.ValidationErrorResponse { var errs validator.ValidationErrors ok := errors.As(err, &errs) if !ok { - return ctx. - Status(fiber.StatusBadRequest). - JSON(exceptions.NewEmptyValidationErrorResponse(location)) + return exceptions.NewEmptyValidationErrorResponse(location) } + return exceptions.ValidationErrorResponseFromErr(&errs, location) +} + +func validateErrorJSONResponse(logger *slog.Logger, ctx *fiber.Ctx, location string, err error) error { + logger.WarnContext(ctx.UserContext(), "Failed to validate request", "error", err) + logResponse(logger, ctx, fiber.StatusBadRequest) return ctx. Status(fiber.StatusBadRequest). - JSON(exceptions.ValidationErrorResponseFromErr(&errs, location)) + JSON(validationErrorException(location, err)) } func validateBodyErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, err error) error { - return validateErrorResponse(logger, ctx, exceptions.ValidationResponseLocationBody, err) + return validateErrorJSONResponse(logger, ctx, exceptions.ValidationResponseLocationBody, err) } func validateURLParamsErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, err error) error { - return validateErrorResponse(logger, ctx, exceptions.ValidationResponseLocationParams, err) + return validateErrorJSONResponse(logger, ctx, exceptions.ValidationResponseLocationParams, err) } func validateQueryParamsErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, err error) error { - return validateErrorResponse(logger, ctx, exceptions.ValidationResponseLocationQuery, err) + return validateErrorJSONResponse(logger, ctx, exceptions.ValidationResponseLocationQuery, err) +} + +func validationErrorHTMLResponse(logger *slog.Logger, ctx *fiber.Ctx, location string, err error) error { + logger.WarnContext(ctx.UserContext(), "Failed to validate request", "error", err) + expt := validationErrorException(location, err) + errHtml, err := templates.BuildErrorTemplate( + templates.ErrorTemplateOptions{ + Status: fiber.StatusBadRequest, + ErrorCode: expt.Code, + MessageTitle: expt.Message, + Messages: utils.MapSlice(expt.Fields, func(f *exceptions.FieldError) string { + return fmt.Sprintf("Field '%s' - Value '%s': %s", f.Param, f.Value, f.Message) + }), + }, + ) + if err != nil { + logger.ErrorContext(ctx.UserContext(), "Failed to build error template", "error", err) + logResponse(logger, ctx, fiber.StatusInternalServerError) + return ctx.Status(fiber.StatusInternalServerError). + Type("html"). + SendString(templates.InternalServerErrorTemplate) + } + + return ctx.Status(fiber.StatusBadRequest).Type("html").SendString(errHtml) } func serviceErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, serviceErr *exceptions.ServiceError) error { @@ -103,6 +131,27 @@ func serviceErrorWithFieldsResponse(logger *slog.Logger, ctx *fiber.Ctx, service )) } +func serviceErrorHTMLResponse(logger *slog.Logger, ctx *fiber.Ctx, serviceErr *exceptions.ServiceError) error { + status := exceptions.NewRequestErrorStatus(serviceErr.Code) + errHtml, err := templates.BuildErrorTemplate( + templates.ErrorTemplateOptions{ + Status: status, + ErrorCode: serviceErr.Code, + MessageTitle: serviceErr.Message, + }, + ) + if err != nil { + logger.ErrorContext(ctx.UserContext(), "Failed to build error template", "error", err) + logResponse(logger, ctx, fiber.StatusInternalServerError) + return ctx.Status(fiber.StatusInternalServerError). + Type("html"). + SendString(templates.InternalServerErrorTemplate) + } + + logResponse(logger, ctx, status) + return ctx.Status(status).Type("html").SendString(errHtml) +} + func oauthErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, message string) error { resErr := exceptions.NewOAuthError(message) diff --git a/idp/internal/controllers/oauth.go b/idp/internal/controllers/oauth.go index eb41f98..a9cc491 100644 --- a/idp/internal/controllers/oauth.go +++ b/idp/internal/controllers/oauth.go @@ -12,7 +12,6 @@ import ( "log/slog" "github.com/gofiber/fiber/v2" - "github.com/tugascript/devlogs/idp/internal/controllers/bodies" "github.com/tugascript/devlogs/idp/internal/controllers/params" "github.com/tugascript/devlogs/idp/internal/exceptions" diff --git a/idp/internal/controllers/oauth_dynamic_registration.go b/idp/internal/controllers/oauth_dynamic_registration.go new file mode 100644 index 0000000..2034aaa --- /dev/null +++ b/idp/internal/controllers/oauth_dynamic_registration.go @@ -0,0 +1,144 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package controllers + +import ( + "github.com/gofiber/fiber/v2" + + "github.com/tugascript/devlogs/idp/internal/controllers/params" + "github.com/tugascript/devlogs/idp/internal/controllers/paths" + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/services" +) + +const ( + oauthDynamicRegistration string = "oauth_dynamic_registration" + + accountsIATCookieSuffix string = "_acc_iat" + accountsIAT2FACookieSuffix string = "_acc_iat_2fa" +) + +func (c *Controllers) OAuthDynamicRegistrationIATAuth(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATAuth") + logRequest(logger, ctx) + + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ + ClientID: ctx.Query("client_id"), + ResponseType: ctx.Query("response_type"), + Challenge: ctx.Query("code_challenge"), + ChallengeMethod: ctx.Query("code_challenge_method"), + State: ctx.Query("state"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { + return validationErrorHTMLResponse(logger, ctx, exceptions.ValidationResponseLocationQuery, err) + } + + redirectURL, serviceErr := c.services.InitiateOAuthDynamicRegistrationIATAuth( + ctx.UserContext(), + services.InitiateOAuthDynamicRegistrationIATAuthOptions{ + RequestID: requestID, + Domain: qPrms.ClientID, + State: qPrms.State, + SessionKey: ctx.Cookies(c.cookieName + accountsIATCookieSuffix), + RefreshToken: ctx.Cookies(c.cookieName + refreshCookieSuffix), + Challenge: qPrms.Challenge, + ChallengeMethod: qPrms.ChallengeMethod, + RedirectURI: qPrms.RedirectURI, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect(redirectURL, fiber.StatusFound) +} + +func (c *Controllers) OAuthDynamicRegistrationIATLoginGet(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATLoginGet") + logRequest(logger, ctx) + + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ + ClientID: ctx.Query("client_id"), + ResponseType: ctx.Query("response_type"), + Challenge: ctx.Query("code_challenge"), + ChallengeMethod: ctx.Query("code_challenge_method"), + State: ctx.Query("state"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { + return validationErrorHTMLResponse(logger, ctx, exceptions.ValidationResponseLocationQuery, err) + } + + loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATAuthRenderOptions{ + RequestID: requestID, + State: qPrms.State, + CodeChallenge: qPrms.Challenge, + CodeChallengeMethod: qPrms.ChallengeMethod, + RedirectURI: qPrms.RedirectURI, + Domain: qPrms.ClientID, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).Type("html").SendString(loginHTML) +} + +func (c *Controllers) saveAccountIATCookie( + ctx *fiber.Ctx, + sessionKey string, +) { + ctx.Cookie(&fiber.Cookie{ + Name: c.cookieName + accountsIATCookieSuffix, + Value: sessionKey, + Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + + paths.OAuthAuth, + HTTPOnly: true, + SameSite: fiber.CookieSameSiteLaxMode, + Secure: true, + MaxAge: int(c.services.GetOAuthCodeTTL()), + }) +} + +func (c *Controllers) removeAccountIATCookie(ctx *fiber.Ctx) { + ctx.Cookie(&fiber.Cookie{ + Name: c.cookieName + accountsIATCookieSuffix, + Value: "", + Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + + paths.OAuthAuth, + HTTPOnly: true, + Secure: true, + SameSite: fiber.CookieSameSiteNoneMode, + MaxAge: -1, + }) +} + +func (c *Controllers) saveAccountIAT2FACookie( + ctx *fiber.Ctx, + sessionID string, + clientID string, +) { + ctx.Cookie(&fiber.Cookie{ + Name: c.cookieName + accountsIAT2FACookieSuffix, + Value: sessionID, + Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + + clientID + paths.AuthLogin + paths.Auth2FA, + HTTPOnly: true, + SameSite: fiber.CookieSameSiteLaxMode, + Secure: true, + MaxAge: int(c.services.GetOAuthCodeTTL()), + }) +} diff --git a/idp/internal/controllers/params/oauth.go b/idp/internal/controllers/params/oauth.go index 6a29b19..0763796 100644 --- a/idp/internal/controllers/params/oauth.go +++ b/idp/internal/controllers/params/oauth.go @@ -1,3 +1,9 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + package params type OAuthQueryParams struct { diff --git a/idp/internal/controllers/params/oauth_dynamic_registration.go b/idp/internal/controllers/params/oauth_dynamic_registration.go new file mode 100644 index 0000000..1cdc583 --- /dev/null +++ b/idp/internal/controllers/params/oauth_dynamic_registration.go @@ -0,0 +1,23 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package params + +type OAuthDynamicRegistrationIATAuthQueryParams struct { + ClientID string `validate:"required,fqdn"` + ResponseType string `validate:"required,oneof=code"` + Challenge string `validate:"required,min=1"` + ChallengeMethod string `validate:"omitempty,oneof=plain s256 S256"` + State string `validate:"required,min=1"` + RedirectURI string `validate:"required,uri"` +} + +type OAuthDynamicRegistrationIATAuthLoginGetQueryParams struct { + Challenge string `validate:"required,min=1"` + ChallengeMethod string `validate:"omitempty,oneof=plain s256 S256"` + RedirectURI string `validate:"required,url"` + State string `validate:"required,min=1"` +} diff --git a/idp/internal/controllers/paths/common.go b/idp/internal/controllers/paths/common.go index bb5e9c8..4e20871 100644 --- a/idp/internal/controllers/paths/common.go +++ b/idp/internal/controllers/paths/common.go @@ -8,6 +8,7 @@ package paths const ( Base string = "/" + V1 string = "/v1" Keys string = "/keys" Confirm string = "/confirm" Recover string = "/recover" diff --git a/idp/internal/exceptions/controllers.go b/idp/internal/exceptions/controllers.go index 156b1d8..ede05e6 100644 --- a/idp/internal/exceptions/controllers.go +++ b/idp/internal/exceptions/controllers.go @@ -233,7 +233,7 @@ func buildFieldErrorMessage(tag string, val any) string { } } -func ValidationErrorResponseFromErr(err *validator.ValidationErrors, location string) ValidationErrorResponse { +func ValidationErrorResponseFromErr(err *validator.ValidationErrors, location string) *ValidationErrorResponse { fields := make([]FieldError, len(*err)) for i, field := range *err { @@ -245,7 +245,7 @@ func ValidationErrorResponseFromErr(err *validator.ValidationErrors, location st } } - return ValidationErrorResponse{ + return &ValidationErrorResponse{ Code: StatusValidation, Message: ValidationResponseMessage, Fields: fields, diff --git a/idp/internal/providers/cache/account_credentials_dynamic_registration.go b/idp/internal/providers/cache/account_credentials_dynamic_registration.go index 15ab010..41bcb1a 100644 --- a/idp/internal/providers/cache/account_credentials_dynamic_registration.go +++ b/idp/internal/providers/cache/account_credentials_dynamic_registration.go @@ -10,13 +10,11 @@ import ( "context" "encoding/json" "fmt" - "net/url" "strings" "time" "github.com/google/uuid" - "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/utils" ) @@ -24,13 +22,16 @@ const ( accountCredentialsDynamicRegistrationLocation string = "account_credentials_dynamic_registration" accountCredentialsDynamicRegistrationIATPrefix string = "account_credentials_dynamic_registration_iat" + + csrfTokenByteLen int = 16 + sessionKeyByteLen int = 32 ) -func buildAccountCredentialsDynamicRegistrationIATAuthCacheKey(clientID string) string { - return fmt.Sprintf("%s:auth:%s", accountCredentialsDynamicRegistrationIATPrefix, clientID) +func buildAccountCredentialsDynamicRegistrationIATLoginCacheKey(clientID string) string { + return fmt.Sprintf("%s:login:%s", accountCredentialsDynamicRegistrationIATPrefix, clientID) } -type AccountCredentialsDynamicRegistrationIATAuthData struct { +type AccountCredentialsDynamicRegistrationIATLoginData struct { RedirectURI string `json:"redirect_uri"` Challenge string `json:"challenge"` CSRFToken string `json:"csrf_token"` @@ -38,42 +39,37 @@ type AccountCredentialsDynamicRegistrationIATAuthData struct { State string `json:"state"` } -type SaveAccountCredentialsDynamicRegistrationIATAuthOptions struct { +type SaveAccountCredentialsDynamicRegistrationIATLoginOptions struct { + Domain string RequestID string Challenge string State string RedirectURI string } -func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATAuth( +func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATLogin( ctx context.Context, - opts SaveAccountCredentialsDynamicRegistrationIATAuthOptions, + opts SaveAccountCredentialsDynamicRegistrationIATLoginOptions, ) (string, string, error) { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: accountCredentialsDynamicRegistrationLocation, - Method: "SaveAccountCredentialsDynamicRegistrationIATAuth", + Method: "SaveAccountCredentialsDynamicRegistrationIATLogin", RequestID: opts.RequestID, }).With( "redirectUri", opts.RedirectURI, ) logger.DebugContext(ctx, "Saving account credentials dynamic registration IAT sessions...") - u, err := url.Parse(opts.RedirectURI) - if err != nil { - logger.ErrorContext(ctx, "Invalid redirect URI", "error", err) - return "", "", exceptions.NewValidationError("invalid redirect URI") - } - - csrfToken, err := utils.GenerateBase64Secret(16) + csrfToken, err := utils.GenerateBase64Secret(csrfTokenByteLen) if err != nil { logger.ErrorContext(ctx, "Error generating CSRF token", "error", err) return "", "", err } - data := AccountCredentialsDynamicRegistrationIATAuthData{ + data := AccountCredentialsDynamicRegistrationIATLoginData{ Challenge: opts.Challenge, State: opts.State, - Domain: u.Hostname(), + Domain: opts.Domain, RedirectURI: opts.RedirectURI, CSRFToken: utils.Sha256HashHex(csrfToken), } @@ -86,7 +82,7 @@ func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATAuth( clientID := utils.Base62UUID() return clientID, csrfToken, c.storage.SetWithContext( ctx, - buildAccountCredentialsDynamicRegistrationIATAuthCacheKey(clientID), + buildAccountCredentialsDynamicRegistrationIATLoginCacheKey(clientID), dataBytes, c.oauthStateTTL, ) @@ -100,7 +96,7 @@ type GetAccountCredentialsDynamicRegistrationIATAuthOptions struct { func (c *Cache) GetAccountCredentialsDynamicRegistrationAuthIAT( ctx context.Context, opts GetAccountCredentialsDynamicRegistrationIATAuthOptions, -) (AccountCredentialsDynamicRegistrationIATAuthData, bool, error) { +) (AccountCredentialsDynamicRegistrationIATLoginData, bool, error) { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: accountCredentialsDynamicRegistrationLocation, Method: "GetAccountCredentialsDynamicRegistrationAuthIAT", @@ -110,20 +106,20 @@ func (c *Cache) GetAccountCredentialsDynamicRegistrationAuthIAT( ) logger.DebugContext(ctx, "Getting account credentials dynamic registration IAT...") - data, err := c.storage.GetWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATAuthCacheKey(opts.ClientID)) + data, err := c.storage.GetWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATLoginCacheKey(opts.ClientID)) if err != nil { logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT", "error", err) - return AccountCredentialsDynamicRegistrationIATAuthData{}, false, err + return AccountCredentialsDynamicRegistrationIATLoginData{}, false, err } if data == nil { logger.DebugContext(ctx, "Account credentials dynamic registration IAT not found") - return AccountCredentialsDynamicRegistrationIATAuthData{}, false, nil + return AccountCredentialsDynamicRegistrationIATLoginData{}, false, nil } - var authData AccountCredentialsDynamicRegistrationIATAuthData + var authData AccountCredentialsDynamicRegistrationIATLoginData if err := json.Unmarshal(data, &authData); err != nil { logger.ErrorContext(ctx, "Failed to unmarshal account credentials dynamic registration IAT data", "error", err) - return AccountCredentialsDynamicRegistrationIATAuthData{}, false, err + return AccountCredentialsDynamicRegistrationIATLoginData{}, false, err } return authData, true, nil @@ -147,7 +143,7 @@ func (c *Cache) DeleteAccountCredentialsDynamicRegistrationIATAuth( ) logger.DebugContext(ctx, "Deleting account credentials dynamic registration IAT...") - return c.storage.DeleteWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATAuthCacheKey(opts.ClientID)) + return c.storage.DeleteWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATLoginCacheKey(opts.ClientID)) } type AccountCredentialsDynamicRegistrationIAT2FAData struct { @@ -212,7 +208,150 @@ func (c *Cache) SaveAccountCredentialsDynamicRegistrationIAT2FA( dataBytes, time.Duration(opts.TwoFATTL)*time.Second, ) +} + +type GetAccountCredentialsDynamicRegistrationIAT2FAOptions struct { + RequestID string + SessionID string +} + +func (c *Cache) GetAccountCredentialsDynamicRegistrationIAT2FA( + ctx context.Context, + opts GetAccountCredentialsDynamicRegistrationIAT2FAOptions, +) (AccountCredentialsDynamicRegistrationIAT2FAData, bool, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "GetAccountCredentialsDynamicRegistrationIAT2FA", + RequestID: opts.RequestID, + }).With( + "sessionId", opts.SessionID, + ) + logger.DebugContext(ctx, "Getting account credentials dynamic registration IAT2FA...") + + data, err := c.storage.GetWithContext(ctx, buildAccountCredentialsDynamicRegistrationIAT2FACacheKey(opts.SessionID)) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT2FA", "error", err) + return AccountCredentialsDynamicRegistrationIAT2FAData{}, false, err + } + if data == nil { + logger.DebugContext(ctx, "Account credentials dynamic registration IAT2FA not found") + return AccountCredentialsDynamicRegistrationIAT2FAData{}, false, nil + } + + var twoFAData AccountCredentialsDynamicRegistrationIAT2FAData + if err := json.Unmarshal(data, &twoFAData); err != nil { + logger.ErrorContext(ctx, "Failed to unmarshal account credentials dynamic registration IAT2FA data", "error", err) + return AccountCredentialsDynamicRegistrationIAT2FAData{}, false, err + } + + return twoFAData, true, nil +} + +func buildAccountCredentialsDynamicRegistrationIAT2FACSRFCacheKey(sessionID string) string { + return fmt.Sprintf("%s:2fa-csrf:%s", accountCredentialsDynamicRegistrationIATPrefix, utils.Sha256HashHex(sessionID)) +} + +type SaveAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions struct { + RequestID string + SessionID string + TwoFATTL int64 +} + +func (c *Cache) SaveAccountCredentialsDynamicRegistrationIAT2FACSRFToken( + ctx context.Context, + opts SaveAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions, +) (string, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "SaveAccountCredentialsDynamicRegistrationIAT2FACSRFToken", + RequestID: opts.RequestID, + }).With( + "sessionId", opts.SessionID, + ) + logger.DebugContext(ctx, "Saving account credentials dynamic registration IAT2FA CSRF token...") + + csrfToken, err := utils.GenerateBase64Secret(csrfTokenByteLen) + if err != nil { + logger.ErrorContext(ctx, "Error generating CSRF token", "error", err) + return "", err + } + + return csrfToken, c.storage.SetWithContext( + ctx, + buildAccountCredentialsDynamicRegistrationIAT2FACSRFCacheKey(opts.SessionID), + []byte(utils.Sha256HashHex(csrfToken)), + time.Duration(opts.TwoFATTL)*time.Second, + ) +} +type VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions struct { + RequestID string + SessionID string + CSRFToken string +} + +func (c *Cache) VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFToken( + ctx context.Context, + opts VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions, +) (bool, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFToken", + RequestID: opts.RequestID, + }).With( + "sessionId", opts.SessionID, + ) + logger.DebugContext(ctx, "Verifying account credentials dynamic registration IAT2FA CSRF token...") + + hashedCSRFToken, err := c.storage.GetWithContext(ctx, buildAccountCredentialsDynamicRegistrationIAT2FACSRFCacheKey(opts.SessionID)) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT2FA CSRF token", "error", err) + return false, err + } + if hashedCSRFToken == nil { + logger.DebugContext(ctx, "Account credentials dynamic registration IAT2FA CSRF token not found") + return false, nil + } + + ok, err := utils.CompareShaHex(opts.CSRFToken, string(hashedCSRFToken)) + if err != nil { + logger.ErrorContext(ctx, "Error comparing CSRF token", "error", err) + return false, err + } + if !ok { + logger.DebugContext(ctx, "Invalid CSRF token") + return false, nil + } + if err := c.storage.DeleteWithContext( + ctx, + buildAccountCredentialsDynamicRegistrationIAT2FACSRFCacheKey(opts.SessionID), + ); err != nil { + logger.ErrorContext(ctx, "Error deleting CSRF token", "error", err) + return false, err + } + + return true, nil +} + +type DeleteAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions struct { + RequestID string + SessionID string +} + +func (c *Cache) DeleteAccountCredentialsDynamicRegistrationIAT2FACSRFToken( + ctx context.Context, + opts DeleteAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions, +) error { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "DeleteAccountCredentialsDynamicRegistrationIAT2FACSRFToken", + RequestID: opts.RequestID, + }) + logger.DebugContext(ctx, "Deleting account credentials dynamic registration IAT2FA CSRF token...") + return c.storage.DeleteWithContext( + ctx, + buildAccountCredentialsDynamicRegistrationIAT2FACSRFCacheKey(opts.SessionID), + ) } func buildAccountCredentialsDynamicRegistrationIATCodeCacheKey(codeID string) string { @@ -347,3 +486,158 @@ func (c *Cache) VerifyAccountCredentialsRegistrationIATCode( } return codeData, true, nil } + +type AccountCredentialsDynamicRegistrationSessionData struct { + AccountPublicID uuid.UUID `json:"account_public_id"` + AccountVersion int32 `json:"account_version"` + SessionKey string `json:"session_key"` +} + +func buildAccountCredentialsDynamicRegistrationSessionCacheKey(domain string, clientID string) string { + return fmt.Sprintf("%s:session:%s:%s", accountCredentialsDynamicRegistrationIATPrefix, domain, clientID) +} + +func formatSessionKey(clientID, sessionKey string) string { + return fmt.Sprintf("%s.%s", clientID, sessionKey) +} + +func parseSessionKey(sessionKey string) (string, string, bool) { + parts := strings.Split(sessionKey, ".") + if len(parts) != 2 { + return "", "", false + } + return parts[0], parts[1], true +} + +type CreateAccountCredentialsRegistrationSessionKeyOptions struct { + RequestID string + ClientID string + Domain string + AccountPublicID uuid.UUID + AccountVersion int32 +} + +func (c *Cache) CreateAccountCredentialsRegistrationSessionKey( + ctx context.Context, + opts CreateAccountCredentialsRegistrationSessionKeyOptions, +) (string, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "CreateAccountCredentialsRegistrationSessionKey", + RequestID: opts.RequestID, + }).With( + "clientId", opts.ClientID, + "domain", opts.Domain, + "accountPublicId", opts.AccountPublicID, + ) + logger.DebugContext(ctx, "Creating account credentials registration session key...") + + sessionKey, err := utils.GenerateBase64Secret(sessionKeyByteLen) + if err != nil { + logger.ErrorContext(ctx, "Error generating session key", "error", err) + return "", err + } + + data := AccountCredentialsDynamicRegistrationSessionData{ + AccountPublicID: opts.AccountPublicID, + AccountVersion: opts.AccountVersion, + SessionKey: utils.Sha256HashHex(sessionKey), + } + dataBytes, err := json.Marshal(data) + if err != nil { + logger.ErrorContext(ctx, "Failed to marshal account credentials registration session data", "error", err) + return "", err + } + + if err := c.storage.SetWithContext( + ctx, + buildAccountCredentialsDynamicRegistrationSessionCacheKey(opts.Domain, opts.ClientID), + dataBytes, + c.oauthCodeTTL, + ); err != nil { + logger.ErrorContext(ctx, "Failed to set account credentials registration session in cache", "error", err) + return "", err + } + + return formatSessionKey(opts.ClientID, sessionKey), nil +} + +type VerifyAccountCredentialsRegistrationSessionKeyOptions struct { + RequestID string + Domain string + SessionKey string +} + +func (c *Cache) VerifyAccountCredentialsRegistrationSessionKey( + ctx context.Context, + opts VerifyAccountCredentialsRegistrationSessionKeyOptions, +) (AccountCredentialsDynamicRegistrationSessionData, string, bool, bool, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "VerifyAccountCredentialsRegistrationSessionKey", + RequestID: opts.RequestID, + }).With( + "domain", opts.Domain, + ) + logger.DebugContext(ctx, "Verifying account credentials registration session key...") + + clientID, sessionKey, ok := parseSessionKey(opts.SessionKey) + if !ok { + logger.DebugContext(ctx, "Invalid account credentials registration session key format") + return AccountCredentialsDynamicRegistrationSessionData{}, "", false, true, nil + } + + data, err := c.storage.GetWithContext(ctx, buildAccountCredentialsDynamicRegistrationSessionCacheKey(opts.Domain, clientID)) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials registration session from cache", "error", err) + return AccountCredentialsDynamicRegistrationSessionData{}, "", false, false, err + } + if data == nil { + logger.DebugContext(ctx, "Account credentials registration session not found") + return AccountCredentialsDynamicRegistrationSessionData{}, "", false, false, nil + } + + var sessionData AccountCredentialsDynamicRegistrationSessionData + if err := json.Unmarshal(data, &sessionData); err != nil { + logger.ErrorContext(ctx, "Failed to unmarshal account credentials registration session data", "error", err) + return AccountCredentialsDynamicRegistrationSessionData{}, "", false, false, err + } + + ok, err = utils.CompareShaHex(sessionKey, sessionData.SessionKey) + if err != nil { + logger.ErrorContext(ctx, "Error comparing session key", "error", err) + return AccountCredentialsDynamicRegistrationSessionData{}, "", false, false, err + } + if !ok { + logger.DebugContext(ctx, "Invalid session key") + return AccountCredentialsDynamicRegistrationSessionData{}, clientID, false, true, nil + } + + return sessionData, clientID, true, true, nil +} + +type DeleteAccountCredentialsRegistrationSessionKeyOptions struct { + RequestID string + Domain string + ClientID string +} + +func (c *Cache) DeleteAccountCredentialsRegistrationSessionKey( + ctx context.Context, + opts DeleteAccountCredentialsRegistrationSessionKeyOptions, +) error { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "DeleteAccountCredentialsRegistrationSessionKey", + RequestID: opts.RequestID, + }).With( + "domain", opts.Domain, + "clientId", opts.ClientID, + ) + logger.DebugContext(ctx, "Deleting account credentials registration session key...") + + return c.storage.DeleteWithContext( + ctx, + buildAccountCredentialsDynamicRegistrationSessionCacheKey(opts.Domain, opts.ClientID), + ) +} diff --git a/idp/internal/providers/cache/cache.go b/idp/internal/providers/cache/cache.go index cc101b9..25b8df8 100644 --- a/idp/internal/providers/cache/cache.go +++ b/idp/internal/providers/cache/cache.go @@ -32,6 +32,7 @@ type Cache struct { wellKnownTTL time.Duration oauthStateTTL time.Duration oauthCodeTTL time.Duration + oauthCodeSec int64 } func NewCache( @@ -61,9 +62,14 @@ func NewCache( wellKnownTTL: time.Duration(wellKnownTTL) * time.Second, oauthStateTTL: time.Duration(oauthStateTTL) * time.Second, oauthCodeTTL: time.Duration(oauthCodeTTL) * time.Second, + oauthCodeSec: oauthCodeTTL, } } +func (c *Cache) OAuthCodeTTL() int64 { + return c.oauthCodeSec +} + func (c *Cache) ResetCache() error { return c.storage.Reset() } diff --git a/idp/internal/providers/database/account_credentials.sql.go b/idp/internal/providers/database/account_credentials.sql.go index 691b54d..ef4b466 100644 --- a/idp/internal/providers/database/account_credentials.sql.go +++ b/idp/internal/providers/database/account_credentials.sql.go @@ -25,6 +25,24 @@ func (q *Queries) CountAccountCredentialsByAccountPublicID(ctx context.Context, return count, err } +const countAccountCredentialsByAccountPublicIDAndClientID = `-- name: CountAccountCredentialsByAccountPublicIDAndClientID :one +SELECT COUNT(*) FROM "account_credentials" +WHERE "account_public_id" = $1 AND "client_id" = $2 +LIMIT 1 +` + +type CountAccountCredentialsByAccountPublicIDAndClientIDParams struct { + AccountPublicID uuid.UUID + ClientID string +} + +func (q *Queries) CountAccountCredentialsByAccountPublicIDAndClientID(ctx context.Context, arg CountAccountCredentialsByAccountPublicIDAndClientIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countAccountCredentialsByAccountPublicIDAndClientID, arg.AccountPublicID, arg.ClientID) + var count int64 + err := row.Scan(&count) + return count, err +} + const countAccountCredentialsByNameAndAccountID = `-- name: CountAccountCredentialsByNameAndAccountID :one SELECT COUNT(*) FROM "account_credentials" WHERE "account_id" = $1 AND "name" = $2 diff --git a/idp/internal/providers/database/queries/account_credentials.sql b/idp/internal/providers/database/queries/account_credentials.sql index 3e1bfa7..c47e40c 100644 --- a/idp/internal/providers/database/queries/account_credentials.sql +++ b/idp/internal/providers/database/queries/account_credentials.sql @@ -14,6 +14,11 @@ SELECT * FROM "account_credentials" WHERE "account_public_id" = $1 AND "client_id" = $2 LIMIT 1; +-- name: CountAccountCredentialsByAccountPublicIDAndClientID :one +SELECT COUNT(*) FROM "account_credentials" +WHERE "account_public_id" = $1 AND "client_id" = $2 +LIMIT 1; + -- name: CreateAccountCredentials :one INSERT INTO "account_credentials" ( "client_id", diff --git a/idp/internal/server/routes/account_dynamic_registration.go b/idp/internal/server/routes/account_dynamic_registration.go index 915f68c..7bf907b 100644 --- a/idp/internal/server/routes/account_dynamic_registration.go +++ b/idp/internal/server/routes/account_dynamic_registration.go @@ -14,70 +14,69 @@ import ( ) func (r *Routes) AccountDynamicRegistrationConfigurationRoutes(app *fiber.App) { - router := v1PathRouter(app).Group( - paths.AccountsBase+paths.CredentialsBase+paths.DynamicRegistrationBase, - r.controllers.AccountAccessClaimsMiddleware, - ) + router := v1PathRouter(app).Group(paths.AccountsBase + paths.CredentialsBase + paths.DynamicRegistrationBase) credentialsConfigsWriteScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsConfigsWrite) credentialsConfigsReadScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsConfigsRead) // Dynamic Registration Config - router.Get( - paths.Config, + configRouter := router.Group(paths.Config, r.controllers.AccountAccessClaimsMiddleware) + configRouter.Get( + paths.Base, credentialsConfigsReadScopeMiddleware, r.controllers.GetAccountDynamicRegistrationConfig, ) - router.Put( - paths.Config, + configRouter.Put( + paths.Base, credentialsConfigsWriteScopeMiddleware, r.controllers.UpsertAccountDynamicRegistrationConfig, ) - router.Delete( - paths.Config, + configRouter.Delete( + paths.Base, credentialsConfigsWriteScopeMiddleware, r.controllers.DeleteAccountDynamicRegistrationConfig, ) // Dynamic Registration Domains - router.Post( - paths.Domains, + domainsRouter := router.Group(paths.Domains, r.controllers.AccountAccessClaimsMiddleware) + domainsRouter.Post( + paths.Base, credentialsConfigsWriteScopeMiddleware, r.controllers.CreateAccountCredentialsRegistrationDomain, ) - router.Get( - paths.Domains, + domainsRouter.Get( + paths.Base, credentialsConfigsReadScopeMiddleware, r.controllers.ListAccountCredentialsRegistrationDomains, ) - router.Get( - paths.Domains+paths.SingleDomain, + domainsRouter.Get( + paths.SingleDomain, credentialsConfigsReadScopeMiddleware, r.controllers.GetAccountCredentialsRegistrationDomain, ) - router.Delete( - paths.Domains+paths.SingleDomain, + domainsRouter.Delete( + paths.SingleDomain, credentialsConfigsWriteScopeMiddleware, r.controllers.DeleteAccountCredentialsRegistrationDomain, ) - router.Post( - paths.Domains+paths.VerifyDomain, + domainsRouter.Post( + paths.VerifyDomain, credentialsConfigsWriteScopeMiddleware, r.controllers.VerifyAccountCredentialsRegistrationDomain, ) // Dynamic Registration Domains Code - router.Get( - paths.Domains+paths.DomainCode, + domainsRouter.Get( + paths.DomainCode, credentialsConfigsReadScopeMiddleware, r.controllers.GetAccountCredentialsRegistrationDomainCode, ) - router.Put( - paths.Domains+paths.DomainCode, + domainsRouter.Put( + paths.DomainCode, credentialsConfigsWriteScopeMiddleware, r.controllers.UpsertAccountCredentialsRegistrationDomainCode, ) - router.Delete( - paths.Domains+paths.DomainCode, + domainsRouter.Delete( + paths.DomainCode, credentialsConfigsWriteScopeMiddleware, r.controllers.DeleteAccountCredentialsRegistrationDomainCode, ) diff --git a/idp/internal/server/routes/common.go b/idp/internal/server/routes/common.go index e83dacb..1b97e77 100644 --- a/idp/internal/server/routes/common.go +++ b/idp/internal/server/routes/common.go @@ -6,10 +6,12 @@ package routes -import "github.com/gofiber/fiber/v2" +import ( + "github.com/gofiber/fiber/v2" -const V1Path string = "/v1" + "github.com/tugascript/devlogs/idp/internal/controllers/paths" +) func v1PathRouter(app *fiber.App) fiber.Router { - return app.Group(V1Path) + return app.Group(paths.V1) } diff --git a/idp/internal/services/accounts.go b/idp/internal/services/accounts.go index 9b1dc9a..1e9d148 100644 --- a/idp/internal/services/accounts.go +++ b/idp/internal/services/accounts.go @@ -488,13 +488,7 @@ func (s *Services) ConfirmUpdateAccountEmail( return dtos.AuthDTO{}, serviceErr } - if serviceErr := s.verifyAccountTwoFactor( - ctx, - logger, - opts.RequestID, - &accountDTO, - opts.Code, - ); serviceErr != nil { + if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { return dtos.AuthDTO{}, serviceErr } @@ -685,13 +679,7 @@ func (s *Services) ConfirmUpdateAccountPassword( return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() } - if serviceErr := s.verifyAccountTwoFactor( - ctx, - logger, - opts.RequestID, - &accountDTO, - opts.Code, - ); serviceErr != nil { + if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { return dtos.AuthDTO{}, serviceErr } @@ -1010,13 +998,7 @@ func (s *Services) ConfirmUpdateAccountUsername( return dtos.AuthDTO{}, serviceErr } - if serviceErr := s.verifyAccountTwoFactor( - ctx, - logger, - opts.RequestID, - &accountDTO, - opts.Code, - ); serviceErr != nil { + if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { return dtos.AuthDTO{}, serviceErr } @@ -1170,13 +1152,7 @@ func (s *Services) ConfirmDeleteAccount( return serviceErr } - if serviceErr := s.verifyAccountTwoFactor( - ctx, - logger, - opts.RequestID, - &accountDTO, - opts.Code, - ); serviceErr != nil { + if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { return serviceErr } diff --git a/idp/internal/services/auth.go b/idp/internal/services/auth.go index 9c9b39e..d2072b8 100644 --- a/idp/internal/services/auth.go +++ b/idp/internal/services/auth.go @@ -120,6 +120,14 @@ func (s *Services) GetRefreshTTL() int64 { return s.jwt.GetRefreshTTL() } +func (s *Services) Get2FATTL() int64 { + return s.jwt.Get2FATTL() +} + +func (s *Services) GetOAuthCodeTTL() int64 { + return s.cache.OAuthCodeTTL() +} + func (s *Services) sendConfirmationEmail( ctx context.Context, logger *slog.Logger, @@ -577,11 +585,16 @@ func (s *Services) VerifyAccountTotp( func (s *Services) verifyAccountTwoFactor( ctx context.Context, - logger *slog.Logger, requestID string, accountDTO *dtos.AccountDTO, code string, ) *exceptions.ServiceError { + logger := s.buildLogger(requestID, authLocation, "verifyAccountTwoFactor").With( + "accountPublicId", accountDTO.PublicID, + "twoFactorType", accountDTO.TwoFactorType, + ) + logger.InfoContext(ctx, "Verifying account two factor...") + switch accountDTO.TwoFactorType { case database.TwoFactorTypeNone: logger.WarnContext(ctx, "User has two factor inactive") @@ -656,13 +669,7 @@ func (s *Services) TwoFactorLoginAccount( return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() } - if serviceErr := s.verifyAccountTwoFactor( - ctx, - logger, - opts.RequestID, - &accountDTO, - opts.Code, - ); serviceErr != nil { + if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { return dtos.AuthDTO{}, serviceErr } @@ -780,29 +787,16 @@ func (s *Services) RefreshTokenAccount( return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() } - accountDTO, serviceErr := s.GetAccountByPublicID(ctx, GetAccountByPublicIDOptions{ + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ RequestID: opts.RequestID, PublicID: data.AccountClaims.AccountID, + Version: data.AccountClaims.AccountVersion, }) if serviceErr != nil { - if serviceErr.Code != exceptions.CodeNotFound { - logger.WarnContext(ctx, "Account not found", "error", serviceErr) - return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() - } - - logger.ErrorContext(ctx, "Failed to get account", "error", serviceErr) + logger.WarnContext(ctx, "Failed to get account by public ID and version", "serviceErr", serviceErr) return dtos.AuthDTO{}, serviceErr } - accountVersion := accountDTO.Version() - if accountVersion != data.AccountClaims.AccountVersion { - logger.WarnContext(ctx, "Account versions do not match", - "claimsVersion", data.AccountClaims.AccountVersion, - "accountVersion", accountVersion, - ) - return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() - } - if err := s.database.RevokeToken(ctx, database.RevokeTokenParams{ TokenID: data.TokenID, AccountID: accountDTO.ID(), @@ -1619,13 +1613,7 @@ func (s *Services) ConfirmUpdateAccount2FAUpdate( return dtos.AuthDTO{}, serviceErr } - if serviceErr := s.verifyAccountTwoFactor( - ctx, - logger, - opts.RequestID, - &accountDTO, - opts.Code, - ); serviceErr != nil { + if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { return dtos.AuthDTO{}, serviceErr } diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index 42c0648..378e23d 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -9,40 +9,382 @@ package services import ( "context" "net/url" + "slices" "github.com/google/uuid" - "github.com/tugascript/devlogs/idp/internal/controllers/paths" + "github.com/tugascript/devlogs/idp/internal/controllers/paths" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/cache" "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/providers/mailer" + "github.com/tugascript/devlogs/idp/internal/providers/tokens" "github.com/tugascript/devlogs/idp/internal/services/templates" "github.com/tugascript/devlogs/idp/internal/utils" ) const oauthDynamicRegistrationLocation string = "oauth_dynamic_registration" -type OAuthDynamicRegistrationIATAuthOptions struct { +const ( + oauthDynamicRegistrationIATPath string = paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.DynamicRegistrationBase + paths.InitialAccessToken + oauthDynamicRegistrationIATAuthPath string = oauthDynamicRegistrationIATPath + paths.OAuthAuth +) + +type buildOAuthDynamicRegistrationIATLoginURLOptions struct { + domain string + state string + challenge string + challengeMethod string + redirectURI string +} + +func buildOAuthDynamicRegistrationIATLoginURL(opts buildOAuthDynamicRegistrationIATLoginURLOptions) string { + queryParams := make(url.Values) + queryParams.Add("client_id", opts.domain) + queryParams.Add("response_type", "code") + queryParams.Add("state", opts.state) + queryParams.Add("code_challenge", opts.challenge) + if opts.challengeMethod != "" { + queryParams.Add("code_challenge_method", opts.challengeMethod) + } + return oauthDynamicRegistrationIATAuthPath + paths.AuthLogin + "?" + queryParams.Encode() +} + +type buildOAuthDynamicRegistrationIATCallbackURLOptions struct { + redirectURI string + code string + state string + backendDomain string +} + +func buildOAuthDynamicRegistrationIATCallbackURL(opts buildOAuthDynamicRegistrationIATCallbackURLOptions) string { + queryParams := make(url.Values) + queryParams.Add("code", opts.code) + queryParams.Add("state", opts.state) + queryParams.Add("iss", "https://"+opts.backendDomain) + return opts.redirectURI + "?" + queryParams.Encode() +} + +type generateOAuthDynamicRegistrationIATCallbackOptions struct { + requestID string + clientID string + accountPublicID uuid.UUID + accountVersion int32 + challenge string + challengeMethod string + domain string + redirectURI string + state string + backendDomain string +} + +func (s *Services) generateOAuthDynamicRegistrationIATCallback( + ctx context.Context, + opts generateOAuthDynamicRegistrationIATCallbackOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger( + opts.requestID, + oauthDynamicRegistrationLocation, + "generateOAuthDynamicRegistrationIATCallback", + ).With( + "clientId", opts.clientID, + "accountPublicId", opts.accountPublicID, + ) + logger.InfoContext(ctx, "Generating OAuth dynamic registration IAT callback...") + + hashedChallenge, serviceErr := hashChallenge(opts.challenge, opts.challengeMethod) + if serviceErr != nil { + logger.ErrorContext(ctx, "Invalid code challenge", "serviceError", serviceErr) + return "", serviceErr + } + + code, err := s.cache.GenerateAccountCredentialsRegistrationIATCode( + ctx, + cache.GenerateAccountCredentialsRegistrationIATCodeOptions{ + RequestID: opts.requestID, + ClientID: opts.clientID, + AccountPublicID: opts.accountPublicID, + AccountVersion: opts.accountVersion, + Challenge: hashedChallenge, + Domain: opts.domain, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to generate account credentials registration IAT code", "error", err) + return "", exceptions.NewInternalServerError() + } + + return buildOAuthDynamicRegistrationIATCallbackURL(buildOAuthDynamicRegistrationIATCallbackURLOptions{ + redirectURI: opts.redirectURI, + code: code, + state: opts.state, + backendDomain: opts.backendDomain, + }), nil +} + +type refreshTokenOAuthDynamicRegistrationIATLoginOptions struct { + requestID string + refreshToken string + challenge string + challengeMethod string + domain string + redirectURI string + state string + backendDomain string +} + +func (s *Services) refreshTokenOAuthDynamicRegistrationIATLogin( + ctx context.Context, + opts refreshTokenOAuthDynamicRegistrationIATLoginOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger( + opts.requestID, + oauthDynamicRegistrationLocation, + "refreshTokenOAuthDynamicRegistrationIATLogin", + ).With( + "redirectUri", opts.redirectURI, + ) + logger.InfoContext(ctx, "Refreshing OAuth dynamic registration IAT callback...") + + data, err := s.jwt.VerifyRefreshToken( + opts.refreshToken, + s.BuildGetGlobalPublicKeyFn(ctx, BuildGetGlobalVerifyKeyFnOptions{ + RequestID: opts.requestID, + KeyType: database.TokenKeyTypeRefresh, + }), + ) + if err != nil { + logger.WarnContext(ctx, "Invalid refresh token", "error", err) + return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ + domain: opts.domain, + state: opts.state, + challenge: opts.challenge, + challengeMethod: opts.challengeMethod, + redirectURI: opts.redirectURI, + }), nil + } + + if !slices.ContainsFunc(data.Scopes, func(s string) bool { + return s == tokens.AccountScopeAdmin || s == tokens.AccountScopeCredentialsConfigsWrite + }) { + logger.WarnContext(ctx, "Refresh token missing offline_access scope") + return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ + domain: opts.domain, + state: opts.state, + challenge: opts.challenge, + challengeMethod: opts.challengeMethod, + redirectURI: opts.redirectURI, + }), nil + } + + blt, err := s.database.GetRevokedToken(ctx, data.TokenID) + if err != nil { + if exceptions.FromDBError(err).Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to get blacklisted token", "error", err) + return "", exceptions.NewInternalServerError() + } + } else { + logger.WarnContext(ctx, "Token is revoked", "revokedAt", blt.CreatedAt) + return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ + domain: opts.domain, + state: opts.state, + challenge: opts.challenge, + challengeMethod: opts.challengeMethod, + redirectURI: opts.redirectURI, + }), nil + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.requestID, + PublicID: data.AccountClaims.AccountID, + Version: data.AccountClaims.AccountVersion, + }) + if serviceErr != nil { + if serviceErr.Code != exceptions.CodeNotFound && serviceErr.Code != exceptions.CodeUnauthorized { + logger.ErrorContext(ctx, "Failed to get account by public ID and version", "serviceError", serviceErr) + return "", serviceErr + } + + logger.WarnContext(ctx, "Account not found or version mismatch", "serviceError", serviceErr) + return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ + domain: opts.domain, + state: opts.state, + challenge: opts.challenge, + challengeMethod: opts.challengeMethod, + redirectURI: opts.redirectURI, + }), nil + } + + cbURL, serviceErr := s.generateOAuthDynamicRegistrationIATCallback( + ctx, + generateOAuthDynamicRegistrationIATCallbackOptions{ + requestID: opts.requestID, + clientID: utils.Base62UUID(), + accountPublicID: accountDTO.PublicID, + accountVersion: accountDTO.Version(), + challenge: opts.challenge, + challengeMethod: opts.challengeMethod, + domain: opts.domain, + redirectURI: opts.redirectURI, + state: opts.state, + backendDomain: opts.backendDomain, + }, + ) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to generate OAuth dynamic registration IAT callback", "serviceErr", serviceErr) + return "", serviceErr + } + + return cbURL, nil +} + +type InitiateOAuthDynamicRegistrationIATAuthOptions struct { + RequestID string + Domain string + State string + SessionKey string + RefreshToken string + Challenge string + ChallengeMethod string + RedirectURI string + BackendDomain string +} + +func (s *Services) InitiateOAuthDynamicRegistrationIATAuth( + ctx context.Context, + opts InitiateOAuthDynamicRegistrationIATAuthOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger( + opts.RequestID, + oauthDynamicRegistrationLocation, + "InitiateOAuthDynamicRegistrationIATAuth", + ).With( + "redirectUri", opts.RedirectURI, + ) + logger.InfoContext(ctx, "Starting OAuth dynamic registration IAT authorization...") + + if opts.SessionKey == "" { + if opts.RefreshToken == "" { + logger.InfoContext(ctx, "No session key or refresh token provided, redirecting to login") + return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ + domain: opts.Domain, + state: opts.State, + challenge: opts.Challenge, + challengeMethod: opts.ChallengeMethod, + redirectURI: opts.RedirectURI, + }), nil + } + + logger.InfoContext(ctx, "No session key provided, attempting to refresh with refresh token") + return s.refreshTokenOAuthDynamicRegistrationIATLogin( + ctx, + refreshTokenOAuthDynamicRegistrationIATLoginOptions{ + requestID: opts.RequestID, + refreshToken: opts.RefreshToken, + challenge: opts.Challenge, + challengeMethod: opts.ChallengeMethod, + domain: opts.Domain, + redirectURI: opts.RedirectURI, + state: opts.State, + backendDomain: opts.BackendDomain, + }, + ) + } + + data, credsClientID, verified, found, err := s.cache.VerifyAccountCredentialsRegistrationSessionKey( + ctx, + cache.VerifyAccountCredentialsRegistrationSessionKeyOptions{ + RequestID: opts.RequestID, + SessionKey: opts.SessionKey, + Domain: opts.Domain, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to verify account credentials registration session key", "error", err) + return "", exceptions.NewInternalServerError() + } + if !found { + logger.InfoContext(ctx, "Account credentials registration session key not found") + return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ + domain: opts.Domain, + state: opts.State, + challenge: opts.Challenge, + challengeMethod: opts.ChallengeMethod, + redirectURI: opts.RedirectURI, + }), nil + } + + if !verified { + logger.WarnContext(ctx, "Account credentials registration session key is not verified") + return "", exceptions.NewUnauthorizedError() + } + + if err := s.cache.DeleteAccountCredentialsRegistrationSessionKey( + ctx, + cache.DeleteAccountCredentialsRegistrationSessionKeyOptions{ + RequestID: opts.RequestID, + ClientID: credsClientID, + }, + ); err != nil { + logger.ErrorContext(ctx, "Failed to delete account credentials registration session key", "error", err) + return "", exceptions.NewInternalServerError() + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: data.AccountPublicID, + Version: data.AccountVersion, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account by public ID and version", "serviceError", serviceErr) + return "", serviceErr + } + + logger.InfoContext(ctx, "Successfully verified account credentials registration IAT session key, creating code...") + cbURL, serviceErr := s.generateOAuthDynamicRegistrationIATCallback( + ctx, + generateOAuthDynamicRegistrationIATCallbackOptions{ + requestID: opts.RequestID, + clientID: credsClientID, + accountPublicID: accountDTO.PublicID, + accountVersion: accountDTO.Version(), + challenge: opts.Challenge, + challengeMethod: opts.ChallengeMethod, + domain: opts.Domain, + redirectURI: opts.RedirectURI, + state: opts.State, + backendDomain: opts.BackendDomain, + }, + ) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to generate OAuth dynamic registration IAT callback", "serviceErr", serviceErr) + return "", serviceErr + } + + return cbURL, nil +} + +type OAuthDynamicRegistrationIATAuthRenderOptions struct { RequestID string State string + Domain string CodeChallenge string CodeChallengeMethod string RedirectURI string } -func (s *Services) OAuthDynamicRegistrationIATAuth( +func (s *Services) OAuthDynamicRegistrationIATAuthRender( ctx context.Context, - opts OAuthDynamicRegistrationIATAuthOptions, + opts OAuthDynamicRegistrationIATAuthRenderOptions, ) (string, *exceptions.ServiceError) { logger := s.buildLogger( opts.RequestID, oauthDynamicRegistrationLocation, - "OAuthDynamicRegistrationIATAuth", + "OAuthDynamicRegistrationIATAuthRender", ).With( "redirectUri", opts.RedirectURI, ) - logger.InfoContext(ctx, "Starting OAuth dynamic registration IAT authorization...") + logger.InfoContext(ctx, "Starting OAuth dynamic registration IAT authorization html render...") hashedChallenge, serviceErr := hashChallenge(opts.CodeChallenge, opts.CodeChallengeMethod) if serviceErr != nil { @@ -50,13 +392,14 @@ func (s *Services) OAuthDynamicRegistrationIATAuth( return "", serviceErr } - clientID, csrfToken, err := s.cache.SaveAccountCredentialsDynamicRegistrationIATAuth( + clientID, csrfToken, err := s.cache.SaveAccountCredentialsDynamicRegistrationIATLogin( ctx, - cache.SaveAccountCredentialsDynamicRegistrationIATAuthOptions{ + cache.SaveAccountCredentialsDynamicRegistrationIATLoginOptions{ RequestID: opts.RequestID, Challenge: hashedChallenge, State: opts.State, RedirectURI: opts.RedirectURI, + Domain: opts.Domain, }, ) if err != nil { @@ -66,8 +409,9 @@ func (s *Services) OAuthDynamicRegistrationIATAuth( loginHTML, err := templates.BuildAccountDynamicRegistrationIATAuthTemplate( templates.AccountDynamicRegistrationIATAuthOptions{ - ClientID: clientID, + ACCClientID: clientID, CSRFToken: csrfToken, + Domain: opts.Domain, State: opts.State, CodeChallenge: opts.CodeChallenge, CodeChallengeMethod: opts.CodeChallengeMethod, @@ -87,64 +431,10 @@ func (s *Services) OAuthDynamicRegistrationIATAuth( return loginHTML, nil } -type createAccountCredentialsRegistrationIATCodeOptions struct { - requestID string - clientID string - accountPublicID uuid.UUID - accountVersion int32 - challenge string - domain string -} - -func (s *Services) createAccountCredentialsRegistrationIATCode( - ctx context.Context, - opts createAccountCredentialsRegistrationIATCodeOptions, -) (string, *exceptions.ServiceError) { - logger := s.buildLogger(opts.requestID, accountCredentialsRegistrationIATLocation, "createAccountCredentialsRegistrationIATCode").With( - "clientId", opts.clientID, - "accountPublicId", opts.accountPublicID, - ) - logger.InfoContext(ctx, "Creating account credentials registration IAT code...") - - count, err := s.database.CountAppsByClientIDAndAccountPublicID( - ctx, - database.CountAppsByClientIDAndAccountPublicIDParams{ - ClientID: opts.clientID, - AccountPublicID: opts.accountPublicID, - }, - ) - if err != nil { - logger.ErrorContext(ctx, "Failed to count apps by client ID and account public ID", "error", err) - return "", exceptions.FromDBError(err) - } - if count > 0 { - logger.WarnContext(ctx, "App with the same client ID already exists for this account") - return "", exceptions.NewUnauthorizedError() - } - - code, err := s.cache.GenerateAccountCredentialsRegistrationIATCode( - ctx, - cache.GenerateAccountCredentialsRegistrationIATCodeOptions{ - RequestID: opts.requestID, - ClientID: opts.clientID, - AccountPublicID: opts.accountPublicID, - AccountVersion: opts.accountVersion, - Challenge: opts.challenge, - Domain: opts.domain, - }, - ) - if err != nil { - logger.ErrorContext(ctx, "Failed to generate account credentials registration IAT code", "error", err) - return "", exceptions.NewInternalServerError() - } - - logger.InfoContext(ctx, "Created account credentials registration IAT code successfully") - return code, nil -} - type OAuthDynamicRegistrationIATLoginOptions struct { RequestID string - ClientID string + ACCClientID string + Domain string CSRFToken string CodeChallenge string CodeChallengeMethod string @@ -158,43 +448,37 @@ type OAuthDynamicRegistrationIATLoginOptions struct { func (s *Services) OAuthDynamicRegistrationIATLogin( ctx context.Context, opts OAuthDynamicRegistrationIATLoginOptions, -) (string, string, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, oauthDynamicRegistrationLocation, "OAuthDynamicRegistrationIATLogin").With( - "clientId", opts.ClientID, - "email", opts.Email, +) (string, string, bool, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, oauthDynamicRegistrationLocation, "OAuthDynamicRegistrationIATLoginPost").With( + "clientId", opts.ACCClientID, + "domain", opts.Domain, ) logger.InfoContext(ctx, "Logging in with OAuth dynamic registration IAT...") data, found, err := s.cache.GetAccountCredentialsDynamicRegistrationAuthIAT(ctx, cache.GetAccountCredentialsDynamicRegistrationIATAuthOptions{ RequestID: opts.RequestID, - ClientID: opts.ClientID, + ClientID: opts.ACCClientID, }) if err != nil { logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT", "error", err) - return "", "", exceptions.NewInternalServerError() + return "", "", false, exceptions.NewInternalServerError() } if !found { logger.ErrorContext(ctx, "Account credentials dynamic registration IAT not found") - return "", "", exceptions.NewNotFoundValidationError("invalid client ID") + return "", "", false, exceptions.NewNotFoundValidationError("invalid client ID") } - hashedChallenge, err := hashChallenge(opts.CodeChallenge, opts.CodeChallengeMethod) - if err != nil { - logger.ErrorContext(ctx, "Invalid code challenge", "error", err) - return "", "", exceptions.NewInternalServerError() - } - // Note this is not the verifier so standard comparison is ok - if hashedChallenge != data.Challenge { - logger.WarnContext(ctx, "OAuth Code challenge verification failed") - return "", "", exceptions.NewUnauthorizedError() + if data.Domain != opts.Domain { + logger.WarnContext(ctx, "OAuth Domain does not match", "dataDomain", data.Domain) + return "", "", false, exceptions.NewUnauthorizedError() } if data.State != opts.State { logger.WarnContext(ctx, "OAuth State does not match") - return "", "", exceptions.NewUnauthorizedError() + return "", "", false, exceptions.NewUnauthorizedError() } if data.RedirectURI != opts.RedirectURI { logger.WarnContext(ctx, "OAuth Redirect URI does not match") - return "", "", exceptions.NewUnauthorizedError() + return "", "", false, exceptions.NewUnauthorizedError() } accountDTO, serviceErr := s.GetAccountByEmail(ctx, GetAccountByEmailOptions{ @@ -203,11 +487,11 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( }) if serviceErr != nil { if serviceErr.Code != exceptions.CodeNotFound { - return "", "", serviceErr + return "", "", false, serviceErr } logger.WarnContext(ctx, "Account was not found", "error", serviceErr) - return "", "", exceptions.NewUnauthorizedError() + return "", "", false, exceptions.NewUnauthorizedError() } if _, err := s.database.FindAccountAuthProviderByAccountPublicIdAndProvider( ctx, @@ -219,25 +503,25 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( serviceErr := exceptions.FromDBError(err) if serviceErr.Code != exceptions.CodeNotFound { logger.ErrorContext(ctx, "Failed to find account auth provider", "error", err) - return "", "", serviceErr + return "", "", false, serviceErr } logger.WarnContext(ctx, "Account auth provider not found", "error", err) - return "", "", exceptions.NewUnauthorizedError() + return "", "", false, exceptions.NewUnauthorizedError() } passwordVerified, err := utils.Argon2CompareHash(opts.Password, accountDTO.Password()) if err != nil { logger.ErrorContext(ctx, "Failed to verify password", "error", err) - return "", "", exceptions.NewInternalServerError() + return "", "", false, exceptions.NewInternalServerError() } if !passwordVerified { logger.WarnContext(ctx, "Passwords do not match") - return "", "", exceptions.NewUnauthorizedError() + return "", "", false, exceptions.NewUnauthorizedError() } if !accountDTO.EmailVerified() { logger.InfoContext(ctx, "Account is not confirmed") - return "", "", exceptions.NewForbiddenError() + return "", "", false, exceptions.NewForbiddenError() } if accountDTO.TwoFactorType != database.TwoFactorTypeNone { @@ -250,7 +534,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( AccountVersion: accountDTO.Version(), RedirectURI: opts.RedirectURI, Domain: data.Domain, - ClientID: opts.ClientID, + ClientID: opts.ACCClientID, Challenge: data.Challenge, State: data.State, TwoFATTL: s.jwt.Get2FATTL(), @@ -258,7 +542,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( ) if err != nil { logger.ErrorContext(ctx, "Failed to save account credentials dynamic registration IAT 2FA", "error", err) - return "", "", exceptions.NewInternalServerError() + return "", "", false, exceptions.NewInternalServerError() } if accountDTO.TwoFactorType == database.TwoFactorTypeEmail { @@ -269,7 +553,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( }) if err != nil { logger.ErrorContext(ctx, "Failed to add two factor code", "error", err) - return "", "", exceptions.NewInternalServerError() + return "", "", false, exceptions.NewInternalServerError() } if err := s.mail.Publish2FAEmail(ctx, mailer.TwoFactorEmailOptions{ @@ -279,7 +563,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( Code: code, }); err != nil { logger.ErrorContext(ctx, "Failed to send two factor code email", "error", err) - return "", "", exceptions.NewInternalServerError() + return "", "", false, exceptions.NewInternalServerError() } logger.InfoContext(ctx, "Sent two factor code email successfully") @@ -289,14 +573,23 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( ctx, cache.DeleteAccountCredentialsDynamicRegistrationIATAuthOptions{ RequestID: opts.RequestID, - ClientID: opts.ClientID, + ClientID: opts.ACCClientID, }, ); err != nil { logger.ErrorContext(ctx, "Failed to delete account credentials dynamic registration IAT auth", "error", err) - return "", "", exceptions.NewInternalServerError() + return "", "", false, exceptions.NewInternalServerError() } - return paths.AccountsBase + paths.CredentialsBase + "/" + opts.ClientID + paths.InitialAccessToken + paths.AuthLogin + paths.Auth2FA, sessionID, nil + queryParams := make(url.Values) + queryParams.Add("client_id", data.Domain) + queryParams.Add("redirect_uri", opts.RedirectURI) + queryParams.Add("state", data.State) + queryParams.Add("code_challenge", opts.CodeChallenge) + if opts.CodeChallengeMethod != "" { + queryParams.Add("code_challenge_method", opts.CodeChallengeMethod) + } + return oauthDynamicRegistrationIATPath + "/" + opts.ACCClientID + + paths.AuthLogin + paths.Auth2FA + queryParams.Encode(), sessionID, false, nil } domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ @@ -307,83 +600,237 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( if serviceErr != nil { if serviceErr.Code != exceptions.CodeNotFound { logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) - return "", "", serviceErr + return "", "", false, serviceErr } if err := s.cache.DeleteAccountCredentialsDynamicRegistrationIATAuth( ctx, cache.DeleteAccountCredentialsDynamicRegistrationIATAuthOptions{ RequestID: opts.RequestID, - ClientID: opts.ClientID, + ClientID: opts.ACCClientID, }, ); err != nil { logger.ErrorContext(ctx, "Failed to delete account credentials dynamic registration IAT auth", "error", err) - return "", "", exceptions.NewInternalServerError() + return "", "", false, exceptions.NewInternalServerError() } logger.WarnContext(ctx, "Account credentials registration domain not found") - return "", "", exceptions.NewForbiddenError() + return "", "", false, exceptions.NewForbiddenError() } if !domainDTO.Verified { logger.ErrorContext(ctx, "Account credentials registration domain is not verified") - return "", "", exceptions.NewForbiddenError() + return "", "", false, exceptions.NewForbiddenError() } - code, serviceErr := s.createAccountCredentialsRegistrationIATCode( + sessionKey, err := s.cache.CreateAccountCredentialsRegistrationSessionKey( ctx, - createAccountCredentialsRegistrationIATCodeOptions{ - requestID: opts.RequestID, - clientID: opts.ClientID, - accountPublicID: accountDTO.PublicID, - accountVersion: accountDTO.Version(), - challenge: data.Challenge, - domain: data.Domain, + cache.CreateAccountCredentialsRegistrationSessionKeyOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, + Domain: opts.Domain, + AccountPublicID: accountDTO.PublicID, + AccountVersion: accountDTO.Version(), }, ) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to create account credentials registration IAT code", "serviceError", serviceErr) - return "", "", serviceErr + if err != nil { + logger.ErrorContext(ctx, "Failed to create account credentials registration session", "error", err) + return "", "", false, exceptions.NewInternalServerError() } - if err := s.cache.DeleteAccountCredentialsDynamicRegistrationIATAuth( + + return oauthDynamicRegistrationIATAuthPath, sessionKey, true, nil +} + +type OAuthDynamicRegistrationIAT2FAOptions struct { + RequestID string + Domain string + ACCClientID string + SessionID string + CodeChallenge string + CodeChallengeMethod string +} + +func (s *Services) OAuthDynamicRegistrationIAT2FA( + ctx context.Context, + opts OAuthDynamicRegistrationIAT2FAOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, oauthDynamicRegistrationLocation, "OAuthDynamicRegistrationIAT2FA").With( + "clientId", opts.ACCClientID, + ) + logger.InfoContext(ctx, "Handling OAuth dynamic registration IAT 2FA...") + + data, found, err := s.cache.GetAccountCredentialsDynamicRegistrationIAT2FA(ctx, cache.GetAccountCredentialsDynamicRegistrationIAT2FAOptions{ + RequestID: opts.RequestID, + SessionID: opts.SessionID, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT 2FA", "error", err) + return "", exceptions.NewInternalServerError() + } + if !found { + logger.WarnContext(ctx, "Failed to get account credentials dynamic registration IAT 2FA") + return "", exceptions.NewUnauthorizedError() + } + if opts.ACCClientID != data.ClientID { + logger.WarnContext(ctx, "Client IDs do not match", "sessionClientId", data.ClientID) + return "", exceptions.NewUnauthorizedError() + } + + csrfToken, err := s.cache.SaveAccountCredentialsDynamicRegistrationIAT2FACSRFToken( ctx, - cache.DeleteAccountCredentialsDynamicRegistrationIATAuthOptions{ + cache.SaveAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions{ RequestID: opts.RequestID, - ClientID: opts.ClientID, + SessionID: opts.SessionID, + TwoFATTL: s.jwt.Get2FATTL(), }, - ); err != nil { - logger.ErrorContext(ctx, "Failed to delete account credentials dynamic registration IAT auth", "error", err) - return "", "", exceptions.NewInternalServerError() + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to save account credentials dynamic registration IAT 2FA CSRF token", "error", err) + return "", exceptions.NewInternalServerError() } - queryParams := make(url.Values) - queryParams.Add("code", code) - queryParams.Add("state", data.State) - queryParams.Add("iss", "https://"+opts.BackendDomain) - return data.RedirectURI + "?" + queryParams.Encode(), "", nil + twoFAhtml, err := templates.BuildAccountDynamicRegistrationIAT2FATemplate( + templates.AccountDynamicRegistrationIAT2FAOptions{ + ClientID: opts.ACCClientID, + SessionID: opts.SessionID, + CSRFToken: csrfToken, + State: data.State, + CodeChallenge: opts.CodeChallenge, + CodeChallengeMethod: opts.CodeChallengeMethod, + RedirectURI: data.RedirectURI, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to build account dynamic registration IAT 2FA template", "error", err) + return "", exceptions.NewInternalServerError() + } + + return twoFAhtml, nil } -type OAuthDynamicRegistrationIAT2FAOptions struct { - RequestID string - ClientID string - SessionID string - Code string +type OAuthDynamicRegistrationIATVerify2FACodeOptions struct { + RequestID string + ACCClientID string + Domain string + SessionID string + CSRFToken string + CodeChallenge string + CodeChallengeMethod string + BackendDomain string + Code string } -type OAuthDynamicRegistrationOptions struct { - RedirectURIs []string - TokenEndpointAuthMethod string - ResponseTypes []string - GrantTypes []string - ApplicationType string - ClientName string - ClientURI string - LogoURI string - Scope string - Contacts []string - TOSURI string - PolicyURI string - JWKsURI string - JWKs []string - SoftwareID string - SoftwareVersion string +func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( + ctx context.Context, + opts OAuthDynamicRegistrationIATVerify2FACodeOptions, +) (string, string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, oauthDynamicRegistrationLocation, "OAuthDynamicRegistrationIATVerify2FACode").With( + "clientId", opts.ACCClientID, + "sessionId", opts.SessionID, + ) + logger.InfoContext(ctx, "Verifying OAuth dynamic registration IAT 2FA...") + + data, found, err := s.cache.GetAccountCredentialsDynamicRegistrationIAT2FA(ctx, cache.GetAccountCredentialsDynamicRegistrationIAT2FAOptions{ + RequestID: opts.RequestID, + SessionID: opts.SessionID, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT 2FA", "error", err) + return "", "", exceptions.NewInternalServerError() + } + if !found { + logger.WarnContext(ctx, "Failed to get account credentials dynamic registration IAT 2FA") + return "", "", exceptions.NewUnauthorizedError() + } + if opts.ACCClientID != data.ClientID { + logger.WarnContext(ctx, "Client IDs do not match", "sessionClientId", data.ClientID) + return "", "", exceptions.NewUnauthorizedError() + } + + csrfTokenValid, err := s.cache.VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFToken( + ctx, + cache.VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions{ + RequestID: opts.RequestID, + SessionID: opts.SessionID, + CSRFToken: opts.CSRFToken, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to verify account credentials dynamic registration IAT 2FA CSRF token", "error", err) + return "", "", exceptions.NewInternalServerError() + } + if !csrfTokenValid { + logger.WarnContext(ctx, "Invalid CSRF token") + return "", "", exceptions.NewForbiddenError() + } + + hashedChallenge, serviceErr := hashChallenge(opts.CodeChallenge, opts.CodeChallengeMethod) + if serviceErr != nil { + logger.ErrorContext(ctx, "Invalid code challenge", "serviceErr", serviceErr) + return "", "", exceptions.NewInternalServerError() + } + if hashedChallenge != data.Challenge { + logger.WarnContext(ctx, "OAuth Code challenge verification failed") + return "", "", exceptions.NewUnauthorizedError() + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: data.AccountPublicID, + Version: data.AccountVersion, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account by public ID and version", "serviceError", serviceErr) + return "", "", serviceErr + } + if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { + logger.WarnContext(ctx, "Failed to verify account two factor", "serviceError", serviceErr) + return "", "", serviceErr + } + + domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + RequestID: opts.RequestID, + AccountPublicID: accountDTO.PublicID, + Domain: data.Domain, + }) + if serviceErr != nil { + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) + return "", "", serviceErr + } + + if err := s.cache.DeleteAccountCredentialsDynamicRegistrationIATAuth( + ctx, + cache.DeleteAccountCredentialsDynamicRegistrationIATAuthOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, + }, + ); err != nil { + logger.ErrorContext(ctx, "Failed to delete account credentials dynamic registration IAT auth", "error", err) + return "", "", exceptions.NewInternalServerError() + } + + logger.WarnContext(ctx, "Account credentials registration domain not found") + return "", "", exceptions.NewForbiddenError() + } + if !domainDTO.Verified { + logger.ErrorContext(ctx, "Account credentials registration domain is not verified") + return "", "", exceptions.NewForbiddenError() + } + + sessionKey, err := s.cache.CreateAccountCredentialsRegistrationSessionKey( + ctx, + cache.CreateAccountCredentialsRegistrationSessionKeyOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, + Domain: opts.Domain, + AccountPublicID: accountDTO.PublicID, + AccountVersion: accountDTO.Version(), + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account credentials registration session", "error", err) + return "", "", exceptions.NewInternalServerError() + } + + return oauthDynamicRegistrationIATAuthPath, sessionKey, nil } diff --git a/idp/internal/services/templates/account_dynamic_registration.go b/idp/internal/services/templates/account_dynamic_registration.go index 879d47c..71b5d33 100644 --- a/idp/internal/services/templates/account_dynamic_registration.go +++ b/idp/internal/services/templates/account_dynamic_registration.go @@ -408,7 +408,8 @@ type accountDynamicRegistrationLoginTemplateData struct { } type AccountDynamicRegistrationIATAuthOptions struct { - ClientID string + ACCClientID string + Domain string CSRFToken string State string CodeChallenge string @@ -422,7 +423,8 @@ type AccountDynamicRegistrationIATAuthOptions struct { } func BuildAccountDynamicRegistrationIATAuthTemplate(opts AccountDynamicRegistrationIATAuthOptions) (string, error) { - baseURL := paths.AccountsBase + paths.CredentialsBase + "/" + opts.ClientID + paths.InitialAccessToken + baseURL := paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.DynamicRegistrationBase + + paths.InitialAccessToken + "/" + opts.ACCClientID data := accountDynamicRegistrationLoginTemplateData{ Title: baseAccountLoginTitle, Header: "OAuth Dynamic Client Registration Initial Access Token Login", @@ -436,6 +438,7 @@ func BuildAccountDynamicRegistrationIATAuthTemplate(opts AccountDynamicRegistrat baseTemplateBody := loginForm + divider urlParams := make(url.Values) + urlParams.Add("client_id", opts.Domain) urlParams.Add("response_type", "code") urlParams.Add("state", opts.State) urlParams.Add("code_challenge", opts.CodeChallenge) @@ -670,9 +673,9 @@ type AccountDynamicRegistrationIAT2FAOptions struct { } func BuildAccountDynamicRegistrationIAT2FATemplate(opts AccountDynamicRegistrationIAT2FAOptions) (string, error) { - baseURL := paths.AccountsBase + paths.CredentialsBase + "/" + opts.ClientID + paths.InitialAccessToken data := accountDynamicRegistrationIAT2FAData{ - TwoFAURL: baseURL + paths.AuthLogin + paths.Auth2FA, + TwoFAURL: paths.AccountsBase + paths.CredentialsBase + paths.DynamicRegistrationBase + + paths.InitialAccessToken + "/" + opts.ClientID + paths.AuthLogin + paths.Auth2FA, RedirectURI: opts.RedirectURI, CodeChallenge: opts.CodeChallenge, CodeChallengeMethod: opts.CodeChallengeMethod, diff --git a/idp/internal/services/templates/error.go b/idp/internal/services/templates/error.go new file mode 100644 index 0000000..69bf978 --- /dev/null +++ b/idp/internal/services/templates/error.go @@ -0,0 +1,308 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package templates + +import ( + "bytes" + "fmt" + "html/template" + "strings" + + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const errorTemplate = ` + + + + + + {{.Status}} {{.ErrorCode}} - Dev Logs + + +
+
+
+ +
+

{{.Status}} {{.ErrorCode}}

+
+ +
+

{{.MessageTitle}}

+ %s +
+
+ + +` + +const InternalServerErrorTemplate = ` + + + + + + 500 Internal Server Error - Dev Logs + + +
+
+
+ +
+

500 Internal Server Error

+
+ +
+

Internal Server Error

+

Something Error

+
+
+ + +` + +const errorTemplateName = "error" + +type errorTemplateData struct { + Status int + ErrorCode string + MessageTitle string +} + +type ErrorTemplateOptions struct { + Status int + ErrorCode string + MessageTitle string + Messages []string +} + +func BuildErrorTemplate(options ErrorTemplateOptions) (string, error) { + messageParagraphs := utils.MapSlice(options.Messages, func(msg *string) string { + return fmt.Sprintf("

%s

", *msg) + }) + errTemplate := fmt.Sprintf(errorTemplate, strings.Join(messageParagraphs, "\n")) + + data := errorTemplateData{ + Status: options.Status, + ErrorCode: options.ErrorCode, + MessageTitle: options.MessageTitle, + } + + t, err := template.New(errorTemplateName).Parse(errTemplate) + if err != nil { + return "", err + } + + var errorContent bytes.Buffer + if err := t.Execute(&errorContent, data); err != nil { + return "", err + } + + return errorContent.String(), nil +} diff --git a/idp/internal/services/templates/error.html b/idp/internal/services/templates/error.html new file mode 100644 index 0000000..0e83969 --- /dev/null +++ b/idp/internal/services/templates/error.html @@ -0,0 +1,123 @@ + + + + + + 500 InternalServerError - Dev Logs + + +
+
+
+ +
+

500 Internal Server Error

+
+ +
+

Internal Server Error

+

Something Error

+
+
+ + \ No newline at end of file From ecc116e8d574d937f94f6a7d4e78f9d85adccdfb Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Mon, 1 Sep 2025 18:23:57 +1200 Subject: [PATCH 13/23] feat(idp): add IAT code controllers --- ...ccount_credentials_registration_domains.go | 3 + .../controllers/bodies/common_auth.go | 2 +- .../bodies/oauth_dynamic_registration.go | 17 + idp/internal/controllers/controllers.go | 13 +- idp/internal/controllers/helpers.go | 35 ++ idp/internal/controllers/oauth.go | 101 +++-- .../controllers/oauth_dynamic_registration.go | 417 ++++++++++++++++-- .../params/oauth_dynamic_registration.go | 14 +- .../controllers/paths/dynamic_registration.go | 6 +- idp/internal/exceptions/controllers.go | 15 +- ...ccount_credentials_dynamic_registration.go | 8 +- .../account_credentials_registration_iat.go | 87 ---- .../services/dtos/initial_access_token.go | 7 + .../services/oauth_dynamic_registration.go | 269 +++++++++-- .../templates/account_dynamic_registration.go | 150 ++++--- idp/internal/services/templates/login.html | 14 +- 16 files changed, 864 insertions(+), 294 deletions(-) create mode 100644 idp/internal/services/dtos/initial_access_token.go diff --git a/idp/internal/controllers/account_credentials_registration_domains.go b/idp/internal/controllers/account_credentials_registration_domains.go index 6d29ef5..bd05720 100644 --- a/idp/internal/controllers/account_credentials_registration_domains.go +++ b/idp/internal/controllers/account_credentials_registration_domains.go @@ -101,6 +101,9 @@ func (c *Controllers) ListAccountCredentialsRegistrationDomains(ctx *fiber.Ctx) }, ) } + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } logResponse(logger, ctx, fiber.StatusOK) return ctx.Status(fiber.StatusOK).JSON(dtos.NewPaginationDTO( diff --git a/idp/internal/controllers/bodies/common_auth.go b/idp/internal/controllers/bodies/common_auth.go index 5821947..4900f2b 100644 --- a/idp/internal/controllers/bodies/common_auth.go +++ b/idp/internal/controllers/bodies/common_auth.go @@ -15,7 +15,7 @@ type ConfirmationTokenBody struct { } type LoginBody struct { - Email string `json:"email" validate:"required,email"` + Email string `json:"email" validate:"required,email,max=250"` Password string `json:"password" validate:"required,min=1"` } diff --git a/idp/internal/controllers/bodies/oauth_dynamic_registration.go b/idp/internal/controllers/bodies/oauth_dynamic_registration.go index 23a8fd7..f733318 100644 --- a/idp/internal/controllers/bodies/oauth_dynamic_registration.go +++ b/idp/internal/controllers/bodies/oauth_dynamic_registration.go @@ -24,3 +24,20 @@ type OAuthDynamicClientRegistrationBody struct { SoftwareID string `json:"software_id,omitempty" validate:"omitempty,max=250"` SoftwareVersion string `json:"software_version,omitempty" validate:"omitempty,max=250"` } + +type OAuthDynamicRegistrationIATAuthHiddenFieldsBody struct { + CSRFToken string `json:"csrf_token" validate:"required,min=21,base64rawurl"` + ClientID string `json:"client_id" validate:"required,fqdn"` + ResponseType string `json:"response_type" validate:"required,oneof=code"` + CodeChallenge string `json:"code_challenge" validate:"required,min=1"` + CodeChallengeMethod string `json:"code_challenge_method" validate:"omitempty,oneof=plain s256 S256"` + State string `json:"state" validate:"required,min=1"` + RedirectURI string `json:"redirect_uri" validate:"required,uri"` +} + +type OAuthDynamicRegistrationIATTokenBody struct { + ClientID string `json:"client_id" validate:"required,fqdn"` + GrantType string `json:"grant_type" validate:"required,eq=authorization_code"` + Code string `json:"code" validate:"required,min=1"` + CodeVerifier string `json:"code_verifier" validate:"required,min=1"` +} diff --git a/idp/internal/controllers/controllers.go b/idp/internal/controllers/controllers.go index 83e59f8..f454d74 100644 --- a/idp/internal/controllers/controllers.go +++ b/idp/internal/controllers/controllers.go @@ -16,13 +16,12 @@ import ( ) type Controllers struct { - logger *slog.Logger - services *services.Services - validate *validator.Validate - frontendDomain string - backendDomain string - cookieName string - sessionCookieName string + logger *slog.Logger + services *services.Services + validate *validator.Validate + frontendDomain string + backendDomain string + cookieName string } func NewControllers( diff --git a/idp/internal/controllers/helpers.go b/idp/internal/controllers/helpers.go index 77e2f1d..2742497 100644 --- a/idp/internal/controllers/helpers.go +++ b/idp/internal/controllers/helpers.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "log/slog" + "net/url" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" @@ -179,3 +180,37 @@ func parseRequestErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, err error) e Status(fiber.StatusBadRequest). JSON(exceptions.NewEmptyValidationErrorResponse(exceptions.ValidationResponseLocationBody)) } + +func (c *Controllers) redirectErrorCallback( + logger *slog.Logger, + ctx *fiber.Ctx, + redirectURI string, + state string, + errMsg string, +) error { + qPrams := make(url.Values) + qPrams.Add("error", errMsg) + if state != "" { + qPrams.Add("state", state) + } + qPrams.Add("iss", fmt.Sprintf("https://%s", c.backendDomain)) + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect(redirectURI+"?"+qPrams.Encode(), fiber.StatusFound) +} + +func (c *Controllers) redirectServiceErrorCallback( + logger *slog.Logger, + ctx *fiber.Ctx, + redirectURI string, + state string, + serviceErr *exceptions.ServiceError, +) error { + switch serviceErr.Code { + case exceptions.CodeUnauthorized, exceptions.CodeForbidden: + return c.redirectErrorCallback(logger, ctx, redirectURI, state, exceptions.OAuthErrorAccessDenied) + case exceptions.CodeNotFound, exceptions.CodeValidation: + return c.redirectErrorCallback(logger, ctx, redirectURI, state, exceptions.OAuthErrorInvalidRequest) + default: + return c.redirectErrorCallback(logger, ctx, redirectURI, state, exceptions.OAuthServerError) + } +} diff --git a/idp/internal/controllers/oauth.go b/idp/internal/controllers/oauth.go index a9cc491..0650122 100644 --- a/idp/internal/controllers/oauth.go +++ b/idp/internal/controllers/oauth.go @@ -10,8 +10,10 @@ import ( "encoding/json" "fmt" "log/slog" + "net/url" "github.com/gofiber/fiber/v2" + "github.com/tugascript/devlogs/idp/internal/controllers/bodies" "github.com/tugascript/devlogs/idp/internal/controllers/params" "github.com/tugascript/devlogs/idp/internal/exceptions" @@ -29,6 +31,38 @@ func formatAccountRedirectURL(backendDomain, provider string) string { return fmt.Sprintf("https://%s/v1/auth/oauth2/%s/callback", backendDomain, provider) } +func (c *Controllers) errorCallback(logger *slog.Logger, ctx *fiber.Ctx, state string, errStr string) error { + qPrams := make(url.Values) + qPrams.Add("error", errStr) + if state != "" { + qPrams.Add("state", state) + } + + qPrams.Add("iss", fmt.Sprintf("https://%s", c.backendDomain)) + ctx.Set(fiber.HeaderCacheControl, cacheControlNoStore) + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect( + fmt.Sprintf("https://%s/auth/callback?error=%s", c.frontendDomain, qPrams.Encode()), + fiber.StatusFound, + ) +} + +func (c *Controllers) serviceErrorCallback( + logger *slog.Logger, + ctx *fiber.Ctx, + state string, + serviceErr *exceptions.ServiceError, +) error { + switch serviceErr.Code { + case exceptions.CodeUnauthorized, exceptions.CodeForbidden: + return c.errorCallback(logger, ctx, state, exceptions.OAuthErrorAccessDenied) + case exceptions.CodeNotFound, exceptions.CodeValidation: + return c.errorCallback(logger, ctx, state, exceptions.OAuthErrorInvalidRequest) + default: + return c.errorCallback(logger, ctx, state, exceptions.OAuthServerError) + } +} + func (c *Controllers) AccountOAuthURL(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) logger := c.buildLogger(requestID, oauthLocation, "AccountOAuthURL") @@ -42,10 +76,10 @@ func (c *Controllers) AccountOAuthURL(ctx *fiber.Ctx) error { State: ctx.Query("state"), } if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { - return validateQueryParamsErrorResponse(logger, ctx, err) + return c.errorCallback(logger, ctx, qPrms.State, exceptions.OAuthErrorInvalidRequest) } - url, serviceErr := c.services.AccountOAuthURL(ctx.UserContext(), services.AccountOAuthURLOptions{ + oAuthURL, serviceErr := c.services.AccountOAuthURL(ctx.UserContext(), services.AccountOAuthURLOptions{ RequestID: requestID, Provider: qPrms.ClientID, RedirectURL: formatAccountRedirectURL(c.backendDomain, qPrms.ClientID), @@ -54,11 +88,11 @@ func (c *Controllers) AccountOAuthURL(ctx *fiber.Ctx) error { State: qPrms.State, }) if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) + return c.serviceErrorCallback(logger, ctx, qPrms.State, serviceErr) } logResponse(logger, ctx, fiber.StatusFound) - return ctx.Redirect(url, fiber.StatusFound) + return ctx.Redirect(oAuthURL, fiber.StatusFound) } func (c *Controllers) acceptCallback(logger *slog.Logger, ctx *fiber.Ctx, oauthParams string) error { @@ -70,15 +104,6 @@ func (c *Controllers) acceptCallback(logger *slog.Logger, ctx *fiber.Ctx, oauthP ) } -func (c *Controllers) errorCallback(logger *slog.Logger, ctx *fiber.Ctx, errStr string) error { - ctx.Set(fiber.HeaderCacheControl, cacheControlNoStore) - logResponse(logger, ctx, fiber.StatusFound) - return ctx.Redirect( - fmt.Sprintf("https://%s/auth/callback?error=%s", c.frontendDomain, errStr), - fiber.StatusFound, - ) -} - func (c *Controllers) AccountOAuthCallback(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) logger := c.buildLogger(requestID, oauthLocation, "AccountOAuthCallback") @@ -89,35 +114,28 @@ func (c *Controllers) AccountOAuthCallback(ctx *fiber.Ctx) error { return validateURLParamsErrorResponse(logger, ctx, err) } - queryParams := params.OAuthCallbackQueryParams{ + qPrms := params.OAuthCallbackQueryParams{ Code: ctx.Query("code"), State: ctx.Query("state"), } - if err := c.validate.StructCtx(ctx.UserContext(), queryParams); err != nil { + if err := c.validate.StructCtx(ctx.UserContext(), &qPrms); err != nil { errQuery := ctx.Query("error") if errQuery != "" { - return c.errorCallback(logger, ctx, errQuery) + return c.errorCallback(logger, ctx, qPrms.State, errQuery) } - return c.errorCallback(logger, ctx, exceptions.OAuthErrorInvalidRequest) + return c.errorCallback(logger, ctx, qPrms.State, exceptions.OAuthErrorInvalidRequest) } oauthParams, serviceErr := c.services.ExtLoginAccount(ctx.UserContext(), services.ExtLoginAccountOptions{ RequestID: requestID, Provider: urlParams.Provider, - Code: queryParams.Code, - State: queryParams.State, + Code: qPrms.Code, + State: qPrms.State, RedirectURL: formatAccountRedirectURL(c.backendDomain, urlParams.Provider), }) if serviceErr != nil { - switch serviceErr.Code { - case exceptions.CodeUnauthorized, exceptions.CodeForbidden: - return c.errorCallback(logger, ctx, exceptions.OAuthErrorAccessDenied) - case exceptions.CodeNotFound, exceptions.CodeValidation: - return c.errorCallback(logger, ctx, exceptions.OAuthErrorInvalidRequest) - default: - return c.errorCallback(logger, ctx, exceptions.OAuthServerError) - } + return c.serviceErrorCallback(logger, ctx, qPrms.State, serviceErr) } return c.acceptCallback(logger, ctx, oauthParams) @@ -129,24 +147,24 @@ func (c *Controllers) AccountAppleCallback(ctx *fiber.Ctx) error { logRequest(logger, ctx) if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { - return c.errorCallback(logger, ctx, exceptions.OAuthErrorInvalidRequest) + return c.errorCallback(logger, ctx, "", exceptions.OAuthErrorInvalidRequest) } body := new(bodies.AppleLoginBody) if err := ctx.BodyParser(body); err != nil { - return c.errorCallback(logger, ctx, exceptions.OAuthErrorInvalidRequest) + return c.errorCallback(logger, ctx, "", exceptions.OAuthErrorInvalidRequest) } if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { - return c.errorCallback(logger, ctx, exceptions.OAuthErrorInvalidRequest) + return c.errorCallback(logger, ctx, "", exceptions.OAuthErrorInvalidRequest) } user := new(bodies.AppleUser) if err := json.Unmarshal([]byte(body.User), user); err != nil { - return c.errorCallback(logger, ctx, exceptions.OAuthErrorInvalidScope) + return c.errorCallback(logger, ctx, "", exceptions.OAuthErrorInvalidScope) } if err := c.validate.StructCtx(ctx.UserContext(), user); err != nil { logger.WarnContext(ctx.UserContext(), "Failed to parse apple user data") - return c.errorCallback(logger, ctx, exceptions.OAuthErrorInvalidScope) + return c.errorCallback(logger, ctx, "", exceptions.OAuthErrorInvalidScope) } oauthParams, serviceErr := c.services.AppleLoginAccount(ctx.UserContext(), services.AppleLoginAccountOptions{ @@ -158,14 +176,7 @@ func (c *Controllers) AccountAppleCallback(ctx *fiber.Ctx) error { State: body.State, }) if serviceErr != nil { - switch serviceErr.Code { - case exceptions.CodeUnauthorized, exceptions.CodeForbidden: - return c.errorCallback(logger, ctx, exceptions.OAuthErrorAccessDenied) - case exceptions.CodeNotFound, exceptions.CodeValidation: - return c.errorCallback(logger, ctx, exceptions.OAuthErrorInvalidRequest) - default: - return c.errorCallback(logger, ctx, exceptions.OAuthServerError) - } + return c.serviceErrorCallback(logger, ctx, "", serviceErr) } return c.acceptCallback(logger, ctx, oauthParams) @@ -392,18 +403,10 @@ func (c *Controllers) AccountOAuthToken(ctx *fiber.Ctx) error { logRequest(logger, ctx) if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { - return serviceErrorResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError( - "Content-Type must be application/x-www-form-urlencoded", - )) + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidRequest) } grantType := ctx.FormValue("grant_type") - if grantType == "" { - logger.WarnContext(ctx.UserContext(), "Missing grant_type") - logResponse(logger, ctx, fiber.StatusBadRequest) - - } - switch grantType { case grantTypeRefresh: return c.accountRefreshToken(ctx, requestID) diff --git a/idp/internal/controllers/oauth_dynamic_registration.go b/idp/internal/controllers/oauth_dynamic_registration.go index 2034aaa..b5df165 100644 --- a/idp/internal/controllers/oauth_dynamic_registration.go +++ b/idp/internal/controllers/oauth_dynamic_registration.go @@ -7,12 +7,16 @@ package controllers import ( + "fmt" + "github.com/gofiber/fiber/v2" + "github.com/tugascript/devlogs/idp/internal/controllers/bodies" "github.com/tugascript/devlogs/idp/internal/controllers/params" "github.com/tugascript/devlogs/idp/internal/controllers/paths" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/services" + "github.com/tugascript/devlogs/idp/internal/utils" ) const ( @@ -27,29 +31,41 @@ func (c *Controllers) OAuthDynamicRegistrationIATAuth(ctx *fiber.Ctx) error { logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATAuth") logRequest(logger, ctx) + baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ + ClientID: ctx.Query("client_id"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + responseType := ctx.Query("response_type") + state := ctx.Query("state") + if responseType != "code" { + return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorUnsupportedResponseType) + } + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ - ClientID: ctx.Query("client_id"), - ResponseType: ctx.Query("response_type"), + ResponseType: responseType, Challenge: ctx.Query("code_challenge"), ChallengeMethod: ctx.Query("code_challenge_method"), - State: ctx.Query("state"), - RedirectURI: ctx.Query("redirect_uri"), + State: state, } if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { - return validationErrorHTMLResponse(logger, ctx, exceptions.ValidationResponseLocationQuery, err) + return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorInvalidRequest) } redirectURL, serviceErr := c.services.InitiateOAuthDynamicRegistrationIATAuth( ctx.UserContext(), services.InitiateOAuthDynamicRegistrationIATAuthOptions{ RequestID: requestID, - Domain: qPrms.ClientID, + Domain: baseQPrms.ClientID, State: qPrms.State, SessionKey: ctx.Cookies(c.cookieName + accountsIATCookieSuffix), RefreshToken: ctx.Cookies(c.cookieName + refreshCookieSuffix), Challenge: qPrms.Challenge, ChallengeMethod: qPrms.ChallengeMethod, - RedirectURI: qPrms.RedirectURI, + RedirectURI: baseQPrms.RedirectURI, BackendDomain: c.backendDomain, }, ) @@ -66,16 +82,22 @@ func (c *Controllers) OAuthDynamicRegistrationIATLoginGet(ctx *fiber.Ctx) error logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATLoginGet") logRequest(logger, ctx) + baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ + ClientID: ctx.Query("client_id"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ - ClientID: ctx.Query("client_id"), ResponseType: ctx.Query("response_type"), Challenge: ctx.Query("code_challenge"), ChallengeMethod: ctx.Query("code_challenge_method"), State: ctx.Query("state"), - RedirectURI: ctx.Query("redirect_uri"), } if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { - return validationErrorHTMLResponse(logger, ctx, exceptions.ValidationResponseLocationQuery, err) + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) } loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthRender( @@ -83,10 +105,10 @@ func (c *Controllers) OAuthDynamicRegistrationIATLoginGet(ctx *fiber.Ctx) error services.OAuthDynamicRegistrationIATAuthRenderOptions{ RequestID: requestID, State: qPrms.State, + Domain: baseQPrms.ClientID, CodeChallenge: qPrms.Challenge, CodeChallengeMethod: qPrms.ChallengeMethod, - RedirectURI: qPrms.RedirectURI, - Domain: qPrms.ClientID, + RedirectURI: baseQPrms.RedirectURI, }, ) if serviceErr != nil { @@ -102,10 +124,9 @@ func (c *Controllers) saveAccountIATCookie( sessionKey string, ) { ctx.Cookie(&fiber.Cookie{ - Name: c.cookieName + accountsIATCookieSuffix, - Value: sessionKey, - Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + - paths.OAuthAuth, + Name: c.cookieName + accountsIATCookieSuffix, + Value: sessionKey, + Path: paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + paths.OAuthAuth, HTTPOnly: true, SameSite: fiber.CookieSameSiteLaxMode, Secure: true, @@ -115,10 +136,9 @@ func (c *Controllers) saveAccountIATCookie( func (c *Controllers) removeAccountIATCookie(ctx *fiber.Ctx) { ctx.Cookie(&fiber.Cookie{ - Name: c.cookieName + accountsIATCookieSuffix, - Value: "", - Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + - paths.OAuthAuth, + Name: c.cookieName + accountsIATCookieSuffix, + Value: "", + Path: paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + paths.OAuthAuth, HTTPOnly: true, Secure: true, SameSite: fiber.CookieSameSiteNoneMode, @@ -126,19 +146,358 @@ func (c *Controllers) removeAccountIATCookie(ctx *fiber.Ctx) { }) } -func (c *Controllers) saveAccountIAT2FACookie( - ctx *fiber.Ctx, - sessionID string, - clientID string, -) { +func (c *Controllers) saveAccountIAT2FACookie(ctx *fiber.Ctx, sessionID, clientID string) { ctx.Cookie(&fiber.Cookie{ - Name: c.cookieName + accountsIAT2FACookieSuffix, - Value: sessionID, - Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + - clientID + paths.AuthLogin + paths.Auth2FA, + Name: c.cookieName + accountsIAT2FACookieSuffix, + Value: sessionID, + Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + "/" + clientID + paths.OAuthAuth, HTTPOnly: true, SameSite: fiber.CookieSameSiteLaxMode, Secure: true, MaxAge: int(c.services.GetOAuthCodeTTL()), }) } + +func (c *Controllers) removeAccountIAT2FACookie(ctx *fiber.Ctx, clientID string) { + ctx.Cookie(&fiber.Cookie{ + Name: c.cookieName + accountsIAT2FACookieSuffix, + Value: "", + Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + "/" + clientID + paths.OAuthAuth, + HTTPOnly: true, + SameSite: fiber.CookieSameSiteLaxMode, + Secure: true, + MaxAge: -1, + }) +} + +func (c *Controllers) OAuthDynamicRegistrationIATLoginPost(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATLoginPost") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError("Only application/x-www-form-urlencoded is supported")) + } + + hiddenFields := bodies.OAuthDynamicRegistrationIATAuthHiddenFieldsBody{ + CSRFToken: ctx.FormValue("csrf_token"), + ClientID: ctx.FormValue("client_id"), + ResponseType: ctx.FormValue("response_type"), + CodeChallenge: ctx.FormValue("code_challenge"), + CodeChallengeMethod: ctx.FormValue("code_challenge_method"), + State: ctx.FormValue("state"), + RedirectURI: ctx.FormValue("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &hiddenFields); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + loginBody := bodies.LoginBody{ + Email: ctx.FormValue("email"), + Password: ctx.FormValue("password"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &loginBody); err != nil { + valErr := validationErrorException(exceptions.ValidationResponseLocationBody, err) + loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthReRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATAuthReRenderOptions{ + RequestID: requestID, + Errors: utils.MapSlice(valErr.Fields, func(t *exceptions.FieldError) string { + return fmt.Sprintf("%s %s", t.Param, t.Message) + }), + CSRFToken: hiddenFields.CSRFToken, + ACCClientID: uPrms.ACCClientID, + State: hiddenFields.State, + Domain: hiddenFields.ClientID, + CodeChallenge: hiddenFields.CodeChallenge, + CodeChallengeMethod: hiddenFields.CodeChallengeMethod, + RedirectURI: hiddenFields.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx. + Status(fiber.StatusOK). + Type("html"). + SendString(loginHTML) + } + + redirectURL, sessionKey, loggedIn, serviceErr := c.services.OAuthDynamicRegistrationIATLogin( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATLoginOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Domain: hiddenFields.ClientID, + CSRFToken: hiddenFields.CSRFToken, + CodeChallenge: hiddenFields.CodeChallenge, + CodeChallengeMethod: hiddenFields.CodeChallengeMethod, + State: hiddenFields.State, + RedirectURI: hiddenFields.RedirectURI, + Email: loginBody.Email, + Password: loginBody.Password, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + if serviceErr.Code == exceptions.CodeUnauthorized { + loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthReRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATAuthReRenderOptions{ + RequestID: requestID, + Errors: []string{"Invalid credentials"}, + CSRFToken: hiddenFields.CSRFToken, + ACCClientID: uPrms.ACCClientID, + State: hiddenFields.State, + Domain: hiddenFields.ClientID, + CodeChallenge: hiddenFields.CodeChallenge, + CodeChallengeMethod: hiddenFields.CodeChallengeMethod, + RedirectURI: hiddenFields.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx. + Status(fiber.StatusOK). + Type("html"). + SendString(loginHTML) + } + + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + if loggedIn { + c.saveAccountIAT2FACookie(ctx, sessionKey, uPrms.ACCClientID) + logResponse(logger, ctx, fiber.StatusSeeOther) + return ctx.Redirect(redirectURL, fiber.StatusSeeOther) + } + + c.saveAccountIATCookie(ctx, sessionKey) + logResponse(logger, ctx, fiber.StatusSeeOther) + return ctx.Redirect(redirectURL, fiber.StatusSeeOther) +} + +func (c *Controllers) OAuthDynamicRegistrationIAT2FAGet(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIAT2FAGet") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ + ClientID: ctx.Query("client_id"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ + ResponseType: ctx.Query("response_type"), + Challenge: ctx.Query("code_challenge"), + ChallengeMethod: ctx.Query("code_challenge_method"), + State: ctx.Query("state"), + } + if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + sessionID := ctx.Cookies(c.cookieName + accountsIAT2FACookieSuffix) + if sessionID == "" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnauthorizedError()) + } + + twoFAHTML, serviceErr := c.services.OAuthDynamicRegistrationIAT2FARender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIAT2FARenderOptions{ + RequestID: requestID, + Domain: baseQPrms.ClientID, + ACCClientID: uPrms.ACCClientID, + SessionID: sessionID, + Challenge: qPrms.Challenge, + ChallengeMethod: qPrms.ChallengeMethod, + State: qPrms.State, + RedirectURI: baseQPrms.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).Type("html").SendString(twoFAHTML) +} + +func (c *Controllers) OAuthDynamicRegistrationIAT2FAPost(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIAT2FAPost") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + sessionID := ctx.Cookies(c.cookieName + accountsIAT2FACookieSuffix) + if sessionID == "" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnauthorizedError()) + } + + if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError("Only application/x-www-form-urlencoded is supported")) + } + + hiddenFields := bodies.OAuthDynamicRegistrationIATAuthHiddenFieldsBody{ + CSRFToken: ctx.FormValue("csrf_token"), + ClientID: ctx.FormValue("client_id"), + ResponseType: ctx.FormValue("response_type"), + CodeChallenge: ctx.FormValue("code_challenge"), + CodeChallengeMethod: ctx.FormValue("code_challenge_method"), + State: ctx.FormValue("state"), + RedirectURI: ctx.FormValue("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &hiddenFields); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + twoFABody := bodies.TwoFactorLoginBody{ + Code: ctx.FormValue("code"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &twoFABody); err != nil { + valErr := validationErrorException(exceptions.ValidationResponseLocationBody, err) + twoFAHTML, serviceErr := c.services.OAuthDynamicRegistrationIAT2FAReRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIAT2FAReRenderOptions{ + RequestID: requestID, + Domain: hiddenFields.ClientID, + ACCClientID: uPrms.ACCClientID, + SessionID: sessionID, + Errors: utils.MapSlice(valErr.Fields, func(t *exceptions.FieldError) string { + return fmt.Sprintf("%s %s", t.Param, t.Message) + }), + CSRFToken: hiddenFields.CSRFToken, + Challenge: hiddenFields.CodeChallenge, + ChallengeMethod: hiddenFields.CodeChallengeMethod, + State: hiddenFields.State, + RedirectURI: hiddenFields.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx. + Status(fiber.StatusOK). + Type("html"). + SendString(twoFAHTML) + } + + redirectURL, sessionKey, serviceErr := c.services.OAuthDynamicRegistrationIATVerify2FACode( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATVerify2FACodeOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Domain: hiddenFields.ClientID, + SessionID: sessionID, + CSRFToken: hiddenFields.CSRFToken, + Code: twoFABody.Code, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + if serviceErr.Code == exceptions.CodeUnauthorized { + twoFAHTML, serviceErr := c.services.OAuthDynamicRegistrationIAT2FAReRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIAT2FAReRenderOptions{ + RequestID: requestID, + Domain: hiddenFields.ClientID, + ACCClientID: uPrms.ACCClientID, + SessionID: sessionID, + Errors: []string{"Invalid 2FA code"}, + CSRFToken: hiddenFields.CSRFToken, + Challenge: hiddenFields.CodeChallenge, + ChallengeMethod: hiddenFields.CodeChallengeMethod, + State: hiddenFields.State, + RedirectURI: hiddenFields.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx. + Status(fiber.StatusOK). + Type("html"). + SendString(twoFAHTML) + } + + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + c.removeAccountIAT2FACookie(ctx, uPrms.ACCClientID) + c.saveAccountIATCookie(ctx, sessionKey) + logResponse(logger, ctx, fiber.StatusSeeOther) + return ctx.Redirect(redirectURL, fiber.StatusSeeOther) +} + +func (c *Controllers) OAuthDynamicRegistrationIATToken(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATToken") + logRequest(logger, ctx) + + if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidRequest) + } + + grantType := ctx.Get("grant_type") + if grantType != "authorization_code" { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorUnsupportedGrantType) + } + + body := bodies.OAuthDynamicRegistrationIATTokenBody{ + GrantType: grantType, + Code: ctx.FormValue("code"), + ClientID: ctx.FormValue("client_id"), + CodeVerifier: ctx.FormValue("code_verifier"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &body); err != nil { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidRequest) + } + + authDTO, serviceErr := c.services.VerifyOAuthDynamicRegistrationIATCode( + ctx.UserContext(), + services.VerifyOAuthDynamicRegistrationIATCodeOptions{ + RequestID: requestID, + Code: body.Code, + CodeVerifier: body.CodeVerifier, + Domain: body.ClientID, + }, + ) + if serviceErr != nil { + return oauthErrorResponseMapper(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(authDTO) +} diff --git a/idp/internal/controllers/params/oauth_dynamic_registration.go b/idp/internal/controllers/params/oauth_dynamic_registration.go index 1cdc583..803c13f 100644 --- a/idp/internal/controllers/params/oauth_dynamic_registration.go +++ b/idp/internal/controllers/params/oauth_dynamic_registration.go @@ -6,18 +6,18 @@ package params +type OAuthDynamicRegistrationIATAuthBaseQueryParams struct { + ClientID string `validate:"required,fqdn"` + RedirectURI string `validate:"required,uri"` +} + type OAuthDynamicRegistrationIATAuthQueryParams struct { - ClientID string `validate:"required,fqdn"` ResponseType string `validate:"required,oneof=code"` Challenge string `validate:"required,min=1"` ChallengeMethod string `validate:"omitempty,oneof=plain s256 S256"` State string `validate:"required,min=1"` - RedirectURI string `validate:"required,uri"` } -type OAuthDynamicRegistrationIATAuthLoginGetQueryParams struct { - Challenge string `validate:"required,min=1"` - ChallengeMethod string `validate:"omitempty,oneof=plain s256 S256"` - RedirectURI string `validate:"required,url"` - State string `validate:"required,min=1"` +type OAuthDynamicRegistrationIATAuthURLParams struct { + ACCClientID string `validate:"required,min=22,max=22,alphanum"` } diff --git a/idp/internal/controllers/paths/dynamic_registration.go b/idp/internal/controllers/paths/dynamic_registration.go index 35172f9..6395527 100644 --- a/idp/internal/controllers/paths/dynamic_registration.go +++ b/idp/internal/controllers/paths/dynamic_registration.go @@ -7,6 +7,8 @@ package paths const ( - DynamicRegistrationBase string = "/dynamic-registration" - InitialAccessToken string = "/initial-access-token" + DynamicRegistrationBase string = "/dynamic-registration" + InitialAccessToken string = "/initial-access-token" + InitialAccessTokenAuthEXT string = "/ext" + InitialAccessTokenSingle string = "/:accClientID" ) diff --git a/idp/internal/exceptions/controllers.go b/idp/internal/exceptions/controllers.go index ede05e6..54d0742 100644 --- a/idp/internal/exceptions/controllers.go +++ b/idp/internal/exceptions/controllers.go @@ -24,13 +24,14 @@ const ( StatusForbidden string = "Forbidden" StatusValidation string = "Validation" - OAuthErrorInvalidRequest string = "invalid_request" - OAuthErrorInvalidGrant string = "invalid_grant" - OAuthErrorUnauthorizedClient string = "unauthorized_client" - OAuthErrorAccessDenied string = "access_denied" - OAuthServerError string = "server_error" - OAuthErrorInvalidScope string = "invalid_scope" - OAuthErrorUnsupportedGrantType string = "unsupported_grant_type" + OAuthErrorInvalidRequest string = "invalid_request" + OAuthErrorInvalidGrant string = "invalid_grant" + OAuthErrorUnauthorizedClient string = "unauthorized_client" + OAuthErrorAccessDenied string = "access_denied" + OAuthServerError string = "server_error" + OAuthErrorInvalidScope string = "invalid_scope" + OAuthErrorUnsupportedGrantType string = "unsupported_grant_type" + OAuthErrorUnsupportedResponseType string = "unsupported_response_type" ) type ErrorResponse struct { diff --git a/idp/internal/providers/cache/account_credentials_dynamic_registration.go b/idp/internal/providers/cache/account_credentials_dynamic_registration.go index 41bcb1a..460a8a9 100644 --- a/idp/internal/providers/cache/account_credentials_dynamic_registration.go +++ b/idp/internal/providers/cache/account_credentials_dynamic_registration.go @@ -33,7 +33,6 @@ func buildAccountCredentialsDynamicRegistrationIATLoginCacheKey(clientID string) type AccountCredentialsDynamicRegistrationIATLoginData struct { RedirectURI string `json:"redirect_uri"` - Challenge string `json:"challenge"` CSRFToken string `json:"csrf_token"` Domain string `json:"domain"` State string `json:"state"` @@ -42,7 +41,6 @@ type AccountCredentialsDynamicRegistrationIATLoginData struct { type SaveAccountCredentialsDynamicRegistrationIATLoginOptions struct { Domain string RequestID string - Challenge string State string RedirectURI string } @@ -67,7 +65,6 @@ func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATLogin( } data := AccountCredentialsDynamicRegistrationIATLoginData{ - Challenge: opts.Challenge, State: opts.State, Domain: opts.Domain, RedirectURI: opts.RedirectURI, @@ -150,7 +147,6 @@ type AccountCredentialsDynamicRegistrationIAT2FAData struct { AccountPublicID uuid.UUID `json:"account_public_id"` AccountVersion int32 `json:"account_version"` RedirectURI string `json:"redirect_uri"` - Challenge string `json:"challenge"` ClientID string `json:"clientId"` Domain string `json:"domain"` State string `json:"state"` @@ -167,7 +163,6 @@ type SaveAccountCredentialsDynamicRegistrationIAT2FAOptions struct { RedirectURI string Domain string ClientID string - Challenge string State string TwoFATTL int64 } @@ -193,7 +188,6 @@ func (c *Cache) SaveAccountCredentialsDynamicRegistrationIAT2FA( RedirectURI: opts.RedirectURI, Domain: opts.Domain, ClientID: opts.ClientID, - Challenge: opts.Challenge, State: opts.State, } dataBytes, err := json.Marshal(data) @@ -436,7 +430,7 @@ func (c *Cache) VerifyAccountCredentialsRegistrationIATCode( ) (AccountCredentialsDynamicRegistrationIATCodeData, bool, error) { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: accountCredentialsDynamicRegistrationLocation, - Method: "VerifyAccountCredentialsRegistrationIATCode", + Method: "VerifyOAuthDynamicRegistrationIATCode", RequestID: opts.RequestID, }) logger.DebugContext(ctx, "Verifying account credentials registration IAT code...") diff --git a/idp/internal/services/account_credentials_registration_iat.go b/idp/internal/services/account_credentials_registration_iat.go index 05fd925..4cf26d7 100644 --- a/idp/internal/services/account_credentials_registration_iat.go +++ b/idp/internal/services/account_credentials_registration_iat.go @@ -12,7 +12,6 @@ import ( "github.com/google/uuid" "github.com/tugascript/devlogs/idp/internal/exceptions" - "github.com/tugascript/devlogs/idp/internal/providers/cache" "github.com/tugascript/devlogs/idp/internal/providers/crypto" "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/providers/tokens" @@ -86,89 +85,3 @@ func (s *Services) CreateAccountCredentialsRegistrationIAT( logger.InfoContext(ctx, "Created account credentials registration IAT successfully") return signedToken, nil } - -type VerifyAccountCredentialsRegistrationIATCodeOptions struct { - RequestID string - Code string - CodeVerifier string -} - -func (s *Services) VerifyAccountCredentialsRegistrationIATCode( - ctx context.Context, - opts VerifyAccountCredentialsRegistrationIATCodeOptions, -) (string, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationIATLocation, "VerifyAccountCredentialsRegistrationIATCode") - logger.InfoContext(ctx, "Verifying account credentials registration IAT code...") - - data, found, err := s.cache.VerifyAccountCredentialsRegistrationIATCode(ctx, cache.VerifyAccountCredentialsRegistrationIATCodeOptions{ - RequestID: opts.RequestID, - Code: opts.Code, - }) - if err != nil { - logger.ErrorContext(ctx, "Failed to verify account credentials registration IAT code", "error", err) - return "", exceptions.NewInternalServerError() - } - if !found { - logger.DebugContext(ctx, "Account credentials registration IAT code not found or invalid") - return "", exceptions.NewUnauthorizedError() - } - - ok, err := utils.CompareShaBase64(data.Challenge, opts.CodeVerifier) - if err != nil { - logger.ErrorContext(ctx, "Failed to compare challenge", "error", err) - return "", exceptions.NewInternalServerError() - } - if !ok { - logger.WarnContext(ctx, "OAuth Code challenge verification failed") - return "", exceptions.NewUnauthorizedError() - } - - accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ - RequestID: opts.RequestID, - PublicID: data.AccountPublicID, - Version: data.AccountVersion, - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to get account", "serviceError", serviceErr) - return "", serviceErr - } - - domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ - RequestID: opts.RequestID, - AccountPublicID: accountDTO.PublicID, - Domain: data.Domain, - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) - return "", serviceErr - } - if !domainDTO.Verified { - logger.ErrorContext(ctx, "Account credentials registration domain is not verified") - return "", exceptions.NewValidationError("account credentials registration domain is not verified") - } - - signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ - RequestID: opts.RequestID, - Token: s.jwt.CreateAccountCredentialsDynamicRegistrationToken(tokens.AccountCredentialsDynamicRegistrationTokenOptions{ - AccountPublicID: accountDTO.PublicID, - AccountVersion: accountDTO.Version(), - Domain: data.Domain, - ClientID: data.ClientID, - }), - GetJWKfn: s.BuildGetGlobalEncryptedJWKFn(ctx, BuildEncryptedJWKFnOptions{ - RequestID: opts.RequestID, - KeyType: database.TokenKeyTypeDynamicRegistration, - TTL: s.jwt.GetDynamicRegistrationTTL(), - }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.RequestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.RequestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.RequestID), - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to sign account credentials registration IAT", "serviceError", serviceErr) - return "", serviceErr - } - - logger.InfoContext(ctx, "Verified account credentials registration IAT code successfully") - return signedToken, nil -} diff --git a/idp/internal/services/dtos/initial_access_token.go b/idp/internal/services/dtos/initial_access_token.go new file mode 100644 index 0000000..56e91a3 --- /dev/null +++ b/idp/internal/services/dtos/initial_access_token.go @@ -0,0 +1,7 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package dtos diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index 378e23d..c5ee3bb 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -12,6 +12,8 @@ import ( "slices" "github.com/google/uuid" + "github.com/tugascript/devlogs/idp/internal/providers/crypto" + "github.com/tugascript/devlogs/idp/internal/services/dtos" "github.com/tugascript/devlogs/idp/internal/controllers/paths" "github.com/tugascript/devlogs/idp/internal/exceptions" @@ -386,17 +388,10 @@ func (s *Services) OAuthDynamicRegistrationIATAuthRender( ) logger.InfoContext(ctx, "Starting OAuth dynamic registration IAT authorization html render...") - hashedChallenge, serviceErr := hashChallenge(opts.CodeChallenge, opts.CodeChallengeMethod) - if serviceErr != nil { - logger.ErrorContext(ctx, "Invalid code challenge", "serviceError", serviceErr) - return "", serviceErr - } - clientID, csrfToken, err := s.cache.SaveAccountCredentialsDynamicRegistrationIATLogin( ctx, cache.SaveAccountCredentialsDynamicRegistrationIATLoginOptions{ RequestID: opts.RequestID, - Challenge: hashedChallenge, State: opts.State, RedirectURI: opts.RedirectURI, Domain: opts.Domain, @@ -431,6 +426,57 @@ func (s *Services) OAuthDynamicRegistrationIATAuthRender( return loginHTML, nil } +type OAuthDynamicRegistrationIATAuthReRenderOptions struct { + RequestID string + Errors []string + CSRFToken string + ACCClientID string + State string + Domain string + CodeChallenge string + CodeChallengeMethod string + RedirectURI string +} + +func (s *Services) OAuthDynamicRegistrationIATAuthReRender( + ctx context.Context, + opts OAuthDynamicRegistrationIATAuthReRenderOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger( + opts.RequestID, + oauthDynamicRegistrationLocation, + "OAuthDynamicRegistrationIATAuthReRender", + ).With( + "clientId", opts.ACCClientID, + "redirectUri", opts.RedirectURI, + ) + logger.InfoContext(ctx, "Re-rendering OAuth dynamic registration IAT authorization html...") + + loginHTML, err := templates.BuildAccountDynamicRegistrationIATAuthTemplate( + templates.AccountDynamicRegistrationIATAuthOptions{ + Errors: opts.Errors, + ACCClientID: opts.ACCClientID, + CSRFToken: opts.CSRFToken, + Domain: opts.Domain, + State: opts.State, + CodeChallenge: opts.CodeChallenge, + CodeChallengeMethod: opts.CodeChallengeMethod, + RedirectURI: opts.RedirectURI, + AppleEnabled: s.oauthProviders.IsAppleEnabled(), + FacebookEnabled: s.oauthProviders.IsFacebookEnabled(), + GitHubEnabled: s.oauthProviders.IsGitHubEnabled(), + GoogleEnabled: s.oauthProviders.IsGoogleEnabled(), + MicrosoftEnabled: s.oauthProviders.IsMicrosoftEnabled(), + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to build account dynamic registration IAT auth template", "error", err) + return "", exceptions.NewInternalServerError() + } + + return loginHTML, nil +} + type OAuthDynamicRegistrationIATLoginOptions struct { RequestID string ACCClientID string @@ -465,7 +511,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( } if !found { logger.ErrorContext(ctx, "Account credentials dynamic registration IAT not found") - return "", "", false, exceptions.NewNotFoundValidationError("invalid client ID") + return "", "", false, exceptions.NewNotFoundError() } if data.Domain != opts.Domain { @@ -535,7 +581,6 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( RedirectURI: opts.RedirectURI, Domain: data.Domain, ClientID: opts.ACCClientID, - Challenge: data.Challenge, State: data.State, TwoFATTL: s.jwt.Get2FATTL(), }, @@ -588,7 +633,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( if opts.CodeChallengeMethod != "" { queryParams.Add("code_challenge_method", opts.CodeChallengeMethod) } - return oauthDynamicRegistrationIATPath + "/" + opts.ACCClientID + + return oauthDynamicRegistrationIATPath + "/" + opts.ACCClientID + paths.OAuthAuth + paths.AuthLogin + paths.Auth2FA + queryParams.Encode(), sessionID, false, nil } @@ -640,20 +685,22 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( return oauthDynamicRegistrationIATAuthPath, sessionKey, true, nil } -type OAuthDynamicRegistrationIAT2FAOptions struct { - RequestID string - Domain string - ACCClientID string - SessionID string - CodeChallenge string - CodeChallengeMethod string +type OAuthDynamicRegistrationIAT2FARenderOptions struct { + RequestID string + Domain string + ACCClientID string + SessionID string + Challenge string + ChallengeMethod string + State string + RedirectURI string } -func (s *Services) OAuthDynamicRegistrationIAT2FA( +func (s *Services) OAuthDynamicRegistrationIAT2FARender( ctx context.Context, - opts OAuthDynamicRegistrationIAT2FAOptions, + opts OAuthDynamicRegistrationIAT2FARenderOptions, ) (string, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, oauthDynamicRegistrationLocation, "OAuthDynamicRegistrationIAT2FA").With( + logger := s.buildLogger(opts.RequestID, oauthDynamicRegistrationLocation, "OAuthDynamicRegistrationIAT2FARender").With( "clientId", opts.ACCClientID, ) logger.InfoContext(ctx, "Handling OAuth dynamic registration IAT 2FA...") @@ -670,10 +717,23 @@ func (s *Services) OAuthDynamicRegistrationIAT2FA( logger.WarnContext(ctx, "Failed to get account credentials dynamic registration IAT 2FA") return "", exceptions.NewUnauthorizedError() } + if opts.ACCClientID != data.ClientID { logger.WarnContext(ctx, "Client IDs do not match", "sessionClientId", data.ClientID) return "", exceptions.NewUnauthorizedError() } + if data.Domain != opts.Domain { + logger.WarnContext(ctx, "OAuth Domain does not match", "dataDomain", data.Domain) + return "", exceptions.NewUnauthorizedError() + } + if data.State != opts.State { + logger.WarnContext(ctx, "OAuth State does not match") + return "", exceptions.NewUnauthorizedError() + } + if data.RedirectURI != opts.RedirectURI { + logger.WarnContext(ctx, "OAuth Redirect URI does not match") + return "", exceptions.NewUnauthorizedError() + } csrfToken, err := s.cache.SaveAccountCredentialsDynamicRegistrationIAT2FACSRFToken( ctx, @@ -690,12 +750,13 @@ func (s *Services) OAuthDynamicRegistrationIAT2FA( twoFAhtml, err := templates.BuildAccountDynamicRegistrationIAT2FATemplate( templates.AccountDynamicRegistrationIAT2FAOptions{ - ClientID: opts.ACCClientID, + ACCClientID: opts.ACCClientID, + Domain: opts.Domain, SessionID: opts.SessionID, CSRFToken: csrfToken, State: data.State, - CodeChallenge: opts.CodeChallenge, - CodeChallengeMethod: opts.CodeChallengeMethod, + CodeChallenge: opts.Challenge, + CodeChallengeMethod: opts.ChallengeMethod, RedirectURI: data.RedirectURI, }, ) @@ -707,16 +768,57 @@ func (s *Services) OAuthDynamicRegistrationIAT2FA( return twoFAhtml, nil } +type OAuthDynamicRegistrationIAT2FAReRenderOptions struct { + RequestID string + Domain string + ACCClientID string + SessionID string + Errors []string + CSRFToken string + Challenge string + ChallengeMethod string + State string + RedirectURI string +} + +func (s *Services) OAuthDynamicRegistrationIAT2FAReRender( + ctx context.Context, + opts OAuthDynamicRegistrationIAT2FAReRenderOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, oauthDynamicRegistrationLocation, "OAuthDynamicRegistrationIAT2FAReRender").With( + "clientId", opts.ACCClientID, + ) + logger.InfoContext(ctx, "Re-rendering OAuth dynamic registration IAT 2FA...") + + twoFAhtml, err := templates.BuildAccountDynamicRegistrationIAT2FATemplate( + templates.AccountDynamicRegistrationIAT2FAOptions{ + Errors: opts.Errors, + ACCClientID: opts.ACCClientID, + Domain: opts.Domain, + SessionID: opts.SessionID, + CSRFToken: opts.CSRFToken, + State: opts.State, + CodeChallenge: opts.Challenge, + CodeChallengeMethod: opts.ChallengeMethod, + RedirectURI: opts.RedirectURI, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to build account dynamic registration IAT 2FA template", "error", err) + return "", exceptions.NewInternalServerError() + } + + return twoFAhtml, nil +} + type OAuthDynamicRegistrationIATVerify2FACodeOptions struct { - RequestID string - ACCClientID string - Domain string - SessionID string - CSRFToken string - CodeChallenge string - CodeChallengeMethod string - BackendDomain string - Code string + RequestID string + ACCClientID string + Domain string + SessionID string + CSRFToken string + Code string + BackendDomain string } func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( @@ -763,16 +865,6 @@ func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( return "", "", exceptions.NewForbiddenError() } - hashedChallenge, serviceErr := hashChallenge(opts.CodeChallenge, opts.CodeChallengeMethod) - if serviceErr != nil { - logger.ErrorContext(ctx, "Invalid code challenge", "serviceErr", serviceErr) - return "", "", exceptions.NewInternalServerError() - } - if hashedChallenge != data.Challenge { - logger.WarnContext(ctx, "OAuth Code challenge verification failed") - return "", "", exceptions.NewUnauthorizedError() - } - accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ RequestID: opts.RequestID, PublicID: data.AccountPublicID, @@ -834,3 +926,96 @@ func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( return oauthDynamicRegistrationIATAuthPath, sessionKey, nil } + +type VerifyOAuthDynamicRegistrationIATCodeOptions struct { + RequestID string + Code string + CodeVerifier string + Domain string +} + +func (s *Services) VerifyOAuthDynamicRegistrationIATCode( + ctx context.Context, + opts VerifyOAuthDynamicRegistrationIATCodeOptions, +) (dtos.AuthDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationIATLocation, "VerifyOAuthDynamicRegistrationIATCode") + logger.InfoContext(ctx, "Verifying account credentials registration IAT code...") + + data, found, err := s.cache.VerifyAccountCredentialsRegistrationIATCode(ctx, cache.VerifyAccountCredentialsRegistrationIATCodeOptions{ + RequestID: opts.RequestID, + Code: opts.Code, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to verify account credentials registration IAT code", "error", err) + return dtos.AuthDTO{}, exceptions.NewInternalServerError() + } + if !found { + logger.DebugContext(ctx, "Account credentials registration IAT code not found or invalid") + return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() + } + + if data.Domain != opts.Domain { + logger.WarnContext(ctx, "OAuth Domain does not match", "dataDomain", data.Domain) + return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() + } + + ok, err := utils.CompareShaBase64(data.Challenge, opts.CodeVerifier) + if err != nil { + logger.ErrorContext(ctx, "Failed to compare challenge", "error", err) + return dtos.AuthDTO{}, exceptions.NewInternalServerError() + } + if !ok { + logger.WarnContext(ctx, "OAuth Code challenge verification failed") + return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: data.AccountPublicID, + Version: data.AccountVersion, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr + } + + domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + RequestID: opts.RequestID, + AccountPublicID: accountDTO.PublicID, + Domain: data.Domain, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr + } + if !domainDTO.Verified { + logger.ErrorContext(ctx, "Account credentials registration domain is not verified") + return dtos.AuthDTO{}, exceptions.NewValidationError("account credentials registration domain is not verified") + } + + tokenTTL := s.jwt.GetDynamicRegistrationTTL() + signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ + RequestID: opts.RequestID, + Token: s.jwt.CreateAccountCredentialsDynamicRegistrationToken(tokens.AccountCredentialsDynamicRegistrationTokenOptions{ + AccountPublicID: accountDTO.PublicID, + AccountVersion: accountDTO.Version(), + Domain: data.Domain, + ClientID: data.ClientID, + }), + GetJWKfn: s.BuildGetGlobalEncryptedJWKFn(ctx, BuildEncryptedJWKFnOptions{ + RequestID: opts.RequestID, + KeyType: database.TokenKeyTypeDynamicRegistration, + TTL: tokenTTL, + }), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.RequestID), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.RequestID), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.RequestID), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to sign account credentials registration IAT", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Verified account credentials registration IAT code successfully") + return dtos.NewAuthDTO(signedToken, tokenTTL), nil +} diff --git a/idp/internal/services/templates/account_dynamic_registration.go b/idp/internal/services/templates/account_dynamic_registration.go index 71b5d33..0eab5dc 100644 --- a/idp/internal/services/templates/account_dynamic_registration.go +++ b/idp/internal/services/templates/account_dynamic_registration.go @@ -11,6 +11,7 @@ import ( "fmt" "html/template" "net/url" + "strings" "github.com/tugascript/devlogs/idp/internal/controllers/paths" ) @@ -152,7 +153,7 @@ const accountDynamicRegistrationBaseTemplate = ` align-items: center; gap: 0.75rem; position: relative; - width: 100%; + width: 100%%; } .oauth-button-text { @@ -256,6 +257,12 @@ const accountDynamicRegistrationBaseTemplate = ` text-align: center; color: #222; letter-spacing: 1px; + } + + #form-errors { + color: #C62828; + text-align: center; + margin-bottom: 1em; } {{.Title}} @@ -290,16 +297,27 @@ func buildEntryAccountDynamicRegistrationTemplate(body string) string { const baseAccountLoginTitle = "Account Login" +const formErrors = ` +
+ %s +
+` + +func buildFormErrors(errors []string) string { + return fmt.Sprintf(formErrors, strings.Join(errors, "\n")) +} + const loginForm = ` + - - + + ` @@ -394,6 +412,7 @@ const accountDynamicRegistrationLoginTemplateName = "login" type accountDynamicRegistrationLoginTemplateData struct { Title string Header string + ClientID string LoginURL string AppleLoginURL string FacebookLoginURL string @@ -415,6 +434,7 @@ type AccountDynamicRegistrationIATAuthOptions struct { CodeChallenge string CodeChallengeMethod string RedirectURI string + Errors []string AppleEnabled bool FacebookEnabled bool GitHubEnabled bool @@ -423,58 +443,61 @@ type AccountDynamicRegistrationIATAuthOptions struct { } func BuildAccountDynamicRegistrationIATAuthTemplate(opts AccountDynamicRegistrationIATAuthOptions) (string, error) { + baseTemplateBody := "" + if len(opts.Errors) > 0 { + baseTemplateBody += buildFormErrors(opts.Errors) + } + baseURL := paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.DynamicRegistrationBase + - paths.InitialAccessToken + "/" + opts.ACCClientID + paths.InitialAccessToken + "/" + opts.ACCClientID + paths.OAuthAuth data := accountDynamicRegistrationLoginTemplateData{ Title: baseAccountLoginTitle, Header: "OAuth Dynamic Client Registration Initial Access Token Login", LoginURL: baseURL + paths.AuthLogin, RedirectURI: opts.RedirectURI, + ClientID: opts.Domain, CodeChallenge: opts.CodeChallenge, CodeChallengeMethod: opts.CodeChallengeMethod, State: opts.State, CSRFToken: opts.CSRFToken, } - baseTemplateBody := loginForm + divider - - urlParams := make(url.Values) - urlParams.Add("client_id", opts.Domain) - urlParams.Add("response_type", "code") - urlParams.Add("state", opts.State) - urlParams.Add("code_challenge", opts.CodeChallenge) - if opts.CodeChallengeMethod != "" { - urlParams.Add("code_challenge_method", opts.CodeChallengeMethod) - } - urlParams.Add("redirect_uri", opts.RedirectURI) - - if opts.AppleEnabled { - urlParams.Add("client_id", "apple") - data.AppleLoginURL = baseURL + paths.OAuthAuth + "?" + urlParams.Encode() - baseTemplateBody += appleLoginButton - } - if opts.FacebookEnabled { - urlParams.Del("client_id") - urlParams.Add("client_id", "facebook") - data.FacebookLoginURL = baseURL + paths.OAuthAuth + "?" + urlParams.Encode() - baseTemplateBody += facebookLoginButton - } - if opts.GitHubEnabled { - urlParams.Del("client_id") - urlParams.Add("client_id", "github") - data.GithubLoginURL = baseURL + paths.OAuthAuth + "?" + urlParams.Encode() - baseTemplateBody += githubLoginButton - } - if opts.GoogleEnabled { - urlParams.Del("client_id") - urlParams.Add("client_id", "google") - data.GoogleLoginURL = baseURL + paths.OAuthAuth + "?" + urlParams.Encode() - baseTemplateBody += googleLoginButton - } - if opts.MicrosoftEnabled { - urlParams.Del("client_id") - urlParams.Add("client_id", "microsoft") - data.MicrosoftLoginURL = baseURL + paths.OAuthAuth + "?" + urlParams.Encode() - baseTemplateBody += microsoftLoginButton + baseTemplateBody += loginForm + + if opts.AppleEnabled || opts.FacebookEnabled || opts.GitHubEnabled || opts.GoogleEnabled || opts.MicrosoftEnabled { + baseTemplateBody += divider + extAuthURL := baseURL + paths.InitialAccessTokenAuthEXT + + // Common URL parameters for all OAuth providers + urlParams := make(url.Values) + urlParams.Add("client_id", opts.Domain) + urlParams.Add("response_type", "code") + urlParams.Add("state", opts.State) + urlParams.Add("code_challenge", opts.CodeChallenge) + if opts.CodeChallengeMethod != "" { + urlParams.Add("code_challenge_method", opts.CodeChallengeMethod) + } + urlParams.Add("redirect_uri", opts.RedirectURI) + + if opts.AppleEnabled { + data.AppleLoginURL = extAuthURL + "/apple" + "?" + urlParams.Encode() + baseTemplateBody += appleLoginButton + } + if opts.FacebookEnabled { + data.FacebookLoginURL = extAuthURL + "/facebook" + "?" + urlParams.Encode() + baseTemplateBody += facebookLoginButton + } + if opts.GitHubEnabled { + data.GithubLoginURL = extAuthURL + "/github" + "?" + urlParams.Encode() + baseTemplateBody += githubLoginButton + } + if opts.GoogleEnabled { + data.GoogleLoginURL = extAuthURL + "/google" + "?" + urlParams.Encode() + baseTemplateBody += googleLoginButton + } + if opts.MicrosoftEnabled { + data.MicrosoftLoginURL = extAuthURL + "/microsoft" + "?" + urlParams.Encode() + baseTemplateBody += microsoftLoginButton + } } loginTemplate := buildEntryAccountDynamicRegistrationTemplate(baseTemplateBody) @@ -604,6 +627,12 @@ const twoFaTemplate = ` text-align: center; color: #222; letter-spacing: 1px; + } + + #form-errors { + color: #C62828; + text-align: center; + margin-bottom: 1em; } Title @@ -627,17 +656,21 @@ const twoFaTemplate = `

Two-Factor Authentication

- + %s
- - - - - + + + + + + + + 0 { + errDiv = buildFormErrors(opts.Errors) + } + data := accountDynamicRegistrationIAT2FAData{ - TwoFAURL: paths.AccountsBase + paths.CredentialsBase + paths.DynamicRegistrationBase + - paths.InitialAccessToken + "/" + opts.ClientID + paths.AuthLogin + paths.Auth2FA, + TwoFAURL: paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.DynamicRegistrationBase + + paths.InitialAccessToken + "/" + opts.ACCClientID + paths.OAuthAuth + paths.AuthLogin + paths.Auth2FA, + ClientID: opts.Domain, RedirectURI: opts.RedirectURI, CodeChallenge: opts.CodeChallenge, CodeChallengeMethod: opts.CodeChallengeMethod, @@ -684,7 +726,7 @@ func BuildAccountDynamicRegistrationIAT2FATemplate(opts AccountDynamicRegistrati SessionID: opts.SessionID, } - t, err := template.New("two_fa").Parse(twoFaTemplate) + t, err := template.New("two_fa").Parse(fmt.Sprintf(twoFaTemplate, errDiv)) if err != nil { return "", nil } diff --git a/idp/internal/services/templates/login.html b/idp/internal/services/templates/login.html index d50764a..b4ca881 100644 --- a/idp/internal/services/templates/login.html +++ b/idp/internal/services/templates/login.html @@ -239,6 +239,12 @@ color: #222; letter-spacing: 1px; } + + #form-errors { + color: #C62828; + text-align: center; + margin-bottom: 1em; + } Login - DevLogs @@ -263,6 +269,10 @@

Welcome back {{.Name}}

+
+

Invalid credentials

+
+ @@ -270,8 +280,8 @@

Welcome back {{.Name}}

- - + +
From c5a3400b4653748df317ce803c327af29cdb0488 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Mon, 1 Sep 2025 23:10:46 +1200 Subject: [PATCH 14/23] feat(idp): add iat code flow routes --- idp/internal/controllers/auth.go | 4 +- .../controllers/oauth_dynamic_registration.go | 18 +- ...ccount_credentials_dynamic_registration.go | 138 +++++++-- ...ccount_dynamic_registration_domains.sql.go | 68 +++++ .../account_dynamic_registration_domains.sql | 26 ++ .../routes/account_dynamic_registration.go | 13 + .../services/oauth_dynamic_registration.go | 265 ++++++++++++++---- 7 files changed, 445 insertions(+), 87 deletions(-) diff --git a/idp/internal/controllers/auth.go b/idp/internal/controllers/auth.go index 075a725..43acee7 100644 --- a/idp/internal/controllers/auth.go +++ b/idp/internal/controllers/auth.go @@ -8,10 +8,10 @@ package controllers import ( "github.com/gofiber/fiber/v2" - "github.com/tugascript/devlogs/idp/internal/controllers/paths" "github.com/tugascript/devlogs/idp/internal/controllers/bodies" "github.com/tugascript/devlogs/idp/internal/controllers/params" + "github.com/tugascript/devlogs/idp/internal/controllers/paths" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/services" "github.com/tugascript/devlogs/idp/internal/services/dtos" @@ -23,7 +23,7 @@ func (c *Controllers) saveAccountRefreshCookie(ctx *fiber.Ctx, token string) { ctx.Cookie(&fiber.Cookie{ Name: c.cookieName + refreshCookieSuffix, Value: token, - Path: paths.V1 + paths.AuthBase, + Path: paths.V1, HTTPOnly: true, SameSite: fiber.CookieSameSiteNoneMode, Secure: true, diff --git a/idp/internal/controllers/oauth_dynamic_registration.go b/idp/internal/controllers/oauth_dynamic_registration.go index b5df165..7a2854a 100644 --- a/idp/internal/controllers/oauth_dynamic_registration.go +++ b/idp/internal/controllers/oauth_dynamic_registration.go @@ -55,13 +55,19 @@ func (c *Controllers) OAuthDynamicRegistrationIATAuth(ctx *fiber.Ctx) error { return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorInvalidRequest) } + sessionKey := ctx.Cookies(c.cookieName + accountsIATCookieSuffix) + if sessionKey != "" { + // This ensures that the key is only used once + c.removeAccountIATCookie(ctx) + } + redirectURL, serviceErr := c.services.InitiateOAuthDynamicRegistrationIATAuth( ctx.UserContext(), services.InitiateOAuthDynamicRegistrationIATAuthOptions{ RequestID: requestID, Domain: baseQPrms.ClientID, State: qPrms.State, - SessionKey: ctx.Cookies(c.cookieName + accountsIATCookieSuffix), + SessionKey: sessionKey, RefreshToken: ctx.Cookies(c.cookieName + refreshCookieSuffix), Challenge: qPrms.Challenge, ChallengeMethod: qPrms.ChallengeMethod, @@ -70,7 +76,7 @@ func (c *Controllers) OAuthDynamicRegistrationIATAuth(ctx *fiber.Ctx) error { }, ) if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) + return c.redirectServiceErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, serviceErr) } logResponse(logger, ctx, fiber.StatusFound) @@ -82,6 +88,13 @@ func (c *Controllers) OAuthDynamicRegistrationIATLoginGet(ctx *fiber.Ctx) error logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATLoginGet") logRequest(logger, ctx) + uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ ClientID: ctx.Query("client_id"), RedirectURI: ctx.Query("redirect_uri"), @@ -104,6 +117,7 @@ func (c *Controllers) OAuthDynamicRegistrationIATLoginGet(ctx *fiber.Ctx) error ctx.UserContext(), services.OAuthDynamicRegistrationIATAuthRenderOptions{ RequestID: requestID, + ACCClientID: uPrms.ACCClientID, State: qPrms.State, Domain: baseQPrms.ClientID, CodeChallenge: qPrms.Challenge, diff --git a/idp/internal/providers/cache/account_credentials_dynamic_registration.go b/idp/internal/providers/cache/account_credentials_dynamic_registration.go index 460a8a9..122c76f 100644 --- a/idp/internal/providers/cache/account_credentials_dynamic_registration.go +++ b/idp/internal/providers/cache/account_credentials_dynamic_registration.go @@ -27,59 +27,54 @@ const ( sessionKeyByteLen int = 32 ) -func buildAccountCredentialsDynamicRegistrationIATLoginCacheKey(clientID string) string { - return fmt.Sprintf("%s:login:%s", accountCredentialsDynamicRegistrationIATPrefix, clientID) +func buildAccountCredentialsDynamicRegistrationIATAuthCacheKey(clientID string) string { + return fmt.Sprintf("%s:auth:%s", accountCredentialsDynamicRegistrationIATPrefix, clientID) } -type AccountCredentialsDynamicRegistrationIATLoginData struct { +type AccountCredentialsDynamicRegistrationIATAuthData struct { RedirectURI string `json:"redirect_uri"` - CSRFToken string `json:"csrf_token"` Domain string `json:"domain"` State string `json:"state"` + Challenge string `json:"challenge"` } -type SaveAccountCredentialsDynamicRegistrationIATLoginOptions struct { +type SaveAccountCredentialsDynamicRegistrationIATAuthOptions struct { Domain string RequestID string State string RedirectURI string + Challenge string } -func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATLogin( +func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATAuth( ctx context.Context, - opts SaveAccountCredentialsDynamicRegistrationIATLoginOptions, -) (string, string, error) { + opts SaveAccountCredentialsDynamicRegistrationIATAuthOptions, +) (string, error) { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: accountCredentialsDynamicRegistrationLocation, - Method: "SaveAccountCredentialsDynamicRegistrationIATLogin", + Method: "SaveAccountCredentialsDynamicRegistrationIATAuth", RequestID: opts.RequestID, }).With( "redirectUri", opts.RedirectURI, ) logger.DebugContext(ctx, "Saving account credentials dynamic registration IAT sessions...") - csrfToken, err := utils.GenerateBase64Secret(csrfTokenByteLen) - if err != nil { - logger.ErrorContext(ctx, "Error generating CSRF token", "error", err) - return "", "", err - } - - data := AccountCredentialsDynamicRegistrationIATLoginData{ + data := AccountCredentialsDynamicRegistrationIATAuthData{ State: opts.State, Domain: opts.Domain, RedirectURI: opts.RedirectURI, - CSRFToken: utils.Sha256HashHex(csrfToken), + Challenge: opts.Challenge, } dataBytes, err := json.Marshal(data) if err != nil { logger.ErrorContext(ctx, "Failed to marshal account credentials dynamic registration IAT data", "error", err) - return "", "", err + return "", err } clientID := utils.Base62UUID() - return clientID, csrfToken, c.storage.SetWithContext( + return clientID, c.storage.SetWithContext( ctx, - buildAccountCredentialsDynamicRegistrationIATLoginCacheKey(clientID), + buildAccountCredentialsDynamicRegistrationIATAuthCacheKey(clientID), dataBytes, c.oauthStateTTL, ) @@ -93,7 +88,7 @@ type GetAccountCredentialsDynamicRegistrationIATAuthOptions struct { func (c *Cache) GetAccountCredentialsDynamicRegistrationAuthIAT( ctx context.Context, opts GetAccountCredentialsDynamicRegistrationIATAuthOptions, -) (AccountCredentialsDynamicRegistrationIATLoginData, bool, error) { +) (AccountCredentialsDynamicRegistrationIATAuthData, bool, error) { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: accountCredentialsDynamicRegistrationLocation, Method: "GetAccountCredentialsDynamicRegistrationAuthIAT", @@ -103,20 +98,20 @@ func (c *Cache) GetAccountCredentialsDynamicRegistrationAuthIAT( ) logger.DebugContext(ctx, "Getting account credentials dynamic registration IAT...") - data, err := c.storage.GetWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATLoginCacheKey(opts.ClientID)) + data, err := c.storage.GetWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATAuthCacheKey(opts.ClientID)) if err != nil { logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT", "error", err) - return AccountCredentialsDynamicRegistrationIATLoginData{}, false, err + return AccountCredentialsDynamicRegistrationIATAuthData{}, false, err } if data == nil { logger.DebugContext(ctx, "Account credentials dynamic registration IAT not found") - return AccountCredentialsDynamicRegistrationIATLoginData{}, false, nil + return AccountCredentialsDynamicRegistrationIATAuthData{}, false, nil } - var authData AccountCredentialsDynamicRegistrationIATLoginData + var authData AccountCredentialsDynamicRegistrationIATAuthData if err := json.Unmarshal(data, &authData); err != nil { logger.ErrorContext(ctx, "Failed to unmarshal account credentials dynamic registration IAT data", "error", err) - return AccountCredentialsDynamicRegistrationIATLoginData{}, false, err + return AccountCredentialsDynamicRegistrationIATAuthData{}, false, err } return authData, true, nil @@ -140,7 +135,96 @@ func (c *Cache) DeleteAccountCredentialsDynamicRegistrationIATAuth( ) logger.DebugContext(ctx, "Deleting account credentials dynamic registration IAT...") - return c.storage.DeleteWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATLoginCacheKey(opts.ClientID)) + return c.storage.DeleteWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATAuthCacheKey(opts.ClientID)) +} + +func buildAccountCredentialsDynamicRegistrationIATLoginCSRFKey(domain, clientID string) string { + return fmt.Sprintf("%s:login:%s:%s", accountCredentialsDynamicRegistrationIATPrefix, domain, clientID) +} + +type SaveAccountCredentialsDynamicRegistrationIATLoginCSRFOptions struct { + RequestID string + ClientID string + Domain string +} + +func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATLoginCSRF( + ctx context.Context, + opts SaveAccountCredentialsDynamicRegistrationIATLoginCSRFOptions, +) (string, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "SaveAccountCredentialsDynamicRegistrationIATLoginCSRF", + RequestID: opts.RequestID, + }).With( + "clientId", opts.ClientID, + "domain", opts.Domain, + ) + logger.DebugContext(ctx, "Saving account credentials dynamic registration IAT login CSRF token...") + + csrfToken, err := utils.GenerateBase64Secret(csrfTokenByteLen) + if err != nil { + logger.ErrorContext(ctx, "Error generating CSRF token", "error", err) + return "", err + } + + return csrfToken, c.storage.SetWithContext( + ctx, + buildAccountCredentialsDynamicRegistrationIATLoginCSRFKey(opts.Domain, opts.ClientID), + []byte(utils.Sha256HashHex(csrfToken)), + c.oauthStateTTL, + ) +} + +type VerifyAccountCredentialsDynamicRegistrationIATLoginCSRFOptions struct { + RequestID string + ClientID string + Domain string + CSRFToken string +} + +func (c *Cache) VerifyAccountCredentialsDynamicRegistrationIATLoginCSRF( + ctx context.Context, + opts VerifyAccountCredentialsDynamicRegistrationIATLoginCSRFOptions, +) (bool, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "VerifyAccountCredentialsDynamicRegistrationIATLoginCSRF", + RequestID: opts.RequestID, + }).With( + "clientId", opts.ClientID, + "domain", opts.Domain, + ) + logger.DebugContext(ctx, "Verifying account credentials dynamic registration IAT login CSRF token...") + + hashedCSRFToken, err := c.storage.GetWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATLoginCSRFKey(opts.Domain, opts.ClientID)) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT login CSRF token", "error", err) + return false, err + } + if hashedCSRFToken == nil { + logger.DebugContext(ctx, "Account credentials dynamic registration IAT login CSRF token not found") + return false, nil + } + + ok, err := utils.CompareShaHex(opts.CSRFToken, string(hashedCSRFToken)) + if err != nil { + logger.ErrorContext(ctx, "Error comparing CSRF token", "error", err) + return false, err + } + if !ok { + logger.DebugContext(ctx, "Invalid CSRF token") + return false, nil + } + if err := c.storage.DeleteWithContext( + ctx, + buildAccountCredentialsDynamicRegistrationIATLoginCSRFKey(opts.Domain, opts.ClientID), + ); err != nil { + logger.ErrorContext(ctx, "Error deleting CSRF token", "error", err) + return false, err + } + + return true, nil } type AccountCredentialsDynamicRegistrationIAT2FAData struct { diff --git a/idp/internal/providers/database/account_dynamic_registration_domains.sql.go b/idp/internal/providers/database/account_dynamic_registration_domains.sql.go index 352dd5a..4137186 100644 --- a/idp/internal/providers/database/account_dynamic_registration_domains.sql.go +++ b/idp/internal/providers/database/account_dynamic_registration_domains.sql.go @@ -43,6 +43,74 @@ func (q *Queries) CountFilteredAccountDynamicRegistrationDomainsByAccountPublicI return count, err } +const countVerifiedAccountDynamicRegistrationDomainsByDomain = `-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomain :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE "domain" = $1 AND "verified_at" IS NOT NULL +LIMIT 1 +` + +func (q *Queries) CountVerifiedAccountDynamicRegistrationDomainsByDomain(ctx context.Context, domain string) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedAccountDynamicRegistrationDomainsByDomain, domain) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID = `-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" = $2 AND + "verified_at" IS NOT NULL +LIMIT 1 +` + +type CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams struct { + AccountPublicID uuid.UUID + Domain string +} + +func (q *Queries) CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID(ctx context.Context, arg CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID, arg.AccountPublicID, arg.Domain) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countVerifiedAccountDynamicRegistrationDomainsByDomains = `-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomains :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE "domain" IN ($1) AND "verified_at" IS NOT NULL +LIMIT 1 +` + +func (q *Queries) CountVerifiedAccountDynamicRegistrationDomainsByDomains(ctx context.Context, domains []string) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedAccountDynamicRegistrationDomainsByDomains, domains) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID = `-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" IN ($2) AND + "verified_at" IS NOT NULL +LIMIT 1 +` + +type CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams struct { + AccountPublicID uuid.UUID + Domains []string +} + +func (q *Queries) CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID(ctx context.Context, arg CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID, arg.AccountPublicID, arg.Domains) + var count int64 + err := row.Scan(&count) + return count, err +} + const createAccountDynamicRegistrationDomain = `-- name: CreateAccountDynamicRegistrationDomain :one INSERT INTO "account_dynamic_registration_domains" ( diff --git a/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql b/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql index f892c6d..60a0436 100644 --- a/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql +++ b/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql @@ -66,6 +66,32 @@ WHERE "domain" ILIKE $2 LIMIT 1; +-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomain :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE "domain" = $1 AND "verified_at" IS NOT NULL +LIMIT 1; + +-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomains :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE "domain" IN (sqlc.slice('domains')) AND "verified_at" IS NOT NULL +LIMIT 1; + +-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" IN (sqlc.slice('domains')) AND + "verified_at" IS NOT NULL +LIMIT 1; + +-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID :one +SELECT COUNT(*) FROM "account_dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" = $2 AND + "verified_at" IS NOT NULL +LIMIT 1; + -- name: DeleteAccountDynamicRegistrationDomain :exec DELETE FROM "account_dynamic_registration_domains" WHERE "id" = $1; diff --git a/idp/internal/server/routes/account_dynamic_registration.go b/idp/internal/server/routes/account_dynamic_registration.go index 7bf907b..eebd5a5 100644 --- a/idp/internal/server/routes/account_dynamic_registration.go +++ b/idp/internal/server/routes/account_dynamic_registration.go @@ -80,4 +80,17 @@ func (r *Routes) AccountDynamicRegistrationConfigurationRoutes(app *fiber.App) { credentialsConfigsWriteScopeMiddleware, r.controllers.DeleteAccountCredentialsRegistrationDomainCode, ) + + // Initial Access Token (IAT) routes + iatRouter := router.Group(paths.InitialAccessToken) + + // Dynamic Registration IAT Code Exchange flow + iatRouter.Get(paths.OAuthAuth, r.controllers.OAuthDynamicRegistrationIATAuth) + const loginRoute = paths.InitialAccessTokenSingle + paths.AuthLogin + iatRouter.Get(loginRoute, r.controllers.OAuthDynamicRegistrationIATLoginGet) + iatRouter.Post(loginRoute, r.controllers.OAuthDynamicRegistrationIATLoginPost) + iatRouter.Get(loginRoute+paths.Auth2FA, r.controllers.OAuthDynamicRegistrationIAT2FAGet) + iatRouter.Post(loginRoute+paths.Auth2FA, r.controllers.OAuthDynamicRegistrationIAT2FAPost) + iatRouter.Post(paths.OAuthToken, r.controllers.OAuthDynamicRegistrationIATToken) + } diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index c5ee3bb..61a2a3f 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -12,15 +12,16 @@ import ( "slices" "github.com/google/uuid" - "github.com/tugascript/devlogs/idp/internal/providers/crypto" - "github.com/tugascript/devlogs/idp/internal/services/dtos" + "golang.org/x/net/publicsuffix" "github.com/tugascript/devlogs/idp/internal/controllers/paths" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/cache" + "github.com/tugascript/devlogs/idp/internal/providers/crypto" "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/providers/mailer" "github.com/tugascript/devlogs/idp/internal/providers/tokens" + "github.com/tugascript/devlogs/idp/internal/services/dtos" "github.com/tugascript/devlogs/idp/internal/services/templates" "github.com/tugascript/devlogs/idp/internal/utils" ) @@ -33,6 +34,7 @@ const ( ) type buildOAuthDynamicRegistrationIATLoginURLOptions struct { + accClientID string domain string state string challenge string @@ -49,7 +51,7 @@ func buildOAuthDynamicRegistrationIATLoginURL(opts buildOAuthDynamicRegistration if opts.challengeMethod != "" { queryParams.Add("code_challenge_method", opts.challengeMethod) } - return oauthDynamicRegistrationIATAuthPath + paths.AuthLogin + "?" + queryParams.Encode() + return oauthDynamicRegistrationIATPath + "/" + opts.accClientID + paths.OAuthAuth + paths.AuthLogin + "?" + queryParams.Encode() } type buildOAuthDynamicRegistrationIATCallbackURLOptions struct { @@ -124,6 +126,92 @@ func (s *Services) generateOAuthDynamicRegistrationIATCallback( }), nil } +type oauthDynamicRegistrationIATAuthOptions struct { + requestID string + challenge string + challengeMethod string + domain string + redirectURI string + state string +} + +func (s *Services) oauthDynamicRegistrationIATAuth( + ctx context.Context, + opts oauthDynamicRegistrationIATAuthOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger( + opts.requestID, + oauthDynamicRegistrationLocation, + "oauthDynamicRegistrationIATAuth", + ).With( + "domain", opts.domain, + "redirectUri", opts.redirectURI, + ) + logger.InfoContext(ctx, "Handling OAuth dynamic registration IAT auth...") + + tldOneDomain, err := publicsuffix.EffectiveTLDPlusOne(opts.domain) + if err != nil { + logger.WarnContext(ctx, "Invalid domain", "error", err) + return "", exceptions.NewValidationError("invalid client_id") + } + + var count int64 + if tldOneDomain != opts.domain { + count, err = s.database.CountVerifiedAccountDynamicRegistrationDomainsByDomains( + ctx, + []string{opts.domain, tldOneDomain}, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to count account dynamic registration domains by domains", "error", err) + return "", exceptions.NewInternalServerError() + } + if count == 0 { + logger.WarnContext(ctx, "Domain not registered for dynamic registration") + return "", exceptions.NewForbiddenError() + } + } else { + count, err = s.database.CountVerifiedAccountDynamicRegistrationDomainsByDomain(ctx, opts.domain) + } + if err != nil { + logger.ErrorContext(ctx, "Failed to count account dynamic registration domains by domains", "error", err) + return "", exceptions.NewInternalServerError() + } + if count == 0 { + logger.WarnContext(ctx, "Domain not registered for dynamic registration") + return "", exceptions.NewForbiddenError() + } + + hashedChallenge, serviceErr := hashChallenge(opts.challenge, opts.challengeMethod) + if serviceErr != nil { + logger.ErrorContext(ctx, "Invalid code challenge", "serviceError", serviceErr) + return "", serviceErr + } + + clientID, err := s.cache.SaveAccountCredentialsDynamicRegistrationIATAuth( + ctx, + cache.SaveAccountCredentialsDynamicRegistrationIATAuthOptions{ + Domain: opts.domain, + RequestID: opts.requestID, + State: opts.state, + RedirectURI: opts.redirectURI, + Challenge: hashedChallenge, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to save account credentials dynamic registration IAT auth", "error", err) + return "", exceptions.NewInternalServerError() + } + + return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ + accClientID: clientID, + domain: opts.domain, + state: opts.state, + challenge: opts.challenge, + challengeMethod: opts.challengeMethod, + redirectURI: opts.redirectURI, + }), nil +} + type refreshTokenOAuthDynamicRegistrationIATLoginOptions struct { requestID string refreshToken string @@ -144,6 +232,7 @@ func (s *Services) refreshTokenOAuthDynamicRegistrationIATLogin( oauthDynamicRegistrationLocation, "refreshTokenOAuthDynamicRegistrationIATLogin", ).With( + "domain", opts.domain, "redirectUri", opts.redirectURI, ) logger.InfoContext(ctx, "Refreshing OAuth dynamic registration IAT callback...") @@ -208,13 +297,14 @@ func (s *Services) refreshTokenOAuthDynamicRegistrationIATLogin( } logger.WarnContext(ctx, "Account not found or version mismatch", "serviceError", serviceErr) - return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ - domain: opts.domain, - state: opts.state, + return s.oauthDynamicRegistrationIATAuth(ctx, oauthDynamicRegistrationIATAuthOptions{ + requestID: opts.requestID, challenge: opts.challenge, challengeMethod: opts.challengeMethod, + domain: opts.domain, redirectURI: opts.redirectURI, - }), nil + state: opts.state, + }) } cbURL, serviceErr := s.generateOAuthDynamicRegistrationIATCallback( @@ -268,13 +358,14 @@ func (s *Services) InitiateOAuthDynamicRegistrationIATAuth( if opts.SessionKey == "" { if opts.RefreshToken == "" { logger.InfoContext(ctx, "No session key or refresh token provided, redirecting to login") - return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ - domain: opts.Domain, - state: opts.State, + return s.oauthDynamicRegistrationIATAuth(ctx, oauthDynamicRegistrationIATAuthOptions{ + requestID: opts.RequestID, challenge: opts.Challenge, challengeMethod: opts.ChallengeMethod, + domain: opts.Domain, redirectURI: opts.RedirectURI, - }), nil + state: opts.State, + }) } logger.InfoContext(ctx, "No session key provided, attempting to refresh with refresh token") @@ -306,14 +397,8 @@ func (s *Services) InitiateOAuthDynamicRegistrationIATAuth( return "", exceptions.NewInternalServerError() } if !found { - logger.InfoContext(ctx, "Account credentials registration session key not found") - return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ - domain: opts.Domain, - state: opts.State, - challenge: opts.Challenge, - challengeMethod: opts.ChallengeMethod, - redirectURI: opts.RedirectURI, - }), nil + logger.WarnContext(ctx, "Account credentials registration session key not found") + return "", exceptions.NewUnauthorizedError() } if !verified { @@ -368,6 +453,7 @@ func (s *Services) InitiateOAuthDynamicRegistrationIATAuth( type OAuthDynamicRegistrationIATAuthRenderOptions struct { RequestID string + ACCClientID string State string Domain string CodeChallenge string @@ -388,23 +474,51 @@ func (s *Services) OAuthDynamicRegistrationIATAuthRender( ) logger.InfoContext(ctx, "Starting OAuth dynamic registration IAT authorization html render...") - clientID, csrfToken, err := s.cache.SaveAccountCredentialsDynamicRegistrationIATLogin( + data, found, err := s.cache.GetAccountCredentialsDynamicRegistrationAuthIAT( ctx, - cache.SaveAccountCredentialsDynamicRegistrationIATLoginOptions{ - RequestID: opts.RequestID, - State: opts.State, - RedirectURI: opts.RedirectURI, - Domain: opts.Domain, + cache.GetAccountCredentialsDynamicRegistrationIATAuthOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, }, ) if err != nil { - logger.ErrorContext(ctx, "Failed to save account credentials dynamic registration IAT auth", "error", err) + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT", "error", err) + return "", exceptions.NewInternalServerError() + } + if !found { + logger.WarnContext(ctx, "Account credentials dynamic registration IAT not found") + return "", exceptions.NewForbiddenError() + } + + if data.Domain != opts.Domain { + logger.WarnContext(ctx, "OAuth Domain does not match", "dataDomain", data.Domain) + return "", exceptions.NewUnauthorizedError() + } + if data.State != opts.State { + logger.WarnContext(ctx, "OAuth State does not match") + return "", exceptions.NewUnauthorizedError() + } + if data.RedirectURI != opts.RedirectURI { + logger.WarnContext(ctx, "OAuth Redirect URI does not match") + return "", exceptions.NewUnauthorizedError() + } + + csrfToken, err := s.cache.SaveAccountCredentialsDynamicRegistrationIATLoginCSRF( + ctx, + cache.SaveAccountCredentialsDynamicRegistrationIATLoginCSRFOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, + Domain: opts.Domain, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to save account credentials dynamic registration IAT auth CSRF token", "error", err) return "", exceptions.NewInternalServerError() } loginHTML, err := templates.BuildAccountDynamicRegistrationIATAuthTemplate( templates.AccountDynamicRegistrationIATAuthOptions{ - ACCClientID: clientID, + ACCClientID: opts.ACCClientID, CSRFToken: csrfToken, Domain: opts.Domain, State: opts.State, @@ -501,6 +615,24 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( ) logger.InfoContext(ctx, "Logging in with OAuth dynamic registration IAT...") + validCSRF, err := s.cache.VerifyAccountCredentialsDynamicRegistrationIATLoginCSRF( + ctx, + cache.VerifyAccountCredentialsDynamicRegistrationIATLoginCSRFOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, + Domain: opts.Domain, + CSRFToken: opts.CSRFToken, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to verify account credentials dynamic registration IAT auth CSRF token", "error", err) + return "", "", false, exceptions.NewInternalServerError() + } + if !validCSRF { + logger.WarnContext(ctx, "Invalid CSRF token") + return "", "", false, exceptions.NewForbiddenError() + } + data, found, err := s.cache.GetAccountCredentialsDynamicRegistrationAuthIAT(ctx, cache.GetAccountCredentialsDynamicRegistrationIATAuthOptions{ RequestID: opts.RequestID, ClientID: opts.ACCClientID, @@ -663,7 +795,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( return "", "", false, exceptions.NewForbiddenError() } if !domainDTO.Verified { - logger.ErrorContext(ctx, "Account credentials registration domain is not verified") + logger.ErrorContext(ctx, "Account credentials registration domain is not validCSRF") return "", "", false, exceptions.NewForbiddenError() } @@ -831,6 +963,23 @@ func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( ) logger.InfoContext(ctx, "Verifying OAuth dynamic registration IAT 2FA...") + verifiedCSRF, err := s.cache.VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFToken( + ctx, + cache.VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions{ + RequestID: opts.RequestID, + SessionID: opts.SessionID, + CSRFToken: opts.CSRFToken, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to verify account credentials dynamic registration IAT 2FA CSRF token", "error", err) + return "", "", exceptions.NewInternalServerError() + } + if !verifiedCSRF { + logger.WarnContext(ctx, "Invalid CSRF token") + return "", "", exceptions.NewForbiddenError() + } + data, found, err := s.cache.GetAccountCredentialsDynamicRegistrationIAT2FA(ctx, cache.GetAccountCredentialsDynamicRegistrationIAT2FAOptions{ RequestID: opts.RequestID, SessionID: opts.SessionID, @@ -848,23 +997,6 @@ func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( return "", "", exceptions.NewUnauthorizedError() } - csrfTokenValid, err := s.cache.VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFToken( - ctx, - cache.VerifyAccountCredentialsDynamicRegistrationIAT2FACSRFTokenOptions{ - RequestID: opts.RequestID, - SessionID: opts.SessionID, - CSRFToken: opts.CSRFToken, - }, - ) - if err != nil { - logger.ErrorContext(ctx, "Failed to verify account credentials dynamic registration IAT 2FA CSRF token", "error", err) - return "", "", exceptions.NewInternalServerError() - } - if !csrfTokenValid { - logger.WarnContext(ctx, "Invalid CSRF token") - return "", "", exceptions.NewForbiddenError() - } - accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ RequestID: opts.RequestID, PublicID: data.AccountPublicID, @@ -927,6 +1059,8 @@ func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( return oauthDynamicRegistrationIATAuthPath, sessionKey, nil } +// TODO: add external callbacks + type VerifyOAuthDynamicRegistrationIATCodeOptions struct { RequestID string Code string @@ -979,18 +1113,37 @@ func (s *Services) VerifyOAuthDynamicRegistrationIATCode( return dtos.AuthDTO{}, serviceErr } - domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ - RequestID: opts.RequestID, - AccountPublicID: accountDTO.PublicID, - Domain: data.Domain, - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) - return dtos.AuthDTO{}, serviceErr + tldOneDomain, err := publicsuffix.EffectiveTLDPlusOne(opts.Domain) + if err != nil { + logger.WarnContext(ctx, "Invalid domain", "error", err) + return dtos.AuthDTO{}, exceptions.NewValidationError("invalid client_id") } - if !domainDTO.Verified { - logger.ErrorContext(ctx, "Account credentials registration domain is not verified") - return dtos.AuthDTO{}, exceptions.NewValidationError("account credentials registration domain is not verified") + + var count int64 + if tldOneDomain != data.Domain { + count, err = s.database.CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID( + ctx, + database.CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams{ + AccountPublicID: accountDTO.PublicID, + Domains: []string{data.Domain, tldOneDomain}, + }, + ) + } else { + count, err = s.database.CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID( + ctx, + database.CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams{ + AccountPublicID: accountDTO.PublicID, + Domain: data.Domain, + }, + ) + } + if err != nil { + logger.ErrorContext(ctx, "Failed to count verified account dynamic registration domains by domains and account public ID", "error", err) + return dtos.AuthDTO{}, exceptions.NewInternalServerError() + } + if count == 0 { + logger.WarnContext(ctx, "Account does not have any verified dynamic registration domains matching the OAuth Domain") + return dtos.AuthDTO{}, exceptions.NewForbiddenError() } tokenTTL := s.jwt.GetDynamicRegistrationTTL() From 637c4ddf7b028967d8accf4917bb2d9bf4466d83 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Thu, 4 Sep 2025 00:00:27 +1200 Subject: [PATCH 15/23] feat(idp): start adding ext auth for iat --- idp/go.mod | 70 +++-- idp/go.sum | 73 ++++- .../controllers/oauth_dynamic_registration.go | 104 +++++++ .../params/oauth_dynamic_registration.go | 5 + .../controllers/paths/dynamic_registration.go | 10 +- ...ccount_credentials_dynamic_registration.go | 90 +++++++ .../routes/account_dynamic_registration.go | 16 +- .../services/oauth_dynamic_registration.go | 255 +++++++++++++++++- 8 files changed, 582 insertions(+), 41 deletions(-) diff --git a/idp/go.mod b/idp/go.mod index 1c83a41..1f11010 100644 --- a/idp/go.mod +++ b/idp/go.mod @@ -7,20 +7,21 @@ require ( github.com/go-faker/faker/v4 v4.6.1 github.com/go-playground/validator/v10 v10.27.0 github.com/gofiber/fiber/v2 v2.52.9 - github.com/gofiber/storage/redis/v3 v3.4.0 + github.com/gofiber/storage/redis/v3 v3.4.1 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/h2non/gock v1.2.0 github.com/jackc/pgx/v5 v5.7.5 github.com/joho/godotenv v1.5.1 - github.com/openbao/openbao/api/auth/approle/v2 v2.3.1 - github.com/openbao/openbao/api/v2 v2.3.1 + github.com/openbao/openbao/api/auth/approle/v2 v2.4.0 + github.com/openbao/openbao/api/v2 v2.4.0 github.com/pquerna/otp v1.5.0 github.com/redis/go-redis/v9 v9.12.1 golang.org/x/crypto v0.41.0 + golang.org/x/net v0.43.0 golang.org/x/oauth2 v0.30.0 golang.org/x/text v0.28.0 - google.golang.org/api v0.247.0 + google.golang.org/api v0.248.0 ) require ( @@ -28,23 +29,36 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.8.0 // indirect dario.cat/mergo v1.0.2 // indirect + github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/boombuler/barcode v1.1.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/containerd/errdefs v1.0.0 // indirect + github.com/containerd/errdefs/pkg v0.3.0 // indirect + github.com/containerd/log v0.1.0 // indirect + github.com/containerd/platforms v0.2.1 // indirect + github.com/cpuguy83/dockercfg v0.3.2 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.3.3+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect + github.com/docker/go-units v0.5.0 // indirect + github.com/ebitengine/purego v0.8.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/gabriel-vasile/mimetype v1.4.9 // indirect + github.com/gabriel-vasile/mimetype v1.4.10 // indirect github.com/go-jose/go-jose/v3 v3.0.4 // indirect + github.com/go-jose/go-jose/v4 v4.1.2 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/gofiber/storage/testhelpers/redis v0.0.0-20250815074620-1386290f7fd5 // indirect + github.com/gofiber/storage/testhelpers/redis v0.0.0-20250829072152-23fd56bd1077 // indirect + github.com/gogo/protobuf v1.3.2 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect @@ -62,30 +76,52 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 // indirect + github.com/lufia/plan9stats v0.0.0-20250827001030-24949be3fa54 // indirect + github.com/magiconair/properties v1.8.10 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/moby/go-archive v0.1.0 // indirect + github.com/moby/patternmatcher v0.6.0 // indirect + github.com/moby/sys/sequential v0.6.0 // indirect + github.com/moby/sys/user v0.4.0 // indirect + github.com/moby/sys/userns v0.1.0 // indirect github.com/moby/term v0.5.2 // indirect + github.com/morikuni/aec v1.0.0 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.1.1 // indirect github.com/philhofer/fwd v1.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect - github.com/shirou/gopsutil/v4 v4.25.7 // indirect - github.com/tinylib/msgp v1.3.0 // indirect + github.com/shirou/gopsutil/v4 v4.25.8 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/stretchr/testify v1.11.1 // indirect + github.com/testcontainers/testcontainers-go v0.38.0 // indirect + github.com/testcontainers/testcontainers-go/modules/redis v0.38.0 // indirect + github.com/tinylib/msgp v1.4.0 // indirect + github.com/tklauser/go-sysconf v0.3.15 // indirect + github.com/tklauser/numcpus v0.10.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.65.0 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect - go.opentelemetry.io/otel v1.37.0 // indirect - go.opentelemetry.io/otel/metric v1.37.0 // indirect - go.opentelemetry.io/otel/trace v1.37.0 // indirect - golang.org/x/net v0.43.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect + go.opentelemetry.io/otel v1.38.0 // indirect + go.opentelemetry.io/otel/metric v1.38.0 // indirect + go.opentelemetry.io/otel/trace v1.38.0 // indirect + golang.org/x/mod v0.26.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/time v0.12.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect - google.golang.org/grpc v1.74.2 // indirect - google.golang.org/protobuf v1.36.7 // indirect + golang.org/x/tools v0.35.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250826171959-ef028d996bc1 // indirect + google.golang.org/grpc v1.75.0 // indirect + google.golang.org/protobuf v1.36.8 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/idp/go.sum b/idp/go.sum index 90b40f4..30b891a 100644 --- a/idp/go.sum +++ b/idp/go.sum @@ -1,5 +1,3 @@ -cloud.google.com/go/auth v0.16.4 h1:fXOAIQmkApVvcIn7Pc2+5J8QTMVbUGLscnSVNl11su8= -cloud.google.com/go/auth v0.16.4/go.mod h1:j10ncYwjX/g3cdX7GpEzsdM+d+ZNsXAbb6qXA7p1Y5M= cloud.google.com/go/auth v0.16.5 h1:mFWNQ2FEVWAliEQWpAdH80omXFokmrnbDhUS9cBywsI= cloud.google.com/go/auth v0.16.5/go.mod h1:utzRfHMP+Vv0mpOkTRQoWD2q3BatTOoWbA7gCc2dUhQ= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= @@ -59,15 +57,22 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2 github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= +github.com/gabriel-vasile/mimetype v1.4.10 h1:zyueNbySn/z8mJZHLt6IPw0KoZsiQNszIpU+bX4+ZK0= +github.com/gabriel-vasile/mimetype v1.4.10/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/go-faker/faker/v4 v4.6.1 h1:xUyVpAjEtB04l6XFY0V/29oR332rOSPWV4lU8RwDt4k= github.com/go-faker/faker/v4 v4.6.1/go.mod h1:arSdxNCSt7mOhdk8tEolvHeIJ7eX4OX80wXjKKvkKBY= github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= +github.com/go-jose/go-jose/v4 v4.1.1 h1:JYhSgy4mXXzAdF3nUx3ygx347LRXJRrpgyU3adRmkAI= +github.com/go-jose/go-jose/v4 v4.1.1/go.mod h1:BdsZGqgdO3b6tTc6LSE56wcDbMMLuPsw5d4ZD5f94kA= +github.com/go-jose/go-jose/v4 v4.1.2 h1:TK/7NqRQZfgAh+Td8AlsrvtPoUyiHh0LqVvokh+1vHI= +github.com/go-jose/go-jose/v4 v4.1.2/go.mod h1:22cg9HWM1pOlnRiY+9cQYJ9XHmya1bYW8OeDM6Ku6Oo= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= @@ -86,8 +91,12 @@ github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5 github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/storage/redis/v3 v3.4.0 h1:FbtVgHsWkHFaogObFyNbBkNkZL9/zYxQkS1PV0rA5Ss= github.com/gofiber/storage/redis/v3 v3.4.0/go.mod h1:5efv+XbKwSQju9j7tokMgFWZ1JwlZvSsIL4RNJSDyf0= +github.com/gofiber/storage/redis/v3 v3.4.1 h1:feZc1xv1UuW+a1qnpISPaak7r/r0SkNVFHmg9R7PJ/c= +github.com/gofiber/storage/redis/v3 v3.4.1/go.mod h1:rbycYIeewyFZ1uMf9I6t/C3RHZWIOmSRortjvyErhyA= github.com/gofiber/storage/testhelpers/redis v0.0.0-20250815074620-1386290f7fd5 h1:vC79Z8gkydKoxsq+7+IhnTd3z2J7qs1Zi5wXTP29/C4= github.com/gofiber/storage/testhelpers/redis v0.0.0-20250815074620-1386290f7fd5/go.mod h1:PU9dj9E5K6+TLw7pF87y4yOf5HUH6S9uxTlhuRAVMEY= +github.com/gofiber/storage/testhelpers/redis v0.0.0-20250829072152-23fd56bd1077 h1:AQiZAq2FaKjRu08sPHVO8sOnFxUk4+nrvmaJO42YlSA= +github.com/gofiber/storage/testhelpers/redis v0.0.0-20250829072152-23fd56bd1077/go.mod h1:PU9dj9E5K6+TLw7pF87y4yOf5HUH6S9uxTlhuRAVMEY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= @@ -138,12 +147,16 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 h1:PpXWgLPs+Fqr325bN2FD2ISlRRztXibcX6e8f5FR5Dc= github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= +github.com/lufia/plan9stats v0.0.0-20250827001030-24949be3fa54 h1:mFWunSatvkQQDhpdyuFAYwyAan3hzCuma+Pz8sqvOfg= +github.com/lufia/plan9stats v0.0.0-20250827001030-24949be3fa54/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -176,8 +189,12 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= github.com/openbao/openbao/api/auth/approle/v2 v2.3.1 h1:g2m00OqV+T6cn9oMwkVKcw2gs0TA4uHK9yH9ASWybtE= github.com/openbao/openbao/api/auth/approle/v2 v2.3.1/go.mod h1:Bn4PFuu1mRG2Vcoz2KFdIynAtN90fa4kFUgExNEm8Cs= +github.com/openbao/openbao/api/auth/approle/v2 v2.4.0 h1:e2CScj+WeCTbh/cbap4aeAJ2XWU2CZ+x5lFfGaS3DI4= +github.com/openbao/openbao/api/auth/approle/v2 v2.4.0/go.mod h1:n77uPZESGOsxNXcnLlJKvJCvNZeIDQ9ZWsNry4dUlDA= github.com/openbao/openbao/api/v2 v2.3.1 h1:+Ho5A1jWedZonDz+HDViSOXTieotUT6w7r2Q8Sc8GNM= github.com/openbao/openbao/api/v2 v2.3.1/go.mod h1:oEeWVQSz1LeJJGwwCiPzHX6seppRh8jYXaw6W6yYvao= +github.com/openbao/openbao/api/v2 v2.4.0 h1:OcHJgexGt65qFNcpNNqM2v3otaWt8YhfD7Q5Sy6CWZc= +github.com/openbao/openbao/api/v2 v2.4.0/go.mod h1:ULxn1SwPo/txs19I1VHEBBqMspG8wiZ17qe9DMjCwP0= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -202,6 +219,8 @@ github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkB github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/shirou/gopsutil/v4 v4.25.7 h1:bNb2JuqKuAu3tRlPv5piSmBZyMfecwQ+t/ILq+1JqVM= github.com/shirou/gopsutil/v4 v4.25.7/go.mod h1:XV/egmwJtd3ZQjBpJVY5kndsiOO4IRqy9TQnmm6VP7U= +github.com/shirou/gopsutil/v4 v4.25.8 h1:NnAsw9lN7587WHxjJA9ryDnqhJpFH6A+wagYWTOH970= +github.com/shirou/gopsutil/v4 v4.25.8/go.mod h1:q9QdMmfAOVIw7a+eF86P7ISEU6ka+NLgkUxlopV4RwI= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -209,24 +228,28 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/testcontainers/testcontainers-go v0.38.0 h1:d7uEapLcv2P8AvH8ahLqDMMxda2W9gQN1nRbHS28HBw= github.com/testcontainers/testcontainers-go v0.38.0/go.mod h1:C52c9MoHpWO+C4aqmgSU+hxlR5jlEayWtgYrb8Pzz1w= github.com/testcontainers/testcontainers-go/modules/redis v0.38.0 h1:289pn0BFmGqDrd6BrImZAprFef9aaPZacx07YOQaPV4= github.com/testcontainers/testcontainers-go/modules/redis v0.38.0/go.mod h1:EcKPWRzOglnQfYe+ekA8RPEIWSNJTGwaC5oE5bQV+D0= github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww= github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= +github.com/tinylib/msgp v1.4.0 h1:SYOeDRiydzOw9kSiwdYp9UcBgPFtLU2WDHaJXyHruf8= +github.com/tinylib/msgp v1.4.0/go.mod h1:cvjFkb4RiC8qSBOPMGPSzSAx47nAsfhLVTCZZNuHv5o= github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4= github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4= github.com/tklauser/numcpus v0.10.0 h1:18njr6LDBk1zuna922MgdjQuJFjrdppsZG60sHGfjso= github.com/tklauser/numcpus v0.10.0/go.mod h1:BiTKazU708GQTYF4mB+cmlpT2Is1gLk7XVuEeem8LsQ= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.64.0 h1:QBygLLQmiAyiXuRhthf0tuRkqAFcrC42dckN2S+N3og= -github.com/valyala/fasthttp v1.64.0/go.mod h1:dGmFxwkWXSK0NbOSJuF7AMVzU+lkHz0wQVvVITv2UQA= github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= @@ -234,24 +257,43 @@ go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJyS go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= @@ -261,15 +303,22 @@ golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -295,21 +344,37 @@ golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.247.0 h1:tSd/e0QrUlLsrwMKmkbQhYVa109qIintOls2Wh6bngc= google.golang.org/api v0.247.0/go.mod h1:r1qZOPmxXffXg6xS5uhx16Fa/UFY8QU/K4bfKrnvovM= +google.golang.org/api v0.248.0 h1:hUotakSkcwGdYUqzCRc5yGYsg4wXxpkKlW5ryVqvC1Y= +google.golang.org/api v0.248.0/go.mod h1:yAFUAF56Li7IuIQbTFoLwXTCI6XCFKueOlS7S9e4F9k= google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= +google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 h1:FiusG7LWj+4byqhbvmB+Q93B/mOxJLN2DTozDuZm4EU= google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250826171959-ef028d996bc1 h1:pmJpJEvT846VzausCQ5d7KreSROcDqmO388w5YbnltA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250826171959-ef028d996bc1/go.mod h1:GmFNa4BdJZ2a8G+wCe9Bg3wwThLrJun751XstdJt5Og= google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= +google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4= +google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/idp/internal/controllers/oauth_dynamic_registration.go b/idp/internal/controllers/oauth_dynamic_registration.go index 7a2854a..3d94719 100644 --- a/idp/internal/controllers/oauth_dynamic_registration.go +++ b/idp/internal/controllers/oauth_dynamic_registration.go @@ -475,6 +475,110 @@ func (c *Controllers) OAuthDynamicRegistrationIAT2FAPost(ctx *fiber.Ctx) error { return ctx.Redirect(redirectURL, fiber.StatusSeeOther) } +func (c *Controllers) OAuthDynamicRegistrationIATExtAuthGet(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATExtAuthGet") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATExtAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + Provider: ctx.Params("provider"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ + ClientID: ctx.Query("client_id"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + responseType := ctx.Query("response_type") + state := ctx.Query("state") + if responseType != "code" { + return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorUnsupportedResponseType) + } + + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ + ResponseType: responseType, + Challenge: ctx.Query("code_challenge"), + ChallengeMethod: ctx.Query("code_challenge_method"), + State: state, + } + if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { + return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorInvalidRequest) + } + + authURL, serviceErr := c.services.OAuthDynamicRegistrationIATExtGet( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATExtGetOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Provider: uPrms.Provider, + Domain: baseQPrms.ClientID, + CallbackURL: baseQPrms.RedirectURI, + RedirectURI: baseQPrms.RedirectURI, + State: qPrms.State, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect(authURL, fiber.StatusFound) +} + +func (c *Controllers) OAuthDynamicRegistrationIATExtCB(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATExtCB") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATExtAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + Provider: ctx.Params("provider"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + qPrms := params.OAuthCallbackQueryParams{ + Code: ctx.Query("code"), + State: ctx.Query("state"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &qPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + cbURL, serviceErr := c.services.OAuthDynamicRegistrationIATExtCB( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATExtCBOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Provider: uPrms.Provider, + State: qPrms.State, + Code: qPrms.Code, + RedirectURL: "https://" + c.backendDomain + paths.V1 + paths.AccountsBase + + paths.CredentialsBase + paths.DynamicRegistrationBase + paths.InitialAccessToken + + "/" + uPrms.ACCClientID + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT + "/" + + uPrms.Provider + paths.InitialAccessTokenCallback, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect(cbURL, fiber.StatusFound) +} + +// TODO: add Apple callback + func (c *Controllers) OAuthDynamicRegistrationIATToken(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATToken") diff --git a/idp/internal/controllers/params/oauth_dynamic_registration.go b/idp/internal/controllers/params/oauth_dynamic_registration.go index 803c13f..515c706 100644 --- a/idp/internal/controllers/params/oauth_dynamic_registration.go +++ b/idp/internal/controllers/params/oauth_dynamic_registration.go @@ -21,3 +21,8 @@ type OAuthDynamicRegistrationIATAuthQueryParams struct { type OAuthDynamicRegistrationIATAuthURLParams struct { ACCClientID string `validate:"required,min=22,max=22,alphanum"` } + +type OAuthDynamicRegistrationIATExtAuthURLParams struct { + ACCClientID string `validate:"required,min=22,max=22,alphanum"` + Provider string `validate:"required,oneof=apple facebook github google microsoft"` +} diff --git a/idp/internal/controllers/paths/dynamic_registration.go b/idp/internal/controllers/paths/dynamic_registration.go index 6395527..82ddc25 100644 --- a/idp/internal/controllers/paths/dynamic_registration.go +++ b/idp/internal/controllers/paths/dynamic_registration.go @@ -7,8 +7,10 @@ package paths const ( - DynamicRegistrationBase string = "/dynamic-registration" - InitialAccessToken string = "/initial-access-token" - InitialAccessTokenAuthEXT string = "/ext" - InitialAccessTokenSingle string = "/:accClientID" + DynamicRegistrationBase string = "/dynamic-registration" + InitialAccessToken string = "/initial-access-token" + InitialAccessTokenAuthEXT string = "/ext" + InitialAccessTokenCallback string = "/callback" + InitialAccessTokenProvider string = "/:provider" + InitialAccessTokenSingle string = "/:accClientID" ) diff --git a/idp/internal/providers/cache/account_credentials_dynamic_registration.go b/idp/internal/providers/cache/account_credentials_dynamic_registration.go index 122c76f..72a913b 100644 --- a/idp/internal/providers/cache/account_credentials_dynamic_registration.go +++ b/idp/internal/providers/cache/account_credentials_dynamic_registration.go @@ -719,3 +719,93 @@ func (c *Cache) DeleteAccountCredentialsRegistrationSessionKey( buildAccountCredentialsDynamicRegistrationSessionCacheKey(opts.Domain, opts.ClientID), ) } + +func buildAccountCredentialsDynamicRegistrationIATExtAuthCacheKey(provider, state string) string { + return fmt.Sprintf("%s:ext-auth:%s:%s", accountCredentialsDynamicRegistrationIATPrefix, provider, utils.Sha256HashHex(state)) +} + +type AccountCredentialsDynamicRegistrationIATExtAuthData struct { + ClientID string `json:"client_id"` + Domain string `json:"domain"` + RequestState string `json:"request_state"` +} + +type SaveAccountCredentialsDynamicRegistrationIATExtAuthOptions struct { + RequestID string + ClientID string + Domain string + Provider string + State string + RequestState string +} + +func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATExtAuth( + ctx context.Context, + opts SaveAccountCredentialsDynamicRegistrationIATExtAuthOptions, +) error { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "SaveAccountCredentialsDynamicRegistrationIATExtAuth", + RequestID: opts.RequestID, + }).With( + "clientId", opts.ClientID, + "provider", opts.Provider, + ) + logger.DebugContext(ctx, "Saving account credentials dynamic registration IAT external auth...") + + data := AccountCredentialsDynamicRegistrationIATExtAuthData{ + ClientID: opts.ClientID, + Domain: opts.Domain, + RequestState: opts.RequestState, + } + dataBytes, err := json.Marshal(data) + if err != nil { + logger.ErrorContext(ctx, "Failed to marshal account credentials dynamic registration IAT external auth data", "error", err) + return err + } + + return c.storage.SetWithContext( + ctx, + buildAccountCredentialsDynamicRegistrationIATExtAuthCacheKey(opts.Provider, opts.State), + dataBytes, + c.oauthStateTTL, + ) +} + +type GetAccountCredentialsDynamicRegistrationIATExtAuthOptions struct { + RequestID string + Provider string + State string +} + +func (c *Cache) GetAccountCredentialsDynamicRegistrationIATExtAuth( + ctx context.Context, + opts GetAccountCredentialsDynamicRegistrationIATExtAuthOptions, +) (AccountCredentialsDynamicRegistrationIATExtAuthData, bool, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: accountCredentialsDynamicRegistrationLocation, + Method: "GetAccountCredentialsDynamicRegistrationIATExtAuth", + RequestID: opts.RequestID, + }).With( + "provider", opts.Provider, + ) + logger.DebugContext(ctx, "Getting account credentials dynamic registration IAT external auth...") + + data, err := c.storage.GetWithContext(ctx, buildAccountCredentialsDynamicRegistrationIATExtAuthCacheKey(opts.Provider, opts.State)) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT external auth", "error", err) + return AccountCredentialsDynamicRegistrationIATExtAuthData{}, false, err + } + if data == nil { + logger.DebugContext(ctx, "Account credentials dynamic registration IAT external auth not found") + return AccountCredentialsDynamicRegistrationIATExtAuthData{}, false, nil + } + + var authData AccountCredentialsDynamicRegistrationIATExtAuthData + if err := json.Unmarshal(data, &authData); err != nil { + logger.ErrorContext(ctx, "Failed to unmarshal account credentials dynamic registration IAT external auth data", "error", err) + return AccountCredentialsDynamicRegistrationIATExtAuthData{}, false, err + } + + return authData, true, nil +} diff --git a/idp/internal/server/routes/account_dynamic_registration.go b/idp/internal/server/routes/account_dynamic_registration.go index eebd5a5..d879706 100644 --- a/idp/internal/server/routes/account_dynamic_registration.go +++ b/idp/internal/server/routes/account_dynamic_registration.go @@ -86,11 +86,21 @@ func (r *Routes) AccountDynamicRegistrationConfigurationRoutes(app *fiber.App) { // Dynamic Registration IAT Code Exchange flow iatRouter.Get(paths.OAuthAuth, r.controllers.OAuthDynamicRegistrationIATAuth) + iatRouter.Post(paths.OAuthToken, r.controllers.OAuthDynamicRegistrationIATToken) + + // Dynamic Registration IAT Login flow const loginRoute = paths.InitialAccessTokenSingle + paths.AuthLogin iatRouter.Get(loginRoute, r.controllers.OAuthDynamicRegistrationIATLoginGet) iatRouter.Post(loginRoute, r.controllers.OAuthDynamicRegistrationIATLoginPost) - iatRouter.Get(loginRoute+paths.Auth2FA, r.controllers.OAuthDynamicRegistrationIAT2FAGet) - iatRouter.Post(loginRoute+paths.Auth2FA, r.controllers.OAuthDynamicRegistrationIAT2FAPost) - iatRouter.Post(paths.OAuthToken, r.controllers.OAuthDynamicRegistrationIATToken) + // Dynamic Registration IAT 2FA flow + const twoFAAuthRoute = loginRoute + paths.Auth2FA + iatRouter.Get(twoFAAuthRoute, r.controllers.OAuthDynamicRegistrationIAT2FAGet) + iatRouter.Post(twoFAAuthRoute, r.controllers.OAuthDynamicRegistrationIAT2FAPost) + + // Dynamic Registration IAT External Auth flow + const extAuthRoute = paths.InitialAccessTokenSingle + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT + iatRouter.Get(extAuthRoute+paths.InitialAccessTokenProvider, r.controllers.OAuthDynamicRegistrationIATExtAuthGet) + // TODO: add Apple callback + iatRouter.Get(extAuthRoute+paths.OAuthCallback, r.controllers.OAuthDynamicRegistrationIATExtCB) } diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index 61a2a3f..ba26698 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -20,6 +20,7 @@ import ( "github.com/tugascript/devlogs/idp/internal/providers/crypto" "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/providers/mailer" + "github.com/tugascript/devlogs/idp/internal/providers/oauth" "github.com/tugascript/devlogs/idp/internal/providers/tokens" "github.com/tugascript/devlogs/idp/internal/services/dtos" "github.com/tugascript/devlogs/idp/internal/services/templates" @@ -75,7 +76,6 @@ type generateOAuthDynamicRegistrationIATCallbackOptions struct { accountPublicID uuid.UUID accountVersion int32 challenge string - challengeMethod string domain string redirectURI string state string @@ -96,12 +96,6 @@ func (s *Services) generateOAuthDynamicRegistrationIATCallback( ) logger.InfoContext(ctx, "Generating OAuth dynamic registration IAT callback...") - hashedChallenge, serviceErr := hashChallenge(opts.challenge, opts.challengeMethod) - if serviceErr != nil { - logger.ErrorContext(ctx, "Invalid code challenge", "serviceError", serviceErr) - return "", serviceErr - } - code, err := s.cache.GenerateAccountCredentialsRegistrationIATCode( ctx, cache.GenerateAccountCredentialsRegistrationIATCodeOptions{ @@ -109,7 +103,7 @@ func (s *Services) generateOAuthDynamicRegistrationIATCallback( ClientID: opts.clientID, AccountPublicID: opts.accountPublicID, AccountVersion: opts.accountVersion, - Challenge: hashedChallenge, + Challenge: opts.challenge, Domain: opts.domain, }, ) @@ -256,7 +250,7 @@ func (s *Services) refreshTokenOAuthDynamicRegistrationIATLogin( } if !slices.ContainsFunc(data.Scopes, func(s string) bool { - return s == tokens.AccountScopeAdmin || s == tokens.AccountScopeCredentialsConfigsWrite + return s == tokens.AccountScopeAdmin || s == tokens.AccountScopeCredentialsWrite }) { logger.WarnContext(ctx, "Refresh token missing offline_access scope") return buildOAuthDynamicRegistrationIATLoginURL(buildOAuthDynamicRegistrationIATLoginURLOptions{ @@ -307,6 +301,12 @@ func (s *Services) refreshTokenOAuthDynamicRegistrationIATLogin( }) } + hashedChallenge, serviceErr := hashChallenge(opts.challenge, opts.challengeMethod) + if serviceErr != nil { + logger.ErrorContext(ctx, "Invalid code challenge", "serviceError", serviceErr) + return "", serviceErr + } + cbURL, serviceErr := s.generateOAuthDynamicRegistrationIATCallback( ctx, generateOAuthDynamicRegistrationIATCallbackOptions{ @@ -314,8 +314,7 @@ func (s *Services) refreshTokenOAuthDynamicRegistrationIATLogin( clientID: utils.Base62UUID(), accountPublicID: accountDTO.PublicID, accountVersion: accountDTO.Version(), - challenge: opts.challenge, - challengeMethod: opts.challengeMethod, + challenge: hashedChallenge, domain: opts.domain, redirectURI: opts.redirectURI, state: opts.state, @@ -427,6 +426,12 @@ func (s *Services) InitiateOAuthDynamicRegistrationIATAuth( return "", serviceErr } + hashedChallenge, serviceErr := hashChallenge(opts.Challenge, opts.ChallengeMethod) + if serviceErr != nil { + logger.ErrorContext(ctx, "Invalid code challenge", "serviceError", serviceErr) + return "", serviceErr + } + logger.InfoContext(ctx, "Successfully verified account credentials registration IAT session key, creating code...") cbURL, serviceErr := s.generateOAuthDynamicRegistrationIATCallback( ctx, @@ -435,8 +440,7 @@ func (s *Services) InitiateOAuthDynamicRegistrationIATAuth( clientID: credsClientID, accountPublicID: accountDTO.PublicID, accountVersion: accountDTO.Version(), - challenge: opts.Challenge, - challengeMethod: opts.ChallengeMethod, + challenge: hashedChallenge, domain: opts.Domain, redirectURI: opts.RedirectURI, state: opts.State, @@ -1172,3 +1176,228 @@ func (s *Services) VerifyOAuthDynamicRegistrationIATCode( logger.InfoContext(ctx, "Verified account credentials registration IAT code successfully") return dtos.NewAuthDTO(signedToken, tokenTTL), nil } + +type OAuthDynamicRegistrationIATExtGetOptions struct { + RequestID string + ACCClientID string + Domain string + Provider string + CallbackURL string + RedirectURI string + State string + BackendDomain string +} + +func (s *Services) OAuthDynamicRegistrationIATExtGet( + ctx context.Context, + opts OAuthDynamicRegistrationIATExtGetOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, oauthLocation, "OAuthDynamicRegistrationIATExtGet").With( + "Provider", opts.Provider, + ) + logger.InfoContext(ctx, "External logging in account...") + + authUrlOpts := oauth.AuthorizationURLOptions{ + RequestID: opts.RequestID, + Scopes: oauthScopes, + RedirectURL: opts.CallbackURL, + } + var oauthUrl, state string + var serviceErr *exceptions.ServiceError + switch opts.Provider { + case AuthProviderApple: + oauthUrl, state, serviceErr = s.oauthProviders.GetAppleAuthorizationURL(ctx, authUrlOpts) + case AuthProviderFacebook: + oauthUrl, state, serviceErr = s.oauthProviders.GetFacebookAuthorizationURL(ctx, authUrlOpts) + case AuthProviderGitHub: + oauthUrl, state, serviceErr = s.oauthProviders.GetGithubAuthorizationURL(ctx, authUrlOpts) + case AuthProviderGoogle: + oauthUrl, state, serviceErr = s.oauthProviders.GetGoogleAuthorizationURL(ctx, authUrlOpts) + case AuthProviderMicrosoft: + oauthUrl, state, serviceErr = s.oauthProviders.GetMicrosoftAuthorizationURL(ctx, authUrlOpts) + default: + logger.ErrorContext(ctx, "Provider must be 'apple', 'facebook', 'github', 'google' and 'microsoft'") + return "", exceptions.NewInternalServerError() + } + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get authorization url or State", "error", serviceErr) + return "", serviceErr + } + + data, found, err := s.cache.GetAccountCredentialsDynamicRegistrationAuthIAT(ctx, cache.GetAccountCredentialsDynamicRegistrationIATAuthOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT", "error", err) + return "", exceptions.NewInternalServerError() + } + if !found { + logger.ErrorContext(ctx, "Account credentials dynamic registration IAT not found") + return "", exceptions.NewNotFoundError() + } + + if data.Domain != opts.Domain { + logger.WarnContext(ctx, "OAuth Domain does not match", "dataDomain", data.Domain) + return "", exceptions.NewUnauthorizedError() + } + if data.State != opts.State { + logger.WarnContext(ctx, "OAuth State does not match") + return "", exceptions.NewUnauthorizedError() + } + if data.RedirectURI != opts.RedirectURI { + logger.WarnContext(ctx, "OAuth Redirect URI does not match") + return "", exceptions.NewUnauthorizedError() + } + + if err := s.cache.SaveAccountCredentialsDynamicRegistrationIATExtAuth(ctx, cache.SaveAccountCredentialsDynamicRegistrationIATExtAuthOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, + Domain: opts.Domain, + Provider: opts.Provider, + State: state, + RequestState: opts.State, + }); err != nil { + logger.ErrorContext(ctx, "Failed to save account credentials dynamic registration IAT external auth", "error", err) + return "", exceptions.NewInternalServerError() + } + + logger.InfoContext(ctx, "Saved account credentials dynamic registration IAT external auth successfully") + return oauthUrl, nil +} + +type OAuthDynamicRegistrationIATExtCBOptions struct { + RequestID string + ACCClientID string + Provider string + State string + Code string + RedirectURL string + BackendDomain string +} + +func (s *Services) OAuthDynamicRegistrationIATExtCB( + ctx context.Context, + opts OAuthDynamicRegistrationIATExtCBOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, oauthLocation, "OAuthDynamicRegistrationIATExtCB") + logger.InfoContext(ctx, "External callback for account...") + + data, found, err := s.cache.GetAccountCredentialsDynamicRegistrationIATExtAuth(ctx, cache.GetAccountCredentialsDynamicRegistrationIATExtAuthOptions{ + RequestID: opts.RequestID, + Provider: opts.Provider, + State: opts.State, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT external auth", "error", err) + return "", exceptions.NewInternalServerError() + } + if !found { + logger.ErrorContext(ctx, "Account credentials dynamic registration IAT external auth not found") + return "", exceptions.NewNotFoundError() + } + + if data.ClientID != opts.ACCClientID { + logger.WarnContext(ctx, "Client IDs do not match", "dataClientId", data.ClientID) + return "", exceptions.NewUnauthorizedError() + } + + accessTokenOpts := oauth.AccessTokenOptions{ + RequestID: opts.RequestID, + Code: opts.Code, + Scopes: oauthScopes, + RedirectURL: opts.RedirectURL, + } + var token string + var serviceErr *exceptions.ServiceError + switch opts.Provider { + case AuthProviderFacebook: + token, serviceErr = s.oauthProviders.GetFacebookAccessToken(ctx, accessTokenOpts) + case AuthProviderGitHub: + token, serviceErr = s.oauthProviders.GetGithubAccessToken(ctx, accessTokenOpts) + case AuthProviderGoogle: + token, serviceErr = s.oauthProviders.GetGoogleAccessToken(ctx, accessTokenOpts) + case AuthProviderMicrosoft: + token, serviceErr = s.oauthProviders.GetMicrosoftAccessToken(ctx, accessTokenOpts) + default: + logger.ErrorContext(ctx, "Provider must be 'facebook', 'github', 'google' and 'microsoft'") + return "", exceptions.NewInternalServerError() + } + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get oauth access token", "error", serviceErr) + return "", serviceErr + } + + authData, found, err := s.cache.GetAccountCredentialsDynamicRegistrationAuthIAT(ctx, cache.GetAccountCredentialsDynamicRegistrationIATAuthOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT", "error", err) + return "", exceptions.NewInternalServerError() + } + if !found { + logger.ErrorContext(ctx, "Account credentials dynamic registration IAT not found") + return "", exceptions.NewNotFoundError() + } + + if authData.Domain != data.Domain { + logger.WarnContext(ctx, "OAuth Domain does not match", "dataDomain", authData.Domain) + return "", exceptions.NewUnauthorizedError() + } + if authData.State != data.RequestState { + logger.WarnContext(ctx, "OAuth State does not match") + return "", exceptions.NewUnauthorizedError() + } + + userData, serviceErr := s.extOAuthUser(ctx, logger, extOAuthUserOptions{ + requestID: opts.RequestID, + provider: opts.Provider, + token: token, + }) + if serviceErr != nil { + return "", serviceErr + } + + accountDTO, serviceErr := s.GetAccountByEmail(ctx, GetAccountByEmailOptions{ + RequestID: opts.RequestID, + Email: userData.Email, + }) + if serviceErr != nil { + return "", serviceErr + } + if _, serviceErr := s.GetAccountAuthProvider(ctx, GetAccountAuthProviderOptions{ + RequestID: opts.RequestID, + PublicID: accountDTO.PublicID, + Provider: opts.Provider, + }); serviceErr != nil { + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to get account auth provider", "serviceError", serviceErr) + return "", serviceErr + } + + logger.WarnContext(ctx, "Account auth provider not found", "serviceError", serviceErr) + return "", exceptions.NewUnauthorizedError() + } + + cbURL, serviceErr := s.generateOAuthDynamicRegistrationIATCallback( + ctx, + generateOAuthDynamicRegistrationIATCallbackOptions{ + requestID: opts.RequestID, + clientID: opts.ACCClientID, + accountPublicID: accountDTO.PublicID, + accountVersion: accountDTO.Version(), + challenge: authData.Challenge, + domain: authData.Domain, + redirectURI: authData.RedirectURI, + state: authData.State, + backendDomain: opts.BackendDomain, + }, + ) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to generate OAuth dynamic registration IAT callback", "serviceError", serviceErr) + return "", serviceErr + } + + return cbURL, nil +} From b89f31ddbfcb97688d5a270cf20e5dfd2c874768 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Thu, 4 Sep 2025 07:04:42 +1200 Subject: [PATCH 16/23] feat(idp): add apple callback for dynamic iat --- .../bodies/oauth_dynamic_registration.go | 10 ++ idp/internal/controllers/oauth.go | 10 +- .../controllers/oauth_dynamic_registration.go | 57 +++++++- .../params/oauth_dynamic_registration.go | 6 +- .../routes/account_dynamic_registration.go | 2 +- .../services/oauth_dynamic_registration.go | 127 +++++++++++++++++- 6 files changed, 203 insertions(+), 9 deletions(-) diff --git a/idp/internal/controllers/bodies/oauth_dynamic_registration.go b/idp/internal/controllers/bodies/oauth_dynamic_registration.go index f733318..3ef4948 100644 --- a/idp/internal/controllers/bodies/oauth_dynamic_registration.go +++ b/idp/internal/controllers/bodies/oauth_dynamic_registration.go @@ -41,3 +41,13 @@ type OAuthDynamicRegistrationIATTokenBody struct { Code string `json:"code" validate:"required,min=1"` CodeVerifier string `json:"code_verifier" validate:"required,min=1"` } + +type OAuthDynamicRegistrationIATExtAppleUserBody struct { + Email string `json:"email" validate:"required,email"` +} + +type OAuthDynamicRegistrationIATExtAppleBody struct { + Code string `json:"code" validate:"required,min=1"` + State string `json:"state" validate:"required,min=1"` + User string `json:"user" validate:"required,json"` +} diff --git a/idp/internal/controllers/oauth.go b/idp/internal/controllers/oauth.go index 0650122..8e57b77 100644 --- a/idp/internal/controllers/oauth.go +++ b/idp/internal/controllers/oauth.go @@ -152,19 +152,19 @@ func (c *Controllers) AccountAppleCallback(ctx *fiber.Ctx) error { body := new(bodies.AppleLoginBody) if err := ctx.BodyParser(body); err != nil { - return c.errorCallback(logger, ctx, "", exceptions.OAuthErrorInvalidRequest) + return c.errorCallback(logger, ctx, body.State, exceptions.OAuthErrorInvalidRequest) } if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { - return c.errorCallback(logger, ctx, "", exceptions.OAuthErrorInvalidRequest) + return c.errorCallback(logger, ctx, body.State, exceptions.OAuthErrorInvalidRequest) } user := new(bodies.AppleUser) if err := json.Unmarshal([]byte(body.User), user); err != nil { - return c.errorCallback(logger, ctx, "", exceptions.OAuthErrorInvalidScope) + return c.errorCallback(logger, ctx, body.State, exceptions.OAuthErrorInvalidScope) } if err := c.validate.StructCtx(ctx.UserContext(), user); err != nil { logger.WarnContext(ctx.UserContext(), "Failed to parse apple user data") - return c.errorCallback(logger, ctx, "", exceptions.OAuthErrorInvalidScope) + return c.errorCallback(logger, ctx, body.State, exceptions.OAuthErrorInvalidScope) } oauthParams, serviceErr := c.services.AppleLoginAccount(ctx.UserContext(), services.AppleLoginAccountOptions{ @@ -176,7 +176,7 @@ func (c *Controllers) AccountAppleCallback(ctx *fiber.Ctx) error { State: body.State, }) if serviceErr != nil { - return c.serviceErrorCallback(logger, ctx, "", serviceErr) + return c.serviceErrorCallback(logger, ctx, body.State, serviceErr) } return c.acceptCallback(logger, ctx, oauthParams) diff --git a/idp/internal/controllers/oauth_dynamic_registration.go b/idp/internal/controllers/oauth_dynamic_registration.go index 3d94719..e7a49d1 100644 --- a/idp/internal/controllers/oauth_dynamic_registration.go +++ b/idp/internal/controllers/oauth_dynamic_registration.go @@ -7,6 +7,7 @@ package controllers import ( + "encoding/json" "fmt" "github.com/gofiber/fiber/v2" @@ -577,7 +578,61 @@ func (c *Controllers) OAuthDynamicRegistrationIATExtCB(ctx *fiber.Ctx) error { return ctx.Redirect(cbURL, fiber.StatusFound) } -// TODO: add Apple callback +func (c *Controllers) OAuthDynamicRegistrationIATExtAppleCB(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATExtAppleCB") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATExtAppleURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError("Only application/x-www-form-urlencoded is supported")) + } + + qPrms := bodies.OAuthDynamicRegistrationIATExtAppleBody{ + Code: ctx.FormValue("code"), + State: ctx.FormValue("state"), + User: ctx.FormValue("user"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &qPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + user := new(bodies.OAuthDynamicRegistrationIATExtAppleUserBody) + if err := json.Unmarshal([]byte(qPrms.User), user); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + if err := c.validate.StructCtx(ctx.UserContext(), user); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + cbURL, serviceErr := c.services.OAuthDynamicRegistrationIATExtAppleCB( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATExtAppleCBOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Email: user.Email, + Code: qPrms.Code, + State: qPrms.State, + RedirectURL: "https://" + c.backendDomain + paths.V1 + paths.AccountsBase + + paths.CredentialsBase + paths.DynamicRegistrationBase + paths.InitialAccessToken + + "/" + uPrms.ACCClientID + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT + "/" + + services.AuthProviderApple + paths.InitialAccessTokenCallback, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect(cbURL, fiber.StatusFound) +} func (c *Controllers) OAuthDynamicRegistrationIATToken(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) diff --git a/idp/internal/controllers/params/oauth_dynamic_registration.go b/idp/internal/controllers/params/oauth_dynamic_registration.go index 515c706..5b8f8fa 100644 --- a/idp/internal/controllers/params/oauth_dynamic_registration.go +++ b/idp/internal/controllers/params/oauth_dynamic_registration.go @@ -24,5 +24,9 @@ type OAuthDynamicRegistrationIATAuthURLParams struct { type OAuthDynamicRegistrationIATExtAuthURLParams struct { ACCClientID string `validate:"required,min=22,max=22,alphanum"` - Provider string `validate:"required,oneof=apple facebook github google microsoft"` + Provider string `validate:"required,oneof=facebook github google microsoft"` +} + +type OAuthDynamicRegistrationIATExtAppleURLParams struct { + ACCClientID string `validate:"required,min=22,max=22,alphanum"` } diff --git a/idp/internal/server/routes/account_dynamic_registration.go b/idp/internal/server/routes/account_dynamic_registration.go index d879706..6ba8980 100644 --- a/idp/internal/server/routes/account_dynamic_registration.go +++ b/idp/internal/server/routes/account_dynamic_registration.go @@ -101,6 +101,6 @@ func (r *Routes) AccountDynamicRegistrationConfigurationRoutes(app *fiber.App) { // Dynamic Registration IAT External Auth flow const extAuthRoute = paths.InitialAccessTokenSingle + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT iatRouter.Get(extAuthRoute+paths.InitialAccessTokenProvider, r.controllers.OAuthDynamicRegistrationIATExtAuthGet) - // TODO: add Apple callback + iatRouter.Post(extAuthRoute+paths.OAuthAppleCallback, r.controllers.OAuthDynamicRegistrationIATExtAppleCB) iatRouter.Get(extAuthRoute+paths.OAuthCallback, r.controllers.OAuthDynamicRegistrationIATExtCB) } diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index ba26698..0a8464c 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -1199,7 +1199,7 @@ func (s *Services) OAuthDynamicRegistrationIATExtGet( authUrlOpts := oauth.AuthorizationURLOptions{ RequestID: opts.RequestID, - Scopes: oauthScopes, + Scopes: make([]oauth.Scope, 0), RedirectURL: opts.CallbackURL, } var oauthUrl, state string @@ -1401,3 +1401,128 @@ func (s *Services) OAuthDynamicRegistrationIATExtCB( return cbURL, nil } + +type OAuthDynamicRegistrationIATExtAppleCBOptions struct { + RequestID string + ACCClientID string + Email string + Code string + State string + RedirectURL string + BackendDomain string +} + +func (s *Services) OAuthDynamicRegistrationIATExtAppleCB( + ctx context.Context, + opts OAuthDynamicRegistrationIATExtAppleCBOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, oauthLocation, "OAuthDynamicRegistrationIATExtAppleCB") + logger.InfoContext(ctx, "External callback for account...") + + data, found, err := s.cache.GetAccountCredentialsDynamicRegistrationIATExtAuth(ctx, cache.GetAccountCredentialsDynamicRegistrationIATExtAuthOptions{ + RequestID: opts.RequestID, + Provider: AuthProviderApple, + State: opts.State, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT external auth", "error", err) + return "", exceptions.NewInternalServerError() + } + if !found { + logger.ErrorContext(ctx, "Account credentials dynamic registration IAT external auth not found") + return "", exceptions.NewNotFoundError() + } + + if data.ClientID != opts.ACCClientID { + logger.WarnContext(ctx, "Client IDs do not match", "dataClientId", data.ClientID) + return "", exceptions.NewUnauthorizedError() + } + + idToken, serviceErr := s.oauthProviders.GetAppleIDToken(ctx, oauth.AccessTokenOptions{ + RequestID: opts.RequestID, + Code: opts.Code, + Scopes: oauthScopes, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get apple AccountID token", "error", serviceErr) + return "", serviceErr + } + + authData, found, err := s.cache.GetAccountCredentialsDynamicRegistrationAuthIAT(ctx, cache.GetAccountCredentialsDynamicRegistrationIATAuthOptions{ + RequestID: opts.RequestID, + ClientID: opts.ACCClientID, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account credentials dynamic registration IAT", "error", err) + return "", exceptions.NewInternalServerError() + } + if !found { + logger.ErrorContext(ctx, "Account credentials dynamic registration IAT not found") + return "", exceptions.NewNotFoundError() + } + + if authData.Domain != data.Domain { + logger.WarnContext(ctx, "OAuth Domain does not match", "dataDomain", authData.Domain) + return "", exceptions.NewUnauthorizedError() + } + if authData.State != data.RequestState { + logger.WarnContext(ctx, "OAuth State does not match") + return "", exceptions.NewUnauthorizedError() + } + + ok, serviceErr := s.oauthProviders.ValidateAppleIDToken(ctx, oauth.ValidateAppleIDTokenOptions{ + RequestID: opts.RequestID, + Token: idToken, + Email: opts.Email, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to validate apple AccountID token", "error", serviceErr) + return "", serviceErr + } + if !ok { + logger.WarnContext(ctx, "Apple account is not verified") + return "", exceptions.NewUnauthorizedError() + } + + accountDTO, serviceErr := s.GetAccountByEmail(ctx, GetAccountByEmailOptions{ + RequestID: opts.RequestID, + Email: opts.Email, + }) + if serviceErr != nil { + return "", serviceErr + } + if _, serviceErr := s.GetAccountAuthProvider(ctx, GetAccountAuthProviderOptions{ + RequestID: opts.RequestID, + PublicID: accountDTO.PublicID, + Provider: AuthProviderApple, + }); serviceErr != nil { + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to get account auth provider", "serviceError", serviceErr) + return "", serviceErr + } + + logger.WarnContext(ctx, "Account auth provider not found", "serviceError", serviceErr) + return "", exceptions.NewUnauthorizedError() + } + + cbURL, serviceErr := s.generateOAuthDynamicRegistrationIATCallback( + ctx, + generateOAuthDynamicRegistrationIATCallbackOptions{ + requestID: opts.RequestID, + clientID: opts.ACCClientID, + accountPublicID: accountDTO.PublicID, + accountVersion: accountDTO.Version(), + challenge: authData.Challenge, + domain: authData.Domain, + redirectURI: authData.RedirectURI, + state: authData.State, + backendDomain: opts.BackendDomain, + }, + ) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to generate OAuth dynamic registration IAT callback", "serviceError", serviceErr) + return "", serviceErr + } + + return cbURL, nil +} From adcdbf0e2e693db64dfdf9b81be37b4735dde069 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Thu, 4 Sep 2025 07:09:20 +1200 Subject: [PATCH 17/23] fix(idp): fix linter warnings --- idp/internal/controllers/helpers.go | 24 ------------------------ idp/internal/providers/oauth/oauth.go | 2 +- idp/tests/auth_test.go | 12 ++++++------ idp/tests/oauth_test.go | 6 +++--- 4 files changed, 10 insertions(+), 34 deletions(-) diff --git a/idp/internal/controllers/helpers.go b/idp/internal/controllers/helpers.go index 2742497..3e3d1a4 100644 --- a/idp/internal/controllers/helpers.go +++ b/idp/internal/controllers/helpers.go @@ -93,30 +93,6 @@ func validateQueryParamsErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, err e return validateErrorJSONResponse(logger, ctx, exceptions.ValidationResponseLocationQuery, err) } -func validationErrorHTMLResponse(logger *slog.Logger, ctx *fiber.Ctx, location string, err error) error { - logger.WarnContext(ctx.UserContext(), "Failed to validate request", "error", err) - expt := validationErrorException(location, err) - errHtml, err := templates.BuildErrorTemplate( - templates.ErrorTemplateOptions{ - Status: fiber.StatusBadRequest, - ErrorCode: expt.Code, - MessageTitle: expt.Message, - Messages: utils.MapSlice(expt.Fields, func(f *exceptions.FieldError) string { - return fmt.Sprintf("Field '%s' - Value '%s': %s", f.Param, f.Value, f.Message) - }), - }, - ) - if err != nil { - logger.ErrorContext(ctx.UserContext(), "Failed to build error template", "error", err) - logResponse(logger, ctx, fiber.StatusInternalServerError) - return ctx.Status(fiber.StatusInternalServerError). - Type("html"). - SendString(templates.InternalServerErrorTemplate) - } - - return ctx.Status(fiber.StatusBadRequest).Type("html").SendString(errHtml) -} - func serviceErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, serviceErr *exceptions.ServiceError) error { status := exceptions.NewRequestErrorStatus(serviceErr.Code) resErr := exceptions.NewErrorResponse(serviceErr) diff --git a/idp/internal/providers/oauth/oauth.go b/idp/internal/providers/oauth/oauth.go index 1fb6ce5..61742ec 100644 --- a/idp/internal/providers/oauth/oauth.go +++ b/idp/internal/providers/oauth/oauth.go @@ -214,7 +214,7 @@ func NewProviders( }, Enabled: microsoftCfg.Enabled(), }, - logger: log, + logger: log.With(utils.BaseLayer, logLayer), } } diff --git a/idp/tests/auth_test.go b/idp/tests/auth_test.go index ef68cab..7a08370 100644 --- a/idp/tests/auth_test.go +++ b/idp/tests/auth_test.go @@ -1480,7 +1480,7 @@ func TestListAccountAuthProviders(t *testing.T) { AssertFn: func(t *testing.T, _ string, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.ItemsDTO[dtos.AuthProviderDTO]{}) AssertEqual(t, len(resBody.Items), 1) - AssertEqual(t, resBody.Items[0].Provider, services.AuthProviderLocal) + AssertEqual(t, resBody.Items[0].Provider, database.AuthProviderLocal) AssertNotEmpty(t, resBody.Items[0].RegisteredAt) }, }, @@ -1549,11 +1549,11 @@ func TestListAccountAuthProviders(t *testing.T) { AssertFn: func(t *testing.T, _ string, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.ItemsDTO[dtos.AuthProviderDTO]{}) AssertEqual(t, len(resBody.Items), 3) - AssertEqual(t, resBody.Items[0].Provider, services.AuthProviderGoogle) + AssertEqual(t, resBody.Items[0].Provider, database.AuthProviderGoogle) AssertNotEmpty(t, resBody.Items[0].RegisteredAt) - AssertEqual(t, resBody.Items[1].Provider, services.AuthProviderMicrosoft) + AssertEqual(t, resBody.Items[1].Provider, database.AuthProviderMicrosoft) AssertNotEmpty(t, resBody.Items[1].RegisteredAt) - AssertEqual(t, resBody.Items[2].Provider, services.AuthProviderLocal) + AssertEqual(t, resBody.Items[2].Provider, database.AuthProviderLocal) AssertNotEmpty(t, resBody.Items[2].RegisteredAt) }, }, @@ -1639,7 +1639,7 @@ func TestGetAccountAuthProvider(t *testing.T) { ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, provider string, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.AuthProviderDTO{}) - AssertEqual(t, resBody.Provider, services.AuthProviderApple) + AssertEqual(t, resBody.Provider, database.AuthProviderApple) AssertNotEmpty(t, resBody.RegisteredAt) }, Path: authProviderPath + "/apple", @@ -1690,7 +1690,7 @@ func TestGetAccountAuthProvider(t *testing.T) { ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, provider string, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.AuthProviderDTO{}) - AssertEqual(t, resBody.Provider, services.AuthProviderGoogle) + AssertEqual(t, resBody.Provider, database.AuthProviderGoogle) AssertNotEmpty(t, resBody.RegisteredAt) }, Path: authProviderPath + "/google", diff --git a/idp/tests/oauth_test.go b/idp/tests/oauth_test.go index f3aa753..00a84c8 100644 --- a/idp/tests/oauth_test.go +++ b/idp/tests/oauth_test.go @@ -46,7 +46,7 @@ func TestAccountOAuthURL(t *testing.T) { params.Add("response_type", "code") params.Add("scope", "email profile") params.Add("state", generateState(t)) - params.Add("code_challenge", utils.Sha256HashBase64([]byte(generateState(t)))) + params.Add("code_challenge", utils.Sha256HashBase64(generateState(t))) params.Add("code_challenge_method", "S256") return "?" + params.Encode() } @@ -165,7 +165,7 @@ func callbackBeforeEach(t *testing.T, provider string) (string, string, string) testCache := GetTestCache(t) requestID := uuid.NewString() - challenge := utils.Sha256HashBase64([]byte(state + requestID)) + challenge := utils.Sha256HashBase64(state + requestID) stateOpts := cache.SaveOAuthStateDataOptions{ RequestID: requestID, State: state, @@ -695,7 +695,7 @@ func TestOAuthToken(t *testing.T) { GivenName: account.GivenName, FamilyName: account.FamilyName, Provider: provider, - Challenge: utils.Sha256HashBase64([]byte(challenge)), + Challenge: utils.Sha256HashBase64(challenge), }) if err != nil { t.Fatal("Failed to generate OAuth code", err) From ce7522f8a887ff877e12d02f34fb9a9503d295a6 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 7 Sep 2025 14:23:25 +1200 Subject: [PATCH 18/23] feat(idp): add multiple 2FA providers support --- idp/initial_schema.dbml | 65 +- .../controllers/account_2fa_configs.go | 209 +++++ idp/internal/controllers/auth.go | 83 +- .../controllers/bodies/account_2fa_configs.go | 12 + idp/internal/controllers/middleware.go | 35 +- .../controllers/params/account_2fa_configs.go | 11 + idp/internal/controllers/paths/two_fa.go | 12 + idp/internal/controllers/users_auth.go | 41 +- ...ccount_credentials_dynamic_registration.go | 3 + .../providers/cache/sensitive_requests.go | 130 +-- .../database/account_2fa_configs.sql.go | 196 +++++ .../account_credential_secrets.sql.go | 5 +- .../database/account_credentials_keys.sql.go | 5 +- .../providers/database/accounts.sql.go | 102 +-- ...0241213231542_create_initial_schema.up.sql | 61 +- idp/internal/providers/database/models.go | 112 ++- .../database/queries/account_2fa_configs.sql | 49 ++ .../providers/database/queries/accounts.sql | 14 +- .../providers/database/queries/users.sql | 10 +- idp/internal/providers/database/users.sql.go | 101 +-- idp/internal/providers/tokens/twoFactor.go | 65 +- idp/internal/server/routes/auth.go | 48 +- idp/internal/server/routes/users_auth.go | 8 +- idp/internal/services/account_2fa_configs.go | 830 ++++++++++++++++++ .../account_credentials_registration_iat.go | 12 +- idp/internal/services/accounts.go | 160 +++- idp/internal/services/auth.go | 738 ++++------------ idp/internal/services/deks.go | 39 +- idp/internal/services/dtos/account.go | 12 +- .../services/dtos/account_2fa_config.go | 94 ++ idp/internal/services/dtos/user.go | 8 +- idp/internal/services/helpers.go | 14 - idp/internal/services/jwks.go | 28 +- idp/internal/services/oauth.go | 13 +- .../services/oauth_dynamic_registration.go | 50 +- idp/internal/services/users.go | 1 - idp/internal/services/users_auth.go | 202 +---- 37 files changed, 2372 insertions(+), 1206 deletions(-) create mode 100644 idp/internal/controllers/account_2fa_configs.go create mode 100644 idp/internal/controllers/bodies/account_2fa_configs.go create mode 100644 idp/internal/controllers/params/account_2fa_configs.go create mode 100644 idp/internal/controllers/paths/two_fa.go create mode 100644 idp/internal/providers/database/account_2fa_configs.sql.go create mode 100644 idp/internal/providers/database/queries/account_2fa_configs.sql create mode 100644 idp/internal/services/account_2fa_configs.go create mode 100644 idp/internal/services/dtos/account_2fa_config.go diff --git a/idp/initial_schema.dbml b/idp/initial_schema.dbml index 4892a90..89a9017 100644 --- a/idp/initial_schema.dbml +++ b/idp/initial_schema.dbml @@ -110,10 +110,10 @@ Table token_signing_keys as TS { } Ref: TS.dek_kid > DEK.kid [delete: cascade, update: cascade] -Enum two_factor_type { - "none" - "totp" - "email" +Enum activity_status { + "active" + "suspended" + "blocked" } Table accounts as A { @@ -128,9 +128,7 @@ Table accounts as A { password text version integer [not null, default: 1] email_verified boolean [not null, default: false] - - is_active boolean [not null, default: true] - two_factor_type two_factor_type [not null, default: 'none'] + activity_status activity_status [not null, default: 'active'] created_at timestamptz [not null, default: `now()`] updated_at timestamptz [not null, default: `now()`] @@ -143,6 +141,33 @@ Table accounts as A { } } +Enum two_factor_type { + "totp" + "email" +} + +Table account_2fa_configs as A2FA { + id serial [pk] + + account_id integer [not null] + account_public_id uuid [not null] + + two_factor_type two_factor_type [not null] + is_default boolean [not null, default: false] + is_active boolean [not null, default: false] + + created_at timestamptz [not null, default: `now()`] + updated_at timestamptz [not null, default: `now()`] + + Indexes { + (account_id) [name: 'account_2fa_configs_account_id_idx'] + (account_public_id) [name: 'account_2fa_configs_account_public_id_idx'] + (account_public_id, is_default) [name: 'account_2fa_configs_account_public_id_is_default_idx'] + (account_public_id, two_factor_type) [name: 'account_2fa_configs_account_public_id_two_factor_type_idx'] + } +} +Ref: A2FA.account_id > A.id [delete: cascade] + Enum totp_usage { "account" "user" @@ -545,9 +570,7 @@ Table users as U { password text version integer [not null, default: 1] email_verified boolean [not null, default: false] - - is_active boolean [not null, default: true] - two_factor_type two_factor_type [not null, default: 'none'] + activity_status activity_status [not null, default: 'active'] user_data jsonb [not null, default: '{}'] @@ -564,6 +587,27 @@ Table users as U { } Ref: U.account_id > A.id [delete: cascade] +Table user_2fa_configs as U2FA { + id serial [pk] + account_id integer [not null] + + user_id integer [not null] + two_factor_type two_factor_type [not null] + is_default boolean [not null, default: false] + + created_at timestamptz [not null, default: `now()`] + updated_at timestamptz [not null, default: `now()`] + + Indexes { + (account_id) [name: 'user_2fa_configs_account_id_idx'] + (user_id) [name: 'user_2fa_configs_user_id_idx'] + (two_factor_type) [name: 'user_2fa_configs_two_factor_type_idx'] + (user_id, two_factor_type) [unique, name: 'user_2fa_configs_user_id_two_factor_type_uidx'] + } +} +Ref: U2FA.account_id > A.id [delete: cascade] +Ref: U2FA.user_id > U.id [delete: cascade] + Table user_data_encryption_keys as UDEK { user_id integer [not null] data_encryption_key_id integer [not null] @@ -895,6 +939,7 @@ Enum initial_access_token_generation_method { Enum software_statement_verification_method { "manual" "jwks_uri" + "jwk_x5_parameters" } Table account_dynamic_registration_configs as ADRC { diff --git a/idp/internal/controllers/account_2fa_configs.go b/idp/internal/controllers/account_2fa_configs.go new file mode 100644 index 0000000..12ff4f7 --- /dev/null +++ b/idp/internal/controllers/account_2fa_configs.go @@ -0,0 +1,209 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package controllers + +import ( + "github.com/gofiber/fiber/v2" + + "github.com/tugascript/devlogs/idp/internal/controllers/bodies" + "github.com/tugascript/devlogs/idp/internal/controllers/params" + "github.com/tugascript/devlogs/idp/internal/services" +) + +const account2FAConfigsLocation = "account_2fa_configs" + +func (c *Controllers) GetDefaultAccount2FAConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, account2FAConfigsLocation, "GetDefaultAccount2FAConfig") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + account2FAConfigDTO, serviceErr := c.services.GetDefaultAccount2FAConfig(ctx.UserContext(), services.GetDefaultAccount2FAConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + }) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&account2FAConfigDTO) +} + +func (c *Controllers) GetAccount2FAConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, account2FAConfigsLocation, "GetAccount2FAConfig") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + urlParams := params.GetAccount2FAConfigURLParams{TwoFAType: ctx.Params("twoFAType")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return validateURLParamsErrorResponse(logger, ctx, err) + } + + twoFAType, serviceErr := services.Map2FAType(urlParams.TwoFAType) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + account2FAConfigDTO, serviceErr := c.services.GetAccount2FAConfig(ctx.UserContext(), services.GetAccount2FAConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + TwoFAType: twoFAType, + }) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&account2FAConfigDTO) +} + +func (c *Controllers) CreateAccount2FAConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, account2FAConfigsLocation, "CreateAccount2FAConfig") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + body := new(bodies.Account2FAConfigBody) + if err := ctx.BodyParser(body); err != nil { + return parseRequestErrorResponse(logger, ctx, err) + } + if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { + return validateBodyErrorResponse(logger, ctx, err) + } + + account2FAConfigDTO, serviceErr := c.services.CreateAccount2FAConfig( + ctx.UserContext(), + services.CreateAccount2FAConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + TwoFAType: body.TwoFAType, + IsDefault: body.IsDefault, + }) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusCreated) + return ctx.Status(fiber.StatusCreated).JSON(&account2FAConfigDTO) +} + +func (c *Controllers) SetAccount2FAConfigDefault(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, account2FAConfigsLocation, "SetAccount2FAConfigDefault") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + urlParams := params.GetAccount2FAConfigURLParams{TwoFAType: ctx.Params("twoFAType")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return validateURLParamsErrorResponse(logger, ctx, err) + } + + account2FAConfigDTO, serviceErr := c.services.SetAccount2FAConfigDefault( + ctx.UserContext(), + services.SetAccount2FAConfigDefaultOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + TwoFAType: urlParams.TwoFAType, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&account2FAConfigDTO) +} + +func (c *Controllers) DeleteAccount2FAConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, account2FAConfigsLocation, "DeleteAccount2FAConfig") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + urlParams := params.GetAccount2FAConfigURLParams{TwoFAType: ctx.Params("twoFAType")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return validateURLParamsErrorResponse(logger, ctx, err) + } + + account2FAConfigDTO, serviceErr := c.services.DeleteAccount2FAConfig(ctx.UserContext(), services.DeleteAccount2FAConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + TwoFAType: urlParams.TwoFAType, + }) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&account2FAConfigDTO) +} + +func (c *Controllers) ConfirmDeleteAccount2FAConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, account2FAConfigsLocation, "ConfirmDeleteAccount2FAConfig") + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + urlParams := params.GetAccount2FAConfigURLParams{TwoFAType: ctx.Params("twoFAType")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return validateURLParamsErrorResponse(logger, ctx, err) + } + + body := new(bodies.TwoFactorLoginBody) + if err := ctx.BodyParser(body); err != nil { + return parseRequestErrorResponse(logger, ctx, err) + } + if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { + return validateBodyErrorResponse(logger, ctx, err) + } + + authDTO, serviceErr := c.services.ConfirmDeleteAccount2FAConfig( + ctx.UserContext(), + services.ConfirmDeleteAccount2FAConfigOptions{ + RequestID: requestID, + PublicID: accountClaims.AccountID, + Version: accountClaims.AccountVersion, + TwoFAType: urlParams.TwoFAType, + Code: body.Code, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&authDTO) +} diff --git a/idp/internal/controllers/auth.go b/idp/internal/controllers/auth.go index 43acee7..1550168 100644 --- a/idp/internal/controllers/auth.go +++ b/idp/internal/controllers/auth.go @@ -133,7 +133,7 @@ func (c *Controllers) TwoFactorLoginAccount(ctx *fiber.Ctx) error { logger := c.buildLogger(requestID, authLocation, "TwoFactorLoginAccount") logRequest(logger, ctx) - accountClaims, serviceErr := getAccountClaims(ctx) + accountClaims, twoFAType, serviceErr := getAccounts2FAClaims(ctx) if serviceErr != nil { return serviceErrorResponse(logger, ctx, serviceErr) } @@ -146,11 +146,12 @@ func (c *Controllers) TwoFactorLoginAccount(ctx *fiber.Ctx) error { return validateBodyErrorResponse(logger, ctx, err) } - authDTO, serviceErr := c.services.TwoFactorLoginAccount(ctx.UserContext(), services.TwoFactorLoginAccountOptions{ - RequestID: requestID, - PublicID: accountClaims.AccountID, - Version: accountClaims.AccountVersion, - Code: body.Code, + authDTO, serviceErr := c.services.VerifyAccount2FA(ctx.UserContext(), services.VerifyAccount2FAOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + TwoFAType: twoFAType, + Code: body.Code, }) if serviceErr != nil { return serviceErrorResponse(logger, ctx, serviceErr) @@ -366,73 +367,3 @@ func (c *Controllers) GetAccountAuthProvider(ctx *fiber.Ctx) error { logResponse(logger, ctx, fiber.StatusOK) return ctx.Status(fiber.StatusOK).JSON(&authProviderDTO) } - -func (c *Controllers) UpdateAccount2FA(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, authLocation, "UpdateAccount2FA") - logRequest(logger, ctx) - - accountClaims, serviceErr := getAccountClaims(ctx) - if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) - } - - body := new(bodies.Update2FABody) - if err := ctx.BodyParser(body); err != nil { - return parseRequestErrorResponse(logger, ctx, err) - } - if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { - return validateBodyErrorResponse(logger, ctx, err) - } - - authDTO, serviceErr := c.services.UpdateAccount2FA(ctx.UserContext(), services.UpdateAccount2FAOptions{ - RequestID: requestID, - PublicID: accountClaims.AccountID, - Version: accountClaims.AccountVersion, - TwoFactorType: body.TwoFactorType, - Password: body.Password, - }) - if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) - } - - if authDTO.RefreshToken != "" { - c.saveAccountRefreshCookie(ctx, authDTO.RefreshToken) - } - - logResponse(logger, ctx, fiber.StatusOK) - return ctx.Status(fiber.StatusOK).JSON(&authDTO) -} - -func (c *Controllers) ConfirmUpdateAccount2FA(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, authLocation, "ConfirmUpdateAccount2FAUpdate") - logRequest(logger, ctx) - - accountClaims, serviceErr := getAccountClaims(ctx) - if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) - } - - body := new(bodies.TwoFactorLoginBody) - if err := ctx.BodyParser(body); err != nil { - return parseRequestErrorResponse(logger, ctx, err) - } - if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { - return validateBodyErrorResponse(logger, ctx, err) - } - - authDTO, serviceErr := c.services.ConfirmUpdateAccount2FAUpdate(ctx.UserContext(), services.ConfirmUpdateAccount2FAUpdateOptions{ - RequestID: requestID, - PublicID: accountClaims.AccountID, - Version: accountClaims.AccountVersion, - Code: body.Code, - }) - if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusOK) - c.saveAccountRefreshCookie(ctx, authDTO.RefreshToken) - return ctx.Status(fiber.StatusOK).JSON(&authDTO) -} diff --git a/idp/internal/controllers/bodies/account_2fa_configs.go b/idp/internal/controllers/bodies/account_2fa_configs.go new file mode 100644 index 0000000..57b2540 --- /dev/null +++ b/idp/internal/controllers/bodies/account_2fa_configs.go @@ -0,0 +1,12 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package bodies + +type Account2FAConfigBody struct { + TwoFAType string `json:"two_factor_type" validate:"required,oneof=email totp"` + IsDefault bool `json:"is_default" validate:"required,boolean"` +} diff --git a/idp/internal/controllers/middleware.go b/idp/internal/controllers/middleware.go index c2d1b90..7840ec5 100644 --- a/idp/internal/controllers/middleware.go +++ b/idp/internal/controllers/middleware.go @@ -111,16 +111,17 @@ func (c *Controllers) AccountAccessClaimsMiddleware(ctx *fiber.Ctx) error { } func (c *Controllers) TwoFAAccessClaimsMiddleware(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) logger := c.buildLogger(getRequestID(ctx), middlewareLocation, "TwoFAAccessClaimsMiddleware") authHeader := ctx.Get("Authorization") if authHeader == "" { return serviceErrorResponse(logger, ctx, exceptions.NewUnauthorizedError()) } - accountClaims, serviceErr := c.services.Process2FAAuthHeader( + accountClaims, twoFAType, serviceErr := c.services.Process2FAAuthHeader( ctx.UserContext(), services.ProcessAuthHeaderOptions{ - RequestID: getRequestID(ctx), + RequestID: requestID, AuthHeader: authHeader, }, ) @@ -129,6 +130,7 @@ func (c *Controllers) TwoFAAccessClaimsMiddleware(ctx *fiber.Ctx) error { } ctx.Locals("account", accountClaims) + ctx.Locals("twoFAType", twoFAType) return ctx.Next() } @@ -250,7 +252,6 @@ func (c *Controllers) AccountHostMiddleware(ctx *fiber.Ctx) error { func getAccountClaims(ctx *fiber.Ctx) (tokens.AccountClaims, *exceptions.ServiceError) { account, ok := ctx.Locals("account").(tokens.AccountClaims) - if !ok || account.AccountID == uuid.Nil { return tokens.AccountClaims{}, exceptions.NewUnauthorizedError() } @@ -258,6 +259,20 @@ func getAccountClaims(ctx *fiber.Ctx) (tokens.AccountClaims, *exceptions.Service return account, nil } +func getAccounts2FAClaims(ctx *fiber.Ctx) (tokens.AccountClaims, tokens.TwoFAType, *exceptions.ServiceError) { + account, ok := ctx.Locals("account").(tokens.AccountClaims) + if !ok || account.AccountID == uuid.Nil { + return tokens.AccountClaims{}, "", exceptions.NewUnauthorizedError() + } + + twoFAType, ok := ctx.Locals("twoFAType").(tokens.TwoFAType) + if !ok || twoFAType == "" { + return tokens.AccountClaims{}, "", exceptions.NewUnauthorizedError() + } + + return account, twoFAType, nil +} + func getScopes(ctx *fiber.Ctx) ([]tokens.AccountScope, *exceptions.ServiceError) { scopes, ok := ctx.Locals("scopes").([]tokens.AccountScope) if !ok || scopes == nil { @@ -295,20 +310,6 @@ func getUserAccessClaims(ctx *fiber.Ctx) (tokens.UserAuthClaims, tokens.AppClaim return user, app, scopes, nil } -func getUserPurposeClaims(ctx *fiber.Ctx) (tokens.UserPurposeClaims, tokens.AppClaims, *exceptions.ServiceError) { - user, ok := ctx.Locals("user").(tokens.UserPurposeClaims) - if !ok || user.UserID == uuid.Nil { - return tokens.UserPurposeClaims{}, tokens.AppClaims{}, exceptions.NewUnauthorizedError() - } - - app, ok := ctx.Locals("app").(tokens.AppClaims) - if !ok || app.ClientID == "" { - return tokens.UserPurposeClaims{}, tokens.AppClaims{}, exceptions.NewUnauthorizedError() - } - - return user, app, nil -} - func getHostAccount(ctx *fiber.Ctx) (string, int32, *exceptions.ServiceError) { accountUsername, ok := ctx.Locals("accountUsername").(string) if !ok || accountUsername == "" { diff --git a/idp/internal/controllers/params/account_2fa_configs.go b/idp/internal/controllers/params/account_2fa_configs.go new file mode 100644 index 0000000..a4c8bb3 --- /dev/null +++ b/idp/internal/controllers/params/account_2fa_configs.go @@ -0,0 +1,11 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package params + +type GetAccount2FAConfigURLParams struct { + TwoFAType string `validate:"required,oneof=email totp"` +} diff --git a/idp/internal/controllers/paths/two_fa.go b/idp/internal/controllers/paths/two_fa.go new file mode 100644 index 0000000..1ee357b --- /dev/null +++ b/idp/internal/controllers/paths/two_fa.go @@ -0,0 +1,12 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package paths + +const ( + TwoFASingle string = "/:twoFAType" + TwoFADefault string = "/default" +) diff --git a/idp/internal/controllers/users_auth.go b/idp/internal/controllers/users_auth.go index d160bdc..0d8a4c6 100644 --- a/idp/internal/controllers/users_auth.go +++ b/idp/internal/controllers/users_auth.go @@ -152,46 +152,7 @@ func (c *Controllers) LoginUser(ctx *fiber.Ctx) error { return ctx.Status(fiber.StatusOK).JSON(&authDTO) } -func (c *Controllers) TwoFactorLoginUser(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, usersAuthLocation, "TwoFactorLoginUser") - logRequest(logger, ctx) - - accountUsername, accountID, serviceErr := getHostAccount(ctx) - if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) - } - - userClaims, appClaims, serviceErr := getUserPurposeClaims(ctx) - if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) - } - - body := new(bodies.TwoFactorLoginBody) - if err := ctx.BodyParser(body); err != nil { - return parseRequestErrorResponse(logger, ctx, err) - } - if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { - return validateBodyErrorResponse(logger, ctx, err) - } - - authDTO, serviceErr := c.services.TwoFactorLoginUser(ctx.UserContext(), services.TwoFactorLoginUserOptions{ - RequestID: requestID, - AccountID: accountID, - AccountUsername: accountUsername, - AppClientID: appClaims.ClientID, - AppVersion: appClaims.Version, - UserPublicID: userClaims.UserID, - UserVersion: userClaims.UserVersion, - Code: body.Code, - }) - if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusOK) - return ctx.Status(fiber.StatusOK).JSON(&authDTO) -} +// TODO: Add 2FA Login func (c *Controllers) LogoutUser(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) diff --git a/idp/internal/providers/cache/account_credentials_dynamic_registration.go b/idp/internal/providers/cache/account_credentials_dynamic_registration.go index 72a913b..f776a48 100644 --- a/idp/internal/providers/cache/account_credentials_dynamic_registration.go +++ b/idp/internal/providers/cache/account_credentials_dynamic_registration.go @@ -234,6 +234,7 @@ type AccountCredentialsDynamicRegistrationIAT2FAData struct { ClientID string `json:"clientId"` Domain string `json:"domain"` State string `json:"state"` + TwoFAType string `json:"two_factor_type"` } func buildAccountCredentialsDynamicRegistrationIAT2FACacheKey(sessionID string) string { @@ -248,6 +249,7 @@ type SaveAccountCredentialsDynamicRegistrationIAT2FAOptions struct { Domain string ClientID string State string + TwoFAType string TwoFATTL int64 } @@ -273,6 +275,7 @@ func (c *Cache) SaveAccountCredentialsDynamicRegistrationIAT2FA( Domain: opts.Domain, ClientID: opts.ClientID, State: opts.State, + TwoFAType: opts.TwoFAType, } dataBytes, err := json.Marshal(data) if err != nil { diff --git a/idp/internal/providers/cache/sensitive_requests.go b/idp/internal/providers/cache/sensitive_requests.go index beaa23c..de8f7db 100644 --- a/idp/internal/providers/cache/sensitive_requests.go +++ b/idp/internal/providers/cache/sensitive_requests.go @@ -13,7 +13,6 @@ import ( "github.com/google/uuid" - "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/utils" ) @@ -28,8 +27,8 @@ const ( emailUpdatePrefix string = "email_update" passwordUpdatePrefix string = "password_update" deleteAccountPrefix string = "delete_account" - twoFactorUpdatePrefix string = "two_factor_update" usernameUpdatePrefix string = "username_update" + twoFactorDeletePrefix string = "two_factor_delete" ) type SaveUpdateEmailRequestOptions struct { @@ -219,129 +218,162 @@ func (c *Cache) GetDeleteAccountRequest(ctx context.Context, opts GetDeleteAccou return true, nil } -type SaveTwoFactorUpdateRequestOptions struct { +type SaveUpdateUsernameRequestOptions struct { RequestID string PrefixType SensitiveRequestPrefixType PublicID uuid.UUID - TwoFactorType database.TwoFactorType + Username string DurationSeconds int64 } -func (c *Cache) SaveTwoFactorUpdateRequest(ctx context.Context, opts SaveTwoFactorUpdateRequestOptions) error { +func (c *Cache) SaveUpdateUsernameRequest(ctx context.Context, opts SaveUpdateUsernameRequestOptions) error { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: sensitiveRequestsLocation, - Method: "SaveTwoFactorUpdateRequest", + Method: "SaveUpdateUsernameRequest", RequestID: opts.RequestID, }).With( "prefixType", opts.PrefixType, "publicID", opts.PublicID, - "twoFactorType", opts.TwoFactorType, + "username", opts.Username, ) - logger.DebugContext(ctx, "Saving two-factor update request...") + logger.DebugContext(ctx, "Saving update username request...") - key := fmt.Sprintf("%s:%s:%s", twoFactorUpdatePrefix, opts.PrefixType, opts.PublicID.String()) - val := []byte(opts.TwoFactorType) + key := fmt.Sprintf("%s:%s:%s", usernameUpdatePrefix, opts.PrefixType, opts.PublicID.String()) + val := []byte(opts.Username) exp := time.Duration(opts.DurationSeconds) * time.Second if err := c.storage.SetWithContext(ctx, key, val, exp); err != nil { - logger.ErrorContext(ctx, "Error caching two-factor update request", "error", err) + logger.ErrorContext(ctx, "Error caching update username request", "error", err) return err } return nil } -type GetTwoFactorUpdateRequestOptions struct { +type GetUpdateUsernameRequestOptions struct { RequestID string PrefixType SensitiveRequestPrefixType PublicID uuid.UUID } -func (c *Cache) GetTwoFactorUpdateRequest( - ctx context.Context, - opts GetTwoFactorUpdateRequestOptions, -) (database.TwoFactorType, error) { +func (c *Cache) GetUpdateUsernameRequest(ctx context.Context, opts GetUpdateUsernameRequestOptions) (string, error) { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: sensitiveRequestsLocation, - Method: "GetTwoFactorUpdateRequest", + Method: "GetUpdateUsernameRequest", RequestID: opts.RequestID, }).With( "prefixType", opts.PrefixType, "publicID", opts.PublicID, ) - logger.DebugContext(ctx, "Getting two-factor update request...") + logger.DebugContext(ctx, "Getting update username request...") - key := fmt.Sprintf("%s:%s:%s", twoFactorUpdatePrefix, opts.PrefixType, opts.PublicID.String()) + key := fmt.Sprintf("%s:%s:%s", usernameUpdatePrefix, opts.PrefixType, opts.PublicID.String()) val, err := c.storage.GetWithContext(ctx, key) if err != nil { - logger.ErrorContext(ctx, "Error getting the two-factor update request", "error", err) + logger.ErrorContext(ctx, "Error getting the update username request", "error", err) return "", err } if val == nil { - logger.DebugContext(ctx, "Two-factor update request not found") + logger.DebugContext(ctx, "Update username request not found") return "", nil } - return database.TwoFactorType(val), nil + return string(val), nil } -type SaveUpdateUsernameRequestOptions struct { - RequestID string - PrefixType SensitiveRequestPrefixType - PublicID uuid.UUID - Username string - DurationSeconds int64 +type SaveDelete2FAConfigRequestOptions struct { + RequestID string + PrefixType SensitiveRequestPrefixType + PublicID uuid.UUID + TwoFAType string + TTL int64 } -func (c *Cache) SaveUpdateUsernameRequest(ctx context.Context, opts SaveUpdateUsernameRequestOptions) error { +func (c *Cache) SaveDelete2FAConfigRequest( + ctx context.Context, + opts SaveDelete2FAConfigRequestOptions, +) (string, error) { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: sensitiveRequestsLocation, - Method: "SaveUpdateUsernameRequest", + Method: "SaveDelete2FAConfigRequest", RequestID: opts.RequestID, }).With( "prefixType", opts.PrefixType, "publicID", opts.PublicID, - "username", opts.Username, + "twoFAType", opts.TwoFAType, ) - logger.DebugContext(ctx, "Saving update username request...") + logger.DebugContext(ctx, "Saving delete 2FA config request...") + + var code, hashedCode string + if opts.TwoFAType == "email" { + var err error + code, err = generate2FACode() + if err != nil { + logger.ErrorContext(ctx, "Error generating 2FA code", "error", err) + return "", err + } + + hashedCode = utils.Sha256HashHex(code) + } - key := fmt.Sprintf("%s:%s:%s", usernameUpdatePrefix, opts.PrefixType, opts.PublicID.String()) - val := []byte(opts.Username) - exp := time.Duration(opts.DurationSeconds) * time.Second - if err := c.storage.SetWithContext(ctx, key, val, exp); err != nil { - logger.ErrorContext(ctx, "Error caching update username request", "error", err) - return err + key := fmt.Sprintf("%s:%s:%s:%s", twoFactorDeletePrefix, opts.PrefixType, opts.PublicID.String(), opts.TwoFAType) + if err := c.storage.SetWithContext(ctx, key, []byte(hashedCode), time.Duration(opts.TTL)*time.Second); err != nil { + logger.ErrorContext(ctx, "Error caching delete 2FA config request", "error", err) + return "", err } - return nil + return code, nil } -type GetUpdateUsernameRequestOptions struct { +type VerifyDelete2FAConfigRequestOptions struct { RequestID string PrefixType SensitiveRequestPrefixType PublicID uuid.UUID + TwoFAType string + Code string } -func (c *Cache) GetUpdateUsernameRequest(ctx context.Context, opts GetUpdateUsernameRequestOptions) (string, error) { +func (c *Cache) VerifyDelete2FAConfigRequest( + ctx context.Context, + opts VerifyDelete2FAConfigRequestOptions, +) (bool, error) { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: sensitiveRequestsLocation, - Method: "GetUpdateUsernameRequest", + Method: "VerifyDelete2FAConfigRequest", RequestID: opts.RequestID, }).With( "prefixType", opts.PrefixType, "publicID", opts.PublicID, + "twoFAType", opts.TwoFAType, ) - logger.DebugContext(ctx, "Getting update username request...") + logger.DebugContext(ctx, "Verifying delete 2FA config request...") - key := fmt.Sprintf("%s:%s:%s", usernameUpdatePrefix, opts.PrefixType, opts.PublicID.String()) + key := fmt.Sprintf("%s:%s:%s:%s", twoFactorDeletePrefix, opts.PrefixType, opts.PublicID.String(), opts.TwoFAType) val, err := c.storage.GetWithContext(ctx, key) if err != nil { - logger.ErrorContext(ctx, "Error getting the update username request", "error", err) - return "", err + logger.ErrorContext(ctx, "Error getting the delete 2FA config request", "error", err) + return false, err } if val == nil { - logger.DebugContext(ctx, "Update username request not found") - return "", nil + logger.DebugContext(ctx, "Delete 2FA config request not found") + return false, nil } - return string(val), nil + if opts.TwoFAType == "email" { + ok, err := utils.CompareShaHex(opts.Code, string(val)) + if err != nil { + logger.ErrorContext(ctx, "Error comparing delete 2FA config request", "error", err) + return false, err + } + if !ok { + logger.DebugContext(ctx, "Delete 2FA config request does not match") + return false, nil + } + } + + if err := c.storage.DeleteWithContext(ctx, key); err != nil { + logger.ErrorContext(ctx, "Error deleting delete 2FA config request", "error", err) + return true, err + } + + return true, nil } diff --git a/idp/internal/providers/database/account_2fa_configs.sql.go b/idp/internal/providers/database/account_2fa_configs.sql.go new file mode 100644 index 0000000..9d84eae --- /dev/null +++ b/idp/internal/providers/database/account_2fa_configs.sql.go @@ -0,0 +1,196 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: account_2fa_configs.sql + +package database + +import ( + "context" + + "github.com/google/uuid" +) + +const countAccount2FAConfigsByAccountID = `-- name: CountAccount2FAConfigsByAccountID :one +SELECT COUNT(*) FROM "account_2fa_configs" +WHERE "account_id" = $1 +LIMIT 1 +` + +func (q *Queries) CountAccount2FAConfigsByAccountID(ctx context.Context, accountID int32) (int64, error) { + row := q.db.QueryRow(ctx, countAccount2FAConfigsByAccountID, accountID) + var count int64 + err := row.Scan(&count) + return count, err +} + +const createAccount2FAConfig = `-- name: CreateAccount2FAConfig :one + +INSERT INTO "account_2fa_configs" ( + "account_id", + "account_public_id", + "two_factor_type", + "is_default" +) VALUES ( + $1, + $2, + $3, + $4 +) RETURNING id, account_id, account_public_id, two_factor_type, is_default, is_active, created_at, updated_at +` + +type CreateAccount2FAConfigParams struct { + AccountID int32 + AccountPublicID uuid.UUID + TwoFactorType TwoFactorType + IsDefault bool +} + +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +func (q *Queries) CreateAccount2FAConfig(ctx context.Context, arg CreateAccount2FAConfigParams) (Account2faConfig, error) { + row := q.db.QueryRow(ctx, createAccount2FAConfig, + arg.AccountID, + arg.AccountPublicID, + arg.TwoFactorType, + arg.IsDefault, + ) + var i Account2faConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.TwoFactorType, + &i.IsDefault, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteAccount2FAConfig = `-- name: DeleteAccount2FAConfig :exec +DELETE FROM "account_2fa_configs" +WHERE "id" = $1 +` + +func (q *Queries) DeleteAccount2FAConfig(ctx context.Context, id int32) error { + _, err := q.db.Exec(ctx, deleteAccount2FAConfig, id) + return err +} + +const findAccount2FAConfigByAccountPublicIDAndType = `-- name: FindAccount2FAConfigByAccountPublicIDAndType :one +SELECT id, account_id, account_public_id, two_factor_type, is_default, is_active, created_at, updated_at FROM "account_2fa_configs" +WHERE "account_public_id" = $1 AND "two_factor_type" = $2 +LIMIT 1 +` + +type FindAccount2FAConfigByAccountPublicIDAndTypeParams struct { + AccountPublicID uuid.UUID + TwoFactorType TwoFactorType +} + +func (q *Queries) FindAccount2FAConfigByAccountPublicIDAndType(ctx context.Context, arg FindAccount2FAConfigByAccountPublicIDAndTypeParams) (Account2faConfig, error) { + row := q.db.QueryRow(ctx, findAccount2FAConfigByAccountPublicIDAndType, arg.AccountPublicID, arg.TwoFactorType) + var i Account2faConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.TwoFactorType, + &i.IsDefault, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const findAccount2FAConfigsByAccountPublicID = `-- name: FindAccount2FAConfigsByAccountPublicID :many +SELECT id, account_id, account_public_id, two_factor_type, is_default, is_active, created_at, updated_at FROM "account_2fa_configs" +WHERE "account_public_id" = $1 +ORDER BY "id" DESC +` + +func (q *Queries) FindAccount2FAConfigsByAccountPublicID(ctx context.Context, accountPublicID uuid.UUID) ([]Account2faConfig, error) { + rows, err := q.db.Query(ctx, findAccount2FAConfigsByAccountPublicID, accountPublicID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Account2faConfig{} + for rows.Next() { + var i Account2faConfig + if err := rows.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.TwoFactorType, + &i.IsDefault, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const findDefaultAccount2FAConfigByAccountPublicID = `-- name: FindDefaultAccount2FAConfigByAccountPublicID :one +SELECT id, account_id, account_public_id, two_factor_type, is_default, is_active, created_at, updated_at FROM "account_2fa_configs" +WHERE "account_public_id" = $1 AND "is_default" = true +LIMIT 1 +` + +func (q *Queries) FindDefaultAccount2FAConfigByAccountPublicID(ctx context.Context, accountPublicID uuid.UUID) (Account2faConfig, error) { + row := q.db.QueryRow(ctx, findDefaultAccount2FAConfigByAccountPublicID, accountPublicID) + var i Account2faConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.TwoFactorType, + &i.IsDefault, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const updateAccount2FAConfig = `-- name: UpdateAccount2FAConfig :one +UPDATE "account_2fa_configs" SET + "is_default" = $2, + "updated_at" = now() +WHERE "id" = $1 +RETURNING id, account_id, account_public_id, two_factor_type, is_default, is_active, created_at, updated_at +` + +type UpdateAccount2FAConfigParams struct { + ID int32 + IsDefault bool +} + +func (q *Queries) UpdateAccount2FAConfig(ctx context.Context, arg UpdateAccount2FAConfigParams) (Account2faConfig, error) { + row := q.db.QueryRow(ctx, updateAccount2FAConfig, arg.ID, arg.IsDefault) + var i Account2faConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.TwoFactorType, + &i.IsDefault, + &i.IsActive, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/idp/internal/providers/database/account_credential_secrets.sql.go b/idp/internal/providers/database/account_credential_secrets.sql.go index be6eae8..2ea3eca 100644 --- a/idp/internal/providers/database/account_credential_secrets.sql.go +++ b/idp/internal/providers/database/account_credential_secrets.sql.go @@ -100,7 +100,7 @@ func (q *Queries) FindAccountCredentialSecretByAccountCredentialIDAndCredentials } const findAccountCredentialsSecretAccountByAccountCredentialIDAndSecretID = `-- name: FindAccountCredentialsSecretAccountByAccountCredentialIDAndSecretID :one -SELECT a.id, a.public_id, a.given_name, a.family_name, a.username, a.email, a.organization, a.password, a.version, a.email_verified, a.is_active, a.two_factor_type, a.created_at, a.updated_at FROM "accounts" AS "a" +SELECT a.id, a.public_id, a.given_name, a.family_name, a.username, a.email, a.organization, a.password, a.version, a.email_verified, a.activity_status, a.created_at, a.updated_at FROM "accounts" AS "a" LEFT JOIN "account_credentials_secrets" AS "acs" ON "acs"."account_id" = "a"."id" WHERE "acs"."account_credentials_id" = $1 AND @@ -127,8 +127,7 @@ func (q *Queries) FindAccountCredentialsSecretAccountByAccountCredentialIDAndSec &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) diff --git a/idp/internal/providers/database/account_credentials_keys.sql.go b/idp/internal/providers/database/account_credentials_keys.sql.go index aa5dc79..26f8441 100644 --- a/idp/internal/providers/database/account_credentials_keys.sql.go +++ b/idp/internal/providers/database/account_credentials_keys.sql.go @@ -99,7 +99,7 @@ func (q *Queries) FindAccountCredentialKeyByAccountCredentialIDAndPublicKID(ctx } const findAccountCredentialsKeyAccountByAccountCredentialIDAndJWKKID = `-- name: FindAccountCredentialsKeyAccountByAccountCredentialIDAndJWKKID :one -SELECT a.id, a.public_id, a.given_name, a.family_name, a.username, a.email, a.organization, a.password, a.version, a.email_verified, a.is_active, a.two_factor_type, a.created_at, a.updated_at FROM "accounts" AS "a" +SELECT a.id, a.public_id, a.given_name, a.family_name, a.username, a.email, a.organization, a.password, a.version, a.email_verified, a.activity_status, a.created_at, a.updated_at FROM "accounts" AS "a" LEFT JOIN "account_credentials_keys" AS "ack" ON "ack"."account_id" = "a"."id" WHERE "ack"."account_credentials_id" = $1 AND @@ -126,8 +126,7 @@ func (q *Queries) FindAccountCredentialsKeyAccountByAccountCredentialIDAndJWKKID &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) diff --git a/idp/internal/providers/database/accounts.sql.go b/idp/internal/providers/database/accounts.sql.go index 83ed343..8bef59b 100644 --- a/idp/internal/providers/database/accounts.sql.go +++ b/idp/internal/providers/database/accounts.sql.go @@ -18,7 +18,7 @@ UPDATE "accounts" SET "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 -RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at +RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at ` func (q *Queries) ConfirmAccount(ctx context.Context, id int32) (Account, error) { @@ -35,8 +35,7 @@ func (q *Queries) ConfirmAccount(ctx context.Context, id int32) (Account, error) &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) @@ -83,7 +82,7 @@ INSERT INTO "accounts" ( $4, $5, $6 -) RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at +) RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at ` type CreateAccountWithPasswordParams struct { @@ -121,8 +120,7 @@ func (q *Queries) CreateAccountWithPassword(ctx context.Context, arg CreateAccou &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) @@ -146,7 +144,7 @@ INSERT INTO "accounts" ( $5, 2, true -) RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at +) RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at ` type CreateAccountWithoutPasswordParams struct { @@ -177,8 +175,7 @@ func (q *Queries) CreateAccountWithoutPassword(ctx context.Context, arg CreateAc &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) @@ -205,7 +202,7 @@ func (q *Queries) DeleteAllAccounts(ctx context.Context) error { } const findAccountByEmail = `-- name: FindAccountByEmail :one -SELECT id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at FROM "accounts" +SELECT id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at FROM "accounts" WHERE "email" = $1 LIMIT 1 ` @@ -223,8 +220,7 @@ func (q *Queries) FindAccountByEmail(ctx context.Context, email string) (Account &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) @@ -232,7 +228,7 @@ func (q *Queries) FindAccountByEmail(ctx context.Context, email string) (Account } const findAccountById = `-- name: FindAccountById :one -SELECT id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at FROM "accounts" +SELECT id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at FROM "accounts" WHERE "id" = $1 LIMIT 1 ` @@ -250,8 +246,7 @@ func (q *Queries) FindAccountById(ctx context.Context, id int32) (Account, error &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) @@ -259,7 +254,7 @@ func (q *Queries) FindAccountById(ctx context.Context, id int32) (Account, error } const findAccountByPublicID = `-- name: FindAccountByPublicID :one -SELECT id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at FROM "accounts" +SELECT id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at FROM "accounts" WHERE "public_id" = $1 LIMIT 1 ` @@ -277,8 +272,7 @@ func (q *Queries) FindAccountByPublicID(ctx context.Context, publicID uuid.UUID) &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) @@ -286,7 +280,7 @@ func (q *Queries) FindAccountByPublicID(ctx context.Context, publicID uuid.UUID) } const findAccountByPublicIDAndVersion = `-- name: FindAccountByPublicIDAndVersion :one -SELECT id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at FROM "accounts" +SELECT id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at FROM "accounts" WHERE "public_id" = $1 AND "version" = $2 LIMIT 1 ` @@ -309,8 +303,7 @@ func (q *Queries) FindAccountByPublicIDAndVersion(ctx context.Context, arg FindA &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) @@ -352,7 +345,7 @@ UPDATE "accounts" SET "family_name" = $2, "updated_at" = now() WHERE "id" = $3 -RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at +RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at ` type UpdateAccountParams struct { @@ -375,8 +368,7 @@ func (q *Queries) UpdateAccount(ctx context.Context, arg UpdateAccountParams) (A &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) @@ -389,7 +381,7 @@ UPDATE "accounts" SET "version" = "version" + 1, "updated_at" = now() WHERE "id" = $2 -RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at +RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at ` type UpdateAccountEmailParams struct { @@ -411,8 +403,7 @@ func (q *Queries) UpdateAccountEmail(ctx context.Context, arg UpdateAccountEmail &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) @@ -425,7 +416,7 @@ UPDATE "accounts" SET "version" = "version" + 1, "updated_at" = now() WHERE "id" = $2 -RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at +RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at ` type UpdateAccountPasswordParams struct { @@ -447,39 +438,20 @@ func (q *Queries) UpdateAccountPassword(ctx context.Context, arg UpdateAccountPa &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) return i, err } -const updateAccountTwoFactorType = `-- name: UpdateAccountTwoFactorType :exec -UPDATE "accounts" SET - "two_factor_type" = $1, - "version" = "version" + 1, - "updated_at" = now() -WHERE "id" = $2 -` - -type UpdateAccountTwoFactorTypeParams struct { - TwoFactorType TwoFactorType - ID int32 -} - -func (q *Queries) UpdateAccountTwoFactorType(ctx context.Context, arg UpdateAccountTwoFactorTypeParams) error { - _, err := q.db.Exec(ctx, updateAccountTwoFactorType, arg.TwoFactorType, arg.ID) - return err -} - const updateAccountUsername = `-- name: UpdateAccountUsername :one UPDATE "accounts" SET "username" = $1, "version" = "version" + 1, "updated_at" = now() WHERE "id" = $2 -RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, is_active, two_factor_type, created_at, updated_at +RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at ` type UpdateAccountUsernameParams struct { @@ -501,8 +473,36 @@ func (q *Queries) UpdateAccountUsername(ctx context.Context, arg UpdateAccountUs &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const updateAccountVersion = `-- name: UpdateAccountVersion :one +UPDATE "accounts" SET + "version" = "version" + 1, + "updated_at" = now() +WHERE "id" = $1 +RETURNING id, public_id, given_name, family_name, username, email, organization, password, version, email_verified, activity_status, created_at, updated_at +` + +func (q *Queries) UpdateAccountVersion(ctx context.Context, id int32) (Account, error) { + row := q.db.QueryRow(ctx, updateAccountVersion, id) + var i Account + err := row.Scan( + &i.ID, + &i.PublicID, + &i.GivenName, + &i.FamilyName, + &i.Username, + &i.Email, + &i.Organization, + &i.Password, + &i.Version, + &i.EmailVerified, + &i.ActivityStatus, &i.CreatedAt, &i.UpdatedAt, ) diff --git a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql index e9f1420..c5150f3 100644 --- a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql +++ b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql @@ -1,6 +1,6 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-08-18T07:42:22.764Z +-- Generated at: 2025-09-06T04:54:24.517Z CREATE TYPE "kek_usage" AS ENUM ( 'global', @@ -35,8 +35,13 @@ CREATE TYPE "token_key_type" AS ENUM ( 'dynamic_registration' ); +CREATE TYPE "activity_status" AS ENUM ( + 'active', + 'suspended', + 'blocked' +); + CREATE TYPE "two_factor_type" AS ENUM ( - 'none', 'totp', 'email' ); @@ -177,7 +182,8 @@ CREATE TYPE "initial_access_token_generation_method" AS ENUM ( CREATE TYPE "software_statement_verification_method" AS ENUM ( 'manual', - 'jwks_uri' + 'jwks_uri', + 'jwk_x5_parameters' ); CREATE TYPE "domain_verification_method" AS ENUM ( @@ -247,8 +253,18 @@ CREATE TABLE "accounts" ( "password" text, "version" integer NOT NULL DEFAULT 1, "email_verified" boolean NOT NULL DEFAULT false, - "is_active" boolean NOT NULL DEFAULT true, - "two_factor_type" two_factor_type NOT NULL DEFAULT 'none', + "activity_status" activity_status NOT NULL DEFAULT 'active', + "created_at" timestamptz NOT NULL DEFAULT (now()), + "updated_at" timestamptz NOT NULL DEFAULT (now()) +); + +CREATE TABLE "account_2fa_configs" ( + "id" serial PRIMARY KEY, + "account_id" integer NOT NULL, + "account_public_id" uuid NOT NULL, + "two_factor_type" two_factor_type NOT NULL, + "is_default" boolean NOT NULL DEFAULT false, + "is_active" boolean NOT NULL DEFAULT false, "created_at" timestamptz NOT NULL DEFAULT (now()), "updated_at" timestamptz NOT NULL DEFAULT (now()) ); @@ -406,13 +422,22 @@ CREATE TABLE "users" ( "password" text, "version" integer NOT NULL DEFAULT 1, "email_verified" boolean NOT NULL DEFAULT false, - "is_active" boolean NOT NULL DEFAULT true, - "two_factor_type" two_factor_type NOT NULL DEFAULT 'none', + "activity_status" activity_status NOT NULL DEFAULT 'active', "user_data" jsonb NOT NULL DEFAULT '{}', "created_at" timestamptz NOT NULL DEFAULT (now()), "updated_at" timestamptz NOT NULL DEFAULT (now()) ); +CREATE TABLE "user_2fa_configs" ( + "id" serial PRIMARY KEY, + "account_id" integer NOT NULL, + "user_id" integer NOT NULL, + "two_factor_type" two_factor_type NOT NULL, + "is_default" boolean NOT NULL DEFAULT false, + "created_at" timestamptz NOT NULL DEFAULT (now()), + "updated_at" timestamptz NOT NULL DEFAULT (now()) +); + CREATE TABLE "user_data_encryption_keys" ( "user_id" integer NOT NULL, "data_encryption_key_id" integer NOT NULL, @@ -690,6 +715,14 @@ CREATE INDEX "accounts_public_id_version_idx" ON "accounts" ("public_id", "versi CREATE UNIQUE INDEX "accounts_username_uidx" ON "accounts" ("username"); +CREATE INDEX "account_2fa_configs_account_id_idx" ON "account_2fa_configs" ("account_id"); + +CREATE INDEX "account_2fa_configs_account_public_id_idx" ON "account_2fa_configs" ("account_public_id"); + +CREATE INDEX "account_2fa_configs_account_public_id_is_default_idx" ON "account_2fa_configs" ("account_public_id", "is_default"); + +CREATE INDEX "account_2fa_configs_account_public_id_two_factor_type_idx" ON "account_2fa_configs" ("account_public_id", "two_factor_type"); + CREATE INDEX "accounts_totps_dek_kid_idx" ON "totps" ("dek_kid"); CREATE INDEX "accounts_totps_account_id_idx" ON "totps" ("account_id"); @@ -798,6 +831,14 @@ CREATE UNIQUE INDEX "users_public_id_uidx" ON "users" ("public_id"); CREATE INDEX "users_public_id_version_idx" ON "users" ("public_id", "version"); +CREATE INDEX "user_2fa_configs_account_id_idx" ON "user_2fa_configs" ("account_id"); + +CREATE INDEX "user_2fa_configs_user_id_idx" ON "user_2fa_configs" ("user_id"); + +CREATE INDEX "user_2fa_configs_two_factor_type_idx" ON "user_2fa_configs" ("two_factor_type"); + +CREATE UNIQUE INDEX "user_2fa_configs_user_id_two_factor_type_uidx" ON "user_2fa_configs" ("user_id", "two_factor_type"); + CREATE INDEX "user_data_encryption_keys_user_id_idx" ON "user_data_encryption_keys" ("user_id"); CREATE UNIQUE INDEX "user_data_encryption_keys_data_encryption_key_id_uidx" ON "user_data_encryption_keys" ("data_encryption_key_id"); @@ -950,6 +991,8 @@ ALTER TABLE "data_encryption_keys" ADD FOREIGN KEY ("kek_kid") REFERENCES "key_e ALTER TABLE "token_signing_keys" ADD FOREIGN KEY ("dek_kid") REFERENCES "data_encryption_keys" ("kid") ON DELETE CASCADE ON UPDATE CASCADE; +ALTER TABLE "account_2fa_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; + ALTER TABLE "totps" ADD FOREIGN KEY ("dek_kid") REFERENCES "data_encryption_keys" ("kid") ON DELETE CASCADE ON UPDATE CASCADE; ALTER TABLE "totps" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; @@ -998,6 +1041,10 @@ ALTER TABLE "account_token_signing_keys" ADD FOREIGN KEY ("token_signing_key_id" ALTER TABLE "users" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; +ALTER TABLE "user_2fa_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; + +ALTER TABLE "user_2fa_configs" ADD FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON DELETE CASCADE; + ALTER TABLE "user_data_encryption_keys" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; ALTER TABLE "user_data_encryption_keys" ADD FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON DELETE CASCADE; diff --git a/idp/internal/providers/database/models.go b/idp/internal/providers/database/models.go index 51ff16c..251a769 100644 --- a/idp/internal/providers/database/models.go +++ b/idp/internal/providers/database/models.go @@ -110,6 +110,49 @@ func (ns NullAccountCredentialsType) Value() (driver.Value, error) { return string(ns.AccountCredentialsType), nil } +type ActivityStatus string + +const ( + ActivityStatusActive ActivityStatus = "active" + ActivityStatusSuspended ActivityStatus = "suspended" + ActivityStatusBlocked ActivityStatus = "blocked" +) + +func (e *ActivityStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ActivityStatus(s) + case string: + *e = ActivityStatus(s) + default: + return fmt.Errorf("unsupported scan type for ActivityStatus: %T", src) + } + return nil +} + +type NullActivityStatus struct { + ActivityStatus ActivityStatus + Valid bool // Valid is true if ActivityStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullActivityStatus) Scan(value interface{}) error { + if value == nil { + ns.ActivityStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.ActivityStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullActivityStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.ActivityStatus), nil +} + type AppProfileType string const ( @@ -827,8 +870,9 @@ func (ns NullSecretStorageMode) Value() (driver.Value, error) { type SoftwareStatementVerificationMethod string const ( - SoftwareStatementVerificationMethodManual SoftwareStatementVerificationMethod = "manual" - SoftwareStatementVerificationMethodJwksUri SoftwareStatementVerificationMethod = "jwks_uri" + SoftwareStatementVerificationMethodManual SoftwareStatementVerificationMethod = "manual" + SoftwareStatementVerificationMethodJwksUri SoftwareStatementVerificationMethod = "jwks_uri" + SoftwareStatementVerificationMethodJwkX5Parameters SoftwareStatementVerificationMethod = "jwk_x5_parameters" ) func (e *SoftwareStatementVerificationMethod) Scan(src interface{}) error { @@ -1130,7 +1174,6 @@ func (ns NullTransport) Value() (driver.Value, error) { type TwoFactorType string const ( - TwoFactorTypeNone TwoFactorType = "none" TwoFactorTypeTotp TwoFactorType = "totp" TwoFactorTypeEmail TwoFactorType = "email" ) @@ -1171,20 +1214,30 @@ func (ns NullTwoFactorType) Value() (driver.Value, error) { } type Account struct { - ID int32 - PublicID uuid.UUID - GivenName string - FamilyName string - Username string - Email string - Organization pgtype.Text - Password pgtype.Text - Version int32 - EmailVerified bool - IsActive bool - TwoFactorType TwoFactorType - CreatedAt time.Time - UpdatedAt time.Time + ID int32 + PublicID uuid.UUID + GivenName string + FamilyName string + Username string + Email string + Organization pgtype.Text + Password pgtype.Text + Version int32 + EmailVerified bool + ActivityStatus ActivityStatus + CreatedAt time.Time + UpdatedAt time.Time +} + +type Account2faConfig struct { + ID int32 + AccountID int32 + AccountPublicID uuid.UUID + TwoFactorType TwoFactorType + IsDefault bool + IsActive bool + CreatedAt time.Time + UpdatedAt time.Time } type AccountAuthProvider struct { @@ -1542,17 +1595,26 @@ type Totp struct { } type User struct { + ID int32 + PublicID uuid.UUID + AccountID int32 + Email string + Username string + Password pgtype.Text + Version int32 + EmailVerified bool + ActivityStatus ActivityStatus + UserData []byte + CreatedAt time.Time + UpdatedAt time.Time +} + +type User2faConfig struct { ID int32 - PublicID uuid.UUID AccountID int32 - Email string - Username string - Password pgtype.Text - Version int32 - EmailVerified bool - IsActive bool + UserID int32 TwoFactorType TwoFactorType - UserData []byte + IsDefault bool CreatedAt time.Time UpdatedAt time.Time } diff --git a/idp/internal/providers/database/queries/account_2fa_configs.sql b/idp/internal/providers/database/queries/account_2fa_configs.sql new file mode 100644 index 0000000..b32432f --- /dev/null +++ b/idp/internal/providers/database/queries/account_2fa_configs.sql @@ -0,0 +1,49 @@ +-- Copyright (c) 2025 Afonso Barracha +-- +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +-- name: CreateAccount2FAConfig :one +INSERT INTO "account_2fa_configs" ( + "account_id", + "account_public_id", + "two_factor_type", + "is_default" +) VALUES ( + $1, + $2, + $3, + $4 +) RETURNING *; + +-- name: FindDefaultAccount2FAConfigByAccountPublicID :one +SELECT * FROM "account_2fa_configs" +WHERE "account_public_id" = $1 AND "is_default" = true +LIMIT 1; + +-- name: FindAccount2FAConfigByAccountPublicIDAndType :one +SELECT * FROM "account_2fa_configs" +WHERE "account_public_id" = $1 AND "two_factor_type" = $2 +LIMIT 1; + +-- name: FindAccount2FAConfigsByAccountPublicID :many +SELECT * FROM "account_2fa_configs" +WHERE "account_public_id" = $1 +ORDER BY "id" DESC; + +-- name: UpdateAccount2FAConfig :one +UPDATE "account_2fa_configs" SET + "is_default" = $2, + "updated_at" = now() +WHERE "id" = $1 +RETURNING *; + +-- name: DeleteAccount2FAConfig :exec +DELETE FROM "account_2fa_configs" +WHERE "id" = $1; + +-- name: CountAccount2FAConfigsByAccountID :one +SELECT COUNT(*) FROM "account_2fa_configs" +WHERE "account_id" = $1 +LIMIT 1; \ No newline at end of file diff --git a/idp/internal/providers/database/queries/accounts.sql b/idp/internal/providers/database/queries/accounts.sql index 3c763a2..6bc22ef 100644 --- a/idp/internal/providers/database/queries/accounts.sql +++ b/idp/internal/providers/database/queries/accounts.sql @@ -64,6 +64,13 @@ UPDATE "accounts" SET WHERE "id" = $2 RETURNING *; +-- name: UpdateAccountVersion :one +UPDATE "accounts" SET + "version" = "version" + 1, + "updated_at" = now() +WHERE "id" = $1 +RETURNING *; + -- name: FindAccountByEmail :one SELECT * FROM "accounts" WHERE "email" = $1 LIMIT 1; @@ -92,13 +99,6 @@ UPDATE "accounts" SET WHERE "id" = $1 RETURNING *; --- name: UpdateAccountTwoFactorType :exec -UPDATE "accounts" SET - "two_factor_type" = $1, - "version" = "version" + 1, - "updated_at" = now() -WHERE "id" = $2; - -- name: DeleteAllAccounts :exec DELETE FROM "accounts"; diff --git a/idp/internal/providers/database/queries/users.sql b/idp/internal/providers/database/queries/users.sql index 57efc46..6aebab7 100644 --- a/idp/internal/providers/database/queries/users.sql +++ b/idp/internal/providers/database/queries/users.sql @@ -69,14 +69,14 @@ LIMIT 1; -- name: UpdateUser :one UPDATE "users" SET - "email" = $1, - "username" = $2, - "user_data" = $3, - "is_active" = $4, + "email" = $2, + "username" = $3, + "user_data" = $4, "email_verified" = $5, + "activity_status" = $6, "version" = "version" + 1, "updated_at" = now() -WHERE "id" = $6 +WHERE "id" = $1 RETURNING *; -- name: UpdateUserPassword :one diff --git a/idp/internal/providers/database/users.sql.go b/idp/internal/providers/database/users.sql.go index 2e2029c..6ecd08f 100644 --- a/idp/internal/providers/database/users.sql.go +++ b/idp/internal/providers/database/users.sql.go @@ -18,7 +18,7 @@ UPDATE "users" SET "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 -RETURNING id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at +RETURNING id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at ` func (q *Queries) ConfirmUser(ctx context.Context, id int32) (User, error) { @@ -33,8 +33,7 @@ func (q *Queries) ConfirmUser(ctx context.Context, id int32) (User, error) { &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -126,7 +125,7 @@ INSERT INTO "users" ( $4, $5, $6 -) RETURNING id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at +) RETURNING id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at ` type CreateUserWithPasswordParams struct { @@ -162,8 +161,7 @@ func (q *Queries) CreateUserWithPassword(ctx context.Context, arg CreateUserWith &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -184,7 +182,7 @@ INSERT INTO "users" ( $3, $4, $5 -) RETURNING id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at +) RETURNING id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at ` type CreateUserWithoutPasswordParams struct { @@ -213,8 +211,7 @@ func (q *Queries) CreateUserWithoutPassword(ctx context.Context, arg CreateUserW &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -233,7 +230,7 @@ func (q *Queries) DeleteUser(ctx context.Context, id int32) error { } const filterUsersByEmailOrUsernameAndByAccountIDOrderedByEmail = `-- name: FilterUsersByEmailOrUsernameAndByAccountIDOrderedByEmail :many -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "account_id" = $1 AND ("email" ILIKE $2 OR "username" ILIKE $3) ORDER BY "email" ASC OFFSET $4 LIMIT $5 @@ -271,8 +268,7 @@ func (q *Queries) FilterUsersByEmailOrUsernameAndByAccountIDOrderedByEmail(ctx c &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -288,7 +284,7 @@ func (q *Queries) FilterUsersByEmailOrUsernameAndByAccountIDOrderedByEmail(ctx c } const filterUsersByEmailOrUsernameAndByAccountIDOrderedByID = `-- name: FilterUsersByEmailOrUsernameAndByAccountIDOrderedByID :many -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "account_id" = $1 AND ("email" ILIKE $2 OR "username" ILIKE $3) ORDER BY "id" DESC OFFSET $4 LIMIT $5 @@ -326,8 +322,7 @@ func (q *Queries) FilterUsersByEmailOrUsernameAndByAccountIDOrderedByID(ctx cont &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -343,7 +338,7 @@ func (q *Queries) FilterUsersByEmailOrUsernameAndByAccountIDOrderedByID(ctx cont } const filterUsersByEmailOrUsernameAndByAccountIDOrderedByUsername = `-- name: FilterUsersByEmailOrUsernameAndByAccountIDOrderedByUsername :many -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "account_id" = $1 AND ("email" ILIKE $2 OR "username" ILIKE $3) ORDER BY "username" ASC OFFSET $4 LIMIT $5 @@ -381,8 +376,7 @@ func (q *Queries) FilterUsersByEmailOrUsernameAndByAccountIDOrderedByUsername(ct &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -398,7 +392,7 @@ func (q *Queries) FilterUsersByEmailOrUsernameAndByAccountIDOrderedByUsername(ct } const findPaginatedUsersByAccountIDOrderedByEmail = `-- name: FindPaginatedUsersByAccountIDOrderedByEmail :many -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "account_id" = $1 ORDER BY "email" ASC OFFSET $2 LIMIT $3 @@ -428,8 +422,7 @@ func (q *Queries) FindPaginatedUsersByAccountIDOrderedByEmail(ctx context.Contex &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -445,7 +438,7 @@ func (q *Queries) FindPaginatedUsersByAccountIDOrderedByEmail(ctx context.Contex } const findPaginatedUsersByAccountIDOrderedByID = `-- name: FindPaginatedUsersByAccountIDOrderedByID :many -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "account_id" = $1 ORDER BY "id" DESC OFFSET $2 LIMIT $3 @@ -475,8 +468,7 @@ func (q *Queries) FindPaginatedUsersByAccountIDOrderedByID(ctx context.Context, &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -492,7 +484,7 @@ func (q *Queries) FindPaginatedUsersByAccountIDOrderedByID(ctx context.Context, } const findPaginatedUsersByAccountIDOrderedByUsername = `-- name: FindPaginatedUsersByAccountIDOrderedByUsername :many -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "account_id" = $1 ORDER BY "username" ASC OFFSET $2 LIMIT $3 @@ -522,8 +514,7 @@ func (q *Queries) FindPaginatedUsersByAccountIDOrderedByUsername(ctx context.Con &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -539,7 +530,7 @@ func (q *Queries) FindPaginatedUsersByAccountIDOrderedByUsername(ctx context.Con } const findUserByEmailAndAccountID = `-- name: FindUserByEmailAndAccountID :one -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "email" = $1 AND "account_id" = $2 LIMIT 1 ` @@ -561,8 +552,7 @@ func (q *Queries) FindUserByEmailAndAccountID(ctx context.Context, arg FindUserB &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -571,7 +561,7 @@ func (q *Queries) FindUserByEmailAndAccountID(ctx context.Context, arg FindUserB } const findUserByID = `-- name: FindUserByID :one -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "id" = $1 LIMIT 1 ` @@ -587,8 +577,7 @@ func (q *Queries) FindUserByID(ctx context.Context, id int32) (User, error) { &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -597,7 +586,7 @@ func (q *Queries) FindUserByID(ctx context.Context, id int32) (User, error) { } const findUserByPublicIDAndVersion = `-- name: FindUserByPublicIDAndVersion :one -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "public_id" = $1 AND "version" = $2 LIMIT 1 ` @@ -618,8 +607,7 @@ func (q *Queries) FindUserByPublicIDAndVersion(ctx context.Context, arg FindUser &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -628,7 +616,7 @@ func (q *Queries) FindUserByPublicIDAndVersion(ctx context.Context, arg FindUser } const findUserByUsernameAndAccountID = `-- name: FindUserByUsernameAndAccountID :one -SELECT id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at FROM "users" +SELECT id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at FROM "users" WHERE "username" = $1 AND "account_id" = $2 LIMIT 1 ` @@ -650,8 +638,7 @@ func (q *Queries) FindUserByUsernameAndAccountID(ctx context.Context, arg FindUs &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -661,34 +648,34 @@ func (q *Queries) FindUserByUsernameAndAccountID(ctx context.Context, arg FindUs const updateUser = `-- name: UpdateUser :one UPDATE "users" SET - "email" = $1, - "username" = $2, - "user_data" = $3, - "is_active" = $4, + "email" = $2, + "username" = $3, + "user_data" = $4, "email_verified" = $5, + "activity_status" = $6, "version" = "version" + 1, "updated_at" = now() -WHERE "id" = $6 -RETURNING id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at +WHERE "id" = $1 +RETURNING id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at ` type UpdateUserParams struct { - Email string - Username string - UserData []byte - IsActive bool - EmailVerified bool - ID int32 + ID int32 + Email string + Username string + UserData []byte + EmailVerified bool + ActivityStatus ActivityStatus } func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) { row := q.db.QueryRow(ctx, updateUser, + arg.ID, arg.Email, arg.Username, arg.UserData, - arg.IsActive, arg.EmailVerified, - arg.ID, + arg.ActivityStatus, ) var i User err := row.Scan( @@ -700,8 +687,7 @@ func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, e &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, @@ -715,7 +701,7 @@ UPDATE "users" SET "version" = "version" + 1, "updated_at" = now() WHERE "id" = $2 -RETURNING id, public_id, account_id, email, username, password, version, email_verified, is_active, two_factor_type, user_data, created_at, updated_at +RETURNING id, public_id, account_id, email, username, password, version, email_verified, activity_status, user_data, created_at, updated_at ` type UpdateUserPasswordParams struct { @@ -735,8 +721,7 @@ func (q *Queries) UpdateUserPassword(ctx context.Context, arg UpdateUserPassword &i.Password, &i.Version, &i.EmailVerified, - &i.IsActive, - &i.TwoFactorType, + &i.ActivityStatus, &i.UserData, &i.CreatedAt, &i.UpdatedAt, diff --git a/idp/internal/providers/tokens/twoFactor.go b/idp/internal/providers/tokens/twoFactor.go index 2d4c7f2..e2a71e7 100644 --- a/idp/internal/providers/tokens/twoFactor.go +++ b/idp/internal/providers/tokens/twoFactor.go @@ -7,27 +7,76 @@ package tokens import ( + "fmt" + "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/tugascript/devlogs/idp/internal/controllers/paths" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +type TwoFAType string + +const ( + TwoFATypeTOTP TwoFAType = "totp" + TwoFATypeEmail TwoFAType = "email" ) +type account2FATokenClaims struct { + AccountClaims + Purpose TokenPurpose `json:"purpose"` + TwoFAType TwoFAType `json:"two_fa_type"` + jwt.RegisteredClaims +} + type Account2FATokenOptions struct { - PublicID uuid.UUID - Version int32 + PublicID uuid.UUID + Version int32 + TwoFAType TwoFAType } func (t *Tokens) Create2FAToken(opts Account2FATokenOptions) *jwt.Token { - return t.createPurposeToken(accountPurposeTokenOptions{ - ttlSec: t.twoFATTL, - accountPublicID: opts.PublicID, - accountVersion: opts.Version, - path: paths.AuthBase + paths.AuthLogin + paths.Auth2FA, - purpose: TokenPurpose2FA, + now := time.Now() + iat := jwt.NewNumericDate(now) + exp := jwt.NewNumericDate(now.Add(time.Second * time.Duration(t.twoFATTL))) + + return jwt.NewWithClaims(jwt.SigningMethodEdDSA, account2FATokenClaims{ + AccountClaims: AccountClaims{ + AccountID: opts.PublicID, + AccountVersion: opts.Version, + }, + RegisteredClaims: jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{ + buildPathAudience(t.backendDomain, paths.V1+paths.AuthBase+paths.AuthLogin+paths.Auth2FA), + }, + Issuer: fmt.Sprintf("https://%s", t.backendDomain), + Subject: opts.PublicID.String(), + IssuedAt: iat, + NotBefore: iat, + ExpiresAt: exp, + ID: uuid.NewString(), + }, + Purpose: TokenPurpose2FA, + TwoFAType: opts.TwoFAType, }) } +func (t *Tokens) Verify2FAToken(token string, getPublicJWK GetPublicJWK) (AccountClaims, TwoFAType, error) { + claims := new(account2FATokenClaims) + + if _, err := jwt.ParseWithClaims( + token, + claims, + buildVerifyKey(utils.SupportedCryptoSuiteEd25519, getPublicJWK), + ); err != nil { + return AccountClaims{}, "", err + } + + return claims.AccountClaims, claims.TwoFAType, nil +} + func (t *Tokens) Get2FATTL() int64 { return t.twoFATTL } diff --git a/idp/internal/server/routes/auth.go b/idp/internal/server/routes/auth.go index 3b78f56..a48d64c 100644 --- a/idp/internal/server/routes/auth.go +++ b/idp/internal/server/routes/auth.go @@ -32,17 +32,6 @@ func (r *Routes) AuthRoutes(app *fiber.App) { r.controllers.TwoFAAccessClaimsMiddleware, r.controllers.RecoverAccount, ) - router.Put( - paths.Auth2FA, - r.controllers.AccountAccessClaimsMiddleware, - r.controllers.AdminScopeMiddleware, - r.controllers.UpdateAccount2FA, - ) - router.Post( - paths.Auth2FA+paths.Confirm, - r.controllers.TwoFAAccessClaimsMiddleware, - r.controllers.ConfirmUpdateAccount2FA, - ) router.Post(paths.AuthRefresh, r.controllers.RefreshAccount) router.Post(paths.AuthLogout, r.controllers.AccountAccessClaimsMiddleware, r.controllers.LogoutAccount) router.Post(paths.AuthForgotPassword, r.controllers.ForgotAccountPassword) @@ -59,4 +48,41 @@ func (r *Routes) AuthRoutes(app *fiber.App) { authProvsReaderMW, r.controllers.GetAccountAuthProvider, ) + + // 2FA routes + router.Post( + paths.Auth2FA, + r.controllers.AccountAccessClaimsMiddleware, + r.controllers.AdminScopeMiddleware, + r.controllers.CreateAccount2FAConfig, + ) + router.Get( + paths.Auth2FA+paths.TwoFADefault, + r.controllers.AccountAccessClaimsMiddleware, + r.controllers.AdminScopeMiddleware, + r.controllers.GetDefaultAccount2FAConfig, + ) + router.Get( + paths.Auth2FA+paths.TwoFASingle, + r.controllers.AccountAccessClaimsMiddleware, + r.controllers.AdminScopeMiddleware, + r.controllers.GetAccount2FAConfig, + ) + router.Patch( + paths.Auth2FA+paths.TwoFASingle, + r.controllers.AccountAccessClaimsMiddleware, + r.controllers.AdminScopeMiddleware, + r.controllers.SetAccount2FAConfigDefault, + ) + router.Delete( + paths.Auth2FA+paths.TwoFASingle, + r.controllers.AccountAccessClaimsMiddleware, + r.controllers.AdminScopeMiddleware, + r.controllers.DeleteAccount2FAConfig, + ) + router.Post( + paths.Auth2FA+paths.TwoFASingle+paths.Confirm, + r.controllers.TwoFAAccessClaimsMiddleware, + r.controllers.ConfirmDeleteAccount2FAConfig, + ) } diff --git a/idp/internal/server/routes/users_auth.go b/idp/internal/server/routes/users_auth.go index 4a9546b..ba425d0 100644 --- a/idp/internal/server/routes/users_auth.go +++ b/idp/internal/server/routes/users_auth.go @@ -18,11 +18,9 @@ func (r *Routes) UsersAuthRoutes(app *fiber.App) { router.Post(paths.AuthRegister, r.controllers.AppAccessClaimsMiddleware, r.controllers.RegisterUser) router.Post(paths.AuthConfirmEmail, r.controllers.AppAccessClaimsMiddleware, r.controllers.ConfirmUser) router.Post(paths.AuthLogin, r.controllers.AppAccessClaimsMiddleware, r.controllers.LoginUser) - router.Post( - paths.AuthLogin+paths.Auth2FA, - r.controllers.User2FAClaimsMiddleware, - r.controllers.TwoFactorLoginUser, - ) + + // TODO: Add 2FA Login + router.Post(paths.AuthRefresh, r.controllers.AppAccessClaimsMiddleware, r.controllers.RefreshUser) router.Post( paths.AuthLogout, diff --git a/idp/internal/services/account_2fa_configs.go b/idp/internal/services/account_2fa_configs.go new file mode 100644 index 0000000..ed0f1f0 --- /dev/null +++ b/idp/internal/services/account_2fa_configs.go @@ -0,0 +1,830 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + "fmt" + "slices" + + "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/cache" + "github.com/tugascript/devlogs/idp/internal/providers/crypto" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/providers/mailer" + "github.com/tugascript/devlogs/idp/internal/providers/tokens" + "github.com/tugascript/devlogs/idp/internal/services/dtos" +) + +const ( + account2FAConfigLocation = "account_2fa_configs" + + TwoFactorTypeEmail string = "email" + TwoFactorTypeTotp string = "totp" +) + +type GetDefaultAccount2FAConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID +} + +func (s *Services) GetDefaultAccount2FAConfig( + ctx context.Context, + opts GetDefaultAccount2FAConfigOptions, +) (dtos.Account2FAConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, account2FAConfigLocation, "GetDefaultAccount2FAConfig").With( + "accountPublicID", opts.AccountPublicID, + ) + logger.InfoContext(ctx, "Getting default account 2FA config...") + + config, err := s.database.FindDefaultAccount2FAConfigByAccountPublicID(ctx, opts.AccountPublicID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "Default account 2FA config not found", "error", err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + logger.ErrorContext(ctx, "Failed to get default account 2FA config", "error", err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Default account 2FA config found") + return dtos.MapAccount2FAConfigToDTO(&config), nil +} + +type GetAccount2FAConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID + TwoFAType database.TwoFactorType +} + +func (s *Services) GetAccount2FAConfig( + ctx context.Context, + opts GetAccount2FAConfigOptions, +) (dtos.Account2FAConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, account2FAConfigLocation, "GetAccount2FAConfig").With( + "accountPublicID", opts.AccountPublicID, + "twoFAType", opts.TwoFAType, + ) + logger.InfoContext(ctx, "Getting account 2FA config...") + + config, err := s.database.FindAccount2FAConfigByAccountPublicIDAndType(ctx, database.FindAccount2FAConfigByAccountPublicIDAndTypeParams{ + AccountPublicID: opts.AccountPublicID, + TwoFactorType: opts.TwoFAType, + }) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "Account 2FA config not found", "error", err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + logger.ErrorContext(ctx, "Failed to get account 2FA config", "error", err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Account 2FA config found") + return dtos.MapAccount2FAConfigToDTO(&config), nil +} + +type getDefaultAccount2FAConfigInternalOptions struct { + requestID string + accountPublicID uuid.UUID +} + +func (s *Services) getDefaultAccount2FAConfigInternal( + ctx context.Context, + opts getDefaultAccount2FAConfigInternalOptions, +) (*dtos.Account2FAConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.requestID, account2FAConfigLocation, "getDefaultAccount2FAConfigInternal").With( + "accountPublicID", opts.accountPublicID, + ) + logger.InfoContext(ctx, "Getting default account 2FA config...") + + config, err := s.database.FindDefaultAccount2FAConfigByAccountPublicID(ctx, opts.accountPublicID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code == exceptions.CodeNotFound { + logger.InfoContext(ctx, "Default account 2FA config not found", "error", err) + return nil, nil + } + + logger.ErrorContext(ctx, "Failed to get default account 2FA config", "error", err) + return nil, serviceErr + } + + dto := dtos.MapAccount2FAConfigToDTO(&config) + logger.InfoContext(ctx, "Default account 2FA config found") + return &dto, nil +} + +func Map2FAType(twoFAType string) (database.TwoFactorType, *exceptions.ServiceError) { + switch twoFAType { + case TwoFactorTypeEmail: + return database.TwoFactorTypeEmail, nil + case TwoFactorTypeTotp: + return database.TwoFactorTypeTotp, nil + default: + return "", exceptions.NewValidationError("invalid two factor type") + } +} + +type buildStoreAccountTOTPOptions struct { + requestID string + accountID int32 + queries *database.Queries +} + +func (s *Services) buildStoreAccountTOTP( + ctx context.Context, + opts buildStoreAccountTOTPOptions, +) crypto.StoreTOTP { + logger := s.buildLogger(opts.requestID, authLocation, "buildStoreAccountTOTP").With( + "AccountID", opts.accountID, + ) + logger.InfoContext(ctx, "Building store account TOTP function...") + + return func(dekKID, encSecret string, hashedCode []byte, url string) *exceptions.ServiceError { + var serviceErr *exceptions.ServiceError + + qrs := s.mapQueries(opts.queries) + id, err := qrs.CreateTotp(ctx, database.CreateTotpParams{ + DekKid: dekKID, + Url: url, + Secret: encSecret, + RecoveryCodes: hashedCode, + Usage: database.TotpUsageAccount, + AccountID: opts.accountID, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to create TOTP", "error", err) + serviceErr = exceptions.FromDBError(err) + return serviceErr + } + + if err = qrs.CreateAccountTotp(ctx, database.CreateAccountTotpParams{ + AccountID: opts.accountID, + TotpID: id, + }); err != nil { + logger.ErrorContext(ctx, "Failed to create account recovery keys", "error", err) + serviceErr = exceptions.FromDBError(err) + return serviceErr + } + + return nil + } +} + +type createAccount2FAConfigInternalOptions struct { + requestID string + isDefault bool + accountDTO dtos.AccountDTO + defaultConfig *dtos.Account2FAConfigDTO +} + +func (s *Services) createTOTPAccount2FAConfig( + ctx context.Context, + opts createAccount2FAConfigInternalOptions, +) (dtos.Account2FAConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.requestID, account2FAConfigLocation, "createTOTPAccount2FAConfig").With( + "accountID", opts.accountDTO.ID(), + "isDefault", opts.isDefault, + ) + logger.InfoContext(ctx, "Creating account 2FA config...") + + var serviceErr *exceptions.ServiceError + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return dtos.Account2FAConfigDTO{}, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + twoFAConfig, err := qrs.CreateAccount2FAConfig(ctx, database.CreateAccount2FAConfigParams{ + AccountID: opts.accountDTO.ID(), + AccountPublicID: opts.accountDTO.PublicID, + TwoFactorType: database.TwoFactorTypeTotp, + IsDefault: opts.isDefault || opts.defaultConfig == nil, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account 2FA config", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + totpKey, err := s.crypto.GenerateTotpKey(ctx, crypto.GenerateTotpKeyOptions{ + RequestID: opts.requestID, + Email: opts.accountDTO.Email, + GetDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ + RequestID: opts.requestID, + AccountID: opts.accountDTO.ID(), + Queries: qrs, + }), + StoreTOTPfn: s.buildStoreAccountTOTP(ctx, buildStoreAccountTOTPOptions{ + requestID: opts.requestID, + accountID: opts.accountDTO.ID(), + queries: qrs, + }), + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to generate TOTP", "error", err) + serviceErr = exceptions.NewInternalServerError() + return dtos.Account2FAConfigDTO{}, serviceErr + } + + if opts.defaultConfig != nil && opts.isDefault { + if _, err := qrs.UpdateAccount2FAConfig(ctx, database.UpdateAccount2FAConfigParams{ + ID: opts.defaultConfig.ID(), + IsDefault: false, + }); err != nil { + logger.ErrorContext(ctx, "Failed to update default account 2FA config", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + } + + account, err := qrs.UpdateAccountVersion(ctx, opts.accountDTO.ID()) + if err != nil { + logger.ErrorContext(ctx, "Failed to update account version", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ + RequestID: opts.requestID, + Token: s.jwt.Create2FAToken(tokens.Account2FATokenOptions{ + PublicID: account.PublicID, + Version: account.Version, + }), + GetJWKfn: s.BuildGetGlobalEncryptedJWKFn(ctx, BuildEncryptedJWKFnOptions{ + RequestID: opts.requestID, + KeyType: database.TokenKeyType2faAuthentication, + TTL: s.jwt.Get2FATTL(), + Queries: qrs, + }), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.requestID, + Queries: qrs, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.requestID, + Queries: qrs, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.requestID, + Queries: qrs, + }), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to sign 2FA token", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Account 2FA config created successfully") + return dtos.MapAccount2FAConfigTOTPToDTO( + &twoFAConfig, + signedToken, + totpKey.Img(), + totpKey.Codes(), + s.jwt.Get2FATTL(), + "Please scan QR Code with your authentication app", + ), nil +} + +type createEmailAccount2FAConfigInternalOptions struct { + requestID string + isDefault bool + accountDTO dtos.AccountDTO + defaultConfig *dtos.Account2FAConfigDTO +} + +func (s *Services) createEmailAccount2FAConfig( + ctx context.Context, + opts createEmailAccount2FAConfigInternalOptions, +) (dtos.Account2FAConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.requestID, account2FAConfigLocation, "createEmailAccount2FAConfig").With( + "accountID", opts.accountDTO.ID(), + "isDefault", opts.isDefault, + ) + logger.InfoContext(ctx, "Creating account 2FA config...") + + var serviceErr *exceptions.ServiceError + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return dtos.Account2FAConfigDTO{}, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + twoFAConfig, err := qrs.CreateAccount2FAConfig(ctx, database.CreateAccount2FAConfigParams{ + AccountID: opts.accountDTO.ID(), + AccountPublicID: opts.accountDTO.PublicID, + TwoFactorType: database.TwoFactorTypeEmail, + IsDefault: opts.isDefault || opts.defaultConfig == nil, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account 2FA config", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + if opts.defaultConfig != nil && opts.isDefault { + if _, err := qrs.UpdateAccount2FAConfig(ctx, database.UpdateAccount2FAConfigParams{ + ID: opts.defaultConfig.ID(), + IsDefault: false, + }); err != nil { + logger.ErrorContext(ctx, "Failed to update default account 2FA config", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + } + + code, err := s.cache.AddTwoFactorCode(ctx, cache.AddTwoFactorCodeOptions{ + RequestID: opts.requestID, + AccountID: opts.accountDTO.ID(), + TTL: s.jwt.Get2FATTL(), + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to generate two factor Code", "error", err) + serviceErr = exceptions.NewInternalServerError() + return dtos.Account2FAConfigDTO{}, serviceErr + } + + if err := s.mail.Publish2FAEmail(ctx, mailer.TwoFactorEmailOptions{ + RequestID: opts.requestID, + Email: opts.accountDTO.Email, + Name: fmt.Sprintf("%s %s", opts.accountDTO.GivenName, opts.accountDTO.FamilyName), + Code: code, + }); err != nil { + logger.ErrorContext(ctx, "Failed to publish two factor email", "error", err) + serviceErr = exceptions.NewInternalServerError() + return dtos.Account2FAConfigDTO{}, serviceErr + } + + account, err := qrs.UpdateAccountVersion(ctx, opts.accountDTO.ID()) + if err != nil { + logger.ErrorContext(ctx, "Failed to update account version", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ + RequestID: opts.requestID, + Token: s.jwt.Create2FAToken(tokens.Account2FATokenOptions{ + PublicID: account.PublicID, + Version: account.Version, + }), + GetJWKfn: s.BuildGetGlobalEncryptedJWKFn(ctx, BuildEncryptedJWKFnOptions{ + RequestID: opts.requestID, + KeyType: database.TokenKeyType2faAuthentication, + TTL: s.jwt.Get2FATTL(), + Queries: qrs, + }), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.requestID, + Queries: qrs, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.requestID, + Queries: qrs, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.requestID, + Queries: qrs, + }), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to sign 2FA token", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Account 2FA config created successfully") + return dtos.MapAccount2FAConfigCodeToDTO( + &twoFAConfig, + signedToken, + s.jwt.Get2FATTL(), + "Please enter the code sent to your email", + ), nil +} + +type CreateAccount2FAConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + TwoFAType string + IsDefault bool +} + +func (s *Services) CreateAccount2FAConfig( + ctx context.Context, + opts CreateAccount2FAConfigOptions, +) (dtos.Account2FAConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, account2FAConfigLocation, "CreateAccount2FAConfig").With( + "accountPublicID", opts.AccountPublicID, + "accountVersion", opts.AccountVersion, + "twoFAType", opts.TwoFAType, + "isDefault", opts.IsDefault, + ) + logger.InfoContext(ctx, "Creating account 2FA config...") + + twoFAType, serviceErr := Map2FAType(opts.TwoFAType) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map two factor type", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account ID by public ID and version", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + configDTO, serviceErr := s.getDefaultAccount2FAConfigInternal(ctx, getDefaultAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountPublicID: opts.AccountPublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get default account 2FA config", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + switch twoFAType { + case database.TwoFactorTypeTotp: + return s.createTOTPAccount2FAConfig(ctx, createAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountDTO: accountDTO, + isDefault: opts.IsDefault, + defaultConfig: configDTO, + }) + case database.TwoFactorTypeEmail: + return s.createEmailAccount2FAConfig(ctx, createEmailAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountDTO: accountDTO, + isDefault: opts.IsDefault, + defaultConfig: configDTO, + }) + default: + logger.WarnContext(ctx, "Invalid two factor type", "twoFAType", opts.TwoFAType) + return dtos.Account2FAConfigDTO{}, exceptions.NewValidationError("invalid two factor type") + } +} + +type SetAccount2FAConfigDefaultOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + TwoFAType string +} + +func (s *Services) SetAccount2FAConfigDefault( + ctx context.Context, + opts SetAccount2FAConfigDefaultOptions, +) (dtos.Account2FAConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, account2FAConfigLocation, "SetAccount2FAConfigDefault").With( + "accountPublicID", opts.AccountPublicID, + "accountVersion", opts.AccountVersion, + "twoFAType", opts.TwoFAType, + ) + logger.InfoContext(ctx, "Setting account 2FA config default...") + + twoFAType, serviceErr := Map2FAType(opts.TwoFAType) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map two factor type", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + if _, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }); serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account ID by public ID and version", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + configDTO, serviceErr := s.GetAccount2FAConfig(ctx, GetAccount2FAConfigOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + TwoFAType: twoFAType, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account 2FA config", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + if configDTO.IsDefault { + logger.WarnContext(ctx, "Account 2FA config is already default", "twoFAType", opts.TwoFAType) + return dtos.Account2FAConfigDTO{}, exceptions.NewForbiddenError() + } + + defaultConfig, serviceErr := s.getDefaultAccount2FAConfigInternal(ctx, getDefaultAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountPublicID: opts.AccountPublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get default account 2FA config", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return dtos.Account2FAConfigDTO{}, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + config, err := qrs.UpdateAccount2FAConfig(ctx, database.UpdateAccount2FAConfigParams{ + ID: configDTO.ID(), + IsDefault: true, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to update account 2FA config", "error", err) + return dtos.Account2FAConfigDTO{}, exceptions.FromDBError(err) + } + + if defaultConfig != nil { + if _, err := qrs.UpdateAccount2FAConfig(ctx, database.UpdateAccount2FAConfigParams{ + ID: defaultConfig.ID(), + IsDefault: false, + }); err != nil { + logger.ErrorContext(ctx, "Failed to update default account 2FA config", "error", err) + return dtos.Account2FAConfigDTO{}, exceptions.FromDBError(err) + } + } + + logger.InfoContext(ctx, "Account 2FA config set as default successfully") + return dtos.MapAccount2FAConfigToDTO(&config), nil +} + +type DeleteAccount2FAConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + TwoFAType string +} + +func (s *Services) DeleteAccount2FAConfig( + ctx context.Context, + opts DeleteAccount2FAConfigOptions, +) (dtos.Account2FAConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, account2FAConfigLocation, "DeleteAccount2FAConfig").With( + "accountPublicID", opts.AccountPublicID, + "accountVersion", opts.AccountVersion, + "twoFAType", opts.TwoFAType, + ) + logger.InfoContext(ctx, "Deleting account 2FA config...") + + twoFAType, serviceErr := Map2FAType(opts.TwoFAType) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map two factor type", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account by public ID and version", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + config, err := s.database.FindAccount2FAConfigByAccountPublicIDAndType(ctx, database.FindAccount2FAConfigByAccountPublicIDAndTypeParams{ + AccountPublicID: opts.AccountPublicID, + TwoFactorType: twoFAType, + }) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "Account 2FA config not found", "error", err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + logger.ErrorContext(ctx, "Failed to get account 2FA config", "error", err) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + code, err := s.cache.SaveDelete2FAConfigRequest(ctx, cache.SaveDelete2FAConfigRequestOptions{ + RequestID: opts.RequestID, + PrefixType: cache.SensitiveRequestAccountPrefix, + PublicID: opts.AccountPublicID, + TwoFAType: opts.TwoFAType, + TTL: s.jwt.Get2FATTL(), + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to save delete 2FA config request", "error", err) + return dtos.Account2FAConfigDTO{}, exceptions.NewInternalServerError() + } + + if opts.TwoFAType == "email" { + if err := s.mail.Publish2FAEmail(ctx, mailer.TwoFactorEmailOptions{ + RequestID: opts.RequestID, + Email: accountDTO.Email, + Name: fmt.Sprintf("%s %s", accountDTO.GivenName, accountDTO.FamilyName), + Code: code, + }); err != nil { + logger.ErrorContext(ctx, "Failed to publish two factor email", "error", err) + return dtos.Account2FAConfigDTO{}, exceptions.NewInternalServerError() + } + } + + signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ + RequestID: opts.RequestID, + Token: s.jwt.Create2FAToken(tokens.Account2FATokenOptions{ + PublicID: accountDTO.PublicID, + Version: accountDTO.Version(), + }), + GetJWKfn: s.BuildGetGlobalEncryptedJWKFn(ctx, BuildEncryptedJWKFnOptions{ + RequestID: opts.RequestID, + KeyType: database.TokenKeyType2faAuthentication, + TTL: s.jwt.Get2FATTL(), + }), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.RequestID, + }), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to sign 2FA token", "serviceError", serviceErr) + return dtos.Account2FAConfigDTO{}, serviceErr + } + + msg := "Please enter the code sent to your email" + if opts.TwoFAType == "totp" { + msg = "Please enter the code from your authentication app" + } + + return dtos.MapAccount2FAConfigCodeToDTO( + &config, + signedToken, + s.jwt.Get2FATTL(), + msg, + ), nil +} + +type ConfirmDeleteAccount2FAConfigOptions struct { + RequestID string + PublicID uuid.UUID + Version int32 + TwoFAType string + Code string +} + +func (s *Services) ConfirmDeleteAccount2FAConfig( + ctx context.Context, + opts ConfirmDeleteAccount2FAConfigOptions, +) (dtos.AuthDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, account2FAConfigLocation, "ConfirmDeleteAccount2FAConfig").With( + "publicID", opts.PublicID, + "version", opts.Version, + "twoFAType", opts.TwoFAType, + "code", opts.Code, + ) + logger.InfoContext(ctx, "Confirming delete account 2FA config...") + + twoFAType, serviceErr := Map2FAType(opts.TwoFAType) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map two factor type", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.PublicID, + Version: opts.Version, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account by public ID and version", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr + } + + ok, err := s.cache.VerifyDelete2FAConfigRequest(ctx, cache.VerifyDelete2FAConfigRequestOptions{ + RequestID: opts.RequestID, + PrefixType: cache.SensitiveRequestAccountPrefix, + PublicID: opts.PublicID, + TwoFAType: opts.TwoFAType, + Code: opts.Code, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to verify delete 2FA config request", "error", err) + return dtos.AuthDTO{}, exceptions.NewInternalServerError() + } + if !ok { + logger.WarnContext(ctx, "Delete 2FA config request does not match") + return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() + } + + configs, err := s.database.FindAccount2FAConfigsByAccountPublicID(ctx, accountDTO.PublicID) + if err != nil { + logger.ErrorContext(ctx, "Failed to get account 2FA configs", "error", err) + return dtos.AuthDTO{}, exceptions.NewInternalServerError() + } + + length := len(configs) + if length == 0 { + logger.WarnContext(ctx, "Account 2FA configs not found") + return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() + } + + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return dtos.AuthDTO{}, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + if length == 1 { + if err := qrs.DeleteAccount2FAConfig(ctx, configs[0].ID); err != nil { + logger.ErrorContext(ctx, "Failed to delete account 2FA config", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AuthDTO{}, serviceErr + } + + account, err := qrs.UpdateAccountVersion(ctx, accountDTO.ID()) + if err != nil { + logger.ErrorContext(ctx, "Failed to update account version", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AuthDTO{}, serviceErr + } + + accountDTO = dtos.MapAccountToDTO(&account) + return s.GenerateFullAuthDTO( + ctx, + logger, + qrs, + opts.RequestID, + &accountDTO, + []tokens.AccountScope{tokens.AccountScopeAdmin}, + "Account 2FA config deleted successfully", + ) + } + + idx := slices.IndexFunc(configs, func(config database.Account2faConfig) bool { + return config.TwoFactorType == twoFAType + }) + if idx == -1 { + logger.WarnContext(ctx, "Account 2FA config not found") + return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() + } + + config := configs[idx] + if err := qrs.DeleteAccount2FAConfig(ctx, config.ID); err != nil { + logger.ErrorContext(ctx, "Failed to delete account 2FA config", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AuthDTO{}, serviceErr + } + if config.IsDefault { + uIdx := 0 + if idx == 0 { + uIdx = 1 + } + + if _, err := qrs.UpdateAccount2FAConfig(ctx, database.UpdateAccount2FAConfigParams{ + ID: configs[uIdx].ID, + IsDefault: true, + }); err != nil { + logger.ErrorContext(ctx, "Failed to update account 2FA config", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AuthDTO{}, serviceErr + } + } + + return s.GenerateFullAuthDTO( + ctx, + logger, + qrs, + opts.RequestID, + &accountDTO, + []tokens.AccountScope{tokens.AccountScopeAdmin}, + "Account 2FA config deleted successfully", + ) +} diff --git a/idp/internal/services/account_credentials_registration_iat.go b/idp/internal/services/account_credentials_registration_iat.go index 4cf26d7..b6d4021 100644 --- a/idp/internal/services/account_credentials_registration_iat.go +++ b/idp/internal/services/account_credentials_registration_iat.go @@ -73,9 +73,15 @@ func (s *Services) CreateAccountCredentialsRegistrationIAT( KeyType: database.TokenKeyTypeDynamicRegistration, TTL: s.jwt.GetDynamicRegistrationTTL(), }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.RequestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.RequestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.RequestID), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.RequestID, + }), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign account credentials registration IAT", "serviceError", serviceErr) diff --git a/idp/internal/services/accounts.go b/idp/internal/services/accounts.go index 1e9d148..90d3aeb 100644 --- a/idp/internal/services/accounts.go +++ b/idp/internal/services/accounts.go @@ -399,59 +399,69 @@ func (s *Services) UpdateAccountEmail( return dtos.AuthDTO{}, exceptions.NewConflictError("Email already in use") } - if accountDTO.TwoFactorType != database.TwoFactorTypeNone { - logger.InfoContext(ctx, "Account has 2FA enabled", "twoFactorType", accountDTO.TwoFactorType) - - err = s.cache.SaveUpdateEmailRequest(ctx, cache.SaveUpdateEmailRequestOptions{ - RequestID: opts.RequestID, - PrefixType: cache.SensitiveRequestAccountPrefix, - PublicID: accountDTO.PublicID, - Email: newEmail, - DurationSeconds: s.jwt.Get2FATTL(), - }) + default2FAConfig, serviceErr := s.getDefaultAccount2FAConfigInternal(ctx, getDefaultAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountPublicID: accountDTO.PublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get default 2FA config", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr + } + if default2FAConfig == nil { + account, err := s.updateAccountEmailInDB(ctx, logger, accountDTO.ID(), accountDTO.Email, newEmail) if err != nil { - logger.ErrorContext(ctx, "Failed to cache email update request", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() + logger.ErrorContext(ctx, "Failed to update account email", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AuthDTO{}, serviceErr } - authDTO, serviceErr := s.generate2FAAuth( + logger.InfoContext(ctx, "Updated account email successfully") + accountDTO = dtos.MapAccountToDTO(&account) + return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, - "Please provide two factor code to confirm email update", + []tokens.AccountScope{tokens.AccountScopeAdmin}, + "Email updated successfully", ) - if serviceErr != nil { - return dtos.AuthDTO{}, serviceErr - } - - logger.InfoContext(ctx, "Email update request cached successfully") - return authDTO, serviceErr } - account, err := s.updateAccountEmailInDB(ctx, logger, accountDTO.ID(), accountDTO.Email, newEmail) + logger.InfoContext(ctx, "Account has 2FA enabled", "twoFactorType", default2FAConfig.TwoFactorType) + err = s.cache.SaveUpdateEmailRequest(ctx, cache.SaveUpdateEmailRequestOptions{ + RequestID: opts.RequestID, + PrefixType: cache.SensitiveRequestAccountPrefix, + PublicID: accountDTO.PublicID, + Email: newEmail, + DurationSeconds: s.jwt.Get2FATTL(), + }) if err != nil { - logger.ErrorContext(ctx, "Failed to update account email", "error", err) - serviceErr = exceptions.FromDBError(err) - return dtos.AuthDTO{}, serviceErr + logger.ErrorContext(ctx, "Failed to cache email update request", "error", err) + return dtos.AuthDTO{}, exceptions.NewInternalServerError() } - logger.InfoContext(ctx, "Updated account email successfully") - accountDTO = dtos.MapAccountToDTO(&account) - return s.GenerateFullAuthDTO( + authDTO, serviceErr := s.generate2FAAuth( ctx, logger, opts.RequestID, &accountDTO, - []tokens.AccountScope{tokens.AccountScopeAdmin}, - "Email updated successfully", + default2FAConfig.TwoFactorType, + "Please provide two factor code to confirm email update", ) + if serviceErr != nil { + return dtos.AuthDTO{}, serviceErr + } + + logger.InfoContext(ctx, "Email update request cached successfully") + return authDTO, serviceErr } type ConfirmUpdateAccountEmailOptions struct { RequestID string PublicID uuid.UUID Version int32 + TwoFAType tokens.TwoFAType Code string } @@ -488,7 +498,15 @@ func (s *Services) ConfirmUpdateAccountEmail( return dtos.AuthDTO{}, serviceErr } - if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { + if serviceErr := s.verifyAccount2FAInternal(ctx, verifyAccount2FAInternalOptions{ + requestID: opts.RequestID, + accountID: accountDTO.ID(), + accountPublicID: opts.PublicID, + accountVersion: opts.Version, + twoFAType: opts.TwoFAType, + code: opts.Code, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to verify account two factor", "serviceError", serviceErr) return dtos.AuthDTO{}, serviceErr } @@ -503,6 +521,7 @@ func (s *Services) ConfirmUpdateAccountEmail( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -586,17 +605,24 @@ func (s *Services) UpdateAccountPassword( return dtos.AuthDTO{}, exceptions.NewValidationError("Invalid password") } - if accountDTO.TwoFactorType != database.TwoFactorTypeNone { - logger.InfoContext(ctx, "Account has 2FA enabled", "twoFactorType", accountDTO.TwoFactorType) + default2FAConfig, serviceErr := s.getDefaultAccount2FAConfigInternal(ctx, getDefaultAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountPublicID: accountDTO.PublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get default 2FA config", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr + } - err = s.cache.SaveUpdatePasswordRequest(ctx, cache.SaveUpdatePasswordRequestOptions{ + if default2FAConfig != nil { + logger.InfoContext(ctx, "Account has 2FA enabled", "twoFactorType", default2FAConfig.TwoFactorType) + if err := s.cache.SaveUpdatePasswordRequest(ctx, cache.SaveUpdatePasswordRequestOptions{ RequestID: opts.RequestID, PrefixType: cache.SensitiveRequestAccountPrefix, PublicID: accountDTO.PublicID, NewPassword: opts.NewPassword, DurationSeconds: s.jwt.Get2FATTL(), - }) - if err != nil { + }); err != nil { logger.ErrorContext(ctx, "Failed to cache password update request", "error", err) return dtos.AuthDTO{}, exceptions.NewInternalServerError() } @@ -606,6 +632,7 @@ func (s *Services) UpdateAccountPassword( logger, opts.RequestID, &accountDTO, + default2FAConfig.TwoFactorType, "Please provide two factor code to confirm password update", ) if serviceErr != nil { @@ -632,6 +659,7 @@ func (s *Services) UpdateAccountPassword( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -643,6 +671,7 @@ type ConfirmUpdateAccountPasswordOptions struct { RequestID string PublicID uuid.UUID Version int32 + TwoFAType tokens.TwoFAType Code string } @@ -679,7 +708,15 @@ func (s *Services) ConfirmUpdateAccountPassword( return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() } - if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { + if serviceErr := s.verifyAccount2FAInternal(ctx, verifyAccount2FAInternalOptions{ + requestID: opts.RequestID, + accountID: accountDTO.ID(), + accountPublicID: opts.PublicID, + accountVersion: opts.Version, + twoFAType: opts.TwoFAType, + code: opts.Code, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to verify account two factor", "serviceError", serviceErr) return dtos.AuthDTO{}, serviceErr } @@ -694,6 +731,7 @@ func (s *Services) ConfirmUpdateAccountPassword( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -789,6 +827,7 @@ func (s *Services) CreateAccountPassword( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -908,8 +947,16 @@ func (s *Services) UpdateAccountUsername( return dtos.AuthDTO{}, exceptions.NewConflictError("Username already in use") } - if accountDTO.TwoFactorType != database.TwoFactorTypeNone { - logger.InfoContext(ctx, "Account has 2FA enabled", "twoFactorType", accountDTO.TwoFactorType) + default2FAConfig, serviceErr := s.getDefaultAccount2FAConfigInternal(ctx, getDefaultAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountPublicID: accountDTO.PublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get default 2FA config", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr + } + if default2FAConfig != nil { + logger.InfoContext(ctx, "Account has 2FA enabled", "twoFactorType", default2FAConfig.TwoFactorType) if err := s.cache.SaveUpdateUsernameRequest(ctx, cache.SaveUpdateUsernameRequestOptions{ RequestID: opts.RequestID, @@ -927,6 +974,7 @@ func (s *Services) UpdateAccountUsername( logger, opts.RequestID, &accountDTO, + default2FAConfig.TwoFactorType, "Please provide two factor code to confirm username update", ) if serviceErr != nil { @@ -951,6 +999,7 @@ func (s *Services) UpdateAccountUsername( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -962,6 +1011,7 @@ type ConfirmUpdateAccountUsernameOptions struct { RequestID string PublicID uuid.UUID Version int32 + TwoFAType tokens.TwoFAType Code string } @@ -998,7 +1048,15 @@ func (s *Services) ConfirmUpdateAccountUsername( return dtos.AuthDTO{}, serviceErr } - if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { + if serviceErr := s.verifyAccount2FAInternal(ctx, verifyAccount2FAInternalOptions{ + requestID: opts.RequestID, + accountID: accountDTO.ID(), + accountPublicID: opts.PublicID, + accountVersion: opts.Version, + twoFAType: opts.TwoFAType, + code: opts.Code, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to verify account two factor", "serviceError", serviceErr) return dtos.AuthDTO{}, serviceErr } @@ -1016,6 +1074,7 @@ func (s *Services) ConfirmUpdateAccountUsername( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -1075,9 +1134,17 @@ func (s *Services) DeleteAccount( } } - if accountDTO.TwoFactorType != database.TwoFactorTypeNone { - logger.InfoContext(ctx, "Account has 2FA enabled", "twoFactorType", accountDTO.TwoFactorType) + default2FAConfig, serviceErr := s.getDefaultAccount2FAConfigInternal(ctx, getDefaultAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountPublicID: accountDTO.PublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get default 2FA config", "serviceError", serviceErr) + return false, dtos.AuthDTO{}, serviceErr + } + if default2FAConfig != nil { + logger.InfoContext(ctx, "Account has 2FA enabled", "twoFactorType", default2FAConfig.TwoFactorType) if err := s.cache.SaveDeleteAccountRequest(ctx, cache.SaveDeleteAccountRequestOptions{ RequestID: opts.RequestID, PrefixType: cache.SensitiveRequestAccountPrefix, @@ -1093,6 +1160,7 @@ func (s *Services) DeleteAccount( logger, opts.RequestID, &accountDTO, + default2FAConfig.TwoFactorType, "Please provide two factor code to confirm account deletion", ) if serviceErr != nil { @@ -1116,6 +1184,7 @@ type ConfirmDeleteAccountOptions struct { RequestID string PublicID uuid.UUID Version int32 + TwoFAType tokens.TwoFAType Code string } @@ -1152,7 +1221,14 @@ func (s *Services) ConfirmDeleteAccount( return serviceErr } - if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { + if serviceErr := s.verifyAccount2FAInternal(ctx, verifyAccount2FAInternalOptions{ + requestID: opts.RequestID, + accountID: accountDTO.ID(), + accountPublicID: opts.PublicID, + accountVersion: opts.Version, + twoFAType: opts.TwoFAType, + code: opts.Code, + }); serviceErr != nil { return serviceErr } diff --git a/idp/internal/services/auth.go b/idp/internal/services/auth.go index d2072b8..93d2efd 100644 --- a/idp/internal/services/auth.go +++ b/idp/internal/services/auth.go @@ -104,16 +104,28 @@ func (s *Services) ProcessAccountAuthHeader( func (s *Services) Process2FAAuthHeader( ctx context.Context, opts ProcessAuthHeaderOptions, -) (tokens.AccountClaims, *exceptions.ServiceError) { - return s.processPurposeAuthHeader( - ctx, - processPurposeAuthHeaderOptions{ - requestID: opts.RequestID, - authHeader: opts.AuthHeader, - tokenPurpose: tokens.TokenPurpose2FA, - tokenKeyType: database.TokenKeyType2faAuthentication, - }, +) (tokens.AccountClaims, tokens.TwoFAType, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, authLocation, "Process2FAAuthHeader") + logger.InfoContext(ctx, "Processing purpose auth header...") + + token, serviceErr := extractAuthHeaderToken(opts.AuthHeader) + if serviceErr != nil { + return tokens.AccountClaims{}, "", serviceErr + } + + accountClaims, twoFAType, err := s.jwt.Verify2FAToken( + token, + s.BuildGetGlobalPublicKeyFn(ctx, BuildGetGlobalVerifyKeyFnOptions{ + RequestID: opts.RequestID, + KeyType: database.TokenKeyType2faAuthentication, + }), ) + if err != nil { + logger.ErrorContext(ctx, "Failed to verify purpose token", "error", err) + return tokens.AccountClaims{}, "", exceptions.NewUnauthorizedError() + } + + return accountClaims, twoFAType, nil } func (s *Services) GetRefreshTTL() int64 { @@ -146,9 +158,15 @@ func (s *Services) sendConfirmationEmail( KeyType: database.TokenKeyTypeEmailVerification, TTL: s.jwt.GetConfirmationTTL(), }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, requestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, requestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, requestID), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: requestID, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: requestID, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: requestID, + }), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign confirmation token", "serviceError", serviceErr) @@ -217,6 +235,7 @@ func (s *Services) RegisterAccount( func (s *Services) GenerateFullAuthDTO( ctx context.Context, logger *slog.Logger, + qrs *database.Queries, requestID string, accountDTO *dtos.AccountDTO, scopes []tokens.AccountScope, @@ -240,10 +259,20 @@ func (s *Services) GenerateFullAuthDTO( RequestID: requestID, KeyType: database.TokenKeyTypeAccess, TTL: s.jwt.GetAccessTTL(), + Queries: qrs, + }), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: requestID, + Queries: qrs, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: requestID, + Queries: qrs, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: requestID, + Queries: qrs, }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, requestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, requestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, requestID), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign access token", "serviceError", serviceErr) @@ -267,10 +296,20 @@ func (s *Services) GenerateFullAuthDTO( RequestID: requestID, KeyType: database.TokenKeyTypeRefresh, TTL: s.jwt.GetRefreshTTL(), + Queries: qrs, + }), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: requestID, + Queries: qrs, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: requestID, + Queries: qrs, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: requestID, + Queries: qrs, }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, requestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, requestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, requestID), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign refresh token", "serviceError", serviceErr) @@ -332,6 +371,7 @@ func (s *Services) ConfirmAccount( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -344,6 +384,7 @@ func (s *Services) generate2FAAuth( logger *slog.Logger, requestID string, accountDTO *dtos.AccountDTO, + twoFAType database.TwoFactorType, msg string, ) (dtos.AuthDTO, *exceptions.ServiceError) { twoFAToken, err := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ @@ -357,16 +398,22 @@ func (s *Services) generate2FAAuth( KeyType: database.TokenKeyType2faAuthentication, TTL: s.jwt.Get2FATTL(), }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, requestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, requestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, requestID), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: requestID, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: requestID, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: requestID, + }), }) if err != nil { logger.ErrorContext(ctx, "Failed to sign 2FA token", "error", err) return dtos.AuthDTO{}, exceptions.NewInternalServerError() } - if accountDTO.TwoFactorType == database.TwoFactorTypeEmail { + if twoFAType == database.TwoFactorTypeEmail { code, err := s.cache.AddTwoFactorCode(ctx, cache.AddTwoFactorCodeOptions{ RequestID: requestID, AccountID: accountDTO.ID(), @@ -455,13 +502,21 @@ func (s *Services) LoginAccount( } } - switch accountDTO.TwoFactorType { - case database.TwoFactorTypeEmail, database.TwoFactorTypeTotp: + default2FaConfig, serviceErr := s.getDefaultAccount2FAConfigInternal(ctx, getDefaultAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountPublicID: accountDTO.PublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get default account 2FA config", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr + } + if default2FaConfig != nil { authDTO, serviceErr := s.generate2FAAuth( ctx, logger, opts.RequestID, &accountDTO, + default2FaConfig.TwoFactorType, "Please provide two factor code", ) if serviceErr != nil { @@ -473,6 +528,7 @@ func (s *Services) LoginAccount( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -583,30 +639,65 @@ func (s *Services) VerifyAccountTotp( return verified, nil } -func (s *Services) verifyAccountTwoFactor( +func mapTokens2FAType(twoFAType tokens.TwoFAType) (database.TwoFactorType, *exceptions.ServiceError) { + switch twoFAType { + case tokens.TwoFATypeTOTP: + return database.TwoFactorTypeTotp, nil + case tokens.TwoFATypeEmail: + return database.TwoFactorTypeEmail, nil + default: + return "", exceptions.NewUnauthorizedError() + } +} + +type verifyAccount2FAInternalOptions struct { + requestID string + accountID int32 + accountPublicID uuid.UUID + accountVersion int32 + twoFAType tokens.TwoFAType + code string +} + +func (s *Services) verifyAccount2FAInternal( ctx context.Context, - requestID string, - accountDTO *dtos.AccountDTO, - code string, + opts verifyAccount2FAInternalOptions, ) *exceptions.ServiceError { - logger := s.buildLogger(requestID, authLocation, "verifyAccountTwoFactor").With( - "accountPublicId", accountDTO.PublicID, - "twoFactorType", accountDTO.TwoFactorType, + logger := s.buildLogger(opts.requestID, authLocation, "verifyDefaultAccount2FA").With( + "accountPublicId", opts.accountPublicID, ) logger.InfoContext(ctx, "Verifying account two factor...") - switch accountDTO.TwoFactorType { - case database.TwoFactorTypeNone: - logger.WarnContext(ctx, "User has two factor inactive") - return exceptions.NewForbiddenError() + twoFAType, serviceErr := mapTokens2FAType(opts.twoFAType) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map two factor type", "serviceError", serviceErr) + return serviceErr + } + + configDTO, serviceErr := s.GetAccount2FAConfig(ctx, GetAccount2FAConfigOptions{ + RequestID: opts.requestID, + AccountPublicID: opts.accountPublicID, + TwoFAType: twoFAType, + }) + if serviceErr != nil { + if serviceErr.Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "Account 2FA config not found", "serviceError", serviceErr) + return exceptions.NewForbiddenError() + } + + logger.ErrorContext(ctx, "Failed to get account 2FA config", "serviceError", serviceErr) + return serviceErr + } + + switch configDTO.TwoFactorType { case database.TwoFactorTypeTotp: ok, serviceErr := s.VerifyAccountTotp(ctx, VerifyAccountTotpOptions{ - RequestID: requestID, - ID: accountDTO.ID(), - Code: code, + RequestID: opts.requestID, + ID: opts.accountID, + Code: opts.code, }) if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to verify TOTP Code", "error", serviceErr) + logger.ErrorContext(ctx, "Failed to verify TOTP Code", "serviceError", serviceErr) return serviceErr } if !ok { @@ -615,9 +706,9 @@ func (s *Services) verifyAccountTwoFactor( } case database.TwoFactorTypeEmail: ok, err := s.cache.VerifyTwoFactorCode(ctx, cache.VerifyTwoFactorCodeOptions{ - RequestID: requestID, - AccountID: accountDTO.ID(), - Code: code, + RequestID: opts.requestID, + AccountID: opts.accountID, + Code: opts.code, }) if err != nil { logger.ErrorContext(ctx, "Error verifying Code", "error", err) @@ -627,55 +718,59 @@ func (s *Services) verifyAccountTwoFactor( logger.WarnContext(ctx, "Failed to verify Code") return exceptions.NewUnauthorizedError() } + default: + logger.WarnContext(ctx, "Invalid two factor type", "twoFactorType", configDTO.TwoFactorType) + return exceptions.NewUnauthorizedError() } + logger.InfoContext(ctx, "Account two factor verified successfully") return nil } -type TwoFactorLoginAccountOptions struct { - RequestID string - PublicID uuid.UUID - Version int32 - Code string +type VerifyAccount2FAOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + TwoFAType tokens.TwoFAType + Code string } -func (s *Services) TwoFactorLoginAccount( +func (s *Services) VerifyAccount2FA( ctx context.Context, - opts TwoFactorLoginAccountOptions, + opts VerifyAccount2FAOptions, ) (dtos.AuthDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, authLocation, "TwoFactorLoginAccount") - logger.InfoContext(ctx, "2FA logging in account...") + logger := s.buildLogger(opts.RequestID, authLocation, "VerifyAccount2FA").With( + "accountPublicId", opts.AccountPublicID, + "twoFactorType", opts.TwoFAType, + ) + logger.InfoContext(ctx, "Verifying account two factor...") - accountDTO, serviceErr := s.GetAccountByPublicID(ctx, GetAccountByPublicIDOptions{ + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ RequestID: opts.RequestID, - PublicID: opts.PublicID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, }) if serviceErr != nil { - if serviceErr.Code != exceptions.CodeNotFound { - return dtos.AuthDTO{}, serviceErr - } - - logger.WarnContext(ctx, "Account was not found", "error", serviceErr) - return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() - } - - logger = logger.With("accountId", accountDTO.ID()) - accountVersion := accountDTO.Version() - if accountVersion != opts.Version { - logger.WarnContext(ctx, "Account versions do not match", - "accessTokenVersion", opts.Version, - "accountVersion", accountVersion, - ) - return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() + logger.ErrorContext(ctx, "Failed to get account ID by public ID and version", "serviceError", serviceErr) + return dtos.AuthDTO{}, serviceErr } - if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { + if serviceErr := s.verifyAccount2FAInternal(ctx, verifyAccount2FAInternalOptions{ + requestID: opts.RequestID, + accountID: accountDTO.ID(), + accountPublicID: opts.AccountPublicID, + accountVersion: opts.AccountVersion, + twoFAType: opts.TwoFAType, + code: opts.Code, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to verify account two factor", "serviceError", serviceErr) return dtos.AuthDTO{}, serviceErr } return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -812,6 +907,7 @@ func (s *Services) RefreshTokenAccount( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, data.Scopes, @@ -853,9 +949,15 @@ func (s *Services) ForgotAccountPassword( KeyType: database.TokenKeyTypePasswordReset, TTL: s.jwt.GetResetTTL(), }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.RequestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.RequestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.RequestID), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.RequestID, + }), }) if err != nil { logger.ErrorContext(ctx, "Failed to generate rest token", "error", err) @@ -1042,8 +1144,12 @@ func (s *Services) RecoverAccount( logger.ErrorContext(ctx, "Failed to get account by public ID and version", "error", serviceErr) return dtos.AuthDTO{}, serviceErr } - if accountDTO.TwoFactorType != database.TwoFactorTypeTotp { - logger.WarnContext(ctx, "Account does not have TOTP enabled") + if _, serviceErr := s.GetAccount2FAConfig(ctx, GetAccount2FAConfigOptions{ + RequestID: opts.RequestID, + AccountPublicID: accountDTO.PublicID, + TwoFAType: database.TwoFactorTypeTotp, + }); serviceErr != nil { + logger.WarnContext(ctx, "Account does not have TOTP enabled", "serviceError", serviceErr) return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() } @@ -1082,9 +1188,15 @@ func (s *Services) RecoverAccount( KeyType: database.TokenKeyType2faAuthentication, TTL: s.jwt.Get2FATTL(), }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.RequestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.RequestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.RequestID), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.RequestID, + }), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign 2FA token", "serviceError", serviceErr) @@ -1167,473 +1279,3 @@ func (s *Services) GetAccountAuthProvider( logger.InfoContext(ctx, "Retrieved account auth provider successfully") return dtos.MapAccountAuthProviderToDTO(&authProvider), nil } - -type buildStoreAccountTOTPOptions struct { - requestID string - accountID int32 -} - -func (s *Services) buildStoreAccountTOTP( - ctx context.Context, - opts buildStoreAccountTOTPOptions, -) crypto.StoreTOTP { - logger := s.buildLogger(opts.requestID, authLocation, "buildStoreAccountTOTP").With( - "AccountID", opts.accountID, - ) - logger.InfoContext(ctx, "Building store account TOTP function...") - - return func(dekKID, encSecret string, hashedCode []byte, url string) *exceptions.ServiceError { - var serviceErr *exceptions.ServiceError - qrs, txn, err := s.database.BeginTx(ctx) - if err != nil { - logger.ErrorContext(ctx, "Failed to start transaction", "error", err) - return exceptions.FromDBError(err) - } - defer func() { - logger.DebugContext(ctx, "Finalizing transaction") - s.database.FinalizeTx(ctx, txn, err, serviceErr) - }() - - id, err := qrs.CreateTotp(ctx, database.CreateTotpParams{ - DekKid: dekKID, - Url: url, - Secret: encSecret, - RecoveryCodes: hashedCode, - Usage: database.TotpUsageAccount, - AccountID: opts.accountID, - }) - if err != nil { - logger.ErrorContext(ctx, "Failed to create TOTP", "error", err) - serviceErr = exceptions.FromDBError(err) - return serviceErr - } - - if err = qrs.CreateAccountTotp(ctx, database.CreateAccountTotpParams{ - AccountID: opts.accountID, - TotpID: id, - }); err != nil { - logger.ErrorContext(ctx, "Failed to create account recovery keys", "error", err) - serviceErr = exceptions.FromDBError(err) - return serviceErr - } - - if err = qrs.UpdateAccountTwoFactorType(ctx, database.UpdateAccountTwoFactorTypeParams{ - TwoFactorType: database.TwoFactorTypeTotp, - ID: opts.accountID, - }); err != nil { - logger.ErrorContext(ctx, "Failed to update account 2FA", "error", err) - serviceErr = exceptions.FromDBError(err) - return serviceErr - } - - return nil - } -} - -type updateAccount2FAOptions struct { - requestID string - id int32 - email string - prev2FAType database.TwoFactorType -} - -func (s *Services) updateAccountTOTP2FA( - ctx context.Context, - opts updateAccount2FAOptions, -) (dtos.AuthDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.requestID, authLocation, "updateAccountTOTP2FA").With( - "id", opts.id, - ) - logger.InfoContext(ctx, "Update account TOTP 2FA...") - - totpKey, err := s.crypto.GenerateTotpKey(ctx, crypto.GenerateTotpKeyOptions{ - RequestID: opts.requestID, - Email: opts.email, - GetDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ - RequestID: opts.requestID, - AccountID: opts.id, - }), - StoreTOTPfn: s.buildStoreAccountTOTP(ctx, buildStoreAccountTOTPOptions{ - requestID: opts.requestID, - accountID: opts.id, - }), - }) - if err != nil { - logger.ErrorContext(ctx, "Failed to generate TOTP", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - - accountDTO, serviceErr := s.GetAccountByID(ctx, GetAccountByIDOptions{ - RequestID: opts.requestID, - ID: opts.id, - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to get account by ID", "error", serviceErr) - return dtos.AuthDTO{}, serviceErr - } - - signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ - RequestID: opts.requestID, - Token: s.jwt.Create2FAToken(tokens.Account2FATokenOptions{ - PublicID: accountDTO.PublicID, - Version: accountDTO.Version(), - }), - GetJWKfn: s.BuildGetGlobalEncryptedJWKFn(ctx, BuildEncryptedJWKFnOptions{ - RequestID: opts.requestID, - KeyType: database.TokenKeyType2faAuthentication, - TTL: s.jwt.Get2FATTL(), - }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.requestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.requestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.requestID), - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to sign 2FA token", "serviceError", serviceErr) - return dtos.AuthDTO{}, serviceErr - } - - return dtos.NewAuthDTOWithData( - signedToken, - "Please scan QR Code with your authentication app", - map[string]string{ - "image": totpKey.Img(), - "recovery_keys": totpKey.Codes(), - }, - s.jwt.Get2FATTL(), - ), nil -} - -func (s *Services) updateAccountEmail2FA( - ctx context.Context, - opts updateAccount2FAOptions, -) (dtos.AuthDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.requestID, authLocation, "updateAccountEmail2FA").With( - "id", opts.id, - ) - logger.InfoContext(ctx, "Update account email 2FA...") - - code, err := s.cache.AddTwoFactorCode(ctx, cache.AddTwoFactorCodeOptions{ - RequestID: opts.requestID, - AccountID: opts.id, - }) - if err != nil { - logger.ErrorContext(ctx, "Failed to generate two factor Code", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - - if opts.prev2FAType == database.TwoFactorTypeTotp { - var serviceErr *exceptions.ServiceError - qrs, txn, err := s.database.BeginTx(ctx) - if err != nil { - logger.ErrorContext(ctx, "Failed to start transaction", "error", err) - return dtos.AuthDTO{}, exceptions.FromDBError(err) - } - defer func() { - logger.DebugContext(ctx, "Finalizing transaction") - s.database.FinalizeTx(ctx, txn, err, serviceErr) - }() - - if err = qrs.UpdateAccountTwoFactorType(ctx, database.UpdateAccountTwoFactorTypeParams{ - TwoFactorType: database.TwoFactorTypeEmail, - ID: opts.id, - }); err != nil { - logger.ErrorContext(ctx, "Failed to enable 2FA email", "error", err) - serviceErr = exceptions.FromDBError(err) - return dtos.AuthDTO{}, serviceErr - } - - if err := qrs.DeleteAccountRecoveryKeys(ctx, opts.id); err != nil { - logger.ErrorContext(ctx, "Failed to delete recovery keys", "error", err) - serviceErr = exceptions.FromDBError(err) - return dtos.AuthDTO{}, serviceErr - } - } else { - if err = s.database.UpdateAccountTwoFactorType(ctx, database.UpdateAccountTwoFactorTypeParams{ - TwoFactorType: database.TwoFactorTypeEmail, - ID: opts.id, - }); err != nil { - logger.ErrorContext(ctx, "Failed to enable 2FA email", "error", err) - return dtos.AuthDTO{}, exceptions.FromDBError(err) - } - } - - accountDTO, serviceErr := s.GetAccountByID(ctx, GetAccountByIDOptions{ - RequestID: opts.requestID, - ID: opts.id, - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to get account by ID", "serviceError", serviceErr) - return dtos.AuthDTO{}, serviceErr - } - - if err := s.mail.Publish2FAEmail(ctx, mailer.TwoFactorEmailOptions{ - RequestID: opts.requestID, - Email: accountDTO.Email, - Name: fmt.Sprintf("%s %s", accountDTO.GivenName, accountDTO.FamilyName), - Code: code, - }); err != nil { - logger.ErrorContext(ctx, "Failed to publish two factor email", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - - signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ - RequestID: opts.requestID, - Token: s.jwt.Create2FAToken(tokens.Account2FATokenOptions{ - PublicID: accountDTO.PublicID, - Version: accountDTO.Version(), - }), - GetJWKfn: s.BuildGetGlobalEncryptedJWKFn(ctx, BuildEncryptedJWKFnOptions{ - RequestID: opts.requestID, - KeyType: database.TokenKeyType2faAuthentication, - TTL: s.jwt.Get2FATTL(), - }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.requestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.requestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.requestID), - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to sign 2FA token", "serviceError", serviceErr) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - - return dtos.NewTempAuthDTO(signedToken, "Please provide email two factor code", s.jwt.Get2FATTL()), nil -} - -func (s *Services) disableAccount2FA( - ctx context.Context, - opts updateAccount2FAOptions, -) (dtos.AuthDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.requestID, authLocation, "disableAccount2FA").With( - "id", opts.id, - ) - logger.InfoContext(ctx, "Update account TOTP 2FA...") - - if opts.prev2FAType == database.TwoFactorTypeTotp { - var serviceErr *exceptions.ServiceError - qrs, txn, err := s.database.BeginTx(ctx) - if err != nil { - logger.ErrorContext(ctx, "Failed to start transaction", "error", err) - return dtos.AuthDTO{}, exceptions.FromDBError(err) - } - defer func() { - logger.DebugContext(ctx, "Finalizing transaction") - s.database.FinalizeTx(ctx, txn, err, serviceErr) - }() - - if err = qrs.UpdateAccountTwoFactorType(ctx, database.UpdateAccountTwoFactorTypeParams{ - TwoFactorType: database.TwoFactorTypeNone, - ID: opts.id, - }); err != nil { - logger.ErrorContext(ctx, "Failed to disable 2FA", "error", err) - serviceErr = exceptions.FromDBError(err) - return dtos.AuthDTO{}, serviceErr - } - - if err := qrs.DeleteAccountRecoveryKeys(ctx, opts.id); err != nil { - logger.ErrorContext(ctx, "Failed to delete recovery keys", "error", err) - serviceErr = exceptions.FromDBError(err) - return dtos.AuthDTO{}, serviceErr - } - } else { - if err := s.database.UpdateAccountTwoFactorType(ctx, database.UpdateAccountTwoFactorTypeParams{ - TwoFactorType: database.TwoFactorTypeNone, - ID: opts.id, - }); err != nil { - logger.ErrorContext(ctx, "Failed to disable 2FA", "error", err) - return dtos.AuthDTO{}, exceptions.FromDBError(err) - } - } - - accountDTO, serviceErr := s.GetAccountByID(ctx, GetAccountByIDOptions{ - RequestID: opts.requestID, - ID: opts.id, - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to get account by ID", "serviceError", serviceErr) - return dtos.AuthDTO{}, serviceErr - } - - return s.GenerateFullAuthDTO( - ctx, - logger, - opts.requestID, - &accountDTO, - []tokens.AccountScope{tokens.AccountScopeAdmin}, - "Successfully disabled oauth", - ) -} - -type UpdateAccount2FAOptions struct { - RequestID string - PublicID uuid.UUID - Version int32 - TwoFactorType string - Password string -} - -func (s *Services) UpdateAccount2FA( - ctx context.Context, - opts UpdateAccount2FAOptions, -) (dtos.AuthDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, authLocation, "UpdateAccount2FA").With( - "publicID", opts.PublicID, - "twoFactorType", opts.TwoFactorType, - ) - logger.InfoContext(ctx, "Updating account 2FA...") - - twoFactorType, serviceErr := mapTwoFactorType(opts.TwoFactorType) - if serviceErr != nil { - logger.WarnContext(ctx, "Invalid two factor type", "serviceError", serviceErr) - return dtos.AuthDTO{}, serviceErr - } - - accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ - RequestID: opts.RequestID, - PublicID: opts.PublicID, - Version: opts.Version, - }) - if serviceErr != nil { - return dtos.AuthDTO{}, serviceErr - } - - count, err := s.database.CountAccountAuthProvidersByEmailAndProvider( - ctx, - database.CountAccountAuthProvidersByEmailAndProviderParams{ - Email: accountDTO.Email, - Provider: database.AuthProviderLocal, - }, - ) - if err != nil { - logger.ErrorContext(ctx, "Failed to count auth providers", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - if count > 0 { - if opts.Password == "" { - logger.WarnContext(ctx, "Password is required for email auth Provider") - return dtos.AuthDTO{}, exceptions.NewValidationError("password is required") - } - - ok, err := utils.Argon2CompareHash(opts.Password, accountDTO.Password()) - if err != nil { - logger.ErrorContext(ctx, "Failed to compare password hashes", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - if !ok { - logger.WarnContext(ctx, "Passwords do not match") - return dtos.AuthDTO{}, exceptions.NewValidationError("Invalid password") - } - } - - if accountDTO.TwoFactorType == twoFactorType { - logger.WarnContext(ctx, "Account already uses given 2FA type", "twoFactorType", twoFactorType) - return dtos.AuthDTO{}, exceptions.NewValidationError("Account already uses given 2FA type") - } - - updateOpts := updateAccount2FAOptions{ - requestID: opts.RequestID, - id: accountDTO.ID(), - email: accountDTO.Email, - prev2FAType: accountDTO.TwoFactorType, - } - if accountDTO.TwoFactorType == database.TwoFactorTypeNone { - switch twoFactorType { - case database.TwoFactorTypeTotp: - logger.InfoContext(ctx, "Enabling TOTP 2FA") - return s.updateAccountTOTP2FA(ctx, updateOpts) - case database.TwoFactorTypeEmail: - logger.InfoContext(ctx, "Enabling email 2FA") - return s.updateAccountEmail2FA(ctx, updateOpts) - default: - logger.WarnContext(ctx, "Unknown two factor type, it must be 'totp' or 'email'") - return dtos.AuthDTO{}, exceptions.NewForbiddenError() - } - } - - if err := s.cache.SaveTwoFactorUpdateRequest(ctx, cache.SaveTwoFactorUpdateRequestOptions{ - RequestID: opts.RequestID, - PrefixType: cache.SensitiveRequestAccountPrefix, - PublicID: accountDTO.PublicID, - TwoFactorType: database.TwoFactorType(opts.TwoFactorType), - DurationSeconds: s.jwt.Get2FATTL(), - }); err != nil { - logger.ErrorContext(ctx, "Failed to save two-factor update request", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - - authDTO, serviceErr := s.generate2FAAuth( - ctx, - logger, - opts.RequestID, - &accountDTO, - "Please provide two factor code to confirm two factor update", - ) - if serviceErr != nil { - return dtos.AuthDTO{}, serviceErr - } - - return authDTO, nil -} - -type ConfirmUpdateAccount2FAUpdateOptions struct { - RequestID string - PublicID uuid.UUID - Version int32 - Code string -} - -func (s *Services) ConfirmUpdateAccount2FAUpdate( - ctx context.Context, - opts ConfirmUpdateAccount2FAUpdateOptions, -) (dtos.AuthDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, authLocation, "ConfirmUpdateAccount2FAUpdate").With( - "publicID", opts.PublicID, - ) - logger.InfoContext(ctx, "Confirming account 2FA update...") - - twoFactorType, err := s.cache.GetTwoFactorUpdateRequest(ctx, cache.GetTwoFactorUpdateRequestOptions{ - RequestID: opts.RequestID, - PrefixType: cache.SensitiveRequestAccountPrefix, - PublicID: opts.PublicID, - }) - if err != nil { - logger.ErrorContext(ctx, "Failed to get two-factor update request", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - if twoFactorType == "" { - logger.WarnContext(ctx, "Two-factor update request not found") - return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() - } - - accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ - RequestID: opts.RequestID, - PublicID: opts.PublicID, - Version: opts.Version, - }) - if serviceErr != nil { - return dtos.AuthDTO{}, serviceErr - } - - if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { - return dtos.AuthDTO{}, serviceErr - } - - updateOpts := updateAccount2FAOptions{ - requestID: opts.RequestID, - id: accountDTO.ID(), - email: accountDTO.Email, - prev2FAType: accountDTO.TwoFactorType, - } - switch twoFactorType { - case database.TwoFactorTypeTotp: - logger.InfoContext(ctx, "Enabling TOTP 2FA") - return s.updateAccountTOTP2FA(ctx, updateOpts) - case database.TwoFactorTypeEmail: - logger.InfoContext(ctx, "Enabling email 2FA") - return s.updateAccountEmail2FA(ctx, updateOpts) - case database.TwoFactorTypeNone: - logger.InfoContext(ctx, "Disabling 2FA") - return s.disableAccount2FA(ctx, updateOpts) - default: - return dtos.AuthDTO{}, exceptions.NewForbiddenError() - } -} diff --git a/idp/internal/services/deks.go b/idp/internal/services/deks.go index 93574d6..a5411d3 100644 --- a/idp/internal/services/deks.go +++ b/idp/internal/services/deks.go @@ -53,11 +53,13 @@ func (s *Services) buildStoreGlobalDEKfn( ctx context.Context, requestID string, data map[string]string, + queries *database.Queries, ) crypto.StoreDEK { logger := s.buildLogger(requestID, deksLocation, "storeGlobalDEK") logger.InfoContext(ctx, "Building store function for global DEK...") return func(dekID string, encryptedDEK string, kekID uuid.UUID) (int32, *exceptions.ServiceError) { - dekEnt, err := s.database.CreateDataEncryptionKey(ctx, database.CreateDataEncryptionKeyParams{ + qrs := s.mapQueries(queries) + dekEnt, err := qrs.CreateDataEncryptionKey(ctx, database.CreateDataEncryptionKeyParams{ Kid: dekID, KekKid: kekID, Usage: database.DekUsageGlobal, @@ -87,15 +89,20 @@ func (s *Services) buildStoreGlobalDEKfn( } } +type BuildGetGlobalDEKFnOptions struct { + RequestID string + Queries *database.Queries +} + func (s *Services) BuildGetEncGlobalDEKFn( ctx context.Context, - requestID string, + opts BuildGetGlobalDEKFnOptions, ) crypto.GetDEKtoEncrypt { - logger := s.buildLogger(requestID, deksLocation, "BuildGetEncGlobalDEKFn") + logger := s.buildLogger(opts.RequestID, deksLocation, "BuildGetEncGlobalDEKFn") logger.InfoContext(ctx, "Build GetDEKtoEncrypt function...") return func() (crypto.DEKID, crypto.EncryptedDEK, uuid.UUID, *exceptions.ServiceError) { kid, dek, kekKID, ok, err := s.cache.GetEncDEK(ctx, cache.GetEncDEKOptions{ - RequestID: requestID, + RequestID: opts.RequestID, Suffix: "global", }) if err != nil { @@ -107,7 +114,8 @@ func (s *Services) BuildGetEncGlobalDEKFn( return kid, dek, kekKID, nil } - dekEnt, err := s.database.FindValidGlobalDataEncryptionKey(ctx, time.Now().Add(-2*time.Hour)) + qrs := s.mapQueries(opts.Queries) + dekEnt, err := qrs.FindValidGlobalDataEncryptionKey(ctx, time.Now().Add(-2*time.Hour)) if err != nil { serviceErr := exceptions.FromDBError(err) if serviceErr.Code != exceptions.CodeNotFound { @@ -115,7 +123,7 @@ func (s *Services) BuildGetEncGlobalDEKFn( return "", "", uuid.Nil, serviceErr } - kekKID, serviceErr := s.GetOrCreateGlobalKEK(ctx, requestID) + kekKID, serviceErr := s.GetOrCreateGlobalKEK(ctx, opts.RequestID) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to get or create global KEK", "error", serviceErr) return "", "", uuid.Nil, serviceErr @@ -123,9 +131,9 @@ func (s *Services) BuildGetEncGlobalDEKFn( data := make(map[string]string) if serviceErr := s.createDEK(ctx, createDEKOptions{ - requestID: requestID, + requestID: opts.RequestID, kekKID: kekKID, - storeFN: s.buildStoreGlobalDEKfn(ctx, requestID, data), + storeFN: s.buildStoreGlobalDEKfn(ctx, opts.RequestID, data, qrs), }); serviceErr != nil { logger.ErrorContext(ctx, "Failed to create global DEK", "serviceError", serviceErr) return "", "", uuid.Nil, serviceErr @@ -153,14 +161,14 @@ func (s *Services) BuildGetEncGlobalDEKFn( func (s *Services) BuildGetGlobalDecDEKFn( ctx context.Context, - requestID string, + opts BuildGetGlobalDEKFnOptions, ) crypto.GetDEKtoDecrypt { - logger := s.buildLogger(requestID, deksLocation, "BuildGetGlobalDecDEKFn") + logger := s.buildLogger(opts.RequestID, deksLocation, "BuildGetGlobalDecDEKFn") logger.InfoContext(ctx, "Building GetDEKtoDecrypt function for global DEK...") return func(kid string) (crypto.EncryptedDEK, crypto.KEKID, crypto.IsExpiredDEK, *exceptions.ServiceError) { dek, kekKID, expiresAt, ok, err := s.cache.GetDecDEK(ctx, cache.GetDecDEKOptions{ - RequestID: requestID, + RequestID: opts.RequestID, KID: kid, Prefix: "global", }) @@ -175,14 +183,15 @@ func (s *Services) BuildGetGlobalDecDEKFn( return dek, kekKID, now.After(expiresAt), nil } - dekEnt, err := s.database.FindDataEncryptionKeyByKID(ctx, kid) + qrs := s.mapQueries(opts.Queries) + dekEnt, err := qrs.FindDataEncryptionKeyByKID(ctx, kid) if err != nil { logger.ErrorContext(ctx, "Failed to get DEK", "error", err) return "", uuid.Nil, false, exceptions.FromDBError(err) } if err := s.cache.SaveDecDEK(ctx, cache.SaveDecDEKOptions{ - RequestID: requestID, + RequestID: opts.RequestID, DEK: dekEnt.Dek, KID: dekEnt.Kid, KEKid: dekEnt.KekKid, @@ -336,6 +345,7 @@ func (s *Services) BuildGetEncAccountDEKfn( requestID: opts.RequestID, accountID: opts.AccountID, data: data, + queries: qrs, }), }, ); serviceErr != nil { @@ -408,8 +418,9 @@ func (s *Services) BuildGetDecAccountDEKFn( return dek, kekKID, now.After(expiresAt), nil } + qrs := s.mapQueries(opts.Queries) logger.InfoContext(ctx, "DEK not found in cache, checking database...") - dekEnt, err := s.mapQueries(opts.Queries).FindAccountDataEncryptionKeyByAccountIDAndKID( + dekEnt, err := qrs.FindAccountDataEncryptionKeyByAccountIDAndKID( ctx, database.FindAccountDataEncryptionKeyByAccountIDAndKIDParams{ AccountID: opts.AccountID, diff --git a/idp/internal/services/dtos/account.go b/idp/internal/services/dtos/account.go index 3ebd99b..f2d0dcc 100644 --- a/idp/internal/services/dtos/account.go +++ b/idp/internal/services/dtos/account.go @@ -13,12 +13,11 @@ import ( ) type AccountDTO struct { - PublicID uuid.UUID `json:"id"` - GivenName string `json:"given_name"` - FamilyName string `json:"family_name"` - Email string `json:"email"` - Username string `json:"username"` - TwoFactorType database.TwoFactorType `json:"two_factor_type"` + PublicID uuid.UUID `json:"id"` + GivenName string `json:"given_name"` + FamilyName string `json:"family_name"` + Email string `json:"email"` + Username string `json:"username"` id int32 version int32 @@ -50,7 +49,6 @@ func MapAccountToDTO(account *database.Account) AccountDTO { GivenName: account.GivenName, FamilyName: account.FamilyName, Email: account.Email, - TwoFactorType: account.TwoFactorType, Username: account.Username, emailVerified: account.EmailVerified, password: account.Password.String, diff --git a/idp/internal/services/dtos/account_2fa_config.go b/idp/internal/services/dtos/account_2fa_config.go new file mode 100644 index 0000000..15ae54f --- /dev/null +++ b/idp/internal/services/dtos/account_2fa_config.go @@ -0,0 +1,94 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package dtos + +import ( + "github.com/tugascript/devlogs/idp/internal/providers/database" +) + +type Account2FATOTPConfigDTO struct { + AccessToken string `json:"access_token"` + Image string `json:"image"` + RecoveryKeys string `json:"recovery_keys"` + ExpiresIn int64 `json:"expires_in"` + Message string `json:"message"` +} + +type Account2FACodeConfigDTO struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + Message string `json:"message"` +} + +type Account2FAConfigDTO struct { + id int32 + + TwoFactorType database.TwoFactorType `json:"two_factor_type"` + IsDefault bool `json:"is_default"` + CreatedAt int64 `json:"created_at"` + + // TOTP 2FA + *Account2FATOTPConfigDTO + + // Email & TOTP Deletion 2FA + *Account2FACodeConfigDTO +} + +func (a *Account2FAConfigDTO) ID() int32 { + return a.id +} + +func MapAccount2FAConfigToDTO(account2FAConfig *database.Account2faConfig) Account2FAConfigDTO { + return Account2FAConfigDTO{ + id: account2FAConfig.ID, + TwoFactorType: account2FAConfig.TwoFactorType, + IsDefault: account2FAConfig.IsDefault, + CreatedAt: account2FAConfig.CreatedAt.Unix(), + } +} + +func MapAccount2FAConfigTOTPToDTO( + account2FAConfig *database.Account2faConfig, + accessToken string, + image string, + recoveryKeys string, + expiresIn int64, + message string, +) Account2FAConfigDTO { + return Account2FAConfigDTO{ + id: account2FAConfig.ID, + TwoFactorType: account2FAConfig.TwoFactorType, + IsDefault: account2FAConfig.IsDefault, + CreatedAt: account2FAConfig.CreatedAt.Unix(), + Account2FATOTPConfigDTO: &Account2FATOTPConfigDTO{ + AccessToken: accessToken, + Image: image, + RecoveryKeys: recoveryKeys, + ExpiresIn: expiresIn, + Message: message, + }, + } +} + +func MapAccount2FAConfigCodeToDTO( + account2FAConfig *database.Account2faConfig, + accessToken string, + expiresIn int64, + message string, +) Account2FAConfigDTO { + return Account2FAConfigDTO{ + id: account2FAConfig.ID, + TwoFactorType: account2FAConfig.TwoFactorType, + IsDefault: account2FAConfig.IsDefault, + CreatedAt: account2FAConfig.CreatedAt.Unix(), + Account2FACodeConfigDTO: &Account2FACodeConfigDTO{ + AccessToken: accessToken, + ExpiresIn: expiresIn, + Message: message, + }, + } +} diff --git a/idp/internal/services/dtos/user.go b/idp/internal/services/dtos/user.go index d8b69d6..22ce422 100644 --- a/idp/internal/services/dtos/user.go +++ b/idp/internal/services/dtos/user.go @@ -16,10 +16,9 @@ import ( ) type UserDTO struct { - PublicID uuid.UUID `json:"id"` - Email string `json:"email"` - Username string `json:"username"` - TwoFactorType database.TwoFactorType `json:"two_factor_type"` + PublicID uuid.UUID `json:"id"` + Email string `json:"email"` + Username string `json:"username"` DataDTO id int32 @@ -55,7 +54,6 @@ func MapUserToDTO(user *database.User) (UserDTO, *exceptions.ServiceError) { PublicID: user.PublicID, Email: user.Email, Username: user.Username, - TwoFactorType: user.TwoFactorType, DataDTO: data, version: user.Version, emailVerified: user.EmailVerified, diff --git a/idp/internal/services/helpers.go b/idp/internal/services/helpers.go index 94ca3a9..1f82d3e 100644 --- a/idp/internal/services/helpers.go +++ b/idp/internal/services/helpers.go @@ -198,20 +198,6 @@ func mapDomain(baseURI string, domain string) (string, *exceptions.ServiceError) return host, nil } -func mapTwoFactorType(twoFactorType string) (database.TwoFactorType, *exceptions.ServiceError) { - if len(twoFactorType) < 4 { - return "", exceptions.NewValidationError("invalid two factor type") - } - - dbTwoFactorType := database.TwoFactorType(twoFactorType) - switch dbTwoFactorType { - case database.TwoFactorTypeNone, database.TwoFactorTypeEmail, database.TwoFactorTypeTotp: - return dbTwoFactorType, nil - default: - return "", exceptions.NewValidationError("invalid two factor type") - } -} - func mapCCSecretStorageMode(authMethod string) database.SecretStorageMode { if authMethod == AuthMethodClientSecretJWT { return database.SecretStorageModeEncrypted diff --git a/idp/internal/services/jwks.go b/idp/internal/services/jwks.go index c45c489..0973312 100644 --- a/idp/internal/services/jwks.go +++ b/idp/internal/services/jwks.go @@ -91,6 +91,7 @@ type buildStoreGlobalJWKfnOptions struct { requestID string keyType database.TokenKeyType data map[string]string + queries *database.Queries } func (s *Services) buildStoreGlobalJWKfn( @@ -113,7 +114,8 @@ func (s *Services) buildStoreGlobalJWKfn( } logger.InfoContext(ctx, "Storing global JWK", "kid", kid, "cryptoSuite", dbCryptoSuite) - id, err := s.database.CreateTokenSigningKey(ctx, database.CreateTokenSigningKeyParams{ + qrs := s.mapQueries(opts.queries) + id, err := qrs.CreateTokenSigningKey(ctx, database.CreateTokenSigningKeyParams{ Kid: kid, KeyType: opts.keyType, PublicKey: pubKeyBytes, @@ -151,6 +153,7 @@ type BuildEncryptedJWKFnOptions struct { RequestID string KeyType database.TokenKeyType TTL int64 + Queries *database.Queries } func (s *Services) BuildGetGlobalEncryptedJWKFn( @@ -183,7 +186,8 @@ func (s *Services) BuildGetGlobalEncryptedJWKFn( } logger.InfoContext(ctx, "JWK not found in cache, checking database...") - jwkEnt, err := s.database.FindGlobalTokenSigningKey(ctx, database.FindGlobalTokenSigningKeyParams{ + qrs := s.mapQueries(opts.Queries) + jwkEnt, err := qrs.FindGlobalTokenSigningKey(ctx, database.FindGlobalTokenSigningKeyParams{ KeyType: opts.KeyType, ExpiresAt: time.Now().Add(-1 * (time.Hour + time.Duration(opts.TTL)*time.Second)), }) @@ -200,11 +204,15 @@ func (s *Services) BuildGetGlobalEncryptedJWKFn( requestID: opts.RequestID, keyType: opts.KeyType, cryptoSuite: cryptoSuite, - getDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.RequestID), + getDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + Queries: qrs, + }), storeFN: s.buildStoreGlobalJWKfn(ctx, buildStoreGlobalJWKfnOptions{ requestID: opts.RequestID, keyType: opts.KeyType, data: data, + queries: qrs, }), }); serviceErr != nil { logger.ErrorContext(ctx, "Failed to create JWK", "serviceError", serviceErr) @@ -730,22 +738,28 @@ func (s *Services) GetAndCacheAccountDistributedJWK( return etag, jwks, nil } +type BuildUpdateJWKDEKFnOptions struct { + RequestID string + Queries *database.Queries +} + func (s *Services) BuildUpdateJWKDEKFn( ctx context.Context, - requestID string, + opts BuildUpdateJWKDEKFnOptions, ) crypto.StoreReEncryptedData { - logger := s.buildLogger(requestID, jwkLocation, "BuildUpdateJWKDEKFn") + logger := s.buildLogger(opts.RequestID, jwkLocation, "BuildUpdateJWKDEKFn") logger.InfoContext(ctx, "Building update JWK DEK function...") return func(kid crypto.EntityID, dekID crypto.DEKID, encPrivKey crypto.DEKCiphertext) *exceptions.ServiceError { logger.InfoContext(ctx, "Updating JWK DEK...") - jwkEnt, err := s.database.FindTokenSigningKeyByKID(ctx, kid) + qrs := s.mapQueries(opts.Queries) + jwkEnt, err := qrs.FindTokenSigningKeyByKID(ctx, kid) if err != nil { logger.ErrorContext(ctx, "Failed to get JWK from database", "error", err) return exceptions.FromDBError(err) } - if err := s.database.UpdateTokenSigningKeyDEKAndPrivateKey( + if err := qrs.UpdateTokenSigningKeyDEKAndPrivateKey( ctx, database.UpdateTokenSigningKeyDEKAndPrivateKeyParams{ ID: jwkEnt.ID, diff --git a/idp/internal/services/oauth.go b/idp/internal/services/oauth.go index 2a70149..fa25070 100644 --- a/idp/internal/services/oauth.go +++ b/idp/internal/services/oauth.go @@ -469,6 +469,7 @@ func (s *Services) OAuthLoginAccount( return s.GenerateFullAuthDTO( ctx, logger, + s.database.Queries, opts.RequestID, &accountDTO, []tokens.AccountScope{tokens.AccountScopeAdmin}, @@ -717,9 +718,15 @@ func (s *Services) generateClientCredentialsAuthentication( KeyType: database.TokenKeyTypeClientCredentials, TTL: s.jwt.GetAccountCredentialsTTL(), }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.requestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.requestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.requestID), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.requestID, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.requestID, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.requestID, + }), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign access token", "serviceError", serviceErr) diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index 0a8464c..0996752 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -706,7 +706,15 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( return "", "", false, exceptions.NewForbiddenError() } - if accountDTO.TwoFactorType != database.TwoFactorTypeNone { + default2FAConfig, serviceErr := s.getDefaultAccount2FAConfigInternal(ctx, getDefaultAccount2FAConfigInternalOptions{ + requestID: opts.RequestID, + accountPublicID: accountDTO.PublicID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get default 2FA config", "serviceError", serviceErr) + return "", "", false, serviceErr + } + if default2FAConfig != nil { logger.InfoContext(ctx, "Two-Factor is enabled, proceeding to 2FA step") sessionID, err := s.cache.SaveAccountCredentialsDynamicRegistrationIAT2FA( ctx, @@ -718,6 +726,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( Domain: data.Domain, ClientID: opts.ACCClientID, State: data.State, + TwoFAType: string(default2FAConfig.TwoFactorType), TwoFATTL: s.jwt.Get2FATTL(), }, ) @@ -726,7 +735,7 @@ func (s *Services) OAuthDynamicRegistrationIATLogin( return "", "", false, exceptions.NewInternalServerError() } - if accountDTO.TwoFactorType == database.TwoFactorTypeEmail { + if default2FAConfig.TwoFactorType == database.TwoFactorTypeEmail { code, err := s.cache.AddTwoFactorCode(ctx, cache.AddTwoFactorCodeOptions{ RequestID: opts.RequestID, AccountID: accountDTO.ID(), @@ -947,6 +956,17 @@ func (s *Services) OAuthDynamicRegistrationIAT2FAReRender( return twoFAhtml, nil } +func map2FATypeTokens(twoFAType string) (tokens.TwoFAType, *exceptions.ServiceError) { + switch twoFAType { + case TwoFactorTypeEmail: + return tokens.TwoFATypeEmail, nil + case TwoFactorTypeTotp: + return tokens.TwoFATypeTOTP, nil + default: + return "", exceptions.NewValidationError("invalid two factor type") + } +} + type OAuthDynamicRegistrationIATVerify2FACodeOptions struct { RequestID string ACCClientID string @@ -1000,6 +1020,11 @@ func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( logger.WarnContext(ctx, "Client IDs do not match", "sessionClientId", data.ClientID) return "", "", exceptions.NewUnauthorizedError() } + twoFAType, serviceErr := map2FATypeTokens(data.TwoFAType) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map two factor type", "serviceError", serviceErr) + return "", "", serviceErr + } accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ RequestID: opts.RequestID, @@ -1010,7 +1035,14 @@ func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( logger.WarnContext(ctx, "Failed to get account by public ID and version", "serviceError", serviceErr) return "", "", serviceErr } - if serviceErr := s.verifyAccountTwoFactor(ctx, opts.RequestID, &accountDTO, opts.Code); serviceErr != nil { + if serviceErr := s.verifyAccount2FAInternal(ctx, verifyAccount2FAInternalOptions{ + requestID: opts.RequestID, + accountID: accountDTO.ID(), + accountPublicID: accountDTO.PublicID, + accountVersion: accountDTO.Version(), + twoFAType: twoFAType, + code: opts.Code, + }); serviceErr != nil { logger.WarnContext(ctx, "Failed to verify account two factor", "serviceError", serviceErr) return "", "", serviceErr } @@ -1164,9 +1196,15 @@ func (s *Services) VerifyOAuthDynamicRegistrationIATCode( KeyType: database.TokenKeyTypeDynamicRegistration, TTL: tokenTTL, }), - GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, opts.RequestID), - GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, opts.RequestID), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.RequestID), + GetDecryptDEKfn: s.BuildGetGlobalDecDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + GetEncryptDEKfn: s.BuildGetEncGlobalDEKFn(ctx, BuildGetGlobalDEKFnOptions{ + RequestID: opts.RequestID, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.RequestID, + }), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign account credentials registration IAT", "serviceError", serviceErr) diff --git a/idp/internal/services/users.go b/idp/internal/services/users.go index fc07ea3..23f6786 100644 --- a/idp/internal/services/users.go +++ b/idp/internal/services/users.go @@ -512,7 +512,6 @@ func (s *Services) UpdateUser( Email: email, Username: username, UserData: data, - IsActive: opts.IsActive, EmailVerified: opts.EmailVerified, }) if err != nil { diff --git a/idp/internal/services/users_auth.go b/idp/internal/services/users_auth.go index 81a6675..e15a36a 100644 --- a/idp/internal/services/users_auth.go +++ b/idp/internal/services/users_auth.go @@ -201,7 +201,9 @@ func (s *Services) sendUserConfirmationEmail( RequestID: opts.requestID, AccountID: opts.accountID, }), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.requestID), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.requestID, + }), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign user token", "serviceError", serviceErr) @@ -334,7 +336,9 @@ func (s *Services) generateFullUserAuthDTO( RequestID: requestID, AccountID: accountID, }), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, requestID), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: requestID, + }), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign access token", "serviceError", serviceErr) @@ -373,7 +377,9 @@ func (s *Services) generateFullUserAuthDTO( RequestID: requestID, AccountID: accountID, }), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, requestID), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: requestID, + }), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign access token", "serviceError", serviceErr) @@ -633,75 +639,7 @@ func (s *Services) LoginUser( } } - switch userDTO.TwoFactorType { - case database.TwoFactorTypeEmail, database.TwoFactorTypeTotp: - logger.WarnContext(ctx, "User has two-factor authentication enabled") - twoFAToken, err := s.jwt.CreateUserPurposeToken(tokens.UserPurposeTokenOptions{ - TokenType: tokens.PurposeTokenTypeTwoFA, - AccountUsername: opts.AccountUsername, - UserPublicID: userDTO.PublicID, - UserVersion: userDTO.Version(), - AppClientID: appDTO.ClientID, - AppVersion: appDTO.Version(), - Path: paths.AppsBase + paths.UsersBase + paths.AuthLogin + paths.Auth2FA, - }) - if err != nil { - logger.ErrorContext(ctx, "Failed to create two-factor authentication token", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - - signedTwoFAToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ - RequestID: opts.RequestID, - Token: twoFAToken, - GetJWKfn: s.BuildGetEncryptedAccountJWKFn(ctx, BuildGetEncryptedAccountJWKFnOptions{ - RequestID: opts.RequestID, - KeyType: database.TokenKeyType2faAuthentication, - AccountID: opts.AccountID, - }), - GetDecryptDEKfn: s.BuildGetDecAccountDEKFn(ctx, BuildGetDecAccountDEKFnOptions{ - RequestID: opts.RequestID, - AccountID: opts.AccountID, - }), - GetEncryptDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ - RequestID: opts.RequestID, - AccountID: opts.AccountID, - }), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.RequestID), - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to sign two-factor authentication token", "serviceError", serviceErr) - return dtos.AuthDTO{}, serviceErr - } - - if userDTO.TwoFactorType == database.TwoFactorTypeEmail { - code, err := s.cache.AddTwoFactorCode(ctx, cache.AddTwoFactorCodeOptions{ - RequestID: opts.RequestID, - AccountID: opts.AccountID, - UserID: userDTO.ID(), - TTL: s.jwt.Get2FATTL(), - }) - if err != nil { - logger.ErrorContext(ctx, "Failed to add two-factor Code", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - - if err := s.mail.PublishUser2FAEmail(ctx, mailer.User2FAEmailOptions{ - RequestID: opts.RequestID, - AppName: appDTO.Name, - Email: userDTO.Email, - Code: code, - }); err != nil { - logger.ErrorContext(ctx, "Failed to publish 2FA email", "error", err) - return dtos.AuthDTO{}, exceptions.NewInternalServerError() - } - } - - return dtos.NewTempAuthDTO( - signedTwoFAToken, - "Please provide two factor Code", - s.jwt.Get2FATTL(), - ), nil - } + // TODO: add two factor login return s.generateFullUserAuthDTO( ctx, @@ -716,17 +654,17 @@ func (s *Services) LoginUser( ) } -type verifyUserTotpOptions struct { +type VerifyUserTOTPOptions struct { requestID string userID int32 code string } -func (s *Services) verifyUserTotp( +func (s *Services) VerifyUserTOTP( ctx context.Context, - opts verifyUserTotpOptions, + opts VerifyUserTOTPOptions, ) (bool, *exceptions.ServiceError) { - logger := s.buildLogger(opts.requestID, usersAuthLocation, "verifyUserTotp").With( + logger := s.buildLogger(opts.requestID, usersAuthLocation, "VerifyUserTOTP").With( "userId", opts.userID, ) logger.InfoContext(ctx, "Verifying user TOTP...") @@ -765,18 +703,18 @@ func (s *Services) verifyUserTotp( return true, nil } -type verifierUserEmailCodeOptions struct { +type VerifyUserEmailCodeOptions struct { requestID string accountID int32 userID int32 code string } -func (s *Services) verifyUserEmailCode( +func (s *Services) VerifyUserEmailCode( ctx context.Context, - opts verifierUserEmailCodeOptions, + opts VerifyUserEmailCodeOptions, ) (bool, *exceptions.ServiceError) { - logger := s.buildLogger(opts.requestID, usersAuthLocation, "verifyUserEmailCode").With( + logger := s.buildLogger(opts.requestID, usersAuthLocation, "VerifyUserEmailCode").With( "accountId", opts.accountID, "userId", opts.userID, ) @@ -802,106 +740,6 @@ func (s *Services) verifyUserEmailCode( return true, nil } -type TwoFactorLoginUserOptions struct { - RequestID string - AccountID int32 - AccountUsername string - AppClientID string - AppVersion int32 - UserPublicID uuid.UUID - UserVersion int32 - Code string -} - -func (s *Services) TwoFactorLoginUser( - ctx context.Context, - opts TwoFactorLoginUserOptions, -) (dtos.AuthDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, usersAuthLocation, "TwoFactorLoginUser").With( - "accountId", opts.AccountID, - "appClientId", opts.AppClientID, - "userPublicId", opts.UserPublicID, - ) - logger.InfoContext(ctx, "Two-factor login for user...") - - appDTO, serviceErr := s.GetAppByClientIDVersionAndAccountID(ctx, GetAppByClientIDVersionAndAccountIDOptions{ - RequestID: opts.RequestID, - ClientID: opts.AppClientID, - Version: opts.AppVersion, - AccountID: opts.AccountID, - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to get app by ID", "error", serviceErr) - return dtos.AuthDTO{}, serviceErr - } - - userDTO, serviceErr := s.GetUserByPublicIDAndVersion(ctx, GetUserByPublicIDAndVersionOptions{ - RequestID: opts.RequestID, - AccountID: opts.AccountID, - PublicID: opts.UserPublicID, - Version: opts.UserVersion, - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to get user by ID", "error", serviceErr) - return dtos.AuthDTO{}, serviceErr - } - - if _, err := s.database.FindAppProfileByAppIDAndUserID(ctx, database.FindAppProfileByAppIDAndUserIDParams{ - AppID: appDTO.ID(), - UserID: userDTO.ID(), - }); err != nil { - serviceErr := exceptions.FromDBError(err) - if serviceErr.Code == exceptions.CodeNotFound { - logger.WarnContext(ctx, "App profile not found") - return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() - } - - logger.ErrorContext(ctx, "Failed to get app profile", "error", serviceErr) - return dtos.AuthDTO{}, serviceErr - } - - // Verify the two-factor Code based on the user's two-factor type - var verified bool - switch userDTO.TwoFactorType { - case database.TwoFactorTypeTotp: - verified, serviceErr = s.verifyUserTotp(ctx, verifyUserTotpOptions{ - requestID: opts.RequestID, - userID: userDTO.ID(), - code: opts.Code, - }) - case database.TwoFactorTypeEmail: - verified, serviceErr = s.verifyUserEmailCode(ctx, verifierUserEmailCodeOptions{ - requestID: opts.RequestID, - accountID: opts.AccountID, - userID: userDTO.ID(), - code: opts.Code, - }) - default: - logger.WarnContext(ctx, "Invalid two-factor type", "twoFactorType", userDTO.TwoFactorType) - return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() - } - - if serviceErr != nil { - return dtos.AuthDTO{}, serviceErr - } - if !verified { - logger.WarnContext(ctx, "Two-factor Code verification failed") - return dtos.AuthDTO{}, exceptions.NewUnauthorizedError() - } - - return s.generateFullUserAuthDTO( - ctx, - logger, - opts.RequestID, - opts.AccountID, - &userDTO, - &appDTO, - appDTO.DefaultScopes, - opts.AccountUsername, - "User two-factor login successful", - ) -} - type LogoutUserOptions struct { RequestID string AccountID int32 @@ -1196,7 +1034,9 @@ func (s *Services) ForgotUserPassword( RequestID: opts.RequestID, AccountID: opts.AccountID, }), - StoreFN: s.BuildUpdateJWKDEKFn(ctx, opts.RequestID), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.RequestID, + }), }) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to sign reset token", "serviceError", serviceErr) From f8e85ce2857c1f94585fbe1b97424f7b143d9aa7 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 7 Sep 2025 14:26:23 +1200 Subject: [PATCH 19/23] fix(idp): fix confimation 2FA config delete check --- idp/internal/controllers/account_2fa_configs.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/idp/internal/controllers/account_2fa_configs.go b/idp/internal/controllers/account_2fa_configs.go index 12ff4f7..eb5d180 100644 --- a/idp/internal/controllers/account_2fa_configs.go +++ b/idp/internal/controllers/account_2fa_configs.go @@ -11,6 +11,7 @@ import ( "github.com/tugascript/devlogs/idp/internal/controllers/bodies" "github.com/tugascript/devlogs/idp/internal/controllers/params" + "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/services" ) @@ -172,7 +173,7 @@ func (c *Controllers) ConfirmDeleteAccount2FAConfig(ctx *fiber.Ctx) error { logger := c.buildLogger(requestID, account2FAConfigsLocation, "ConfirmDeleteAccount2FAConfig") logRequest(logger, ctx) - accountClaims, serviceErr := getAccountClaims(ctx) + accountClaims, twoFAType, serviceErr := getAccounts2FAClaims(ctx) if serviceErr != nil { return serviceErrorResponse(logger, ctx, serviceErr) } @@ -181,6 +182,9 @@ func (c *Controllers) ConfirmDeleteAccount2FAConfig(ctx *fiber.Ctx) error { if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { return validateURLParamsErrorResponse(logger, ctx, err) } + if string(twoFAType) != urlParams.TwoFAType { + return serviceErrorResponse(logger, ctx, exceptions.NewUnauthorizedError()) + } body := new(bodies.TwoFactorLoginBody) if err := ctx.BodyParser(body); err != nil { From 5bb2847773d6ca6784d55639d3a326dff5de3677 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 2 Nov 2025 14:21:41 +1300 Subject: [PATCH 20/23] feat(idp): add software_statement validation --- idp/dbml-error.log | 6 + idp/initial_schema.dbml | 254 +++-- .../account_dynamic_registration_configs.go | 1 - .../bodies/oauth_dynamic_registration.go | 51 +- ...ins.go => dynamic_registration_domains.go} | 8 +- idp/internal/providers/cache/response.go | 97 +- .../database/account_credentials.sql.go | 395 ++++++-- .../database/account_credentials_keys.sql.go | 12 +- ...ccount_dynamic_registration_configs.sql.go | 33 +- ...t_dynamic_registration_domain_codes.sql.go | 63 -- ...ccount_dynamic_registration_domains.sql.go | 408 -------- .../providers/database/app_keys.sql.go | 6 +- .../database/app_related_apps.sql.go | 42 +- idp/internal/providers/database/apps.sql.go | 748 ++++++++++---- .../database/credentials_keys.sql.go | 31 +- .../dynamic_registration_domain_codes.sql.go | 53 +- .../dynamic_registration_domains.sql.go | 485 +++++++++ ...egistration_software_statement_keys.sql.go | 75 ++ ...0241213231542_create_initial_schema.up.sql | 238 +++-- idp/internal/providers/database/models.go | 399 ++++++-- .../database/queries/account_credentials.sql | 70 +- .../account_dynamic_registration_configs.sql | 13 +- ...ount_dynamic_registration_domain_codes.sql | 22 - .../account_dynamic_registration_domains.sql | 97 -- .../database/queries/app_related_apps.sql | 2 +- .../providers/database/queries/apps.sql | 28 +- .../database/queries/credentials_keys.sql | 5 + .../dynamic_registration_domain_codes.sql | 12 +- .../queries/dynamic_registration_domains.sql | 123 +++ ...c_registration_software_statement_keys.sql | 16 + ...ynamic_registration_software_statements.go | 101 ++ idp/internal/providers/tokens/jwks.go | 106 ++ .../routes/account_dynamic_registration.go | 2 +- idp/internal/server/server.go | 9 +- idp/internal/services/account_credentials.go | 126 ++- .../account_credentials_registration.go | 943 ++++++++++++++++++ .../account_dynamic_registration_configs.go | 68 +- idp/internal/services/apps.go | 58 +- .../services/dtos/account_credentials.go | 297 ++++-- .../account_dynamic_registration_config.go | 4 +- idp/internal/services/dtos/app.go | 70 +- .../dtos/dynamic_registration_domain.go | 4 +- ...ins.go => dynamic_registration_domains.go} | 172 ++-- idp/internal/services/helpers.go | 169 +++- .../services/oauth_dynamic_registration.go | 12 +- idp/internal/services/services.go | 5 + idp/internal/services/software_statement.go | 524 ++++++++++ idp/internal/services/users_auth.go | 6 +- idp/internal/utils/jwk.go | 58 ++ idp/tests/account_credentials_test.go | 6 +- 50 files changed, 4944 insertions(+), 1589 deletions(-) rename idp/internal/controllers/{account_credentials_registration_domains.go => dynamic_registration_domains.go} (97%) delete mode 100644 idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go delete mode 100644 idp/internal/providers/database/account_dynamic_registration_domains.sql.go create mode 100644 idp/internal/providers/database/dynamic_registration_domains.sql.go create mode 100644 idp/internal/providers/database/dynamic_registration_software_statement_keys.sql.go delete mode 100644 idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql delete mode 100644 idp/internal/providers/database/queries/account_dynamic_registration_domains.sql create mode 100644 idp/internal/providers/database/queries/dynamic_registration_domains.sql create mode 100644 idp/internal/providers/database/queries/dynamic_registration_software_statement_keys.sql create mode 100644 idp/internal/providers/tokens/dynamic_registration_software_statements.go create mode 100644 idp/internal/providers/tokens/jwks.go create mode 100644 idp/internal/services/account_credentials_registration.go rename idp/internal/services/{account_credentials_registration_domains.go => dynamic_registration_domains.go} (80%) create mode 100644 idp/internal/services/software_statement.go diff --git a/idp/dbml-error.log b/idp/dbml-error.log index fca0d60..8337ecd 100644 --- a/idp/dbml-error.log +++ b/idp/dbml-error.log @@ -34,3 +34,9 @@ undefined 2025-08-16T01:54:09.751Z undefined +2025-10-10T22:09:18.834Z +undefined + +2025-10-30T18:50:05.002Z +undefined + diff --git a/idp/initial_schema.dbml b/idp/initial_schema.dbml index 89a9017..b274791 100644 --- a/idp/initial_schema.dbml +++ b/idp/initial_schema.dbml @@ -64,6 +64,21 @@ Enum token_crypto_suite { "EdDSA" } +Enum token_encryption_algorithm { + "RSA-OAEP-256" + "ECDH-ES" + "ECDH-ES+A256KW" +} + +Enum token_encryption_encoding { + "A128CBC-HS256" + "A192CBC-HS384" + "A256CBC-HS512" + "A128GCM" + "A192GCM" + "A256GCM" +} + Enum token_key_usage { "global" "account" @@ -241,6 +256,7 @@ Table credentials_keys as CK { public_key jsonb [not null] crypto_suite token_crypto_suite [not null] is_revoked boolean [not null, default: false] + is_external boolean [not null, default: false] usage credentials_usage [not null] account_id integer [not null] @@ -340,7 +356,6 @@ Enum auth_method { Enum response_type { "code" - "id_token" "code id_token" } @@ -374,6 +389,11 @@ Enum transport { "streamable_http" } +Enum client_subject_type { + "public" + "pairwise" +} + Enum creation_method { "manual" "dynamic_registration" @@ -385,26 +405,53 @@ Table account_credentials as AC { account_id integer [not null] account_public_id uuid [not null] - client_id varchar(22) [not null] - name varchar(255) [not null] + // Internal for checks domain "varchar(250)" [not null] - credentials_type account_credentials_type [not null] - scopes "account_credentials_scope[]" [not null] - token_endpoint_auth_method "auth_method" [not null] - grant_types "grant_type[]" [not null] + creation_method creation_method [not null] + transport transport [not null] version integer [not null, default: 1] - transport transport [not null] - creation_method creation_method [not null] + // Client ID generated internally + client_id varchar(22) [not null] - client_uri varchar(512) [not null] + // RFC 7591 Client Metadata redirect_uris "varchar(2048)[]" [not null] + token_endpoint_auth_method "auth_method" [not null] + grant_types "grant_type[]" [not null] + response_types "response_type[]" [not null] + client_name varchar(255) [not null] + client_uri varchar(512) [not null] logo_uri varchar(512) [null] - policy_uri varchar(512) [null] + scopes "account_credentials_scope[]" [not null] // scope devided by ' ' + contacts "varchar(250)[]" [not null] tos_uri varchar(512) [null] - software_id varchar(512) [not null] + policy_uri varchar(512) [null] + jwks_uri varchar(512) [null] + jwks jsonb [null] + software_id varchar(512) [null] software_version varchar(512) [null] - contacts "varchar(250)[]" [not null] + + // OpenID Connect Dynamic Client Registration 1.0 + credentials_type account_credentials_type [not null] // application type + sector_identifier_uri varchar(512) [null] + subject_type client_subject_type [null] + id_token_signed_response_alg token_crypto_suite [not null] + id_token_encrypted_response_alg token_encryption_algorithm [null] + id_token_encrypted_response_enc token_encryption_encoding [null] + userinfo_signed_response_alg token_crypto_suite [null] + userinfo_encrypted_response_alg token_encryption_algorithm [null] + userinfo_encrypted_response_enc token_encryption_encoding [null] + request_object_signing_alg token_crypto_suite [null] + request_object_encryption_alg token_encryption_algorithm [null] + request_object_encryption_enc token_encryption_encoding [null] + token_endpoint_auth_signing_alg token_crypto_suite [null] + default_max_age bigint [null] + require_auth_time boolean [not null, default: false] + default_acr_values "varchar(100)[]" [null] + initiate_login_uri varchar(512) [null] + request_uris "varchar(2048)[]" [null] + // Custom field for selecting access token signing algorithm + access_token_signing_alg token_crypto_suite [not null] created_at timestamptz [not null, default: `now()`] updated_at timestamptz [not null, default: `now()`] @@ -414,7 +461,7 @@ Table account_credentials as AC { (account_id) [name: 'account_credentials_account_id_idx'] (account_public_id) [name: 'account_credentials_account_public_id_idx'] (account_public_id, client_id) [name: 'account_credentials_account_public_id_client_id_idx'] - (name, account_id) [unique, name: 'account_credentials_name_account_id_uidx'] + (client_name, account_id) [unique, name: 'account_credentials_client_name_account_id_uidx'] } } Ref: AC.account_id > A.id [delete: cascade] @@ -773,24 +820,28 @@ Table apps as APP { account_public_id uuid [not null] - app_type app_type [not null] - name varchar(255) [not null] client_id varchar(22) [not null] version integer [not null, default: 1] creation_method creation_method [not null] // Common dynamic registration fields + // RFC 7591 Client Metadata + redirect_uris "varchar(2048)[]" [not null] + token_endpoint_auth_method "auth_method" [not null] + grant_types "grant_type[]" [not null] + response_types "response_type[]" [not null] + client_name varchar(255) [not null] client_uri varchar(512) [not null] logo_uri varchar(512) [null] + scopes "scopes[]" [not null] // scope devided by ' ' + custom_scopes "varchar(512)[]" [not null] + contacts "varchar(250)[]" [not null] tos_uri varchar(512) [null] policy_uri varchar(512) [null] - software_id varchar(250) [not null] - software_version varchar(250) [null] - contacts "varchar(250)[]" [not null] - token_endpoint_auth_method "auth_method" [not null] - scopes "scopes[]" [not null] - custom_scopes "varchar(512)[]" [not null] - grant_types "grant_type[]" [not null] + jwks_uri varchar(512) [null] + jwks jsonb [null] + software_id varchar(512) [null] + software_version varchar(512) [null] // Common on all OAuth2 apps domain varchar(250) [not null] @@ -801,9 +852,28 @@ Table apps as APP { default_scopes "scopes[]" [not null] default_custom_scopes "varchar(512)[]" [not null] - // Apps with redirects - redirect_uris "varchar(2048)[]" [not null] - response_types "response_type[]" [not null] + // Additional OIDC registration metadata + // OpenID Connect Dynamic Client Registration 1.0 + app_type app_type [not null] // application type + sector_identifier_uri varchar(512) [null] + subject_type client_subject_type [null] + id_token_signed_response_alg token_crypto_suite [not null] + id_token_encrypted_response_alg token_encryption_algorithm [null] + id_token_encrypted_response_enc token_encryption_encoding [null] + userinfo_signed_response_alg token_crypto_suite [null] + userinfo_encrypted_response_alg token_encryption_algorithm [null] + userinfo_encrypted_response_enc token_encryption_encoding [null] + request_object_signing_alg token_crypto_suite [null] + request_object_encryption_alg token_encryption_algorithm [null] + request_object_encryption_enc token_encryption_encoding [null] + token_endpoint_auth_signing_alg token_crypto_suite [null] + default_max_age integer [null] + require_auth_time boolean [not null, default: false] + default_acr_values "varchar(100)[]" [null] + initiate_login_uri varchar(512) [null] + request_uris "varchar(2048)[]" [null] + // Custom field for selecting access token signing algorithm + access_token_signing_alg token_crypto_suite [not null] // Tokens TTLs id_token_ttl integer [not null, default: 300] // 5 minutes @@ -819,8 +889,8 @@ Table apps as APP { (client_id) [unique, name: 'apps_client_id_uidx'] (client_id, account_public_id) [name: 'apps_client_id_account_public_id_idx'] (account_public_id) [name: 'apps_account_public_id_idx'] - (name) [name: 'apps_name_idx'] - (account_id, name) [unique, name: 'apps_account_id_name_uidx'] + (client_name) [name: 'apps_client_name_idx'] + (account_id, client_name) [unique, name: 'apps_account_id_client_name_uidx'] (account_id, app_type) [name: 'apps_account_id_app_type_idx'] } } @@ -939,7 +1009,6 @@ Enum initial_access_token_generation_method { Enum software_statement_verification_method { "manual" "jwks_uri" - "jwk_x5_parameters" } Table account_dynamic_registration_configs as ADRC { @@ -949,9 +1018,9 @@ Table account_dynamic_registration_configs as ADRC { account_public_id uuid [not null] account_credentials_types "account_credentials_type[]" [not null] - whitelisted_domains "varchar(250)[]" [not null] require_software_statement_credential_types "account_credentials_type[]" [not null] software_statement_verification_methods "software_statement_verification_method[]" [not null] + require_verified_domains_credentials_type "account_credentials_type[]" [not null] require_initial_access_token_credential_types "account_credentials_type[]" [not null] initial_access_token_generation_methods "initial_access_token_generation_method[]" [not null] @@ -966,13 +1035,54 @@ Table account_dynamic_registration_configs as ADRC { } Ref: ADRC.account_id > A.id [delete: cascade] +Table app_dynamic_registration_configs as APDRC { + id serial [pk] + + account_id integer [not null] + + allowed_app_types "app_type[]" [not null] + whitelisted_domains "varchar(250)[]" [not null] + default_allow_user_registration boolean [not null] + default_auth_providers "auth_provider[]" [not null] + default_username_column app_username_column [not null] + default_allowed_scopes "scopes[]" [not null] + default_scopes "scopes[]" [not null] + + require_verified_domains_app_types "app_type[]" [not null] + + require_software_statement_app_types "app_type[]" [not null] + software_statement_verification_methods "software_statement_verification_method[]" [not null] + + require_initial_access_token_app_types "app_type[]" [not null] + initial_access_token_generation_methods "initial_access_token_generation_method[]" [not null] + initial_access_token_ttl integer [not null, default: 3600] // 1 hour + initial_access_token_max_uses int [not null, default: 1] + + allowed_grant_types "grant_type[]" [not null, default: '{ "authorization_code", "refresh_token", "client_credentials", "urn:ietf:params:oauth:grant-type:device_code", "urn:ietf:params:oauth:grant-type:jwt-bearer" }'] + allowed_response_types "response_type[]" [not null, default: '{ "code", "id_token", "code id_token" }'] + allowed_token_endpoint_auth_methods "auth_method[]" [not null, default: '{ "none", "client_secret_post", "client_secret_basic", "client_secret_jwt", "private_key_jwt" }'] + max_redirect_uris int [not null, default: 10] + + created_at timestamptz [not null, default: `now()`] + updated_at timestamptz [not null, default: `now()`] + + Indexes { + (account_id) [name: 'app_dynamic_registration_configs_account_id_idx'] + } +} +Ref: APDRC.account_id > A.id [delete: cascade] + +Enum dynamic_registration_usage { + "account" + "app" +} + Enum domain_verification_method { "authorization_code" - "software_statement" "dns_txt_record" } -Table account_dynamic_registration_domains as ADRD { +Table dynamic_registration_domains as ADRD { id serial [pk] account_id integer [not null] @@ -981,6 +1091,7 @@ Table account_dynamic_registration_domains as ADRD { domain varchar(250) [not null] verified_at timestamptz [null] verification_method domain_verification_method [not null] + usages "dynamic_registration_usage[]" [not null] created_at timestamptz [not null, default: `now()`] updated_at timestamptz [not null, default: `now()`] @@ -998,6 +1109,8 @@ Table dynamic_registration_domain_codes as DRDC { id serial [pk] account_id integer [not null] + dynamic_registration_domain_id integer [not null] + verification_host varchar(50) [not null] verification_code text [not null] hmac_secret_id varchar(22) [not null] @@ -1008,85 +1121,35 @@ Table dynamic_registration_domain_codes as DRDC { updated_at timestamptz [not null, default: `now()`] Indexes { - (account_id) [name: 'account_dynamic_registration_domain_codes_account_id_idx'] + (account_id) [name: 'dynamic_registration_domain_codes_account_id_idx'] + (dynamic_registration_domain_id) [name: 'dynamic_registration_domain_codes_dynamic_registration_domain_id_idx'] } } Ref: DRDC.account_id > A.id [delete: cascade] +Ref: DRDC.dynamic_registration_domain_id > ADRD.id [delete: cascade] Ref: DRDC.hmac_secret_id > AHS.secret_id [delete: cascade, update: cascade] -Table account_dynamic_registration_domain_codes as ADRDC { - account_dynamic_registration_domain_id integer [not null] - dynamic_registration_domain_code_id integer [not null] - - account_id integer [not null] - created_at timestamptz [not null, default: `now()`] - - Indexes { - (account_dynamic_registration_domain_id, dynamic_registration_domain_code_id) [pk] - (account_id) [name: 'account_dynamic_registration_domain_codes_account_id_idx'] - (account_dynamic_registration_domain_id) [unique, name: 'account_dynamic_registration_domain_codes_account_dynamic_registration_domain_id_uidx'] - (dynamic_registration_domain_code_id) [unique, name: 'account_dynamic_registration_domain_codes_dynamic_registration_domain_code_id_uidx'] - } -} -Ref: ADRDC.account_id > A.id [delete: cascade] -Ref: ADRDC.account_dynamic_registration_domain_id > ADRD.id [delete: cascade] -Ref: ADRDC.dynamic_registration_domain_code_id > DRDC.id [delete: cascade] - -Table account_dynamic_registration_software_statement_keys as ADRSK { +Table dynamic_registration_software_statement_keys as DRSK { id serial [pk] account_id integer [not null] account_public_id uuid [not null] credentials_key_id integer [not null] - account_dynamic_registration_domain_id integer [not null] + credentials_key_kid varchar(22) [not null] + root_domain varchar(250) [not null] created_at timestamptz [not null, default: `now()`] Indexes { - (account_id) [name: 'account_dynamic_registration_software_statement_keys_account_id_idx'] - (account_public_id) [name: 'account_dynamic_registration_software_statement_keys_account_public_id_idx'] - (credentials_key_id) [unique, name: 'account_dynamic_registration_software_statement_keys_credentials_key_id_uidx'] - (account_dynamic_registration_domain_id) [unique, name: 'account_dynamic_registration_software_statement_keys_account_dynamic_registration_domain_id_uidx'] - } -} -Ref: ADRSK.account_id > A.id [delete: cascade] -Ref: ADRSK.credentials_key_id > CK.id [delete: cascade] -Ref: ADRSK.account_dynamic_registration_domain_id > ADRD.id [delete: cascade] - -Table app_dynamic_registration_configs as APDRC { - id serial [pk] - - account_id integer [not null] - - allowed_app_types "app_type[]" [not null] - whitelisted_domains "varchar(250)[]" [not null] - default_allow_user_registration boolean [not null] - default_auth_providers "auth_provider[]" [not null] - default_username_column app_username_column [not null] - default_allowed_scopes "scopes[]" [not null] - default_scopes "scopes[]" [not null] - - require_software_statement_app_types "app_type[]" [not null] - software_statement_verification_methods "software_statement_verification_method[]" [not null] - - require_initial_access_token_app_types "app_type[]" [not null] - initial_access_token_generation_methods "initial_access_token_generation_method[]" [not null] - initial_access_token_ttl integer [not null, default: 3600] // 1 hour - initial_access_token_max_uses int [not null, default: 1] - - allowed_grant_types "grant_type[]" [not null, default: '{ "authorization_code", "refresh_token", "client_credentials", "urn:ietf:params:oauth:grant-type:device_code", "urn:ietf:params:oauth:grant-type:jwt-bearer" }'] - allowed_response_types "response_type[]" [not null, default: '{ "code", "id_token", "code id_token" }'] - allowed_token_endpoint_auth_methods "auth_method[]" [not null, default: '{ "none", "client_secret_post", "client_secret_basic", "client_secret_jwt", "private_key_jwt" }'] - max_redirect_uris int [not null, default: 10] - - created_at timestamptz [not null, default: `now()`] - updated_at timestamptz [not null, default: `now()`] - - Indexes { - (account_id) [name: 'app_dynamic_registration_configs_account_id_idx'] + (account_id) [name: 'drs_statement_keys_account_id_idx'] + (account_public_id) [name: 'drs_statement_keys_account_public_id_idx'] + (credentials_key_id) [unique, name: 'drs_statement_keys_credentials_key_id_uidx'] + (root_domain, account_public_id) [name: 'drs_statement_keys_root_domain_account_public_id_idx'] + (credentials_key_kid, account_public_id) [name: 'drs_statement_keys_credentials_key_kid_account_public_id_idx'] } } -Ref: APDRC.account_id > A.id [delete: cascade] +Ref: DRSK.account_id > A.id [delete: cascade] +Ref: DRSK.credentials_key_id > CK.id [delete: cascade] Enum app_profile_type { "human" @@ -1140,4 +1203,3 @@ Table revoked_tokens as RT { } } Ref: RT.account_id > A.id [delete: cascade] - diff --git a/idp/internal/controllers/account_dynamic_registration_configs.go b/idp/internal/controllers/account_dynamic_registration_configs.go index cc1cd10..40876b2 100644 --- a/idp/internal/controllers/account_dynamic_registration_configs.go +++ b/idp/internal/controllers/account_dynamic_registration_configs.go @@ -46,7 +46,6 @@ func (c *Controllers) UpsertAccountDynamicRegistrationConfig(ctx *fiber.Ctx) err AccountPublicID: accountClaims.AccountID, AccountVersion: accountClaims.AccountVersion, AccountCredentialsTypes: body.AccountCredentialsTypes, - WhitelistedDomains: body.WhitelistedDomains, RequireSoftwareStatementCredentialTypes: body.RequireSoftwareStatementCredentialTypes, SoftwareStatementVerificationMethods: body.SoftwareStatementVerificationMethods, RequireInitialAccessTokenCredentialTypes: body.RequireInitialAccessTokenCredentialTypes, diff --git a/idp/internal/controllers/bodies/oauth_dynamic_registration.go b/idp/internal/controllers/bodies/oauth_dynamic_registration.go index 3ef4948..5efbf4e 100644 --- a/idp/internal/controllers/bodies/oauth_dynamic_registration.go +++ b/idp/internal/controllers/bodies/oauth_dynamic_registration.go @@ -7,22 +7,41 @@ package bodies type OAuthDynamicClientRegistrationBody struct { - RedirectURIs []string `json:"redirect_uris" validate:"required,min=1,dive,uri"` - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty" validate:"omitempty,oneof=none client_secret_basic client_secret_post client_secret_jwt private_key_jwt"` - ResponseTypes []string `json:"response_types,omitempty" validate:"omitempty,dive,oneof=none code"` - GrantTypes []string `json:"grant_types,omitempty" validate:"omitempty,min=1,dive,oneof=authorization_code refresh_token client_credentials urn:ietf:params:oauth:grant-type:jwt-bearer"` - ApplicationType string `json:"application_type" validate:"required,oneof=native service mcp"` - ClientName string `json:"client_name" validate:"required,min=1,max=255"` - ClientURI string `json:"client_uri" validate:"required,url"` - LogoURI string `json:"logo_uri,omitempty" validate:"omitempty,url"` - Scope string `json:"scope" validate:"required,multiple_scope"` - Contacts []string `json:"contacts,omitempty" validate:"omitempty,unique,dive,email"` - TOSURI string `json:"tos_uri,omitempty" validate:"omitempty,url"` - PolicyURI string `json:"policy_uri,omitempty" validate:"omitempty,url"` - JWKsURI string `json:"jwks_uri,omitempty" validate:"omitempty,url"` - JWKs []string `json:"jwks,omitempty" validate:"omitempty,json"` - SoftwareID string `json:"software_id,omitempty" validate:"omitempty,max=250"` - SoftwareVersion string `json:"software_version,omitempty" validate:"omitempty,max=250"` + RedirectURIs []string `json:"redirect_uris,omitempty" validate:"omitempty,min=1,dive,uri"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty" validate:"omitempty,oneof=none client_secret_basic client_secret_post client_secret_jwt private_key_jwt"` + ResponseTypes []string `json:"response_types,omitempty" validate:"omitempty,dive,oneof=code 'code id_token'"` + GrantTypes []string `json:"grant_types,omitempty" validate:"omitempty,min=1,dive,oneof=authorization_code refresh_token client_credentials urn:ietf:params:oauth:grant-type:jwt-bearer"` + ApplicationType string `json:"application_type" validate:"required,oneof=native service mcp"` + ClientName string `json:"client_name" validate:"required,min=1,max=255"` + ClientURI string `json:"client_uri" validate:"required,url"` + LogoURI string `json:"logo_uri,omitempty" validate:"omitempty,url"` + Scope string `json:"scope" validate:"required,multiple_scope"` + Contacts []string `json:"contacts,omitempty" validate:"omitempty,unique,dive,email"` + TOSURI string `json:"tos_uri,omitempty" validate:"omitempty,url"` + PolicyURI string `json:"policy_uri,omitempty" validate:"omitempty,url"` + JWKsURI string `json:"jwks_uri,omitempty" validate:"omitempty,url"` + JWKs []string `json:"jwks,omitempty" validate:"omitempty,json"` + SoftwareID string `json:"software_id,omitempty" validate:"omitempty,max=512"` + SoftwareVersion string `json:"software_version,omitempty" validate:"omitempty,max=512"` + SubjectType string `json:"subject_type,omitempty" validate:"omitempty,oneof=public pairwise"` + SectorIdentifierURI string `json:"sector_identifier_uri,omitempty" validate:"omitempty,url"` + DefaultMaxAge int64 `json:"default_max_age,omitempty" validate:"omitempty,min=0"` + RequireAuthTime bool `json:"require_auth_time,omitempty" validate:"omitempty,bool"` + DefaultACRValues []string `json:"default_acr_values,omitempty" validate:"omitempty,unique,dive,max=100"` + InitiateLoginURI string `json:"initiate_login_uri,omitempty" validate:"omitempty,url"` + RequestURIs []string `json:"request_uris,omitempty" validate:"omitempty,unique,dive,url"` + IDTokenSignedResponseAlg string `json:"id_token_signed_response_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` + IDTokenEncryptedResponseAlg string `json:"id_token_encrypted_response_alg,omitempty" validate:"omitempty,oneof=RSA-OAEP-256 ECDH-ES ECDH-ES+A256KW"` + IDTokenEncryptedResponseEnc string `json:"id_token_encrypted_response_enc,omitempty" validate:"omitempty,oneof=A128CBC-HS256 A192CBC-HS384 A256CBC-HS512 A128GCM A192GCM A256GCM"` + UserInfoSignedResponseAlg string `json:"userinfo_signed_response_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` + UserInfoEncryptedResponseAlg string `json:"userinfo_encrypted_response_alg,omitempty" validate:"omitempty,oneof=RSA-OAEP-256 ECDH-ES ECDH-ES+A256KW"` + UserInfoEncryptedResponseEnc string `json:"userinfo_encrypted_response_enc,omitempty" validate:"omitempty,oneof=A128CBC-HS256 A192CBC-HS384 A256CBC-HS512 A128GCM A192GCM A256GCM"` + RequestObjectSigningAlg string `json:"request_object_signing_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` + RequestObjectEncryptionAlg string `json:"request_object_encryption_alg,omitempty" validate:"omitempty,oneof=RSA-OAEP-256 ECDH-ES ECDH-ES+A256KW"` + RequestObjectEncryptionEnc string `json:"request_object_encryption_enc,omitempty" validate:"omitempty,oneof=A128CBC-HS256 A192CBC-HS384 A256CBC-HS512 A128GCM A192GCM A256GCM"` + TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` + AccessTokenSigningAlg string `json:"access_token_signing_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` + SoftwareStatement string `json:"software_statement,omitempty" validate:"omitempty,jwt"` } type OAuthDynamicRegistrationIATAuthHiddenFieldsBody struct { diff --git a/idp/internal/controllers/account_credentials_registration_domains.go b/idp/internal/controllers/dynamic_registration_domains.go similarity index 97% rename from idp/internal/controllers/account_credentials_registration_domains.go rename to idp/internal/controllers/dynamic_registration_domains.go index bd05720..a2e7338 100644 --- a/idp/internal/controllers/account_credentials_registration_domains.go +++ b/idp/internal/controllers/dynamic_registration_domains.go @@ -20,9 +20,9 @@ const ( accountCredentialsRegistrationDomainsLocation string = "account_credentials_registration_domains" ) -func (c *Controllers) CreateAccountCredentialsRegistrationDomain(ctx *fiber.Ctx) error { +func (c *Controllers) CreateDynamicRegistrationDomain(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "CreateAccountDynamicRegistrationDomain") + logger := c.buildLogger(requestID, accountCredentialsRegistrationDomainsLocation, "CreateDynamicRegistrationDomain") logRequest(logger, ctx) accountClaims, serviceErr := getAccountClaims(ctx) @@ -38,9 +38,9 @@ func (c *Controllers) CreateAccountCredentialsRegistrationDomain(ctx *fiber.Ctx) return validateBodyErrorResponse(logger, ctx, err) } - domainDTO, serviceErr := c.services.CreateAccountCredentialsRegistrationDomain( + domainDTO, serviceErr := c.services.CreateDynamicRegistrationDomain( ctx.UserContext(), - services.CreateAccountCredentialsRegistrationDomainOptions{ + services.CreateDynamicRegistrationDomainOptions{ RequestID: requestID, AccountPublicID: accountClaims.AccountID, AccountVersion: accountClaims.AccountVersion, diff --git a/idp/internal/providers/cache/response.go b/idp/internal/providers/cache/response.go index 1958f83..08b64ba 100644 --- a/idp/internal/providers/cache/response.go +++ b/idp/internal/providers/cache/response.go @@ -19,7 +19,7 @@ const responseLocation string = "response" type SaveResponseOptions[T any] struct { RequestID string Key string - TTL int + TTL time.Duration Value T } @@ -41,7 +41,7 @@ func SaveResponse[T any]( return "", err } - if err := c.storage.SetWithContext(ctx, opts.Key, responseBytes, time.Duration(opts.TTL)*time.Second); err != nil { + if err := c.storage.SetWithContext(ctx, opts.Key, responseBytes, opts.TTL); err != nil { logger.ErrorContext(ctx, "Error saving response", "error", err) return "", err } @@ -50,6 +50,33 @@ func SaveResponse[T any]( return utils.GenerateETag(responseBytes), nil } +func SaveResponseWithoutETag[T any]( + c *Cache, + ctx context.Context, + opts SaveResponseOptions[T], +) error { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: responseLocation, + Method: "SaveResponseWithoutETag", + RequestID: opts.RequestID, + }).With("key", opts.Key) + logger.DebugContext(ctx, "Saving response without ETag...") + + responseBytes, err := json.Marshal(opts.Value) + if err != nil { + logger.ErrorContext(ctx, "Error marshalling response", "error", err) + return err + } + + if err := c.storage.SetWithContext(ctx, opts.Key, responseBytes, opts.TTL); err != nil { + logger.ErrorContext(ctx, "Error saving response", "error", err) + return err + } + + logger.DebugContext(ctx, "Response saved successfully") + return nil +} + type GetResponseOptions[T any] struct { RequestID string Key string @@ -59,7 +86,7 @@ func GetResponse[T any]( c *Cache, ctx context.Context, opts GetResponseOptions[T], -) (T, string, error) { +) (T, string, bool, error) { logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ Location: responseLocation, Method: "GetResponse", @@ -71,17 +98,73 @@ func GetResponse[T any]( responseBytes, err := c.storage.GetWithContext(ctx, opts.Key) if err != nil { logger.ErrorContext(ctx, "Error getting cached response", "error", err) - return response, "", err + return response, "", false, err + } + if responseBytes == nil { + logger.DebugContext(ctx, "No cached response found") + return response, "", false, nil + } + + if err := json.Unmarshal(responseBytes, &response); err != nil { + logger.ErrorContext(ctx, "Error unmarshalling response", "error", err) + return response, "", false, err + } + + return response, utils.GenerateETag(responseBytes), true, nil +} + +func GetResponseWithoutETag[T any]( + c *Cache, + ctx context.Context, + opts GetResponseOptions[T], +) (T, bool, error) { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: responseLocation, + Method: "GetResponseWithoutETag", + RequestID: opts.RequestID, + }).With("key", opts.Key) + logger.DebugContext(ctx, "Getting cached response without ETag...") + + var response T + responseBytes, err := c.storage.GetWithContext(ctx, opts.Key) + if err != nil { + logger.ErrorContext(ctx, "Error getting cached response", "error", err) + return response, false, err } if responseBytes == nil { logger.DebugContext(ctx, "No cached response found") - return response, "", nil + return response, false, nil } if err := json.Unmarshal(responseBytes, &response); err != nil { logger.ErrorContext(ctx, "Error unmarshalling response", "error", err) - return response, "", err + return response, false, err + } + + return response, true, nil +} + +type DeleteResponseOptions struct { + RequestID string + Key string +} + +func (c *Cache) DeleteResponse( + ctx context.Context, + opts DeleteResponseOptions, +) error { + logger := utils.BuildLogger(c.logger, utils.LoggerOptions{ + Location: responseLocation, + Method: "DeleteResponse", + RequestID: opts.RequestID, + }).With("key", opts.Key) + logger.DebugContext(ctx, "Deleting cached response...") + + if err := c.storage.DeleteWithContext(ctx, opts.Key); err != nil { + logger.ErrorContext(ctx, "Error deleting cached response", "error", err) + return err } - return response, utils.GenerateETag(responseBytes), nil + logger.DebugContext(ctx, "Cached response deleted successfully") + return nil } diff --git a/idp/internal/providers/database/account_credentials.sql.go b/idp/internal/providers/database/account_credentials.sql.go index ef4b466..d9915cd 100644 --- a/idp/internal/providers/database/account_credentials.sql.go +++ b/idp/internal/providers/database/account_credentials.sql.go @@ -45,16 +45,16 @@ func (q *Queries) CountAccountCredentialsByAccountPublicIDAndClientID(ctx contex const countAccountCredentialsByNameAndAccountID = `-- name: CountAccountCredentialsByNameAndAccountID :one SELECT COUNT(*) FROM "account_credentials" -WHERE "account_id" = $1 AND "name" = $2 +WHERE "account_id" = $1 AND "client_name" = $2 ` type CountAccountCredentialsByNameAndAccountIDParams struct { - AccountID int32 - Name string + AccountID int32 + ClientName string } func (q *Queries) CountAccountCredentialsByNameAndAccountID(ctx context.Context, arg CountAccountCredentialsByNameAndAccountIDParams) (int64, error) { - row := q.db.QueryRow(ctx, countAccountCredentialsByNameAndAccountID, arg.AccountID, arg.Name) + row := q.db.QueryRow(ctx, countAccountCredentialsByNameAndAccountID, arg.AccountID, arg.ClientName) var count int64 err := row.Scan(&count) return count, err @@ -62,24 +62,46 @@ func (q *Queries) CountAccountCredentialsByNameAndAccountID(ctx context.Context, const createAccountCredentials = `-- name: CreateAccountCredentials :one INSERT INTO "account_credentials" ( - "client_id", "account_id", "account_public_id", - "credentials_type", - "name", - "scopes", - "token_endpoint_auth_method", "domain", - "client_uri", + "creation_method", + "transport", + "client_id", "redirect_uris", + "token_endpoint_auth_method", + "grant_types", + "response_types", + "client_name", + "client_uri", "logo_uri", - "policy_uri", + "scopes", + "contacts", "tos_uri", + "policy_uri", + "jwks_uri", + "jwks", "software_id", "software_version", - "contacts", - "creation_method", - "transport" + "credentials_type", + "sector_identifier_uri", + "subject_type", + "id_token_signed_response_alg", + "id_token_encrypted_response_alg", + "id_token_encrypted_response_enc", + "userinfo_signed_response_alg", + "userinfo_encrypted_response_alg", + "userinfo_encrypted_response_enc", + "request_object_signing_alg", + "request_object_encryption_alg", + "request_object_encryption_enc", + "token_endpoint_auth_signing_alg", + "default_max_age", + "require_auth_time", + "default_acr_values", + "initiate_login_uri", + "request_uris", + "access_token_signing_alg" ) VALUES ( $1, $2, @@ -98,75 +120,162 @@ INSERT INTO "account_credentials" ( $15, $16, $17, - $18 -) RETURNING id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at + $18, + $19, + $20, + $21, + $22, + $23, + $24, + $25, + $26, + $27, + $28, + $29, + $30, + $31, + $32, + $33, + $34, + $35, + $36, + $37, + $38, + $39, + $40 +) RETURNING id, account_id, account_public_id, domain, creation_method, transport, version, client_id, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, credentials_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, created_at, updated_at ` type CreateAccountCredentialsParams struct { - ClientID string - AccountID int32 - AccountPublicID uuid.UUID - CredentialsType AccountCredentialsType - Name string - Scopes []AccountCredentialsScope - TokenEndpointAuthMethod AuthMethod - Domain string - ClientUri string - RedirectUris []string - LogoUri pgtype.Text - PolicyUri pgtype.Text - TosUri pgtype.Text - SoftwareID string - SoftwareVersion pgtype.Text - Contacts []string - CreationMethod CreationMethod - Transport Transport + AccountID int32 + AccountPublicID uuid.UUID + Domain string + CreationMethod CreationMethod + Transport Transport + ClientID string + RedirectUris []string + TokenEndpointAuthMethod AuthMethod + GrantTypes []GrantType + ResponseTypes []ResponseType + ClientName string + ClientUri string + LogoUri pgtype.Text + Scopes []AccountCredentialsScope + Contacts []string + TosUri pgtype.Text + PolicyUri pgtype.Text + JwksUri pgtype.Text + Jwks []byte + SoftwareID pgtype.Text + SoftwareVersion pgtype.Text + CredentialsType AccountCredentialsType + SectorIdentifierUri pgtype.Text + SubjectType NullClientSubjectType + IDTokenSignedResponseAlg TokenCryptoSuite + IDTokenEncryptedResponseAlg NullTokenEncryptionAlgorithm + IDTokenEncryptedResponseEnc NullTokenEncryptionEncoding + UserinfoSignedResponseAlg NullTokenCryptoSuite + UserinfoEncryptedResponseAlg NullTokenEncryptionAlgorithm + UserinfoEncryptedResponseEnc NullTokenEncryptionEncoding + RequestObjectSigningAlg NullTokenCryptoSuite + RequestObjectEncryptionAlg NullTokenEncryptionAlgorithm + RequestObjectEncryptionEnc NullTokenEncryptionEncoding + TokenEndpointAuthSigningAlg NullTokenCryptoSuite + DefaultMaxAge pgtype.Int8 + RequireAuthTime bool + DefaultAcrValues []string + InitiateLoginUri pgtype.Text + RequestUris []string + AccessTokenSigningAlg TokenCryptoSuite } func (q *Queries) CreateAccountCredentials(ctx context.Context, arg CreateAccountCredentialsParams) (AccountCredential, error) { row := q.db.QueryRow(ctx, createAccountCredentials, - arg.ClientID, arg.AccountID, arg.AccountPublicID, - arg.CredentialsType, - arg.Name, - arg.Scopes, - arg.TokenEndpointAuthMethod, arg.Domain, - arg.ClientUri, + arg.CreationMethod, + arg.Transport, + arg.ClientID, arg.RedirectUris, + arg.TokenEndpointAuthMethod, + arg.GrantTypes, + arg.ResponseTypes, + arg.ClientName, + arg.ClientUri, arg.LogoUri, - arg.PolicyUri, + arg.Scopes, + arg.Contacts, arg.TosUri, + arg.PolicyUri, + arg.JwksUri, + arg.Jwks, arg.SoftwareID, arg.SoftwareVersion, - arg.Contacts, - arg.CreationMethod, - arg.Transport, + arg.CredentialsType, + arg.SectorIdentifierUri, + arg.SubjectType, + arg.IDTokenSignedResponseAlg, + arg.IDTokenEncryptedResponseAlg, + arg.IDTokenEncryptedResponseEnc, + arg.UserinfoSignedResponseAlg, + arg.UserinfoEncryptedResponseAlg, + arg.UserinfoEncryptedResponseEnc, + arg.RequestObjectSigningAlg, + arg.RequestObjectEncryptionAlg, + arg.RequestObjectEncryptionEnc, + arg.TokenEndpointAuthSigningAlg, + arg.DefaultMaxAge, + arg.RequireAuthTime, + arg.DefaultAcrValues, + arg.InitiateLoginUri, + arg.RequestUris, + arg.AccessTokenSigningAlg, ) var i AccountCredential err := row.Scan( &i.ID, &i.AccountID, &i.AccountPublicID, - &i.ClientID, - &i.Name, &i.Domain, - &i.CredentialsType, - &i.Scopes, + &i.CreationMethod, + &i.Transport, + &i.Version, + &i.ClientID, + &i.RedirectUris, &i.TokenEndpointAuthMethod, &i.GrantTypes, - &i.Version, - &i.Transport, - &i.CreationMethod, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, - &i.RedirectUris, &i.LogoUri, - &i.PolicyUri, + &i.Scopes, + &i.Contacts, &i.TosUri, + &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, + &i.CredentialsType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.CreatedAt, &i.UpdatedAt, ) @@ -193,7 +302,7 @@ func (q *Queries) DeleteAllAccountCredentials(ctx context.Context) error { } const findAccountCredentialsByAccountPublicIDAndClientID = `-- name: FindAccountCredentialsByAccountPublicIDAndClientID :one -SELECT id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at FROM "account_credentials" +SELECT id, account_id, account_public_id, domain, creation_method, transport, version, client_id, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, credentials_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, created_at, updated_at FROM "account_credentials" WHERE "account_public_id" = $1 AND "client_id" = $2 LIMIT 1 ` @@ -210,24 +319,45 @@ func (q *Queries) FindAccountCredentialsByAccountPublicIDAndClientID(ctx context &i.ID, &i.AccountID, &i.AccountPublicID, - &i.ClientID, - &i.Name, &i.Domain, - &i.CredentialsType, - &i.Scopes, + &i.CreationMethod, + &i.Transport, + &i.Version, + &i.ClientID, + &i.RedirectUris, &i.TokenEndpointAuthMethod, &i.GrantTypes, - &i.Version, - &i.Transport, - &i.CreationMethod, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, - &i.RedirectUris, &i.LogoUri, - &i.PolicyUri, + &i.Scopes, + &i.Contacts, &i.TosUri, + &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, + &i.CredentialsType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.CreatedAt, &i.UpdatedAt, ) @@ -236,7 +366,7 @@ func (q *Queries) FindAccountCredentialsByAccountPublicIDAndClientID(ctx context const findAccountCredentialsByClientID = `-- name: FindAccountCredentialsByClientID :one -SELECT id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at FROM "account_credentials" +SELECT id, account_id, account_public_id, domain, creation_method, transport, version, client_id, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, credentials_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, created_at, updated_at FROM "account_credentials" WHERE "client_id" = $1 LIMIT 1 ` @@ -253,24 +383,45 @@ func (q *Queries) FindAccountCredentialsByClientID(ctx context.Context, clientID &i.ID, &i.AccountID, &i.AccountPublicID, - &i.ClientID, - &i.Name, &i.Domain, - &i.CredentialsType, - &i.Scopes, + &i.CreationMethod, + &i.Transport, + &i.Version, + &i.ClientID, + &i.RedirectUris, &i.TokenEndpointAuthMethod, &i.GrantTypes, - &i.Version, - &i.Transport, - &i.CreationMethod, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, - &i.RedirectUris, &i.LogoUri, - &i.PolicyUri, + &i.Scopes, + &i.Contacts, &i.TosUri, + &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, + &i.CredentialsType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.CreatedAt, &i.UpdatedAt, ) @@ -278,7 +429,7 @@ func (q *Queries) FindAccountCredentialsByClientID(ctx context.Context, clientID } const findPaginatedAccountCredentialsByAccountPublicID = `-- name: FindPaginatedAccountCredentialsByAccountPublicID :many -SELECT id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at FROM "account_credentials" +SELECT id, account_id, account_public_id, domain, creation_method, transport, version, client_id, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, credentials_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, created_at, updated_at FROM "account_credentials" WHERE "account_public_id" = $1 ORDER BY "id" DESC OFFSET $2 LIMIT $3 @@ -303,24 +454,45 @@ func (q *Queries) FindPaginatedAccountCredentialsByAccountPublicID(ctx context.C &i.ID, &i.AccountID, &i.AccountPublicID, - &i.ClientID, - &i.Name, &i.Domain, - &i.CredentialsType, - &i.Scopes, + &i.CreationMethod, + &i.Transport, + &i.Version, + &i.ClientID, + &i.RedirectUris, &i.TokenEndpointAuthMethod, &i.GrantTypes, - &i.Version, - &i.Transport, - &i.CreationMethod, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, - &i.RedirectUris, &i.LogoUri, - &i.PolicyUri, + &i.Scopes, + &i.Contacts, &i.TosUri, + &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, + &i.CredentialsType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.CreatedAt, &i.UpdatedAt, ); err != nil { @@ -337,7 +509,7 @@ func (q *Queries) FindPaginatedAccountCredentialsByAccountPublicID(ctx context.C const updateAccountCredentials = `-- name: UpdateAccountCredentials :one UPDATE "account_credentials" SET "scopes" = $2, - "name" = $3, + "client_name" = $3, "domain" = $4, "client_uri" = $5, "redirect_uris" = $6, @@ -350,13 +522,13 @@ UPDATE "account_credentials" SET "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 -RETURNING id, account_id, account_public_id, client_id, name, domain, credentials_type, scopes, token_endpoint_auth_method, grant_types, version, transport, creation_method, client_uri, redirect_uris, logo_uri, policy_uri, tos_uri, software_id, software_version, contacts, created_at, updated_at +RETURNING id, account_id, account_public_id, domain, creation_method, transport, version, client_id, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, credentials_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, created_at, updated_at ` type UpdateAccountCredentialsParams struct { ID int32 Scopes []AccountCredentialsScope - Name string + ClientName string Domain string ClientUri string RedirectUris []string @@ -372,7 +544,7 @@ func (q *Queries) UpdateAccountCredentials(ctx context.Context, arg UpdateAccoun row := q.db.QueryRow(ctx, updateAccountCredentials, arg.ID, arg.Scopes, - arg.Name, + arg.ClientName, arg.Domain, arg.ClientUri, arg.RedirectUris, @@ -388,24 +560,45 @@ func (q *Queries) UpdateAccountCredentials(ctx context.Context, arg UpdateAccoun &i.ID, &i.AccountID, &i.AccountPublicID, - &i.ClientID, - &i.Name, &i.Domain, - &i.CredentialsType, - &i.Scopes, + &i.CreationMethod, + &i.Transport, + &i.Version, + &i.ClientID, + &i.RedirectUris, &i.TokenEndpointAuthMethod, &i.GrantTypes, - &i.Version, - &i.Transport, - &i.CreationMethod, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, - &i.RedirectUris, &i.LogoUri, - &i.PolicyUri, + &i.Scopes, + &i.Contacts, &i.TosUri, + &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, + &i.CredentialsType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.CreatedAt, &i.UpdatedAt, ) diff --git a/idp/internal/providers/database/account_credentials_keys.sql.go b/idp/internal/providers/database/account_credentials_keys.sql.go index 26f8441..706474a 100644 --- a/idp/internal/providers/database/account_credentials_keys.sql.go +++ b/idp/internal/providers/database/account_credentials_keys.sql.go @@ -67,7 +67,7 @@ func (q *Queries) CreateAccountCredentialKey(ctx context.Context, arg CreateAcco } const findAccountCredentialKeyByAccountCredentialIDAndPublicKID = `-- name: FindAccountCredentialKeyByAccountCredentialIDAndPublicKID :one -SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" +SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.is_external, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" LEFT JOIN "account_credentials_keys" "ack" ON "ack"."credentials_key_id" = "ckr"."id" WHERE "ack"."account_credentials_id" = $1 AND @@ -89,6 +89,7 @@ func (q *Queries) FindAccountCredentialKeyByAccountCredentialIDAndPublicKID(ctx &i.PublicKey, &i.CryptoSuite, &i.IsRevoked, + &i.IsExternal, &i.Usage, &i.AccountID, &i.ExpiresAt, @@ -134,7 +135,7 @@ func (q *Queries) FindAccountCredentialsKeyAccountByAccountCredentialIDAndJWKKID } const findActiveAccountCredentialKeysByAccountPublicID = `-- name: FindActiveAccountCredentialKeysByAccountPublicID :many -SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" +SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.is_external, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" LEFT JOIN "account_credentials_keys" "ack" ON "ack"."credentials_key_id" = "ckr"."id" WHERE "ack"."account_public_id" = $1 AND @@ -158,6 +159,7 @@ func (q *Queries) FindActiveAccountCredentialKeysByAccountPublicID(ctx context.C &i.PublicKey, &i.CryptoSuite, &i.IsRevoked, + &i.IsExternal, &i.Usage, &i.AccountID, &i.ExpiresAt, @@ -175,7 +177,7 @@ func (q *Queries) FindActiveAccountCredentialKeysByAccountPublicID(ctx context.C } const findCurrentAccountCredentialKeyByAccountCredentialID = `-- name: FindCurrentAccountCredentialKeyByAccountCredentialID :one -SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" +SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.is_external, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" LEFT JOIN "account_credentials_keys" "ack" ON "ack"."credentials_key_id" = "ckr"."id" WHERE "ack"."account_credentials_id" = $1 AND @@ -193,6 +195,7 @@ func (q *Queries) FindCurrentAccountCredentialKeyByAccountCredentialID(ctx conte &i.PublicKey, &i.CryptoSuite, &i.IsRevoked, + &i.IsExternal, &i.Usage, &i.AccountID, &i.ExpiresAt, @@ -203,7 +206,7 @@ func (q *Queries) FindCurrentAccountCredentialKeyByAccountCredentialID(ctx conte } const findPaginatedAccountCredentialKeysByAccountCredentialID = `-- name: FindPaginatedAccountCredentialKeysByAccountCredentialID :many -SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" +SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.is_external, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" LEFT JOIN "account_credentials_keys" "ack" ON "ack"."credentials_key_id" = "ckr"."id" WHERE "ack"."account_credentials_id" = $1 ORDER BY "ckr"."expires_at" DESC @@ -231,6 +234,7 @@ func (q *Queries) FindPaginatedAccountCredentialKeysByAccountCredentialID(ctx co &i.PublicKey, &i.CryptoSuite, &i.IsRevoked, + &i.IsExternal, &i.Usage, &i.AccountID, &i.ExpiresAt, diff --git a/idp/internal/providers/database/account_dynamic_registration_configs.sql.go b/idp/internal/providers/database/account_dynamic_registration_configs.sql.go index 8a0e69a..a850e8d 100644 --- a/idp/internal/providers/database/account_dynamic_registration_configs.sql.go +++ b/idp/internal/providers/database/account_dynamic_registration_configs.sql.go @@ -17,7 +17,6 @@ INSERT INTO "account_dynamic_registration_configs" ( "account_id", "account_public_id", "account_credentials_types", - "whitelisted_domains", "require_software_statement_credential_types", "software_statement_verification_methods", "require_initial_access_token_credential_types", @@ -29,16 +28,14 @@ INSERT INTO "account_dynamic_registration_configs" ( $4, $5, $6, - $7, - $8 -) RETURNING id, account_id, account_public_id, account_credentials_types, whitelisted_domains, require_software_statement_credential_types, software_statement_verification_methods, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at + $7 +) RETURNING id, account_id, account_public_id, account_credentials_types, require_software_statement_credential_types, software_statement_verification_methods, require_verified_domains_credentials_type, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at ` type CreateAccountDynamicRegistrationConfigParams struct { AccountID int32 AccountPublicID uuid.UUID AccountCredentialsTypes []AccountCredentialsType - WhitelistedDomains []string RequireSoftwareStatementCredentialTypes []AccountCredentialsType SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod RequireInitialAccessTokenCredentialTypes []AccountCredentialsType @@ -55,7 +52,6 @@ func (q *Queries) CreateAccountDynamicRegistrationConfig(ctx context.Context, ar arg.AccountID, arg.AccountPublicID, arg.AccountCredentialsTypes, - arg.WhitelistedDomains, arg.RequireSoftwareStatementCredentialTypes, arg.SoftwareStatementVerificationMethods, arg.RequireInitialAccessTokenCredentialTypes, @@ -67,9 +63,9 @@ func (q *Queries) CreateAccountDynamicRegistrationConfig(ctx context.Context, ar &i.AccountID, &i.AccountPublicID, &i.AccountCredentialsTypes, - &i.WhitelistedDomains, &i.RequireSoftwareStatementCredentialTypes, &i.SoftwareStatementVerificationMethods, + &i.RequireVerifiedDomainsCredentialsType, &i.RequireInitialAccessTokenCredentialTypes, &i.InitialAccessTokenGenerationMethods, &i.CreatedAt, @@ -88,7 +84,7 @@ func (q *Queries) DeleteAccountDynamicRegistrationConfig(ctx context.Context, id } const findAccountDynamicRegistrationConfigByAccountID = `-- name: FindAccountDynamicRegistrationConfigByAccountID :one -SELECT id, account_id, account_public_id, account_credentials_types, whitelisted_domains, require_software_statement_credential_types, software_statement_verification_methods, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at FROM "account_dynamic_registration_configs" +SELECT id, account_id, account_public_id, account_credentials_types, require_software_statement_credential_types, software_statement_verification_methods, require_verified_domains_credentials_type, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at FROM "account_dynamic_registration_configs" WHERE "account_id" = $1 LIMIT 1 ` @@ -100,9 +96,9 @@ func (q *Queries) FindAccountDynamicRegistrationConfigByAccountID(ctx context.Co &i.AccountID, &i.AccountPublicID, &i.AccountCredentialsTypes, - &i.WhitelistedDomains, &i.RequireSoftwareStatementCredentialTypes, &i.SoftwareStatementVerificationMethods, + &i.RequireVerifiedDomainsCredentialsType, &i.RequireInitialAccessTokenCredentialTypes, &i.InitialAccessTokenGenerationMethods, &i.CreatedAt, @@ -112,7 +108,7 @@ func (q *Queries) FindAccountDynamicRegistrationConfigByAccountID(ctx context.Co } const findAccountDynamicRegistrationConfigByAccountPublicID = `-- name: FindAccountDynamicRegistrationConfigByAccountPublicID :one -SELECT id, account_id, account_public_id, account_credentials_types, whitelisted_domains, require_software_statement_credential_types, software_statement_verification_methods, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at FROM "account_dynamic_registration_configs" +SELECT id, account_id, account_public_id, account_credentials_types, require_software_statement_credential_types, software_statement_verification_methods, require_verified_domains_credentials_type, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at FROM "account_dynamic_registration_configs" WHERE "account_public_id" = $1 LIMIT 1 ` @@ -124,9 +120,9 @@ func (q *Queries) FindAccountDynamicRegistrationConfigByAccountPublicID(ctx cont &i.AccountID, &i.AccountPublicID, &i.AccountCredentialsTypes, - &i.WhitelistedDomains, &i.RequireSoftwareStatementCredentialTypes, &i.SoftwareStatementVerificationMethods, + &i.RequireVerifiedDomainsCredentialsType, &i.RequireInitialAccessTokenCredentialTypes, &i.InitialAccessTokenGenerationMethods, &i.CreatedAt, @@ -138,19 +134,17 @@ func (q *Queries) FindAccountDynamicRegistrationConfigByAccountPublicID(ctx cont const updateAccountDynamicRegistrationConfig = `-- name: UpdateAccountDynamicRegistrationConfig :one UPDATE "account_dynamic_registration_configs" SET "account_credentials_types" = $2, - "whitelisted_domains" = $3, - "require_software_statement_credential_types" = $4, - "software_statement_verification_methods" = $5, - "require_initial_access_token_credential_types" = $6, - "initial_access_token_generation_methods" = $7 + "require_software_statement_credential_types" = $3, + "software_statement_verification_methods" = $4, + "require_initial_access_token_credential_types" = $5, + "initial_access_token_generation_methods" = $6 WHERE "id" = $1 -RETURNING id, account_id, account_public_id, account_credentials_types, whitelisted_domains, require_software_statement_credential_types, software_statement_verification_methods, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at +RETURNING id, account_id, account_public_id, account_credentials_types, require_software_statement_credential_types, software_statement_verification_methods, require_verified_domains_credentials_type, require_initial_access_token_credential_types, initial_access_token_generation_methods, created_at, updated_at ` type UpdateAccountDynamicRegistrationConfigParams struct { ID int32 AccountCredentialsTypes []AccountCredentialsType - WhitelistedDomains []string RequireSoftwareStatementCredentialTypes []AccountCredentialsType SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod RequireInitialAccessTokenCredentialTypes []AccountCredentialsType @@ -161,7 +155,6 @@ func (q *Queries) UpdateAccountDynamicRegistrationConfig(ctx context.Context, ar row := q.db.QueryRow(ctx, updateAccountDynamicRegistrationConfig, arg.ID, arg.AccountCredentialsTypes, - arg.WhitelistedDomains, arg.RequireSoftwareStatementCredentialTypes, arg.SoftwareStatementVerificationMethods, arg.RequireInitialAccessTokenCredentialTypes, @@ -173,9 +166,9 @@ func (q *Queries) UpdateAccountDynamicRegistrationConfig(ctx context.Context, ar &i.AccountID, &i.AccountPublicID, &i.AccountCredentialsTypes, - &i.WhitelistedDomains, &i.RequireSoftwareStatementCredentialTypes, &i.SoftwareStatementVerificationMethods, + &i.RequireVerifiedDomainsCredentialsType, &i.RequireInitialAccessTokenCredentialTypes, &i.InitialAccessTokenGenerationMethods, &i.CreatedAt, diff --git a/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go b/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go deleted file mode 100644 index 1abe6eb..0000000 --- a/idp/internal/providers/database/account_dynamic_registration_domain_codes.sql.go +++ /dev/null @@ -1,63 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.29.0 -// source: account_dynamic_registration_domain_codes.sql - -package database - -import ( - "context" -) - -const createAccountDynamicRegistrationDomainCode = `-- name: CreateAccountDynamicRegistrationDomainCode :exec - -INSERT INTO "account_dynamic_registration_domain_codes" ( - "account_dynamic_registration_domain_id", - "dynamic_registration_domain_code_id", - "account_id" -) VALUES ( - $1, - $2, - $3 -) -` - -type CreateAccountDynamicRegistrationDomainCodeParams struct { - AccountDynamicRegistrationDomainID int32 - DynamicRegistrationDomainCodeID int32 - AccountID int32 -} - -// Copyright (c) 2025 Afonso Barracha -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. -func (q *Queries) CreateAccountDynamicRegistrationDomainCode(ctx context.Context, arg CreateAccountDynamicRegistrationDomainCodeParams) error { - _, err := q.db.Exec(ctx, createAccountDynamicRegistrationDomainCode, arg.AccountDynamicRegistrationDomainID, arg.DynamicRegistrationDomainCodeID, arg.AccountID) - return err -} - -const findDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID = `-- name: FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID :one -SELECT d.id, d.account_id, d.verification_host, d.verification_code, d.hmac_secret_id, d.verification_prefix, d.expires_at, d.created_at, d.updated_at FROM "dynamic_registration_domain_codes" "d" -LEFT JOIN "account_dynamic_registration_domain_codes" "a" ON "d"."id" = "a"."dynamic_registration_domain_code_id" -WHERE "a"."account_dynamic_registration_domain_id" = $1 -LIMIT 1 -` - -func (q *Queries) FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID(ctx context.Context, accountDynamicRegistrationDomainID int32) (DynamicRegistrationDomainCode, error) { - row := q.db.QueryRow(ctx, findDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID, accountDynamicRegistrationDomainID) - var i DynamicRegistrationDomainCode - err := row.Scan( - &i.ID, - &i.AccountID, - &i.VerificationHost, - &i.VerificationCode, - &i.HmacSecretID, - &i.VerificationPrefix, - &i.ExpiresAt, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} diff --git a/idp/internal/providers/database/account_dynamic_registration_domains.sql.go b/idp/internal/providers/database/account_dynamic_registration_domains.sql.go deleted file mode 100644 index 4137186..0000000 --- a/idp/internal/providers/database/account_dynamic_registration_domains.sql.go +++ /dev/null @@ -1,408 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.29.0 -// source: account_dynamic_registration_domains.sql - -package database - -import ( - "context" - - "github.com/google/uuid" -) - -const countAccountDynamicRegistrationDomainsByAccountPublicID = `-- name: CountAccountDynamicRegistrationDomainsByAccountPublicID :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE "account_public_id" = $1 -` - -func (q *Queries) CountAccountDynamicRegistrationDomainsByAccountPublicID(ctx context.Context, accountPublicID uuid.UUID) (int64, error) { - row := q.db.QueryRow(ctx, countAccountDynamicRegistrationDomainsByAccountPublicID, accountPublicID) - var count int64 - err := row.Scan(&count) - return count, err -} - -const countFilteredAccountDynamicRegistrationDomainsByAccountPublicID = `-- name: CountFilteredAccountDynamicRegistrationDomainsByAccountPublicID :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" ILIKE $2 -LIMIT 1 -` - -type CountFilteredAccountDynamicRegistrationDomainsByAccountPublicIDParams struct { - AccountPublicID uuid.UUID - Domain string -} - -func (q *Queries) CountFilteredAccountDynamicRegistrationDomainsByAccountPublicID(ctx context.Context, arg CountFilteredAccountDynamicRegistrationDomainsByAccountPublicIDParams) (int64, error) { - row := q.db.QueryRow(ctx, countFilteredAccountDynamicRegistrationDomainsByAccountPublicID, arg.AccountPublicID, arg.Domain) - var count int64 - err := row.Scan(&count) - return count, err -} - -const countVerifiedAccountDynamicRegistrationDomainsByDomain = `-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomain :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE "domain" = $1 AND "verified_at" IS NOT NULL -LIMIT 1 -` - -func (q *Queries) CountVerifiedAccountDynamicRegistrationDomainsByDomain(ctx context.Context, domain string) (int64, error) { - row := q.db.QueryRow(ctx, countVerifiedAccountDynamicRegistrationDomainsByDomain, domain) - var count int64 - err := row.Scan(&count) - return count, err -} - -const countVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID = `-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" = $2 AND - "verified_at" IS NOT NULL -LIMIT 1 -` - -type CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams struct { - AccountPublicID uuid.UUID - Domain string -} - -func (q *Queries) CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID(ctx context.Context, arg CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams) (int64, error) { - row := q.db.QueryRow(ctx, countVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID, arg.AccountPublicID, arg.Domain) - var count int64 - err := row.Scan(&count) - return count, err -} - -const countVerifiedAccountDynamicRegistrationDomainsByDomains = `-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomains :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE "domain" IN ($1) AND "verified_at" IS NOT NULL -LIMIT 1 -` - -func (q *Queries) CountVerifiedAccountDynamicRegistrationDomainsByDomains(ctx context.Context, domains []string) (int64, error) { - row := q.db.QueryRow(ctx, countVerifiedAccountDynamicRegistrationDomainsByDomains, domains) - var count int64 - err := row.Scan(&count) - return count, err -} - -const countVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID = `-- name: CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" IN ($2) AND - "verified_at" IS NOT NULL -LIMIT 1 -` - -type CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams struct { - AccountPublicID uuid.UUID - Domains []string -} - -func (q *Queries) CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID(ctx context.Context, arg CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams) (int64, error) { - row := q.db.QueryRow(ctx, countVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID, arg.AccountPublicID, arg.Domains) - var count int64 - err := row.Scan(&count) - return count, err -} - -const createAccountDynamicRegistrationDomain = `-- name: CreateAccountDynamicRegistrationDomain :one - -INSERT INTO "account_dynamic_registration_domains" ( - "account_id", - "account_public_id", - "domain", - "verification_method" -) VALUES ( - $1, - $2, - $3, - $4 -) RETURNING id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at -` - -type CreateAccountDynamicRegistrationDomainParams struct { - AccountID int32 - AccountPublicID uuid.UUID - Domain string - VerificationMethod DomainVerificationMethod -} - -// Copyright (c) 2025 Afonso Barracha -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. -func (q *Queries) CreateAccountDynamicRegistrationDomain(ctx context.Context, arg CreateAccountDynamicRegistrationDomainParams) (AccountDynamicRegistrationDomain, error) { - row := q.db.QueryRow(ctx, createAccountDynamicRegistrationDomain, - arg.AccountID, - arg.AccountPublicID, - arg.Domain, - arg.VerificationMethod, - ) - var i AccountDynamicRegistrationDomain - err := row.Scan( - &i.ID, - &i.AccountID, - &i.AccountPublicID, - &i.Domain, - &i.VerifiedAt, - &i.VerificationMethod, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const deleteAccountDynamicRegistrationDomain = `-- name: DeleteAccountDynamicRegistrationDomain :exec -DELETE FROM "account_dynamic_registration_domains" -WHERE "id" = $1 -` - -func (q *Queries) DeleteAccountDynamicRegistrationDomain(ctx context.Context, id int32) error { - _, err := q.db.Exec(ctx, deleteAccountDynamicRegistrationDomain, id) - return err -} - -const filterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain = `-- name: FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many -SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" ILIKE $2 -ORDER BY "domain" ASC -LIMIT $3 OFFSET $4 -` - -type FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams struct { - AccountPublicID uuid.UUID - Domain string - Limit int32 - Offset int32 -} - -func (q *Queries) FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain(ctx context.Context, arg FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams) ([]AccountDynamicRegistrationDomain, error) { - rows, err := q.db.Query(ctx, filterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain, - arg.AccountPublicID, - arg.Domain, - arg.Limit, - arg.Offset, - ) - if err != nil { - return nil, err - } - defer rows.Close() - items := []AccountDynamicRegistrationDomain{} - for rows.Next() { - var i AccountDynamicRegistrationDomain - if err := rows.Scan( - &i.ID, - &i.AccountID, - &i.AccountPublicID, - &i.Domain, - &i.VerifiedAt, - &i.VerificationMethod, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const filterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID = `-- name: FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many -SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" ILIKE $2 -ORDER BY "id" DESC -LIMIT $3 OFFSET $4 -` - -type FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams struct { - AccountPublicID uuid.UUID - Domain string - Limit int32 - Offset int32 -} - -func (q *Queries) FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID(ctx context.Context, arg FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams) ([]AccountDynamicRegistrationDomain, error) { - rows, err := q.db.Query(ctx, filterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID, - arg.AccountPublicID, - arg.Domain, - arg.Limit, - arg.Offset, - ) - if err != nil { - return nil, err - } - defer rows.Close() - items := []AccountDynamicRegistrationDomain{} - for rows.Next() { - var i AccountDynamicRegistrationDomain - if err := rows.Scan( - &i.ID, - &i.AccountID, - &i.AccountPublicID, - &i.Domain, - &i.VerifiedAt, - &i.VerificationMethod, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const findAccountDynamicRegistrationDomainByAccountPublicIDAndDomain = `-- name: FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain :one -SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 LIMIT 1 -` - -type FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams struct { - AccountPublicID uuid.UUID - Domain string -} - -func (q *Queries) FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx context.Context, arg FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams) (AccountDynamicRegistrationDomain, error) { - row := q.db.QueryRow(ctx, findAccountDynamicRegistrationDomainByAccountPublicIDAndDomain, arg.AccountPublicID, arg.Domain) - var i AccountDynamicRegistrationDomain - err := row.Scan( - &i.ID, - &i.AccountID, - &i.AccountPublicID, - &i.Domain, - &i.VerifiedAt, - &i.VerificationMethod, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const findPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain = `-- name: FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many -SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" -WHERE "account_public_id" = $1 -ORDER BY "domain" ASC -LIMIT $2 OFFSET $3 -` - -type FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams struct { - AccountPublicID uuid.UUID - Limit int32 - Offset int32 -} - -func (q *Queries) FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain(ctx context.Context, arg FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams) ([]AccountDynamicRegistrationDomain, error) { - rows, err := q.db.Query(ctx, findPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain, arg.AccountPublicID, arg.Limit, arg.Offset) - if err != nil { - return nil, err - } - defer rows.Close() - items := []AccountDynamicRegistrationDomain{} - for rows.Next() { - var i AccountDynamicRegistrationDomain - if err := rows.Scan( - &i.ID, - &i.AccountID, - &i.AccountPublicID, - &i.Domain, - &i.VerifiedAt, - &i.VerificationMethod, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const findPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID = `-- name: FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many -SELECT id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at FROM "account_dynamic_registration_domains" -WHERE "account_public_id" = $1 -ORDER BY "id" DESC -LIMIT $2 OFFSET $3 -` - -type FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams struct { - AccountPublicID uuid.UUID - Limit int32 - Offset int32 -} - -func (q *Queries) FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID(ctx context.Context, arg FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams) ([]AccountDynamicRegistrationDomain, error) { - rows, err := q.db.Query(ctx, findPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID, arg.AccountPublicID, arg.Limit, arg.Offset) - if err != nil { - return nil, err - } - defer rows.Close() - items := []AccountDynamicRegistrationDomain{} - for rows.Next() { - var i AccountDynamicRegistrationDomain - if err := rows.Scan( - &i.ID, - &i.AccountID, - &i.AccountPublicID, - &i.Domain, - &i.VerifiedAt, - &i.VerificationMethod, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const verifyAccountDynamicRegistrationDomain = `-- name: VerifyAccountDynamicRegistrationDomain :one -UPDATE "account_dynamic_registration_domains" -SET - "verified_at" = NOW(), - "verification_method" = $2 -WHERE "id" = $1 RETURNING id, account_id, account_public_id, domain, verified_at, verification_method, created_at, updated_at -` - -type VerifyAccountDynamicRegistrationDomainParams struct { - ID int32 - VerificationMethod DomainVerificationMethod -} - -func (q *Queries) VerifyAccountDynamicRegistrationDomain(ctx context.Context, arg VerifyAccountDynamicRegistrationDomainParams) (AccountDynamicRegistrationDomain, error) { - row := q.db.QueryRow(ctx, verifyAccountDynamicRegistrationDomain, arg.ID, arg.VerificationMethod) - var i AccountDynamicRegistrationDomain - err := row.Scan( - &i.ID, - &i.AccountID, - &i.AccountPublicID, - &i.Domain, - &i.VerifiedAt, - &i.VerificationMethod, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} diff --git a/idp/internal/providers/database/app_keys.sql.go b/idp/internal/providers/database/app_keys.sql.go index 70971f3..54d8fa8 100644 --- a/idp/internal/providers/database/app_keys.sql.go +++ b/idp/internal/providers/database/app_keys.sql.go @@ -53,7 +53,7 @@ func (q *Queries) CreateAppKey(ctx context.Context, arg CreateAppKeyParams) erro } const findAppKeyByAppIDAndPublicKID = `-- name: FindAppKeyByAppIDAndPublicKID :one -SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" +SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.is_external, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" LEFT JOIN "app_keys" "ak" ON "ak"."credentials_key_id" = "ckr"."id" WHERE "ak"."app_id" = $1 AND @@ -75,6 +75,7 @@ func (q *Queries) FindAppKeyByAppIDAndPublicKID(ctx context.Context, arg FindApp &i.PublicKey, &i.CryptoSuite, &i.IsRevoked, + &i.IsExternal, &i.Usage, &i.AccountID, &i.ExpiresAt, @@ -85,7 +86,7 @@ func (q *Queries) FindAppKeyByAppIDAndPublicKID(ctx context.Context, arg FindApp } const findPaginatedAppKeysByAppID = `-- name: FindPaginatedAppKeysByAppID :many -SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" +SELECT ckr.id, ckr.public_kid, ckr.public_key, ckr.crypto_suite, ckr.is_revoked, ckr.is_external, ckr.usage, ckr.account_id, ckr.expires_at, ckr.created_at, ckr.updated_at FROM "credentials_keys" "ckr" LEFT JOIN "app_keys" "ak" ON "ak"."credentials_key_id" = "ckr"."id" WHERE "ak"."app_id" = $1 ORDER BY "ckr"."expires_at" DESC @@ -113,6 +114,7 @@ func (q *Queries) FindPaginatedAppKeysByAppID(ctx context.Context, arg FindPagin &i.PublicKey, &i.CryptoSuite, &i.IsRevoked, + &i.IsExternal, &i.Usage, &i.AccountID, &i.ExpiresAt, diff --git a/idp/internal/providers/database/app_related_apps.sql.go b/idp/internal/providers/database/app_related_apps.sql.go index 4894b01..29d7c22 100644 --- a/idp/internal/providers/database/app_related_apps.sql.go +++ b/idp/internal/providers/database/app_related_apps.sql.go @@ -54,10 +54,10 @@ func (q *Queries) DeleteAppRelatedAppsByAppIDAndRelatedAppIDs(ctx context.Contex } const findRelatedAppsByAppID = `-- name: FindRelatedAppsByAppID :many -SELECT a.id, a.account_id, a.account_public_id, a.app_type, a.name, a.client_id, a.version, a.creation_method, a.client_uri, a.logo_uri, a.tos_uri, a.policy_uri, a.software_id, a.software_version, a.contacts, a.token_endpoint_auth_method, a.scopes, a.custom_scopes, a.grant_types, a.domain, a.transport, a.allow_user_registration, a.auth_providers, a.username_column, a.default_scopes, a.default_custom_scopes, a.redirect_uris, a.response_types, a.id_token_ttl, a.token_ttl, a.refresh_token_ttl, a.created_at, a.updated_at FROM "apps" a +SELECT a.id, a.account_id, a.account_public_id, a.client_id, a.version, a.creation_method, a.redirect_uris, a.token_endpoint_auth_method, a.grant_types, a.response_types, a.client_name, a.client_uri, a.logo_uri, a.scopes, a.custom_scopes, a.contacts, a.tos_uri, a.policy_uri, a.jwks_uri, a.jwks, a.software_id, a.software_version, a.domain, a.transport, a.allow_user_registration, a.auth_providers, a.username_column, a.default_scopes, a.default_custom_scopes, a.app_type, a.sector_identifier_uri, a.subject_type, a.id_token_signed_response_alg, a.id_token_encrypted_response_alg, a.id_token_encrypted_response_enc, a.userinfo_signed_response_alg, a.userinfo_encrypted_response_alg, a.userinfo_encrypted_response_enc, a.request_object_signing_alg, a.request_object_encryption_alg, a.request_object_encryption_enc, a.token_endpoint_auth_signing_alg, a.default_max_age, a.require_auth_time, a.default_acr_values, a.initiate_login_uri, a.request_uris, a.access_token_signing_alg, a.id_token_ttl, a.token_ttl, a.refresh_token_ttl, a.created_at, a.updated_at FROM "apps" a INNER JOIN "app_related_apps" ara ON a.id = ara.related_app_id WHERE ara.app_id = $1 -ORDER BY a.name ASC +ORDER BY a.client_name ASC ` func (q *Queries) FindRelatedAppsByAppID(ctx context.Context, appID int32) ([]App, error) { @@ -73,22 +73,25 @@ func (q *Queries) FindRelatedAppsByAppID(ctx context.Context, appID int32) ([]Ap &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -96,8 +99,25 @@ func (q *Queries) FindRelatedAppsByAppID(ctx context.Context, appID int32) ([]Ap &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, diff --git a/idp/internal/providers/database/apps.sql.go b/idp/internal/providers/database/apps.sql.go index cb76eb8..8c0ffb6 100644 --- a/idp/internal/providers/database/apps.sql.go +++ b/idp/internal/providers/database/apps.sql.go @@ -14,17 +14,17 @@ import ( const countAppsByAccountIDAndName = `-- name: CountAppsByAccountIDAndName :one SELECT COUNT(*) FROM "apps" -WHERE "account_id" = $1 AND "name" = $2 +WHERE "account_id" = $1 AND "client_name" = $2 LIMIT 1 ` type CountAppsByAccountIDAndNameParams struct { - AccountID int32 - Name string + AccountID int32 + ClientName string } func (q *Queries) CountAppsByAccountIDAndName(ctx context.Context, arg CountAppsByAccountIDAndNameParams) (int64, error) { - row := q.db.QueryRow(ctx, countAppsByAccountIDAndName, arg.AccountID, arg.Name) + row := q.db.QueryRow(ctx, countAppsByAccountIDAndName, arg.AccountID, arg.ClientName) var count int64 err := row.Scan(&count) return count, err @@ -63,17 +63,17 @@ func (q *Queries) CountAppsByClientIDAndAccountPublicID(ctx context.Context, arg const countFilteredAppsByNameAndByAccountPublicID = `-- name: CountFilteredAppsByNameAndByAccountPublicID :one SELECT COUNT(*) FROM "apps" -WHERE "account_public_id" = $1 AND "name" ILIKE $2 +WHERE "account_public_id" = $1 AND "client_name" ILIKE $2 LIMIT 1 ` type CountFilteredAppsByNameAndByAccountPublicIDParams struct { AccountPublicID uuid.UUID - Name string + ClientName string } func (q *Queries) CountFilteredAppsByNameAndByAccountPublicID(ctx context.Context, arg CountFilteredAppsByNameAndByAccountPublicIDParams) (int64, error) { - row := q.db.QueryRow(ctx, countFilteredAppsByNameAndByAccountPublicID, arg.AccountPublicID, arg.Name) + row := q.db.QueryRow(ctx, countFilteredAppsByNameAndByAccountPublicID, arg.AccountPublicID, arg.ClientName) var count int64 err := row.Scan(&count) return count, err @@ -82,19 +82,19 @@ func (q *Queries) CountFilteredAppsByNameAndByAccountPublicID(ctx context.Contex const countFilteredAppsByNameAndTypeAndByAccountPublicID = `-- name: CountFilteredAppsByNameAndTypeAndByAccountPublicID :one SELECT COUNT(*) FROM "apps" WHERE "account_public_id" = $1 AND - "name" ILIKE $2 AND + "client_name" ILIKE $2 AND "app_type" = $3 LIMIT 1 ` type CountFilteredAppsByNameAndTypeAndByAccountPublicIDParams struct { AccountPublicID uuid.UUID - Name string + ClientName string AppType AppType } func (q *Queries) CountFilteredAppsByNameAndTypeAndByAccountPublicID(ctx context.Context, arg CountFilteredAppsByNameAndTypeAndByAccountPublicIDParams) (int64, error) { - row := q.db.QueryRow(ctx, countFilteredAppsByNameAndTypeAndByAccountPublicID, arg.AccountPublicID, arg.Name, arg.AppType) + row := q.db.QueryRow(ctx, countFilteredAppsByNameAndTypeAndByAccountPublicID, arg.AccountPublicID, arg.ClientName, arg.AppType) var count int64 err := row.Scan(&count) return count, err @@ -124,7 +124,7 @@ INSERT INTO "apps" ( "account_id", "account_public_id", "app_type", - "name", + "client_name", "client_id", "client_uri", "username_column", @@ -174,14 +174,14 @@ INSERT INTO "apps" ( $24, $25, $26 -) RETURNING id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at +) RETURNING id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at ` type CreateAppParams struct { AccountID int32 AccountPublicID uuid.UUID AppType AppType - Name string + ClientName string ClientID string ClientUri string UsernameColumn AppUsernameColumn @@ -192,7 +192,7 @@ type CreateAppParams struct { TosUri pgtype.Text PolicyUri pgtype.Text Contacts []string - SoftwareID string + SoftwareID pgtype.Text SoftwareVersion pgtype.Text Scopes []Scopes DefaultScopes []Scopes @@ -216,7 +216,7 @@ func (q *Queries) CreateApp(ctx context.Context, arg CreateAppParams) (App, erro arg.AccountID, arg.AccountPublicID, arg.AppType, - arg.Name, + arg.ClientName, arg.ClientID, arg.ClientUri, arg.UsernameColumn, @@ -245,22 +245,25 @@ func (q *Queries) CreateApp(ctx context.Context, arg CreateAppParams) (App, erro &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -268,8 +271,25 @@ func (q *Queries) CreateApp(ctx context.Context, arg CreateAppParams) (App, erro &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -299,15 +319,15 @@ func (q *Queries) DeleteApp(ctx context.Context, id int32) error { } const filterAppsByNameAndByAccountPublicIDOrderedByID = `-- name: FilterAppsByNameAndByAccountPublicIDOrderedByID :many -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" -WHERE "account_public_id" = $1 AND "name" ILIKE $2 +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +WHERE "account_public_id" = $1 AND "client_name" ILIKE $2 ORDER BY "id" DESC OFFSET $3 LIMIT $4 ` type FilterAppsByNameAndByAccountPublicIDOrderedByIDParams struct { AccountPublicID uuid.UUID - Name string + ClientName string Offset int32 Limit int32 } @@ -315,7 +335,7 @@ type FilterAppsByNameAndByAccountPublicIDOrderedByIDParams struct { func (q *Queries) FilterAppsByNameAndByAccountPublicIDOrderedByID(ctx context.Context, arg FilterAppsByNameAndByAccountPublicIDOrderedByIDParams) ([]App, error) { rows, err := q.db.Query(ctx, filterAppsByNameAndByAccountPublicIDOrderedByID, arg.AccountPublicID, - arg.Name, + arg.ClientName, arg.Offset, arg.Limit, ) @@ -330,22 +350,25 @@ func (q *Queries) FilterAppsByNameAndByAccountPublicIDOrderedByID(ctx context.Co &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -353,8 +376,25 @@ func (q *Queries) FilterAppsByNameAndByAccountPublicIDOrderedByID(ctx context.Co &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -372,15 +412,15 @@ func (q *Queries) FilterAppsByNameAndByAccountPublicIDOrderedByID(ctx context.Co } const filterAppsByNameAndByAccountPublicIDOrderedByName = `-- name: FilterAppsByNameAndByAccountPublicIDOrderedByName :many -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" -WHERE "account_public_id" = $1 AND "name" ILIKE $2 -ORDER BY "name" ASC +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +WHERE "account_public_id" = $1 AND "client_name" ILIKE $2 +ORDER BY "client_name" ASC OFFSET $3 LIMIT $4 ` type FilterAppsByNameAndByAccountPublicIDOrderedByNameParams struct { AccountPublicID uuid.UUID - Name string + ClientName string Offset int32 Limit int32 } @@ -388,7 +428,7 @@ type FilterAppsByNameAndByAccountPublicIDOrderedByNameParams struct { func (q *Queries) FilterAppsByNameAndByAccountPublicIDOrderedByName(ctx context.Context, arg FilterAppsByNameAndByAccountPublicIDOrderedByNameParams) ([]App, error) { rows, err := q.db.Query(ctx, filterAppsByNameAndByAccountPublicIDOrderedByName, arg.AccountPublicID, - arg.Name, + arg.ClientName, arg.Offset, arg.Limit, ) @@ -403,22 +443,25 @@ func (q *Queries) FilterAppsByNameAndByAccountPublicIDOrderedByName(ctx context. &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -426,8 +469,25 @@ func (q *Queries) FilterAppsByNameAndByAccountPublicIDOrderedByName(ctx context. &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -445,9 +505,9 @@ func (q *Queries) FilterAppsByNameAndByAccountPublicIDOrderedByName(ctx context. } const filterAppsByNameAndTypeAndByAccountPublicIDOrderedByID = `-- name: FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByID :many -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "account_public_id" = $1 AND - "name" ILIKE $2 AND + "client_name" ILIKE $2 AND "app_type" = $3 ORDER BY "id" DESC OFFSET $4 LIMIT $5 @@ -455,7 +515,7 @@ OFFSET $4 LIMIT $5 type FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByIDParams struct { AccountPublicID uuid.UUID - Name string + ClientName string AppType AppType Offset int32 Limit int32 @@ -464,7 +524,7 @@ type FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByIDParams struct { func (q *Queries) FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByID(ctx context.Context, arg FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByIDParams) ([]App, error) { rows, err := q.db.Query(ctx, filterAppsByNameAndTypeAndByAccountPublicIDOrderedByID, arg.AccountPublicID, - arg.Name, + arg.ClientName, arg.AppType, arg.Offset, arg.Limit, @@ -480,22 +540,25 @@ func (q *Queries) FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByID(ctx con &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -503,8 +566,25 @@ func (q *Queries) FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByID(ctx con &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -522,17 +602,17 @@ func (q *Queries) FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByID(ctx con } const filterAppsByNameAndTypeAndByAccountPublicIDOrderedByName = `-- name: FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByName :many -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "account_public_id" = $1 AND - "name" ILIKE $2 AND + "client_name" ILIKE $2 AND "app_type" = $3 -ORDER BY "name" ASC +ORDER BY "client_name" ASC OFFSET $4 LIMIT $5 ` type FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByNameParams struct { AccountPublicID uuid.UUID - Name string + ClientName string AppType AppType Offset int32 Limit int32 @@ -541,7 +621,7 @@ type FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByNameParams struct { func (q *Queries) FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByName(ctx context.Context, arg FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByNameParams) ([]App, error) { rows, err := q.db.Query(ctx, filterAppsByNameAndTypeAndByAccountPublicIDOrderedByName, arg.AccountPublicID, - arg.Name, + arg.ClientName, arg.AppType, arg.Offset, arg.Limit, @@ -557,22 +637,25 @@ func (q *Queries) FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByName(ctx c &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -580,8 +663,25 @@ func (q *Queries) FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByName(ctx c &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -599,7 +699,7 @@ func (q *Queries) FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByName(ctx c } const filterAppsByTypeAndByAccountPublicIDOrderedByID = `-- name: FilterAppsByTypeAndByAccountPublicIDOrderedByID :many -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "account_public_id" = $1 AND "app_type" = $2 ORDER BY "id" DESC OFFSET $3 LIMIT $4 @@ -630,22 +730,25 @@ func (q *Queries) FilterAppsByTypeAndByAccountPublicIDOrderedByID(ctx context.Co &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -653,8 +756,25 @@ func (q *Queries) FilterAppsByTypeAndByAccountPublicIDOrderedByID(ctx context.Co &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -672,9 +792,9 @@ func (q *Queries) FilterAppsByTypeAndByAccountPublicIDOrderedByID(ctx context.Co } const filterAppsByTypeAndByAccountPublicIDOrderedByName = `-- name: FilterAppsByTypeAndByAccountPublicIDOrderedByName :many -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "account_public_id" = $1 AND "app_type" = $2 -ORDER BY "name" ASC +ORDER BY "client_name" ASC OFFSET $3 LIMIT $4 ` @@ -703,22 +823,25 @@ func (q *Queries) FilterAppsByTypeAndByAccountPublicIDOrderedByName(ctx context. &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -726,8 +849,25 @@ func (q *Queries) FilterAppsByTypeAndByAccountPublicIDOrderedByName(ctx context. &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -745,7 +885,7 @@ func (q *Queries) FilterAppsByTypeAndByAccountPublicIDOrderedByName(ctx context. } const findAppByClientID = `-- name: FindAppByClientID :one -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "client_id" = $1 LIMIT 1 ` @@ -756,22 +896,25 @@ func (q *Queries) FindAppByClientID(ctx context.Context, clientID string) (App, &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -779,8 +922,25 @@ func (q *Queries) FindAppByClientID(ctx context.Context, clientID string) (App, &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -791,7 +951,7 @@ func (q *Queries) FindAppByClientID(ctx context.Context, clientID string) (App, } const findAppByClientIDAndAccountPublicID = `-- name: FindAppByClientIDAndAccountPublicID :one -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "client_id" = $1 AND "account_public_id" = $2 LIMIT 1 ` @@ -808,22 +968,25 @@ func (q *Queries) FindAppByClientIDAndAccountPublicID(ctx context.Context, arg F &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -831,8 +994,25 @@ func (q *Queries) FindAppByClientIDAndAccountPublicID(ctx context.Context, arg F &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -843,7 +1023,7 @@ func (q *Queries) FindAppByClientIDAndAccountPublicID(ctx context.Context, arg F } const findAppByClientIDAndVersion = `-- name: FindAppByClientIDAndVersion :one -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "client_id" = $1 AND "version" = $2 LIMIT 1 ` @@ -859,22 +1039,25 @@ func (q *Queries) FindAppByClientIDAndVersion(ctx context.Context, arg FindAppBy &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -882,8 +1065,25 @@ func (q *Queries) FindAppByClientIDAndVersion(ctx context.Context, arg FindAppBy &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -894,7 +1094,7 @@ func (q *Queries) FindAppByClientIDAndVersion(ctx context.Context, arg FindAppBy } const findAppByID = `-- name: FindAppByID :one -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "id" = $1 LIMIT 1 ` @@ -905,22 +1105,25 @@ func (q *Queries) FindAppByID(ctx context.Context, id int32) (App, error) { &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -928,8 +1131,25 @@ func (q *Queries) FindAppByID(ctx context.Context, id int32) (App, error) { &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -940,9 +1160,9 @@ func (q *Queries) FindAppByID(ctx context.Context, id int32) (App, error) { } const findAppsByClientIDsAndAccountID = `-- name: FindAppsByClientIDsAndAccountID :many -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "client_id" IN ($3) AND "account_id" = $1 -ORDER BY "name" ASC LIMIT $2 +ORDER BY "client_name" ASC LIMIT $2 ` type FindAppsByClientIDsAndAccountIDParams struct { @@ -964,22 +1184,25 @@ func (q *Queries) FindAppsByClientIDsAndAccountID(ctx context.Context, arg FindA &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -987,8 +1210,25 @@ func (q *Queries) FindAppsByClientIDsAndAccountID(ctx context.Context, arg FindA &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -1006,7 +1246,7 @@ func (q *Queries) FindAppsByClientIDsAndAccountID(ctx context.Context, arg FindA } const findPaginatedAppsByAccountPublicIDOrderedByID = `-- name: FindPaginatedAppsByAccountPublicIDOrderedByID :many -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "account_public_id" = $1 ORDER BY "id" DESC OFFSET $2 LIMIT $3 @@ -1031,22 +1271,25 @@ func (q *Queries) FindPaginatedAppsByAccountPublicIDOrderedByID(ctx context.Cont &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -1054,8 +1297,25 @@ func (q *Queries) FindPaginatedAppsByAccountPublicIDOrderedByID(ctx context.Cont &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -1073,9 +1333,9 @@ func (q *Queries) FindPaginatedAppsByAccountPublicIDOrderedByID(ctx context.Cont } const findPaginatedAppsByAccountPublicIDOrderedByName = `-- name: FindPaginatedAppsByAccountPublicIDOrderedByName :many -SELECT id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" +SELECT id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at FROM "apps" WHERE "account_public_id" = $1 -ORDER BY "name" ASC +ORDER BY "client_name" ASC OFFSET $2 LIMIT $3 ` @@ -1098,22 +1358,25 @@ func (q *Queries) FindPaginatedAppsByAccountPublicIDOrderedByName(ctx context.Co &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -1121,8 +1384,25 @@ func (q *Queries) FindPaginatedAppsByAccountPublicIDOrderedByName(ctx context.Co &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -1141,7 +1421,7 @@ func (q *Queries) FindPaginatedAppsByAccountPublicIDOrderedByName(ctx context.Co const updateApp = `-- name: UpdateApp :one UPDATE "apps" -SET "name" = $2, +SET "client_name" = $2, "username_column" = $3, "client_uri" = $4, "logo_uri" = $5, @@ -1158,12 +1438,12 @@ SET "name" = $2, "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 -RETURNING id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at +RETURNING id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at ` type UpdateAppParams struct { ID int32 - Name string + ClientName string UsernameColumn AppUsernameColumn ClientUri string LogoUri pgtype.Text @@ -1182,7 +1462,7 @@ type UpdateAppParams struct { func (q *Queries) UpdateApp(ctx context.Context, arg UpdateAppParams) (App, error) { row := q.db.QueryRow(ctx, updateApp, arg.ID, - arg.Name, + arg.ClientName, arg.UsernameColumn, arg.ClientUri, arg.LogoUri, @@ -1202,22 +1482,25 @@ func (q *Queries) UpdateApp(ctx context.Context, arg UpdateAppParams) (App, erro &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -1225,8 +1508,25 @@ func (q *Queries) UpdateApp(ctx context.Context, arg UpdateAppParams) (App, erro &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -1245,7 +1545,7 @@ SET "scopes" = $2, "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 -RETURNING id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at +RETURNING id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at ` type UpdateAppScopesParams struct { @@ -1269,22 +1569,25 @@ func (q *Queries) UpdateAppScopes(ctx context.Context, arg UpdateAppScopesParams &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -1292,8 +1595,25 @@ func (q *Queries) UpdateAppScopes(ctx context.Context, arg UpdateAppScopesParams &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, @@ -1308,7 +1628,7 @@ UPDATE "apps" SET "version" = "version" + 1, "updated_at" = now() WHERE "id" = $1 -RETURNING id, account_id, account_public_id, app_type, name, client_id, version, creation_method, client_uri, logo_uri, tos_uri, policy_uri, software_id, software_version, contacts, token_endpoint_auth_method, scopes, custom_scopes, grant_types, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, redirect_uris, response_types, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at +RETURNING id, account_id, account_public_id, client_id, version, creation_method, redirect_uris, token_endpoint_auth_method, grant_types, response_types, client_name, client_uri, logo_uri, scopes, custom_scopes, contacts, tos_uri, policy_uri, jwks_uri, jwks, software_id, software_version, domain, transport, allow_user_registration, auth_providers, username_column, default_scopes, default_custom_scopes, app_type, sector_identifier_uri, subject_type, id_token_signed_response_alg, id_token_encrypted_response_alg, id_token_encrypted_response_enc, userinfo_signed_response_alg, userinfo_encrypted_response_alg, userinfo_encrypted_response_enc, request_object_signing_alg, request_object_encryption_alg, request_object_encryption_enc, token_endpoint_auth_signing_alg, default_max_age, require_auth_time, default_acr_values, initiate_login_uri, request_uris, access_token_signing_alg, id_token_ttl, token_ttl, refresh_token_ttl, created_at, updated_at ` func (q *Queries) UpdateAppVersion(ctx context.Context, id int32) (App, error) { @@ -1318,22 +1638,25 @@ func (q *Queries) UpdateAppVersion(ctx context.Context, id int32) (App, error) { &i.ID, &i.AccountID, &i.AccountPublicID, - &i.AppType, - &i.Name, &i.ClientID, &i.Version, &i.CreationMethod, + &i.RedirectUris, + &i.TokenEndpointAuthMethod, + &i.GrantTypes, + &i.ResponseTypes, + &i.ClientName, &i.ClientUri, &i.LogoUri, + &i.Scopes, + &i.CustomScopes, + &i.Contacts, &i.TosUri, &i.PolicyUri, + &i.JwksUri, + &i.Jwks, &i.SoftwareID, &i.SoftwareVersion, - &i.Contacts, - &i.TokenEndpointAuthMethod, - &i.Scopes, - &i.CustomScopes, - &i.GrantTypes, &i.Domain, &i.Transport, &i.AllowUserRegistration, @@ -1341,8 +1664,25 @@ func (q *Queries) UpdateAppVersion(ctx context.Context, id int32) (App, error) { &i.UsernameColumn, &i.DefaultScopes, &i.DefaultCustomScopes, - &i.RedirectUris, - &i.ResponseTypes, + &i.AppType, + &i.SectorIdentifierUri, + &i.SubjectType, + &i.IDTokenSignedResponseAlg, + &i.IDTokenEncryptedResponseAlg, + &i.IDTokenEncryptedResponseEnc, + &i.UserinfoSignedResponseAlg, + &i.UserinfoEncryptedResponseAlg, + &i.UserinfoEncryptedResponseEnc, + &i.RequestObjectSigningAlg, + &i.RequestObjectEncryptionAlg, + &i.RequestObjectEncryptionEnc, + &i.TokenEndpointAuthSigningAlg, + &i.DefaultMaxAge, + &i.RequireAuthTime, + &i.DefaultAcrValues, + &i.InitiateLoginUri, + &i.RequestUris, + &i.AccessTokenSigningAlg, &i.IDTokenTtl, &i.TokenTtl, &i.RefreshTokenTtl, diff --git a/idp/internal/providers/database/credentials_keys.sql.go b/idp/internal/providers/database/credentials_keys.sql.go index 98c9006..3dadbcc 100644 --- a/idp/internal/providers/database/credentials_keys.sql.go +++ b/idp/internal/providers/database/credentials_keys.sql.go @@ -26,7 +26,7 @@ INSERT INTO "credentials_keys" ( $4, $5, $6 -) RETURNING id, public_kid, public_key, crypto_suite, is_revoked, usage, account_id, expires_at, created_at, updated_at +) RETURNING id, public_kid, public_key, crypto_suite, is_revoked, is_external, usage, account_id, expires_at, created_at, updated_at ` type CreateCredentialsKeyParams struct { @@ -59,6 +59,7 @@ func (q *Queries) CreateCredentialsKey(ctx context.Context, arg CreateCredential &i.PublicKey, &i.CryptoSuite, &i.IsRevoked, + &i.IsExternal, &i.Usage, &i.AccountID, &i.ExpiresAt, @@ -77,6 +78,31 @@ func (q *Queries) DeleteAllCredentialsKeys(ctx context.Context) error { return err } +const findCredentialsKeyByID = `-- name: FindCredentialsKeyByID :one +SELECT id, public_kid, public_key, crypto_suite, is_revoked, is_external, usage, account_id, expires_at, created_at, updated_at FROM "credentials_keys" +WHERE "id" = $1 +LIMIT 1 +` + +func (q *Queries) FindCredentialsKeyByID(ctx context.Context, id int32) (CredentialsKey, error) { + row := q.db.QueryRow(ctx, findCredentialsKeyByID, id) + var i CredentialsKey + err := row.Scan( + &i.ID, + &i.PublicKid, + &i.PublicKey, + &i.CryptoSuite, + &i.IsRevoked, + &i.IsExternal, + &i.Usage, + &i.AccountID, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const findCredentialsKeyPublicKeyByPublicKIDCryptoSuiteAndUsage = `-- name: FindCredentialsKeyPublicKeyByPublicKIDCryptoSuiteAndUsage :one SELECT "public_key" FROM "credentials_keys" WHERE @@ -106,7 +132,7 @@ UPDATE "credentials_keys" SET "is_revoked" = true, "updated_at" = now() WHERE "id" = $1 -RETURNING id, public_kid, public_key, crypto_suite, is_revoked, usage, account_id, expires_at, created_at, updated_at +RETURNING id, public_kid, public_key, crypto_suite, is_revoked, is_external, usage, account_id, expires_at, created_at, updated_at ` func (q *Queries) RevokeCredentialsKey(ctx context.Context, id int32) (CredentialsKey, error) { @@ -118,6 +144,7 @@ func (q *Queries) RevokeCredentialsKey(ctx context.Context, id int32) (Credentia &i.PublicKey, &i.CryptoSuite, &i.IsRevoked, + &i.IsExternal, &i.Usage, &i.AccountID, &i.ExpiresAt, diff --git a/idp/internal/providers/database/dynamic_registration_domain_codes.sql.go b/idp/internal/providers/database/dynamic_registration_domain_codes.sql.go index 0c7de03..5bad396 100644 --- a/idp/internal/providers/database/dynamic_registration_domain_codes.sql.go +++ b/idp/internal/providers/database/dynamic_registration_domain_codes.sql.go @@ -10,10 +10,11 @@ import ( "time" ) -const createDynamicRegistrationDomainCode = `-- name: CreateDynamicRegistrationDomainCode :one +const createDynamicRegistrationDomainCode = `-- name: CreateDynamicRegistrationDomainCode :exec INSERT INTO "dynamic_registration_domain_codes" ( "account_id", + "dynamic_registration_domain_id", "verification_host", "verification_code", "verification_prefix", @@ -25,17 +26,19 @@ INSERT INTO "dynamic_registration_domain_codes" ( $3, $4, $5, - $6 -) RETURNING "id" + $6, + $7 +) ` type CreateDynamicRegistrationDomainCodeParams struct { - AccountID int32 - VerificationHost string - VerificationCode string - VerificationPrefix string - HmacSecretID string - ExpiresAt time.Time + AccountID int32 + DynamicRegistrationDomainID int32 + VerificationHost string + VerificationCode string + VerificationPrefix string + HmacSecretID string + ExpiresAt time.Time } // Copyright (c) 2025 Afonso Barracha @@ -43,18 +46,17 @@ type CreateDynamicRegistrationDomainCodeParams struct { // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. -func (q *Queries) CreateDynamicRegistrationDomainCode(ctx context.Context, arg CreateDynamicRegistrationDomainCodeParams) (int32, error) { - row := q.db.QueryRow(ctx, createDynamicRegistrationDomainCode, +func (q *Queries) CreateDynamicRegistrationDomainCode(ctx context.Context, arg CreateDynamicRegistrationDomainCodeParams) error { + _, err := q.db.Exec(ctx, createDynamicRegistrationDomainCode, arg.AccountID, + arg.DynamicRegistrationDomainID, arg.VerificationHost, arg.VerificationCode, arg.VerificationPrefix, arg.HmacSecretID, arg.ExpiresAt, ) - var id int32 - err := row.Scan(&id) - return id, err + return err } const deleteDynamicRegistrationDomainCode = `-- name: DeleteDynamicRegistrationDomainCode :exec @@ -67,6 +69,29 @@ func (q *Queries) DeleteDynamicRegistrationDomainCode(ctx context.Context, id in return err } +const findDynamicRegistrationDomainCodeByDynamicRegistrationDomainID = `-- name: FindDynamicRegistrationDomainCodeByDynamicRegistrationDomainID :one +SELECT id, account_id, dynamic_registration_domain_id, verification_host, verification_code, hmac_secret_id, verification_prefix, expires_at, created_at, updated_at FROM "dynamic_registration_domain_codes" +WHERE "dynamic_registration_domain_id" = $1 +` + +func (q *Queries) FindDynamicRegistrationDomainCodeByDynamicRegistrationDomainID(ctx context.Context, dynamicRegistrationDomainID int32) (DynamicRegistrationDomainCode, error) { + row := q.db.QueryRow(ctx, findDynamicRegistrationDomainCodeByDynamicRegistrationDomainID, dynamicRegistrationDomainID) + var i DynamicRegistrationDomainCode + err := row.Scan( + &i.ID, + &i.AccountID, + &i.DynamicRegistrationDomainID, + &i.VerificationHost, + &i.VerificationCode, + &i.HmacSecretID, + &i.VerificationPrefix, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const updateDynamicRegistrationDomainCode = `-- name: UpdateDynamicRegistrationDomainCode :exec UPDATE "dynamic_registration_domain_codes" SET "verification_host" = $2, diff --git a/idp/internal/providers/database/dynamic_registration_domains.sql.go b/idp/internal/providers/database/dynamic_registration_domains.sql.go new file mode 100644 index 0000000..1e45f98 --- /dev/null +++ b/idp/internal/providers/database/dynamic_registration_domains.sql.go @@ -0,0 +1,485 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: dynamic_registration_domains.sql + +package database + +import ( + "context" + + "github.com/google/uuid" +) + +const countDynamicRegistrationDomainsByAccountPublicID = `-- name: CountDynamicRegistrationDomainsByAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "account_public_id" = $1 +` + +func (q *Queries) CountDynamicRegistrationDomainsByAccountPublicID(ctx context.Context, accountPublicID uuid.UUID) (int64, error) { + row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByAccountPublicID, accountPublicID) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countDynamicRegistrationDomainsByDomain = `-- name: CountDynamicRegistrationDomainsByDomain :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "domain" = $1 +LIMIT 1 +` + +func (q *Queries) CountDynamicRegistrationDomainsByDomain(ctx context.Context, domain string) (int64, error) { + row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomain, domain) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countDynamicRegistrationDomainsByDomainAndAccountPublicID = `-- name: CountDynamicRegistrationDomainsByDomainAndAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" = $2 +LIMIT 1 +` + +type CountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams struct { + AccountPublicID uuid.UUID + Domain string +} + +func (q *Queries) CountDynamicRegistrationDomainsByDomainAndAccountPublicID(ctx context.Context, arg CountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomainAndAccountPublicID, arg.AccountPublicID, arg.Domain) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countDynamicRegistrationDomainsByDomains = `-- name: CountDynamicRegistrationDomainsByDomains :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "domain" IN ($1) +LIMIT 1 +` + +func (q *Queries) CountDynamicRegistrationDomainsByDomains(ctx context.Context, domains []string) (int64, error) { + row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomains, domains) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countDynamicRegistrationDomainsByDomainsAndAccountPublicID = `-- name: CountDynamicRegistrationDomainsByDomainsAndAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" IN ($2) +LIMIT 1 +` + +type CountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams struct { + AccountPublicID uuid.UUID + Domains []string +} + +func (q *Queries) CountDynamicRegistrationDomainsByDomainsAndAccountPublicID(ctx context.Context, arg CountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomainsAndAccountPublicID, arg.AccountPublicID, arg.Domains) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countFilteredDynamicRegistrationDomainsByAccountPublicID = `-- name: CountFilteredDynamicRegistrationDomainsByAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +LIMIT 1 +` + +type CountFilteredDynamicRegistrationDomainsByAccountPublicIDParams struct { + AccountPublicID uuid.UUID + Domain string +} + +func (q *Queries) CountFilteredDynamicRegistrationDomainsByAccountPublicID(ctx context.Context, arg CountFilteredDynamicRegistrationDomainsByAccountPublicIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countFilteredDynamicRegistrationDomainsByAccountPublicID, arg.AccountPublicID, arg.Domain) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countVerifiedDynamicRegistrationDomainsByDomain = `-- name: CountVerifiedDynamicRegistrationDomainsByDomain :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "domain" = $1 AND "verified_at" IS NOT NULL +LIMIT 1 +` + +func (q *Queries) CountVerifiedDynamicRegistrationDomainsByDomain(ctx context.Context, domain string) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedDynamicRegistrationDomainsByDomain, domain) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID = `-- name: CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" = $2 AND + "verified_at" IS NOT NULL +LIMIT 1 +` + +type CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicIDParams struct { + AccountPublicID uuid.UUID + Domain string +} + +func (q *Queries) CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID(ctx context.Context, arg CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID, arg.AccountPublicID, arg.Domain) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countVerifiedDynamicRegistrationDomainsByDomains = `-- name: CountVerifiedDynamicRegistrationDomainsByDomains :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "domain" IN ($1) AND "verified_at" IS NOT NULL +LIMIT 1 +` + +func (q *Queries) CountVerifiedDynamicRegistrationDomainsByDomains(ctx context.Context, domains []string) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedDynamicRegistrationDomainsByDomains, domains) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID = `-- name: CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" IN ($2) AND + "verified_at" IS NOT NULL +LIMIT 1 +` + +type CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams struct { + AccountPublicID uuid.UUID + Domains []string +} + +func (q *Queries) CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID(ctx context.Context, arg CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID, arg.AccountPublicID, arg.Domains) + var count int64 + err := row.Scan(&count) + return count, err +} + +const createDynamicRegistrationDomain = `-- name: CreateDynamicRegistrationDomain :one + +INSERT INTO "dynamic_registration_domains" ( + "account_id", + "account_public_id", + "domain", + "verification_method", + "usages" +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING id, account_id, account_public_id, domain, verified_at, verification_method, usages, created_at, updated_at +` + +type CreateDynamicRegistrationDomainParams struct { + AccountID int32 + AccountPublicID uuid.UUID + Domain string + VerificationMethod DomainVerificationMethod + Usages []DynamicRegistrationUsage +} + +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +func (q *Queries) CreateDynamicRegistrationDomain(ctx context.Context, arg CreateDynamicRegistrationDomainParams) (DynamicRegistrationDomain, error) { + row := q.db.QueryRow(ctx, createDynamicRegistrationDomain, + arg.AccountID, + arg.AccountPublicID, + arg.Domain, + arg.VerificationMethod, + arg.Usages, + ) + var i DynamicRegistrationDomain + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.Usages, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteDynamicRegistrationDomain = `-- name: DeleteDynamicRegistrationDomain :exec +DELETE FROM "dynamic_registration_domains" +WHERE "id" = $1 +` + +func (q *Queries) DeleteDynamicRegistrationDomain(ctx context.Context, id int32) error { + _, err := q.db.Exec(ctx, deleteDynamicRegistrationDomain, id) + return err +} + +const filterDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain = `-- name: FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, usages, created_at, updated_at FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +ORDER BY "domain" ASC +LIMIT $3 OFFSET $4 +` + +type FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams struct { + AccountPublicID uuid.UUID + Domain string + Limit int32 + Offset int32 +} + +func (q *Queries) FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain(ctx context.Context, arg FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams) ([]DynamicRegistrationDomain, error) { + rows, err := q.db.Query(ctx, filterDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain, + arg.AccountPublicID, + arg.Domain, + arg.Limit, + arg.Offset, + ) + if err != nil { + return nil, err + } + defer rows.Close() + items := []DynamicRegistrationDomain{} + for rows.Next() { + var i DynamicRegistrationDomain + if err := rows.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.Usages, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const filterDynamicRegistrationDomainsByAccountPublicIDOrderedByID = `-- name: FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, usages, created_at, updated_at FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +ORDER BY "id" DESC +LIMIT $3 OFFSET $4 +` + +type FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams struct { + AccountPublicID uuid.UUID + Domain string + Limit int32 + Offset int32 +} + +func (q *Queries) FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByID(ctx context.Context, arg FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams) ([]DynamicRegistrationDomain, error) { + rows, err := q.db.Query(ctx, filterDynamicRegistrationDomainsByAccountPublicIDOrderedByID, + arg.AccountPublicID, + arg.Domain, + arg.Limit, + arg.Offset, + ) + if err != nil { + return nil, err + } + defer rows.Close() + items := []DynamicRegistrationDomain{} + for rows.Next() { + var i DynamicRegistrationDomain + if err := rows.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.Usages, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const findDynamicRegistrationDomainByAccountPublicIDAndDomain = `-- name: FindDynamicRegistrationDomainByAccountPublicIDAndDomain :one +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, usages, created_at, updated_at FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 LIMIT 1 +` + +type FindDynamicRegistrationDomainByAccountPublicIDAndDomainParams struct { + AccountPublicID uuid.UUID + Domain string +} + +func (q *Queries) FindDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx context.Context, arg FindDynamicRegistrationDomainByAccountPublicIDAndDomainParams) (DynamicRegistrationDomain, error) { + row := q.db.QueryRow(ctx, findDynamicRegistrationDomainByAccountPublicIDAndDomain, arg.AccountPublicID, arg.Domain) + var i DynamicRegistrationDomain + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.Usages, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const findPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain = `-- name: FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, usages, created_at, updated_at FROM "dynamic_registration_domains" +WHERE "account_public_id" = $1 +ORDER BY "domain" ASC +LIMIT $2 OFFSET $3 +` + +type FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams struct { + AccountPublicID uuid.UUID + Limit int32 + Offset int32 +} + +func (q *Queries) FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain(ctx context.Context, arg FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams) ([]DynamicRegistrationDomain, error) { + rows, err := q.db.Query(ctx, findPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain, arg.AccountPublicID, arg.Limit, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + items := []DynamicRegistrationDomain{} + for rows.Next() { + var i DynamicRegistrationDomain + if err := rows.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.Usages, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const findPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByID = `-- name: FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many +SELECT id, account_id, account_public_id, domain, verified_at, verification_method, usages, created_at, updated_at FROM "dynamic_registration_domains" +WHERE "account_public_id" = $1 +ORDER BY "id" DESC +LIMIT $2 OFFSET $3 +` + +type FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams struct { + AccountPublicID uuid.UUID + Limit int32 + Offset int32 +} + +func (q *Queries) FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByID(ctx context.Context, arg FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams) ([]DynamicRegistrationDomain, error) { + rows, err := q.db.Query(ctx, findPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByID, arg.AccountPublicID, arg.Limit, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + items := []DynamicRegistrationDomain{} + for rows.Next() { + var i DynamicRegistrationDomain + if err := rows.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.Usages, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const verifyDynamicRegistrationDomain = `-- name: VerifyDynamicRegistrationDomain :one +UPDATE "dynamic_registration_domains" +SET + "verified_at" = NOW(), + "verification_method" = $2 +WHERE "id" = $1 RETURNING id, account_id, account_public_id, domain, verified_at, verification_method, usages, created_at, updated_at +` + +type VerifyDynamicRegistrationDomainParams struct { + ID int32 + VerificationMethod DomainVerificationMethod +} + +func (q *Queries) VerifyDynamicRegistrationDomain(ctx context.Context, arg VerifyDynamicRegistrationDomainParams) (DynamicRegistrationDomain, error) { + row := q.db.QueryRow(ctx, verifyDynamicRegistrationDomain, arg.ID, arg.VerificationMethod) + var i DynamicRegistrationDomain + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.Domain, + &i.VerifiedAt, + &i.VerificationMethod, + &i.Usages, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/idp/internal/providers/database/dynamic_registration_software_statement_keys.sql.go b/idp/internal/providers/database/dynamic_registration_software_statement_keys.sql.go new file mode 100644 index 0000000..02d1de4 --- /dev/null +++ b/idp/internal/providers/database/dynamic_registration_software_statement_keys.sql.go @@ -0,0 +1,75 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: dynamic_registration_software_statement_keys.sql + +package database + +import ( + "context" + + "github.com/google/uuid" +) + +const findDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicID = `-- name: FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicID :one +SELECT id, account_id, account_public_id, credentials_key_id, credentials_key_kid, root_domain, created_at FROM "dynamic_registration_software_statement_keys" +WHERE "credentials_key_kid" = $1 AND "account_public_id" = $2 +LIMIT 1 +` + +type FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicIDParams struct { + CredentialsKeyKid string + AccountPublicID uuid.UUID +} + +func (q *Queries) FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicID(ctx context.Context, arg FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicIDParams) (DynamicRegistrationSoftwareStatementKey, error) { + row := q.db.QueryRow(ctx, findDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicID, arg.CredentialsKeyKid, arg.AccountPublicID) + var i DynamicRegistrationSoftwareStatementKey + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.CredentialsKeyID, + &i.CredentialsKeyKid, + &i.RootDomain, + &i.CreatedAt, + ) + return i, err +} + +const findDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicID = `-- name: FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicID :one + +SELECT c.id, c.public_kid, c.public_key, c.crypto_suite, c.is_revoked, c.is_external, c.usage, c.account_id, c.expires_at, c.created_at, c.updated_at FROM "credentials_keys" AS "c" +LEFT JOIN "dynamic_registration_software_statement_keys" AS "d" ON "c"."id" = "d"."credential_key_id" +WHERE "d"."root_domain" = $1 AND "d"."account_public_id" = $2 +LIMIT 1 +` + +type FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicIDParams struct { + RootDomain string + AccountPublicID uuid.UUID +} + +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +func (q *Queries) FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicID(ctx context.Context, arg FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicIDParams) (CredentialsKey, error) { + row := q.db.QueryRow(ctx, findDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicID, arg.RootDomain, arg.AccountPublicID) + var i CredentialsKey + err := row.Scan( + &i.ID, + &i.PublicKid, + &i.PublicKey, + &i.CryptoSuite, + &i.IsRevoked, + &i.IsExternal, + &i.Usage, + &i.AccountID, + &i.ExpiresAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql index c5150f3..71d24f8 100644 --- a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql +++ b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql @@ -1,6 +1,6 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-09-06T04:54:24.517Z +-- Generated at: 2025-11-02T00:22:09.380Z CREATE TYPE "kek_usage" AS ENUM ( 'global', @@ -19,6 +19,21 @@ CREATE TYPE "token_crypto_suite" AS ENUM ( 'EdDSA' ); +CREATE TYPE "token_encryption_algorithm" AS ENUM ( + 'RSA-OAEP-256', + 'ECDH-ES', + 'ECDH-ES+A256KW' +); + +CREATE TYPE "token_encryption_encoding" AS ENUM ( + 'A128CBC-HS256', + 'A192CBC-HS384', + 'A256CBC-HS512', + 'A128GCM', + 'A192GCM', + 'A256GCM' +); + CREATE TYPE "token_key_usage" AS ENUM ( 'global', 'account' @@ -72,7 +87,6 @@ CREATE TYPE "auth_method" AS ENUM ( CREATE TYPE "response_type" AS ENUM ( 'code', - 'id_token', 'code id_token' ); @@ -106,6 +120,11 @@ CREATE TYPE "transport" AS ENUM ( 'streamable_http' ); +CREATE TYPE "client_subject_type" AS ENUM ( + 'public', + 'pairwise' +); + CREATE TYPE "creation_method" AS ENUM ( 'manual', 'dynamic_registration' @@ -182,13 +201,16 @@ CREATE TYPE "initial_access_token_generation_method" AS ENUM ( CREATE TYPE "software_statement_verification_method" AS ENUM ( 'manual', - 'jwks_uri', - 'jwk_x5_parameters' + 'jwks_uri' +); + +CREATE TYPE "dynamic_registration_usage" AS ENUM ( + 'account', + 'app' ); CREATE TYPE "domain_verification_method" AS ENUM ( 'authorization_code', - 'software_statement', 'dns_txt_record' ); @@ -301,6 +323,7 @@ CREATE TABLE "credentials_keys" ( "public_key" jsonb NOT NULL, "crypto_suite" token_crypto_suite NOT NULL, "is_revoked" boolean NOT NULL DEFAULT false, + "is_external" boolean NOT NULL DEFAULT false, "usage" credentials_usage NOT NULL, "account_id" integer NOT NULL, "expires_at" timestamptz NOT NULL, @@ -344,24 +367,45 @@ CREATE TABLE "account_credentials" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, "account_public_id" uuid NOT NULL, - "client_id" varchar(22) NOT NULL, - "name" varchar(255) NOT NULL, "domain" varchar(250) NOT NULL, - "credentials_type" account_credentials_type NOT NULL, - "scopes" account_credentials_scope[] NOT NULL, + "creation_method" creation_method NOT NULL, + "transport" transport NOT NULL, + "version" integer NOT NULL DEFAULT 1, + "client_id" varchar(22) NOT NULL, + "redirect_uris" varchar(2048)[] NOT NULL, "token_endpoint_auth_method" auth_method NOT NULL, "grant_types" grant_type[] NOT NULL, - "version" integer NOT NULL DEFAULT 1, - "transport" transport NOT NULL, - "creation_method" creation_method NOT NULL, + "response_types" response_type[] NOT NULL, + "client_name" varchar(255) NOT NULL, "client_uri" varchar(512) NOT NULL, - "redirect_uris" varchar(2048)[] NOT NULL, "logo_uri" varchar(512), - "policy_uri" varchar(512), + "scopes" account_credentials_scope[] NOT NULL, + "contacts" varchar(250)[] NOT NULL, "tos_uri" varchar(512), - "software_id" varchar(512) NOT NULL, + "policy_uri" varchar(512), + "jwks_uri" varchar(512), + "jwks" jsonb, + "software_id" varchar(512), "software_version" varchar(512), - "contacts" varchar(250)[] NOT NULL, + "credentials_type" account_credentials_type NOT NULL, + "sector_identifier_uri" varchar(512), + "subject_type" client_subject_type, + "id_token_signed_response_alg" token_crypto_suite NOT NULL, + "id_token_encrypted_response_alg" token_encryption_algorithm, + "id_token_encrypted_response_enc" token_encryption_encoding, + "userinfo_signed_response_alg" token_crypto_suite, + "userinfo_encrypted_response_alg" token_encryption_algorithm, + "userinfo_encrypted_response_enc" token_encryption_encoding, + "request_object_signing_alg" token_crypto_suite, + "request_object_encryption_alg" token_encryption_algorithm, + "request_object_encryption_enc" token_encryption_encoding, + "token_endpoint_auth_signing_alg" token_crypto_suite, + "default_max_age" bigint, + "require_auth_time" boolean NOT NULL DEFAULT false, + "default_acr_values" varchar(100)[], + "initiate_login_uri" varchar(512), + "request_uris" varchar(2048)[], + "access_token_signing_alg" token_crypto_suite NOT NULL, "created_at" timestamptz NOT NULL DEFAULT (now()), "updated_at" timestamptz NOT NULL DEFAULT (now()) ); @@ -499,22 +543,25 @@ CREATE TABLE "apps" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, "account_public_id" uuid NOT NULL, - "app_type" app_type NOT NULL, - "name" varchar(255) NOT NULL, "client_id" varchar(22) NOT NULL, "version" integer NOT NULL DEFAULT 1, "creation_method" creation_method NOT NULL, + "redirect_uris" varchar(2048)[] NOT NULL, + "token_endpoint_auth_method" auth_method NOT NULL, + "grant_types" grant_type[] NOT NULL, + "response_types" response_type[] NOT NULL, + "client_name" varchar(255) NOT NULL, "client_uri" varchar(512) NOT NULL, "logo_uri" varchar(512), - "tos_uri" varchar(512), - "policy_uri" varchar(512), - "software_id" varchar(250) NOT NULL, - "software_version" varchar(250), - "contacts" varchar(250)[] NOT NULL, - "token_endpoint_auth_method" auth_method NOT NULL, "scopes" scopes[] NOT NULL, "custom_scopes" varchar(512)[] NOT NULL, - "grant_types" grant_type[] NOT NULL, + "contacts" varchar(250)[] NOT NULL, + "tos_uri" varchar(512), + "policy_uri" varchar(512), + "jwks_uri" varchar(512), + "jwks" jsonb, + "software_id" varchar(512), + "software_version" varchar(512), "domain" varchar(250) NOT NULL, "transport" transport NOT NULL, "allow_user_registration" bool NOT NULL, @@ -522,8 +569,25 @@ CREATE TABLE "apps" ( "username_column" app_username_column NOT NULL, "default_scopes" scopes[] NOT NULL, "default_custom_scopes" varchar(512)[] NOT NULL, - "redirect_uris" varchar(2048)[] NOT NULL, - "response_types" response_type[] NOT NULL, + "app_type" app_type NOT NULL, + "sector_identifier_uri" varchar(512), + "subject_type" client_subject_type, + "id_token_signed_response_alg" token_crypto_suite NOT NULL, + "id_token_encrypted_response_alg" token_encryption_algorithm, + "id_token_encrypted_response_enc" token_encryption_encoding, + "userinfo_signed_response_alg" token_crypto_suite, + "userinfo_encrypted_response_alg" token_encryption_algorithm, + "userinfo_encrypted_response_enc" token_encryption_encoding, + "request_object_signing_alg" token_crypto_suite, + "request_object_encryption_alg" token_encryption_algorithm, + "request_object_encryption_enc" token_encryption_encoding, + "token_endpoint_auth_signing_alg" token_crypto_suite, + "default_max_age" integer, + "require_auth_time" boolean NOT NULL DEFAULT false, + "default_acr_values" varchar(100)[], + "initiate_login_uri" varchar(512), + "request_uris" varchar(2048)[], + "access_token_signing_alg" token_crypto_suite NOT NULL, "id_token_ttl" integer NOT NULL DEFAULT 300, "token_ttl" integer NOT NULL DEFAULT 300, "refresh_token_ttl" integer NOT NULL DEFAULT 604800, @@ -584,22 +648,48 @@ CREATE TABLE "account_dynamic_registration_configs" ( "account_id" integer NOT NULL, "account_public_id" uuid NOT NULL, "account_credentials_types" account_credentials_type[] NOT NULL, - "whitelisted_domains" varchar(250)[] NOT NULL, "require_software_statement_credential_types" account_credentials_type[] NOT NULL, "software_statement_verification_methods" software_statement_verification_method[] NOT NULL, + "require_verified_domains_credentials_type" account_credentials_type[] NOT NULL, "require_initial_access_token_credential_types" account_credentials_type[] NOT NULL, "initial_access_token_generation_methods" initial_access_token_generation_method[] NOT NULL, "created_at" timestamptz NOT NULL DEFAULT (now()), "updated_at" timestamptz NOT NULL DEFAULT (now()) ); -CREATE TABLE "account_dynamic_registration_domains" ( +CREATE TABLE "app_dynamic_registration_configs" ( + "id" serial PRIMARY KEY, + "account_id" integer NOT NULL, + "allowed_app_types" app_type[] NOT NULL, + "whitelisted_domains" varchar(250)[] NOT NULL, + "default_allow_user_registration" boolean NOT NULL, + "default_auth_providers" auth_provider[] NOT NULL, + "default_username_column" app_username_column NOT NULL, + "default_allowed_scopes" scopes[] NOT NULL, + "default_scopes" scopes[] NOT NULL, + "require_verified_domains_app_types" app_type[] NOT NULL, + "require_software_statement_app_types" app_type[] NOT NULL, + "software_statement_verification_methods" software_statement_verification_method[] NOT NULL, + "require_initial_access_token_app_types" app_type[] NOT NULL, + "initial_access_token_generation_methods" initial_access_token_generation_method[] NOT NULL, + "initial_access_token_ttl" integer NOT NULL DEFAULT 3600, + "initial_access_token_max_uses" int NOT NULL DEFAULT 1, + "allowed_grant_types" grant_type[] NOT NULL DEFAULT '{ "authorization_code", "refresh_token", "client_credentials", "urn:ietf:params:oauth:grant-type:device_code", "urn:ietf:params:oauth:grant-type:jwt-bearer" }', + "allowed_response_types" response_type[] NOT NULL DEFAULT '{ "code", "id_token", "code id_token" }', + "allowed_token_endpoint_auth_methods" auth_method[] NOT NULL DEFAULT '{ "none", "client_secret_post", "client_secret_basic", "client_secret_jwt", "private_key_jwt" }', + "max_redirect_uris" int NOT NULL DEFAULT 10, + "created_at" timestamptz NOT NULL DEFAULT (now()), + "updated_at" timestamptz NOT NULL DEFAULT (now()) +); + +CREATE TABLE "dynamic_registration_domains" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, "account_public_id" uuid NOT NULL, "domain" varchar(250) NOT NULL, "verified_at" timestamptz, "verification_method" domain_verification_method NOT NULL, + "usages" dynamic_registration_usage[] NOT NULL, "created_at" timestamptz NOT NULL DEFAULT (now()), "updated_at" timestamptz NOT NULL DEFAULT (now()) ); @@ -607,6 +697,7 @@ CREATE TABLE "account_dynamic_registration_domains" ( CREATE TABLE "dynamic_registration_domain_codes" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, + "dynamic_registration_domain_id" integer NOT NULL, "verification_host" varchar(50) NOT NULL, "verification_code" text NOT NULL, "hmac_secret_id" varchar(22) NOT NULL, @@ -616,47 +707,16 @@ CREATE TABLE "dynamic_registration_domain_codes" ( "updated_at" timestamptz NOT NULL DEFAULT (now()) ); -CREATE TABLE "account_dynamic_registration_domain_codes" ( - "account_dynamic_registration_domain_id" integer NOT NULL, - "dynamic_registration_domain_code_id" integer NOT NULL, - "account_id" integer NOT NULL, - "created_at" timestamptz NOT NULL DEFAULT (now()), - PRIMARY KEY ("account_dynamic_registration_domain_id", "dynamic_registration_domain_code_id") -); - -CREATE TABLE "account_dynamic_registration_software_statement_keys" ( +CREATE TABLE "dynamic_registration_software_statement_keys" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, "account_public_id" uuid NOT NULL, "credentials_key_id" integer NOT NULL, - "account_dynamic_registration_domain_id" integer NOT NULL, + "credentials_key_kid" varchar(22) NOT NULL, + "root_domain" varchar(250) NOT NULL, "created_at" timestamptz NOT NULL DEFAULT (now()) ); -CREATE TABLE "app_dynamic_registration_configs" ( - "id" serial PRIMARY KEY, - "account_id" integer NOT NULL, - "allowed_app_types" app_type[] NOT NULL, - "whitelisted_domains" varchar(250)[] NOT NULL, - "default_allow_user_registration" boolean NOT NULL, - "default_auth_providers" auth_provider[] NOT NULL, - "default_username_column" app_username_column NOT NULL, - "default_allowed_scopes" scopes[] NOT NULL, - "default_scopes" scopes[] NOT NULL, - "require_software_statement_app_types" app_type[] NOT NULL, - "software_statement_verification_methods" software_statement_verification_method[] NOT NULL, - "require_initial_access_token_app_types" app_type[] NOT NULL, - "initial_access_token_generation_methods" initial_access_token_generation_method[] NOT NULL, - "initial_access_token_ttl" integer NOT NULL DEFAULT 3600, - "initial_access_token_max_uses" int NOT NULL DEFAULT 1, - "allowed_grant_types" grant_type[] NOT NULL DEFAULT '{ "authorization_code", "refresh_token", "client_credentials", "urn:ietf:params:oauth:grant-type:device_code", "urn:ietf:params:oauth:grant-type:jwt-bearer" }', - "allowed_response_types" response_type[] NOT NULL DEFAULT '{ "code", "id_token", "code id_token" }', - "allowed_token_endpoint_auth_methods" auth_method[] NOT NULL DEFAULT '{ "none", "client_secret_post", "client_secret_basic", "client_secret_jwt", "private_key_jwt" }', - "max_redirect_uris" int NOT NULL DEFAULT 10, - "created_at" timestamptz NOT NULL DEFAULT (now()), - "updated_at" timestamptz NOT NULL DEFAULT (now()) -); - CREATE TABLE "app_profiles" ( "app_id" integer NOT NULL, "user_id" integer NOT NULL, @@ -783,7 +843,7 @@ CREATE INDEX "account_credentials_account_public_id_idx" ON "account_credentials CREATE INDEX "account_credentials_account_public_id_client_id_idx" ON "account_credentials" ("account_public_id", "client_id"); -CREATE UNIQUE INDEX "account_credentials_name_account_id_uidx" ON "account_credentials" ("name", "account_id"); +CREATE UNIQUE INDEX "account_credentials_client_name_account_id_uidx" ON "account_credentials" ("client_name", "account_id"); CREATE INDEX "account_credential_secrets_account_id_idx" ON "account_credentials_secrets" ("account_id"); @@ -905,9 +965,9 @@ CREATE INDEX "apps_client_id_account_public_id_idx" ON "apps" ("client_id", "acc CREATE INDEX "apps_account_public_id_idx" ON "apps" ("account_public_id"); -CREATE INDEX "apps_name_idx" ON "apps" ("name"); +CREATE INDEX "apps_client_name_idx" ON "apps" ("client_name"); -CREATE UNIQUE INDEX "apps_account_id_name_uidx" ON "apps" ("account_id", "name"); +CREATE UNIQUE INDEX "apps_account_id_client_name_uidx" ON "apps" ("account_id", "client_name"); CREATE INDEX "apps_account_id_app_type_idx" ON "apps" ("account_id", "app_type"); @@ -947,31 +1007,29 @@ CREATE UNIQUE INDEX "account_dynamic_registration_configs_account_id_uidx" ON "a CREATE INDEX "account_dynamic_registration_configs_account_public_id_idx" ON "account_dynamic_registration_configs" ("account_public_id"); -CREATE INDEX "accounts_totps_account_id_idx" ON "account_dynamic_registration_domains" ("account_id"); - -CREATE INDEX "account_dynamic_registration_domains_account_public_id_idx" ON "account_dynamic_registration_domains" ("account_public_id"); +CREATE INDEX "app_dynamic_registration_configs_account_id_idx" ON "app_dynamic_registration_configs" ("account_id"); -CREATE INDEX "account_dynamic_registration_domains_domain_idx" ON "account_dynamic_registration_domains" ("domain"); +CREATE INDEX "accounts_totps_account_id_idx" ON "dynamic_registration_domains" ("account_id"); -CREATE UNIQUE INDEX "account_dynamic_registration_domains_account_public_id_domain_uidx" ON "account_dynamic_registration_domains" ("account_public_id", "domain"); +CREATE INDEX "account_dynamic_registration_domains_account_public_id_idx" ON "dynamic_registration_domains" ("account_public_id"); -CREATE INDEX "account_dynamic_registration_domain_codes_account_id_idx" ON "dynamic_registration_domain_codes" ("account_id"); +CREATE INDEX "account_dynamic_registration_domains_domain_idx" ON "dynamic_registration_domains" ("domain"); -CREATE INDEX "account_dynamic_registration_domain_codes_account_id_idx" ON "account_dynamic_registration_domain_codes" ("account_id"); +CREATE UNIQUE INDEX "account_dynamic_registration_domains_account_public_id_domain_uidx" ON "dynamic_registration_domains" ("account_public_id", "domain"); -CREATE UNIQUE INDEX "account_dynamic_registration_domain_codes_account_dynamic_registration_domain_id_uidx" ON "account_dynamic_registration_domain_codes" ("account_dynamic_registration_domain_id"); +CREATE INDEX "dynamic_registration_domain_codes_account_id_idx" ON "dynamic_registration_domain_codes" ("account_id"); -CREATE UNIQUE INDEX "account_dynamic_registration_domain_codes_dynamic_registration_domain_code_id_uidx" ON "account_dynamic_registration_domain_codes" ("dynamic_registration_domain_code_id"); +CREATE INDEX "dynamic_registration_domain_codes_dynamic_registration_domain_id_idx" ON "dynamic_registration_domain_codes" ("dynamic_registration_domain_id"); -CREATE INDEX "account_dynamic_registration_software_statement_keys_account_id_idx" ON "account_dynamic_registration_software_statement_keys" ("account_id"); +CREATE INDEX "drs_statement_keys_account_id_idx" ON "dynamic_registration_software_statement_keys" ("account_id"); -CREATE INDEX "account_dynamic_registration_software_statement_keys_account_public_id_idx" ON "account_dynamic_registration_software_statement_keys" ("account_public_id"); +CREATE INDEX "drs_statement_keys_account_public_id_idx" ON "dynamic_registration_software_statement_keys" ("account_public_id"); -CREATE UNIQUE INDEX "account_dynamic_registration_software_statement_keys_credentials_key_id_uidx" ON "account_dynamic_registration_software_statement_keys" ("credentials_key_id"); +CREATE UNIQUE INDEX "drs_statement_keys_credentials_key_id_uidx" ON "dynamic_registration_software_statement_keys" ("credentials_key_id"); -CREATE UNIQUE INDEX "account_dynamic_registration_software_statement_keys_account_dynamic_registration_domain_id_uidx" ON "account_dynamic_registration_software_statement_keys" ("account_dynamic_registration_domain_id"); +CREATE INDEX "drs_statement_keys_root_domain_account_public_id_idx" ON "dynamic_registration_software_statement_keys" ("root_domain", "account_public_id"); -CREATE INDEX "app_dynamic_registration_configs_account_id_idx" ON "app_dynamic_registration_configs" ("account_id"); +CREATE INDEX "drs_statement_keys_credentials_key_kid_account_public_id_idx" ON "dynamic_registration_software_statement_keys" ("credentials_key_kid", "account_public_id"); CREATE INDEX "user_profiles_app_id_idx" ON "app_profiles" ("app_id"); @@ -1113,25 +1171,19 @@ ALTER TABLE "app_designs" ADD FOREIGN KEY ("app_id") REFERENCES "apps" ("id") ON ALTER TABLE "account_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; -ALTER TABLE "account_dynamic_registration_domains" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; - -ALTER TABLE "dynamic_registration_domain_codes" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; - -ALTER TABLE "dynamic_registration_domain_codes" ADD FOREIGN KEY ("hmac_secret_id") REFERENCES "account_hmac_secrets" ("secret_id") ON DELETE CASCADE ON UPDATE CASCADE; - -ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; +ALTER TABLE "app_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; -ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("account_dynamic_registration_domain_id") REFERENCES "account_dynamic_registration_domains" ("id") ON DELETE CASCADE; +ALTER TABLE "dynamic_registration_domains" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; -ALTER TABLE "account_dynamic_registration_domain_codes" ADD FOREIGN KEY ("dynamic_registration_domain_code_id") REFERENCES "dynamic_registration_domain_codes" ("id") ON DELETE CASCADE; +ALTER TABLE "dynamic_registration_domain_codes" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; -ALTER TABLE "account_dynamic_registration_software_statement_keys" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; +ALTER TABLE "dynamic_registration_domain_codes" ADD FOREIGN KEY ("dynamic_registration_domain_id") REFERENCES "dynamic_registration_domains" ("id") ON DELETE CASCADE; -ALTER TABLE "account_dynamic_registration_software_statement_keys" ADD FOREIGN KEY ("credentials_key_id") REFERENCES "credentials_keys" ("id") ON DELETE CASCADE; +ALTER TABLE "dynamic_registration_domain_codes" ADD FOREIGN KEY ("hmac_secret_id") REFERENCES "account_hmac_secrets" ("secret_id") ON DELETE CASCADE ON UPDATE CASCADE; -ALTER TABLE "account_dynamic_registration_software_statement_keys" ADD FOREIGN KEY ("account_dynamic_registration_domain_id") REFERENCES "account_dynamic_registration_domains" ("id") ON DELETE CASCADE; +ALTER TABLE "dynamic_registration_software_statement_keys" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; -ALTER TABLE "app_dynamic_registration_configs" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE; +ALTER TABLE "dynamic_registration_software_statement_keys" ADD FOREIGN KEY ("credentials_key_id") REFERENCES "credentials_keys" ("id") ON DELETE CASCADE; ALTER TABLE "app_profiles" ADD FOREIGN KEY ("app_id") REFERENCES "apps" ("id") ON DELETE CASCADE; diff --git a/idp/internal/providers/database/models.go b/idp/internal/providers/database/models.go index 251a769..c23b977 100644 --- a/idp/internal/providers/database/models.go +++ b/idp/internal/providers/database/models.go @@ -437,6 +437,48 @@ func (ns NullClaims) Value() (driver.Value, error) { return string(ns.Claims), nil } +type ClientSubjectType string + +const ( + ClientSubjectTypePublic ClientSubjectType = "public" + ClientSubjectTypePairwise ClientSubjectType = "pairwise" +) + +func (e *ClientSubjectType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ClientSubjectType(s) + case string: + *e = ClientSubjectType(s) + default: + return fmt.Errorf("unsupported scan type for ClientSubjectType: %T", src) + } + return nil +} + +type NullClientSubjectType struct { + ClientSubjectType ClientSubjectType + Valid bool // Valid is true if ClientSubjectType is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullClientSubjectType) Scan(value interface{}) error { + if value == nil { + ns.ClientSubjectType, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.ClientSubjectType.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullClientSubjectType) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.ClientSubjectType), nil +} + type CreationMethod string const ( @@ -569,7 +611,6 @@ type DomainVerificationMethod string const ( DomainVerificationMethodAuthorizationCode DomainVerificationMethod = "authorization_code" - DomainVerificationMethodSoftwareStatement DomainVerificationMethod = "software_statement" DomainVerificationMethodDnsTxtRecord DomainVerificationMethod = "dns_txt_record" ) @@ -608,6 +649,48 @@ func (ns NullDomainVerificationMethod) Value() (driver.Value, error) { return string(ns.DomainVerificationMethod), nil } +type DynamicRegistrationUsage string + +const ( + DynamicRegistrationUsageAccount DynamicRegistrationUsage = "account" + DynamicRegistrationUsageApp DynamicRegistrationUsage = "app" +) + +func (e *DynamicRegistrationUsage) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = DynamicRegistrationUsage(s) + case string: + *e = DynamicRegistrationUsage(s) + default: + return fmt.Errorf("unsupported scan type for DynamicRegistrationUsage: %T", src) + } + return nil +} + +type NullDynamicRegistrationUsage struct { + DynamicRegistrationUsage DynamicRegistrationUsage + Valid bool // Valid is true if DynamicRegistrationUsage is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullDynamicRegistrationUsage) Scan(value interface{}) error { + if value == nil { + ns.DynamicRegistrationUsage, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.DynamicRegistrationUsage.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullDynamicRegistrationUsage) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.DynamicRegistrationUsage), nil +} + type GrantType string const ( @@ -741,7 +824,6 @@ type ResponseType string const ( ResponseTypeCode ResponseType = "code" - ResponseTypeIDToken ResponseType = "id_token" ResponseTypeCodeidToken ResponseType = "code id_token" ) @@ -870,9 +952,8 @@ func (ns NullSecretStorageMode) Value() (driver.Value, error) { type SoftwareStatementVerificationMethod string const ( - SoftwareStatementVerificationMethodManual SoftwareStatementVerificationMethod = "manual" - SoftwareStatementVerificationMethodJwksUri SoftwareStatementVerificationMethod = "jwks_uri" - SoftwareStatementVerificationMethodJwkX5Parameters SoftwareStatementVerificationMethod = "jwk_x5_parameters" + SoftwareStatementVerificationMethodManual SoftwareStatementVerificationMethod = "manual" + SoftwareStatementVerificationMethodJwksUri SoftwareStatementVerificationMethod = "jwks_uri" ) func (e *SoftwareStatementVerificationMethod) Scan(src interface{}) error { @@ -953,6 +1034,95 @@ func (ns NullTokenCryptoSuite) Value() (driver.Value, error) { return string(ns.TokenCryptoSuite), nil } +type TokenEncryptionAlgorithm string + +const ( + TokenEncryptionAlgorithmRSAOAEP256 TokenEncryptionAlgorithm = "RSA-OAEP-256" + TokenEncryptionAlgorithmECDHES TokenEncryptionAlgorithm = "ECDH-ES" + TokenEncryptionAlgorithmECDHESA256KW TokenEncryptionAlgorithm = "ECDH-ES+A256KW" +) + +func (e *TokenEncryptionAlgorithm) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = TokenEncryptionAlgorithm(s) + case string: + *e = TokenEncryptionAlgorithm(s) + default: + return fmt.Errorf("unsupported scan type for TokenEncryptionAlgorithm: %T", src) + } + return nil +} + +type NullTokenEncryptionAlgorithm struct { + TokenEncryptionAlgorithm TokenEncryptionAlgorithm + Valid bool // Valid is true if TokenEncryptionAlgorithm is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullTokenEncryptionAlgorithm) Scan(value interface{}) error { + if value == nil { + ns.TokenEncryptionAlgorithm, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.TokenEncryptionAlgorithm.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullTokenEncryptionAlgorithm) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.TokenEncryptionAlgorithm), nil +} + +type TokenEncryptionEncoding string + +const ( + TokenEncryptionEncodingA128CBCHS256 TokenEncryptionEncoding = "A128CBC-HS256" + TokenEncryptionEncodingA192CBCHS384 TokenEncryptionEncoding = "A192CBC-HS384" + TokenEncryptionEncodingA256CBCHS512 TokenEncryptionEncoding = "A256CBC-HS512" + TokenEncryptionEncodingA128GCM TokenEncryptionEncoding = "A128GCM" + TokenEncryptionEncodingA192GCM TokenEncryptionEncoding = "A192GCM" + TokenEncryptionEncodingA256GCM TokenEncryptionEncoding = "A256GCM" +) + +func (e *TokenEncryptionEncoding) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = TokenEncryptionEncoding(s) + case string: + *e = TokenEncryptionEncoding(s) + default: + return fmt.Errorf("unsupported scan type for TokenEncryptionEncoding: %T", src) + } + return nil +} + +type NullTokenEncryptionEncoding struct { + TokenEncryptionEncoding TokenEncryptionEncoding + Valid bool // Valid is true if TokenEncryptionEncoding is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullTokenEncryptionEncoding) Scan(value interface{}) error { + if value == nil { + ns.TokenEncryptionEncoding, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.TokenEncryptionEncoding.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullTokenEncryptionEncoding) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.TokenEncryptionEncoding), nil +} + type TokenKeyType string const ( @@ -1250,29 +1420,50 @@ type AccountAuthProvider struct { } type AccountCredential struct { - ID int32 - AccountID int32 - AccountPublicID uuid.UUID - ClientID string - Name string - Domain string - CredentialsType AccountCredentialsType - Scopes []AccountCredentialsScope - TokenEndpointAuthMethod AuthMethod - GrantTypes []GrantType - Version int32 - Transport Transport - CreationMethod CreationMethod - ClientUri string - RedirectUris []string - LogoUri pgtype.Text - PolicyUri pgtype.Text - TosUri pgtype.Text - SoftwareID string - SoftwareVersion pgtype.Text - Contacts []string - CreatedAt time.Time - UpdatedAt time.Time + ID int32 + AccountID int32 + AccountPublicID uuid.UUID + Domain string + CreationMethod CreationMethod + Transport Transport + Version int32 + ClientID string + RedirectUris []string + TokenEndpointAuthMethod AuthMethod + GrantTypes []GrantType + ResponseTypes []ResponseType + ClientName string + ClientUri string + LogoUri pgtype.Text + Scopes []AccountCredentialsScope + Contacts []string + TosUri pgtype.Text + PolicyUri pgtype.Text + JwksUri pgtype.Text + Jwks []byte + SoftwareID pgtype.Text + SoftwareVersion pgtype.Text + CredentialsType AccountCredentialsType + SectorIdentifierUri pgtype.Text + SubjectType NullClientSubjectType + IDTokenSignedResponseAlg TokenCryptoSuite + IDTokenEncryptedResponseAlg NullTokenEncryptionAlgorithm + IDTokenEncryptedResponseEnc NullTokenEncryptionEncoding + UserinfoSignedResponseAlg NullTokenCryptoSuite + UserinfoEncryptedResponseAlg NullTokenEncryptionAlgorithm + UserinfoEncryptedResponseEnc NullTokenEncryptionEncoding + RequestObjectSigningAlg NullTokenCryptoSuite + RequestObjectEncryptionAlg NullTokenEncryptionAlgorithm + RequestObjectEncryptionEnc NullTokenEncryptionEncoding + TokenEndpointAuthSigningAlg NullTokenCryptoSuite + DefaultMaxAge pgtype.Int8 + RequireAuthTime bool + DefaultAcrValues []string + InitiateLoginUri pgtype.Text + RequestUris []string + AccessTokenSigningAlg TokenCryptoSuite + CreatedAt time.Time + UpdatedAt time.Time } type AccountCredentialsKey struct { @@ -1304,42 +1495,15 @@ type AccountDynamicRegistrationConfig struct { AccountID int32 AccountPublicID uuid.UUID AccountCredentialsTypes []AccountCredentialsType - WhitelistedDomains []string RequireSoftwareStatementCredentialTypes []AccountCredentialsType SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod + RequireVerifiedDomainsCredentialsType []AccountCredentialsType RequireInitialAccessTokenCredentialTypes []AccountCredentialsType InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod CreatedAt time.Time UpdatedAt time.Time } -type AccountDynamicRegistrationDomain struct { - ID int32 - AccountID int32 - AccountPublicID uuid.UUID - Domain string - VerifiedAt pgtype.Timestamptz - VerificationMethod DomainVerificationMethod - CreatedAt time.Time - UpdatedAt time.Time -} - -type AccountDynamicRegistrationDomainCode struct { - AccountDynamicRegistrationDomainID int32 - DynamicRegistrationDomainCodeID int32 - AccountID int32 - CreatedAt time.Time -} - -type AccountDynamicRegistrationSoftwareStatementKey struct { - ID int32 - AccountID int32 - AccountPublicID uuid.UUID - CredentialsKeyID int32 - AccountDynamicRegistrationDomainID int32 - CreatedAt time.Time -} - type AccountHmacSecret struct { ID int32 AccountID int32 @@ -1370,39 +1534,59 @@ type AccountTotp struct { } type App struct { - ID int32 - AccountID int32 - AccountPublicID uuid.UUID - AppType AppType - Name string - ClientID string - Version int32 - CreationMethod CreationMethod - ClientUri string - LogoUri pgtype.Text - TosUri pgtype.Text - PolicyUri pgtype.Text - SoftwareID string - SoftwareVersion pgtype.Text - Contacts []string - TokenEndpointAuthMethod AuthMethod - Scopes []Scopes - CustomScopes []string - GrantTypes []GrantType - Domain string - Transport Transport - AllowUserRegistration bool - AuthProviders []AuthProvider - UsernameColumn AppUsernameColumn - DefaultScopes []Scopes - DefaultCustomScopes []string - RedirectUris []string - ResponseTypes []ResponseType - IDTokenTtl int32 - TokenTtl int32 - RefreshTokenTtl int32 - CreatedAt time.Time - UpdatedAt time.Time + ID int32 + AccountID int32 + AccountPublicID uuid.UUID + ClientID string + Version int32 + CreationMethod CreationMethod + RedirectUris []string + TokenEndpointAuthMethod AuthMethod + GrantTypes []GrantType + ResponseTypes []ResponseType + ClientName string + ClientUri string + LogoUri pgtype.Text + Scopes []Scopes + CustomScopes []string + Contacts []string + TosUri pgtype.Text + PolicyUri pgtype.Text + JwksUri pgtype.Text + Jwks []byte + SoftwareID pgtype.Text + SoftwareVersion pgtype.Text + Domain string + Transport Transport + AllowUserRegistration bool + AuthProviders []AuthProvider + UsernameColumn AppUsernameColumn + DefaultScopes []Scopes + DefaultCustomScopes []string + AppType AppType + SectorIdentifierUri pgtype.Text + SubjectType NullClientSubjectType + IDTokenSignedResponseAlg TokenCryptoSuite + IDTokenEncryptedResponseAlg NullTokenEncryptionAlgorithm + IDTokenEncryptedResponseEnc NullTokenEncryptionEncoding + UserinfoSignedResponseAlg NullTokenCryptoSuite + UserinfoEncryptedResponseAlg NullTokenEncryptionAlgorithm + UserinfoEncryptedResponseEnc NullTokenEncryptionEncoding + RequestObjectSigningAlg NullTokenCryptoSuite + RequestObjectEncryptionAlg NullTokenEncryptionAlgorithm + RequestObjectEncryptionEnc NullTokenEncryptionEncoding + TokenEndpointAuthSigningAlg NullTokenCryptoSuite + DefaultMaxAge pgtype.Int4 + RequireAuthTime bool + DefaultAcrValues []string + InitiateLoginUri pgtype.Text + RequestUris []string + AccessTokenSigningAlg TokenCryptoSuite + IDTokenTtl int32 + TokenTtl int32 + RefreshTokenTtl int32 + CreatedAt time.Time + UpdatedAt time.Time } type AppDesign struct { @@ -1427,6 +1611,7 @@ type AppDynamicRegistrationConfig struct { DefaultUsernameColumn AppUsernameColumn DefaultAllowedScopes []Scopes DefaultScopes []Scopes + RequireVerifiedDomainsAppTypes []AppType RequireSoftwareStatementAppTypes []AppType SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod RequireInitialAccessTokenAppTypes []AppType @@ -1488,6 +1673,7 @@ type CredentialsKey struct { PublicKey []byte CryptoSuite TokenCryptoSuite IsRevoked bool + IsExternal bool Usage CredentialsUsage AccountID int32 ExpiresAt time.Time @@ -1521,18 +1707,41 @@ type DataEncryptionKey struct { UpdatedAt time.Time } -type DynamicRegistrationDomainCode struct { +type DynamicRegistrationDomain struct { ID int32 AccountID int32 - VerificationHost string - VerificationCode string - HmacSecretID string - VerificationPrefix string - ExpiresAt time.Time + AccountPublicID uuid.UUID + Domain string + VerifiedAt pgtype.Timestamptz + VerificationMethod DomainVerificationMethod + Usages []DynamicRegistrationUsage CreatedAt time.Time UpdatedAt time.Time } +type DynamicRegistrationDomainCode struct { + ID int32 + AccountID int32 + DynamicRegistrationDomainID int32 + VerificationHost string + VerificationCode string + HmacSecretID string + VerificationPrefix string + ExpiresAt time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +type DynamicRegistrationSoftwareStatementKey struct { + ID int32 + AccountID int32 + AccountPublicID uuid.UUID + CredentialsKeyID int32 + CredentialsKeyKid string + RootDomain string + CreatedAt time.Time +} + type KeyEncryptionKey struct { ID int32 Kid uuid.UUID diff --git a/idp/internal/providers/database/queries/account_credentials.sql b/idp/internal/providers/database/queries/account_credentials.sql index c47e40c..bb3451e 100644 --- a/idp/internal/providers/database/queries/account_credentials.sql +++ b/idp/internal/providers/database/queries/account_credentials.sql @@ -21,24 +21,46 @@ LIMIT 1; -- name: CreateAccountCredentials :one INSERT INTO "account_credentials" ( - "client_id", "account_id", "account_public_id", - "credentials_type", - "name", - "scopes", - "token_endpoint_auth_method", "domain", - "client_uri", + "creation_method", + "transport", + "client_id", "redirect_uris", + "token_endpoint_auth_method", + "grant_types", + "response_types", + "client_name", + "client_uri", "logo_uri", - "policy_uri", + "scopes", + "contacts", "tos_uri", + "policy_uri", + "jwks_uri", + "jwks", "software_id", "software_version", - "contacts", - "creation_method", - "transport" + "credentials_type", + "sector_identifier_uri", + "subject_type", + "id_token_signed_response_alg", + "id_token_encrypted_response_alg", + "id_token_encrypted_response_enc", + "userinfo_signed_response_alg", + "userinfo_encrypted_response_alg", + "userinfo_encrypted_response_enc", + "request_object_signing_alg", + "request_object_encryption_alg", + "request_object_encryption_enc", + "token_endpoint_auth_signing_alg", + "default_max_age", + "require_auth_time", + "default_acr_values", + "initiate_login_uri", + "request_uris", + "access_token_signing_alg" ) VALUES ( $1, $2, @@ -57,13 +79,35 @@ INSERT INTO "account_credentials" ( $15, $16, $17, - $18 + $18, + $19, + $20, + $21, + $22, + $23, + $24, + $25, + $26, + $27, + $28, + $29, + $30, + $31, + $32, + $33, + $34, + $35, + $36, + $37, + $38, + $39, + $40 ) RETURNING *; -- name: UpdateAccountCredentials :one UPDATE "account_credentials" SET "scopes" = $2, - "name" = $3, + "client_name" = $3, "domain" = $4, "client_uri" = $5, "redirect_uris" = $6, @@ -80,7 +124,7 @@ RETURNING *; -- name: CountAccountCredentialsByNameAndAccountID :one SELECT COUNT(*) FROM "account_credentials" -WHERE "account_id" = $1 AND "name" = $2; +WHERE "account_id" = $1 AND "client_name" = $2; -- name: DeleteAccountCredentials :exec DELETE FROM "account_credentials" diff --git a/idp/internal/providers/database/queries/account_dynamic_registration_configs.sql b/idp/internal/providers/database/queries/account_dynamic_registration_configs.sql index 697d1a3..eeece27 100644 --- a/idp/internal/providers/database/queries/account_dynamic_registration_configs.sql +++ b/idp/internal/providers/database/queries/account_dynamic_registration_configs.sql @@ -9,7 +9,6 @@ INSERT INTO "account_dynamic_registration_configs" ( "account_id", "account_public_id", "account_credentials_types", - "whitelisted_domains", "require_software_statement_credential_types", "software_statement_verification_methods", "require_initial_access_token_credential_types", @@ -21,18 +20,16 @@ INSERT INTO "account_dynamic_registration_configs" ( $4, $5, $6, - $7, - $8 + $7 ) RETURNING *; -- name: UpdateAccountDynamicRegistrationConfig :one UPDATE "account_dynamic_registration_configs" SET "account_credentials_types" = $2, - "whitelisted_domains" = $3, - "require_software_statement_credential_types" = $4, - "software_statement_verification_methods" = $5, - "require_initial_access_token_credential_types" = $6, - "initial_access_token_generation_methods" = $7 + "require_software_statement_credential_types" = $3, + "software_statement_verification_methods" = $4, + "require_initial_access_token_credential_types" = $5, + "initial_access_token_generation_methods" = $6 WHERE "id" = $1 RETURNING *; diff --git a/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql b/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql deleted file mode 100644 index a92cd0f..0000000 --- a/idp/internal/providers/database/queries/account_dynamic_registration_domain_codes.sql +++ /dev/null @@ -1,22 +0,0 @@ --- Copyright (c) 2025 Afonso Barracha --- --- This Source Code Form is subject to the terms of the Mozilla Public --- License, v. 2.0. If a copy of the MPL was not distributed with this --- file, You can obtain one at https://mozilla.org/MPL/2.0/. - --- name: CreateAccountDynamicRegistrationDomainCode :exec -INSERT INTO "account_dynamic_registration_domain_codes" ( - "account_dynamic_registration_domain_id", - "dynamic_registration_domain_code_id", - "account_id" -) VALUES ( - $1, - $2, - $3 -); - --- name: FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID :one -SELECT "d".* FROM "dynamic_registration_domain_codes" "d" -LEFT JOIN "account_dynamic_registration_domain_codes" "a" ON "d"."id" = "a"."dynamic_registration_domain_code_id" -WHERE "a"."account_dynamic_registration_domain_id" = $1 -LIMIT 1; diff --git a/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql b/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql deleted file mode 100644 index 60a0436..0000000 --- a/idp/internal/providers/database/queries/account_dynamic_registration_domains.sql +++ /dev/null @@ -1,97 +0,0 @@ --- Copyright (c) 2025 Afonso Barracha --- --- This Source Code Form is subject to the terms of the Mozilla Public --- License, v. 2.0. If a copy of the MPL was not distributed with this --- file, You can obtain one at https://mozilla.org/MPL/2.0/. - --- name: CreateAccountDynamicRegistrationDomain :one -INSERT INTO "account_dynamic_registration_domains" ( - "account_id", - "account_public_id", - "domain", - "verification_method" -) VALUES ( - $1, - $2, - $3, - $4 -) RETURNING *; - --- name: FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain :one -SELECT * FROM "account_dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 LIMIT 1; - --- name: VerifyAccountDynamicRegistrationDomain :one -UPDATE "account_dynamic_registration_domains" -SET - "verified_at" = NOW(), - "verification_method" = $2 -WHERE "id" = $1 RETURNING *; - --- name: FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many -SELECT * FROM "account_dynamic_registration_domains" -WHERE "account_public_id" = $1 -ORDER BY "id" DESC -LIMIT $2 OFFSET $3; - --- name: FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many -SELECT * FROM "account_dynamic_registration_domains" -WHERE "account_public_id" = $1 -ORDER BY "domain" ASC -LIMIT $2 OFFSET $3; - --- name: CountAccountDynamicRegistrationDomainsByAccountPublicID :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE "account_public_id" = $1; - --- name: FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many -SELECT * FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" ILIKE $2 -ORDER BY "id" DESC -LIMIT $3 OFFSET $4; - --- name: FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many -SELECT * FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" ILIKE $2 -ORDER BY "domain" ASC -LIMIT $3 OFFSET $4; - --- name: CountFilteredAccountDynamicRegistrationDomainsByAccountPublicID :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" ILIKE $2 -LIMIT 1; - --- name: CountVerifiedAccountDynamicRegistrationDomainsByDomain :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE "domain" = $1 AND "verified_at" IS NOT NULL -LIMIT 1; - --- name: CountVerifiedAccountDynamicRegistrationDomainsByDomains :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE "domain" IN (sqlc.slice('domains')) AND "verified_at" IS NOT NULL -LIMIT 1; - --- name: CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" IN (sqlc.slice('domains')) AND - "verified_at" IS NOT NULL -LIMIT 1; - --- name: CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID :one -SELECT COUNT(*) FROM "account_dynamic_registration_domains" -WHERE - "account_public_id" = $1 AND - "domain" = $2 AND - "verified_at" IS NOT NULL -LIMIT 1; - --- name: DeleteAccountDynamicRegistrationDomain :exec -DELETE FROM "account_dynamic_registration_domains" -WHERE "id" = $1; diff --git a/idp/internal/providers/database/queries/app_related_apps.sql b/idp/internal/providers/database/queries/app_related_apps.sql index 3649e69..4774393 100644 --- a/idp/internal/providers/database/queries/app_related_apps.sql +++ b/idp/internal/providers/database/queries/app_related_apps.sql @@ -19,7 +19,7 @@ INSERT INTO "app_related_apps" ( SELECT a.* FROM "apps" a INNER JOIN "app_related_apps" ara ON a.id = ara.related_app_id WHERE ara.app_id = $1 -ORDER BY a.name ASC; +ORDER BY a.client_name ASC; -- name: DeleteAppRelatedAppsByAppIDAndRelatedAppIDs :exec DELETE FROM "app_related_apps" diff --git a/idp/internal/providers/database/queries/apps.sql b/idp/internal/providers/database/queries/apps.sql index 42a8637..4bfb39d 100644 --- a/idp/internal/providers/database/queries/apps.sql +++ b/idp/internal/providers/database/queries/apps.sql @@ -9,7 +9,7 @@ INSERT INTO "apps" ( "account_id", "account_public_id", "app_type", - "name", + "client_name", "client_id", "client_uri", "username_column", @@ -64,7 +64,7 @@ INSERT INTO "apps" ( -- name: CountAppsByAccountIDAndName :one SELECT COUNT(*) FROM "apps" -WHERE "account_id" = $1 AND "name" = $2 +WHERE "account_id" = $1 AND "client_name" = $2 LIMIT 1; -- name: FindAppByClientID :one @@ -86,7 +86,7 @@ WHERE "id" = $1 LIMIT 1; -- name: UpdateApp :one UPDATE "apps" -SET "name" = $2, +SET "client_name" = $2, "username_column" = $3, "client_uri" = $4, "logo_uri" = $5, @@ -129,7 +129,7 @@ OFFSET $2 LIMIT $3; -- name: FindPaginatedAppsByAccountPublicIDOrderedByName :many SELECT * FROM "apps" WHERE "account_public_id" = $1 -ORDER BY "name" ASC +ORDER BY "client_name" ASC OFFSET $2 LIMIT $3; -- name: CountAppsByAccountPublicID :one @@ -139,7 +139,7 @@ LIMIT 1; -- name: FilterAppsByNameAndByAccountPublicIDOrderedByID :many SELECT * FROM "apps" -WHERE "account_public_id" = $1 AND "name" ILIKE $2 +WHERE "account_public_id" = $1 AND "client_name" ILIKE $2 ORDER BY "id" DESC OFFSET $3 LIMIT $4; @@ -152,34 +152,34 @@ OFFSET $3 LIMIT $4; -- name: FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByID :many SELECT * FROM "apps" WHERE "account_public_id" = $1 AND - "name" ILIKE $2 AND + "client_name" ILIKE $2 AND "app_type" = $3 ORDER BY "id" DESC OFFSET $4 LIMIT $5; -- name: FilterAppsByNameAndByAccountPublicIDOrderedByName :many SELECT * FROM "apps" -WHERE "account_public_id" = $1 AND "name" ILIKE $2 -ORDER BY "name" ASC +WHERE "account_public_id" = $1 AND "client_name" ILIKE $2 +ORDER BY "client_name" ASC OFFSET $3 LIMIT $4; -- name: FilterAppsByTypeAndByAccountPublicIDOrderedByName :many SELECT * FROM "apps" WHERE "account_public_id" = $1 AND "app_type" = $2 -ORDER BY "name" ASC +ORDER BY "client_name" ASC OFFSET $3 LIMIT $4; -- name: FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByName :many SELECT * FROM "apps" WHERE "account_public_id" = $1 AND - "name" ILIKE $2 AND + "client_name" ILIKE $2 AND "app_type" = $3 -ORDER BY "name" ASC +ORDER BY "client_name" ASC OFFSET $4 LIMIT $5; -- name: CountFilteredAppsByNameAndByAccountPublicID :one SELECT COUNT(*) FROM "apps" -WHERE "account_public_id" = $1 AND "name" ILIKE $2 +WHERE "account_public_id" = $1 AND "client_name" ILIKE $2 LIMIT 1; -- name: CountFilteredAppsByTypeAndByAccountPublicID :one @@ -190,7 +190,7 @@ LIMIT 1; -- name: CountFilteredAppsByNameAndTypeAndByAccountPublicID :one SELECT COUNT(*) FROM "apps" WHERE "account_public_id" = $1 AND - "name" ILIKE $2 AND + "client_name" ILIKE $2 AND "app_type" = $3 LIMIT 1; @@ -204,7 +204,7 @@ RETURNING *; -- name: FindAppsByClientIDsAndAccountID :many SELECT * FROM "apps" WHERE "client_id" IN (sqlc.slice('client_ids')) AND "account_id" = $1 -ORDER BY "name" ASC LIMIT $2; +ORDER BY "client_name" ASC LIMIT $2; -- name: CountAppsByClientIDAndAccountPublicID :one SELECT COUNT(*) FROM "apps" diff --git a/idp/internal/providers/database/queries/credentials_keys.sql b/idp/internal/providers/database/queries/credentials_keys.sql index 50e9e11..bdbaea0 100644 --- a/idp/internal/providers/database/queries/credentials_keys.sql +++ b/idp/internal/providers/database/queries/credentials_keys.sql @@ -47,3 +47,8 @@ LIMIT 1; -- name: DeleteAllCredentialsKeys :exec DELETE FROM "credentials_keys"; + +-- name: FindCredentialsKeyByID :one +SELECT * FROM "credentials_keys" +WHERE "id" = $1 +LIMIT 1; diff --git a/idp/internal/providers/database/queries/dynamic_registration_domain_codes.sql b/idp/internal/providers/database/queries/dynamic_registration_domain_codes.sql index 9b73f9e..6c002a0 100644 --- a/idp/internal/providers/database/queries/dynamic_registration_domain_codes.sql +++ b/idp/internal/providers/database/queries/dynamic_registration_domain_codes.sql @@ -4,9 +4,10 @@ -- License, v. 2.0. If a copy of the MPL was not distributed with this -- file, You can obtain one at https://mozilla.org/MPL/2.0/. --- name: CreateDynamicRegistrationDomainCode :one +-- name: CreateDynamicRegistrationDomainCode :exec INSERT INTO "dynamic_registration_domain_codes" ( "account_id", + "dynamic_registration_domain_id", "verification_host", "verification_code", "verification_prefix", @@ -18,8 +19,9 @@ INSERT INTO "dynamic_registration_domain_codes" ( $3, $4, $5, - $6 -) RETURNING "id"; + $6, + $7 +); -- name: UpdateDynamicRegistrationDomainCode :exec UPDATE "dynamic_registration_domain_codes" SET @@ -33,3 +35,7 @@ WHERE "id" = $1; -- name: DeleteDynamicRegistrationDomainCode :exec DELETE FROM "dynamic_registration_domain_codes" WHERE "id" = $1; + +-- name: FindDynamicRegistrationDomainCodeByDynamicRegistrationDomainID :one +SELECT * FROM "dynamic_registration_domain_codes" +WHERE "dynamic_registration_domain_id" = $1; diff --git a/idp/internal/providers/database/queries/dynamic_registration_domains.sql b/idp/internal/providers/database/queries/dynamic_registration_domains.sql new file mode 100644 index 0000000..f646140 --- /dev/null +++ b/idp/internal/providers/database/queries/dynamic_registration_domains.sql @@ -0,0 +1,123 @@ +-- Copyright (c) 2025 Afonso Barracha +-- +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +-- name: CreateDynamicRegistrationDomain :one +INSERT INTO "dynamic_registration_domains" ( + "account_id", + "account_public_id", + "domain", + "verification_method", + "usages" +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING *; + +-- name: FindDynamicRegistrationDomainByAccountPublicIDAndDomain :one +SELECT * FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 LIMIT 1; + +-- name: VerifyDynamicRegistrationDomain :one +UPDATE "dynamic_registration_domains" +SET + "verified_at" = NOW(), + "verification_method" = $2 +WHERE "id" = $1 RETURNING *; + +-- name: FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many +SELECT * FROM "dynamic_registration_domains" +WHERE "account_public_id" = $1 +ORDER BY "id" DESC +LIMIT $2 OFFSET $3; + +-- name: FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many +SELECT * FROM "dynamic_registration_domains" +WHERE "account_public_id" = $1 +ORDER BY "domain" ASC +LIMIT $2 OFFSET $3; + +-- name: CountDynamicRegistrationDomainsByAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "account_public_id" = $1; + +-- name: FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByID :many +SELECT * FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +ORDER BY "id" DESC +LIMIT $3 OFFSET $4; + +-- name: FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain :many +SELECT * FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +ORDER BY "domain" ASC +LIMIT $3 OFFSET $4; + +-- name: CountFilteredDynamicRegistrationDomainsByAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" ILIKE $2 +LIMIT 1; + +-- name: CountVerifiedDynamicRegistrationDomainsByDomain :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "domain" = $1 AND "verified_at" IS NOT NULL +LIMIT 1; + +-- name: CountDynamicRegistrationDomainsByDomain :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "domain" = $1 +LIMIT 1; + +-- name: CountVerifiedDynamicRegistrationDomainsByDomains :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "domain" IN (sqlc.slice('domains')) AND "verified_at" IS NOT NULL +LIMIT 1; + +-- name: CountDynamicRegistrationDomainsByDomains :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE "domain" IN (sqlc.slice('domains')) +LIMIT 1; + +-- name: CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" IN (sqlc.slice('domains')) AND + "verified_at" IS NOT NULL +LIMIT 1; + +-- name: CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" = $2 AND + "verified_at" IS NOT NULL +LIMIT 1; + +-- name: CountDynamicRegistrationDomainsByDomainsAndAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" IN (sqlc.slice('domains')) +LIMIT 1; + +-- name: CountDynamicRegistrationDomainsByDomainAndAccountPublicID :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" = $2 +LIMIT 1; + +-- name: DeleteDynamicRegistrationDomain :exec +DELETE FROM "dynamic_registration_domains" +WHERE "id" = $1; diff --git a/idp/internal/providers/database/queries/dynamic_registration_software_statement_keys.sql b/idp/internal/providers/database/queries/dynamic_registration_software_statement_keys.sql new file mode 100644 index 0000000..7342eed --- /dev/null +++ b/idp/internal/providers/database/queries/dynamic_registration_software_statement_keys.sql @@ -0,0 +1,16 @@ +-- Copyright (c) 2025 Afonso Barracha +-- +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +-- name: FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicID :one +SELECT "c".* FROM "credentials_keys" AS "c" +LEFT JOIN "dynamic_registration_software_statement_keys" AS "d" ON "c"."id" = "d"."credential_key_id" +WHERE "d"."root_domain" = $1 AND "d"."account_public_id" = $2 +LIMIT 1; + +-- name: FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicID :one +SELECT * FROM "dynamic_registration_software_statement_keys" +WHERE "credentials_key_kid" = $1 AND "account_public_id" = $2 +LIMIT 1; diff --git a/idp/internal/providers/tokens/dynamic_registration_software_statements.go b/idp/internal/providers/tokens/dynamic_registration_software_statements.go new file mode 100644 index 0000000..0770eec --- /dev/null +++ b/idp/internal/providers/tokens/dynamic_registration_software_statements.go @@ -0,0 +1,101 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package tokens + +import ( + "context" + + "github.com/golang-jwt/jwt/v5" + + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const dynamicRegistrationSoftwareStatementsLocation = "dynamic_registration_software_statements" + +type SoftwareStatementClaims struct { + RedirectURIs []string `json:"redirect_uris,omitempty" validate:"omitempty,min=1,dive,uri"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty" validate:"omitempty,oneof=none client_secret_basic client_secret_post client_secret_jwt private_key_jwt"` + GrantTypes []string `json:"grant_types,omitempty" validate:"omitempty,min=1,dive,oneof=authorization_code refresh_token client_credentials urn:ietf:params:oauth:grant-type:jwt-bearer"` + ResponseTypes []string `json:"response_types,omitempty" validate:"omitempty,dive,oneof=none code 'code id_token'"` + ApplicationType string `json:"application_type,omitempty" validate:"omitempty,oneof=native service mcp"` + ClientName string `json:"client_name,omitempty" validate:"omitempty,min=1,max=255"` + ClientURI string `json:"client_uri,omitempty" validate:"omitempty,url"` + LogoURI string `json:"logo_uri,omitempty" validate:"omitempty,url"` + Scope string `json:"scope,omitempty" validate:"omitempty,multiple_scope"` + Contacts []string `json:"contacts,omitempty" validate:"omitempty,unique,dive,email"` + TOSURI string `json:"tos_uri,omitempty" validate:"omitempty,url"` + PolicyURI string `json:"policy_uri,omitempty" validate:"omitempty,url"` + JWKsURI string `json:"jwks_uri,omitempty" validate:"omitempty,url"` + JWKs []string `json:"jwks,omitempty" validate:"omitempty,json"` + SoftwareID string `json:"software_id,omitempty" validate:"omitempty,max=512"` + SoftwareVersion string `json:"software_version,omitempty" validate:"omitempty,max=512"` + SubjectType string `json:"subject_type,omitempty" validate:"omitempty,oneof=public pairwise"` + SectorIdentifierURI string `json:"sector_identifier_uri,omitempty" validate:"omitempty,url"` + DefaultMaxAge int64 `json:"default_max_age,omitempty" validate:"omitempty,min=0"` + RequireAuthTime bool `json:"require_auth_time,omitempty" validate:"omitempty,bool"` + DefaultACRValues []string `json:"default_acr_values,omitempty" validate:"omitempty,unique,dive,max=100"` + InitiateLoginURI string `json:"initiate_login_uri,omitempty" validate:"omitempty,url"` + RequestURIs []string `json:"request_uris,omitempty" validate:"omitempty,unique,dive,url"` + IDTokenSignedResponseAlg string `json:"id_token_signed_response_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` + IDTokenEncryptedResponseAlg string `json:"id_token_encrypted_response_alg,omitempty" validate:"omitempty,oneof=RSA-OAEP-256 ECDH-ES ECDH-ES+A256KW"` + IDTokenEncryptedResponseEnc string `json:"id_token_encrypted_response_enc,omitempty" validate:"omitempty,oneof=A128CBC-HS256 A192CBC-HS384 A256CBC-HS512 A128GCM A192GCM A256GCM"` + UserInfoSignedResponseAlg string `json:"userinfo_signed_response_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` + UserInfoEncryptedResponseAlg string `json:"userinfo_encrypted_response_alg,omitempty" validate:"omitempty,oneof=RSA-OAEP-256 ECDH-ES ECDH-ES+A256KW"` + UserInfoEncryptedResponseEnc string `json:"userinfo_encrypted_response_enc,omitempty" validate:"omitempty,oneof=A128CBC-HS256 A192CBC-HS384 A256CBC-HS512 A128GCM A192GCM A256GCM"` + RequestObjectSigningAlg string `json:"request_object_signing_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` + RequestObjectEncryptionAlg string `json:"request_object_encryption_alg,omitempty" validate:"omitempty,oneof=RSA-OAEP-256 ECDH-ES ECDH-ES+A256KW"` + RequestObjectEncryptionEnc string `json:"request_object_encryption_enc,omitempty" validate:"omitempty,oneof=A128CBC-HS256 A192CBC-HS384 A256CBC-HS512 A128GCM A192GCM A256GCM"` + TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` + AccessTokenSigningAlg string `json:"access_token_signing_alg,omitempty" validate:"omitempty,oneof=RS256 ES256 EdDSA"` +} + +type GetUnknownPublicJWK = func(kid string) (utils.JWK, error) + +type softwareStatementJWTClaims struct { + SoftwareStatementClaims + jwt.RegisteredClaims +} + +type VerifySoftwareStatementOptions struct { + RequestID string + SoftwareStatement string + GetPublicJWK GetUnknownPublicJWK +} + +func (t *Tokens) VerifySoftwareStatement( + ctx context.Context, + opts VerifySoftwareStatementOptions, +) (SoftwareStatementClaims, jwt.RegisteredClaims, error) { + logger := utils.BuildLogger(t.logger, utils.LoggerOptions{ + Location: dynamicRegistrationSoftwareStatementsLocation, + Method: "VerifySoftwareStatementToken", + RequestID: opts.RequestID, + }) + logger.DebugContext(ctx, "Verifying software statement token") + + var claims softwareStatementJWTClaims + if _, err := jwt.ParseWithClaims(opts.SoftwareStatement, &claims, func(token *jwt.Token) (interface{}, error) { + kid, err := extractTokenKID(token) + if err != nil { + logger.DebugContext(ctx, "Failed to extract KID from software statement token", "error", err) + return nil, err + } + + jwk, err := opts.GetPublicJWK(kid) + if err != nil { + logger.WarnContext(ctx, "Failed to get public JWK for software statement token", "error", err, "kid", kid) + return nil, err + } + + return jwk.ToUsableKey() + }); err != nil { + logger.WarnContext(ctx, "Failed to verify software statement token", "error", err) + return SoftwareStatementClaims{}, jwt.RegisteredClaims{}, err + } + + return claims.SoftwareStatementClaims, claims.RegisteredClaims, nil +} diff --git a/idp/internal/providers/tokens/jwks.go b/idp/internal/providers/tokens/jwks.go new file mode 100644 index 0000000..4bd14c4 --- /dev/null +++ b/idp/internal/providers/tokens/jwks.go @@ -0,0 +1,106 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package tokens + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "time" + + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const ( + jwksLocation string = "jwks" + + maxRequestBodySize = 128 * 1024 // 128 KB +) + +var client = &http.Client{ + Timeout: 10 * time.Second, +} + +func httpGetWithinLimit(url string) ([]byte, error) { + resp, err := client.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + limited := http.MaxBytesReader(nil, resp.Body, maxRequestBodySize) + return io.ReadAll(limited) +} + +type PublicJWKsResult struct { + Keys []utils.JWK `json:"keys"` +} + +func (p *PublicJWKsResult) UnmarshalJSON(data []byte) error { + type Alias PublicJWKsResult + aux := &struct { + Keys []json.RawMessage `json:"keys"` + *Alias + }{ + Alias: (*Alias)(p), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + p.Keys = make([]utils.JWK, 0, len(aux.Keys)) + for _, raw := range aux.Keys { + jwk, err := utils.JsonToJWK(raw) + if err != nil { + return err + } + + if jwk.GetKeyID() == "" { + return errors.New("JWK is missing 'kid' field") + } + + p.Keys = append(p.Keys, jwk) + } + + return nil +} + +type GetPublicJWKsOptions struct { + RequestID string + URL string +} + +func (t *Tokens) GetPublicJWKs(ctx context.Context, opts GetPublicJWKsOptions) (PublicJWKsResult, error) { + logger := utils.BuildLogger(t.logger, utils.LoggerOptions{ + Location: jwksLocation, + Method: "GetPublicJWKs", + RequestID: opts.RequestID, + }) + logger.DebugContext(ctx, "Fetching public JWKs") + + bytes, err := httpGetWithinLimit(opts.URL) + if err != nil { + logger.ErrorContext(ctx, "Failed to fetch public JWKs", "error", err) + return PublicJWKsResult{}, err + } + + var result PublicJWKsResult + if err := json.Unmarshal(bytes, &result); err != nil { + logger.ErrorContext(ctx, "Failed to parse public JWKs", "error", err) + return PublicJWKsResult{}, err + } + if len(result.Keys) == 0 { + logger.ErrorContext(ctx, "No JWKs found in the response") + return PublicJWKsResult{}, errors.New("no JWKs found in the response") + } + + logger.InfoContext(ctx, "Successfully fetched public JWKs", "keyCount", len(result.Keys)) + return result, nil +} diff --git a/idp/internal/server/routes/account_dynamic_registration.go b/idp/internal/server/routes/account_dynamic_registration.go index 6ba8980..ca95cd2 100644 --- a/idp/internal/server/routes/account_dynamic_registration.go +++ b/idp/internal/server/routes/account_dynamic_registration.go @@ -42,7 +42,7 @@ func (r *Routes) AccountDynamicRegistrationConfigurationRoutes(app *fiber.App) { domainsRouter.Post( paths.Base, credentialsConfigsWriteScopeMiddleware, - r.controllers.CreateAccountCredentialsRegistrationDomain, + r.controllers.CreateDynamicRegistrationDomain, ) domainsRouter.Get( paths.Base, diff --git a/idp/internal/server/server.go b/idp/internal/server/server.go index b061a7f..8d96fc1 100644 --- a/idp/internal/server/server.go +++ b/idp/internal/server/server.go @@ -269,6 +269,10 @@ func New( ) logger.InfoContext(ctx, "Finished building OAuth provider") + logger.InfoContext(ctx, "Loading validators...") + vld := validations.NewValidator(logger) + logger.InfoContext(ctx, "Finished loading validators") + logger.InfoContext(ctx, "Building services...") newServices := services.NewServices( logger, @@ -278,6 +282,7 @@ func New( jwts, cryp, oauthProviders, + vld, cfg.KEKExpirationDays(), cfg.DEKExpirationDays(), cfg.JWKExpirationDays(), @@ -290,10 +295,6 @@ func New( ) logger.InfoContext(ctx, "Finished building services") - logger.InfoContext(ctx, "Loading validators...") - vld := validations.NewValidator(logger) - logger.InfoContext(ctx, "Finished loading validators") - server := &FiberServer{ App: fiber.New(fiber.Config{ ServerHeader: "idp", diff --git a/idp/internal/services/account_credentials.go b/idp/internal/services/account_credentials.go index baa3c2c..0c5d689 100644 --- a/idp/internal/services/account_credentials.go +++ b/idp/internal/services/account_credentials.go @@ -9,6 +9,7 @@ package services import ( "context" "fmt" + "slices" "strings" "time" @@ -24,8 +25,8 @@ import ( const ( accountCredentialsLocation string = "account_credentials" - accountCredentialsKeysCacheTTL int = 900 // 15 minutes - accountCredentialsKeysCacheKeyPrefix string = "account_credentials_keys" + accountCredentialsKeysCacheTTL time.Duration = 15 * time.Minute + accountCredentialsKeysCacheKeyPrefix string = "account_credentials_keys" ) func mapAccountCredentialsTransport( @@ -106,6 +107,79 @@ func mapAccountCredentialsTokenEndpointAuthMethod( } } +func mapAccountCredentialsGrantTypes( + applicationType database.AccountCredentialsType, + grantTypes []string, +) ([]database.GrantType, *exceptions.ServiceError) { + switch applicationType { + case database.AccountCredentialsTypeMcp: + if len(grantTypes) == 0 { + return []database.GrantType{database.GrantTypeAuthorizationCode, database.GrantTypeRefreshToken}, nil + } + + gts := make([]database.GrantType, 0, len(grantTypes)) + for _, grantType := range grantTypes { + mappedGrantType, serviceErr := mapGrantType(grantType) + if serviceErr != nil { + return nil, serviceErr + } + + if mappedGrantType != database.GrantTypeAuthorizationCode && mappedGrantType != database.GrantTypeRefreshToken { + return nil, exceptions.NewValidationError("only authorization_code and refresh_token grant types are supported for mcp credentials") + } + + gts = append(gts, mappedGrantType) + } + + return gts, nil + case database.AccountCredentialsTypeService: + if len(grantTypes) == 0 { + return []database.GrantType{database.GrantTypeClientCredentials, database.GrantTypeUrnIetfParamsOauthGrantTypeJwtBearer}, nil + } + + gts := make([]database.GrantType, 0, len(grantTypes)) + for _, grantType := range grantTypes { + mappedGrantType, serviceErr := mapGrantType(grantType) + if serviceErr != nil { + return nil, serviceErr + } + + if mappedGrantType != database.GrantTypeClientCredentials && mappedGrantType != database.GrantTypeUrnIetfParamsOauthGrantTypeJwtBearer { + return nil, exceptions.NewValidationError("only client_credentials and urn:ietf:params:oauth:grant-type:jwt-bearer grant types are supported for service credentials") + } + + gts = append(gts, mappedGrantType) + } + + return gts, nil + case database.AccountCredentialsTypeNative: + if len(grantTypes) == 0 { + return []database.GrantType{database.GrantTypeAuthorizationCode, database.GrantTypeRefreshToken}, nil + } + + gts := make([]database.GrantType, 0, len(grantTypes)) + for _, grantType := range grantTypes { + mappedGrantType, serviceErr := mapGrantType(grantType) + if serviceErr != nil { + return nil, serviceErr + } + + if mappedGrantType != database.GrantTypeAuthorizationCode && mappedGrantType != database.GrantTypeRefreshToken { + return nil, exceptions.NewValidationError("only authorization_code and refresh_token grant types are supported for native credentials") + } + + gts = append(gts, mappedGrantType) + } + if !slices.Contains(gts, database.GrantTypeAuthorizationCode) { + return nil, exceptions.NewValidationError("authorization_code grant type is required for native credentials") + } + + return gts, nil + default: + return nil, exceptions.NewValidationError("invalid credentials type: " + string(applicationType)) + } +} + func mapAccountCredentialsScope(scope string) (database.AccountCredentialsScope, *exceptions.ServiceError) { acScope := database.AccountCredentialsScope(scope) switch acScope { @@ -219,8 +293,8 @@ func (s *Services) CreateAccountCredentials( count, err := s.database.CountAccountCredentialsByNameAndAccountID( ctx, database.CountAccountCredentialsByNameAndAccountIDParams{ - AccountID: accountID, - Name: name, + AccountID: accountID, + ClientName: name, }, ) if err != nil { @@ -240,7 +314,7 @@ func (s *Services) CreateAccountCredentials( AccountID: accountID, AccountPublicID: opts.AccountPublicID, CredentialsType: credentialsType, - Name: name, + ClientName: name, Scopes: scopes, TokenEndpointAuthMethod: authMethod, Domain: domain, @@ -251,7 +325,7 @@ func (s *Services) CreateAccountCredentials( LogoUri: mapEmptyURL(opts.LogoURI), PolicyUri: mapEmptyURL(opts.PolicyURI), TosUri: mapEmptyURL(opts.TOSURI), - SoftwareID: opts.SoftwareID, + SoftwareID: mapEmptyString(opts.SoftwareID), SoftwareVersion: mapEmptyString(opts.SoftwareVersion), Contacts: opts.Contacts, CreationMethod: opts.CreationMethod, @@ -263,7 +337,7 @@ func (s *Services) CreateAccountCredentials( return dtos.AccountCredentialsDTO{}, exceptions.FromDBError(err) } - return dtos.MapAccountCredentialsToDTO(&accountCredentials), nil + return dtos.MapAccountCredentialsToDTO(&accountCredentials) } qrs, txn, err := s.database.BeginTx(ctx) @@ -283,7 +357,7 @@ func (s *Services) CreateAccountCredentials( AccountID: accountID, AccountPublicID: opts.AccountPublicID, CredentialsType: credentialsType, - Name: name, + ClientName: name, Scopes: scopes, TokenEndpointAuthMethod: authMethod, Domain: domain, @@ -294,7 +368,7 @@ func (s *Services) CreateAccountCredentials( LogoUri: mapEmptyURL(opts.LogoURI), PolicyUri: mapEmptyURL(opts.PolicyURI), TosUri: mapEmptyURL(opts.TOSURI), - SoftwareID: opts.SoftwareID, + SoftwareID: mapEmptyString(opts.SoftwareID), SoftwareVersion: mapEmptyString(opts.SoftwareVersion), Contacts: opts.Contacts, CreationMethod: opts.CreationMethod, @@ -344,7 +418,7 @@ func (s *Services) CreateAccountCredentials( return dtos.AccountCredentialsDTO{}, serviceErr } - return dtos.MapAccountCredentialsToDTOWithJWK(&accountCredentials, jwk, dbPrms.ExpiresAt), nil + return dtos.MapAccountCredentialsToDTOWithJWK(&accountCredentials, jwk, dbPrms.ExpiresAt) case AuthMethodClientSecretBasic, AuthMethodClientSecretPost, AuthMethodClientSecretJWT: var ccID int32 var secretID, secret string @@ -377,7 +451,7 @@ func (s *Services) CreateAccountCredentials( return dtos.AccountCredentialsDTO{}, serviceErr } - return dtos.MapAccountCredentialsToDTOWithSecret(&accountCredentials, secretID, secret, exp), nil + return dtos.MapAccountCredentialsToDTOWithSecret(&accountCredentials, secretID, secret, exp) default: logger.ErrorContext(ctx, "Invalid auth method", "authMethod", opts.AuthMethod) serviceErr = exceptions.NewInternalServerError() @@ -420,7 +494,7 @@ func (s *Services) GetAccountCredentialsByClientIDAndAccountPublicID( } logger.InfoContext(ctx, "Got account keys by client id and account public id successfully") - return dtos.MapAccountCredentialsToDTO(&accountCredentials), nil + return dtos.MapAccountCredentialsToDTO(&accountCredentials) } type GetAccountCredentialsByPublicIDOptions struct { @@ -449,7 +523,7 @@ func (s *Services) GetAccountCredentialsByPublicID( return dtos.AccountCredentialsDTO{}, serviceErr } - return dtos.MapAccountCredentialsToDTO(&accountClients), nil + return dtos.MapAccountCredentialsToDTO(&accountClients) } type getAccountCredentialsForMutationOptions struct { @@ -520,8 +594,18 @@ func (s *Services) ListAccountCredentialsByAccountPublicID( return nil, 0, exceptions.NewInternalServerError() } + accountCredentialsDTOs := make([]dtos.AccountCredentialsDTO, len(accountCredentials)) + for i, accountCredential := range accountCredentials { + var serviceErr *exceptions.ServiceError + accountCredentialsDTOs[i], serviceErr = dtos.MapAccountCredentialsToDTO(&accountCredential) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map account credentials to DTO", "serviceError", serviceErr) + return nil, 0, serviceErr + } + } + logger.InfoContext(ctx, "Successfully listed account keys by account id") - return utils.MapSlice(accountCredentials, dtos.MapAccountCredentialsToDTO), count, nil + return accountCredentialsDTOs, count, nil } func mapAccountCredentialsUpdateTransport( @@ -588,12 +672,12 @@ func (s *Services) UpdateAccountCredentials( } name := strings.TrimSpace(opts.Name) - if name != accountCredentialsDTO.Name { + if name != accountCredentialsDTO.ClientName { count, err := s.database.CountAccountCredentialsByNameAndAccountID( ctx, database.CountAccountCredentialsByNameAndAccountIDParams{ - AccountID: accountCredentialsDTO.AccountID(), - Name: name, + AccountID: accountCredentialsDTO.AccountID(), + ClientName: name, }, ) if err != nil { @@ -624,7 +708,7 @@ func (s *Services) UpdateAccountCredentials( accountCredentials, err := s.database.UpdateAccountCredentials(ctx, database.UpdateAccountCredentialsParams{ ID: accountCredentialsDTO.ID(), Scopes: scopes, - Name: name, + ClientName: name, Domain: domain, ClientUri: opts.ClientURI, RedirectUris: opts.RedirectURIs, @@ -640,7 +724,7 @@ func (s *Services) UpdateAccountCredentials( return dtos.AccountCredentialsDTO{}, exceptions.FromDBError(err) } - return dtos.MapAccountCredentialsToDTO(&accountCredentials), nil + return dtos.MapAccountCredentialsToDTO(&accountCredentials) } type DeleteAccountCredentialsOptions struct { @@ -1382,7 +1466,7 @@ func (s *Services) ListActiveAccountCredentialsKeysWithCache( logger.InfoContext(ctx, "Listing account credentials keys with cache...") cacheKey := fmt.Sprintf("%s:%s", accountCredentialsKeysCacheKeyPrefix, opts.AccountPublicID) - jwksDTO, etag, err := cache.GetResponse(s.cache, ctx, cache.GetResponseOptions[dtos.JWKsDTO]{ + jwksDTO, etag, found, err := cache.GetResponse(s.cache, ctx, cache.GetResponseOptions[dtos.JWKsDTO]{ RequestID: opts.RequestID, Key: cacheKey, }) @@ -1390,7 +1474,7 @@ func (s *Services) ListActiveAccountCredentialsKeysWithCache( logger.ErrorContext(ctx, "Failed to get cached account credentials keys", "error", err) return dtos.JWKsDTO{}, "", exceptions.NewInternalServerError() } - if etag != "" { + if found && etag != "" { logger.InfoContext(ctx, "Found cached account credentials keys", "etag", etag) return jwksDTO, etag, nil } diff --git a/idp/internal/services/account_credentials_registration.go b/idp/internal/services/account_credentials_registration.go new file mode 100644 index 0000000..e5038c1 --- /dev/null +++ b/idp/internal/services/account_credentials_registration.go @@ -0,0 +1,943 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + "errors" + "net/url" + "slices" + "strings" + "time" + + "github.com/google/uuid" + "golang.org/x/net/publicsuffix" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/providers/tokens" + "github.com/tugascript/devlogs/idp/internal/services/dtos" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const ( + accountCredentialsRegistrationLocation = "account_credentials_registration" +) + +var allowedAccountCredentialsScopes []string = []string{ + string(database.AccountCredentialsScopeEmail), + string(database.AccountCredentialsScopeProfile), + string(database.AccountCredentialsScopeAccountAdmin), + string(database.AccountCredentialsScopeAccountUsersRead), + string(database.AccountCredentialsScopeAccountUsersWrite), + string(database.AccountCredentialsScopeAccountAppsRead), + string(database.AccountCredentialsScopeAccountAppsWrite), + string(database.AccountCredentialsScopeAccountAppsConfigsRead), + string(database.AccountCredentialsScopeAccountAppsConfigsWrite), + string(database.AccountCredentialsScopeAccountCredentialsRead), + string(database.AccountCredentialsScopeAccountCredentialsWrite), + string(database.AccountCredentialsScopeAccountCredentialsConfigsRead), + string(database.AccountCredentialsScopeAccountCredentialsConfigsWrite), + string(database.AccountCredentialsScopeAccountAuthProvidersRead), +} + +type checkAccountCRDomainOptions struct { + requestID string + accountPublicID uuid.UUID + domain string + requireVerifiedDomains bool +} + +func (s *Services) checkAccountCRDomain( + ctx context.Context, + opts checkAccountCRDomainOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.requestID, dynamicRegistrationDomainsLocation, "checkAccountCRDomain").With( + "domain", opts.domain, + ) + logger.InfoContext(ctx, "Checking account credential domain validity") + + baseDomain, err := publicsuffix.EffectiveTLDPlusOne(opts.domain) + if err != nil { + logger.WarnContext(ctx, "Failed to parse base domain", "error", err) + return "", exceptions.NewValidationError("invalid client URI") + } + + var count int64 + if baseDomain != opts.domain { + if opts.requireVerifiedDomains { + count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID( + ctx, + database.CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams{ + AccountPublicID: opts.accountPublicID, + Domains: []string{opts.domain, baseDomain}, + }, + ) + } else { + count, err = s.database.CountDynamicRegistrationDomainsByDomainsAndAccountPublicID( + ctx, + database.CountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams{ + AccountPublicID: opts.accountPublicID, + Domains: []string{opts.domain, baseDomain}, + }, + ) + } + } else { + if opts.requireVerifiedDomains { + count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID( + ctx, + database.CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicIDParams{ + AccountPublicID: opts.accountPublicID, + Domain: opts.domain, + }, + ) + } else { + count, err = s.database.CountDynamicRegistrationDomainsByDomainAndAccountPublicID( + ctx, + database.CountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams{ + AccountPublicID: opts.accountPublicID, + Domain: opts.domain, + }, + ) + } + } + + if err != nil { + logger.ErrorContext(ctx, "Failed to count verified account dynamic registration domains", "error", err) + return "", exceptions.FromDBError(err) + } + if count > 0 { + logger.InfoContext(ctx, "Account credential domain is whitelisted") + return baseDomain, nil + } + + logger.InfoContext(ctx, "Account credential domain is not whitelisted or verified") + return "", exceptions.NewUnauthorizedError() +} + +type buildAccountCRSoftwareStatementFuncOptions struct { + requestID string + accountPublicID uuid.UUID + verificationMethods []database.SoftwareStatementVerificationMethod + jwksURI string + jwks []string + domain string + baseDomain string +} + +func (s *Services) buildAccountCRSoftwareStatementFunc( + ctx context.Context, + opts buildAccountCRSoftwareStatementFuncOptions, +) tokens.GetUnknownPublicJWK { + logger := s.buildLogger(opts.requestID, accountCredentialsRegistrationLocation, "buildAccountCRSoftwareStatementFunc").With( + "accountPublicID", opts.accountPublicID, + ) + logger.InfoContext(ctx, "Checking account credential software statement validity") + + if slices.Contains(opts.verificationMethods, database.SoftwareStatementVerificationMethodJwksUri) && opts.jwksURI != "" { + return func(kid string) (utils.JWK, error) { + parsedURI, err := url.Parse(opts.jwksURI) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse JWKs URI", "error", err) + return nil, errors.New("invalid JWKs URI") + } + if parsedURI.Host != opts.baseDomain || !strings.Contains(parsedURI.Host, "."+opts.baseDomain) { + logger.WarnContext(ctx, "JWKs URI parsedURI does not match client URI parsedURI") + return nil, errors.New("JWKs URI parsedURI does not match client URI parsedURI") + } + + jwks, err := s.jwt.GetPublicJWKs(ctx, tokens.GetPublicJWKsOptions{ + RequestID: opts.requestID, + URL: opts.jwksURI, + }) + if err != nil { + logger.WarnContext(ctx, "Failed to get public JWKs from JWKs URI", "error", err) + return nil, errors.New("failed to get public JWKs from JWKs URI") + } + + jwkIdx := slices.IndexFunc(jwks.Keys, func(jwk utils.JWK) bool { + return jwk.GetKeyID() == kid + }) + if jwkIdx == -1 { + logger.WarnContext(ctx, "No matching JWK found for KID in JWKs URI", "kid", kid) + return nil, errors.New("no matching JWK found for KID in JWKs URI") + } + + return jwks.Keys[jwkIdx], nil + } + } + if slices.Contains(opts.verificationMethods, database.SoftwareStatementVerificationMethodManual) { + if len(opts.jwks) > 0 { + return func(kid string) (utils.JWK, error) { + jwks := make([]utils.JWK, 0, len(opts.jwks)) + for _, rawJWK := range opts.jwks { + jwk, err := utils.JsonToJWK([]byte(rawJWK)) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse manual JWK", "error", err) + return nil, errors.New("failed to parse manual JWK") + } + jwks = append(jwks, jwk) + } + + jwkIdx := slices.IndexFunc(jwks, func(jwk utils.JWK) bool { + return jwk.GetKeyID() == kid + }) + if jwkIdx == -1 { + logger.WarnContext(ctx, "No matching manual JWK found for KID", "kid", kid) + return nil, errors.New("no matching manual JWK found for KID") + } + + sliceJWK := jwks[jwkIdx] + jwkRefEnt, err := s.database.FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicID( + ctx, + database.FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicIDParams{ + CredentialsKeyKid: kid, + AccountPublicID: opts.accountPublicID, + }, + ) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "No database entry found for manual JWK", "kid", kid, "error", err) + return nil, errors.New("no database entry found for manual JWK") + } + + logger.ErrorContext(ctx, "Failed to find database entry for manual JWK", "kid", kid, "error", err) + return nil, errors.New("failed to find database entry for manual JWK") + } + if jwkRefEnt.RootDomain != opts.baseDomain { + logger.WarnContext(ctx, "Manual JWK root domain does not match client URI base domain", + "kid", kid, "jwkRootDomain", jwkRefEnt.RootDomain, "baseDomain", opts.baseDomain, + ) + return nil, errors.New("manual JWK root domain does not match client URI base domain") + } + + jwkEnt, err := s.database.FindCredentialsKeyByID(ctx, jwkRefEnt.CredentialsKeyID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "No credentials key found for manual JWK", "kid", kid, "error", err) + return nil, errors.New("no credentials key found for manual JWK") + } + + logger.ErrorContext(ctx, "Failed to find credentials key for manual JWK", "kid", kid, "error", err) + return nil, errors.New("failed to find credentials key for manual JWK") + } + + entJWK, err := utils.JsonToJWK(jwkEnt.PublicKey) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse manual JWK", "error", err) + return nil, errors.New("failed to parse manual JWK") + } + if !entJWK.ComparePublicKey(sliceJWK) { + logger.WarnContext(ctx, "Manual JWK does not match database credentials key", "kid", kid) + return nil, errors.New("manual JWK does not match database credentials key") + } + + return sliceJWK, nil + } + } + + return func(kid string) (utils.JWK, error) { + jwkEntity, err := s.database.FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicID( + ctx, + database.FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicIDParams{ + RootDomain: opts.baseDomain, + AccountPublicID: opts.accountPublicID, + }, + ) + if err != nil { + if exceptions.FromDBError(err).Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "No manual JWKs found for software statement", "error", err) + return nil, errors.New("no manual JWKs found for software statement") + } + + logger.ErrorContext(ctx, "Failed to find manual JWKs for software statement", "error", err) + return nil, errors.New("failed to find manual JWKs for software statement") + } + if jwkEntity.PublicKid != kid { + logger.WarnContext(ctx, "No matching manual JWK found for KID", + "kid", kid, "publicKID", jwkEntity.PublicKid, + ) + return nil, errors.New("no matching manual JWK found for KID") + } + + jwk, err := utils.JsonToJWK(jwkEntity.PublicKey) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse manual JWK for software statement", "error", err) + return nil, errors.New("failed to parse manual JWK for software statement") + } + + return jwk, nil + } + } + + return func(kid string) (utils.JWK, error) { + logger.WarnContext(ctx, "No verification method available for software statement") + return nil, errors.New("no verification method available") + } +} + +func mapAccountCredentialsDRTransport(applicationType database.AccountCredentialsType) database.Transport { + if applicationType == database.AccountCredentialsTypeMcp { + return database.TransportStreamableHttp + } + + return database.TransportHttps +} + +type mapAccountCredentialsRegistrationDataToDBParamsOptions struct { + applicationType database.AccountCredentialsType + accountPublicID uuid.UUID + accountID int32 + domain string + requestID string + tokenEndpointAuthMethod database.AuthMethod + transport database.Transport + scopes []database.AccountCredentialsScope + data *ApplicationRegistrationData + claims *tokens.SoftwareStatementClaims +} + +func (s *Services) mapAccountCredentialsRegistrationDataToDBParams( + ctx context.Context, + opts mapAccountCredentialsRegistrationDataToDBParamsOptions, +) (database.CreateAccountCredentialsParams, *exceptions.ServiceError) { + logger := s.buildLogger(opts.requestID, accountCredentialsRegistrationLocation, "mapAccountCredentialsRegistrationDataToDBParams").With( + "accountPublicID", opts.accountPublicID, + "accountID", opts.accountID, + "domain", opts.domain, + "data", opts.data, + "claims", opts.claims, + ) + logger.InfoContext(ctx, "Mapping account credentials registration data to database params") + + accessTokenSigningAlg, serviceErr := mapTokenCryptoSuiteWithDefault(opts.data.AccessTokenSigningAlg) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map access token signing alg", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + responseTypes, serviceErr := mapResponseTypesWithDefault(opts.data.ResponseTypes) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map response types", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + grantTypes, serviceErr := mapAccountCredentialsGrantTypes(opts.applicationType, opts.data.GrantTypes) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map grant types", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + subjectType, serviceErr := mapEmptySubjectType(opts.data.SubjectType) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map subject type", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + tokenEndpointAuthSigningAlg, serviceErr := mapEmptyTokenCryptoSuite(opts.data.TokenEndpointAuthSigningAlg) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map token endpoint auth signing alg", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + idSignAlg, serviceErr := mapTokenCryptoSuiteWithDefault(opts.data.IDTokenSignedResponseAlg) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map ID token signed response alg", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + idEncAlg, serviceErr := mapEmptyTokenEncryptionAlgorithm(opts.data.IDTokenEncryptedResponseAlg) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map ID token encrypted response alg", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + idEncEnc, serviceErr := mapEmptyTokenEncryptionEncoding(idEncAlg, opts.data.IDTokenEncryptedResponseEnc) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map ID token encrypted response enc", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + userInfoSignAlg, serviceErr := mapEmptyTokenCryptoSuite(opts.data.UserInfoSignedResponseAlg) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map user info signed response alg", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + userInfoEncAlg, serviceErr := mapEmptyTokenEncryptionAlgorithm(opts.data.UserInfoEncryptedResponseAlg) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map user info encrypted response alg", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + userInfoEncEnc, serviceErr := mapEmptyTokenEncryptionEncoding(userInfoEncAlg, opts.data.UserInfoEncryptedResponseEnc) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map user info encrypted response enc", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + requestObjectSigningAlg, serviceErr := mapEmptyTokenCryptoSuite(opts.data.RequestObjectSigningAlg) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map request object signing alg", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + requestObjectEncryptionAlg, serviceErr := mapEmptyTokenEncryptionAlgorithm(opts.data.RequestObjectEncryptionAlg) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map request object encryption alg", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + requestObjectEncryptionEnc, serviceErr := mapEmptyTokenEncryptionEncoding(requestObjectEncryptionAlg, opts.data.RequestObjectEncryptionEnc) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map request object encryption enc", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + jwks, serviceErr := mapEmptyJWKs(logger, ctx, opts.data.JWKs) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map JWKs", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + params := database.CreateAccountCredentialsParams{ + AccountID: opts.accountID, + AccountPublicID: opts.accountPublicID, + Domain: opts.domain, + CreationMethod: database.CreationMethodDynamicRegistration, + Transport: opts.transport, + ClientID: utils.Base62UUID(), + RedirectUris: utils.MapSlice(opts.data.RedirectURIs, func(uri *string) string { + return utils.ProcessURL(*uri) + }), + TokenEndpointAuthMethod: opts.tokenEndpointAuthMethod, + TokenEndpointAuthSigningAlg: tokenEndpointAuthSigningAlg, + AccessTokenSigningAlg: accessTokenSigningAlg, + GrantTypes: grantTypes, + ResponseTypes: responseTypes, + ClientName: opts.data.ClientName, + ClientUri: utils.ProcessURL(opts.data.ClientURI), + LogoUri: mapEmptyURL(opts.data.LogoURI), + Scopes: opts.scopes, + Contacts: opts.data.Contacts, + TosUri: mapEmptyURL(opts.data.TOSURI), + PolicyUri: mapEmptyURL(opts.data.PolicyURI), + JwksUri: mapEmptyURL(opts.data.JWKsURI), + Jwks: jwks, + SoftwareID: mapEmptyString(opts.data.SoftwareID), + SoftwareVersion: mapEmptyString(opts.data.SoftwareVersion), + CredentialsType: opts.applicationType, + SectorIdentifierUri: mapEmptyURL(opts.data.SectorIdentifierURI), + SubjectType: subjectType, + IDTokenSignedResponseAlg: idSignAlg, + IDTokenEncryptedResponseAlg: idEncAlg, + IDTokenEncryptedResponseEnc: idEncEnc, + UserinfoSignedResponseAlg: userInfoSignAlg, + UserinfoEncryptedResponseAlg: userInfoEncAlg, + UserinfoEncryptedResponseEnc: userInfoEncEnc, + RequestObjectSigningAlg: requestObjectSigningAlg, + RequestObjectEncryptionAlg: requestObjectEncryptionAlg, + RequestObjectEncryptionEnc: requestObjectEncryptionEnc, + DefaultMaxAge: mapEmptyBigInt(opts.data.DefaultMaxAge), + RequireAuthTime: opts.data.RequireAuthTime, + DefaultAcrValues: opts.data.DefaultACRValues, + InitiateLoginUri: mapEmptyURL(opts.data.InitiateLoginURI), + RequestUris: utils.MapSlice(opts.data.RequestURIs, func(uri *string) string { + return utils.ProcessURL(*uri) + }), + } + + if opts.claims != nil { + if opts.claims.ClientName != "" { + params.ClientName = opts.claims.ClientName + } + if opts.claims.ClientURI != "" { + params.ClientUri = utils.ProcessURL(opts.claims.ClientURI) + } + if opts.claims.LogoURI != "" { + params.LogoUri = mapEmptyURL(opts.claims.LogoURI) + } + if len(opts.claims.RedirectURIs) > 0 { + params.RedirectUris = utils.MapSlice(opts.claims.RedirectURIs, func(uri *string) string { + return utils.ProcessURL(*uri) + }) + } + if opts.claims.TOSURI != "" { + params.TosUri = mapEmptyURL(opts.claims.TOSURI) + } + if opts.claims.PolicyURI != "" { + params.PolicyUri = mapEmptyURL(opts.claims.PolicyURI) + } + if opts.claims.JWKsURI != "" { + params.JwksUri = mapEmptyURL(opts.claims.JWKsURI) + } + if len(opts.claims.JWKs) > 0 { + jwks, serviceErr := mapEmptyJWKs(logger, ctx, opts.claims.JWKs) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map JWKs", "serviceError", serviceErr) + return database.CreateAccountCredentialsParams{}, serviceErr + } + + params.Jwks = jwks + } + if opts.claims.SoftwareID != "" { + params.SoftwareID = mapEmptyString(opts.claims.SoftwareID) + } + if opts.claims.SoftwareVersion != "" { + params.SoftwareVersion = mapEmptyString(opts.claims.SoftwareVersion) + } + if opts.claims.SectorIdentifierURI != "" { + params.SectorIdentifierUri = mapEmptyURL(opts.claims.SectorIdentifierURI) + } + if opts.claims.SubjectType != "" { + subjectType, _ := mapEmptySubjectType(opts.claims.SubjectType) + params.SubjectType = subjectType + } + if len(opts.claims.RequestURIs) > 0 { + params.RequestUris = utils.MapSlice(opts.claims.RequestURIs, func(uri *string) string { + return utils.ProcessURL(*uri) + }) + } + if opts.claims.IDTokenSignedResponseAlg != "" { + idSignAlg, _ := mapTokenCryptoSuiteWithDefault(opts.claims.IDTokenSignedResponseAlg) + params.IDTokenSignedResponseAlg = idSignAlg + } + if opts.claims.IDTokenEncryptedResponseAlg != "" { + idEncAlg, _ := mapEmptyTokenEncryptionAlgorithm(opts.claims.IDTokenEncryptedResponseAlg) + params.IDTokenEncryptedResponseAlg = idEncAlg + } + if opts.claims.IDTokenEncryptedResponseEnc != "" { + idEncEnc, _ := mapEmptyTokenEncryptionEncoding(params.IDTokenEncryptedResponseAlg, opts.claims.IDTokenEncryptedResponseEnc) + params.IDTokenEncryptedResponseEnc = idEncEnc + } + if opts.claims.UserInfoSignedResponseAlg != "" { + userInfoSignAlg, _ := mapEmptyTokenCryptoSuite(opts.claims.UserInfoSignedResponseAlg) + params.UserinfoSignedResponseAlg = userInfoSignAlg + } + if opts.claims.UserInfoEncryptedResponseAlg != "" { + userInfoEncAlg, _ := mapEmptyTokenEncryptionAlgorithm(opts.claims.UserInfoEncryptedResponseAlg) + params.UserinfoEncryptedResponseAlg = userInfoEncAlg + } + if opts.claims.UserInfoEncryptedResponseEnc != "" { + userInfoEncEnc, _ := mapEmptyTokenEncryptionEncoding(params.UserinfoEncryptedResponseAlg, opts.claims.UserInfoEncryptedResponseEnc) + params.UserinfoEncryptedResponseEnc = userInfoEncEnc + } + if opts.claims.RequestObjectSigningAlg != "" { + requestObjectSigningAlg, _ := mapEmptyTokenCryptoSuite(opts.claims.RequestObjectSigningAlg) + params.RequestObjectSigningAlg = requestObjectSigningAlg + } + if opts.claims.RequestObjectEncryptionAlg != "" { + requestObjectEncryptionAlg, _ := mapEmptyTokenEncryptionAlgorithm(opts.claims.RequestObjectEncryptionAlg) + params.RequestObjectEncryptionAlg = requestObjectEncryptionAlg + } + if opts.claims.RequestObjectEncryptionEnc != "" { + requestObjectEncryptionEnc, _ := mapEmptyTokenEncryptionEncoding(params.RequestObjectEncryptionAlg, opts.claims.RequestObjectEncryptionEnc) + params.RequestObjectEncryptionEnc = requestObjectEncryptionEnc + } + if opts.claims.TokenEndpointAuthSigningAlg != "" { + tokenEndpointAuthSigningAlg, _ := mapEmptyTokenCryptoSuite(opts.claims.TokenEndpointAuthSigningAlg) + params.TokenEndpointAuthSigningAlg = tokenEndpointAuthSigningAlg + } + if opts.claims.AccessTokenSigningAlg != "" { + accessTokenSigningAlg, _ := mapTokenCryptoSuiteWithDefault(opts.claims.AccessTokenSigningAlg) + params.AccessTokenSigningAlg = accessTokenSigningAlg + } + if opts.claims.RequireAuthTime { + params.RequireAuthTime = opts.claims.RequireAuthTime + } + if opts.claims.DefaultMaxAge > 0 { + params.DefaultMaxAge = mapEmptyBigInt(opts.claims.DefaultMaxAge) + } + if opts.claims.DefaultACRValues != nil { + params.DefaultAcrValues = opts.claims.DefaultACRValues + } + if opts.claims.InitiateLoginURI != "" { + params.InitiateLoginUri = mapEmptyURL(opts.claims.InitiateLoginURI) + } + if len(opts.claims.GrantTypes) > 0 { + params.GrantTypes = utils.MapSlice(opts.claims.GrantTypes, func(grantType *string) database.GrantType { + return database.GrantType(*grantType) + }) + } + if len(opts.claims.ResponseTypes) > 0 { + params.ResponseTypes = utils.MapSlice(opts.claims.ResponseTypes, func(responseType *string) database.ResponseType { + return database.ResponseType(*responseType) + }) + } + if opts.claims.Scope != "" { + params.Scopes = utils.MapSlice(strings.Fields(opts.claims.Scope), func(scope *string) database.AccountCredentialsScope { + return database.AccountCredentialsScope(*scope) + }) + } + if len(opts.claims.Contacts) > 0 { + params.Contacts = opts.claims.Contacts + } + } + + return params, nil +} + +type CreateAccountCredentialsRegistrationOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + IsAuthenticated bool + ApplicationType string + RedirectURIs []string + TokenEndpointAuthMethod string + GrantTypes []string + ResponseTypes []string + ClientName string + ClientURI string + LogoURI string + TOSURI string + PolicyURI string + Contacts []string + SoftwareID string + SoftwareVersion string + SoftwareStatement string + JWKsURI string + JWKs []string + FrontendDomain string + BackendDomain string + RequireAuthTime bool + DefaultMaxAge int64 + SubjectType string + IDTokenSignedResponseAlg string + IDTokenEncryptedResponseAlg string + IDTokenEncryptedResponseEnc string + RequestObjectSigningAlg string + RequestObjectEncryptionAlg string + RequestObjectEncryptionEnc string + DefaultACRValues []string + Scope string + SectorIdentifierURI string + InitiateLoginURI string + RequestURIs []string + UserInfoSignedResponseAlg string + UserInfoEncryptedResponseAlg string + UserInfoEncryptedResponseEnc string + TokenEndpointAuthSigningAlg string + AccessTokenSigningAlg string +} + +func (s *Services) CreateAccountCredentialsRegistration( + ctx context.Context, + opts CreateAccountCredentialsRegistrationOptions, +) (dtos.AccountCredentialsDTO, *exceptions.ServiceError) { + logger := s.buildLogger( + opts.RequestID, + accountCredentialsRegistrationLocation, + "CreateAccountCredentialsRegistration", + ).With( + "accountPublicID", opts.AccountPublicID, + ) + logger.InfoContext(ctx, "Creating account credentials registration...") + + applicationType, serviceErr := mapAccountCredentialsType(opts.ApplicationType) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map application type", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + scopes, serviceErr := mapAccountCredentialsScopes(strings.Fields(opts.Scope)) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map scopes", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + transport := mapAccountCredentialsDRTransport(applicationType) + tokenEndpointAuthMethod, serviceErr := mapAccountCredentialsTokenEndpointAuthMethod( + opts.TokenEndpointAuthMethod, + applicationType, + transport, + ) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map token endpoint auth method", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + if !validateEncryptionAlgorithmPair(opts.IDTokenEncryptedResponseAlg, opts.IDTokenEncryptedResponseEnc) { + logger.WarnContext(ctx, "id_token encryption algorithm and encoding must both be set or both be unset") + return dtos.AccountCredentialsDTO{}, exceptions.NewValidationError("id_token encryption algorithm and encoding mismatch") + } + if !validateEncryptionAlgorithmPair(opts.UserInfoEncryptedResponseAlg, opts.UserInfoEncryptedResponseEnc) { + logger.WarnContext(ctx, "userinfo encryption algorithm and encoding must both be set or both be unset") + return dtos.AccountCredentialsDTO{}, exceptions.NewValidationError("userinfo encryption algorithm and encoding mismatch") + } + if !validateEncryptionAlgorithmPair(opts.RequestObjectEncryptionAlg, opts.RequestObjectEncryptionEnc) { + logger.WarnContext(ctx, "request_object encryption algorithm and encoding must both be set or both be unset") + return dtos.AccountCredentialsDTO{}, exceptions.NewValidationError("request_object encryption algorithm and encoding mismatch") + } + + accountDRConfigDTO, serviceErr := s.GetAndCacheAccountDynamicRegistrationConfig(ctx, GetAndCacheAccountDynamicRegistrationConfigOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + }) + if serviceErr != nil { + if serviceErr.Code == exceptions.CodeNotFound { + logger.InfoContext(ctx, "Account dynamic registration config not found") + return dtos.AccountCredentialsDTO{}, exceptions.NewForbiddenError() + } + + logger.ErrorContext(ctx, "Failed to get account dynamic registration config", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + accountID, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account ID", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + if slices.Contains(accountDRConfigDTO.RequireInitialAccessTokenCredentialTypes, applicationType) && + !opts.IsAuthenticated { + logger.WarnContext(ctx, "Account dynamic registration configuration needs to contain initial access token") + return dtos.AccountCredentialsDTO{}, exceptions.NewUnauthorizedError() + } + + if slices.Contains(accountDRConfigDTO.RequireSoftwareStatementCredentialTypes, applicationType) && + opts.SoftwareStatement == "" { + logger.WarnContext(ctx, "Account dynamic registration configuration needs to contain software statement") + return dtos.AccountCredentialsDTO{}, exceptions.NewUnauthorizedError() + } + + _, serviceErr = s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + parsedClientURI, err := url.Parse(opts.ClientURI) + if err != nil { + logger.WarnContext(ctx, "Failed to parse client URI", "error", err) + return dtos.AccountCredentialsDTO{}, exceptions.NewValidationError("invalid client URI") + } + + domain := parsedClientURI.Hostname() + baseDomain, serviceErr := s.checkAccountCRDomain(ctx, checkAccountCRDomainOptions{ + requestID: opts.RequestID, + accountPublicID: opts.AccountPublicID, + domain: domain, + requireVerifiedDomains: slices.Contains(accountDRConfigDTO.RequireVerifiedDomainsCredentialsType, applicationType), + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to check domain validity", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + data := ApplicationRegistrationData{ + RedirectURIs: opts.RedirectURIs, + TokenEndpointAuthMethod: opts.TokenEndpointAuthMethod, + ResponseTypes: opts.ResponseTypes, + GrantTypes: opts.GrantTypes, + ApplicationType: opts.ApplicationType, + ClientName: opts.ClientName, + ClientURI: opts.ClientURI, + LogoURI: opts.LogoURI, + Scope: opts.Scope, + Contacts: opts.Contacts, + TOSURI: opts.TOSURI, + PolicyURI: opts.PolicyURI, + JWKsURI: opts.JWKsURI, + JWKs: opts.JWKs, + SoftwareID: opts.SoftwareID, + SoftwareVersion: opts.SoftwareVersion, + SubjectType: opts.SubjectType, + SectorIdentifierURI: opts.SectorIdentifierURI, + DefaultMaxAge: opts.DefaultMaxAge, + RequireAuthTime: opts.RequireAuthTime, + DefaultACRValues: opts.DefaultACRValues, + InitiateLoginURI: opts.InitiateLoginURI, + RequestURIs: opts.RequestURIs, + IDTokenSignedResponseAlg: opts.IDTokenSignedResponseAlg, + IDTokenEncryptedResponseAlg: opts.IDTokenEncryptedResponseAlg, + IDTokenEncryptedResponseEnc: opts.IDTokenEncryptedResponseEnc, + UserInfoSignedResponseAlg: opts.UserInfoSignedResponseAlg, + UserInfoEncryptedResponseAlg: opts.UserInfoEncryptedResponseAlg, + UserInfoEncryptedResponseEnc: opts.UserInfoEncryptedResponseEnc, + RequestObjectSigningAlg: opts.RequestObjectSigningAlg, + RequestObjectEncryptionAlg: opts.RequestObjectEncryptionAlg, + RequestObjectEncryptionEnc: opts.RequestObjectEncryptionEnc, + TokenEndpointAuthSigningAlg: opts.TokenEndpointAuthSigningAlg, + AccessTokenSigningAlg: opts.AccessTokenSigningAlg, + } + var ssClaimsReference *tokens.SoftwareStatementClaims + if opts.SoftwareStatement != "" { + ssClaims, stdClaims, err := s.jwt.VerifySoftwareStatement(ctx, tokens.VerifySoftwareStatementOptions{ + RequestID: opts.RequestID, + SoftwareStatement: opts.SoftwareStatement, + GetPublicJWK: s.buildAccountCRSoftwareStatementFunc(ctx, buildAccountCRSoftwareStatementFuncOptions{ + requestID: opts.RequestID, + accountPublicID: opts.AccountPublicID, + verificationMethods: accountDRConfigDTO.SoftwareStatementVerificationMethods, + jwksURI: opts.JWKsURI, + jwks: opts.JWKs, + domain: domain, + baseDomain: baseDomain, + }), + }) + if err != nil { + logger.WarnContext(ctx, "Failed to verify software statement", "error", err) + return dtos.AccountCredentialsDTO{}, exceptions.NewUnauthorizedError() + } + if serviceErr := s.verifySoftwareStatementSTDClaims(ctx, verifySoftwareStatementSTDClaimsOptions{ + requestID: opts.RequestID, + backendDomain: opts.BackendDomain, + frontendDomain: opts.FrontendDomain, + domain: domain, + baseDomain: baseDomain, + claims: &stdClaims, + }); serviceErr != nil { + logger.WarnContext(ctx, "Failed to verify software statement standard claims", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + if serviceErr := s.validateSoftwareStatementClaims(ctx, validateSoftwareStatementClaimsOptions{ + requestID: opts.RequestID, + claims: &ssClaims, + data: &data, + allowedScopes: utils.SliceToHashSet(allowedAccountCredentialsScopes), + }); serviceErr != nil { + logger.WarnContext(ctx, "Failed to validate software statement claims", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + ssClaimsReference = &ssClaims + } + + params, serviceErr := s.mapAccountCredentialsRegistrationDataToDBParams(ctx, mapAccountCredentialsRegistrationDataToDBParamsOptions{ + applicationType: applicationType, + accountPublicID: opts.AccountPublicID, + accountID: opts.AccountVersion, + domain: domain, + requestID: opts.RequestID, + tokenEndpointAuthMethod: tokenEndpointAuthMethod, + transport: transport, + scopes: scopes, + data: &data, + claims: ssClaimsReference, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map account credentials registration data to database params", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + if tokenEndpointAuthMethod == database.AuthMethodNone { + accountCredentials, err := s.database.CreateAccountCredentials(ctx, params) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account credentials", "error", err) + return dtos.AccountCredentialsDTO{}, exceptions.FromDBError(err) + } + + logger.InfoContext(ctx, "Created account credentials successfully") + return dtos.MapAccountCredentialsToDTO(&accountCredentials) + } + + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return dtos.AccountCredentialsDTO{}, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + accountCredentials, err := s.database.CreateAccountCredentials(ctx, params) + if err != nil { + logger.ErrorContext(ctx, "Failed to create account credentials", "error", err) + return dtos.AccountCredentialsDTO{}, exceptions.FromDBError(err) + } + + switch tokenEndpointAuthMethod { + case database.AuthMethodPrivateKeyJwt: + var dbPrms database.CreateCredentialsKeyParams + var jwk utils.JWK + dbPrms, jwk, serviceErr = s.clientCredentialsKey(ctx, clientCredentialsKeyOptions{ + requestID: opts.RequestID, + accountID: accountID, + accountPublicID: opts.AccountPublicID, + expiresIn: s.accountCCExpDays, + usage: database.CredentialsUsageAccount, + cryptoSuite: utils.SupportedCryptoSuite(params.AccessTokenSigningAlg), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to generate client credentials key", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + var clientKey database.CredentialsKey + clientKey, err = qrs.CreateCredentialsKey(ctx, dbPrms) + if err != nil { + logger.ErrorContext(ctx, "Failed to create client key", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + if err = qrs.CreateAccountCredentialKey(ctx, database.CreateAccountCredentialKeyParams{ + AccountID: accountID, + AccountCredentialsID: accountCredentials.ID, + CredentialsKeyID: clientKey.ID, + AccountPublicID: opts.AccountPublicID, + JwkKid: clientKey.PublicKid, + }); err != nil { + logger.ErrorContext(ctx, "Failed to create account credential key", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + return dtos.MapAccountCredentialsToDTOWithJWK(&accountCredentials, jwk, dbPrms.ExpiresAt) + case database.AuthMethodClientSecretBasic, database.AuthMethodClientSecretPost, database.AuthMethodClientSecretJwt: + var ccID int32 + var secretID, secret string + var exp time.Time + ccID, secretID, secret, exp, serviceErr = s.clientCredentialsSecret(ctx, qrs, clientCredentialsSecretOptions{ + requestID: opts.RequestID, + accountID: accountID, + storageMode: mapCCSecretStorageMode(string(tokenEndpointAuthMethod)), + expiresIn: s.appCCExpDays, + usage: database.CredentialsUsageAccount, + dekFN: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ + RequestID: opts.RequestID, + AccountID: accountID, + }), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to create client credentials secret", "serviceError", serviceErr) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + if err = qrs.CreateAccountCredentialSecret(ctx, database.CreateAccountCredentialSecretParams{ + AccountID: accountID, + AccountPublicID: opts.AccountPublicID, + AccountCredentialsID: accountCredentials.ID, + CredentialsSecretID: ccID, + SecretID: secretID, + }); err != nil { + logger.ErrorContext(ctx, "Failed to create account credential secret", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AccountCredentialsDTO{}, serviceErr + } + + return dtos.MapAccountCredentialsToDTOWithSecret(&accountCredentials, secretID, secret, exp) + default: + logger.ErrorContext(ctx, "Invalid token endpoint auth method", "tokenEndpointAuthMethod", tokenEndpointAuthMethod) + serviceErr = exceptions.NewInternalServerError() + return dtos.AccountCredentialsDTO{}, serviceErr + } +} diff --git a/idp/internal/services/account_dynamic_registration_configs.go b/idp/internal/services/account_dynamic_registration_configs.go index 3724bd7..8317817 100644 --- a/idp/internal/services/account_dynamic_registration_configs.go +++ b/idp/internal/services/account_dynamic_registration_configs.go @@ -8,13 +8,15 @@ package services import ( "context" + "fmt" + "time" "github.com/google/uuid" "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/cache" "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/services/dtos" - "github.com/tugascript/devlogs/idp/internal/utils" ) const ( @@ -25,8 +27,14 @@ const ( initialAccessTokenGenerationMethodAuthorizationCode string = "authorization_code" initialAccessTokenGenerationMethodManual string = "manual" + + accountDynamicRegistrationConfigCacheTTL time.Duration = 24 * time.Hour ) +func buildAccountDynamicRegistrationConfigCacheKey(accountPublicID uuid.UUID) string { + return fmt.Sprintf("%s:%s", accountDynamicRegistrationConfigsLocation, accountPublicID.String()) +} + func mapAccountCredentialsTypes(credentialsTypes []string) ([]database.AccountCredentialsType, *exceptions.ServiceError) { accountCredentialsTypes := make([]database.AccountCredentialsType, 0, len(credentialsTypes)) for _, credentialsType := range credentialsTypes { @@ -98,7 +106,6 @@ type SaveAccountDynamicRegistrationConfigOptions struct { AccountPublicID uuid.UUID AccountVersion int32 AccountCredentialsTypes []string - WhitelistedDomains []string RequireSoftwareStatementCredentialTypes []string SoftwareStatementVerificationMethods []string RequireInitialAccessTokenCredentialTypes []string @@ -170,7 +177,6 @@ func (s *Services) SaveAccountDynamicRegistrationConfig( AccountID: accountID, AccountPublicID: opts.AccountPublicID, AccountCredentialsTypes: credentialsTypes, - WhitelistedDomains: utils.ToEmptySlice(opts.WhitelistedDomains), RequireSoftwareStatementCredentialTypes: requireSoftwareStatementCredentialTypes, SoftwareStatementVerificationMethods: softwareStatementVerificationMethods, RequireInitialAccessTokenCredentialTypes: requireInitialAccessTokenCredentialTypes, @@ -183,13 +189,11 @@ func (s *Services) SaveAccountDynamicRegistrationConfig( } return dtos.MapAccountDynamicRegistrationConfigToDTO(&accountDynamicRegistrationConfig), true, nil - } accountDynamicRegistrationConfig, err = s.database.UpdateAccountDynamicRegistrationConfig(ctx, database.UpdateAccountDynamicRegistrationConfigParams{ ID: accountDynamicRegistrationConfig.ID, AccountCredentialsTypes: credentialsTypes, - WhitelistedDomains: utils.ToEmptySlice(opts.WhitelistedDomains), RequireSoftwareStatementCredentialTypes: requireSoftwareStatementCredentialTypes, SoftwareStatementVerificationMethods: softwareStatementVerificationMethods, RequireInitialAccessTokenCredentialTypes: requireInitialAccessTokenCredentialTypes, @@ -200,6 +204,14 @@ func (s *Services) SaveAccountDynamicRegistrationConfig( return dtos.AccountDynamicRegistrationConfigDTO{}, false, exceptions.FromDBError(err) } + if err := s.cache.DeleteResponse(ctx, cache.DeleteResponseOptions{ + RequestID: opts.RequestID, + Key: buildAccountDynamicRegistrationConfigCacheKey(opts.AccountPublicID), + }); err != nil { + logger.ErrorContext(ctx, "Failed to delete cached account dynamic registration config", "error", err) + return dtos.AccountDynamicRegistrationConfigDTO{}, false, exceptions.NewInternalServerError() + } + return dtos.MapAccountDynamicRegistrationConfigToDTO(&accountDynamicRegistrationConfig), false, nil } @@ -232,6 +244,52 @@ func (s *Services) GetAccountDynamicRegistrationConfig( return dtos.MapAccountDynamicRegistrationConfigToDTO(&accountDynamicRegistrationConfig), nil } +type GetAndCacheAccountDynamicRegistrationConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID +} + +func (s *Services) GetAndCacheAccountDynamicRegistrationConfig( + ctx context.Context, + opts GetAndCacheAccountDynamicRegistrationConfigOptions, +) (dtos.AccountDynamicRegistrationConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountDynamicRegistrationConfigsLocation, "GetAndCacheAccountDynamicRegistrationConfig").With( + "accountPublicID", opts.AccountPublicID, + ) + logger.InfoContext(ctx, "Getting and caching account dynamic registration config...") + + accountDRConfigDTO, found, err := cache.GetResponseWithoutETag(s.cache, ctx, cache.GetResponseOptions[dtos.AccountDynamicRegistrationConfigDTO]{ + RequestID: opts.RequestID, + Key: buildAccountDynamicRegistrationConfigCacheKey(opts.AccountPublicID), + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to get cached account dynamic registration config", "error", err) + return dtos.AccountDynamicRegistrationConfigDTO{}, exceptions.NewInternalServerError() + } + if found { + logger.InfoContext(ctx, "Account dynamic registration config found in cache") + return accountDRConfigDTO, nil + } + + accountDRConfigDTO, serviceErr := s.GetAccountDynamicRegistrationConfig(ctx, GetAccountDynamicRegistrationConfigOptions(opts)) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account dynamic registration config", "serviceError", serviceErr) + return dtos.AccountDynamicRegistrationConfigDTO{}, serviceErr + } + + if err := cache.SaveResponseWithoutETag(s.cache, ctx, cache.SaveResponseOptions[dtos.AccountDynamicRegistrationConfigDTO]{ + RequestID: opts.RequestID, + Key: buildAccountDynamicRegistrationConfigCacheKey(opts.AccountPublicID), + TTL: accountDynamicRegistrationConfigCacheTTL, + Value: accountDRConfigDTO, + }); err != nil { + logger.ErrorContext(ctx, "Failed to save account dynamic registration config to cache", "error", err) + return dtos.AccountDynamicRegistrationConfigDTO{}, exceptions.NewInternalServerError() + } + + return accountDRConfigDTO, nil +} + type DeleteAccountDynamicRegistrationConfigOptions struct { RequestID string AccountPublicID uuid.UUID diff --git a/idp/internal/services/apps.go b/idp/internal/services/apps.go index 5841c37..7b24315 100644 --- a/idp/internal/services/apps.go +++ b/idp/internal/services/apps.go @@ -272,7 +272,7 @@ func (s *Services) FilterAccountAppsByName( apps, err = s.database.FilterAppsByNameAndByAccountPublicIDOrderedByID(ctx, database.FilterAppsByNameAndByAccountPublicIDOrderedByIDParams{ AccountPublicID: opts.AccountPublicID, - Name: name, + ClientName: name, Offset: opts.Offset, Limit: opts.Limit, }, @@ -281,7 +281,7 @@ func (s *Services) FilterAccountAppsByName( apps, err = s.database.FilterAppsByNameAndByAccountPublicIDOrderedByName(ctx, database.FilterAppsByNameAndByAccountPublicIDOrderedByNameParams{ AccountPublicID: opts.AccountPublicID, - Name: name, + ClientName: name, Offset: opts.Offset, Limit: opts.Limit, }, @@ -298,7 +298,7 @@ func (s *Services) FilterAccountAppsByName( count, err := s.database.CountFilteredAppsByNameAndByAccountPublicID(ctx, database.CountFilteredAppsByNameAndByAccountPublicIDParams{ AccountPublicID: opts.AccountPublicID, - Name: name, + ClientName: name, }, ) if err != nil { @@ -444,7 +444,7 @@ func (s *Services) FilterAccountAppsByNameAndType( apps, err = s.database.FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByID(ctx, database.FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByIDParams{ AccountPublicID: opts.AccountPublicID, - Name: name, + ClientName: name, Offset: opts.Offset, Limit: opts.Limit, AppType: appType, @@ -454,7 +454,7 @@ func (s *Services) FilterAccountAppsByNameAndType( apps, err = s.database.FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByName(ctx, database.FilterAppsByNameAndTypeAndByAccountPublicIDOrderedByNameParams{ AccountPublicID: opts.AccountPublicID, - Name: name, + ClientName: name, Offset: opts.Offset, Limit: opts.Limit, AppType: appType, @@ -472,7 +472,7 @@ func (s *Services) FilterAccountAppsByNameAndType( count, err := s.database.CountFilteredAppsByNameAndTypeAndByAccountPublicID(ctx, database.CountFilteredAppsByNameAndTypeAndByAccountPublicIDParams{ AccountPublicID: opts.AccountPublicID, - Name: name, + ClientName: name, AppType: appType, }, ) @@ -605,8 +605,8 @@ func (s *Services) checkForDuplicateApps( logger.InfoContext(ctx, "Checking for duplicate apps...") count, err := s.database.CountAppsByAccountIDAndName(ctx, database.CountAppsByAccountIDAndNameParams{ - AccountID: opts.accountID, - Name: opts.name, + AccountID: opts.accountID, + ClientName: opts.name, }) if err != nil { logger.ErrorContext(ctx, "Failed to count apps by name", "error", err) @@ -693,7 +693,7 @@ func (s *Services) createApp( AccountPublicID: opts.accountPublicID, CreationMethod: opts.creationMethod, AppType: opts.appType, - Name: opts.name, + ClientName: opts.name, ClientID: clientID, ClientUri: utils.ProcessURL(opts.clientURI), AllowUserRegistration: opts.allowUserRegistration, @@ -703,7 +703,7 @@ func (s *Services) createApp( LogoUri: mapEmptyURL(opts.logoURI), TosUri: mapEmptyURL(opts.tosURI), PolicyUri: mapEmptyURL(opts.policyURI), - SoftwareID: opts.softwareID, + SoftwareID: mapEmptyString(opts.softwareID), SoftwareVersion: mapEmptyString(opts.softwareVersion), Scopes: stdScopes, DefaultScopes: defaultStdScopes, @@ -773,7 +773,7 @@ func (s *Services) createSingleApp( AccountPublicID: opts.accountPublicID, CreationMethod: opts.creationMethod, AppType: opts.appType, - Name: opts.name, + ClientName: opts.name, ClientID: clientID, ClientUri: utils.ProcessURL(opts.clientURI), AllowUserRegistration: opts.allowUserRegistration, @@ -783,7 +783,7 @@ func (s *Services) createSingleApp( LogoUri: mapEmptyURL(opts.logoURI), TosUri: mapEmptyURL(opts.tosURI), PolicyUri: mapEmptyURL(opts.policyURI), - SoftwareID: opts.softwareID, + SoftwareID: mapEmptyString(opts.softwareID), SoftwareVersion: mapEmptyString(opts.softwareVersion), Scopes: stdScopes, DefaultScopes: defaultStdScopes, @@ -860,7 +860,7 @@ func (s *Services) CreateWebApp( return dtos.AppDTO{}, serviceErr } - responseTypes, serviceErr := mapResponseTypes(opts.ResponseTypes) + responseTypes, serviceErr := mapResponseTypesWithDefault(opts.ResponseTypes) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to map response types", "serviceError", serviceErr) return dtos.AppDTO{}, serviceErr @@ -1040,7 +1040,7 @@ func (s *Services) CreateSPANativeApp( ) logger.InfoContext(ctx, "Creating SPA or Native app...") - responseTypes, serviceErr := mapResponseTypes(opts.ResponseTypes) + responseTypes, serviceErr := mapResponseTypesWithDefault(opts.ResponseTypes) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to map response types", "serviceError", serviceErr) return dtos.AppDTO{}, serviceErr @@ -2026,7 +2026,7 @@ func (s *Services) updateApp( ) (database.App, error) { logger := s.buildLogger(opts.requestID, appsLocation, "updateApp").With( "appID", appDTO.ID(), - "appName", appDTO.Name, + "appClientName", appDTO.ClientName, ) logger.InfoContext(ctx, "Updating base app...") @@ -2058,7 +2058,7 @@ func (s *Services) updateApp( app, err := qrs.UpdateApp(ctx, database.UpdateAppParams{ ID: appDTO.ID(), - Name: opts.name, + ClientName: opts.name, UsernameColumn: usernameColumn, ClientUri: opts.clientURI, LogoUri: mapEmptyURL(opts.logoURI), @@ -2093,7 +2093,7 @@ func (s *Services) updateSingleApp( ) (database.App, *exceptions.ServiceError) { logger := s.buildLogger(opts.requestID, appsLocation, "updateApp").With( "appID", appDTO.ID(), - "appName", appDTO.Name, + "appClientName", appDTO.ClientName, ) logger.InfoContext(ctx, "Updating base app...") @@ -2125,7 +2125,7 @@ func (s *Services) updateSingleApp( app, err := s.database.UpdateApp(ctx, database.UpdateAppParams{ ID: appDTO.ID(), - Name: opts.name, + ClientName: opts.name, UsernameColumn: usernameColumn, ClientUri: opts.clientURI, LogoUri: mapEmptyURL(opts.logoURI), @@ -2177,8 +2177,6 @@ func mapResponseTypesUpdate( switch utils.Lowered(rt) { case ResponseTypeCode: dbResponseTypes = append(dbResponseTypes, database.ResponseTypeCode) - case ResponseTypeIdToken: - dbResponseTypes = append(dbResponseTypes, database.ResponseTypeIDToken) case ResponseTypeCodeIdToken: dbResponseTypes = append(dbResponseTypes, database.ResponseTypeCodeidToken) default: @@ -2216,7 +2214,7 @@ func (s *Services) UpdateWebSPANativeApp( ) (dtos.AppDTO, *exceptions.ServiceError) { logger := s.buildLogger(opts.RequestID, appsLocation, "UpdateWebSPANativeApp").With( "appID", appDTO.ID(), - "appName", appDTO.Name, + "appClientName", appDTO.ClientName, "appType", appDTO.AppType, ) logger.InfoContext(ctx, "Updating web or SPA or native app...") @@ -2228,7 +2226,7 @@ func (s *Services) UpdateWebSPANativeApp( } name := strings.TrimSpace(opts.Name) - if appDTO.Name != name { + if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ requestID: opts.RequestID, accountID: opts.AccountID, @@ -2296,12 +2294,12 @@ func (s *Services) UpdateBackendApp( ) (dtos.AppDTO, *exceptions.ServiceError) { logger := s.buildLogger(opts.RequestID, appsLocation, "UpdateBackendApp").With( "appID", appDTO.ID(), - "appName", appDTO.Name, + "appClientName", appDTO.ClientName, ) logger.InfoContext(ctx, "Updating backend app...") name := strings.TrimSpace(opts.Name) - if appDTO.Name != name { + if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ requestID: opts.RequestID, accountID: opts.AccountID, @@ -2378,12 +2376,12 @@ func (s *Services) UpdateDeviceApp( ) (dtos.AppDTO, *exceptions.ServiceError) { logger := s.buildLogger(opts.RequestID, appsLocation, "UpdateDeviceApp").With( "appID", appDTO.ID(), - "appName", appDTO.Name, + "appClientName", appDTO.ClientName, ) logger.InfoContext(ctx, "Updating device app...") name := strings.TrimSpace(opts.Name) - if appDTO.Name != name { + if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ requestID: opts.RequestID, accountID: opts.AccountID, @@ -2546,12 +2544,12 @@ func (s *Services) UpdateServiceApp( ) (dtos.AppDTO, *exceptions.ServiceError) { logger := s.buildLogger(opts.RequestID, appsLocation, "UpdateServiceApp").With( "appID", appDTO.ID(), - "appName", appDTO.Name, + "appClientName", appDTO.ClientName, ) logger.InfoContext(ctx, "Updating service app...") name := strings.TrimSpace(opts.Name) - if appDTO.Name != name { + if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ requestID: opts.RequestID, accountID: opts.AccountID, @@ -2636,12 +2634,12 @@ func (s *Services) UpdateMCPApp( ) (dtos.AppDTO, *exceptions.ServiceError) { logger := s.buildLogger(opts.RequestID, appsLocation, "UpdateMCPApp").With( "appID", appDTO.ID(), - "appName", appDTO.Name, + "appClientName", appDTO.ClientName, ) logger.InfoContext(ctx, "Updating MCP app...") name := strings.TrimSpace(opts.Name) - if appDTO.Name != name { + if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ requestID: opts.RequestID, accountID: opts.AccountID, diff --git a/idp/internal/services/dtos/account_credentials.go b/idp/internal/services/dtos/account_credentials.go index 20d5ede..fcf6290 100644 --- a/idp/internal/services/dtos/account_credentials.go +++ b/idp/internal/services/dtos/account_credentials.go @@ -11,14 +11,15 @@ import ( "fmt" "time" + "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/utils" ) type AccountCredentialsDTO struct { ClientID string `json:"client_id"` - Type database.AccountCredentialsType `json:"type"` - Name string `json:"name"` + Type database.AccountCredentialsType `json:"application_type"` + ClientName string `json:"client_name"` Domain string `json:"domain"` Scopes []database.AccountCredentialsScope `json:"scopes"` TokenEndpointAuthMethod database.AuthMethod `json:"token_endpoint_auth_method"` @@ -32,10 +33,32 @@ type AccountCredentialsDTO struct { SoftwareID string `json:"software_id"` SoftwareVersion string `json:"software_version,omitempty"` Contacts []string `json:"contacts,omitempty"` - ClientSecretID string `json:"client_secret_id,omitempty"` - ClientSecret string `json:"client_secret,omitempty"` - ClientSecretJWK utils.JWK `json:"client_secret_jwk,omitempty"` - ClientSecretExp int64 `json:"client_secret_exp,omitempty"` + JWKsURI string `json:"jwks_uri,omitempty"` + JWKs []utils.JWK `json:"jwks,omitempty"` + + SectorIdentifierURI string `json:"sector_identifier_uri,omitempty"` + SubjectType database.ClientSubjectType `json:"subject_type,omitempty"` + IDTokenSignedResponseAlg database.TokenCryptoSuite `json:"id_token_signed_response_alg"` + IDTokenEncryptedResponseAlg database.TokenEncryptionAlgorithm `json:"id_token_encrypted_response_alg,omitempty"` + IDTokenEncryptedResponseEnc database.TokenEncryptionEncoding `json:"id_token_encrypted_response_enc,omitempty"` + UserInfoSignedResponseAlg database.TokenCryptoSuite `json:"userinfo_signed_response_alg,omitempty"` + UserInfoEncryptedResponseAlg database.TokenEncryptionAlgorithm `json:"userinfo_encrypted_response_alg,omitempty"` + UserInfoEncryptedResponseEnc database.TokenEncryptionEncoding `json:"userinfo_encrypted_response_enc,omitempty"` + RequestObjectSigningAlg database.TokenCryptoSuite `json:"request_object_signing_alg,omitempty"` + RequestObjectEncryptionAlg database.TokenEncryptionAlgorithm `json:"request_object_encryption_alg,omitempty"` + RequestObjectEncryptionEnc database.TokenEncryptionEncoding `json:"request_object_encryption_enc,omitempty"` + TokenEndpointAuthSigningAlg database.TokenCryptoSuite `json:"token_endpoint_auth_signing_alg,omitempty"` + AccessTokenSigningAlg database.TokenCryptoSuite `json:"access_token_signing_alg"` + DefaultMaxAge int64 `json:"default_max_age,omitempty"` + RequireAuthTime bool `json:"require_auth_time,omitempty"` + DefaultACRValues []string `json:"default_acr_values,omitempty"` + InitiateLoginURI string `json:"initiate_login_uri,omitempty"` + RequestURIs []string `json:"request_uris,omitempty"` + + ClientSecretID string `json:"client_secret_id,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + ClientSecretJWK utils.JWK `json:"client_secret_jwk,omitempty"` + ClientSecretExp int64 `json:"client_secret_exp,omitempty"` id int32 accountId int32 @@ -52,7 +75,8 @@ func (ak *AccountCredentialsDTO) ID() int32 { func (ak *AccountCredentialsDTO) UnmarshalJSON(data []byte) error { type Alias AccountCredentialsDTO aux := &struct { - ClientSecretJWK json.RawMessage `json:"client_secret_jwk"` + ClientSecretJWK json.RawMessage `json:"client_secret_jwk"` + JWKs []json.RawMessage `json:"jwks"` *Alias }{ Alias: (*Alias)(ak), @@ -70,12 +94,24 @@ func (ak *AccountCredentialsDTO) UnmarshalJSON(data []byte) error { ak.ClientSecretJWK = jwk } + if aux.JWKs != nil { + jwks := make([]utils.JWK, 0, len(aux.JWKs)) + for _, raw := range aux.JWKs { + jwk, err := utils.JsonToJWK(raw) + if err != nil { + return err + } + jwks = append(jwks, jwk) + } + ak.JWKs = jwks + } + return nil } func MapAccountCredentialsToDTO( accountCredential *database.AccountCredential, -) AccountCredentialsDTO { +) (AccountCredentialsDTO, *exceptions.ServiceError) { var redirectURIs []string if len(accountCredential.RedirectUris) > 0 { redirectURIs = accountCredential.RedirectUris @@ -86,94 +122,199 @@ func MapAccountCredentialsToDTO( contacts = accountCredential.Contacts } - return AccountCredentialsDTO{ - id: accountCredential.ID, - ClientID: accountCredential.ClientID, - Type: accountCredential.CredentialsType, - Name: accountCredential.Name, - Domain: accountCredential.Domain, - ClientURI: accountCredential.ClientUri, - RedirectURIs: redirectURIs, - LogoURI: accountCredential.LogoUri.String, - TOSURI: accountCredential.TosUri.String, - PolicyURI: accountCredential.PolicyUri.String, - SoftwareID: accountCredential.SoftwareID, - SoftwareVersion: accountCredential.SoftwareVersion.String, - Contacts: contacts, - CreationMethod: accountCredential.CreationMethod, - Transport: accountCredential.Transport, - TokenEndpointAuthMethod: accountCredential.TokenEndpointAuthMethod, - accountId: accountCredential.AccountID, + jwks := make([]utils.JWK, 0) + if accountCredential.Jwks != nil { + var rawJwks []json.RawMessage + if err := json.Unmarshal(accountCredential.Jwks, &rawJwks); err != nil { + return AccountCredentialsDTO{}, exceptions.NewInternalServerError() + } + for _, raw := range rawJwks { + jwk, err := utils.JsonToJWK(raw) + if err != nil { + return AccountCredentialsDTO{}, exceptions.NewInternalServerError() + } + jwks = append(jwks, jwk) + } } + + return AccountCredentialsDTO{ + id: accountCredential.ID, + ClientID: accountCredential.ClientID, + Type: accountCredential.CredentialsType, + ClientName: accountCredential.ClientName, + Domain: accountCredential.Domain, + ClientURI: accountCredential.ClientUri, + RedirectURIs: redirectURIs, + LogoURI: accountCredential.LogoUri.String, + TOSURI: accountCredential.TosUri.String, + PolicyURI: accountCredential.PolicyUri.String, + SoftwareID: accountCredential.SoftwareID.String, + SoftwareVersion: accountCredential.SoftwareVersion.String, + Contacts: contacts, + CreationMethod: accountCredential.CreationMethod, + Transport: accountCredential.Transport, + TokenEndpointAuthMethod: accountCredential.TokenEndpointAuthMethod, + accountId: accountCredential.AccountID, + JWKsURI: accountCredential.JwksUri.String, + JWKs: jwks, + SectorIdentifierURI: accountCredential.SectorIdentifierUri.String, + SubjectType: accountCredential.SubjectType.ClientSubjectType, + IDTokenSignedResponseAlg: accountCredential.IDTokenSignedResponseAlg, + IDTokenEncryptedResponseAlg: accountCredential.IDTokenEncryptedResponseAlg.TokenEncryptionAlgorithm, + IDTokenEncryptedResponseEnc: accountCredential.IDTokenEncryptedResponseEnc.TokenEncryptionEncoding, + UserInfoSignedResponseAlg: accountCredential.UserinfoSignedResponseAlg.TokenCryptoSuite, + UserInfoEncryptedResponseAlg: accountCredential.UserinfoEncryptedResponseAlg.TokenEncryptionAlgorithm, + UserInfoEncryptedResponseEnc: accountCredential.UserinfoEncryptedResponseEnc.TokenEncryptionEncoding, + RequestObjectSigningAlg: accountCredential.RequestObjectSigningAlg.TokenCryptoSuite, + RequestObjectEncryptionAlg: accountCredential.RequestObjectEncryptionAlg.TokenEncryptionAlgorithm, + RequestObjectEncryptionEnc: accountCredential.RequestObjectEncryptionEnc.TokenEncryptionEncoding, + TokenEndpointAuthSigningAlg: accountCredential.TokenEndpointAuthSigningAlg.TokenCryptoSuite, + AccessTokenSigningAlg: accountCredential.AccessTokenSigningAlg, + DefaultMaxAge: accountCredential.DefaultMaxAge.Int64, + RequireAuthTime: accountCredential.RequireAuthTime, + DefaultACRValues: accountCredential.DefaultAcrValues, + InitiateLoginURI: accountCredential.InitiateLoginUri.String, + RequestURIs: accountCredential.RequestUris, + }, nil } func MapAccountCredentialsToDTOWithJWK( - accountKeys *database.AccountCredential, + accountCredential *database.AccountCredential, jwk utils.JWK, exp time.Time, -) AccountCredentialsDTO { +) (AccountCredentialsDTO, *exceptions.ServiceError) { var contacts []string - if len(accountKeys.Contacts) > 0 { - contacts = accountKeys.Contacts + if len(accountCredential.Contacts) > 0 { + contacts = accountCredential.Contacts } - return AccountCredentialsDTO{ - id: accountKeys.ID, - Type: accountKeys.CredentialsType, - Name: accountKeys.Name, - Domain: accountKeys.Domain, - ClientURI: accountKeys.ClientUri, - RedirectURIs: accountKeys.RedirectUris, - LogoURI: accountKeys.LogoUri.String, - TOSURI: accountKeys.TosUri.String, - PolicyURI: accountKeys.PolicyUri.String, - SoftwareID: accountKeys.SoftwareID, - SoftwareVersion: accountKeys.SoftwareVersion.String, - Contacts: contacts, - CreationMethod: accountKeys.CreationMethod, - Transport: accountKeys.Transport, - TokenEndpointAuthMethod: accountKeys.TokenEndpointAuthMethod, - accountId: accountKeys.AccountID, - ClientID: accountKeys.ClientID, - ClientSecretID: jwk.GetKeyID(), - ClientSecretJWK: jwk, - ClientSecretExp: exp.Unix(), - Scopes: accountKeys.Scopes, + jwks := make([]utils.JWK, 0) + if accountCredential.Jwks != nil { + var rawJwks []json.RawMessage + if err := json.Unmarshal(accountCredential.Jwks, &rawJwks); err != nil { + return AccountCredentialsDTO{}, exceptions.NewInternalServerError() + } + for _, raw := range rawJwks { + jwk, err := utils.JsonToJWK(raw) + if err != nil { + return AccountCredentialsDTO{}, exceptions.NewInternalServerError() + } + jwks = append(jwks, jwk) + } } + + return AccountCredentialsDTO{ + id: accountCredential.ID, + Type: accountCredential.CredentialsType, + ClientName: accountCredential.ClientName, + Domain: accountCredential.Domain, + ClientURI: accountCredential.ClientUri, + RedirectURIs: accountCredential.RedirectUris, + LogoURI: accountCredential.LogoUri.String, + TOSURI: accountCredential.TosUri.String, + PolicyURI: accountCredential.PolicyUri.String, + SoftwareID: accountCredential.SoftwareID.String, + SoftwareVersion: accountCredential.SoftwareVersion.String, + Contacts: contacts, + CreationMethod: accountCredential.CreationMethod, + Transport: accountCredential.Transport, + TokenEndpointAuthMethod: accountCredential.TokenEndpointAuthMethod, + accountId: accountCredential.AccountID, + ClientID: accountCredential.ClientID, + ClientSecretID: jwk.GetKeyID(), + ClientSecretJWK: jwk, + ClientSecretExp: exp.Unix(), + Scopes: accountCredential.Scopes, + JWKsURI: accountCredential.JwksUri.String, + JWKs: jwks, + SectorIdentifierURI: accountCredential.SectorIdentifierUri.String, + SubjectType: accountCredential.SubjectType.ClientSubjectType, + IDTokenSignedResponseAlg: accountCredential.IDTokenSignedResponseAlg, + IDTokenEncryptedResponseAlg: accountCredential.IDTokenEncryptedResponseAlg.TokenEncryptionAlgorithm, + IDTokenEncryptedResponseEnc: accountCredential.IDTokenEncryptedResponseEnc.TokenEncryptionEncoding, + UserInfoSignedResponseAlg: accountCredential.UserinfoSignedResponseAlg.TokenCryptoSuite, + UserInfoEncryptedResponseAlg: accountCredential.UserinfoEncryptedResponseAlg.TokenEncryptionAlgorithm, + UserInfoEncryptedResponseEnc: accountCredential.UserinfoEncryptedResponseEnc.TokenEncryptionEncoding, + RequestObjectSigningAlg: accountCredential.RequestObjectSigningAlg.TokenCryptoSuite, + RequestObjectEncryptionAlg: accountCredential.RequestObjectEncryptionAlg.TokenEncryptionAlgorithm, + RequestObjectEncryptionEnc: accountCredential.RequestObjectEncryptionEnc.TokenEncryptionEncoding, + TokenEndpointAuthSigningAlg: accountCredential.TokenEndpointAuthSigningAlg.TokenCryptoSuite, + AccessTokenSigningAlg: accountCredential.AccessTokenSigningAlg, + DefaultMaxAge: accountCredential.DefaultMaxAge.Int64, + RequireAuthTime: accountCredential.RequireAuthTime, + DefaultACRValues: accountCredential.DefaultAcrValues, + InitiateLoginURI: accountCredential.InitiateLoginUri.String, + RequestURIs: accountCredential.RequestUris, + }, nil } func MapAccountCredentialsToDTOWithSecret( - accountKeys *database.AccountCredential, + accountCredential *database.AccountCredential, secretID, secret string, exp time.Time, -) AccountCredentialsDTO { +) (AccountCredentialsDTO, *exceptions.ServiceError) { var contacts []string - if len(accountKeys.Contacts) > 0 { - contacts = accountKeys.Contacts + if len(accountCredential.Contacts) > 0 { + contacts = accountCredential.Contacts } - return AccountCredentialsDTO{ - id: accountKeys.ID, - Type: accountKeys.CredentialsType, - Name: accountKeys.Name, - Domain: accountKeys.Domain, - ClientURI: accountKeys.ClientUri, - RedirectURIs: accountKeys.RedirectUris, - LogoURI: accountKeys.LogoUri.String, - TOSURI: accountKeys.TosUri.String, - PolicyURI: accountKeys.PolicyUri.String, - SoftwareID: accountKeys.SoftwareID, - SoftwareVersion: accountKeys.SoftwareVersion.String, - Contacts: contacts, - CreationMethod: accountKeys.CreationMethod, - Transport: accountKeys.Transport, - TokenEndpointAuthMethod: accountKeys.TokenEndpointAuthMethod, - accountId: accountKeys.AccountID, - ClientID: accountKeys.ClientID, - ClientSecretID: secretID, - ClientSecret: fmt.Sprintf("%s.%s", secretID, secret), - ClientSecretExp: exp.Unix(), - Scopes: accountKeys.Scopes, + jwks := make([]utils.JWK, 0) + if accountCredential.Jwks != nil { + var rawJwks []json.RawMessage + if err := json.Unmarshal(accountCredential.Jwks, &rawJwks); err != nil { + return AccountCredentialsDTO{}, exceptions.NewInternalServerError() + } + for _, raw := range rawJwks { + jwk, err := utils.JsonToJWK(raw) + if err != nil { + return AccountCredentialsDTO{}, exceptions.NewInternalServerError() + } + jwks = append(jwks, jwk) + } } + + return AccountCredentialsDTO{ + id: accountCredential.ID, + Type: accountCredential.CredentialsType, + ClientName: accountCredential.ClientName, + Domain: accountCredential.Domain, + ClientURI: accountCredential.ClientUri, + RedirectURIs: accountCredential.RedirectUris, + LogoURI: accountCredential.LogoUri.String, + TOSURI: accountCredential.TosUri.String, + PolicyURI: accountCredential.PolicyUri.String, + SoftwareID: accountCredential.SoftwareID.String, + SoftwareVersion: accountCredential.SoftwareVersion.String, + Contacts: contacts, + CreationMethod: accountCredential.CreationMethod, + Transport: accountCredential.Transport, + TokenEndpointAuthMethod: accountCredential.TokenEndpointAuthMethod, + accountId: accountCredential.AccountID, + ClientID: accountCredential.ClientID, + ClientSecretID: secretID, + ClientSecret: fmt.Sprintf("%s.%s", secretID, secret), + ClientSecretExp: exp.Unix(), + Scopes: accountCredential.Scopes, + JWKsURI: accountCredential.JwksUri.String, + JWKs: jwks, + SectorIdentifierURI: accountCredential.SectorIdentifierUri.String, + SubjectType: accountCredential.SubjectType.ClientSubjectType, + IDTokenSignedResponseAlg: accountCredential.IDTokenSignedResponseAlg, + IDTokenEncryptedResponseAlg: accountCredential.IDTokenEncryptedResponseAlg.TokenEncryptionAlgorithm, + IDTokenEncryptedResponseEnc: accountCredential.IDTokenEncryptedResponseEnc.TokenEncryptionEncoding, + UserInfoSignedResponseAlg: accountCredential.UserinfoSignedResponseAlg.TokenCryptoSuite, + UserInfoEncryptedResponseAlg: accountCredential.UserinfoEncryptedResponseAlg.TokenEncryptionAlgorithm, + UserInfoEncryptedResponseEnc: accountCredential.UserinfoEncryptedResponseEnc.TokenEncryptionEncoding, + RequestObjectSigningAlg: accountCredential.RequestObjectSigningAlg.TokenCryptoSuite, + RequestObjectEncryptionAlg: accountCredential.RequestObjectEncryptionAlg.TokenEncryptionAlgorithm, + RequestObjectEncryptionEnc: accountCredential.RequestObjectEncryptionEnc.TokenEncryptionEncoding, + TokenEndpointAuthSigningAlg: accountCredential.TokenEndpointAuthSigningAlg.TokenCryptoSuite, + AccessTokenSigningAlg: accountCredential.AccessTokenSigningAlg, + DefaultMaxAge: accountCredential.DefaultMaxAge.Int64, + RequireAuthTime: accountCredential.RequireAuthTime, + DefaultACRValues: accountCredential.DefaultAcrValues, + InitiateLoginURI: accountCredential.InitiateLoginUri.String, + RequestURIs: accountCredential.RequestUris, + }, nil } diff --git a/idp/internal/services/dtos/account_dynamic_registration_config.go b/idp/internal/services/dtos/account_dynamic_registration_config.go index 964ece5..9215960 100644 --- a/idp/internal/services/dtos/account_dynamic_registration_config.go +++ b/idp/internal/services/dtos/account_dynamic_registration_config.go @@ -12,8 +12,8 @@ type AccountDynamicRegistrationConfigDTO struct { id int32 CredentialsTypes []database.AccountCredentialsType `json:"credentials_types"` - WhitelistedDomains []string `json:"whitelisted_domains"` RequireSoftwareStatementCredentialTypes []database.AccountCredentialsType `json:"require_software_statement_credential_types"` + RequireVerifiedDomainsCredentialsType []database.AccountCredentialsType `json:"require_verified_domains_credentials_type"` SoftwareStatementVerificationMethods []database.SoftwareStatementVerificationMethod `json:"software_statement_verification_methods"` RequireInitialAccessTokenCredentialTypes []database.AccountCredentialsType `json:"require_initial_access_token_credential_types"` InitialAccessTokenGenerationMethods []database.InitialAccessTokenGenerationMethod `json:"initial_access_token_generation_methods"` @@ -29,8 +29,8 @@ func MapAccountDynamicRegistrationConfigToDTO( return AccountDynamicRegistrationConfigDTO{ id: config.ID, CredentialsTypes: config.AccountCredentialsTypes, - WhitelistedDomains: config.WhitelistedDomains, RequireSoftwareStatementCredentialTypes: config.RequireSoftwareStatementCredentialTypes, + RequireVerifiedDomainsCredentialsType: config.RequireVerifiedDomainsCredentialsType, SoftwareStatementVerificationMethods: config.SoftwareStatementVerificationMethods, RequireInitialAccessTokenCredentialTypes: config.RequireInitialAccessTokenCredentialTypes, InitialAccessTokenGenerationMethods: config.InitialAccessTokenGenerationMethods, diff --git a/idp/internal/services/dtos/app.go b/idp/internal/services/dtos/app.go index 71b0877..2413cfc 100644 --- a/idp/internal/services/dtos/app.go +++ b/idp/internal/services/dtos/app.go @@ -17,10 +17,10 @@ import ( ) type RelatedAppDTO struct { - AppType database.AppType `json:"app_type"` - Name string `json:"name"` - ClientID string `json:"client_id"` - Links LinksSelfDTO `json:"links"` + AppType database.AppType `json:"app_type"` + ClientName string `json:"client_name"` + ClientID string `json:"client_id"` + Links LinksSelfDTO `json:"links"` } func newRelatedAppDTO( @@ -29,10 +29,10 @@ func newRelatedAppDTO( route string, ) RelatedAppDTO { return RelatedAppDTO{ - AppType: app.AppType, - Name: app.Name, - ClientID: app.ClientID, - Links: NewLinksSelfDTO(backendDomain, route), + AppType: app.AppType, + ClientName: app.ClientName, + ClientID: app.ClientID, + Links: NewLinksSelfDTO(backendDomain, route), } } @@ -42,7 +42,7 @@ type AppDTO struct { version int32 AppType database.AppType `json:"app_type"` - Name string `json:"name"` + ClientName string `json:"client_name"` ClientID string `json:"client_id"` Domain string `json:"domain"` Transport database.Transport `json:"transport"` @@ -130,7 +130,7 @@ func MapAppToDTO(app *database.App) AppDTO { accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -139,7 +139,7 @@ func MapAppToDTO(app *database.App) AppDTO { LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -161,7 +161,7 @@ func MapWebNativeSPAMCPAppToDTO(app *database.App) AppDTO { accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -170,7 +170,7 @@ func MapWebNativeSPAMCPAppToDTO(app *database.App) AppDTO { LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -197,7 +197,7 @@ func MapWebAppWithSecretToDTO( accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -206,7 +206,7 @@ func MapWebAppWithSecretToDTO( LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -231,7 +231,7 @@ func MapWebAppWithJWKToDTO(app *database.App, jwk utils.JWK, exp time.Time) AppD accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -240,7 +240,7 @@ func MapWebAppWithJWKToDTO(app *database.App, jwk utils.JWK, exp time.Time) AppD LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -265,7 +265,7 @@ func MapBackendAppWithJWKToDTO(app *database.App, jwk utils.JWK, exp time.Time) accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -274,7 +274,7 @@ func MapBackendAppWithJWKToDTO(app *database.App, jwk utils.JWK, exp time.Time) LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -297,7 +297,7 @@ func MapBackendAppWithSecretToDTO(app *database.App, secretID string, secret str accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -306,7 +306,7 @@ func MapBackendAppWithSecretToDTO(app *database.App, secretID string, secret str LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -329,7 +329,7 @@ func MapDeviceAppToDTO(app *database.App, relatedApps []database.App, backendDom accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, Domain: app.Domain, Transport: app.Transport, CreationMethod: app.CreationMethod, @@ -338,7 +338,7 @@ func MapDeviceAppToDTO(app *database.App, relatedApps []database.App, backendDom LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -366,7 +366,7 @@ func MapServiceAppWithJWKToDTO( accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -375,7 +375,7 @@ func MapServiceAppWithJWKToDTO( LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -407,7 +407,7 @@ func MapServiceAppWithSecretToDTO( accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -416,7 +416,7 @@ func MapServiceAppWithSecretToDTO( LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -442,7 +442,7 @@ func MapBackendAppToDTO(app *database.App) AppDTO { accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -451,7 +451,7 @@ func MapBackendAppToDTO(app *database.App) AppDTO { LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -474,7 +474,7 @@ func MapServiceAppToDTO( accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -483,7 +483,7 @@ func MapServiceAppToDTO( LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -506,7 +506,7 @@ func MapMCPAppWithJWKToDTO(app *database.App, jwk utils.JWK, exp time.Time) AppD accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -515,7 +515,7 @@ func MapMCPAppWithJWKToDTO(app *database.App, jwk utils.JWK, exp time.Time) AppD LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, @@ -542,7 +542,7 @@ func MapMCPAppWithSecretToDTO( accountID: app.AccountID, version: app.Version, AppType: app.AppType, - Name: app.Name, + ClientName: app.ClientName, ClientID: app.ClientID, Domain: app.Domain, Transport: app.Transport, @@ -551,7 +551,7 @@ func MapMCPAppWithSecretToDTO( LogoURI: app.LogoUri.String, TosURI: app.TosUri.String, PolicyURI: app.PolicyUri.String, - SoftwareID: app.SoftwareID, + SoftwareID: app.SoftwareID.String, SoftwareVersion: app.SoftwareVersion.String, TokenEndpointAuthMethod: app.TokenEndpointAuthMethod, GrantTypes: app.GrantTypes, diff --git a/idp/internal/services/dtos/dynamic_registration_domain.go b/idp/internal/services/dtos/dynamic_registration_domain.go index d15f07b..509b335 100644 --- a/idp/internal/services/dtos/dynamic_registration_domain.go +++ b/idp/internal/services/dtos/dynamic_registration_domain.go @@ -33,7 +33,7 @@ func (a *DynamicRegistrationDomainDTO) ID() int32 { } func MapAccountCredentialsRegistrationDomainToDTOWithCode( - domain *database.AccountDynamicRegistrationDomain, + domain *database.DynamicRegistrationDomain, verificationHost string, verificationPrefix string, verificationCode string, @@ -53,7 +53,7 @@ func MapAccountCredentialsRegistrationDomainToDTOWithCode( } func MapAccountCredentialsRegistrationDomainToDTO( - domain *database.AccountDynamicRegistrationDomain, + domain *database.DynamicRegistrationDomain, ) DynamicRegistrationDomainDTO { verifiedAt := int64(0) if domain.VerifiedAt.Valid { diff --git a/idp/internal/services/account_credentials_registration_domains.go b/idp/internal/services/dynamic_registration_domains.go similarity index 80% rename from idp/internal/services/account_credentials_registration_domains.go rename to idp/internal/services/dynamic_registration_domains.go index 1ea7811..e6d9815 100644 --- a/idp/internal/services/account_credentials_registration_domains.go +++ b/idp/internal/services/dynamic_registration_domains.go @@ -9,7 +9,6 @@ package services import ( "context" "fmt" - "slices" "time" "github.com/google/uuid" @@ -22,45 +21,63 @@ import ( ) const ( - accountCredentialsRegistrationDomainsLocation string = "account_credentials_registration_domains" + dynamicRegistrationDomainsLocation string = "dynamic_registration_domains" domainCodeByteLength int = 32 ) -type CreateAccountCredentialsRegistrationDomainOptions struct { +func mapDynamicRegistrationDomainUsage(usage string) (database.DynamicRegistrationUsage, *exceptions.ServiceError) { + switch utils.Lowered(usage) { + case "account": + return database.DynamicRegistrationUsageAccount, nil + case "app": + return database.DynamicRegistrationUsageApp, nil + } + return "", exceptions.NewValidationError("Invalid usage") +} + +func mapDynamicRegistrationDomainUsages(usages []string) ([]database.DynamicRegistrationUsage, *exceptions.ServiceError) { + if len(usages) == 0 { + return nil, exceptions.NewValidationError("Usages cannot be empty") + } + + dbUsages := make([]database.DynamicRegistrationUsage, 0, len(usages)) + for _, usage := range usages { + usage, serviceErr := mapDynamicRegistrationDomainUsage(usage) + if serviceErr != nil { + return nil, serviceErr + } + dbUsages = append(dbUsages, usage) + } + + return dbUsages, nil +} + +type CreateDynamicRegistrationDomainOptions struct { RequestID string AccountPublicID uuid.UUID AccountVersion int32 Domain string + Usages []string } -func (s *Services) CreateAccountCredentialsRegistrationDomain( +func (s *Services) CreateDynamicRegistrationDomain( ctx context.Context, - opts CreateAccountCredentialsRegistrationDomainOptions, + opts CreateDynamicRegistrationDomainOptions, ) (dtos.DynamicRegistrationDomainDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "CreateAccountCredentialsRegistrationDomain").With( + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "CreateDynamicRegistrationDomain").With( "accountPublicID", opts.AccountPublicID, "domain", opts.Domain, ) logger.InfoContext(ctx, "Creating account credentials registration domain...") - dynamicRegistrationConfig, serviceErr := s.GetAccountDynamicRegistrationConfig(ctx, GetAccountDynamicRegistrationConfigOptions{ - RequestID: opts.RequestID, - AccountPublicID: opts.AccountPublicID, - }) + usages, serviceErr := mapDynamicRegistrationDomainUsages(opts.Usages) if serviceErr != nil { - if serviceErr.Code != exceptions.CodeNotFound { - logger.WarnContext(ctx, "Account dynamic registration config not found", "serviceError", serviceErr) - return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewNotFoundValidationError("Dynamic registration config not found") - } + logger.WarnContext(ctx, "Failed to map usages", "serviceError", serviceErr) return dtos.DynamicRegistrationDomainDTO{}, serviceErr } - if len(dynamicRegistrationConfig.WhitelistedDomains) > 0 && !slices.Contains(dynamicRegistrationConfig.WhitelistedDomains, opts.Domain) { - logger.WarnContext(ctx, "Domain is not whitelisted", "domain", opts.Domain) - return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewForbiddenValidationError("Domain is not whitelisted") - } - if _, err := s.database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx, database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams{ + if _, err := s.database.FindDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx, database.FindDynamicRegistrationDomainByAccountPublicIDAndDomainParams{ AccountPublicID: opts.AccountPublicID, Domain: opts.Domain, }); err != nil { @@ -94,11 +111,12 @@ func (s *Services) CreateAccountCredentialsRegistrationDomain( s.database.FinalizeTx(ctx, txn, err, serviceErr) }() - domain, err := qrs.CreateAccountDynamicRegistrationDomain(ctx, database.CreateAccountDynamicRegistrationDomainParams{ + domain, err := qrs.CreateDynamicRegistrationDomain(ctx, database.CreateDynamicRegistrationDomainParams{ AccountID: accountDTO.ID(), AccountPublicID: opts.AccountPublicID, Domain: opts.Domain, VerificationMethod: database.DomainVerificationMethodDnsTxtRecord, + Usages: usages, }) if err != nil { logger.ErrorContext(ctx, "Failed to create account dynamic registration domain", "error", err) @@ -139,30 +157,19 @@ func (s *Services) CreateAccountCredentialsRegistrationDomain( Queries: qrs, }), StoreHashedDataFN: func(secretID string, hashedData string) *exceptions.ServiceError { - codeID, err := qrs.CreateDynamicRegistrationDomainCode( + if err := qrs.CreateDynamicRegistrationDomainCode( ctx, database.CreateDynamicRegistrationDomainCodeParams{ - AccountID: accountDTO.ID(), - VerificationCode: hashedData, - VerificationPrefix: verificationPrefix, - VerificationHost: s.accountDomainVerificationHost, - HmacSecretID: secretID, - ExpiresAt: exp, - }, - ) - if err != nil { - logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code", "error", err) - return exceptions.FromDBError(err) - } - if err := qrs.CreateAccountDynamicRegistrationDomainCode( - ctx, - database.CreateAccountDynamicRegistrationDomainCodeParams{ - AccountDynamicRegistrationDomainID: domain.ID, - DynamicRegistrationDomainCodeID: codeID, - AccountID: accountDTO.ID(), + AccountID: accountDTO.ID(), + VerificationCode: hashedData, + VerificationPrefix: verificationPrefix, + VerificationHost: s.accountDomainVerificationHost, + DynamicRegistrationDomainID: domain.ID, + HmacSecretID: secretID, + ExpiresAt: exp, }, ); err != nil { - logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code association", "error", err) + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code", "error", err) return exceptions.FromDBError(err) } return nil @@ -186,13 +193,13 @@ func (s *Services) GetAccountCredentialsRegistrationDomain( ctx context.Context, opts GetAccountCredentialsRegistrationDomainOptions, ) (dtos.DynamicRegistrationDomainDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "GetAccountCredentialsRegistrationDomain").With( + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "GetAccountCredentialsRegistrationDomain").With( "accountPublicID", opts.AccountPublicID, "domain", opts.Domain, ) logger.InfoContext(ctx, "Getting account credentials registration domain...") - domainDTO, err := s.database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx, database.FindAccountDynamicRegistrationDomainByAccountPublicIDAndDomainParams{ + domainDTO, err := s.database.FindDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx, database.FindDynamicRegistrationDomainByAccountPublicIDAndDomainParams{ AccountPublicID: opts.AccountPublicID, Domain: opts.Domain, }) @@ -223,7 +230,7 @@ func (s *Services) ListAccountCredentialsRegistrationDomains( ctx context.Context, opts ListAccountCredentialsRegistrationDomainsOptions, ) ([]dtos.DynamicRegistrationDomainDTO, int64, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "ListAccountCredentialsRegistrationDomains").With( + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "ListAccountCredentialsRegistrationDomains").With( "accountPublicID", opts.AccountPublicID, "offset", opts.Offset, "limit", opts.Limit, @@ -232,22 +239,22 @@ func (s *Services) ListAccountCredentialsRegistrationDomains( logger.InfoContext(ctx, "Listing account credentials registration domains...") order := utils.Lowered(opts.Order) - var domains []database.AccountDynamicRegistrationDomain + var domains []database.DynamicRegistrationDomain var err error switch order { case "date": - domains, err = s.database.FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID( + domains, err = s.database.FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByID( ctx, - database.FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams{ + database.FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams{ AccountPublicID: opts.AccountPublicID, Limit: opts.Limit, Offset: opts.Offset, }, ) case "domain": - domains, err = s.database.FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain( + domains, err = s.database.FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain( ctx, - database.FindPaginatedAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams{ + database.FindPaginatedDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams{ AccountPublicID: opts.AccountPublicID, Limit: opts.Limit, Offset: opts.Offset, @@ -262,7 +269,7 @@ func (s *Services) ListAccountCredentialsRegistrationDomains( return nil, 0, exceptions.FromDBError(err) } - count, err := s.database.CountAccountDynamicRegistrationDomainsByAccountPublicID(ctx, opts.AccountPublicID) + count, err := s.database.CountDynamicRegistrationDomainsByAccountPublicID(ctx, opts.AccountPublicID) if err != nil { logger.ErrorContext(ctx, "Failed to count account dynamic registration domains", "error", err) return nil, 0, exceptions.FromDBError(err) @@ -285,7 +292,7 @@ func (s *Services) FilterAccountCredentialsRegistrationDomains( ctx context.Context, opts FilterAccountCredentialsRegistrationDomainsOptions, ) ([]dtos.DynamicRegistrationDomainDTO, int64, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "FilterAccountCredentialsRegistrationDomains").With( + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "FilterAccountCredentialsRegistrationDomains").With( "accountPublicID", opts.AccountPublicID, "search", opts.Search, "offset", opts.Offset, @@ -296,14 +303,14 @@ func (s *Services) FilterAccountCredentialsRegistrationDomains( domainSearch := utils.DbSearch(opts.Search) order := utils.Lowered(opts.Order) - var domains []database.AccountDynamicRegistrationDomain + var domains []database.DynamicRegistrationDomain var err error switch order { case "date": - domains, err = s.database.FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByID( + domains, err = s.database.FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByID( ctx, - database.FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams{ + database.FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByIDParams{ AccountPublicID: opts.AccountPublicID, Domain: domainSearch, Limit: opts.Limit, @@ -311,9 +318,9 @@ func (s *Services) FilterAccountCredentialsRegistrationDomains( }, ) case "domain": - domains, err = s.database.FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain( + domains, err = s.database.FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByDomain( ctx, - database.FilterAccountDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams{ + database.FilterDynamicRegistrationDomainsByAccountPublicIDOrderedByDomainParams{ AccountPublicID: opts.AccountPublicID, Domain: domainSearch, Limit: opts.Limit, @@ -329,9 +336,9 @@ func (s *Services) FilterAccountCredentialsRegistrationDomains( return nil, 0, exceptions.FromDBError(err) } - count, err := s.database.CountFilteredAccountDynamicRegistrationDomainsByAccountPublicID( + count, err := s.database.CountFilteredDynamicRegistrationDomainsByAccountPublicID( ctx, - database.CountFilteredAccountDynamicRegistrationDomainsByAccountPublicIDParams{ + database.CountFilteredDynamicRegistrationDomainsByAccountPublicIDParams{ AccountPublicID: opts.AccountPublicID, Domain: domainSearch, }, @@ -356,7 +363,7 @@ func (s *Services) DeleteAccountCredentialsRegistrationDomain( ctx context.Context, opts DeleteAccountCredentialsRegistrationDomainOptions, ) *exceptions.ServiceError { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "DeleteAccountCredentialsRegistrationDomain").With( + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "DeleteAccountCredentialsRegistrationDomain").With( "accountPublicID", opts.AccountPublicID, "domain", opts.Domain, ) @@ -380,7 +387,7 @@ func (s *Services) DeleteAccountCredentialsRegistrationDomain( logger.WarnContext(ctx, "Failed to get account credentials registration domain", "error", serviceErr) return serviceErr } - if err := s.database.DeleteAccountDynamicRegistrationDomain(ctx, domainDTO.ID()); err != nil { + if err := s.database.DeleteDynamicRegistrationDomain(ctx, domainDTO.ID()); err != nil { logger.ErrorContext(ctx, "Failed to delete account dynamic registration domain", "error", err) return exceptions.FromDBError(err) } @@ -401,7 +408,7 @@ func (s *Services) VerifyAccountCredentialsRegistrationDomain( ctx context.Context, opts VerifyAccountCredentialsRegistrationDomainOptions, ) (dtos.DynamicRegistrationDomainDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "VerifyAccountCredentialsRegistrationDomain").With( + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "VerifyAccountCredentialsRegistrationDomain").With( "accountPublicID", opts.AccountPublicID, "domain", opts.Domain, "verificationCode", opts.VerificationCode, @@ -436,7 +443,7 @@ func (s *Services) VerifyAccountCredentialsRegistrationDomain( return dtos.DynamicRegistrationDomainDTO{}, serviceErr } - code, err := s.database.FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID(ctx, domainDTO.ID()) + code, err := s.database.FindDynamicRegistrationDomainCodeByDynamicRegistrationDomainID(ctx, domainDTO.ID()) if err != nil { serviceErr := exceptions.FromDBError(err) if serviceErr.Code != exceptions.CodeNotFound { @@ -507,9 +514,9 @@ func (s *Services) VerifyAccountCredentialsRegistrationDomain( s.database.FinalizeTx(ctx, txn, err, serviceErr) }() - domain, err := qrs.VerifyAccountDynamicRegistrationDomain( + domain, err := qrs.VerifyDynamicRegistrationDomain( ctx, - database.VerifyAccountDynamicRegistrationDomainParams{ + database.VerifyDynamicRegistrationDomainParams{ ID: domainDTO.ID(), VerificationMethod: database.DomainVerificationMethodDnsTxtRecord, }, @@ -539,7 +546,7 @@ func (s *Services) GetAccountCredentialsRegistrationDomainCode( ctx context.Context, opts GetAccountCredentialsRegistrationDomainCodeOptions, ) (dtos.DynamicRegistrationDomainCodeDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "GetAccountCredentialsRegistrationDomainCode").With( + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "GetAccountCredentialsRegistrationDomainCode").With( "accountPublicID", opts.AccountPublicID, "domain", opts.Domain, ) @@ -559,7 +566,7 @@ func (s *Services) GetAccountCredentialsRegistrationDomainCode( return dtos.DynamicRegistrationDomainCodeDTO{}, exceptions.NewConflictError("Verification code not available for verified domain") } - code, err := s.database.FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID(ctx, domainDTO.ID()) + code, err := s.database.FindDynamicRegistrationDomainCodeByDynamicRegistrationDomainID(ctx, domainDTO.ID()) if err != nil { serviceErr := exceptions.FromDBError(err) if serviceErr.Code != exceptions.CodeNotFound { @@ -586,7 +593,7 @@ func (s *Services) SaveAccountCredentialsRegistrationDomainCode( ctx context.Context, opts SaveAccountCredentialsRegistrationDomainCodeOptions, ) (dtos.DynamicRegistrationDomainCodeDTO, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "SaveAccountCredentialsRegistrationDomainCode").With( + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "SaveAccountCredentialsRegistrationDomainCode").With( "accountPublicID", opts.AccountPublicID, "domain", opts.Domain, ) @@ -633,7 +640,7 @@ func (s *Services) SaveAccountCredentialsRegistrationDomainCode( verificationPrefix := fmt.Sprintf("%s-verification", accountDTO.Username) exp := time.Now().Add(s.accountDomainVerificationTTL) - code, err := s.database.FindDynamicRegistrationDomainCodeByAccountDynamicRegistrationDomainID(ctx, domainDTO.ID()) + code, err := s.database.FindDynamicRegistrationDomainCodeByDynamicRegistrationDomainID(ctx, domainDTO.ID()) if err != nil { serviceErr := exceptions.FromDBError(err) if serviceErr.Code != exceptions.CodeNotFound { @@ -661,19 +668,7 @@ func (s *Services) SaveAccountCredentialsRegistrationDomainCode( AccountID: accountDTO.ID(), }), StoreHashedDataFN: func(secretID string, hashedData string) *exceptions.ServiceError { - var serviceErr *exceptions.ServiceError - qrs, txn, err := s.database.BeginTx(ctx) - if err != nil { - logger.ErrorContext(ctx, "Failed to start transaction", "error", err) - serviceErr = exceptions.FromDBError(err) - return serviceErr - } - defer func() { - logger.DebugContext(ctx, "Finalizing transaction") - s.database.FinalizeTx(ctx, txn, err, serviceErr) - }() - - codeID, err := qrs.CreateDynamicRegistrationDomainCode( + if err := s.database.CreateDynamicRegistrationDomainCode( ctx, database.CreateDynamicRegistrationDomainCodeParams{ AccountID: accountDTO.ID(), @@ -683,21 +678,8 @@ func (s *Services) SaveAccountCredentialsRegistrationDomainCode( HmacSecretID: secretID, ExpiresAt: exp, }, - ) - if err != nil { - logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code", "error", err) - serviceErr = exceptions.FromDBError(err) - return serviceErr - } - if err := qrs.CreateAccountDynamicRegistrationDomainCode( - ctx, - database.CreateAccountDynamicRegistrationDomainCodeParams{ - AccountDynamicRegistrationDomainID: domainDTO.ID(), - DynamicRegistrationDomainCodeID: codeID, - AccountID: accountDTO.ID(), - }, ); err != nil { - logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code association", "error", err) + logger.ErrorContext(ctx, "Failed to create account dynamic registration domain code", "error", err) serviceErr = exceptions.FromDBError(err) return serviceErr } @@ -778,7 +760,7 @@ func (s *Services) DeleteAccountCredentialsRegistrationDomainCode( ctx context.Context, opts DeleteAccountCredentialsRegistrationDomainCodeOptions, ) *exceptions.ServiceError { - logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationDomainsLocation, "DeleteAccountCredentialsRegistrationDomainCode").With( + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "DeleteAccountCredentialsRegistrationDomainCode").With( "accountPublicID", opts.AccountPublicID, "domain", opts.Domain, ) diff --git a/idp/internal/services/helpers.go b/idp/internal/services/helpers.go index 1f82d3e..98fbeae 100644 --- a/idp/internal/services/helpers.go +++ b/idp/internal/services/helpers.go @@ -8,6 +8,7 @@ package services import ( "context" + "encoding/json" "fmt" "log/slog" "net" @@ -42,15 +43,34 @@ const ( TwoFactorTotp string = "totp" ResponseTypeCode string = "code" - ResponseTypeIdToken string = "id_token" ResponseTypeCodeIdToken string = "code id_token" UsernameColumnEmail string = "email" UsernameColumnUsername string = "username" UsernameColumnBoth string = "both" + GrantTypeAuthorizationCode string = "authorization_code" + GrantTypeRefreshToken string = "refresh_token" + GrantTypeClientCredentials string = "client_credentials" + GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code" + GrantTypeJwtBearer string = "urn:ietf:params:oauth:grant-type:jwt-bearer" + + SubjectTypePublic string = "public" + SubjectTypePairwise string = "pairwise" + ChallengeMethodPlain = "plain" ChallengeMethodS256 = "s256" + + TokenEncryptionAlgorithmRSAOAEP256 string = "RSA-OAEP-256" + TokenEncryptionAlgorithmECDHES string = "ECDH-ES" + TokenEncryptionAlgorithmECDHESA256KW string = "ECDH-ES+A256KW" + + TokenEncryptionEncodingA128CBCHS256 string = "A128CBC-HS256" + TokenEncryptionEncodingA192CBCHS384 string = "A192CBC-HS384" + TokenEncryptionEncodingA256CBCHS512 string = "A256CBC-HS512" + TokenEncryptionEncodingA128GCM string = "A128GCM" + TokenEncryptionEncodingA192GCM string = "A192GCM" + TokenEncryptionEncodingA256GCM string = "A256GCM" ) func (s *Services) buildLogger(requestID, location, function string) *slog.Logger { @@ -101,11 +121,10 @@ func mapAuthMethod(authMethod string) (database.AuthMethod, *exceptions.ServiceE } } -func mapResponseTypes(responseTypes []string) ([]database.ResponseType, *exceptions.ServiceError) { +func mapResponseTypesWithDefault(responseTypes []string) ([]database.ResponseType, *exceptions.ServiceError) { if len(responseTypes) == 0 { return []database.ResponseType{ database.ResponseTypeCode, - database.ResponseTypeIDToken, database.ResponseTypeCodeidToken, }, nil } @@ -115,8 +134,6 @@ func mapResponseTypes(responseTypes []string) ([]database.ResponseType, *excepti switch utils.Lowered(rt) { case ResponseTypeCode: dbResponseTypes = append(dbResponseTypes, database.ResponseTypeCode) - case ResponseTypeIdToken: - dbResponseTypes = append(dbResponseTypes, database.ResponseTypeIDToken) case ResponseTypeCodeIdToken: dbResponseTypes = append(dbResponseTypes, database.ResponseTypeCodeidToken) default: @@ -236,6 +253,101 @@ func mapEmptyString(str string) pgtype.Text { return pgtype.Text{String: strings.TrimSpace(str), Valid: true} } +func mapEmptyBigInt(bigInt int64) pgtype.Int8 { + if bigInt == 0 { + return pgtype.Int8{Valid: false} + } + + return pgtype.Int8{Int64: bigInt, Valid: true} +} + +func mapEmptySubjectType(subjectType string) (database.NullClientSubjectType, *exceptions.ServiceError) { + if subjectType == "" { + return database.NullClientSubjectType{Valid: false}, nil + } + + switch utils.Lowered(subjectType) { + case SubjectTypePublic: + return database.NullClientSubjectType{ClientSubjectType: database.ClientSubjectTypePublic, Valid: true}, nil + case SubjectTypePairwise: + return database.NullClientSubjectType{ClientSubjectType: database.ClientSubjectTypePairwise, Valid: true}, nil + default: + return database.NullClientSubjectType{Valid: false}, exceptions.NewValidationError("invalid subject type: " + subjectType) + } +} + +func mapEmptyTokenCryptoSuite(tokenCryptoSuite string) (database.NullTokenCryptoSuite, *exceptions.ServiceError) { + if tokenCryptoSuite == "" { + return database.NullTokenCryptoSuite{Valid: false}, nil + } + + cryptoSuite, err := mapCryptoSuite(utils.SupportedCryptoSuite(tokenCryptoSuite)) + if err != nil { + return database.NullTokenCryptoSuite{Valid: false}, exceptions.NewValidationError("invalid token crypto suite: " + tokenCryptoSuite) + } + + return database.NullTokenCryptoSuite{TokenCryptoSuite: cryptoSuite, Valid: true}, nil +} + +func mapTokenCryptoSuiteWithDefault(tokenCryptoSuite string) (database.TokenCryptoSuite, *exceptions.ServiceError) { + if tokenCryptoSuite == "" { + return database.TokenCryptoSuiteES256, nil + } + + cryptoSuite, err := mapCryptoSuite(utils.SupportedCryptoSuite(tokenCryptoSuite)) + if err != nil { + return "", exceptions.NewValidationError("invalid token crypto suite: " + tokenCryptoSuite) + } + + return cryptoSuite, nil +} + +func mapEmptyTokenEncryptionAlgorithm(tokenEncryptionAlgorithm string) (database.NullTokenEncryptionAlgorithm, *exceptions.ServiceError) { + if tokenEncryptionAlgorithm == "" { + return database.NullTokenEncryptionAlgorithm{Valid: false}, nil + } + + switch utils.Lowered(tokenEncryptionAlgorithm) { + case TokenEncryptionAlgorithmRSAOAEP256: + return database.NullTokenEncryptionAlgorithm{TokenEncryptionAlgorithm: database.TokenEncryptionAlgorithmRSAOAEP256, Valid: true}, nil + case TokenEncryptionAlgorithmECDHES: + return database.NullTokenEncryptionAlgorithm{TokenEncryptionAlgorithm: database.TokenEncryptionAlgorithmECDHES, Valid: true}, nil + case TokenEncryptionAlgorithmECDHESA256KW: + return database.NullTokenEncryptionAlgorithm{TokenEncryptionAlgorithm: database.TokenEncryptionAlgorithmECDHESA256KW, Valid: true}, nil + default: + return database.NullTokenEncryptionAlgorithm{Valid: false}, exceptions.NewValidationError("invalid token encryption algorithm: " + tokenEncryptionAlgorithm) + } +} + +func mapEmptyTokenEncryptionEncoding( + nullTokenEncryptionAlgorithm database.NullTokenEncryptionAlgorithm, + tokenEncryptionEncoding string, +) (database.NullTokenEncryptionEncoding, *exceptions.ServiceError) { + if !nullTokenEncryptionAlgorithm.Valid { + return database.NullTokenEncryptionEncoding{Valid: false}, nil + } + if tokenEncryptionEncoding == "" { + return database.NullTokenEncryptionEncoding{Valid: true, TokenEncryptionEncoding: database.TokenEncryptionEncodingA128CBCHS256}, nil + } + + switch utils.Lowered(tokenEncryptionEncoding) { + case TokenEncryptionEncodingA128CBCHS256: + return database.NullTokenEncryptionEncoding{Valid: true, TokenEncryptionEncoding: database.TokenEncryptionEncodingA128CBCHS256}, nil + case TokenEncryptionEncodingA192CBCHS384: + return database.NullTokenEncryptionEncoding{Valid: true, TokenEncryptionEncoding: database.TokenEncryptionEncodingA192CBCHS384}, nil + case TokenEncryptionEncodingA256CBCHS512: + return database.NullTokenEncryptionEncoding{Valid: true, TokenEncryptionEncoding: database.TokenEncryptionEncodingA256CBCHS512}, nil + case TokenEncryptionEncodingA128GCM: + return database.NullTokenEncryptionEncoding{Valid: true, TokenEncryptionEncoding: database.TokenEncryptionEncodingA128GCM}, nil + case TokenEncryptionEncodingA192GCM: + return database.NullTokenEncryptionEncoding{Valid: true, TokenEncryptionEncoding: database.TokenEncryptionEncodingA192GCM}, nil + case TokenEncryptionEncodingA256GCM: + return database.NullTokenEncryptionEncoding{Valid: true, TokenEncryptionEncoding: database.TokenEncryptionEncodingA256GCM}, nil + default: + return database.NullTokenEncryptionEncoding{Valid: false}, exceptions.NewValidationError("invalid token encryption encoding: " + tokenEncryptionEncoding) + } +} + type verifyTXTRecordOptions struct { requestID string host string @@ -271,3 +383,50 @@ func (s *Services) verifyTXTRecord( logger.InfoContext(ctx, "TXT code found in records") return nil } + +func mapEmptyJWKs(logger *slog.Logger, ctx context.Context, jsonJWKs []string) ([]byte, *exceptions.ServiceError) { + var jwks []byte + + if len(jsonJWKs) > 0 { + rawJWKs := make([]json.RawMessage, 0, len(jsonJWKs)) + for _, jwk := range jsonJWKs { + jwk, err := utils.JsonToJWK([]byte(jwk)) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse JWK", "error", err) + return nil, exceptions.NewInternalServerError() + } + jwkBytes, err := jwk.MarshalJSON() + if err != nil { + logger.ErrorContext(ctx, "Failed to marshal JWK", "error", err) + return nil, exceptions.NewInternalServerError() + } + rawJWKs = append(rawJWKs, jwkBytes) + } + + var err error + jwks, err = json.Marshal(rawJWKs) + if err != nil { + logger.ErrorContext(ctx, "Failed to marshal JWKS", "error", err) + return nil, exceptions.NewInternalServerError() + } + } + + return jwks, nil +} + +func mapGrantType(grantType string) (database.GrantType, *exceptions.ServiceError) { + switch utils.Lowered(grantType) { + case GrantTypeAuthorizationCode: + return database.GrantTypeAuthorizationCode, nil + case GrantTypeRefreshToken: + return database.GrantTypeRefreshToken, nil + case GrantTypeClientCredentials: + return database.GrantTypeClientCredentials, nil + case GrantTypeDeviceCode: + return database.GrantTypeUrnIetfParamsOauthGrantTypeDeviceCode, nil + case GrantTypeJwtBearer: + return database.GrantTypeUrnIetfParamsOauthGrantTypeJwtBearer, nil + default: + return "", exceptions.NewValidationError("invalid grant type: " + grantType) + } +} diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index 0996752..6bc428c 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -151,7 +151,7 @@ func (s *Services) oauthDynamicRegistrationIATAuth( var count int64 if tldOneDomain != opts.domain { - count, err = s.database.CountVerifiedAccountDynamicRegistrationDomainsByDomains( + count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomains( ctx, []string{opts.domain, tldOneDomain}, ) @@ -164,7 +164,7 @@ func (s *Services) oauthDynamicRegistrationIATAuth( return "", exceptions.NewForbiddenError() } } else { - count, err = s.database.CountVerifiedAccountDynamicRegistrationDomainsByDomain(ctx, opts.domain) + count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomain(ctx, opts.domain) } if err != nil { logger.ErrorContext(ctx, "Failed to count account dynamic registration domains by domains", "error", err) @@ -1157,17 +1157,17 @@ func (s *Services) VerifyOAuthDynamicRegistrationIATCode( var count int64 if tldOneDomain != data.Domain { - count, err = s.database.CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicID( + count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID( ctx, - database.CountVerifiedAccountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams{ + database.CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams{ AccountPublicID: accountDTO.PublicID, Domains: []string{data.Domain, tldOneDomain}, }, ) } else { - count, err = s.database.CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicID( + count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID( ctx, - database.CountVerifiedAccountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams{ + database.CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicIDParams{ AccountPublicID: accountDTO.PublicID, Domain: data.Domain, }, diff --git a/idp/internal/services/services.go b/idp/internal/services/services.go index d955264..b11d4d0 100644 --- a/idp/internal/services/services.go +++ b/idp/internal/services/services.go @@ -10,6 +10,8 @@ import ( "log/slog" "time" + "github.com/go-playground/validator/v10" + "github.com/tugascript/devlogs/idp/internal/providers/cache" "github.com/tugascript/devlogs/idp/internal/providers/crypto" "github.com/tugascript/devlogs/idp/internal/providers/database" @@ -27,6 +29,7 @@ type Services struct { jwt *tokens.Tokens crypto *crypto.Crypto oauthProviders *oauth.Providers + validate *validator.Validate kekExpDays time.Duration dekExpDays time.Duration jwkExpDays time.Duration @@ -46,6 +49,7 @@ func NewServices( jwt *tokens.Tokens, encrypt *crypto.Crypto, oauthProv *oauth.Providers, + validate *validator.Validate, kekExpDays int64, dekExpDays int64, jwkExpDays int64, @@ -64,6 +68,7 @@ func NewServices( jwt: jwt, crypto: encrypt, oauthProviders: oauthProv, + validate: validate, kekExpDays: utils.ToDaysDuration(kekExpDays), dekExpDays: utils.ToDaysDuration(dekExpDays), jwkExpDays: utils.ToDaysDuration(jwkExpDays), diff --git a/idp/internal/services/software_statement.go b/idp/internal/services/software_statement.go new file mode 100644 index 0000000..d6d771c --- /dev/null +++ b/idp/internal/services/software_statement.go @@ -0,0 +1,524 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + "fmt" + "slices" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/tokens" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +type ApplicationRegistrationData struct { + RedirectURIs []string + TokenEndpointAuthMethod string + ResponseTypes []string + GrantTypes []string + ApplicationType string + ClientName string + ClientURI string + LogoURI string + Scope string + Contacts []string + TOSURI string + PolicyURI string + JWKsURI string + JWKs []string + SoftwareID string + SoftwareVersion string + SubjectType string + SectorIdentifierURI string + DefaultMaxAge int64 + RequireAuthTime bool + DefaultACRValues []string + InitiateLoginURI string + RequestURIs []string + IDTokenSignedResponseAlg string + IDTokenEncryptedResponseAlg string + IDTokenEncryptedResponseEnc string + UserInfoSignedResponseAlg string + UserInfoEncryptedResponseAlg string + UserInfoEncryptedResponseEnc string + RequestObjectSigningAlg string + RequestObjectEncryptionAlg string + RequestObjectEncryptionEnc string + TokenEndpointAuthSigningAlg string + AccessTokenSigningAlg string +} + +type verifySoftwareStatementSTDClaimsOptions struct { + requestID string + backendDomain string + frontendDomain string + domain string + baseDomain string + claims *jwt.RegisteredClaims +} + +func (s *Services) verifySoftwareStatementSTDClaims( + ctx context.Context, + opts verifySoftwareStatementSTDClaimsOptions, +) *exceptions.ServiceError { + logger := s.buildLogger(opts.requestID, accountCredentialsRegistrationLocation, "verifySoftwareStatementSTDClaims").With( + "domain", opts.domain, + "baseDomain", opts.baseDomain, + ) + logger.InfoContext(ctx, "Verifying software statement standard claims") + + if opts.claims.Issuer != fmt.Sprintf("https://%s", opts.baseDomain) && + opts.claims.Issuer != fmt.Sprintf("https://%s", opts.domain) { + logger.WarnContext(ctx, "Software statement issuer does not match client URI domain or base domain", + "issuer", opts.claims.Issuer, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.Audience == nil || !slices.ContainsFunc(opts.claims.Audience, func(aud string) bool { + return aud == fmt.Sprintf("https://%s", opts.frontendDomain) || aud == fmt.Sprintf("https://%s", opts.backendDomain) + }) { + logger.WarnContext(ctx, "Software statement audience does not match frontend or backend domain", + "audience", opts.claims.Audience, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.IssuedAt == nil || opts.claims.IssuedAt.Time.IsZero() || opts.claims.IssuedAt.Time.After(time.Now()) { + logger.WarnContext(ctx, "Software statement issued at claim is invalid", + "issuedAt", opts.claims.IssuedAt, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.NotBefore != nil && !opts.claims.NotBefore.Time.IsZero() && opts.claims.NotBefore.Time.After(time.Now()) { + logger.WarnContext(ctx, "Software statement not before claim is invalid", + "notBefore", opts.claims.NotBefore, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.ExpiresAt == nil || opts.claims.ExpiresAt.Time.IsZero() || opts.claims.ExpiresAt.Time.Before(time.Now()) { + logger.WarnContext(ctx, "Software statement expiration claim is invalid", + "expiresAt", opts.claims.ExpiresAt, + ) + return exceptions.NewUnauthorizedError() + } + + logger.InfoContext(ctx, "Verified software statement standard claims") + return nil +} + +type verifySoftwareStatementClaimsOptions struct { + requestID string + clientName string + clientURI string + logoURI string + tosURI string + policyURI string + contacts []string + softwareID string + softwareVersion string + jwksURI string + jwks []string + claims *tokens.SoftwareStatementClaims +} + +func (s *Services) verifySoftwareStatementClaims( + ctx context.Context, + opts verifySoftwareStatementClaimsOptions, +) error { + logger := s.buildLogger( + opts.requestID, + accountCredentialsRegistrationLocation, + "verifySoftwareStatementClaims", + ) + + if opts.claims.ClientName != opts.clientName { + logger.WarnContext(ctx, "Client name in software statement does not match", + "expected", opts.clientName, "got", opts.claims.ClientName, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.ClientURI != opts.clientURI { + logger.WarnContext(ctx, "Client URI in software statement does not match", + "expected", opts.clientURI, "got", opts.claims.ClientURI, + ) + return exceptions.NewUnauthorizedError() + } + + if opts.claims.LogoURI != "" && opts.logoURI != "" && opts.claims.LogoURI != opts.logoURI { + logger.WarnContext(ctx, "Logo URI in software statement does not match", + "expected", opts.logoURI, "got", opts.claims.LogoURI, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.TOSURI != "" && opts.tosURI != "" && opts.claims.TOSURI != opts.tosURI { + logger.WarnContext(ctx, "Terms of Service URI in software statement does not match", + "expected", opts.tosURI, "got", opts.claims.TOSURI, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.PolicyURI != "" && opts.policyURI != "" && opts.claims.PolicyURI != opts.policyURI { + logger.WarnContext(ctx, "Policy URI in software statement does not match", + "expected", opts.policyURI, "got", opts.claims.PolicyURI, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.SoftwareID != "" && opts.softwareID != "" && opts.claims.SoftwareID != opts.softwareID { + logger.WarnContext(ctx, "Software id in software statement does not match", + "expected", opts.softwareID, "got", opts.claims.SoftwareID, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.SoftwareVersion != "" && opts.softwareVersion != "" && opts.claims.SoftwareVersion != opts.softwareVersion { + logger.WarnContext(ctx, "Software version in software statement does not match", + "expected", opts.softwareVersion, "got", opts.claims.SoftwareVersion, + ) + return exceptions.NewUnauthorizedError() + } + if opts.claims.JWKsURI != "" && opts.jwksURI != "" && opts.claims.JWKsURI != opts.jwksURI { + logger.WarnContext(ctx, "JWKs URI in software statement does not match", "expected", + opts.jwksURI, "got", opts.claims.JWKsURI, + ) + return exceptions.NewUnauthorizedError() + } + + if len(opts.claims.Contacts) > 0 && len(opts.contacts) > 0 { + claimsSet := utils.SliceToHashSet(opts.claims.Contacts) + for _, c := range opts.contacts { + if !claimsSet.Contains(c) { + logger.WarnContext(ctx, "Contact in registration not present in software statement", "contact", c) + return exceptions.NewUnauthorizedError() + } + } + } + + logger.InfoContext(ctx, "Verified software statement registration claims") + return nil +} + +// validateEncryptionAlgorithmPair validates that encryption algorithm and encoding are both set or both unset +func validateEncryptionAlgorithmPair(alg, enc string) bool { + if enc != "" && alg == "" { + return false + } + + return true +} + +type validateSoftwareStatementClaimsOptions struct { + requestID string + claims *tokens.SoftwareStatementClaims + data *ApplicationRegistrationData + allowedScopes utils.HashSet[string] +} + +func (s *Services) validateSoftwareStatementClaims( + ctx context.Context, + opts validateSoftwareStatementClaimsOptions, +) *exceptions.ServiceError { + logger := s.buildLogger( + opts.requestID, + accountCredentialsRegistrationLocation, + "validateSoftwareStatementClaims", + ) + logger.InfoContext(ctx, "Validating software statement claims") + + if err := s.validate.StructCtx(ctx, opts.claims); err != nil { + logger.WarnContext(ctx, "Invalid software statement claims", "error", err) + return exceptions.NewValidationError("Invalid software statement claims") + } + + if opts.claims.Scope != "" { + scopes := strings.Fields(opts.claims.Scope) + if len(scopes) == 0 { + logger.WarnContext(ctx, "Invalid scope format in software statement") + return exceptions.NewValidationError("invalid scope") + } + for _, scope := range scopes { + if !opts.allowedScopes.Contains(scope) { + logger.WarnContext(ctx, "Invalid scope in software statement", "scope", scope) + return exceptions.NewValidationError("invalid scope") + } + } + + scopesSet := utils.SliceToHashSet(scopes) + if scopesSet.Size() != len(scopes) { + logger.WarnContext(ctx, "Duplicate scopes in software statement", "scopes", scopes) + return exceptions.NewValidationError("duplicate scopes") + } + + dataScopes := strings.Fields(opts.data.Scope) + if len(dataScopes) != scopesSet.Size() { + logger.WarnContext(ctx, "Scope count mismatch", "expected", len(dataScopes), "got", scopesSet.Size()) + return exceptions.NewValidationError("scope count mismatch") + } + + for _, scope := range dataScopes { + if !scopesSet.Contains(scope) { + logger.WarnContext(ctx, "Scope mismatch", "expected", scope, "got", scopesSet.Contains(scope)) + return exceptions.NewValidationError("scope mismatch") + } + } + } + + if len(opts.claims.JWKs) > 0 { + jwks := make([]utils.JWK, len(opts.claims.JWKs)) + indexMap := make(map[string]int) + for i, rawJWK := range opts.claims.JWKs { + jwk, err := utils.JsonToJWK([]byte(rawJWK)) + if err != nil { + logger.WarnContext(ctx, "Invalid JWK JSON in software statement", "error", err) + return exceptions.NewValidationError("invalid jwks") + } + jwks[i] = jwk + indexMap[jwk.GetKeyID()] = i + } + + if len(opts.data.JWKs) > 0 { + if len(jwks) != len(opts.data.JWKs) { + logger.WarnContext(ctx, "JWK count mismatch", "expected", len(opts.data.JWKs), "got", len(jwks)) + return exceptions.NewValidationError("jwk count mismatch") + } + + for _, rawJWK := range opts.data.JWKs { + jwk, err := utils.JsonToJWK([]byte(rawJWK)) + if err != nil { + logger.WarnContext(ctx, "Invalid JWK JSON in software statement", "error", err) + return exceptions.NewValidationError("invalid jwks") + } + + index, ok := indexMap[jwk.GetKeyID()] + if !ok { + logger.WarnContext(ctx, "JWK not found in software statement", "jwk", jwk.GetKeyID()) + return exceptions.NewValidationError("jwk not found in software statement") + } + if jwks[index].ComparePublicKey(jwk) { + logger.WarnContext(ctx, "JWK mismatch", "expected", jwks[index].GetKeyID(), "got", jwk.GetKeyID()) + return exceptions.NewValidationError("jwk mismatch") + } + } + } + } + + if !validateEncryptionAlgorithmPair(opts.claims.IDTokenEncryptedResponseAlg, opts.claims.IDTokenEncryptedResponseEnc) { + logger.WarnContext(ctx, "id_token encryption algorithm and encoding must both be set or both be unset") + return exceptions.NewValidationError("id_token encryption algorithm and encoding mismatch") + } + + if !validateEncryptionAlgorithmPair(opts.claims.UserInfoEncryptedResponseAlg, opts.claims.UserInfoEncryptedResponseEnc) { + logger.WarnContext(ctx, "userinfo encryption algorithm and encoding must both be set or both be unset") + return exceptions.NewValidationError("userinfo encryption algorithm and encoding mismatch") + } + + if !validateEncryptionAlgorithmPair(opts.claims.RequestObjectEncryptionAlg, opts.claims.RequestObjectEncryptionEnc) { + logger.WarnContext(ctx, "request_object encryption algorithm and encoding must both be set or both be unset") + return exceptions.NewValidationError("request_object encryption algorithm and encoding mismatch") + } + + if len(opts.data.RedirectURIs) > 0 && len(opts.claims.RedirectURIs) > 0 { + if len(opts.data.RedirectURIs) != len(opts.claims.RedirectURIs) { + logger.WarnContext(ctx, "Redirect URI count mismatch", "expected", len(opts.data.RedirectURIs), "got", len(opts.claims.RedirectURIs)) + return exceptions.NewValidationError("redirect URI count mismatch") + } + + redirectURIsSet := utils.SliceToHashSet(opts.claims.RedirectURIs) + if redirectURIsSet.Size() != len(opts.claims.RedirectURIs) { + logger.WarnContext(ctx, "Duplicate redirect URIs in software statement", "redirectURIs", opts.claims.RedirectURIs) + return exceptions.NewValidationError("duplicate redirect URIs") + } + + for _, redirectURI := range opts.data.RedirectURIs { + if !redirectURIsSet.Contains(redirectURI) { + logger.WarnContext(ctx, "Redirect URI not found in software statement", "redirectURI", redirectURI) + return exceptions.NewValidationError("redirect URI not found in software statement") + } + } + } + + if opts.claims.TokenEndpointAuthMethod != "" && opts.data.TokenEndpointAuthMethod != "" && opts.claims.TokenEndpointAuthMethod != opts.data.TokenEndpointAuthMethod { + logger.WarnContext(ctx, "Token endpoint auth method mismatch", "expected", opts.data.TokenEndpointAuthMethod, "got", opts.claims.TokenEndpointAuthMethod) + return exceptions.NewValidationError("token endpoint auth method mismatch") + } + + if len(opts.claims.ResponseTypes) > 0 && len(opts.data.ResponseTypes) > 0 { + if len(opts.claims.ResponseTypes) != len(opts.data.ResponseTypes) { + logger.WarnContext(ctx, "Response type count mismatch", "expected", len(opts.data.ResponseTypes), "got", len(opts.claims.ResponseTypes)) + return exceptions.NewValidationError("response type count mismatch") + } + + responseTypesSet := utils.SliceToHashSet(opts.claims.ResponseTypes) + if responseTypesSet.Size() != len(opts.claims.ResponseTypes) { + logger.WarnContext(ctx, "Duplicate response types in software statement", "responseTypes", opts.claims.ResponseTypes) + return exceptions.NewValidationError("duplicate response types") + } + + for _, responseType := range opts.data.ResponseTypes { + if !responseTypesSet.Contains(responseType) { + logger.WarnContext(ctx, "Response type not found in software statement", "responseType", responseType) + return exceptions.NewValidationError("response type not found in software statement") + } + } + } + + if len(opts.claims.GrantTypes) > 0 && len(opts.data.GrantTypes) > 0 { + if len(opts.claims.GrantTypes) != len(opts.data.GrantTypes) { + logger.WarnContext(ctx, "Grant type count mismatch", "expected", len(opts.data.GrantTypes), "got", len(opts.claims.GrantTypes)) + return exceptions.NewValidationError("grant type count mismatch") + } + + grantTypesSet := utils.SliceToHashSet(opts.claims.GrantTypes) + if grantTypesSet.Size() != len(opts.claims.GrantTypes) { + logger.WarnContext(ctx, "Duplicate grant types in software statement", "grantTypes", opts.claims.GrantTypes) + return exceptions.NewValidationError("duplicate grant types") + } + + for _, grantType := range opts.data.GrantTypes { + if !grantTypesSet.Contains(grantType) { + logger.WarnContext(ctx, "Grant type not found in software statement", "grantType", grantType) + return exceptions.NewValidationError("grant type not found in software statement") + } + } + } + + if opts.claims.ApplicationType != "" && opts.data.ApplicationType != "" && opts.claims.ApplicationType != opts.data.ApplicationType { + logger.WarnContext(ctx, "Application type mismatch", "expected", opts.data.ApplicationType, "got", opts.claims.ApplicationType) + return exceptions.NewValidationError("application type mismatch") + } + if opts.claims.ClientName != "" && opts.data.ClientName != "" && opts.claims.ClientName != opts.data.ClientName { + logger.WarnContext(ctx, "Client name mismatch", "expected", opts.data.ClientName, "got", opts.claims.ClientName) + return exceptions.NewValidationError("client name mismatch") + } + if opts.claims.ClientURI != "" && opts.data.ClientURI != "" && opts.claims.ClientURI != opts.data.ClientURI { + logger.WarnContext(ctx, "Client URI mismatch", "expected", opts.data.ClientURI, "got", opts.claims.ClientURI) + return exceptions.NewValidationError("client URI mismatch") + } + if opts.claims.LogoURI != "" && opts.data.LogoURI != "" && opts.claims.LogoURI != opts.data.LogoURI { + logger.WarnContext(ctx, "Logo URI mismatch", "expected", opts.data.LogoURI, "got", opts.claims.LogoURI) + return exceptions.NewValidationError("logo URI mismatch") + } + if opts.claims.TOSURI != "" && opts.data.TOSURI != "" && opts.claims.TOSURI != opts.data.TOSURI { + logger.WarnContext(ctx, "Terms of Service URI mismatch", "expected", opts.data.TOSURI, "got", opts.claims.TOSURI) + return exceptions.NewValidationError("terms of service URI mismatch") + } + if opts.claims.PolicyURI != "" && opts.data.PolicyURI != "" && opts.claims.PolicyURI != opts.data.PolicyURI { + logger.WarnContext(ctx, "Policy URI mismatch", "expected", opts.data.PolicyURI, "got", opts.claims.PolicyURI) + return exceptions.NewValidationError("policy URI mismatch") + } + if opts.claims.SoftwareID != "" && opts.data.SoftwareID != "" && opts.claims.SoftwareID != opts.data.SoftwareID { + logger.WarnContext(ctx, "Software ID mismatch", "expected", opts.data.SoftwareID, "got", opts.claims.SoftwareID) + return exceptions.NewValidationError("software ID mismatch") + } + if opts.claims.SoftwareVersion != "" && opts.data.SoftwareVersion != "" && opts.claims.SoftwareVersion != opts.data.SoftwareVersion { + logger.WarnContext(ctx, "Software version mismatch", "expected", opts.data.SoftwareVersion, "got", opts.claims.SoftwareVersion) + return exceptions.NewValidationError("software version mismatch") + } + if opts.claims.SubjectType != "" && opts.data.SubjectType != "" && opts.claims.SubjectType != opts.data.SubjectType { + logger.WarnContext(ctx, "Subject type mismatch", "expected", opts.data.SubjectType, "got", opts.claims.SubjectType) + return exceptions.NewValidationError("subject type mismatch") + } + if opts.claims.SectorIdentifierURI != "" && opts.data.SectorIdentifierURI != "" && opts.claims.SectorIdentifierURI != opts.data.SectorIdentifierURI { + logger.WarnContext(ctx, "Sector identifier URI mismatch", "expected", opts.data.SectorIdentifierURI, "got", opts.claims.SectorIdentifierURI) + return exceptions.NewValidationError("sector identifier URI mismatch") + } + if opts.claims.DefaultMaxAge != 0 && opts.data.DefaultMaxAge != 0 && opts.claims.DefaultMaxAge != opts.data.DefaultMaxAge { + logger.WarnContext(ctx, "Default max age mismatch", "expected", opts.data.DefaultMaxAge, "got", opts.claims.DefaultMaxAge) + return exceptions.NewValidationError("default max age mismatch") + } + if opts.claims.RequireAuthTime != false && opts.data.RequireAuthTime != false && opts.claims.RequireAuthTime != opts.data.RequireAuthTime { + logger.WarnContext(ctx, "Require auth time mismatch", "expected", opts.data.RequireAuthTime, "got", opts.claims.RequireAuthTime) + return exceptions.NewValidationError("require auth time mismatch") + } + if len(opts.claims.DefaultACRValues) > 0 && len(opts.data.DefaultACRValues) > 0 { + if len(opts.claims.DefaultACRValues) != len(opts.data.DefaultACRValues) { + logger.WarnContext(ctx, "Default ACR value count mismatch", "expected", len(opts.data.DefaultACRValues), "got", len(opts.claims.DefaultACRValues)) + return exceptions.NewValidationError("default ACR value count mismatch") + } + defaultACRValuesSet := utils.SliceToHashSet(opts.claims.DefaultACRValues) + if defaultACRValuesSet.Size() != len(opts.claims.DefaultACRValues) { + logger.WarnContext(ctx, "Duplicate default ACR values in software statement", "defaultACRValues", opts.claims.DefaultACRValues) + return exceptions.NewValidationError("duplicate default ACR values") + } + for _, defaultACRValue := range opts.data.DefaultACRValues { + if !defaultACRValuesSet.Contains(defaultACRValue) { + logger.WarnContext(ctx, "Default ACR value not found in software statement", "defaultACRValue", defaultACRValue) + return exceptions.NewValidationError("default ACR value not found in software statement") + } + } + } + if opts.claims.InitiateLoginURI != "" && opts.data.InitiateLoginURI != "" && opts.claims.InitiateLoginURI != opts.data.InitiateLoginURI { + logger.WarnContext(ctx, "Initiate login URI mismatch", "expected", opts.data.InitiateLoginURI, "got", opts.claims.InitiateLoginURI) + return exceptions.NewValidationError("initiate login URI mismatch") + } + if len(opts.claims.RequestURIs) > 0 && len(opts.data.RequestURIs) > 0 { + if len(opts.claims.RequestURIs) != len(opts.data.RequestURIs) { + logger.WarnContext(ctx, "Request URI count mismatch", "expected", len(opts.data.RequestURIs), "got", len(opts.claims.RequestURIs)) + return exceptions.NewValidationError("request URI count mismatch") + } + + requestURIsSet := utils.SliceToHashSet(opts.claims.RequestURIs) + if requestURIsSet.Size() != len(opts.claims.RequestURIs) { + logger.WarnContext(ctx, "Duplicate request URIs in software statement", "requestURIs", opts.claims.RequestURIs) + return exceptions.NewValidationError("duplicate request URIs") + } + for _, requestURI := range opts.data.RequestURIs { + if !requestURIsSet.Contains(requestURI) { + logger.WarnContext(ctx, "Request URI not found in software statement", "requestURI", requestURI) + return exceptions.NewValidationError("request URI not found in software statement") + } + } + } + if opts.claims.IDTokenSignedResponseAlg != "" && opts.data.IDTokenSignedResponseAlg != "" && opts.claims.IDTokenSignedResponseAlg != opts.data.IDTokenSignedResponseAlg { + logger.WarnContext(ctx, "ID token signed response algorithm mismatch", "expected", opts.data.IDTokenSignedResponseAlg, "got", opts.claims.IDTokenSignedResponseAlg) + return exceptions.NewValidationError("id token signed response algorithm mismatch") + } + if opts.claims.IDTokenEncryptedResponseAlg != "" && opts.data.IDTokenEncryptedResponseAlg != "" && opts.claims.IDTokenEncryptedResponseAlg != opts.data.IDTokenEncryptedResponseAlg { + logger.WarnContext(ctx, "ID token encrypted response algorithm mismatch", "expected", opts.data.IDTokenEncryptedResponseAlg, "got", opts.claims.IDTokenEncryptedResponseAlg) + return exceptions.NewValidationError("id token encrypted response algorithm mismatch") + } + if opts.claims.IDTokenEncryptedResponseEnc != "" && opts.data.IDTokenEncryptedResponseEnc != "" && opts.claims.IDTokenEncryptedResponseEnc != opts.data.IDTokenEncryptedResponseEnc { + logger.WarnContext(ctx, "ID token encrypted response encoding mismatch", "expected", opts.data.IDTokenEncryptedResponseEnc, "got", opts.claims.IDTokenEncryptedResponseEnc) + return exceptions.NewValidationError("id token encrypted response encoding mismatch") + } + if opts.claims.UserInfoSignedResponseAlg != "" && opts.data.UserInfoSignedResponseAlg != "" && opts.claims.UserInfoSignedResponseAlg != opts.data.UserInfoSignedResponseAlg { + logger.WarnContext(ctx, "User info signed response algorithm mismatch", "expected", opts.data.UserInfoSignedResponseAlg, "got", opts.claims.UserInfoSignedResponseAlg) + return exceptions.NewValidationError("user info signed response algorithm mismatch") + } + if opts.claims.UserInfoEncryptedResponseAlg != "" && opts.data.UserInfoEncryptedResponseAlg != "" && opts.claims.UserInfoEncryptedResponseAlg != opts.data.UserInfoEncryptedResponseAlg { + logger.WarnContext(ctx, "User info encrypted response algorithm mismatch", "expected", opts.data.UserInfoEncryptedResponseAlg, "got", opts.claims.UserInfoEncryptedResponseAlg) + return exceptions.NewValidationError("user info encrypted response algorithm mismatch") + } + if opts.claims.UserInfoEncryptedResponseEnc != "" && opts.data.UserInfoEncryptedResponseEnc != "" && opts.claims.UserInfoEncryptedResponseEnc != opts.data.UserInfoEncryptedResponseEnc { + logger.WarnContext(ctx, "User info encrypted response encoding mismatch", "expected", opts.data.UserInfoEncryptedResponseEnc, "got", opts.claims.UserInfoEncryptedResponseEnc) + return exceptions.NewValidationError("user info encrypted response encoding mismatch") + } + if opts.claims.RequestObjectSigningAlg != "" && opts.data.RequestObjectSigningAlg != "" && opts.claims.RequestObjectSigningAlg != opts.data.RequestObjectSigningAlg { + logger.WarnContext(ctx, "Request object signed response algorithm mismatch", "expected", opts.data.RequestObjectSigningAlg, "got", opts.claims.RequestObjectSigningAlg) + return exceptions.NewValidationError("request object signed response algorithm mismatch") + } + if opts.claims.RequestObjectEncryptionAlg != "" && opts.data.RequestObjectEncryptionAlg != "" && opts.claims.RequestObjectEncryptionAlg != opts.data.RequestObjectEncryptionAlg { + logger.WarnContext(ctx, "Request object encrypted response algorithm mismatch", "expected", opts.data.RequestObjectEncryptionAlg, "got", opts.claims.RequestObjectEncryptionAlg) + return exceptions.NewValidationError("request object encrypted response algorithm mismatch") + } + if opts.claims.RequestObjectEncryptionEnc != "" && opts.data.RequestObjectEncryptionEnc != "" && opts.claims.RequestObjectEncryptionEnc != opts.data.RequestObjectEncryptionEnc { + logger.WarnContext(ctx, "Request object encrypted response encoding mismatch", "expected", opts.data.RequestObjectEncryptionEnc, "got", opts.claims.RequestObjectEncryptionEnc) + return exceptions.NewValidationError("request object encrypted response encoding mismatch") + } + if opts.claims.TokenEndpointAuthSigningAlg != "" && opts.data.TokenEndpointAuthSigningAlg != "" && opts.claims.TokenEndpointAuthSigningAlg != opts.data.TokenEndpointAuthSigningAlg { + logger.WarnContext(ctx, "Token endpoint auth signing algorithm mismatch", "expected", opts.data.TokenEndpointAuthSigningAlg, "got", opts.claims.TokenEndpointAuthSigningAlg) + return exceptions.NewValidationError("token endpoint auth signing algorithm mismatch") + } + if opts.claims.AccessTokenSigningAlg != "" && opts.data.AccessTokenSigningAlg != "" && opts.claims.AccessTokenSigningAlg != opts.data.AccessTokenSigningAlg { + logger.WarnContext(ctx, "Access token signing algorithm mismatch", "expected", opts.data.AccessTokenSigningAlg, "got", opts.claims.AccessTokenSigningAlg) + return exceptions.NewValidationError("access token signing algorithm mismatch") + } + + logger.InfoContext(ctx, "Validated software statement claims") + return nil +} diff --git a/idp/internal/services/users_auth.go b/idp/internal/services/users_auth.go index e15a36a..b56c68e 100644 --- a/idp/internal/services/users_auth.go +++ b/idp/internal/services/users_auth.go @@ -283,7 +283,7 @@ func (s *Services) RegisterUser( accountUsername: opts.AccountUsername, appVersion: appDTO.Version(), appClientID: appDTO.ClientID, - appName: appDTO.Name, + appName: appDTO.ClientName, // TODO: add from app type appConfirmationURI: appDTO.ConfirmationURI, }); serviceErr != nil { return dtos.MessageDTO{}, serviceErr @@ -632,7 +632,7 @@ func (s *Services) LoginUser( accountUsername: opts.AccountUsername, appVersion: appDTO.Version(), appClientID: appDTO.ClientID, - appName: appDTO.Name, + appName: appDTO.ClientName, // TODO: add from app type appConfirmationURI: appDTO.ConfirmationURI, }); serviceErr != nil { return dtos.AuthDTO{}, serviceErr @@ -1045,7 +1045,7 @@ func (s *Services) ForgotUserPassword( if err := s.mail.PublishUserResetEmail(ctx, mailer.UserResetEmailOptions{ RequestID: opts.RequestID, - AppName: appDTO.Name, + AppName: appDTO.ClientName, Email: userDTO.Email, ResetToken: signedResetToken, // TODO: add from app type ResetURI: appDTO.ResetURI, diff --git a/idp/internal/utils/jwk.go b/idp/internal/utils/jwk.go index 2eea180..7fb5700 100644 --- a/idp/internal/utils/jwk.go +++ b/idp/internal/utils/jwk.go @@ -27,6 +27,7 @@ const ( SupportedCryptoSuiteEd25519 SupportedCryptoSuite = "EdDSA" SupportedCryptoSuiteES256 SupportedCryptoSuite = "ES256" SupportedCryptoSuiteHS256 SupportedCryptoSuite = "HS256" + SupportedCryptoSuiteRS256 SupportedCryptoSuite = "RS256" ) func GetSupportedCryptoSuite(cryptoSuite string) (SupportedCryptoSuite, error) { @@ -46,6 +47,7 @@ type JWK interface { ToUsableKey() (any, error) MarshalJSON() ([]byte, error) ToPrivateKey() (any, error) + ComparePublicKey(other JWK) bool } type Ed25519JWK struct { @@ -79,6 +81,15 @@ func (j *Ed25519JWK) ToPrivateKey() (any, error) { return DecodeEd25519JwkPrivate(j) } +func (j *Ed25519JWK) ComparePublicKey(other JWK) bool { + otherEdJwk, ok := other.(*Ed25519JWK) + if !ok { + return false + } + + return otherEdJwk.X == j.X && otherEdJwk.Kty == j.Kty && otherEdJwk.Crv == j.Crv && otherEdJwk.Alg == j.Alg +} + type ES256JWK struct { Kty string `json:"kty"` // Key Type (EC for Elliptic Curve) Crv string `json:"crv"` // Curve (P-256) @@ -111,6 +122,16 @@ func (j *ES256JWK) ToPrivateKey() (any, error) { return DecodeP256JwkPrivate(j) } +func (j *ES256JWK) ComparePublicKey(other JWK) bool { + otherESJwk, ok := other.(*ES256JWK) + if !ok { + return false + } + + return otherESJwk.X == j.X && otherESJwk.Y == j.Y && otherESJwk.Kty == j.Kty && + otherESJwk.Crv == j.Crv && otherESJwk.Alg == j.Alg +} + type RS256JWK struct { Kty string `json:"kty"` Kid string `json:"kid"` @@ -121,6 +142,35 @@ type RS256JWK struct { KeyOps []string `json:"key_ops,omitempty"` } +func (j *RS256JWK) ComparePublicKey(other JWK) bool { + otherRSJwk, ok := other.(*RS256JWK) + if !ok { + return false + } + + return otherRSJwk.N == j.N && otherRSJwk.E == j.E && otherRSJwk.Kty == j.Kty && otherRSJwk.Alg == j.Alg +} + +func (j *RS256JWK) GetKeyType() string { + return j.Kty +} + +func (j *RS256JWK) GetKeyID() string { + return j.Kid +} + +func (j *RS256JWK) ToUsableKey() (any, error) { + return DecodeRS256Jwk(j) +} + +func (j *RS256JWK) MarshalJSON() ([]byte, error) { + return json.Marshal(*j) +} + +func (j *RS256JWK) ToPrivateKey() (any, error) { + return nil, fmt.Errorf("not implemented") +} + const ( okpKty string = "OKP" ed25519Crv string = "Ed25519" @@ -132,6 +182,8 @@ const ( alg string = "EdDSA" verify string = "verify" sign string = "sign" + + rsaKty string = "RSA" ) func bigIntToPaddedBytes(n *big.Int, length int) []byte { @@ -342,6 +394,12 @@ func JsonToJWK(jsonBytes []byte) (JWK, error) { return nil, err } return &jwk, nil + case rsaKty: + var jwk RS256JWK + if err := json.Unmarshal(jsonBytes, &jwk); err != nil { + return nil, err + } + return &jwk, nil default: return nil, fmt.Errorf("unsupported key type: %s", kty) } diff --git a/idp/tests/account_credentials_test.go b/idp/tests/account_credentials_test.go index 654ab98..c214afd 100644 --- a/idp/tests/account_credentials_test.go +++ b/idp/tests/account_credentials_test.go @@ -423,7 +423,7 @@ func TestUpdateAccountCredentials(t *testing.T) { ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, _ bodies.UpdateAccountCredentialsBody, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.AccountCredentialsDTO{}) - AssertEqual(t, resBody.Name, "updated-service-name") + AssertEqual(t, resBody.ClientName, "updated-service-name") AssertEqual(t, len(resBody.Scopes), 1) AssertEqual(t, resBody.Scopes[0], "account:users:read") AssertEqual(t, resBody.SoftwareVersion, "2.0.0") @@ -467,7 +467,7 @@ func TestUpdateAccountCredentials(t *testing.T) { ExpStatus: http.StatusOK, AssertFn: func(t *testing.T, _ bodies.UpdateAccountCredentialsBody, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.AccountCredentialsDTO{}) - AssertEqual(t, resBody.Name, "updated-mcp-name") + AssertEqual(t, resBody.ClientName, "updated-mcp-name") AssertEqual(t, len(resBody.Scopes), 2) AssertEqual(t, resBody.Scopes[0], "account:users:read") AssertEqual(t, resBody.Scopes[1], "account:apps:read") @@ -785,7 +785,7 @@ func TestGetSingleAccountCredentials(t *testing.T) { AssertFn: func(t *testing.T, _ any, res *http.Response) { resBody := AssertTestResponseBody(t, res, dtos.AccountCredentialsDTO{}) AssertNotEmpty(t, resBody.ClientID) - AssertNotEmpty(t, resBody.Name) + AssertNotEmpty(t, resBody.ClientName) AssertEqual(t, resBody.TokenEndpointAuthMethod, database.AuthMethodClientSecretBasic) AssertEmpty(t, resBody.ClientSecret) AssertEmpty(t, resBody.ClientSecretJWK) From 69b85ec89ae18aeb477c0d01244ecbdac0d07874 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 2 Nov 2025 22:25:16 +1300 Subject: [PATCH 21/23] feat(idp): add account dynamic registration route --- idp/internal/controllers/helpers.go | 31 +- idp/internal/controllers/middleware.go | 34 +- idp/internal/controllers/oauth.go | 6 +- .../controllers/oauth_dynamic_registration.go | 710 ++---------------- .../oauth_dynamic_registration_iat.go | 676 +++++++++++++++++ idp/internal/exceptions/controllers.go | 8 +- idp/internal/exceptions/services.go | 10 + .../providers/tokens/dynamic_registration.go | 63 -- .../tokens/dynamic_registration_iat.go | 106 +++ idp/internal/server/routes/oauth.go | 7 + .../account_credentials_registration.go | 39 +- .../account_credentials_registration_iat.go | 38 + .../services/dtos/account_credentials.go | 9 +- idp/internal/services/software_statement.go | 10 +- 14 files changed, 1016 insertions(+), 731 deletions(-) create mode 100644 idp/internal/controllers/oauth_dynamic_registration_iat.go delete mode 100644 idp/internal/providers/tokens/dynamic_registration.go create mode 100644 idp/internal/providers/tokens/dynamic_registration_iat.go diff --git a/idp/internal/controllers/helpers.go b/idp/internal/controllers/helpers.go index 3e3d1a4..abcf2ce 100644 --- a/idp/internal/controllers/helpers.go +++ b/idp/internal/controllers/helpers.go @@ -134,17 +134,21 @@ func oauthErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, message string) err switch message { case exceptions.OAuthErrorInvalidRequest, exceptions.OAuthErrorInvalidGrant, - exceptions.OAuthErrorInvalidScope, exceptions.OAuthErrorUnsupportedGrantType: + exceptions.OAuthErrorInvalidScope, exceptions.OAuthErrorUnsupportedGrantType, + exceptions.OAuthErrorInvalidRedirectURI, exceptions.OAuthErrorInvalidClientMetadata, + exceptions.OAuthErrorInvalidSoftwareStatement, exceptions.OAuthErrorUnapprovedSoftwareStatement, + exceptions.OAuthErrorUnsupportedResponseType: logResponse(logger, ctx, fiber.StatusBadRequest) return ctx.Status(fiber.StatusBadRequest).JSON(&resErr) - case exceptions.OAuthErrorUnauthorizedClient, exceptions.OAuthErrorAccessDenied: + case exceptions.OAuthErrorUnauthorizedClient, exceptions.OAuthErrorAccessDenied, exceptions.OAuthErrorInvalidToken: logResponse(logger, ctx, fiber.StatusUnauthorized) return ctx.Status(fiber.StatusUnauthorized).JSON(&resErr) - case exceptions.OAuthServerError: + case exceptions.OAuthErrorServerError: logResponse(logger, ctx, fiber.StatusInternalServerError) return ctx.Status(fiber.StatusInternalServerError).JSON(&resErr) default: logResponse(logger, ctx, fiber.StatusBadRequest) + resErr = exceptions.NewOAuthError(exceptions.OAuthErrorInvalidRequest) return ctx.Status(fiber.StatusBadRequest).JSON(&resErr) } } @@ -187,6 +191,25 @@ func (c *Controllers) redirectServiceErrorCallback( case exceptions.CodeNotFound, exceptions.CodeValidation: return c.redirectErrorCallback(logger, ctx, redirectURI, state, exceptions.OAuthErrorInvalidRequest) default: - return c.redirectErrorCallback(logger, ctx, redirectURI, state, exceptions.OAuthServerError) + return c.redirectErrorCallback(logger, ctx, redirectURI, state, exceptions.OAuthErrorServerError) + } +} + +func dynamicRegistrationServiceError( + logger *slog.Logger, + ctx *fiber.Ctx, + serviceErr *exceptions.ServiceError, +) error { + switch serviceErr.Code { + case exceptions.CodeUnauthorized, exceptions.CodeForbidden: + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorUnauthorizedClient) + case exceptions.CodeNotFound, exceptions.CodeValidation: + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidClientMetadata) + case exceptions.CodeInvalidToken: + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidSoftwareStatement) + case exceptions.CodeUnauthorizedToken: + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorUnapprovedSoftwareStatement) + default: + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorServerError) } } diff --git a/idp/internal/controllers/middleware.go b/idp/internal/controllers/middleware.go index 7840ec5..fee05eb 100644 --- a/idp/internal/controllers/middleware.go +++ b/idp/internal/controllers/middleware.go @@ -135,7 +135,8 @@ func (c *Controllers) TwoFAAccessClaimsMiddleware(ctx *fiber.Ctx) error { } func (c *Controllers) AppAccessClaimsMiddleware(ctx *fiber.Ctx) error { - logger := c.buildLogger(getRequestID(ctx), middlewareLocation, "AppAccessClaimsMiddleware") + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, middlewareLocation, "AppAccessClaimsMiddleware") authHeader := ctx.Get("Authorization") if authHeader == "" { @@ -150,7 +151,7 @@ func (c *Controllers) AppAccessClaimsMiddleware(ctx *fiber.Ctx) error { appClaims, serviceErr := c.services.ProcessAppAuthHeader( ctx.UserContext(), services.ProcessAppAuthHeaderOptions{ - RequestID: getRequestID(ctx), + RequestID: requestID, AuthHeader: authHeader, AccountID: accountID, }, @@ -163,6 +164,35 @@ func (c *Controllers) AppAccessClaimsMiddleware(ctx *fiber.Ctx) error { return ctx.Next() } +func (c *Controllers) AccountCredentialsDRIATMiddleware(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, middlewareLocation, "AccountCredentialsDRIATMiddleware") + authHeader := ctx.Get("Authorization") + + if authHeader == "" { + logger.InfoContext(ctx.UserContext(), "No Authorization header found") + ctx.Locals("isAuthenticated", false) + return ctx.Next() + } + + domain, accountClaims, serviceErr := c.services.ProcessAccountCredentialsRegistrationIATAuth( + ctx.UserContext(), + services.ProcessAccountCredentialsRegistrationIATAuthOptions{ + RequestID: requestID, + AuthHeader: authHeader, + }, + ) + if serviceErr != nil { + ctx.Set(fiber.HeaderWWWAuthenticate, "Bearer realm=\"accounts\", error=\"invalid_token\"") + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidToken) + } + + ctx.Locals("account", accountClaims) + ctx.Locals("domain", domain) + ctx.Locals("isAuthenticated", true) + return ctx.Next() +} + func (c *Controllers) ScopeMiddleware(scope tokens.AccountScope) func(*fiber.Ctx) error { return func(ctx *fiber.Ctx) error { logger := c.buildLogger(getRequestID(ctx), middlewareLocation, "ScopeMiddleware") diff --git a/idp/internal/controllers/oauth.go b/idp/internal/controllers/oauth.go index 8e57b77..20f4de6 100644 --- a/idp/internal/controllers/oauth.go +++ b/idp/internal/controllers/oauth.go @@ -59,7 +59,7 @@ func (c *Controllers) serviceErrorCallback( case exceptions.CodeNotFound, exceptions.CodeValidation: return c.errorCallback(logger, ctx, state, exceptions.OAuthErrorInvalidRequest) default: - return c.errorCallback(logger, ctx, state, exceptions.OAuthServerError) + return c.errorCallback(logger, ctx, state, exceptions.OAuthErrorServerError) } } @@ -191,7 +191,7 @@ func oauthErrorResponseMapper(logger *slog.Logger, ctx *fiber.Ctx, serviceErr *e case exceptions.CodeForbidden: return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorUnauthorizedClient) default: - return oauthErrorResponse(logger, ctx, exceptions.OAuthServerError) + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorServerError) } } @@ -204,7 +204,7 @@ func oauthClientCredentialsErrorResponse(logger *slog.Logger, ctx *fiber.Ctx, se case exceptions.CodeForbidden: return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorUnauthorizedClient) default: - return oauthErrorResponse(logger, ctx, exceptions.OAuthServerError) + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorServerError) } } diff --git a/idp/internal/controllers/oauth_dynamic_registration.go b/idp/internal/controllers/oauth_dynamic_registration.go index e7a49d1..71f75f1 100644 --- a/idp/internal/controllers/oauth_dynamic_registration.go +++ b/idp/internal/controllers/oauth_dynamic_registration.go @@ -7,670 +7,110 @@ package controllers import ( - "encoding/json" - "fmt" - "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/tugascript/devlogs/idp/internal/controllers/bodies" "github.com/tugascript/devlogs/idp/internal/controllers/params" - "github.com/tugascript/devlogs/idp/internal/controllers/paths" "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/tokens" "github.com/tugascript/devlogs/idp/internal/services" - "github.com/tugascript/devlogs/idp/internal/utils" -) - -const ( - oauthDynamicRegistration string = "oauth_dynamic_registration" - - accountsIATCookieSuffix string = "_acc_iat" - accountsIAT2FACookieSuffix string = "_acc_iat_2fa" ) -func (c *Controllers) OAuthDynamicRegistrationIATAuth(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATAuth") - logRequest(logger, ctx) - - baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ - ClientID: ctx.Query("client_id"), - RedirectURI: ctx.Query("redirect_uri"), - } - if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - - responseType := ctx.Query("response_type") - state := ctx.Query("state") - if responseType != "code" { - return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorUnsupportedResponseType) - } - - qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ - ResponseType: responseType, - Challenge: ctx.Query("code_challenge"), - ChallengeMethod: ctx.Query("code_challenge_method"), - State: state, - } - if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { - return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorInvalidRequest) - } - - sessionKey := ctx.Cookies(c.cookieName + accountsIATCookieSuffix) - if sessionKey != "" { - // This ensures that the key is only used once - c.removeAccountIATCookie(ctx) - } - - redirectURL, serviceErr := c.services.InitiateOAuthDynamicRegistrationIATAuth( - ctx.UserContext(), - services.InitiateOAuthDynamicRegistrationIATAuthOptions{ - RequestID: requestID, - Domain: baseQPrms.ClientID, - State: qPrms.State, - SessionKey: sessionKey, - RefreshToken: ctx.Cookies(c.cookieName + refreshCookieSuffix), - Challenge: qPrms.Challenge, - ChallengeMethod: qPrms.ChallengeMethod, - RedirectURI: baseQPrms.RedirectURI, - BackendDomain: c.backendDomain, - }, - ) - if serviceErr != nil { - return c.redirectServiceErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusFound) - return ctx.Redirect(redirectURL, fiber.StatusFound) -} - -func (c *Controllers) OAuthDynamicRegistrationIATLoginGet(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATLoginGet") - logRequest(logger, ctx) - - uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ - ACCClientID: ctx.Params("accClientID"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) - } - - baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ - ClientID: ctx.Query("client_id"), - RedirectURI: ctx.Query("redirect_uri"), - } - if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - - qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ - ResponseType: ctx.Query("response_type"), - Challenge: ctx.Query("code_challenge"), - ChallengeMethod: ctx.Query("code_challenge_method"), - State: ctx.Query("state"), - } - if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - - loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthRender( - ctx.UserContext(), - services.OAuthDynamicRegistrationIATAuthRenderOptions{ - RequestID: requestID, - ACCClientID: uPrms.ACCClientID, - State: qPrms.State, - Domain: baseQPrms.ClientID, - CodeChallenge: qPrms.Challenge, - CodeChallengeMethod: qPrms.ChallengeMethod, - RedirectURI: baseQPrms.RedirectURI, - }, - ) - if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusOK) - return ctx.Status(fiber.StatusOK).Type("html").SendString(loginHTML) -} - -func (c *Controllers) saveAccountIATCookie( - ctx *fiber.Ctx, - sessionKey string, -) { - ctx.Cookie(&fiber.Cookie{ - Name: c.cookieName + accountsIATCookieSuffix, - Value: sessionKey, - Path: paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + paths.OAuthAuth, - HTTPOnly: true, - SameSite: fiber.CookieSameSiteLaxMode, - Secure: true, - MaxAge: int(c.services.GetOAuthCodeTTL()), - }) -} - -func (c *Controllers) removeAccountIATCookie(ctx *fiber.Ctx) { - ctx.Cookie(&fiber.Cookie{ - Name: c.cookieName + accountsIATCookieSuffix, - Value: "", - Path: paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + paths.OAuthAuth, - HTTPOnly: true, - Secure: true, - SameSite: fiber.CookieSameSiteNoneMode, - MaxAge: -1, - }) -} - -func (c *Controllers) saveAccountIAT2FACookie(ctx *fiber.Ctx, sessionID, clientID string) { - ctx.Cookie(&fiber.Cookie{ - Name: c.cookieName + accountsIAT2FACookieSuffix, - Value: sessionID, - Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + "/" + clientID + paths.OAuthAuth, - HTTPOnly: true, - SameSite: fiber.CookieSameSiteLaxMode, - Secure: true, - MaxAge: int(c.services.GetOAuthCodeTTL()), - }) -} +const oauthDynamicRegistration string = "oauth_dynamic_registration" -func (c *Controllers) removeAccountIAT2FACookie(ctx *fiber.Ctx, clientID string) { - ctx.Cookie(&fiber.Cookie{ - Name: c.cookieName + accountsIAT2FACookieSuffix, - Value: "", - Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + "/" + clientID + paths.OAuthAuth, - HTTPOnly: true, - SameSite: fiber.CookieSameSiteLaxMode, - Secure: true, - MaxAge: -1, - }) -} - -func (c *Controllers) OAuthDynamicRegistrationIATLoginPost(ctx *fiber.Ctx) error { +func (c *Controllers) OAuthDynamicRegistration(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATLoginPost") + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistration") logRequest(logger, ctx) - uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ - ACCClientID: ctx.Params("accClientID"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) - } - - if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError("Only application/x-www-form-urlencoded is supported")) - } - - hiddenFields := bodies.OAuthDynamicRegistrationIATAuthHiddenFieldsBody{ - CSRFToken: ctx.FormValue("csrf_token"), - ClientID: ctx.FormValue("client_id"), - ResponseType: ctx.FormValue("response_type"), - CodeChallenge: ctx.FormValue("code_challenge"), - CodeChallengeMethod: ctx.FormValue("code_challenge_method"), - State: ctx.FormValue("state"), - RedirectURI: ctx.FormValue("redirect_uri"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &hiddenFields); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - - loginBody := bodies.LoginBody{ - Email: ctx.FormValue("email"), - Password: ctx.FormValue("password"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &loginBody); err != nil { - valErr := validationErrorException(exceptions.ValidationResponseLocationBody, err) - loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthReRender( - ctx.UserContext(), - services.OAuthDynamicRegistrationIATAuthReRenderOptions{ - RequestID: requestID, - Errors: utils.MapSlice(valErr.Fields, func(t *exceptions.FieldError) string { - return fmt.Sprintf("%s %s", t.Param, t.Message) - }), - CSRFToken: hiddenFields.CSRFToken, - ACCClientID: uPrms.ACCClientID, - State: hiddenFields.State, - Domain: hiddenFields.ClientID, - CodeChallenge: hiddenFields.CodeChallenge, - CodeChallengeMethod: hiddenFields.CodeChallengeMethod, - RedirectURI: hiddenFields.RedirectURI, - }, - ) - if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusOK) - return ctx. - Status(fiber.StatusOK). - Type("html"). - SendString(loginHTML) - } - - redirectURL, sessionKey, loggedIn, serviceErr := c.services.OAuthDynamicRegistrationIATLogin( - ctx.UserContext(), - services.OAuthDynamicRegistrationIATLoginOptions{ - RequestID: requestID, - ACCClientID: uPrms.ACCClientID, - Domain: hiddenFields.ClientID, - CSRFToken: hiddenFields.CSRFToken, - CodeChallenge: hiddenFields.CodeChallenge, - CodeChallengeMethod: hiddenFields.CodeChallengeMethod, - State: hiddenFields.State, - RedirectURI: hiddenFields.RedirectURI, - Email: loginBody.Email, - Password: loginBody.Password, - BackendDomain: c.backendDomain, - }, - ) - if serviceErr != nil { - if serviceErr.Code == exceptions.CodeUnauthorized { - loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthReRender( - ctx.UserContext(), - services.OAuthDynamicRegistrationIATAuthReRenderOptions{ - RequestID: requestID, - Errors: []string{"Invalid credentials"}, - CSRFToken: hiddenFields.CSRFToken, - ACCClientID: uPrms.ACCClientID, - State: hiddenFields.State, - Domain: hiddenFields.ClientID, - CodeChallenge: hiddenFields.CodeChallenge, - CodeChallengeMethod: hiddenFields.CodeChallengeMethod, - RedirectURI: hiddenFields.RedirectURI, - }, - ) - if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusOK) - return ctx. - Status(fiber.StatusOK). - Type("html"). - SendString(loginHTML) - } - - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - if loggedIn { - c.saveAccountIAT2FACookie(ctx, sessionKey, uPrms.ACCClientID) - logResponse(logger, ctx, fiber.StatusSeeOther) - return ctx.Redirect(redirectURL, fiber.StatusSeeOther) - } - - c.saveAccountIATCookie(ctx, sessionKey) - logResponse(logger, ctx, fiber.StatusSeeOther) - return ctx.Redirect(redirectURL, fiber.StatusSeeOther) -} - -func (c *Controllers) OAuthDynamicRegistrationIAT2FAGet(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIAT2FAGet") - logRequest(logger, ctx) - - uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ - ACCClientID: ctx.Params("accClientID"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) - } - - baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ - ClientID: ctx.Query("client_id"), - RedirectURI: ctx.Query("redirect_uri"), - } - if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - - qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ - ResponseType: ctx.Query("response_type"), - Challenge: ctx.Query("code_challenge"), - ChallengeMethod: ctx.Query("code_challenge_method"), - State: ctx.Query("state"), - } - if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - - sessionID := ctx.Cookies(c.cookieName + accountsIAT2FACookieSuffix) - if sessionID == "" { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnauthorizedError()) - } - - twoFAHTML, serviceErr := c.services.OAuthDynamicRegistrationIAT2FARender( - ctx.UserContext(), - services.OAuthDynamicRegistrationIAT2FARenderOptions{ - RequestID: requestID, - Domain: baseQPrms.ClientID, - ACCClientID: uPrms.ACCClientID, - SessionID: sessionID, - Challenge: qPrms.Challenge, - ChallengeMethod: qPrms.ChallengeMethod, - State: qPrms.State, - RedirectURI: baseQPrms.RedirectURI, - }, - ) - if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusOK) - return ctx.Status(fiber.StatusOK).Type("html").SendString(twoFAHTML) -} - -func (c *Controllers) OAuthDynamicRegistrationIAT2FAPost(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIAT2FAPost") - logRequest(logger, ctx) - - uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ - ACCClientID: ctx.Params("accClientID"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) - } - - sessionID := ctx.Cookies(c.cookieName + accountsIAT2FACookieSuffix) - if sessionID == "" { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnauthorizedError()) - } - - if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError("Only application/x-www-form-urlencoded is supported")) - } - - hiddenFields := bodies.OAuthDynamicRegistrationIATAuthHiddenFieldsBody{ - CSRFToken: ctx.FormValue("csrf_token"), - ClientID: ctx.FormValue("client_id"), - ResponseType: ctx.FormValue("response_type"), - CodeChallenge: ctx.FormValue("code_challenge"), - CodeChallengeMethod: ctx.FormValue("code_challenge_method"), - State: ctx.FormValue("state"), - RedirectURI: ctx.FormValue("redirect_uri"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &hiddenFields); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - - twoFABody := bodies.TwoFactorLoginBody{ - Code: ctx.FormValue("code"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &twoFABody); err != nil { - valErr := validationErrorException(exceptions.ValidationResponseLocationBody, err) - twoFAHTML, serviceErr := c.services.OAuthDynamicRegistrationIAT2FAReRender( - ctx.UserContext(), - services.OAuthDynamicRegistrationIAT2FAReRenderOptions{ - RequestID: requestID, - Domain: hiddenFields.ClientID, - ACCClientID: uPrms.ACCClientID, - SessionID: sessionID, - Errors: utils.MapSlice(valErr.Fields, func(t *exceptions.FieldError) string { - return fmt.Sprintf("%s %s", t.Param, t.Message) - }), - CSRFToken: hiddenFields.CSRFToken, - Challenge: hiddenFields.CodeChallenge, - ChallengeMethod: hiddenFields.CodeChallengeMethod, - State: hiddenFields.State, - RedirectURI: hiddenFields.RedirectURI, - }, - ) - if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusOK) - return ctx. - Status(fiber.StatusOK). - Type("html"). - SendString(twoFAHTML) - } - - redirectURL, sessionKey, serviceErr := c.services.OAuthDynamicRegistrationIATVerify2FACode( - ctx.UserContext(), - services.OAuthDynamicRegistrationIATVerify2FACodeOptions{ - RequestID: requestID, - ACCClientID: uPrms.ACCClientID, - Domain: hiddenFields.ClientID, - SessionID: sessionID, - CSRFToken: hiddenFields.CSRFToken, - Code: twoFABody.Code, - BackendDomain: c.backendDomain, - }, - ) - if serviceErr != nil { - if serviceErr.Code == exceptions.CodeUnauthorized { - twoFAHTML, serviceErr := c.services.OAuthDynamicRegistrationIAT2FAReRender( - ctx.UserContext(), - services.OAuthDynamicRegistrationIAT2FAReRenderOptions{ - RequestID: requestID, - Domain: hiddenFields.ClientID, - ACCClientID: uPrms.ACCClientID, - SessionID: sessionID, - Errors: []string{"Invalid 2FA code"}, - CSRFToken: hiddenFields.CSRFToken, - Challenge: hiddenFields.CodeChallenge, - ChallengeMethod: hiddenFields.CodeChallengeMethod, - State: hiddenFields.State, - RedirectURI: hiddenFields.RedirectURI, - }, - ) - if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusOK) - return ctx. - Status(fiber.StatusOK). - Type("html"). - SendString(twoFAHTML) - } - - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - c.removeAccountIAT2FACookie(ctx, uPrms.ACCClientID) - c.saveAccountIATCookie(ctx, sessionKey) - logResponse(logger, ctx, fiber.StatusSeeOther) - return ctx.Redirect(redirectURL, fiber.StatusSeeOther) -} - -func (c *Controllers) OAuthDynamicRegistrationIATExtAuthGet(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATExtAuthGet") - logRequest(logger, ctx) - - uPrms := params.OAuthDynamicRegistrationIATExtAuthURLParams{ - ACCClientID: ctx.Params("accClientID"), - Provider: ctx.Params("provider"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) - } - - baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ - ClientID: ctx.Query("client_id"), - RedirectURI: ctx.Query("redirect_uri"), - } - if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - - responseType := ctx.Query("response_type") - state := ctx.Query("state") - if responseType != "code" { - return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorUnsupportedResponseType) - } - - qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ - ResponseType: responseType, - Challenge: ctx.Query("code_challenge"), - ChallengeMethod: ctx.Query("code_challenge_method"), - State: state, - } - if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { - return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorInvalidRequest) - } - - authURL, serviceErr := c.services.OAuthDynamicRegistrationIATExtGet( - ctx.UserContext(), - services.OAuthDynamicRegistrationIATExtGetOptions{ - RequestID: requestID, - ACCClientID: uPrms.ACCClientID, - Provider: uPrms.Provider, - Domain: baseQPrms.ClientID, - CallbackURL: baseQPrms.RedirectURI, - RedirectURI: baseQPrms.RedirectURI, - State: qPrms.State, - BackendDomain: c.backendDomain, - }, - ) - if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusFound) - return ctx.Redirect(authURL, fiber.StatusFound) -} - -func (c *Controllers) OAuthDynamicRegistrationIATExtCB(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATExtCB") - logRequest(logger, ctx) - - uPrms := params.OAuthDynamicRegistrationIATExtAuthURLParams{ - ACCClientID: ctx.Params("accClientID"), - Provider: ctx.Params("provider"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) - } - - qPrms := params.OAuthCallbackQueryParams{ - Code: ctx.Query("code"), - State: ctx.Query("state"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &qPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - - cbURL, serviceErr := c.services.OAuthDynamicRegistrationIATExtCB( - ctx.UserContext(), - services.OAuthDynamicRegistrationIATExtCBOptions{ - RequestID: requestID, - ACCClientID: uPrms.ACCClientID, - Provider: uPrms.Provider, - State: qPrms.State, - Code: qPrms.Code, - RedirectURL: "https://" + c.backendDomain + paths.V1 + paths.AccountsBase + - paths.CredentialsBase + paths.DynamicRegistrationBase + paths.InitialAccessToken + - "/" + uPrms.ACCClientID + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT + "/" + - uPrms.Provider + paths.InitialAccessTokenCallback, - BackendDomain: c.backendDomain, - }, - ) - if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) - } - - logResponse(logger, ctx, fiber.StatusFound) - return ctx.Redirect(cbURL, fiber.StatusFound) -} - -func (c *Controllers) OAuthDynamicRegistrationIATExtAppleCB(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATExtAppleCB") - logRequest(logger, ctx) - - uPrms := params.OAuthDynamicRegistrationIATExtAppleURLParams{ - ACCClientID: ctx.Params("accClientID"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + urlParams := params.AccountURLParams{AccountPublicID: ctx.Params("accountPublicID")} + if err := c.validate.StructCtx(ctx.UserContext(), &urlParams); err != nil { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidRequest) } - if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError("Only application/x-www-form-urlencoded is supported")) + accountPublicID, err := uuid.Parse(urlParams.AccountPublicID) + if err != nil { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidRequest) } - qPrms := bodies.OAuthDynamicRegistrationIATExtAppleBody{ - Code: ctx.FormValue("code"), - State: ctx.FormValue("state"), - User: ctx.FormValue("user"), + body := new(bodies.OAuthDynamicClientRegistrationBody) + if err := ctx.BodyParser(body); err != nil { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidClientMetadata) } - if err := c.validate.StructCtx(ctx.UserContext(), &qPrms); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidClientMetadata) } - user := new(bodies.OAuthDynamicRegistrationIATExtAppleUserBody) - if err := json.Unmarshal([]byte(qPrms.User), user); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) - } - if err := c.validate.StructCtx(ctx.UserContext(), user); err != nil { - return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + isAuthenticated, ok := ctx.Locals("isAuthenticated").(bool) + if !ok { + logger.ErrorContext(ctx.UserContext(), "isAuthenticated should be set in context by middleware") + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorServerError) } - cbURL, serviceErr := c.services.OAuthDynamicRegistrationIATExtAppleCB( - ctx.UserContext(), - services.OAuthDynamicRegistrationIATExtAppleCBOptions{ - RequestID: requestID, - ACCClientID: uPrms.ACCClientID, - Email: user.Email, - Code: qPrms.Code, - State: qPrms.State, - RedirectURL: "https://" + c.backendDomain + paths.V1 + paths.AccountsBase + - paths.CredentialsBase + paths.DynamicRegistrationBase + paths.InitialAccessToken + - "/" + uPrms.ACCClientID + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT + "/" + - services.AuthProviderApple + paths.InitialAccessTokenCallback, - BackendDomain: c.backendDomain, - }, - ) - if serviceErr != nil { - return serviceErrorHTMLResponse(logger, ctx, serviceErr) + domain, ok := ctx.Locals("domain").(string) + if isAuthenticated && !ok { + logger.ErrorContext(ctx.UserContext(), "domain should be set in context by middleware") + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorServerError) } - logResponse(logger, ctx, fiber.StatusFound) - return ctx.Redirect(cbURL, fiber.StatusFound) -} - -func (c *Controllers) OAuthDynamicRegistrationIATToken(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthDynamicRegistrationIATToken") - logRequest(logger, ctx) - - if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { - return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidRequest) - } - - grantType := ctx.Get("grant_type") - if grantType != "authorization_code" { - return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorUnsupportedGrantType) - } - - body := bodies.OAuthDynamicRegistrationIATTokenBody{ - GrantType: grantType, - Code: ctx.FormValue("code"), - ClientID: ctx.FormValue("client_id"), - CodeVerifier: ctx.FormValue("code_verifier"), - } - if err := c.validate.StructCtx(ctx.UserContext(), &body); err != nil { - return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidRequest) + account, ok := ctx.Locals("account").(tokens.AccountClaims) + if isAuthenticated && !ok { + logger.ErrorContext(ctx.UserContext(), "account should be set in context by middleware") + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorServerError) } - authDTO, serviceErr := c.services.VerifyOAuthDynamicRegistrationIATCode( + accountCredentialsDTO, serviceErr := c.services.CreateAccountCredentialsRegistration( ctx.UserContext(), - services.VerifyOAuthDynamicRegistrationIATCodeOptions{ - RequestID: requestID, - Code: body.Code, - CodeVerifier: body.CodeVerifier, - Domain: body.ClientID, + services.CreateAccountCredentialsRegistrationOptions{ + RequestID: requestID, + AccountPublicID: accountPublicID, + IsAuthenticated: isAuthenticated, + IATDomain: domain, + AccountVersion: account.AccountVersion, + ApplicationType: body.ApplicationType, + RedirectURIs: body.RedirectURIs, + TokenEndpointAuthMethod: body.TokenEndpointAuthMethod, + GrantTypes: body.GrantTypes, + ResponseTypes: body.ResponseTypes, + ClientName: body.ClientName, + ClientURI: body.ClientURI, + LogoURI: body.LogoURI, + TOSURI: body.TOSURI, + PolicyURI: body.PolicyURI, + Contacts: body.Contacts, + SoftwareID: body.SoftwareID, + SoftwareVersion: body.SoftwareVersion, + SoftwareStatement: body.SoftwareStatement, + JWKsURI: body.JWKsURI, + JWKs: body.JWKs, + FrontendDomain: c.frontendDomain, + BackendDomain: c.backendDomain, + RequireAuthTime: body.RequireAuthTime, + DefaultMaxAge: body.DefaultMaxAge, + SubjectType: body.SubjectType, + IDTokenSignedResponseAlg: body.IDTokenSignedResponseAlg, + IDTokenEncryptedResponseAlg: body.IDTokenEncryptedResponseAlg, + IDTokenEncryptedResponseEnc: body.IDTokenEncryptedResponseEnc, + RequestObjectSigningAlg: body.RequestObjectSigningAlg, + RequestObjectEncryptionAlg: body.RequestObjectEncryptionAlg, + RequestObjectEncryptionEnc: body.RequestObjectEncryptionEnc, + DefaultACRValues: body.DefaultACRValues, + Scope: body.Scope, + SectorIdentifierURI: body.SectorIdentifierURI, + InitiateLoginURI: body.InitiateLoginURI, + RequestURIs: body.RequestURIs, + UserInfoSignedResponseAlg: body.UserInfoSignedResponseAlg, + UserInfoEncryptedResponseAlg: body.UserInfoEncryptedResponseAlg, + UserInfoEncryptedResponseEnc: body.UserInfoEncryptedResponseEnc, + TokenEndpointAuthSigningAlg: body.TokenEndpointAuthSigningAlg, + AccessTokenSigningAlg: body.AccessTokenSigningAlg, }, ) if serviceErr != nil { - return oauthErrorResponseMapper(logger, ctx, serviceErr) + return dynamicRegistrationServiceError(logger, ctx, serviceErr) } - logResponse(logger, ctx, fiber.StatusOK) - return ctx.Status(fiber.StatusOK).JSON(authDTO) + logResponse(logger, ctx, fiber.StatusCreated) + return ctx.Status(fiber.StatusCreated).JSON(&accountCredentialsDTO) } diff --git a/idp/internal/controllers/oauth_dynamic_registration_iat.go b/idp/internal/controllers/oauth_dynamic_registration_iat.go new file mode 100644 index 0000000..fdc607b --- /dev/null +++ b/idp/internal/controllers/oauth_dynamic_registration_iat.go @@ -0,0 +1,676 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package controllers + +import ( + "encoding/json" + "fmt" + + "github.com/gofiber/fiber/v2" + + "github.com/tugascript/devlogs/idp/internal/controllers/bodies" + "github.com/tugascript/devlogs/idp/internal/controllers/params" + "github.com/tugascript/devlogs/idp/internal/controllers/paths" + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/services" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const ( + oauthDynamicRegistrationIAT string = "oauth_dynamic_registration_iat" + + accountsIATCookieSuffix string = "_acc_iat" + accountsIAT2FACookieSuffix string = "_acc_iat_2fa" +) + +func (c *Controllers) OAuthDynamicRegistrationIATAuth(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistrationIAT, "OAuthDynamicRegistrationIATAuth") + logRequest(logger, ctx) + + baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ + ClientID: ctx.Query("client_id"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + responseType := ctx.Query("response_type") + state := ctx.Query("state") + if responseType != "code" { + return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorUnsupportedResponseType) + } + + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ + ResponseType: responseType, + Challenge: ctx.Query("code_challenge"), + ChallengeMethod: ctx.Query("code_challenge_method"), + State: state, + } + if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { + return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorInvalidRequest) + } + + sessionKey := ctx.Cookies(c.cookieName + accountsIATCookieSuffix) + if sessionKey != "" { + // This ensures that the key is only used once + c.removeAccountIATCookie(ctx) + } + + redirectURL, serviceErr := c.services.InitiateOAuthDynamicRegistrationIATAuth( + ctx.UserContext(), + services.InitiateOAuthDynamicRegistrationIATAuthOptions{ + RequestID: requestID, + Domain: baseQPrms.ClientID, + State: qPrms.State, + SessionKey: sessionKey, + RefreshToken: ctx.Cookies(c.cookieName + refreshCookieSuffix), + Challenge: qPrms.Challenge, + ChallengeMethod: qPrms.ChallengeMethod, + RedirectURI: baseQPrms.RedirectURI, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + return c.redirectServiceErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect(redirectURL, fiber.StatusFound) +} + +func (c *Controllers) OAuthDynamicRegistrationIATLoginGet(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistrationIAT, "OAuthDynamicRegistrationIATLoginGet") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ + ClientID: ctx.Query("client_id"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ + ResponseType: ctx.Query("response_type"), + Challenge: ctx.Query("code_challenge"), + ChallengeMethod: ctx.Query("code_challenge_method"), + State: ctx.Query("state"), + } + if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATAuthRenderOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + State: qPrms.State, + Domain: baseQPrms.ClientID, + CodeChallenge: qPrms.Challenge, + CodeChallengeMethod: qPrms.ChallengeMethod, + RedirectURI: baseQPrms.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).Type("html").SendString(loginHTML) +} + +func (c *Controllers) saveAccountIATCookie( + ctx *fiber.Ctx, + sessionKey string, +) { + ctx.Cookie(&fiber.Cookie{ + Name: c.cookieName + accountsIATCookieSuffix, + Value: sessionKey, + Path: paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + paths.OAuthAuth, + HTTPOnly: true, + SameSite: fiber.CookieSameSiteLaxMode, + Secure: true, + MaxAge: int(c.services.GetOAuthCodeTTL()), + }) +} + +func (c *Controllers) removeAccountIATCookie(ctx *fiber.Ctx) { + ctx.Cookie(&fiber.Cookie{ + Name: c.cookieName + accountsIATCookieSuffix, + Value: "", + Path: paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + paths.OAuthAuth, + HTTPOnly: true, + Secure: true, + SameSite: fiber.CookieSameSiteNoneMode, + MaxAge: -1, + }) +} + +func (c *Controllers) saveAccountIAT2FACookie(ctx *fiber.Ctx, sessionID, clientID string) { + ctx.Cookie(&fiber.Cookie{ + Name: c.cookieName + accountsIAT2FACookieSuffix, + Value: sessionID, + Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + "/" + clientID + paths.OAuthAuth, + HTTPOnly: true, + SameSite: fiber.CookieSameSiteLaxMode, + Secure: true, + MaxAge: int(c.services.GetOAuthCodeTTL()), + }) +} + +func (c *Controllers) removeAccountIAT2FACookie(ctx *fiber.Ctx, clientID string) { + ctx.Cookie(&fiber.Cookie{ + Name: c.cookieName + accountsIAT2FACookieSuffix, + Value: "", + Path: paths.AccountsBase + paths.CredentialsBase + paths.InitialAccessToken + "/" + clientID + paths.OAuthAuth, + HTTPOnly: true, + SameSite: fiber.CookieSameSiteLaxMode, + Secure: true, + MaxAge: -1, + }) +} + +func (c *Controllers) OAuthDynamicRegistrationIATLoginPost(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistrationIAT, "OAuthDynamicRegistrationIATLoginPost") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError("Only application/x-www-form-urlencoded is supported")) + } + + hiddenFields := bodies.OAuthDynamicRegistrationIATAuthHiddenFieldsBody{ + CSRFToken: ctx.FormValue("csrf_token"), + ClientID: ctx.FormValue("client_id"), + ResponseType: ctx.FormValue("response_type"), + CodeChallenge: ctx.FormValue("code_challenge"), + CodeChallengeMethod: ctx.FormValue("code_challenge_method"), + State: ctx.FormValue("state"), + RedirectURI: ctx.FormValue("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &hiddenFields); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + loginBody := bodies.LoginBody{ + Email: ctx.FormValue("email"), + Password: ctx.FormValue("password"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &loginBody); err != nil { + valErr := validationErrorException(exceptions.ValidationResponseLocationBody, err) + loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthReRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATAuthReRenderOptions{ + RequestID: requestID, + Errors: utils.MapSlice(valErr.Fields, func(t *exceptions.FieldError) string { + return fmt.Sprintf("%s %s", t.Param, t.Message) + }), + CSRFToken: hiddenFields.CSRFToken, + ACCClientID: uPrms.ACCClientID, + State: hiddenFields.State, + Domain: hiddenFields.ClientID, + CodeChallenge: hiddenFields.CodeChallenge, + CodeChallengeMethod: hiddenFields.CodeChallengeMethod, + RedirectURI: hiddenFields.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx. + Status(fiber.StatusOK). + Type("html"). + SendString(loginHTML) + } + + redirectURL, sessionKey, loggedIn, serviceErr := c.services.OAuthDynamicRegistrationIATLogin( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATLoginOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Domain: hiddenFields.ClientID, + CSRFToken: hiddenFields.CSRFToken, + CodeChallenge: hiddenFields.CodeChallenge, + CodeChallengeMethod: hiddenFields.CodeChallengeMethod, + State: hiddenFields.State, + RedirectURI: hiddenFields.RedirectURI, + Email: loginBody.Email, + Password: loginBody.Password, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + if serviceErr.Code == exceptions.CodeUnauthorized { + loginHTML, serviceErr := c.services.OAuthDynamicRegistrationIATAuthReRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATAuthReRenderOptions{ + RequestID: requestID, + Errors: []string{"Invalid credentials"}, + CSRFToken: hiddenFields.CSRFToken, + ACCClientID: uPrms.ACCClientID, + State: hiddenFields.State, + Domain: hiddenFields.ClientID, + CodeChallenge: hiddenFields.CodeChallenge, + CodeChallengeMethod: hiddenFields.CodeChallengeMethod, + RedirectURI: hiddenFields.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx. + Status(fiber.StatusOK). + Type("html"). + SendString(loginHTML) + } + + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + if loggedIn { + c.saveAccountIAT2FACookie(ctx, sessionKey, uPrms.ACCClientID) + logResponse(logger, ctx, fiber.StatusSeeOther) + return ctx.Redirect(redirectURL, fiber.StatusSeeOther) + } + + c.saveAccountIATCookie(ctx, sessionKey) + logResponse(logger, ctx, fiber.StatusSeeOther) + return ctx.Redirect(redirectURL, fiber.StatusSeeOther) +} + +func (c *Controllers) OAuthDynamicRegistrationIAT2FAGet(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistrationIAT, "OAuthDynamicRegistrationIAT2FAGet") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ + ClientID: ctx.Query("client_id"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ + ResponseType: ctx.Query("response_type"), + Challenge: ctx.Query("code_challenge"), + ChallengeMethod: ctx.Query("code_challenge_method"), + State: ctx.Query("state"), + } + if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + sessionID := ctx.Cookies(c.cookieName + accountsIAT2FACookieSuffix) + if sessionID == "" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnauthorizedError()) + } + + twoFAHTML, serviceErr := c.services.OAuthDynamicRegistrationIAT2FARender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIAT2FARenderOptions{ + RequestID: requestID, + Domain: baseQPrms.ClientID, + ACCClientID: uPrms.ACCClientID, + SessionID: sessionID, + Challenge: qPrms.Challenge, + ChallengeMethod: qPrms.ChallengeMethod, + State: qPrms.State, + RedirectURI: baseQPrms.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).Type("html").SendString(twoFAHTML) +} + +func (c *Controllers) OAuthDynamicRegistrationIAT2FAPost(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistrationIAT, "OAuthDynamicRegistrationIAT2FAPost") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + sessionID := ctx.Cookies(c.cookieName + accountsIAT2FACookieSuffix) + if sessionID == "" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnauthorizedError()) + } + + if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError("Only application/x-www-form-urlencoded is supported")) + } + + hiddenFields := bodies.OAuthDynamicRegistrationIATAuthHiddenFieldsBody{ + CSRFToken: ctx.FormValue("csrf_token"), + ClientID: ctx.FormValue("client_id"), + ResponseType: ctx.FormValue("response_type"), + CodeChallenge: ctx.FormValue("code_challenge"), + CodeChallengeMethod: ctx.FormValue("code_challenge_method"), + State: ctx.FormValue("state"), + RedirectURI: ctx.FormValue("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &hiddenFields); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + twoFABody := bodies.TwoFactorLoginBody{ + Code: ctx.FormValue("code"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &twoFABody); err != nil { + valErr := validationErrorException(exceptions.ValidationResponseLocationBody, err) + twoFAHTML, serviceErr := c.services.OAuthDynamicRegistrationIAT2FAReRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIAT2FAReRenderOptions{ + RequestID: requestID, + Domain: hiddenFields.ClientID, + ACCClientID: uPrms.ACCClientID, + SessionID: sessionID, + Errors: utils.MapSlice(valErr.Fields, func(t *exceptions.FieldError) string { + return fmt.Sprintf("%s %s", t.Param, t.Message) + }), + CSRFToken: hiddenFields.CSRFToken, + Challenge: hiddenFields.CodeChallenge, + ChallengeMethod: hiddenFields.CodeChallengeMethod, + State: hiddenFields.State, + RedirectURI: hiddenFields.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx. + Status(fiber.StatusOK). + Type("html"). + SendString(twoFAHTML) + } + + redirectURL, sessionKey, serviceErr := c.services.OAuthDynamicRegistrationIATVerify2FACode( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATVerify2FACodeOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Domain: hiddenFields.ClientID, + SessionID: sessionID, + CSRFToken: hiddenFields.CSRFToken, + Code: twoFABody.Code, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + if serviceErr.Code == exceptions.CodeUnauthorized { + twoFAHTML, serviceErr := c.services.OAuthDynamicRegistrationIAT2FAReRender( + ctx.UserContext(), + services.OAuthDynamicRegistrationIAT2FAReRenderOptions{ + RequestID: requestID, + Domain: hiddenFields.ClientID, + ACCClientID: uPrms.ACCClientID, + SessionID: sessionID, + Errors: []string{"Invalid 2FA code"}, + CSRFToken: hiddenFields.CSRFToken, + Challenge: hiddenFields.CodeChallenge, + ChallengeMethod: hiddenFields.CodeChallengeMethod, + State: hiddenFields.State, + RedirectURI: hiddenFields.RedirectURI, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx. + Status(fiber.StatusOK). + Type("html"). + SendString(twoFAHTML) + } + + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + c.removeAccountIAT2FACookie(ctx, uPrms.ACCClientID) + c.saveAccountIATCookie(ctx, sessionKey) + logResponse(logger, ctx, fiber.StatusSeeOther) + return ctx.Redirect(redirectURL, fiber.StatusSeeOther) +} + +func (c *Controllers) OAuthDynamicRegistrationIATExtAuthGet(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistrationIAT, "OAuthDynamicRegistrationIATExtAuthGet") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATExtAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + Provider: ctx.Params("provider"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + baseQPrms := params.OAuthDynamicRegistrationIATAuthBaseQueryParams{ + ClientID: ctx.Query("client_id"), + RedirectURI: ctx.Query("redirect_uri"), + } + if err := c.validate.StructCtx(ctx.UserContext(), baseQPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + responseType := ctx.Query("response_type") + state := ctx.Query("state") + if responseType != "code" { + return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorUnsupportedResponseType) + } + + qPrms := params.OAuthDynamicRegistrationIATAuthQueryParams{ + ResponseType: responseType, + Challenge: ctx.Query("code_challenge"), + ChallengeMethod: ctx.Query("code_challenge_method"), + State: state, + } + if err := c.validate.StructCtx(ctx.UserContext(), qPrms); err != nil { + return c.redirectErrorCallback(logger, ctx, baseQPrms.RedirectURI, state, exceptions.OAuthErrorInvalidRequest) + } + + authURL, serviceErr := c.services.OAuthDynamicRegistrationIATExtGet( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATExtGetOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Provider: uPrms.Provider, + Domain: baseQPrms.ClientID, + CallbackURL: baseQPrms.RedirectURI, + RedirectURI: baseQPrms.RedirectURI, + State: qPrms.State, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect(authURL, fiber.StatusFound) +} + +func (c *Controllers) OAuthDynamicRegistrationIATExtCB(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistrationIAT, "OAuthDynamicRegistrationIATExtCB") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATExtAuthURLParams{ + ACCClientID: ctx.Params("accClientID"), + Provider: ctx.Params("provider"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + qPrms := params.OAuthCallbackQueryParams{ + Code: ctx.Query("code"), + State: ctx.Query("state"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &qPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + cbURL, serviceErr := c.services.OAuthDynamicRegistrationIATExtCB( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATExtCBOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Provider: uPrms.Provider, + State: qPrms.State, + Code: qPrms.Code, + RedirectURL: "https://" + c.backendDomain + paths.V1 + paths.AccountsBase + + paths.CredentialsBase + paths.DynamicRegistrationBase + paths.InitialAccessToken + + "/" + uPrms.ACCClientID + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT + "/" + + uPrms.Provider + paths.InitialAccessTokenCallback, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect(cbURL, fiber.StatusFound) +} + +func (c *Controllers) OAuthDynamicRegistrationIATExtAppleCB(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistrationIAT, "OAuthDynamicRegistrationIATExtAppleCB") + logRequest(logger, ctx) + + uPrms := params.OAuthDynamicRegistrationIATExtAppleURLParams{ + ACCClientID: ctx.Params("accClientID"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &uPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewNotFoundError()) + } + + if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewUnsupportedMediaTypeError("Only application/x-www-form-urlencoded is supported")) + } + + qPrms := bodies.OAuthDynamicRegistrationIATExtAppleBody{ + Code: ctx.FormValue("code"), + State: ctx.FormValue("state"), + User: ctx.FormValue("user"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &qPrms); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + user := new(bodies.OAuthDynamicRegistrationIATExtAppleUserBody) + if err := json.Unmarshal([]byte(qPrms.User), user); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + if err := c.validate.StructCtx(ctx.UserContext(), user); err != nil { + return serviceErrorHTMLResponse(logger, ctx, exceptions.NewForbiddenError()) + } + + cbURL, serviceErr := c.services.OAuthDynamicRegistrationIATExtAppleCB( + ctx.UserContext(), + services.OAuthDynamicRegistrationIATExtAppleCBOptions{ + RequestID: requestID, + ACCClientID: uPrms.ACCClientID, + Email: user.Email, + Code: qPrms.Code, + State: qPrms.State, + RedirectURL: "https://" + c.backendDomain + paths.V1 + paths.AccountsBase + + paths.CredentialsBase + paths.DynamicRegistrationBase + paths.InitialAccessToken + + "/" + uPrms.ACCClientID + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT + "/" + + services.AuthProviderApple + paths.InitialAccessTokenCallback, + BackendDomain: c.backendDomain, + }, + ) + if serviceErr != nil { + return serviceErrorHTMLResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusFound) + return ctx.Redirect(cbURL, fiber.StatusFound) +} + +func (c *Controllers) OAuthDynamicRegistrationIATToken(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistrationIAT, "OAuthDynamicRegistrationIATToken") + logRequest(logger, ctx) + + if ctx.Get("Content-Type") != "application/x-www-form-urlencoded" { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidRequest) + } + + grantType := ctx.Get("grant_type") + if grantType != "authorization_code" { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorUnsupportedGrantType) + } + + body := bodies.OAuthDynamicRegistrationIATTokenBody{ + GrantType: grantType, + Code: ctx.FormValue("code"), + ClientID: ctx.FormValue("client_id"), + CodeVerifier: ctx.FormValue("code_verifier"), + } + if err := c.validate.StructCtx(ctx.UserContext(), &body); err != nil { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidRequest) + } + + authDTO, serviceErr := c.services.VerifyOAuthDynamicRegistrationIATCode( + ctx.UserContext(), + services.VerifyOAuthDynamicRegistrationIATCodeOptions{ + RequestID: requestID, + Code: body.Code, + CodeVerifier: body.CodeVerifier, + Domain: body.ClientID, + }, + ) + if serviceErr != nil { + return oauthErrorResponseMapper(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(authDTO) +} diff --git a/idp/internal/exceptions/controllers.go b/idp/internal/exceptions/controllers.go index 54d0742..a67da73 100644 --- a/idp/internal/exceptions/controllers.go +++ b/idp/internal/exceptions/controllers.go @@ -28,10 +28,16 @@ const ( OAuthErrorInvalidGrant string = "invalid_grant" OAuthErrorUnauthorizedClient string = "unauthorized_client" OAuthErrorAccessDenied string = "access_denied" - OAuthServerError string = "server_error" + OAuthErrorServerError string = "server_error" OAuthErrorInvalidScope string = "invalid_scope" OAuthErrorUnsupportedGrantType string = "unsupported_grant_type" OAuthErrorUnsupportedResponseType string = "unsupported_response_type" + + OAuthErrorInvalidRedirectURI string = "invalid_redirect_uri" + OAuthErrorInvalidClientMetadata string = "invalid_client_metadata" + OAuthErrorInvalidSoftwareStatement string = "invalid_software_statement" + OAuthErrorUnapprovedSoftwareStatement string = "unapproved_software_statement" + OAuthErrorInvalidToken string = "invalid_token" ) type ErrorResponse struct { diff --git a/idp/internal/exceptions/services.go b/idp/internal/exceptions/services.go index c2191b6..7a34dc0 100644 --- a/idp/internal/exceptions/services.go +++ b/idp/internal/exceptions/services.go @@ -23,6 +23,8 @@ const ( CodeUnauthorized string = "UNAUTHORIZED" CodeForbidden string = "FORBIDDEN" CodeUnsupportedMediaType string = "UNSUPPORTED_MEDIA_TYPE" + CodeInvalidToken string = "INVALID_TOKEN" + CodeUnauthorizedToken string = "UNAUTHORIZED_TOKEN" ) const ( @@ -91,6 +93,14 @@ func NewForbiddenError() *ServiceError { return NewError(CodeForbidden, MessageForbidden) } +func NewInvalidTokenError(message string) *ServiceError { + return NewError(CodeInvalidToken, message) +} + +func NewUnauthorizedTokenError(message string) *ServiceError { + return NewError(CodeUnauthorizedToken, message) +} + func NewForbiddenValidationError(message string) *ServiceError { return NewError(CodeForbidden, message) } diff --git a/idp/internal/providers/tokens/dynamic_registration.go b/idp/internal/providers/tokens/dynamic_registration.go deleted file mode 100644 index 1a73d72..0000000 --- a/idp/internal/providers/tokens/dynamic_registration.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2025 Afonso Barracha -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -package tokens - -import ( - "fmt" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" -) - -type accountCredentialsDynamicRegistrationClaims struct { - AccountClaims - Domain string `json:"domain"` - ClientID string `json:"client_id"` - jwt.RegisteredClaims -} - -type AccountCredentialsDynamicRegistrationTokenOptions struct { - AccountPublicID uuid.UUID - AccountVersion int32 - Domain string - ClientID string -} - -func (t *Tokens) CreateAccountCredentialsDynamicRegistrationToken( - opts AccountCredentialsDynamicRegistrationTokenOptions, -) *jwt.Token { - now := time.Now() - iat := jwt.NewNumericDate(now) - exp := jwt.NewNumericDate(now.Add(time.Second * time.Duration(t.dynamicRegistrationTTL))) - return jwt.NewWithClaims( - jwt.SigningMethodEdDSA, - accountCredentialsDynamicRegistrationClaims{ - AccountClaims: AccountClaims{ - AccountID: opts.AccountPublicID, - AccountVersion: opts.AccountVersion, - }, - Domain: opts.Domain, - ClientID: opts.ClientID, - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: fmt.Sprintf("https://%s", t.backendDomain), - Audience: []string{ - fmt.Sprintf("https://%s", opts.Domain), - }, - Subject: opts.Domain, - IssuedAt: iat, - NotBefore: iat, - ExpiresAt: exp, - ID: uuid.NewString(), - }, - }, - ) -} - -func (t *Tokens) GetDynamicRegistrationTTL() int64 { - return t.dynamicRegistrationTTL -} diff --git a/idp/internal/providers/tokens/dynamic_registration_iat.go b/idp/internal/providers/tokens/dynamic_registration_iat.go new file mode 100644 index 0000000..d2b59c3 --- /dev/null +++ b/idp/internal/providers/tokens/dynamic_registration_iat.go @@ -0,0 +1,106 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package tokens + +import ( + "context" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const dynamicRegistrationIATLocation = "dynamic_registration_iat" + +type accountCredentialsDynamicRegistrationClaims struct { + AccountClaims + Domain string `json:"domain"` + ClientID string `json:"client_id"` + jwt.RegisteredClaims +} + +type AccountCredentialsDynamicRegistrationTokenOptions struct { + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string + ClientID string +} + +func (t *Tokens) CreateAccountCredentialsDynamicRegistrationToken( + opts AccountCredentialsDynamicRegistrationTokenOptions, +) *jwt.Token { + now := time.Now() + iat := jwt.NewNumericDate(now) + exp := jwt.NewNumericDate(now.Add(time.Second * time.Duration(t.dynamicRegistrationTTL))) + iss := fmt.Sprintf("https://%s", t.backendDomain) + return jwt.NewWithClaims( + jwt.SigningMethodEdDSA, + accountCredentialsDynamicRegistrationClaims{ + AccountClaims: AccountClaims{ + AccountID: opts.AccountPublicID, + AccountVersion: opts.AccountVersion, + }, + Domain: opts.Domain, + ClientID: opts.ClientID, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: iss, + Audience: []string{iss}, + Subject: opts.Domain, + IssuedAt: iat, + NotBefore: iat, + ExpiresAt: exp, + ID: uuid.NewString(), + }, + }, + ) +} + +type VerifyAccountCredentialsDynamicRegistrationTokenOptions struct { + RequestID string + IAT string + GetPublicJWK GetPublicJWK +} + +func (t *Tokens) VerifyAccountCredentialsDynamicRegistrationToken( + ctx context.Context, + opts VerifyAccountCredentialsDynamicRegistrationTokenOptions, +) (string, AccountClaims, error) { + logger := utils.BuildLogger(t.logger, utils.LoggerOptions{ + Location: dynamicRegistrationIATLocation, + Method: "VerifyAccountCredentialsDynamicRegistrationToken", + RequestID: opts.RequestID, + }) + logger.DebugContext(ctx, "Verifying account credentials dynamic registration IAT...") + + claims := new(accountCredentialsDynamicRegistrationClaims) + if _, err := jwt.ParseWithClaims(opts.IAT, claims, func(token *jwt.Token) (interface{}, error) { + kid, err := extractTokenKID(token) + if err != nil { + logger.DebugContext(ctx, "Failed to extract KID from account credentials dynamic registration IAT", "error", err) + return nil, err + } + + jwk, err := opts.GetPublicJWK(kid, utils.SupportedCryptoSuiteEd25519) + if err != nil { + logger.WarnContext(ctx, "Failed to get public JWK for account credentials dynamic registration IAT", "error", err, "kid", kid) + return nil, err + } + + return jwk.ToUsableKey() + }); err != nil { + logger.WarnContext(ctx, "Failed to verify account credentials dynamic registration IAT", "error", err) + return "", AccountClaims{}, err + } + + return claims.Domain, claims.AccountClaims, nil +} + +func (t *Tokens) GetDynamicRegistrationTTL() int64 { + return t.dynamicRegistrationTTL +} diff --git a/idp/internal/server/routes/oauth.go b/idp/internal/server/routes/oauth.go index 599aaa1..6255059 100644 --- a/idp/internal/server/routes/oauth.go +++ b/idp/internal/server/routes/oauth.go @@ -23,4 +23,11 @@ func (r *Routes) OAuthRoutes(app *fiber.App) { // OAuth2 Callbacks router.Post(paths.OAuthAppleCallback, r.controllers.AccountAppleCallback) router.Get(paths.OAuthCallback, r.controllers.AccountOAuthCallback) + + // Register + router.Post( + paths.AccountsBase+paths.AccountsSingle+paths.OAuthRegister, + r.controllers.AccountCredentialsDRIATMiddleware, + r.controllers.OAuthDynamicRegistration, + ) } diff --git a/idp/internal/services/account_credentials_registration.go b/idp/internal/services/account_credentials_registration.go index e5038c1..bcbdc09 100644 --- a/idp/internal/services/account_credentials_registration.go +++ b/idp/internal/services/account_credentials_registration.go @@ -49,6 +49,7 @@ type checkAccountCRDomainOptions struct { requestID string accountPublicID uuid.UUID domain string + iatDomain string requireVerifiedDomains bool } @@ -66,6 +67,13 @@ func (s *Services) checkAccountCRDomain( logger.WarnContext(ctx, "Failed to parse base domain", "error", err) return "", exceptions.NewValidationError("invalid client URI") } + if opts.iatDomain != "" && opts.domain != opts.iatDomain && baseDomain != opts.iatDomain { + logger.WarnContext(ctx, "Client URI base domain does not match IAT domain", + "baseDomain", baseDomain, + "iatDomain", opts.iatDomain, + ) + return "", exceptions.NewUnauthorizedError() + } var count int64 if baseDomain != opts.domain { @@ -586,8 +594,9 @@ func (s *Services) mapAccountCredentialsRegistrationDataToDBParams( type CreateAccountCredentialsRegistrationOptions struct { RequestID string AccountPublicID uuid.UUID - AccountVersion int32 IsAuthenticated bool + IATDomain string + AccountVersion int32 ApplicationType string RedirectURIs []string TokenEndpointAuthMethod string @@ -690,16 +699,6 @@ func (s *Services) CreateAccountCredentialsRegistration( return dtos.AccountCredentialsDTO{}, serviceErr } - accountID, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ - RequestID: opts.RequestID, - PublicID: opts.AccountPublicID, - Version: opts.AccountVersion, - }) - if serviceErr != nil { - logger.ErrorContext(ctx, "Failed to get account ID", "serviceError", serviceErr) - return dtos.AccountCredentialsDTO{}, serviceErr - } - if slices.Contains(accountDRConfigDTO.RequireInitialAccessTokenCredentialTypes, applicationType) && !opts.IsAuthenticated { logger.WarnContext(ctx, "Account dynamic registration configuration needs to contain initial access token") @@ -712,16 +711,23 @@ func (s *Services) CreateAccountCredentialsRegistration( return dtos.AccountCredentialsDTO{}, exceptions.NewUnauthorizedError() } - _, serviceErr = s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + accountDTO, serviceErr := s.GetAccountByPublicID(ctx, GetAccountByPublicIDOptions{ RequestID: opts.RequestID, PublicID: opts.AccountPublicID, - Version: opts.AccountVersion, }) if serviceErr != nil { - logger.WarnContext(ctx, "Failed to get account", "serviceError", serviceErr) + logger.ErrorContext(ctx, "Failed to get account ID", "serviceError", serviceErr) return dtos.AccountCredentialsDTO{}, serviceErr } + if opts.AccountVersion != 0 && accountDTO.Version() != opts.AccountVersion { + logger.WarnContext(ctx, "Account version mismatch", + "providedVersion", opts.AccountVersion, + "currentVersion", accountDTO.Version(), + ) + return dtos.AccountCredentialsDTO{}, exceptions.NewUnauthorizedError() + } + accountID := accountDTO.ID() parsedClientURI, err := url.Parse(opts.ClientURI) if err != nil { logger.WarnContext(ctx, "Failed to parse client URI", "error", err) @@ -732,6 +738,7 @@ func (s *Services) CreateAccountCredentialsRegistration( baseDomain, serviceErr := s.checkAccountCRDomain(ctx, checkAccountCRDomainOptions{ requestID: opts.RequestID, accountPublicID: opts.AccountPublicID, + iatDomain: opts.IATDomain, domain: domain, requireVerifiedDomains: slices.Contains(accountDRConfigDTO.RequireVerifiedDomainsCredentialsType, applicationType), }) @@ -793,7 +800,7 @@ func (s *Services) CreateAccountCredentialsRegistration( }) if err != nil { logger.WarnContext(ctx, "Failed to verify software statement", "error", err) - return dtos.AccountCredentialsDTO{}, exceptions.NewUnauthorizedError() + return dtos.AccountCredentialsDTO{}, exceptions.NewInvalidTokenError("invalid software statement") } if serviceErr := s.verifySoftwareStatementSTDClaims(ctx, verifySoftwareStatementSTDClaimsOptions{ requestID: opts.RequestID, @@ -823,7 +830,7 @@ func (s *Services) CreateAccountCredentialsRegistration( params, serviceErr := s.mapAccountCredentialsRegistrationDataToDBParams(ctx, mapAccountCredentialsRegistrationDataToDBParamsOptions{ applicationType: applicationType, accountPublicID: opts.AccountPublicID, - accountID: opts.AccountVersion, + accountID: accountDTO.Version(), domain: domain, requestID: opts.RequestID, tokenEndpointAuthMethod: tokenEndpointAuthMethod, diff --git a/idp/internal/services/account_credentials_registration_iat.go b/idp/internal/services/account_credentials_registration_iat.go index b6d4021..b11131a 100644 --- a/idp/internal/services/account_credentials_registration_iat.go +++ b/idp/internal/services/account_credentials_registration_iat.go @@ -91,3 +91,41 @@ func (s *Services) CreateAccountCredentialsRegistrationIAT( logger.InfoContext(ctx, "Created account credentials registration IAT successfully") return signedToken, nil } + +type ProcessAccountCredentialsRegistrationIATAuthOptions struct { + RequestID string + AuthHeader string +} + +func (s *Services) ProcessAccountCredentialsRegistrationIATAuth( + ctx context.Context, + opts ProcessAccountCredentialsRegistrationIATAuthOptions, +) (string, tokens.AccountClaims, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, accountCredentialsRegistrationIATLocation, "ProcessAccountCredentialsRegistrationIATAuth") + logger.InfoContext(ctx, "Processing account credentials registration IAT auth...") + + token, serviceErr := extractAuthHeaderToken(opts.AuthHeader) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to extract token from auth header", "serviceError", serviceErr) + return "", tokens.AccountClaims{}, serviceErr + } + + domain, accountClaims, err := s.jwt.VerifyAccountCredentialsDynamicRegistrationToken( + ctx, + tokens.VerifyAccountCredentialsDynamicRegistrationTokenOptions{ + RequestID: opts.RequestID, + IAT: token, + GetPublicJWK: s.BuildGetGlobalPublicKeyFn(ctx, BuildGetGlobalVerifyKeyFnOptions{ + RequestID: opts.RequestID, + KeyType: database.TokenKeyTypeDynamicRegistration, + }), + }, + ) + if err != nil { + logger.WarnContext(ctx, "Failed to verify account credentials registration IAT", "error", err) + return "", tokens.AccountClaims{}, exceptions.NewUnauthorizedError() + } + + logger.InfoContext(ctx, "Processed account credentials registration IAT auth successfully") + return domain, accountClaims, nil +} diff --git a/idp/internal/services/dtos/account_credentials.go b/idp/internal/services/dtos/account_credentials.go index fcf6290..7cb2d31 100644 --- a/idp/internal/services/dtos/account_credentials.go +++ b/idp/internal/services/dtos/account_credentials.go @@ -17,7 +17,9 @@ import ( ) type AccountCredentialsDTO struct { - ClientID string `json:"client_id"` + ClientID string `json:"client_id"` + ClientIDIAT int64 `json:"client_idiat"` + Type database.AccountCredentialsType `json:"application_type"` ClientName string `json:"client_name"` Domain string `json:"domain"` @@ -58,7 +60,7 @@ type AccountCredentialsDTO struct { ClientSecretID string `json:"client_secret_id,omitempty"` ClientSecret string `json:"client_secret,omitempty"` ClientSecretJWK utils.JWK `json:"client_secret_jwk,omitempty"` - ClientSecretExp int64 `json:"client_secret_exp,omitempty"` + ClientSecretExp int64 `json:"client_secret_expires_at,omitempty"` id int32 accountId int32 @@ -140,6 +142,7 @@ func MapAccountCredentialsToDTO( return AccountCredentialsDTO{ id: accountCredential.ID, ClientID: accountCredential.ClientID, + ClientIDIAT: accountCredential.CreatedAt.Unix(), Type: accountCredential.CredentialsType, ClientName: accountCredential.ClientName, Domain: accountCredential.Domain, @@ -221,6 +224,7 @@ func MapAccountCredentialsToDTOWithJWK( TokenEndpointAuthMethod: accountCredential.TokenEndpointAuthMethod, accountId: accountCredential.AccountID, ClientID: accountCredential.ClientID, + ClientIDIAT: accountCredential.CreatedAt.Unix(), ClientSecretID: jwk.GetKeyID(), ClientSecretJWK: jwk, ClientSecretExp: exp.Unix(), @@ -292,6 +296,7 @@ func MapAccountCredentialsToDTOWithSecret( TokenEndpointAuthMethod: accountCredential.TokenEndpointAuthMethod, accountId: accountCredential.AccountID, ClientID: accountCredential.ClientID, + ClientIDIAT: accountCredential.CreatedAt.Unix(), ClientSecretID: secretID, ClientSecret: fmt.Sprintf("%s.%s", secretID, secret), ClientSecretExp: exp.Unix(), diff --git a/idp/internal/services/software_statement.go b/idp/internal/services/software_statement.go index d6d771c..29538f1 100644 --- a/idp/internal/services/software_statement.go +++ b/idp/internal/services/software_statement.go @@ -81,7 +81,7 @@ func (s *Services) verifySoftwareStatementSTDClaims( logger.WarnContext(ctx, "Software statement issuer does not match client URI domain or base domain", "issuer", opts.claims.Issuer, ) - return exceptions.NewUnauthorizedError() + return exceptions.NewUnauthorizedTokenError("issuer does not match client URI domain or base domain") } if opts.claims.Audience == nil || !slices.ContainsFunc(opts.claims.Audience, func(aud string) bool { return aud == fmt.Sprintf("https://%s", opts.frontendDomain) || aud == fmt.Sprintf("https://%s", opts.backendDomain) @@ -89,25 +89,25 @@ func (s *Services) verifySoftwareStatementSTDClaims( logger.WarnContext(ctx, "Software statement audience does not match frontend or backend domain", "audience", opts.claims.Audience, ) - return exceptions.NewUnauthorizedError() + return exceptions.NewUnauthorizedTokenError("audience does not match frontend or backend") } if opts.claims.IssuedAt == nil || opts.claims.IssuedAt.Time.IsZero() || opts.claims.IssuedAt.Time.After(time.Now()) { logger.WarnContext(ctx, "Software statement issued at claim is invalid", "issuedAt", opts.claims.IssuedAt, ) - return exceptions.NewUnauthorizedError() + return exceptions.NewUnauthorizedTokenError("issued at claim is invalid") } if opts.claims.NotBefore != nil && !opts.claims.NotBefore.Time.IsZero() && opts.claims.NotBefore.Time.After(time.Now()) { logger.WarnContext(ctx, "Software statement not before claim is invalid", "notBefore", opts.claims.NotBefore, ) - return exceptions.NewUnauthorizedError() + return exceptions.NewUnauthorizedTokenError("not before claim is invalid") } if opts.claims.ExpiresAt == nil || opts.claims.ExpiresAt.Time.IsZero() || opts.claims.ExpiresAt.Time.Before(time.Now()) { logger.WarnContext(ctx, "Software statement expiration claim is invalid", "expiresAt", opts.claims.ExpiresAt, ) - return exceptions.NewUnauthorizedError() + return exceptions.NewUnauthorizedTokenError("expiresAt claim is invalid") } logger.InfoContext(ctx, "Verified software statement standard claims") From 5c6dafa94309362542a7a574968b0097c309298d Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Tue, 4 Nov 2025 23:53:32 +1300 Subject: [PATCH 22/23] feat(idp): add custom host aware routes --- idp/initial_schema.dbml | 12 +- .../app_dynamic_registration_configs.go | 137 ++++ .../app_dynamic_registration_configs.go | 28 + .../bodies/oauth_dynamic_registration.go | 2 +- idp/internal/controllers/middleware.go | 27 +- idp/internal/controllers/oauth.go | 6 +- .../controllers/oauth_dynamic_registration.go | 91 +++ idp/internal/controllers/paths/common.go | 13 +- idp/internal/controllers/users_oauth.go | 47 ++ idp/internal/controllers/well_known.go | 30 - .../account_token_signing_keys.sql.go | 1 + .../app_dynamic_registration_configs.sql.go | 306 ++++++++ idp/internal/providers/database/apps.sql.go | 19 + .../dynamic_registration_domains.sql.go | 70 +- ...0241213231542_create_initial_schema.up.sql | 20 +- idp/internal/providers/database/models.go | 2 +- .../queries/account_token_signing_keys.sql | 1 + .../app_dynamic_registration_configs.sql | 82 +++ .../providers/database/queries/apps.sql | 5 + .../queries/dynamic_registration_domains.sql | 26 +- .../tokens/dynamic_registration_iat.go | 46 +- idp/internal/server/routes.go | 2 +- .../server/routes/account_credentials.go | 6 +- .../routes/account_dynamic_registration.go | 25 +- idp/internal/server/routes/accounts.go | 2 +- idp/internal/server/routes/app_designs.go | 2 +- idp/internal/server/routes/apps.go | 30 +- idp/internal/server/routes/auth.go | 38 +- idp/internal/server/routes/common.go | 35 +- idp/internal/server/routes/oauth.go | 45 +- idp/internal/server/routes/oidc_configs.go | 2 +- idp/internal/server/routes/users.go | 2 +- idp/internal/server/routes/users_auth.go | 32 - idp/internal/server/routes/well_known.go | 7 +- .../account_credentials_registration.go | 252 +------ .../account_credentials_registration_iat.go | 27 +- .../services/app_dynamic_registration.go | 662 ++++++++++++++++++ .../app_dynamic_registration_configs.go | 420 +++++++++++ .../services/app_dynamic_registration_iat.go | 138 ++++ idp/internal/services/apps.go | 103 +-- idp/internal/services/auth.go | 35 - .../dtos/app_dynamic_registration_config.go | 61 ++ .../services/dynamic_registration_domains.go | 136 ++++ idp/internal/services/oauth.go | 8 +- .../services/oauth_dynamic_registration.go | 11 +- idp/internal/services/software_statement.go | 264 ++++--- idp/internal/services/users_oauth.go | 7 + 47 files changed, 2716 insertions(+), 607 deletions(-) create mode 100644 idp/internal/controllers/app_dynamic_registration_configs.go create mode 100644 idp/internal/controllers/bodies/app_dynamic_registration_configs.go create mode 100644 idp/internal/controllers/users_oauth.go create mode 100644 idp/internal/providers/database/app_dynamic_registration_configs.sql.go create mode 100644 idp/internal/providers/database/queries/app_dynamic_registration_configs.sql delete mode 100644 idp/internal/server/routes/users_auth.go create mode 100644 idp/internal/services/app_dynamic_registration.go create mode 100644 idp/internal/services/app_dynamic_registration_configs.go create mode 100644 idp/internal/services/app_dynamic_registration_iat.go create mode 100644 idp/internal/services/dtos/app_dynamic_registration_config.go create mode 100644 idp/internal/services/users_oauth.go diff --git a/idp/initial_schema.dbml b/idp/initial_schema.dbml index b274791..3763f41 100644 --- a/idp/initial_schema.dbml +++ b/idp/initial_schema.dbml @@ -106,7 +106,7 @@ Table token_signing_keys as TS { crypto_suite token_crypto_suite [not null] expires_at timestamptz [not null] - usage token_key_usage [not null, default: 'account'] + usage token_key_usage [not null] is_distributed boolean [not null, default: false] is_revoked boolean [not null, default: false] @@ -116,7 +116,6 @@ Table token_signing_keys as TS { Indexes { (kid) [unique, name: 'token_signing_keys_kid_uidx'] (expires_at) [name: 'token_signing_keys_expires_at_idx'] - (is_distributed, is_revoked, expires_at) [name: 'token_signing_keys_is_distributed_is_revoked_expires_at_idx'] (key_type, usage, is_revoked, expires_at) [name: 'token_signing_keys_key_type_usage_is_revoked_expires_at_idx'] (usage, is_distributed, is_revoked, expires_at) [name: 'token_signing_keys_usage_is_distributed_is_revoked_expires_at_idx'] (kid, is_revoked) [name: 'token_signing_keys_kid_is_revoked_idx'] @@ -1039,9 +1038,9 @@ Table app_dynamic_registration_configs as APDRC { id serial [pk] account_id integer [not null] + account_public_id uuid [not null] allowed_app_types "app_type[]" [not null] - whitelisted_domains "varchar(250)[]" [not null] default_allow_user_registration boolean [not null] default_auth_providers "auth_provider[]" [not null] default_username_column app_username_column [not null] @@ -1059,7 +1058,7 @@ Table app_dynamic_registration_configs as APDRC { initial_access_token_max_uses int [not null, default: 1] allowed_grant_types "grant_type[]" [not null, default: '{ "authorization_code", "refresh_token", "client_credentials", "urn:ietf:params:oauth:grant-type:device_code", "urn:ietf:params:oauth:grant-type:jwt-bearer" }'] - allowed_response_types "response_type[]" [not null, default: '{ "code", "id_token", "code id_token" }'] + allowed_response_types "response_type[]" [not null, default: '{ "code", "code id_token" }'] allowed_token_endpoint_auth_methods "auth_method[]" [not null, default: '{ "none", "client_secret_post", "client_secret_basic", "client_secret_jwt", "private_key_jwt" }'] max_redirect_uris int [not null, default: 10] @@ -1068,6 +1067,7 @@ Table app_dynamic_registration_configs as APDRC { Indexes { (account_id) [name: 'app_dynamic_registration_configs_account_id_idx'] + (account_public_id) [name: 'app_dynamic_registration_configs_account_public_id_idx'] } } Ref: APDRC.account_id > A.id [delete: cascade] @@ -1099,8 +1099,10 @@ Table dynamic_registration_domains as ADRD { Indexes { (account_id) [name: 'accounts_totps_account_id_idx'] (account_public_id) [name: 'account_dynamic_registration_domains_account_public_id_idx'] - (domain) [name: 'account_dynamic_registration_domains_domain_idx'] (account_public_id, domain) [unique, name: 'account_dynamic_registration_domains_account_public_id_domain_uidx'] + (account_public_id, domain, verified_at) [name: 'account_dynamic_registration_domains_account_public_id_domain_verified_at_idx'] + (account_public_id, domain, usages) [name: 'account_dynamic_registration_domains_account_public_id_domain_usages_idx'] + (account_public_id, domain, usages, verified_at) [name: 'account_dynamic_registration_domains_account_public_id_domain_usages_verified_at_idx'] } } Ref: ADRD.account_id > A.id [delete: cascade] diff --git a/idp/internal/controllers/app_dynamic_registration_configs.go b/idp/internal/controllers/app_dynamic_registration_configs.go new file mode 100644 index 0000000..470e9fa --- /dev/null +++ b/idp/internal/controllers/app_dynamic_registration_configs.go @@ -0,0 +1,137 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package controllers + +import ( + "github.com/gofiber/fiber/v2" + + "github.com/tugascript/devlogs/idp/internal/controllers/bodies" + "github.com/tugascript/devlogs/idp/internal/services" +) + +const ( + appDynamicRegistrationConfigsLocation string = "app_dynamic_registration_configs" +) + +func (c *Controllers) UpsertAppDynamicRegistrationConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger( + requestID, + appDynamicRegistrationConfigsLocation, + "UpsertAppDynamicRegistrationConfig", + ) + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + body := new(bodies.AppDynamicRegistrationConfigBody) + if err := ctx.BodyParser(body); err != nil { + return parseRequestErrorResponse(logger, ctx, err) + } + if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { + return validateBodyErrorResponse(logger, ctx, err) + } + + dto, created, serviceErr := c.services.SaveAppDynamicRegistrationConfig( + ctx.UserContext(), + services.SaveAppDynamicRegistrationConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + AllowedAppTypes: body.AllowedAppTypes, + DefaultAllowUserRegistration: body.DefaultAllowUserRegistration, + DefaultAuthProviders: body.DefaultAuthProviders, + DefaultUsernameColumn: body.DefaultUsernameColumn, + DefaultAllowedScopes: body.DefaultAllowedScopes, + DefaultScopes: body.DefaultScopes, + RequireVerifiedDomainsAppTypes: body.RequireVerifiedDomainsAppTypes, + RequireSoftwareStatementAppTypes: body.RequireSoftwareStatementAppTypes, + SoftwareStatementVerificationMethods: body.SoftwareStatementVerificationMethods, + RequireInitialAccessTokenAppTypes: body.RequireInitialAccessTokenAppTypes, + InitialAccessTokenGenerationMethods: body.InitialAccessTokenGenerationMethods, + InitialAccessTokenTtl: body.InitialAccessTokenTtl, + InitialAccessTokenMaxUses: body.InitialAccessTokenMaxUses, + AllowedGrantTypes: body.AllowedGrantTypes, + AllowedResponseTypes: body.AllowedResponseTypes, + AllowedTokenEndpointAuthMethods: body.AllowedTokenEndpointAuthMethods, + MaxRedirectUris: body.MaxRedirectUris, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + if created { + logResponse(logger, ctx, fiber.StatusCreated) + return ctx.Status(fiber.StatusCreated).JSON(&dto) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&dto) +} + +func (c *Controllers) GetAppDynamicRegistrationConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger( + requestID, + appDynamicRegistrationConfigsLocation, + "GetAppDynamicRegistrationConfig", + ) + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + dto, serviceErr := c.services.GetAppDynamicRegistrationConfig( + ctx.UserContext(), + services.GetAppDynamicRegistrationConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&dto) +} + +func (c *Controllers) DeleteAppDynamicRegistrationConfig(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger( + requestID, + appDynamicRegistrationConfigsLocation, + "DeleteAppDynamicRegistrationConfig", + ) + logRequest(logger, ctx) + + accountClaims, serviceErr := getAccountClaims(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + serviceErr = c.services.DeleteAppDynamicRegistrationConfig( + ctx.UserContext(), + services.DeleteAppDynamicRegistrationConfigOptions{ + RequestID: requestID, + AccountPublicID: accountClaims.AccountID, + AccountVersion: accountClaims.AccountVersion, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusNoContent) + return ctx.SendStatus(fiber.StatusNoContent) +} diff --git a/idp/internal/controllers/bodies/app_dynamic_registration_configs.go b/idp/internal/controllers/bodies/app_dynamic_registration_configs.go new file mode 100644 index 0000000..559e5c2 --- /dev/null +++ b/idp/internal/controllers/bodies/app_dynamic_registration_configs.go @@ -0,0 +1,28 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package bodies + +type AppDynamicRegistrationConfigBody struct { + AllowedAppTypes []string `json:"allowed_app_types" validate:"required,unique,min=1,dive,oneof=web spa native backend device service mcp"` + DefaultAllowUserRegistration bool `json:"default_allow_user_registration"` + DefaultAuthProviders []string `json:"default_auth_providers,omitempty" validate:"omitempty,unique,dive,oneof=local apple facebook github google microsoft"` + DefaultUsernameColumn string `json:"default_username_column,omitempty" validate:"omitempty,oneof=email username both"` + DefaultAllowedScopes []string `json:"default_allowed_scopes,omitempty" validate:"omitempty,unique,dive,single_scope"` + DefaultScopes []string `json:"default_scopes,omitempty" validate:"omitempty,unique,dive,single_scope"` + RequireVerifiedDomainsAppTypes []string `json:"require_verified_domains_app_types,omitempty" validate:"omitempty,unique,dive,oneof=web spa native backend device service mcp"` + RequireSoftwareStatementAppTypes []string `json:"require_software_statement_app_types,omitempty" validate:"omitempty,unique,dive,oneof=web spa native backend device service mcp"` + SoftwareStatementVerificationMethods []string `json:"software_statement_verification_methods,omitempty" validate:"omitempty,unique,min=1,max=2,dive,oneof=manual jwks_uri"` + RequireInitialAccessTokenAppTypes []string `json:"require_initial_access_token_app_types,omitempty" validate:"omitempty,unique,dive,oneof=web spa native backend device service mcp"` + InitialAccessTokenGenerationMethods []string `json:"initial_access_token_generation_methods,omitempty" validate:"omitempty,unique,min=1,max=2,dive,oneof=manual authorization_code"` + InitialAccessTokenTtl int32 `json:"initial_access_token_ttl,omitempty" validate:"omitempty,min=1"` + InitialAccessTokenMaxUses int32 `json:"initial_access_token_max_uses,omitempty" validate:"omitempty,min=1"` + AllowedGrantTypes []string `json:"allowed_grant_types,omitempty" validate:"omitempty,unique,min=1,dive,oneof=authorization_code refresh_token client_credentials urn:ietf:params:oauth:grant-type:device_code urn:ietf:params:oauth:grant-type:jwt-bearer"` + AllowedResponseTypes []string `json:"allowed_response_types,omitempty" validate:"omitempty,unique,dive,oneof=code 'code id_token'"` + AllowedTokenEndpointAuthMethods []string `json:"allowed_token_endpoint_auth_methods,omitempty" validate:"omitempty,unique,dive,oneof=none client_secret_post client_secret_basic client_secret_jwt private_key_jwt"` + MaxRedirectUris int32 `json:"max_redirect_uris,omitempty" validate:"omitempty,min=1"` +} + diff --git a/idp/internal/controllers/bodies/oauth_dynamic_registration.go b/idp/internal/controllers/bodies/oauth_dynamic_registration.go index 5efbf4e..67995e5 100644 --- a/idp/internal/controllers/bodies/oauth_dynamic_registration.go +++ b/idp/internal/controllers/bodies/oauth_dynamic_registration.go @@ -11,7 +11,7 @@ type OAuthDynamicClientRegistrationBody struct { TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty" validate:"omitempty,oneof=none client_secret_basic client_secret_post client_secret_jwt private_key_jwt"` ResponseTypes []string `json:"response_types,omitempty" validate:"omitempty,dive,oneof=code 'code id_token'"` GrantTypes []string `json:"grant_types,omitempty" validate:"omitempty,min=1,dive,oneof=authorization_code refresh_token client_credentials urn:ietf:params:oauth:grant-type:jwt-bearer"` - ApplicationType string `json:"application_type" validate:"required,oneof=native service mcp"` + ApplicationType string `json:"application_type" validate:"required,oneof=native service mcp web spa backend device"` ClientName string `json:"client_name" validate:"required,min=1,max=255"` ClientURI string `json:"client_uri" validate:"required,url"` LogoURI string `json:"logo_uri,omitempty" validate:"omitempty,url"` diff --git a/idp/internal/controllers/middleware.go b/idp/internal/controllers/middleware.go index fee05eb..906e942 100644 --- a/idp/internal/controllers/middleware.go +++ b/idp/internal/controllers/middleware.go @@ -178,8 +178,9 @@ func (c *Controllers) AccountCredentialsDRIATMiddleware(ctx *fiber.Ctx) error { domain, accountClaims, serviceErr := c.services.ProcessAccountCredentialsRegistrationIATAuth( ctx.UserContext(), services.ProcessAccountCredentialsRegistrationIATAuthOptions{ - RequestID: requestID, - AuthHeader: authHeader, + RequestID: requestID, + AuthHeader: authHeader, + BackendDomain: c.backendDomain, }, ) if serviceErr != nil { @@ -245,13 +246,14 @@ func processHost(backendDomain string, host string) (string, error) { return username, nil } -func (c *Controllers) AccountHostMiddleware(ctx *fiber.Ctx) error { +func (c *Controllers) HostMiddleware(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, middlewareLocation, "AccountHostMiddleware") + logger := c.buildLogger(requestID, middlewareLocation, "HostMiddleware") host := ctx.Hostname() - if host == "" { - logger.DebugContext(ctx.UserContext(), "no host found") - return serviceErrorResponse(logger, ctx, exceptions.NewNotFoundError()) + if host == "" || host == c.backendDomain { + logger.InfoContext(ctx.UserContext(), "Base url found") + ctx.Locals("hasAccountHost", false) + return ctx.Next() } username, err := processHost(c.backendDomain, host) @@ -275,11 +277,22 @@ func (c *Controllers) AccountHostMiddleware(ctx *fiber.Ctx) error { return serviceErrorResponse(logger, ctx, serviceErr) } + ctx.Locals("hasAccountHost", true) ctx.Locals("accountUsername", username) ctx.Locals("accountID", accountID) return ctx.Next() } +func (c *Controllers) NoHostMiddleware(ctx *fiber.Ctx) error { + logger := c.buildLogger(getRequestID(ctx), middlewareLocation, "NoHostMiddleware") + host := ctx.Hostname() + if host == "" || host == c.backendDomain { + return ctx.Next() + } + + return serviceErrorResponse(logger, ctx, exceptions.NewNotFoundError()) +} + func getAccountClaims(ctx *fiber.Ctx) (tokens.AccountClaims, *exceptions.ServiceError) { account, ok := ctx.Locals("account").(tokens.AccountClaims) if !ok || account.AccountID == uuid.Nil { diff --git a/idp/internal/controllers/oauth.go b/idp/internal/controllers/oauth.go index 20f4de6..5cd7858 100644 --- a/idp/internal/controllers/oauth.go +++ b/idp/internal/controllers/oauth.go @@ -422,12 +422,12 @@ func (c *Controllers) AccountOAuthToken(ctx *fiber.Ctx) error { } } -func (c *Controllers) AccountOAuthPublicJWKs(ctx *fiber.Ctx) error { +func (c *Controllers) GlobalOAuthPublicJWKs(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, oauthLocation, "AccountOAuthPublicJWKs") + logger := c.buildLogger(requestID, oauthLocation, "GlobalOAuthPublicJWKs") logRequest(logger, ctx) - etag, jwksDTO, serviceErr := c.services.GetAccountPublicJWKs(ctx.UserContext(), requestID) + etag, jwksDTO, serviceErr := c.services.GetGlobalPublicJWKs(ctx.UserContext(), requestID) if serviceErr != nil { return serviceErrorResponse(logger, ctx, serviceErr) } diff --git a/idp/internal/controllers/oauth_dynamic_registration.go b/idp/internal/controllers/oauth_dynamic_registration.go index 71f75f1..c438a0b 100644 --- a/idp/internal/controllers/oauth_dynamic_registration.go +++ b/idp/internal/controllers/oauth_dynamic_registration.go @@ -114,3 +114,94 @@ func (c *Controllers) OAuthDynamicRegistration(ctx *fiber.Ctx) error { logResponse(logger, ctx, fiber.StatusCreated) return ctx.Status(fiber.StatusCreated).JSON(&accountCredentialsDTO) } + +func (c *Controllers) OAuthAppDynamicRegistration(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, oauthDynamicRegistration, "OAuthAppDynamicRegistration") + logRequest(logger, ctx) + + _, accountID, serviceErr := getHostAccount(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + body := new(bodies.OAuthDynamicClientRegistrationBody) + if err := ctx.BodyParser(body); err != nil { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidClientMetadata) + } + if err := c.validate.StructCtx(ctx.UserContext(), body); err != nil { + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidClientMetadata) + } + + isAuthenticated, ok := ctx.Locals("isAuthenticated").(bool) + if !ok { + logger.ErrorContext(ctx.UserContext(), "isAuthenticated should be set in context by middleware") + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorServerError) + } + + domain, ok := ctx.Locals("domain").(string) + if isAuthenticated && !ok { + logger.ErrorContext(ctx.UserContext(), "domain should be set in context by middleware") + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorServerError) + } + + account, ok := ctx.Locals("account").(tokens.AccountClaims) + if isAuthenticated && !ok { + logger.ErrorContext(ctx.UserContext(), "account should be set in context by middleware") + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorServerError) + } + + appDTO, serviceErr := c.services.CreateAppCredentialsRegistration( + ctx.UserContext(), + services.CreateAppCredentialsRegistrationOptions{ + RequestID: requestID, + IsAuthenticated: isAuthenticated, + AccountID: accountID, + IATDomain: domain, + AccountVersion: account.AccountVersion, + ApplicationType: body.ApplicationType, + RedirectURIs: body.RedirectURIs, + TokenEndpointAuthMethod: body.TokenEndpointAuthMethod, + GrantTypes: body.GrantTypes, + ResponseTypes: body.ResponseTypes, + ClientName: body.ClientName, + ClientURI: body.ClientURI, + LogoURI: body.LogoURI, + TOSURI: body.TOSURI, + PolicyURI: body.PolicyURI, + Contacts: body.Contacts, + SoftwareID: body.SoftwareID, + SoftwareVersion: body.SoftwareVersion, + SoftwareStatement: body.SoftwareStatement, + JWKsURI: body.JWKsURI, + JWKs: body.JWKs, + FrontendDomain: c.frontendDomain, + BackendDomain: c.backendDomain, + RequireAuthTime: body.RequireAuthTime, + DefaultMaxAge: body.DefaultMaxAge, + SubjectType: body.SubjectType, + IDTokenSignedResponseAlg: body.IDTokenSignedResponseAlg, + IDTokenEncryptedResponseAlg: body.IDTokenEncryptedResponseAlg, + IDTokenEncryptedResponseEnc: body.IDTokenEncryptedResponseEnc, + RequestObjectSigningAlg: body.RequestObjectSigningAlg, + RequestObjectEncryptionAlg: body.RequestObjectEncryptionAlg, + RequestObjectEncryptionEnc: body.RequestObjectEncryptionEnc, + DefaultACRValues: body.DefaultACRValues, + Scope: body.Scope, + SectorIdentifierURI: body.SectorIdentifierURI, + InitiateLoginURI: body.InitiateLoginURI, + RequestURIs: body.RequestURIs, + UserInfoSignedResponseAlg: body.UserInfoSignedResponseAlg, + UserInfoEncryptedResponseAlg: body.UserInfoEncryptedResponseAlg, + UserInfoEncryptedResponseEnc: body.UserInfoEncryptedResponseEnc, + TokenEndpointAuthSigningAlg: body.TokenEndpointAuthSigningAlg, + AccessTokenSigningAlg: body.AccessTokenSigningAlg, + }, + ) + if serviceErr != nil { + return dynamicRegistrationServiceError(logger, ctx, serviceErr) + } + + logResponse(logger, ctx, fiber.StatusCreated) + return ctx.Status(fiber.StatusCreated).JSON(&appDTO) +} diff --git a/idp/internal/controllers/paths/common.go b/idp/internal/controllers/paths/common.go index 4e20871..95754eb 100644 --- a/idp/internal/controllers/paths/common.go +++ b/idp/internal/controllers/paths/common.go @@ -7,10 +7,11 @@ package paths const ( - Base string = "/" - V1 string = "/v1" - Keys string = "/keys" - Confirm string = "/confirm" - Recover string = "/recover" - Config string = "/config" + Base string = "/" + V1 string = "/v1" + AccountsV1 string = "/v1/a" + Keys string = "/keys" + Confirm string = "/confirm" + Recover string = "/recover" + Config string = "/config" ) diff --git a/idp/internal/controllers/users_oauth.go b/idp/internal/controllers/users_oauth.go new file mode 100644 index 0000000..03f0c43 --- /dev/null +++ b/idp/internal/controllers/users_oauth.go @@ -0,0 +1,47 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package controllers + +import ( + "github.com/gofiber/fiber/v2" + + "github.com/tugascript/devlogs/idp/internal/services" +) + +const usersOAuthLocation string = "users_oauth" + +func (c *Controllers) AccountDistributedOAuthPublicJWKs(ctx *fiber.Ctx) error { + requestID := getRequestID(ctx) + logger := c.buildLogger(requestID, usersOAuthLocation, "AccountDistributedOAuthPublicJWKs") + logRequest(logger, ctx) + + _, accountID, serviceErr := getHostAccount(ctx) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + etag, jwksDTO, serviceErr := c.services.GetAndCacheAccountDistributedJWK( + ctx.UserContext(), + services.GetAndCacheAccountDistributedJWKOptions{ + RequestID: requestID, + AccountID: accountID, + }, + ) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + + if match := ctx.Get(fiber.HeaderIfNoneMatch); match == etag { + logResponse(logger, ctx, fiber.StatusNotModified) + return ctx.SendStatus(fiber.StatusNotModified) + } + + ctx.Set(fiber.HeaderCacheControl, publicJWKsCacheControl) + ctx.Set(fiber.HeaderETag, etag) + logResponse(logger, ctx, fiber.StatusOK) + return ctx.Status(fiber.StatusOK).JSON(&jwksDTO) +} diff --git a/idp/internal/controllers/well_known.go b/idp/internal/controllers/well_known.go index 21d20d8..bf59510 100644 --- a/idp/internal/controllers/well_known.go +++ b/idp/internal/controllers/well_known.go @@ -14,39 +14,9 @@ import ( const ( wellKnownLocation string = "well_known" - wellKnownJWKsCacheControl string = "public, max-age=300, must-revalidate" wellKnownOIDCCacheControl string = "public, max-age=3600, must-revalidate" ) -func (c *Controllers) WellKnownJWKs(ctx *fiber.Ctx) error { - requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, wellKnownLocation, "WellKnownJWKs") - logRequest(logger, ctx) - - _, accountID, serviceErr := getHostAccount(ctx) - if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) - } - - etag, jwksDTO, serviceErr := c.services.WellKnownJWKs(ctx.UserContext(), services.WellKnownJWKsOptions{ - RequestID: requestID, - AccountID: accountID, - }) - if serviceErr != nil { - return serviceErrorResponse(logger, ctx, serviceErr) - } - - if match := ctx.Get(fiber.HeaderIfNoneMatch); match == etag { - logResponse(logger, ctx, fiber.StatusNotModified) - return ctx.SendStatus(fiber.StatusNotModified) - } - - ctx.Set(fiber.HeaderCacheControl, wellKnownJWKsCacheControl) - ctx.Set(fiber.HeaderETag, etag) - logResponse(logger, ctx, fiber.StatusOK) - return ctx.Status(fiber.StatusOK).JSON(&jwksDTO) -} - func (c *Controllers) WellKnownOIDCConfiguration(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) logger := c.buildLogger(requestID, wellKnownLocation, "WellKnownOIDCConfiguration") diff --git a/idp/internal/providers/database/account_token_signing_keys.sql.go b/idp/internal/providers/database/account_token_signing_keys.sql.go index 06545b9..09869f8 100644 --- a/idp/internal/providers/database/account_token_signing_keys.sql.go +++ b/idp/internal/providers/database/account_token_signing_keys.sql.go @@ -39,6 +39,7 @@ const findAccountDistributedTokenSigningKeyPublicKeysByAccountID = `-- name: Fin SELECT "t"."public_key" FROM "token_signing_keys" AS "t" LEFT JOIN "account_token_signing_keys" AS "atsk" ON "t"."id" = "atsk"."token_signing_key_id" WHERE "atsk"."account_id" = $1 AND + "t"."usage" = 'account' AND "t"."is_distributed" = true AND "t"."is_revoked" = false AND "t"."expires_at" > NOW() diff --git a/idp/internal/providers/database/app_dynamic_registration_configs.sql.go b/idp/internal/providers/database/app_dynamic_registration_configs.sql.go new file mode 100644 index 0000000..518225b --- /dev/null +++ b/idp/internal/providers/database/app_dynamic_registration_configs.sql.go @@ -0,0 +1,306 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: app_dynamic_registration_configs.sql + +package database + +import ( + "context" + + "github.com/google/uuid" +) + +const createAppDynamicRegistrationConfig = `-- name: CreateAppDynamicRegistrationConfig :one + +INSERT INTO "app_dynamic_registration_configs" ( + "account_id", + "account_public_id", + "allowed_app_types", + "default_allow_user_registration", + "default_auth_providers", + "default_username_column", + "default_allowed_scopes", + "default_scopes", + "require_verified_domains_app_types", + "require_software_statement_app_types", + "software_statement_verification_methods", + "require_initial_access_token_app_types", + "initial_access_token_generation_methods", + "initial_access_token_ttl", + "initial_access_token_max_uses", + "allowed_grant_types", + "allowed_response_types", + "allowed_token_endpoint_auth_methods", + "max_redirect_uris" +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9, + $10, + $11, + $12, + $13, + $14, + $15, + $16, + $17, + $18, + $19 +) RETURNING id, account_id, account_public_id, allowed_app_types, default_allow_user_registration, default_auth_providers, default_username_column, default_allowed_scopes, default_scopes, require_verified_domains_app_types, require_software_statement_app_types, software_statement_verification_methods, require_initial_access_token_app_types, initial_access_token_generation_methods, initial_access_token_ttl, initial_access_token_max_uses, allowed_grant_types, allowed_response_types, allowed_token_endpoint_auth_methods, max_redirect_uris, created_at, updated_at +` + +type CreateAppDynamicRegistrationConfigParams struct { + AccountID int32 + AccountPublicID uuid.UUID + AllowedAppTypes []AppType + DefaultAllowUserRegistration bool + DefaultAuthProviders []AuthProvider + DefaultUsernameColumn AppUsernameColumn + DefaultAllowedScopes []Scopes + DefaultScopes []Scopes + RequireVerifiedDomainsAppTypes []AppType + RequireSoftwareStatementAppTypes []AppType + SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod + RequireInitialAccessTokenAppTypes []AppType + InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod + InitialAccessTokenTtl int32 + InitialAccessTokenMaxUses int32 + AllowedGrantTypes []GrantType + AllowedResponseTypes []ResponseType + AllowedTokenEndpointAuthMethods []AuthMethod + MaxRedirectUris int32 +} + +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +func (q *Queries) CreateAppDynamicRegistrationConfig(ctx context.Context, arg CreateAppDynamicRegistrationConfigParams) (AppDynamicRegistrationConfig, error) { + row := q.db.QueryRow(ctx, createAppDynamicRegistrationConfig, + arg.AccountID, + arg.AccountPublicID, + arg.AllowedAppTypes, + arg.DefaultAllowUserRegistration, + arg.DefaultAuthProviders, + arg.DefaultUsernameColumn, + arg.DefaultAllowedScopes, + arg.DefaultScopes, + arg.RequireVerifiedDomainsAppTypes, + arg.RequireSoftwareStatementAppTypes, + arg.SoftwareStatementVerificationMethods, + arg.RequireInitialAccessTokenAppTypes, + arg.InitialAccessTokenGenerationMethods, + arg.InitialAccessTokenTtl, + arg.InitialAccessTokenMaxUses, + arg.AllowedGrantTypes, + arg.AllowedResponseTypes, + arg.AllowedTokenEndpointAuthMethods, + arg.MaxRedirectUris, + ) + var i AppDynamicRegistrationConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.AllowedAppTypes, + &i.DefaultAllowUserRegistration, + &i.DefaultAuthProviders, + &i.DefaultUsernameColumn, + &i.DefaultAllowedScopes, + &i.DefaultScopes, + &i.RequireVerifiedDomainsAppTypes, + &i.RequireSoftwareStatementAppTypes, + &i.SoftwareStatementVerificationMethods, + &i.RequireInitialAccessTokenAppTypes, + &i.InitialAccessTokenGenerationMethods, + &i.InitialAccessTokenTtl, + &i.InitialAccessTokenMaxUses, + &i.AllowedGrantTypes, + &i.AllowedResponseTypes, + &i.AllowedTokenEndpointAuthMethods, + &i.MaxRedirectUris, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteAppDynamicRegistrationConfig = `-- name: DeleteAppDynamicRegistrationConfig :exec +DELETE FROM "app_dynamic_registration_configs" WHERE "id" = $1 +` + +func (q *Queries) DeleteAppDynamicRegistrationConfig(ctx context.Context, id int32) error { + _, err := q.db.Exec(ctx, deleteAppDynamicRegistrationConfig, id) + return err +} + +const findAppDynamicRegistrationConfigByAccountID = `-- name: FindAppDynamicRegistrationConfigByAccountID :one +SELECT id, account_id, account_public_id, allowed_app_types, default_allow_user_registration, default_auth_providers, default_username_column, default_allowed_scopes, default_scopes, require_verified_domains_app_types, require_software_statement_app_types, software_statement_verification_methods, require_initial_access_token_app_types, initial_access_token_generation_methods, initial_access_token_ttl, initial_access_token_max_uses, allowed_grant_types, allowed_response_types, allowed_token_endpoint_auth_methods, max_redirect_uris, created_at, updated_at FROM "app_dynamic_registration_configs" +WHERE "account_id" = $1 LIMIT 1 +` + +func (q *Queries) FindAppDynamicRegistrationConfigByAccountID(ctx context.Context, accountID int32) (AppDynamicRegistrationConfig, error) { + row := q.db.QueryRow(ctx, findAppDynamicRegistrationConfigByAccountID, accountID) + var i AppDynamicRegistrationConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.AllowedAppTypes, + &i.DefaultAllowUserRegistration, + &i.DefaultAuthProviders, + &i.DefaultUsernameColumn, + &i.DefaultAllowedScopes, + &i.DefaultScopes, + &i.RequireVerifiedDomainsAppTypes, + &i.RequireSoftwareStatementAppTypes, + &i.SoftwareStatementVerificationMethods, + &i.RequireInitialAccessTokenAppTypes, + &i.InitialAccessTokenGenerationMethods, + &i.InitialAccessTokenTtl, + &i.InitialAccessTokenMaxUses, + &i.AllowedGrantTypes, + &i.AllowedResponseTypes, + &i.AllowedTokenEndpointAuthMethods, + &i.MaxRedirectUris, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const findAppDynamicRegistrationConfigByAccountPublicID = `-- name: FindAppDynamicRegistrationConfigByAccountPublicID :one +SELECT id, account_id, account_public_id, allowed_app_types, default_allow_user_registration, default_auth_providers, default_username_column, default_allowed_scopes, default_scopes, require_verified_domains_app_types, require_software_statement_app_types, software_statement_verification_methods, require_initial_access_token_app_types, initial_access_token_generation_methods, initial_access_token_ttl, initial_access_token_max_uses, allowed_grant_types, allowed_response_types, allowed_token_endpoint_auth_methods, max_redirect_uris, created_at, updated_at FROM "app_dynamic_registration_configs" +WHERE "account_public_id" = $1 LIMIT 1 +` + +func (q *Queries) FindAppDynamicRegistrationConfigByAccountPublicID(ctx context.Context, accountPublicID uuid.UUID) (AppDynamicRegistrationConfig, error) { + row := q.db.QueryRow(ctx, findAppDynamicRegistrationConfigByAccountPublicID, accountPublicID) + var i AppDynamicRegistrationConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.AllowedAppTypes, + &i.DefaultAllowUserRegistration, + &i.DefaultAuthProviders, + &i.DefaultUsernameColumn, + &i.DefaultAllowedScopes, + &i.DefaultScopes, + &i.RequireVerifiedDomainsAppTypes, + &i.RequireSoftwareStatementAppTypes, + &i.SoftwareStatementVerificationMethods, + &i.RequireInitialAccessTokenAppTypes, + &i.InitialAccessTokenGenerationMethods, + &i.InitialAccessTokenTtl, + &i.InitialAccessTokenMaxUses, + &i.AllowedGrantTypes, + &i.AllowedResponseTypes, + &i.AllowedTokenEndpointAuthMethods, + &i.MaxRedirectUris, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const updateAppDynamicRegistrationConfig = `-- name: UpdateAppDynamicRegistrationConfig :one +UPDATE "app_dynamic_registration_configs" SET + "allowed_app_types" = $2, + "default_allow_user_registration" = $3, + "default_auth_providers" = $4, + "default_username_column" = $5, + "default_allowed_scopes" = $6, + "default_scopes" = $7, + "require_verified_domains_app_types" = $8, + "require_software_statement_app_types" = $9, + "software_statement_verification_methods" = $10, + "require_initial_access_token_app_types" = $11, + "initial_access_token_generation_methods" = $12, + "initial_access_token_ttl" = $13, + "initial_access_token_max_uses" = $14, + "allowed_grant_types" = $15, + "allowed_response_types" = $16, + "allowed_token_endpoint_auth_methods" = $17, + "max_redirect_uris" = $18 +WHERE "id" = $1 +RETURNING id, account_id, account_public_id, allowed_app_types, default_allow_user_registration, default_auth_providers, default_username_column, default_allowed_scopes, default_scopes, require_verified_domains_app_types, require_software_statement_app_types, software_statement_verification_methods, require_initial_access_token_app_types, initial_access_token_generation_methods, initial_access_token_ttl, initial_access_token_max_uses, allowed_grant_types, allowed_response_types, allowed_token_endpoint_auth_methods, max_redirect_uris, created_at, updated_at +` + +type UpdateAppDynamicRegistrationConfigParams struct { + ID int32 + AllowedAppTypes []AppType + DefaultAllowUserRegistration bool + DefaultAuthProviders []AuthProvider + DefaultUsernameColumn AppUsernameColumn + DefaultAllowedScopes []Scopes + DefaultScopes []Scopes + RequireVerifiedDomainsAppTypes []AppType + RequireSoftwareStatementAppTypes []AppType + SoftwareStatementVerificationMethods []SoftwareStatementVerificationMethod + RequireInitialAccessTokenAppTypes []AppType + InitialAccessTokenGenerationMethods []InitialAccessTokenGenerationMethod + InitialAccessTokenTtl int32 + InitialAccessTokenMaxUses int32 + AllowedGrantTypes []GrantType + AllowedResponseTypes []ResponseType + AllowedTokenEndpointAuthMethods []AuthMethod + MaxRedirectUris int32 +} + +func (q *Queries) UpdateAppDynamicRegistrationConfig(ctx context.Context, arg UpdateAppDynamicRegistrationConfigParams) (AppDynamicRegistrationConfig, error) { + row := q.db.QueryRow(ctx, updateAppDynamicRegistrationConfig, + arg.ID, + arg.AllowedAppTypes, + arg.DefaultAllowUserRegistration, + arg.DefaultAuthProviders, + arg.DefaultUsernameColumn, + arg.DefaultAllowedScopes, + arg.DefaultScopes, + arg.RequireVerifiedDomainsAppTypes, + arg.RequireSoftwareStatementAppTypes, + arg.SoftwareStatementVerificationMethods, + arg.RequireInitialAccessTokenAppTypes, + arg.InitialAccessTokenGenerationMethods, + arg.InitialAccessTokenTtl, + arg.InitialAccessTokenMaxUses, + arg.AllowedGrantTypes, + arg.AllowedResponseTypes, + arg.AllowedTokenEndpointAuthMethods, + arg.MaxRedirectUris, + ) + var i AppDynamicRegistrationConfig + err := row.Scan( + &i.ID, + &i.AccountID, + &i.AccountPublicID, + &i.AllowedAppTypes, + &i.DefaultAllowUserRegistration, + &i.DefaultAuthProviders, + &i.DefaultUsernameColumn, + &i.DefaultAllowedScopes, + &i.DefaultScopes, + &i.RequireVerifiedDomainsAppTypes, + &i.RequireSoftwareStatementAppTypes, + &i.SoftwareStatementVerificationMethods, + &i.RequireInitialAccessTokenAppTypes, + &i.InitialAccessTokenGenerationMethods, + &i.InitialAccessTokenTtl, + &i.InitialAccessTokenMaxUses, + &i.AllowedGrantTypes, + &i.AllowedResponseTypes, + &i.AllowedTokenEndpointAuthMethods, + &i.MaxRedirectUris, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/idp/internal/providers/database/apps.sql.go b/idp/internal/providers/database/apps.sql.go index 8c0ffb6..137f20b 100644 --- a/idp/internal/providers/database/apps.sql.go +++ b/idp/internal/providers/database/apps.sql.go @@ -12,6 +12,25 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const countAppsByAccountIDAndCliantNameOrSoftwareID = `-- name: CountAppsByAccountIDAndCliantNameOrSoftwareID :one +SELECT COUNT(*) FROM "apps" +WHERE "account_id" = $1 AND ("client_name" = $2 OR "software_id" = $3) +LIMIT 1 +` + +type CountAppsByAccountIDAndCliantNameOrSoftwareIDParams struct { + AccountID int32 + ClientName string + SoftwareID pgtype.Text +} + +func (q *Queries) CountAppsByAccountIDAndCliantNameOrSoftwareID(ctx context.Context, arg CountAppsByAccountIDAndCliantNameOrSoftwareIDParams) (int64, error) { + row := q.db.QueryRow(ctx, countAppsByAccountIDAndCliantNameOrSoftwareID, arg.AccountID, arg.ClientName, arg.SoftwareID) + var count int64 + err := row.Scan(&count) + return count, err +} + const countAppsByAccountIDAndName = `-- name: CountAppsByAccountIDAndName :one SELECT COUNT(*) FROM "apps" WHERE "account_id" = $1 AND "client_name" = $2 diff --git a/idp/internal/providers/database/dynamic_registration_domains.sql.go b/idp/internal/providers/database/dynamic_registration_domains.sql.go index 1e45f98..cc0d50a 100644 --- a/idp/internal/providers/database/dynamic_registration_domains.sql.go +++ b/idp/internal/providers/database/dynamic_registration_domains.sql.go @@ -36,21 +36,23 @@ func (q *Queries) CountDynamicRegistrationDomainsByDomain(ctx context.Context, d return count, err } -const countDynamicRegistrationDomainsByDomainAndAccountPublicID = `-- name: CountDynamicRegistrationDomainsByDomainAndAccountPublicID :one +const countDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUsages = `-- name: CountDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUsages :one SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND - "domain" = $2 + "domain" = $2 AND + "usages" @> $3 LIMIT 1 ` -type CountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams struct { +type CountDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUsagesParams struct { AccountPublicID uuid.UUID Domain string + Usages []DynamicRegistrationUsage } -func (q *Queries) CountDynamicRegistrationDomainsByDomainAndAccountPublicID(ctx context.Context, arg CountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams) (int64, error) { - row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomainAndAccountPublicID, arg.AccountPublicID, arg.Domain) +func (q *Queries) CountDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUsages(ctx context.Context, arg CountDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUsagesParams) (int64, error) { + row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUsages, arg.AccountPublicID, arg.Domain, arg.Usages) var count int64 err := row.Scan(&count) return count, err @@ -69,21 +71,23 @@ func (q *Queries) CountDynamicRegistrationDomainsByDomains(ctx context.Context, return count, err } -const countDynamicRegistrationDomainsByDomainsAndAccountPublicID = `-- name: CountDynamicRegistrationDomainsByDomainsAndAccountPublicID :one +const countDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages = `-- name: CountDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages :one SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND - "domain" IN ($2) + "usages" @> $2 AND + "domain" IN ($3) LIMIT 1 ` -type CountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams struct { +type CountDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsagesParams struct { AccountPublicID uuid.UUID + Usages []DynamicRegistrationUsage Domains []string } -func (q *Queries) CountDynamicRegistrationDomainsByDomainsAndAccountPublicID(ctx context.Context, arg CountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams) (int64, error) { - row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomainsAndAccountPublicID, arg.AccountPublicID, arg.Domains) +func (q *Queries) CountDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages(ctx context.Context, arg CountDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsagesParams) (int64, error) { + row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages, arg.AccountPublicID, arg.Usages, arg.Domains) var count int64 err := row.Scan(&count) return count, err @@ -122,6 +126,29 @@ func (q *Queries) CountVerifiedDynamicRegistrationDomainsByDomain(ctx context.Co return count, err } +const countVerifiedDynamicRegistrationDomainsByDomainAccountPublicIDAndUsages = `-- name: CountVerifiedDynamicRegistrationDomainsByDomainAccountPublicIDAndUsages :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "domain" = $2 AND + "usages" @> $3 AND + "verified_at" IS NOT NULL +LIMIT 1 +` + +type CountVerifiedDynamicRegistrationDomainsByDomainAccountPublicIDAndUsagesParams struct { + AccountPublicID uuid.UUID + Domain string + Usages []DynamicRegistrationUsage +} + +func (q *Queries) CountVerifiedDynamicRegistrationDomainsByDomainAccountPublicIDAndUsages(ctx context.Context, arg CountVerifiedDynamicRegistrationDomainsByDomainAccountPublicIDAndUsagesParams) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedDynamicRegistrationDomainsByDomainAccountPublicIDAndUsages, arg.AccountPublicID, arg.Domain, arg.Usages) + var count int64 + err := row.Scan(&count) + return count, err +} + const countVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID = `-- name: CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID :one SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE @@ -156,6 +183,29 @@ func (q *Queries) CountVerifiedDynamicRegistrationDomainsByDomains(ctx context.C return count, err } +const countVerifiedDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages = `-- name: CountVerifiedDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "usages" @> $2 AND + "domain" IN ($3) AND + "verified_at" IS NOT NULL +LIMIT 1 +` + +type CountVerifiedDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsagesParams struct { + AccountPublicID uuid.UUID + Usages []DynamicRegistrationUsage + Domains []string +} + +func (q *Queries) CountVerifiedDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages(ctx context.Context, arg CountVerifiedDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsagesParams) (int64, error) { + row := q.db.QueryRow(ctx, countVerifiedDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages, arg.AccountPublicID, arg.Usages, arg.Domains) + var count int64 + err := row.Scan(&count) + return count, err +} + const countVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID = `-- name: CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID :one SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE diff --git a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql index 71d24f8..263f2da 100644 --- a/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql +++ b/idp/internal/providers/database/migrations/20241213231542_create_initial_schema.up.sql @@ -1,6 +1,6 @@ -- SQL dump generated using DBML (dbml.dbdiagram.io) -- Database: PostgreSQL --- Generated at: 2025-11-02T00:22:09.380Z +-- Generated at: 2025-11-04T08:56:30.229Z CREATE TYPE "kek_usage" AS ENUM ( 'global', @@ -257,7 +257,7 @@ CREATE TABLE "token_signing_keys" ( "dek_kid" varchar(22) NOT NULL, "crypto_suite" token_crypto_suite NOT NULL, "expires_at" timestamptz NOT NULL, - "usage" token_key_usage NOT NULL DEFAULT 'account', + "usage" token_key_usage NOT NULL, "is_distributed" boolean NOT NULL DEFAULT false, "is_revoked" boolean NOT NULL DEFAULT false, "created_at" timestamptz NOT NULL DEFAULT (now()), @@ -660,8 +660,8 @@ CREATE TABLE "account_dynamic_registration_configs" ( CREATE TABLE "app_dynamic_registration_configs" ( "id" serial PRIMARY KEY, "account_id" integer NOT NULL, + "account_public_id" uuid NOT NULL, "allowed_app_types" app_type[] NOT NULL, - "whitelisted_domains" varchar(250)[] NOT NULL, "default_allow_user_registration" boolean NOT NULL, "default_auth_providers" auth_provider[] NOT NULL, "default_username_column" app_username_column NOT NULL, @@ -675,7 +675,7 @@ CREATE TABLE "app_dynamic_registration_configs" ( "initial_access_token_ttl" integer NOT NULL DEFAULT 3600, "initial_access_token_max_uses" int NOT NULL DEFAULT 1, "allowed_grant_types" grant_type[] NOT NULL DEFAULT '{ "authorization_code", "refresh_token", "client_credentials", "urn:ietf:params:oauth:grant-type:device_code", "urn:ietf:params:oauth:grant-type:jwt-bearer" }', - "allowed_response_types" response_type[] NOT NULL DEFAULT '{ "code", "id_token", "code id_token" }', + "allowed_response_types" response_type[] NOT NULL DEFAULT '{ "code", "code id_token" }', "allowed_token_endpoint_auth_methods" auth_method[] NOT NULL DEFAULT '{ "none", "client_secret_post", "client_secret_basic", "client_secret_jwt", "private_key_jwt" }', "max_redirect_uris" int NOT NULL DEFAULT 10, "created_at" timestamptz NOT NULL DEFAULT (now()), @@ -757,8 +757,6 @@ CREATE UNIQUE INDEX "token_signing_keys_kid_uidx" ON "token_signing_keys" ("kid" CREATE INDEX "token_signing_keys_expires_at_idx" ON "token_signing_keys" ("expires_at"); -CREATE INDEX "token_signing_keys_is_distributed_is_revoked_expires_at_idx" ON "token_signing_keys" ("is_distributed", "is_revoked", "expires_at"); - CREATE INDEX "token_signing_keys_key_type_usage_is_revoked_expires_at_idx" ON "token_signing_keys" ("key_type", "usage", "is_revoked", "expires_at"); CREATE INDEX "token_signing_keys_usage_is_distributed_is_revoked_expires_at_idx" ON "token_signing_keys" ("usage", "is_distributed", "is_revoked", "expires_at"); @@ -1009,14 +1007,20 @@ CREATE INDEX "account_dynamic_registration_configs_account_public_id_idx" ON "ac CREATE INDEX "app_dynamic_registration_configs_account_id_idx" ON "app_dynamic_registration_configs" ("account_id"); +CREATE INDEX "app_dynamic_registration_configs_account_public_id_idx" ON "app_dynamic_registration_configs" ("account_public_id"); + CREATE INDEX "accounts_totps_account_id_idx" ON "dynamic_registration_domains" ("account_id"); CREATE INDEX "account_dynamic_registration_domains_account_public_id_idx" ON "dynamic_registration_domains" ("account_public_id"); -CREATE INDEX "account_dynamic_registration_domains_domain_idx" ON "dynamic_registration_domains" ("domain"); - CREATE UNIQUE INDEX "account_dynamic_registration_domains_account_public_id_domain_uidx" ON "dynamic_registration_domains" ("account_public_id", "domain"); +CREATE INDEX "account_dynamic_registration_domains_account_public_id_domain_verified_at_idx" ON "dynamic_registration_domains" ("account_public_id", "domain", "verified_at"); + +CREATE INDEX "account_dynamic_registration_domains_account_public_id_domain_usages_idx" ON "dynamic_registration_domains" ("account_public_id", "domain", "usages"); + +CREATE INDEX "account_dynamic_registration_domains_account_public_id_domain_usages_verified_at_idx" ON "dynamic_registration_domains" ("account_public_id", "domain", "usages", "verified_at"); + CREATE INDEX "dynamic_registration_domain_codes_account_id_idx" ON "dynamic_registration_domain_codes" ("account_id"); CREATE INDEX "dynamic_registration_domain_codes_dynamic_registration_domain_id_idx" ON "dynamic_registration_domain_codes" ("dynamic_registration_domain_id"); diff --git a/idp/internal/providers/database/models.go b/idp/internal/providers/database/models.go index c23b977..8196938 100644 --- a/idp/internal/providers/database/models.go +++ b/idp/internal/providers/database/models.go @@ -1604,8 +1604,8 @@ type AppDesign struct { type AppDynamicRegistrationConfig struct { ID int32 AccountID int32 + AccountPublicID uuid.UUID AllowedAppTypes []AppType - WhitelistedDomains []string DefaultAllowUserRegistration bool DefaultAuthProviders []AuthProvider DefaultUsernameColumn AppUsernameColumn diff --git a/idp/internal/providers/database/queries/account_token_signing_keys.sql b/idp/internal/providers/database/queries/account_token_signing_keys.sql index b475999..d107689 100644 --- a/idp/internal/providers/database/queries/account_token_signing_keys.sql +++ b/idp/internal/providers/database/queries/account_token_signing_keys.sql @@ -29,6 +29,7 @@ LIMIT 1; SELECT "t"."public_key" FROM "token_signing_keys" AS "t" LEFT JOIN "account_token_signing_keys" AS "atsk" ON "t"."id" = "atsk"."token_signing_key_id" WHERE "atsk"."account_id" = $1 AND + "t"."usage" = 'account' AND "t"."is_distributed" = true AND "t"."is_revoked" = false AND "t"."expires_at" > NOW() diff --git a/idp/internal/providers/database/queries/app_dynamic_registration_configs.sql b/idp/internal/providers/database/queries/app_dynamic_registration_configs.sql new file mode 100644 index 0000000..3b3bd87 --- /dev/null +++ b/idp/internal/providers/database/queries/app_dynamic_registration_configs.sql @@ -0,0 +1,82 @@ +-- Copyright (c) 2025 Afonso Barracha +-- +-- This Source Code Form is subject to the terms of the Mozilla Public +-- License, v. 2.0. If a copy of the MPL was not distributed with this +-- file, You can obtain one at https://mozilla.org/MPL/2.0/. + +-- name: CreateAppDynamicRegistrationConfig :one +INSERT INTO "app_dynamic_registration_configs" ( + "account_id", + "account_public_id", + "allowed_app_types", + "default_allow_user_registration", + "default_auth_providers", + "default_username_column", + "default_allowed_scopes", + "default_scopes", + "require_verified_domains_app_types", + "require_software_statement_app_types", + "software_statement_verification_methods", + "require_initial_access_token_app_types", + "initial_access_token_generation_methods", + "initial_access_token_ttl", + "initial_access_token_max_uses", + "allowed_grant_types", + "allowed_response_types", + "allowed_token_endpoint_auth_methods", + "max_redirect_uris" +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9, + $10, + $11, + $12, + $13, + $14, + $15, + $16, + $17, + $18, + $19 +) RETURNING *; + +-- name: UpdateAppDynamicRegistrationConfig :one +UPDATE "app_dynamic_registration_configs" SET + "allowed_app_types" = $2, + "default_allow_user_registration" = $3, + "default_auth_providers" = $4, + "default_username_column" = $5, + "default_allowed_scopes" = $6, + "default_scopes" = $7, + "require_verified_domains_app_types" = $8, + "require_software_statement_app_types" = $9, + "software_statement_verification_methods" = $10, + "require_initial_access_token_app_types" = $11, + "initial_access_token_generation_methods" = $12, + "initial_access_token_ttl" = $13, + "initial_access_token_max_uses" = $14, + "allowed_grant_types" = $15, + "allowed_response_types" = $16, + "allowed_token_endpoint_auth_methods" = $17, + "max_redirect_uris" = $18 +WHERE "id" = $1 +RETURNING *; + +-- name: FindAppDynamicRegistrationConfigByAccountPublicID :one +SELECT * FROM "app_dynamic_registration_configs" +WHERE "account_public_id" = $1 LIMIT 1; + +-- name: FindAppDynamicRegistrationConfigByAccountID :one +SELECT * FROM "app_dynamic_registration_configs" +WHERE "account_id" = $1 LIMIT 1; + +-- name: DeleteAppDynamicRegistrationConfig :exec +DELETE FROM "app_dynamic_registration_configs" WHERE "id" = $1; + diff --git a/idp/internal/providers/database/queries/apps.sql b/idp/internal/providers/database/queries/apps.sql index 4bfb39d..bf69889 100644 --- a/idp/internal/providers/database/queries/apps.sql +++ b/idp/internal/providers/database/queries/apps.sql @@ -67,6 +67,11 @@ SELECT COUNT(*) FROM "apps" WHERE "account_id" = $1 AND "client_name" = $2 LIMIT 1; +-- name: CountAppsByAccountIDAndCliantNameOrSoftwareID :one +SELECT COUNT(*) FROM "apps" +WHERE "account_id" = $1 AND ("client_name" = $2 OR "software_id" = $3) +LIMIT 1; + -- name: FindAppByClientID :one SELECT * FROM "apps" WHERE "client_id" = $1 LIMIT 1; diff --git a/idp/internal/providers/database/queries/dynamic_registration_domains.sql b/idp/internal/providers/database/queries/dynamic_registration_domains.sql index f646140..64a8c51 100644 --- a/idp/internal/providers/database/queries/dynamic_registration_domains.sql +++ b/idp/internal/providers/database/queries/dynamic_registration_domains.sql @@ -96,6 +96,15 @@ WHERE "verified_at" IS NOT NULL LIMIT 1; +-- name: CountVerifiedDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "usages" @> $2 AND + "domain" IN (sqlc.slice('domains')) AND + "verified_at" IS NOT NULL +LIMIT 1; + -- name: CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID :one SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE @@ -104,18 +113,29 @@ WHERE "verified_at" IS NOT NULL LIMIT 1; --- name: CountDynamicRegistrationDomainsByDomainsAndAccountPublicID :one +-- name: CountVerifiedDynamicRegistrationDomainsByDomainAccountPublicIDAndUsages :one SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND + "domain" = $2 AND + "usages" @> $3 AND + "verified_at" IS NOT NULL +LIMIT 1; + +-- name: CountDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "account_public_id" = $1 AND + "usages" @> $2 AND "domain" IN (sqlc.slice('domains')) LIMIT 1; --- name: CountDynamicRegistrationDomainsByDomainAndAccountPublicID :one +-- name: CountDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUsages :one SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND - "domain" = $2 + "domain" = $2 AND + "usages" @> $3 LIMIT 1; -- name: DeleteDynamicRegistrationDomain :exec diff --git a/idp/internal/providers/tokens/dynamic_registration_iat.go b/idp/internal/providers/tokens/dynamic_registration_iat.go index d2b59c3..f97a5fd 100644 --- a/idp/internal/providers/tokens/dynamic_registration_iat.go +++ b/idp/internal/providers/tokens/dynamic_registration_iat.go @@ -8,7 +8,9 @@ package tokens import ( "context" + "errors" "fmt" + "net/url" "time" "github.com/golang-jwt/jwt/v5" @@ -25,20 +27,21 @@ type accountCredentialsDynamicRegistrationClaims struct { jwt.RegisteredClaims } -type AccountCredentialsDynamicRegistrationTokenOptions struct { +type DynamicRegistrationIATOptions struct { AccountPublicID uuid.UUID AccountVersion int32 + IssuerDomain string Domain string ClientID string } -func (t *Tokens) CreateAccountCredentialsDynamicRegistrationToken( - opts AccountCredentialsDynamicRegistrationTokenOptions, +func (t *Tokens) DynamicRegistrationIAT( + opts DynamicRegistrationIATOptions, ) *jwt.Token { now := time.Now() iat := jwt.NewNumericDate(now) exp := jwt.NewNumericDate(now.Add(time.Second * time.Duration(t.dynamicRegistrationTTL))) - iss := fmt.Sprintf("https://%s", t.backendDomain) + iss := fmt.Sprintf("https://%s", opts.IssuerDomain) return jwt.NewWithClaims( jwt.SigningMethodEdDSA, accountCredentialsDynamicRegistrationClaims{ @@ -61,19 +64,20 @@ func (t *Tokens) CreateAccountCredentialsDynamicRegistrationToken( ) } -type VerifyAccountCredentialsDynamicRegistrationTokenOptions struct { +type VerifyDynamicRegistrationIATOptions struct { RequestID string IAT string + IssuerDomain string GetPublicJWK GetPublicJWK } -func (t *Tokens) VerifyAccountCredentialsDynamicRegistrationToken( +func (t *Tokens) VerifyDynamicRegistrationIAT( ctx context.Context, - opts VerifyAccountCredentialsDynamicRegistrationTokenOptions, + opts VerifyDynamicRegistrationIATOptions, ) (string, AccountClaims, error) { logger := utils.BuildLogger(t.logger, utils.LoggerOptions{ Location: dynamicRegistrationIATLocation, - Method: "VerifyAccountCredentialsDynamicRegistrationToken", + Method: "VerifyDynamicRegistrationIAT", RequestID: opts.RequestID, }) logger.DebugContext(ctx, "Verifying account credentials dynamic registration IAT...") @@ -98,6 +102,32 @@ func (t *Tokens) VerifyAccountCredentialsDynamicRegistrationToken( return "", AccountClaims{}, err } + issDomain, err := url.Parse(claims.Issuer) + if err != nil { + logger.WarnContext(ctx, "Failed to parse issuer from account credentials dynamic registration IAT", "error", err, "issuer", claims.Issuer) + return "", AccountClaims{}, err + } + if issDomain.Host != opts.IssuerDomain { + logger.WarnContext(ctx, "Issuer domain mismatch in account credentials dynamic registration IAT", "expected", opts.IssuerDomain, "actual", issDomain.Host) + return "", AccountClaims{}, errors.New("issuer domain mismatch") + } + + if len(claims.Audience) == 0 { + logger.WarnContext(ctx, "Missing audience in account credentials dynamic registration IAT") + return "", AccountClaims{}, errors.New("missing audience") + } + + audDomain, err := url.Parse(claims.Audience[0]) + if err != nil { + logger.WarnContext(ctx, "Failed to parse audience from account credentials dynamic registration IAT", "error", err, "audience", claims.Audience[0]) + return "", AccountClaims{}, err + } + if audDomain.Host != opts.IssuerDomain { + logger.WarnContext(ctx, "Audience domain mismatch in account credentials dynamic registration IAT", "expected", opts.IssuerDomain, "actual", audDomain.Host) + return "", AccountClaims{}, errors.New("audience domain mismatch") + } + + logger.InfoContext(ctx, "Verified account credentials dynamic registration IAT successfully") return claims.Domain, claims.AccountClaims, nil } diff --git a/idp/internal/server/routes.go b/idp/internal/server/routes.go index ddc942d..2e58750 100644 --- a/idp/internal/server/routes.go +++ b/idp/internal/server/routes.go @@ -15,10 +15,10 @@ func (s *FiberServer) RegisterFiberRoutes() { s.routes.AccountCredentialsSecretsRoutes(s.App) s.routes.AccountKeysRoutes(s.App) s.routes.AccountsRoutes(s.App) - s.routes.UsersAuthRoutes(s.App) s.routes.AppsRoutes(s.App) s.routes.AppDesignsRoutes(s.App) s.routes.AppSecretsRoutes(s.App) + s.routes.AppDynamicRegistrationConfigRoutes(s.App) s.routes.OIDCConfigsRoutes(s.App) s.routes.UsersRoutes(s.App) s.routes.WellKnownRoutes(s.App) diff --git a/idp/internal/server/routes/account_credentials.go b/idp/internal/server/routes/account_credentials.go index 46bbd23..9878911 100644 --- a/idp/internal/server/routes/account_credentials.go +++ b/idp/internal/server/routes/account_credentials.go @@ -14,7 +14,7 @@ import ( ) func (r *Routes) AccountCredentialsRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AccountsBase + paths.CredentialsBase) + router := V1PathRouter(app).Group(paths.AccountsBase+paths.CredentialsBase, r.controllers.NoHostMiddleware) credentialsWriteScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsWrite) credentialsReadScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsRead) @@ -52,7 +52,7 @@ func (r *Routes) AccountCredentialsRoutes(app *fiber.App) { } func (r *Routes) AccountCredentialsSecretsRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AccountsBase + paths.CredentialsBase) + router := V1PathRouter(app).Group(paths.AccountsBase + paths.CredentialsBase) credentialsWriteScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsWrite) credentialsReadScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsRead) @@ -84,7 +84,7 @@ func (r *Routes) AccountCredentialsSecretsRoutes(app *fiber.App) { } func (r *Routes) AccountKeysRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AccountsBase) + router := V1PathRouter(app).Group(paths.AccountsBase) router.Get(paths.AccountsSingle+paths.Keys, r.controllers.ListAccountCredentialsKeys) } diff --git a/idp/internal/server/routes/account_dynamic_registration.go b/idp/internal/server/routes/account_dynamic_registration.go index ca95cd2..1e42e4e 100644 --- a/idp/internal/server/routes/account_dynamic_registration.go +++ b/idp/internal/server/routes/account_dynamic_registration.go @@ -14,7 +14,7 @@ import ( ) func (r *Routes) AccountDynamicRegistrationConfigurationRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AccountsBase + paths.CredentialsBase + paths.DynamicRegistrationBase) + router := V1PathRouter(app).Group(paths.AccountsBase+paths.CredentialsBase+paths.DynamicRegistrationBase, r.controllers.NoHostMiddleware) credentialsConfigsWriteScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsConfigsWrite) credentialsConfigsReadScopeMiddleware := r.controllers.ScopeMiddleware(tokens.AccountScopeCredentialsConfigsRead) @@ -80,27 +80,4 @@ func (r *Routes) AccountDynamicRegistrationConfigurationRoutes(app *fiber.App) { credentialsConfigsWriteScopeMiddleware, r.controllers.DeleteAccountCredentialsRegistrationDomainCode, ) - - // Initial Access Token (IAT) routes - iatRouter := router.Group(paths.InitialAccessToken) - - // Dynamic Registration IAT Code Exchange flow - iatRouter.Get(paths.OAuthAuth, r.controllers.OAuthDynamicRegistrationIATAuth) - iatRouter.Post(paths.OAuthToken, r.controllers.OAuthDynamicRegistrationIATToken) - - // Dynamic Registration IAT Login flow - const loginRoute = paths.InitialAccessTokenSingle + paths.AuthLogin - iatRouter.Get(loginRoute, r.controllers.OAuthDynamicRegistrationIATLoginGet) - iatRouter.Post(loginRoute, r.controllers.OAuthDynamicRegistrationIATLoginPost) - - // Dynamic Registration IAT 2FA flow - const twoFAAuthRoute = loginRoute + paths.Auth2FA - iatRouter.Get(twoFAAuthRoute, r.controllers.OAuthDynamicRegistrationIAT2FAGet) - iatRouter.Post(twoFAAuthRoute, r.controllers.OAuthDynamicRegistrationIAT2FAPost) - - // Dynamic Registration IAT External Auth flow - const extAuthRoute = paths.InitialAccessTokenSingle + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT - iatRouter.Get(extAuthRoute+paths.InitialAccessTokenProvider, r.controllers.OAuthDynamicRegistrationIATExtAuthGet) - iatRouter.Post(extAuthRoute+paths.OAuthAppleCallback, r.controllers.OAuthDynamicRegistrationIATExtAppleCB) - iatRouter.Get(extAuthRoute+paths.OAuthCallback, r.controllers.OAuthDynamicRegistrationIATExtCB) } diff --git a/idp/internal/server/routes/accounts.go b/idp/internal/server/routes/accounts.go index fa9b01f..ae2e297 100644 --- a/idp/internal/server/routes/accounts.go +++ b/idp/internal/server/routes/accounts.go @@ -13,7 +13,7 @@ import ( ) func (r *Routes) AccountsRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AccountsBase) + router := V1PathRouter(app).Group(paths.AccountsBase) router.Get(paths.AccountUserInfo, r.controllers.AccountAccessClaimsMiddleware, r.controllers.GetCurrentAccount) router.Put( diff --git a/idp/internal/server/routes/app_designs.go b/idp/internal/server/routes/app_designs.go index d8d36c4..f13aea2 100644 --- a/idp/internal/server/routes/app_designs.go +++ b/idp/internal/server/routes/app_designs.go @@ -14,7 +14,7 @@ import ( ) func (r *Routes) AppDesignsRoutes(app *fiber.App) { - appDesigns := v1PathRouter(app).Group(paths.AppsBase + paths.AppsSingle) + appDesigns := V1PathRouter(app).Group(paths.AppsBase + paths.AppsSingle) appsWriteScope := r.controllers.ScopeMiddleware(tokens.AccountScopeAppsWrite) diff --git a/idp/internal/server/routes/apps.go b/idp/internal/server/routes/apps.go index 49be372..d898d7c 100644 --- a/idp/internal/server/routes/apps.go +++ b/idp/internal/server/routes/apps.go @@ -14,7 +14,7 @@ import ( ) func (r *Routes) AppsRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AppsBase) + router := V1PathRouter(app).Group(paths.AppsBase) appsWriteScope := r.controllers.ScopeMiddleware(tokens.AccountScopeAppsWrite) appsReadScope := r.controllers.ScopeMiddleware(tokens.AccountScopeAppsRead) @@ -52,7 +52,7 @@ func (r *Routes) AppsRoutes(app *fiber.App) { } func (r *Routes) AppSecretsRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AppsBase) + router := V1PathRouter(app).Group(paths.AppsBase) appsWriteScope := r.controllers.ScopeMiddleware(tokens.AccountScopeAppsWrite) appsReadScope := r.controllers.ScopeMiddleware(tokens.AccountScopeAppsRead) @@ -82,3 +82,29 @@ func (r *Routes) AppSecretsRoutes(app *fiber.App) { r.controllers.RevokeAppSecret, ) } + +func (r *Routes) AppDynamicRegistrationConfigRoutes(app *fiber.App) { + router := V1PathRouter(app).Group(paths.AppsBase + paths.DynamicRegistrationBase + paths.Config) + + appsConfigsWriteScope := r.controllers.ScopeMiddleware(tokens.AccountScopeAppsConfigsWrite) + appsConfigsReadScope := r.controllers.ScopeMiddleware(tokens.AccountScopeAppsConfigsRead) + + router.Get( + paths.Base, + r.controllers.AccountAccessClaimsMiddleware, + appsConfigsReadScope, + r.controllers.GetAppDynamicRegistrationConfig, + ) + router.Put( + paths.Base, + r.controllers.AccountAccessClaimsMiddleware, + appsConfigsWriteScope, + r.controllers.UpsertAppDynamicRegistrationConfig, + ) + router.Delete( + paths.Base, + r.controllers.AccountAccessClaimsMiddleware, + appsConfigsWriteScope, + r.controllers.DeleteAppDynamicRegistrationConfig, + ) +} diff --git a/idp/internal/server/routes/auth.go b/idp/internal/server/routes/auth.go index a48d64c..d01b63b 100644 --- a/idp/internal/server/routes/auth.go +++ b/idp/internal/server/routes/auth.go @@ -14,14 +14,22 @@ import ( ) func (r *Routes) AuthRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AuthBase) - + router := V1PathRouter(app).Group(paths.AuthBase) authProvsReaderMW := r.controllers.ScopeMiddleware(tokens.AccountScopeAuthProvidersRead) // Custom auth paths - router.Post(paths.AuthRegister, r.controllers.RegisterAccount) - router.Post(paths.AuthConfirmEmail, r.controllers.ConfirmAccount) - router.Post(paths.AuthLogin, r.controllers.LoginAccount) + router.Post(paths.AuthRegister, r.controllers.HostMiddleware, HostAwareRoute( + []fiber.Handler{r.controllers.RegisterAccount}, + []fiber.Handler{r.controllers.AppAccessClaimsMiddleware, r.controllers.RegisterUser}, + )) + router.Post(paths.AuthConfirmEmail, r.controllers.HostMiddleware, HostAwareRoute( + []fiber.Handler{r.controllers.ConfirmAccount}, + []fiber.Handler{r.controllers.AppAccessClaimsMiddleware, r.controllers.ConfirmUser}, + )) + router.Post(paths.AuthLogin, r.controllers.HostMiddleware, HostAwareRoute( + []fiber.Handler{r.controllers.LoginAccount}, + []fiber.Handler{r.controllers.AppAccessClaimsMiddleware, r.controllers.LoginUser}, + )) router.Post( paths.AuthLogin+paths.Auth2FA, r.controllers.TwoFAAccessClaimsMiddleware, @@ -32,10 +40,22 @@ func (r *Routes) AuthRoutes(app *fiber.App) { r.controllers.TwoFAAccessClaimsMiddleware, r.controllers.RecoverAccount, ) - router.Post(paths.AuthRefresh, r.controllers.RefreshAccount) - router.Post(paths.AuthLogout, r.controllers.AccountAccessClaimsMiddleware, r.controllers.LogoutAccount) - router.Post(paths.AuthForgotPassword, r.controllers.ForgotAccountPassword) - router.Post(paths.AuthResetPassword, r.controllers.ResetAccountPassword) + router.Post(paths.AuthRefresh, r.controllers.HostMiddleware, HostAwareRoute( + []fiber.Handler{r.controllers.RefreshAccount}, + []fiber.Handler{r.controllers.AppAccessClaimsMiddleware, r.controllers.RefreshUser}, + )) + router.Post(paths.AuthLogout, r.controllers.HostMiddleware, HostAwareRoute( + []fiber.Handler{r.controllers.AccountAccessClaimsMiddleware, r.controllers.LogoutAccount}, + []fiber.Handler{r.controllers.UserAccessClaimsMiddleware, r.controllers.LogoutUser}, + )) + router.Post(paths.AuthForgotPassword, r.controllers.HostMiddleware, HostAwareRoute( + []fiber.Handler{r.controllers.ForgotAccountPassword}, + []fiber.Handler{r.controllers.AppAccessClaimsMiddleware, r.controllers.ForgotUserPassword}, + )) + router.Post(paths.AuthResetPassword, r.controllers.HostMiddleware, HostAwareRoute( + []fiber.Handler{r.controllers.ResetAccountPassword}, + []fiber.Handler{r.controllers.AppAccessClaimsMiddleware, r.controllers.ResetUserPassword}, + )) router.Get( paths.AuthProviders, r.controllers.AccountAccessClaimsMiddleware, diff --git a/idp/internal/server/routes/common.go b/idp/internal/server/routes/common.go index 1b97e77..c36a266 100644 --- a/idp/internal/server/routes/common.go +++ b/idp/internal/server/routes/common.go @@ -10,8 +10,41 @@ import ( "github.com/gofiber/fiber/v2" "github.com/tugascript/devlogs/idp/internal/controllers/paths" + "github.com/tugascript/devlogs/idp/internal/exceptions" ) -func v1PathRouter(app *fiber.App) fiber.Router { +var errorResponseNotFound = exceptions.ErrorResponse{ + Code: exceptions.StatusNotFound, + Message: exceptions.MessageNotFound, +} + +func V1PathRouter(app *fiber.App) fiber.Router { return app.Group(paths.V1) } + +func HostAwareRoute( + normalHandlers []fiber.Handler, + hostHandlers []fiber.Handler, +) fiber.Handler { + return func(ctx *fiber.Ctx) error { + hasAccountHost, ok := ctx.Locals("hasAccountHost").(bool) + + if !ok || !hasAccountHost { + for _, handler := range normalHandlers { + if err := handler(ctx); err != nil { + return err + } + } + + return ctx.Status(fiber.StatusNotFound).JSON(errorResponseNotFound) + } + + for _, handler := range hostHandlers { + if err := handler(ctx); err != nil { + return err + } + } + + return ctx.Status(fiber.StatusNotFound).JSON(errorResponseNotFound) + } +} diff --git a/idp/internal/server/routes/oauth.go b/idp/internal/server/routes/oauth.go index 6255059..403388f 100644 --- a/idp/internal/server/routes/oauth.go +++ b/idp/internal/server/routes/oauth.go @@ -13,10 +13,13 @@ import ( ) func (r *Routes) OAuthRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AuthBase + paths.OAuthBase) + router := V1PathRouter(app).Group(paths.AuthBase + paths.OAuthBase) // Known auth paths (oauth2) - router.Post(paths.OAuthKeys, r.controllers.AccountOAuthPublicJWKs) + router.Post(paths.OAuthKeys, r.controllers.HostMiddleware, HostAwareRoute( + []fiber.Handler{r.controllers.GlobalOAuthPublicJWKs}, + []fiber.Handler{r.controllers.AccountDistributedOAuthPublicJWKs}, + )) router.Post(paths.OAuthToken, r.controllers.AccountOAuthToken) router.Get(paths.OAuthAuth, r.controllers.AccountOAuthURL) @@ -25,9 +28,37 @@ func (r *Routes) OAuthRoutes(app *fiber.App) { router.Get(paths.OAuthCallback, r.controllers.AccountOAuthCallback) // Register - router.Post( - paths.AccountsBase+paths.AccountsSingle+paths.OAuthRegister, - r.controllers.AccountCredentialsDRIATMiddleware, - r.controllers.OAuthDynamicRegistration, - ) + router.Post(paths.OAuthRegister, r.controllers.HostMiddleware, HostAwareRoute( + []fiber.Handler{ + r.controllers.AccountCredentialsDRIATMiddleware, + r.controllers.OAuthDynamicRegistration, + }, + []fiber.Handler{ + // TODO: add app claims for DR + r.controllers.OAuthDynamicRegistration, + }, + )) + + // Initial Access Token (IAT) routes + iatRouter := router.Group(paths.InitialAccessToken) + + // Dynamic Registration IAT Code Exchange flow + iatRouter.Get(paths.OAuthAuth, r.controllers.OAuthDynamicRegistrationIATAuth) + iatRouter.Post(paths.OAuthToken, r.controllers.OAuthDynamicRegistrationIATToken) + + // Dynamic Registration IAT Login flow + const loginRoute = paths.InitialAccessTokenSingle + paths.AuthLogin + iatRouter.Get(loginRoute, r.controllers.OAuthDynamicRegistrationIATLoginGet) + iatRouter.Post(loginRoute, r.controllers.OAuthDynamicRegistrationIATLoginPost) + + // Dynamic Registration IAT 2FA flow + const twoFAAuthRoute = loginRoute + paths.Auth2FA + iatRouter.Get(twoFAAuthRoute, r.controllers.OAuthDynamicRegistrationIAT2FAGet) + iatRouter.Post(twoFAAuthRoute, r.controllers.OAuthDynamicRegistrationIAT2FAPost) + + // Dynamic Registration IAT External Auth flow + const extAuthRoute = paths.InitialAccessTokenSingle + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT + iatRouter.Get(extAuthRoute+paths.InitialAccessTokenProvider, r.controllers.OAuthDynamicRegistrationIATExtAuthGet) + iatRouter.Post(extAuthRoute+paths.OAuthAppleCallback, r.controllers.OAuthDynamicRegistrationIATExtAppleCB) + iatRouter.Get(extAuthRoute+paths.OAuthCallback, r.controllers.OAuthDynamicRegistrationIATExtCB) } diff --git a/idp/internal/server/routes/oidc_configs.go b/idp/internal/server/routes/oidc_configs.go index ff679bc..78a8f67 100644 --- a/idp/internal/server/routes/oidc_configs.go +++ b/idp/internal/server/routes/oidc_configs.go @@ -13,7 +13,7 @@ import ( ) func (r *Routes) OIDCConfigsRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.OIDCConfigBase, r.controllers.AccountAccessClaimsMiddleware) + router := V1PathRouter(app).Group(paths.OIDCConfigBase, r.controllers.AccountAccessClaimsMiddleware) router.Get(paths.Base, r.controllers.GetOIDCConfig) router.Post(paths.Base, r.controllers.AdminScopeMiddleware, r.controllers.CreateOIDCConfig) diff --git a/idp/internal/server/routes/users.go b/idp/internal/server/routes/users.go index 9231522..c9dd2b3 100644 --- a/idp/internal/server/routes/users.go +++ b/idp/internal/server/routes/users.go @@ -24,7 +24,7 @@ import ( ) func (r *Routes) UsersRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.UsersBase, r.controllers.AccountAccessClaimsMiddleware) + router := V1PathRouter(app).Group(paths.UsersBase, r.controllers.AccountAccessClaimsMiddleware) usersReadScope := r.controllers.ScopeMiddleware(tokens.AccountScopeUsersRead) usersWriteScope := r.controllers.ScopeMiddleware(tokens.AccountScopeUsersWrite) diff --git a/idp/internal/server/routes/users_auth.go b/idp/internal/server/routes/users_auth.go deleted file mode 100644 index ba425d0..0000000 --- a/idp/internal/server/routes/users_auth.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 Afonso Barracha -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -package routes - -import ( - "github.com/gofiber/fiber/v2" - - "github.com/tugascript/devlogs/idp/internal/controllers/paths" -) - -func (r *Routes) UsersAuthRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.AppsBase+paths.UsersBase, r.controllers.AccountHostMiddleware) - - router.Post(paths.AuthRegister, r.controllers.AppAccessClaimsMiddleware, r.controllers.RegisterUser) - router.Post(paths.AuthConfirmEmail, r.controllers.AppAccessClaimsMiddleware, r.controllers.ConfirmUser) - router.Post(paths.AuthLogin, r.controllers.AppAccessClaimsMiddleware, r.controllers.LoginUser) - - // TODO: Add 2FA Login - - router.Post(paths.AuthRefresh, r.controllers.AppAccessClaimsMiddleware, r.controllers.RefreshUser) - router.Post( - paths.AuthLogout, - r.controllers.UserAccessClaimsMiddleware, - r.controllers.LogoutUser, - ) - router.Post(paths.AuthForgotPassword, r.controllers.AppAccessClaimsMiddleware, r.controllers.ForgotUserPassword) - router.Post(paths.AuthResetPassword, r.controllers.AppAccessClaimsMiddleware, r.controllers.ResetUserPassword) -} diff --git a/idp/internal/server/routes/well_known.go b/idp/internal/server/routes/well_known.go index cecffed..d92b753 100644 --- a/idp/internal/server/routes/well_known.go +++ b/idp/internal/server/routes/well_known.go @@ -13,8 +13,11 @@ import ( ) func (r *Routes) WellKnownRoutes(app *fiber.App) { - router := v1PathRouter(app).Group(paths.WellKnownBase, r.controllers.AccountHostMiddleware) + router := V1PathRouter(app).Group(paths.WellKnownBase, r.controllers.HostMiddleware) - router.Get(paths.WellKnownJWKs, r.controllers.WellKnownJWKs) + router.Get(paths.WellKnownJWKs, HostAwareRoute( + []fiber.Handler{r.controllers.GlobalOAuthPublicJWKs}, + []fiber.Handler{r.controllers.AccountDistributedOAuthPublicJWKs}, + )) router.Get(paths.WellKnownOIDC, r.controllers.WellKnownOIDCConfiguration) } diff --git a/idp/internal/services/account_credentials_registration.go b/idp/internal/services/account_credentials_registration.go index bcbdc09..d205124 100644 --- a/idp/internal/services/account_credentials_registration.go +++ b/idp/internal/services/account_credentials_registration.go @@ -8,14 +8,12 @@ package services import ( "context" - "errors" "net/url" "slices" "strings" "time" "github.com/google/uuid" - "golang.org/x/net/publicsuffix" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/database" @@ -45,249 +43,8 @@ var allowedAccountCredentialsScopes []string = []string{ string(database.AccountCredentialsScopeAccountAuthProvidersRead), } -type checkAccountCRDomainOptions struct { - requestID string - accountPublicID uuid.UUID - domain string - iatDomain string - requireVerifiedDomains bool -} - -func (s *Services) checkAccountCRDomain( - ctx context.Context, - opts checkAccountCRDomainOptions, -) (string, *exceptions.ServiceError) { - logger := s.buildLogger(opts.requestID, dynamicRegistrationDomainsLocation, "checkAccountCRDomain").With( - "domain", opts.domain, - ) - logger.InfoContext(ctx, "Checking account credential domain validity") - - baseDomain, err := publicsuffix.EffectiveTLDPlusOne(opts.domain) - if err != nil { - logger.WarnContext(ctx, "Failed to parse base domain", "error", err) - return "", exceptions.NewValidationError("invalid client URI") - } - if opts.iatDomain != "" && opts.domain != opts.iatDomain && baseDomain != opts.iatDomain { - logger.WarnContext(ctx, "Client URI base domain does not match IAT domain", - "baseDomain", baseDomain, - "iatDomain", opts.iatDomain, - ) - return "", exceptions.NewUnauthorizedError() - } - - var count int64 - if baseDomain != opts.domain { - if opts.requireVerifiedDomains { - count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicID( - ctx, - database.CountVerifiedDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams{ - AccountPublicID: opts.accountPublicID, - Domains: []string{opts.domain, baseDomain}, - }, - ) - } else { - count, err = s.database.CountDynamicRegistrationDomainsByDomainsAndAccountPublicID( - ctx, - database.CountDynamicRegistrationDomainsByDomainsAndAccountPublicIDParams{ - AccountPublicID: opts.accountPublicID, - Domains: []string{opts.domain, baseDomain}, - }, - ) - } - } else { - if opts.requireVerifiedDomains { - count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicID( - ctx, - database.CountVerifiedDynamicRegistrationDomainsByDomainAndAccountPublicIDParams{ - AccountPublicID: opts.accountPublicID, - Domain: opts.domain, - }, - ) - } else { - count, err = s.database.CountDynamicRegistrationDomainsByDomainAndAccountPublicID( - ctx, - database.CountDynamicRegistrationDomainsByDomainAndAccountPublicIDParams{ - AccountPublicID: opts.accountPublicID, - Domain: opts.domain, - }, - ) - } - } - - if err != nil { - logger.ErrorContext(ctx, "Failed to count verified account dynamic registration domains", "error", err) - return "", exceptions.FromDBError(err) - } - if count > 0 { - logger.InfoContext(ctx, "Account credential domain is whitelisted") - return baseDomain, nil - } - - logger.InfoContext(ctx, "Account credential domain is not whitelisted or verified") - return "", exceptions.NewUnauthorizedError() -} - -type buildAccountCRSoftwareStatementFuncOptions struct { - requestID string - accountPublicID uuid.UUID - verificationMethods []database.SoftwareStatementVerificationMethod - jwksURI string - jwks []string - domain string - baseDomain string -} - -func (s *Services) buildAccountCRSoftwareStatementFunc( - ctx context.Context, - opts buildAccountCRSoftwareStatementFuncOptions, -) tokens.GetUnknownPublicJWK { - logger := s.buildLogger(opts.requestID, accountCredentialsRegistrationLocation, "buildAccountCRSoftwareStatementFunc").With( - "accountPublicID", opts.accountPublicID, - ) - logger.InfoContext(ctx, "Checking account credential software statement validity") - - if slices.Contains(opts.verificationMethods, database.SoftwareStatementVerificationMethodJwksUri) && opts.jwksURI != "" { - return func(kid string) (utils.JWK, error) { - parsedURI, err := url.Parse(opts.jwksURI) - if err != nil { - logger.ErrorContext(ctx, "Failed to parse JWKs URI", "error", err) - return nil, errors.New("invalid JWKs URI") - } - if parsedURI.Host != opts.baseDomain || !strings.Contains(parsedURI.Host, "."+opts.baseDomain) { - logger.WarnContext(ctx, "JWKs URI parsedURI does not match client URI parsedURI") - return nil, errors.New("JWKs URI parsedURI does not match client URI parsedURI") - } - - jwks, err := s.jwt.GetPublicJWKs(ctx, tokens.GetPublicJWKsOptions{ - RequestID: opts.requestID, - URL: opts.jwksURI, - }) - if err != nil { - logger.WarnContext(ctx, "Failed to get public JWKs from JWKs URI", "error", err) - return nil, errors.New("failed to get public JWKs from JWKs URI") - } - - jwkIdx := slices.IndexFunc(jwks.Keys, func(jwk utils.JWK) bool { - return jwk.GetKeyID() == kid - }) - if jwkIdx == -1 { - logger.WarnContext(ctx, "No matching JWK found for KID in JWKs URI", "kid", kid) - return nil, errors.New("no matching JWK found for KID in JWKs URI") - } - - return jwks.Keys[jwkIdx], nil - } - } - if slices.Contains(opts.verificationMethods, database.SoftwareStatementVerificationMethodManual) { - if len(opts.jwks) > 0 { - return func(kid string) (utils.JWK, error) { - jwks := make([]utils.JWK, 0, len(opts.jwks)) - for _, rawJWK := range opts.jwks { - jwk, err := utils.JsonToJWK([]byte(rawJWK)) - if err != nil { - logger.ErrorContext(ctx, "Failed to parse manual JWK", "error", err) - return nil, errors.New("failed to parse manual JWK") - } - jwks = append(jwks, jwk) - } - - jwkIdx := slices.IndexFunc(jwks, func(jwk utils.JWK) bool { - return jwk.GetKeyID() == kid - }) - if jwkIdx == -1 { - logger.WarnContext(ctx, "No matching manual JWK found for KID", "kid", kid) - return nil, errors.New("no matching manual JWK found for KID") - } - - sliceJWK := jwks[jwkIdx] - jwkRefEnt, err := s.database.FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicID( - ctx, - database.FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicIDParams{ - CredentialsKeyKid: kid, - AccountPublicID: opts.accountPublicID, - }, - ) - if err != nil { - serviceErr := exceptions.FromDBError(err) - if serviceErr.Code == exceptions.CodeNotFound { - logger.WarnContext(ctx, "No database entry found for manual JWK", "kid", kid, "error", err) - return nil, errors.New("no database entry found for manual JWK") - } - - logger.ErrorContext(ctx, "Failed to find database entry for manual JWK", "kid", kid, "error", err) - return nil, errors.New("failed to find database entry for manual JWK") - } - if jwkRefEnt.RootDomain != opts.baseDomain { - logger.WarnContext(ctx, "Manual JWK root domain does not match client URI base domain", - "kid", kid, "jwkRootDomain", jwkRefEnt.RootDomain, "baseDomain", opts.baseDomain, - ) - return nil, errors.New("manual JWK root domain does not match client URI base domain") - } - - jwkEnt, err := s.database.FindCredentialsKeyByID(ctx, jwkRefEnt.CredentialsKeyID) - if err != nil { - serviceErr := exceptions.FromDBError(err) - if serviceErr.Code == exceptions.CodeNotFound { - logger.WarnContext(ctx, "No credentials key found for manual JWK", "kid", kid, "error", err) - return nil, errors.New("no credentials key found for manual JWK") - } - - logger.ErrorContext(ctx, "Failed to find credentials key for manual JWK", "kid", kid, "error", err) - return nil, errors.New("failed to find credentials key for manual JWK") - } - - entJWK, err := utils.JsonToJWK(jwkEnt.PublicKey) - if err != nil { - logger.ErrorContext(ctx, "Failed to parse manual JWK", "error", err) - return nil, errors.New("failed to parse manual JWK") - } - if !entJWK.ComparePublicKey(sliceJWK) { - logger.WarnContext(ctx, "Manual JWK does not match database credentials key", "kid", kid) - return nil, errors.New("manual JWK does not match database credentials key") - } - - return sliceJWK, nil - } - } - - return func(kid string) (utils.JWK, error) { - jwkEntity, err := s.database.FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicID( - ctx, - database.FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicIDParams{ - RootDomain: opts.baseDomain, - AccountPublicID: opts.accountPublicID, - }, - ) - if err != nil { - if exceptions.FromDBError(err).Code == exceptions.CodeNotFound { - logger.WarnContext(ctx, "No manual JWKs found for software statement", "error", err) - return nil, errors.New("no manual JWKs found for software statement") - } - - logger.ErrorContext(ctx, "Failed to find manual JWKs for software statement", "error", err) - return nil, errors.New("failed to find manual JWKs for software statement") - } - if jwkEntity.PublicKid != kid { - logger.WarnContext(ctx, "No matching manual JWK found for KID", - "kid", kid, "publicKID", jwkEntity.PublicKid, - ) - return nil, errors.New("no matching manual JWK found for KID") - } - - jwk, err := utils.JsonToJWK(jwkEntity.PublicKey) - if err != nil { - logger.ErrorContext(ctx, "Failed to parse manual JWK for software statement", "error", err) - return nil, errors.New("failed to parse manual JWK for software statement") - } - - return jwk, nil - } - } - - return func(kid string) (utils.JWK, error) { - logger.WarnContext(ctx, "No verification method available for software statement") - return nil, errors.New("no verification method available") - } +var accountCredentialsRegistrationUsages []database.DynamicRegistrationUsage = []database.DynamicRegistrationUsage{ + database.DynamicRegistrationUsageAccount, } func mapAccountCredentialsDRTransport(applicationType database.AccountCredentialsType) database.Transport { @@ -735,11 +492,12 @@ func (s *Services) CreateAccountCredentialsRegistration( } domain := parsedClientURI.Hostname() - baseDomain, serviceErr := s.checkAccountCRDomain(ctx, checkAccountCRDomainOptions{ + baseDomain, serviceErr := s.checkClientRegistrationDomain(ctx, checkClientRegistrationDomainOptions{ requestID: opts.RequestID, accountPublicID: opts.AccountPublicID, iatDomain: opts.IATDomain, domain: domain, + usages: accountCredentialsRegistrationUsages, requireVerifiedDomains: slices.Contains(accountDRConfigDTO.RequireVerifiedDomainsCredentialsType, applicationType), }) if serviceErr != nil { @@ -788,7 +546,7 @@ func (s *Services) CreateAccountCredentialsRegistration( ssClaims, stdClaims, err := s.jwt.VerifySoftwareStatement(ctx, tokens.VerifySoftwareStatementOptions{ RequestID: opts.RequestID, SoftwareStatement: opts.SoftwareStatement, - GetPublicJWK: s.buildAccountCRSoftwareStatementFunc(ctx, buildAccountCRSoftwareStatementFuncOptions{ + GetPublicJWK: s.buildDynamicRegistrationSoftwareStatementFunc(ctx, buildDynamicRegistrationSoftwareStatementFuncOptions{ requestID: opts.RequestID, accountPublicID: opts.AccountPublicID, verificationMethods: accountDRConfigDTO.SoftwareStatementVerificationMethods, diff --git a/idp/internal/services/account_credentials_registration_iat.go b/idp/internal/services/account_credentials_registration_iat.go index b11131a..e42bff7 100644 --- a/idp/internal/services/account_credentials_registration_iat.go +++ b/idp/internal/services/account_credentials_registration_iat.go @@ -25,6 +25,7 @@ type CreateAccountCredentialsRegistrationIATOptions struct { AccountPublicID uuid.UUID AccountVersion int32 Domain string + BackendDomain string } func (s *Services) CreateAccountCredentialsRegistrationIAT( @@ -37,19 +38,14 @@ func (s *Services) CreateAccountCredentialsRegistrationIAT( ) logger.InfoContext(ctx, "Creating account credentials registration IAT...") - domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + if _, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ RequestID: opts.RequestID, AccountPublicID: opts.AccountPublicID, Domain: opts.Domain, - }) - if serviceErr != nil { + }); serviceErr != nil { logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) return "", serviceErr } - if !domainDTO.Verified { - logger.ErrorContext(ctx, "Account credentials registration domain is not verified") - return "", exceptions.NewValidationError("account credentials registration domain is not verified") - } if _, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ RequestID: opts.RequestID, @@ -62,9 +58,10 @@ func (s *Services) CreateAccountCredentialsRegistrationIAT( signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ RequestID: opts.RequestID, - Token: s.jwt.CreateAccountCredentialsDynamicRegistrationToken(tokens.AccountCredentialsDynamicRegistrationTokenOptions{ + Token: s.jwt.DynamicRegistrationIAT(tokens.DynamicRegistrationIATOptions{ AccountPublicID: opts.AccountPublicID, AccountVersion: opts.AccountVersion, + IssuerDomain: opts.BackendDomain, Domain: opts.Domain, ClientID: utils.Base62UUID(), }), @@ -93,8 +90,9 @@ func (s *Services) CreateAccountCredentialsRegistrationIAT( } type ProcessAccountCredentialsRegistrationIATAuthOptions struct { - RequestID string - AuthHeader string + RequestID string + AuthHeader string + BackendDomain string } func (s *Services) ProcessAccountCredentialsRegistrationIATAuth( @@ -110,11 +108,12 @@ func (s *Services) ProcessAccountCredentialsRegistrationIATAuth( return "", tokens.AccountClaims{}, serviceErr } - domain, accountClaims, err := s.jwt.VerifyAccountCredentialsDynamicRegistrationToken( + domain, accountClaims, err := s.jwt.VerifyDynamicRegistrationIAT( ctx, - tokens.VerifyAccountCredentialsDynamicRegistrationTokenOptions{ - RequestID: opts.RequestID, - IAT: token, + tokens.VerifyDynamicRegistrationIATOptions{ + RequestID: opts.RequestID, + IAT: token, + IssuerDomain: opts.BackendDomain, GetPublicJWK: s.BuildGetGlobalPublicKeyFn(ctx, BuildGetGlobalVerifyKeyFnOptions{ RequestID: opts.RequestID, KeyType: database.TokenKeyTypeDynamicRegistration, diff --git a/idp/internal/services/app_dynamic_registration.go b/idp/internal/services/app_dynamic_registration.go new file mode 100644 index 0000000..5213ca2 --- /dev/null +++ b/idp/internal/services/app_dynamic_registration.go @@ -0,0 +1,662 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + "net/url" + "slices" + "strings" + "time" + + "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/providers/tokens" + "github.com/tugascript/devlogs/idp/internal/services/dtos" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const appDynamicRegistrationLocation = "app_dynamic_registration" + +var allowedAppScopes []string = []string{ + string(database.ScopesOpenid), + string(database.ScopesEmail), + string(database.ScopesProfile), + string(database.ScopesAddress), + string(database.ScopesPhone), +} + +var appDynamicRegistrationUsages []database.DynamicRegistrationUsage = []database.DynamicRegistrationUsage{ + database.DynamicRegistrationUsageApp, +} + +func mapAppDRTransport(appType database.AppType) database.Transport { + if appType == database.AppTypeMcp { + return database.TransportStreamableHttp + } + + return database.TransportHttps +} + +func mapAppGrantTypes( + appType database.AppType, + grantTypes []string, +) ([]database.GrantType, *exceptions.ServiceError) { + if len(grantTypes) == 0 { + switch appType { + case database.AppTypeWeb, database.AppTypeSpa, database.AppTypeNative, database.AppTypeMcp: + return authCodeAppGrantTypes, nil + case database.AppTypeBackend, database.AppTypeService: + return []database.GrantType{ + database.GrantTypeClientCredentials, + database.GrantTypeUrnIetfParamsOauthGrantTypeJwtBearer, + }, nil + case database.AppTypeDevice: + return deviceGrantTypes, nil + default: + return nil, exceptions.NewValidationError("invalid app type") + } + } + + gts := make([]database.GrantType, 0, len(grantTypes)) + for _, grantType := range grantTypes { + mappedGrantType, serviceErr := mapGrantType(grantType) + if serviceErr != nil { + return nil, serviceErr + } + gts = append(gts, mappedGrantType) + } + + return gts, nil +} + +func mapAppTokenEndpointAuthMethod( + authMethod string, + appType database.AppType, + transport database.Transport, +) (database.AuthMethod, *exceptions.ServiceError) { + if authMethod == "" { + switch appType { + case database.AppTypeWeb, database.AppTypeSpa, database.AppTypeNative: + return database.AuthMethodClientSecretPost, nil + case database.AppTypeBackend, database.AppTypeService: + return database.AuthMethodPrivateKeyJwt, nil + case database.AppTypeDevice, database.AppTypeMcp: + return database.AuthMethodNone, nil + default: + return "", exceptions.NewValidationError("invalid app type") + } + } + + mappedAuthMethod, serviceErr := mapAuthMethod(authMethod) + if serviceErr != nil { + return "", serviceErr + } + + switch appType { + case database.AppTypeWeb, database.AppTypeSpa, database.AppTypeNative: + if mappedAuthMethod == database.AuthMethodNone { + return "", exceptions.NewValidationError("auth method none is not supported for web, spa, or native apps") + } + case database.AppTypeBackend, database.AppTypeService: + if mappedAuthMethod == database.AuthMethodNone { + return "", exceptions.NewValidationError("auth method none is not supported for backend or service apps") + } + case database.AppTypeDevice, database.AppTypeMcp: + if mappedAuthMethod != database.AuthMethodNone { + return "", exceptions.NewValidationError("only auth method none is supported for device or mcp apps") + } + } + + return mappedAuthMethod, nil +} + +type mapAppRegistrationDataToDBParamsOptions struct { + appType database.AppType + accountPublicID uuid.UUID + accountID int32 + domain string + requestID string + tokenEndpointAuthMethod database.AuthMethod + transport database.Transport + scopes []database.Scopes + customScopes []string + defaultScopes []database.Scopes + defaultCustomScopes []string + allowUserRegistration bool + usernameColumn database.AppUsernameColumn + authProviders []database.AuthProvider + data *ApplicationRegistrationData + claims *tokens.SoftwareStatementClaims +} + +func (s *Services) mapAppRegistrationDataToDBParams( + ctx context.Context, + opts mapAppRegistrationDataToDBParamsOptions, +) (database.CreateAppParams, *exceptions.ServiceError) { + logger := s.buildLogger(opts.requestID, appDynamicRegistrationLocation, "mapAppRegistrationDataToDBParams").With( + "accountPublicID", opts.accountPublicID, + "accountID", opts.accountID, + "domain", opts.domain, + "data", opts.data, + "claims", opts.claims, + ) + logger.InfoContext(ctx, "Mapping app registration data to database params") + + responseTypes, serviceErr := mapResponseTypesWithDefault(opts.data.ResponseTypes) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map response types", "serviceError", serviceErr) + return database.CreateAppParams{}, serviceErr + } + + grantTypes, serviceErr := mapAppGrantTypes(opts.appType, opts.data.GrantTypes) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map grant types", "serviceError", serviceErr) + return database.CreateAppParams{}, serviceErr + } + + params := database.CreateAppParams{ + AccountID: opts.accountID, + AccountPublicID: opts.accountPublicID, + AppType: opts.appType, + ClientName: opts.data.ClientName, + ClientID: utils.Base62UUID(), + ClientUri: utils.ProcessURL(opts.data.ClientURI), + UsernameColumn: opts.usernameColumn, + TokenEndpointAuthMethod: opts.tokenEndpointAuthMethod, + CreationMethod: database.CreationMethodDynamicRegistration, + GrantTypes: grantTypes, + LogoUri: mapEmptyURL(opts.data.LogoURI), + TosUri: mapEmptyURL(opts.data.TOSURI), + PolicyUri: mapEmptyURL(opts.data.PolicyURI), + Contacts: utils.MapSlice(opts.data.Contacts, func(t *string) string { + return utils.Lowered(*t) + }), + SoftwareID: mapEmptyString(opts.data.SoftwareID), + SoftwareVersion: mapEmptyString(opts.data.SoftwareVersion), + Scopes: opts.scopes, + DefaultScopes: opts.defaultScopes, + CustomScopes: opts.customScopes, + DefaultCustomScopes: opts.defaultCustomScopes, + Domain: opts.domain, + Transport: opts.transport, + RedirectUris: utils.MapSlice(opts.data.RedirectURIs, func(uri *string) string { + return utils.ProcessURL(*uri) + }), + ResponseTypes: responseTypes, + AllowUserRegistration: opts.allowUserRegistration, + AuthProviders: opts.authProviders, + } + + if opts.claims != nil { + if opts.claims.ClientName != "" { + params.ClientName = opts.claims.ClientName + } + if opts.claims.ClientURI != "" { + params.ClientUri = utils.ProcessURL(opts.claims.ClientURI) + } + if opts.claims.LogoURI != "" { + params.LogoUri = mapEmptyURL(opts.claims.LogoURI) + } + if len(opts.claims.RedirectURIs) > 0 { + params.RedirectUris = utils.MapSlice(opts.claims.RedirectURIs, func(uri *string) string { + return utils.ProcessURL(*uri) + }) + } + if opts.claims.TOSURI != "" { + params.TosUri = mapEmptyURL(opts.claims.TOSURI) + } + if opts.claims.PolicyURI != "" { + params.PolicyUri = mapEmptyURL(opts.claims.PolicyURI) + } + if opts.claims.SoftwareID != "" { + params.SoftwareID = mapEmptyString(opts.claims.SoftwareID) + } + if opts.claims.SoftwareVersion != "" { + params.SoftwareVersion = mapEmptyString(opts.claims.SoftwareVersion) + } + if len(opts.claims.GrantTypes) > 0 { + params.GrantTypes = utils.MapSlice(opts.claims.GrantTypes, func(grantType *string) database.GrantType { + return database.GrantType(*grantType) + }) + } + if len(opts.claims.ResponseTypes) > 0 { + params.ResponseTypes = utils.MapSlice(opts.claims.ResponseTypes, func(responseType *string) database.ResponseType { + return database.ResponseType(*responseType) + }) + } + if opts.claims.Scope != "" { + scopesList := strings.Fields(opts.claims.Scope) + stdScopes, customScopes, _, _, serviceErr := mapScopesToStandardAndCustomScopes(scopesList, nil) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map scopes from software statement", "serviceError", serviceErr) + return database.CreateAppParams{}, serviceErr + } + params.Scopes = stdScopes + params.CustomScopes = customScopes + } + if len(opts.claims.Contacts) > 0 { + params.Contacts = utils.MapSlice(opts.claims.Contacts, func(t *string) string { + return utils.Lowered(*t) + }) + } + } + + return params, nil +} + +type CreateAppCredentialsRegistrationOptions struct { + RequestID string + AccountID int32 + IsAuthenticated bool + IATDomain string + AccountVersion int32 + ApplicationType string + RedirectURIs []string + TokenEndpointAuthMethod string + GrantTypes []string + ResponseTypes []string + ClientName string + ClientURI string + LogoURI string + TOSURI string + PolicyURI string + Contacts []string + SoftwareID string + SoftwareVersion string + SoftwareStatement string + JWKsURI string + JWKs []string + FrontendDomain string + BackendDomain string + RequireAuthTime bool + DefaultMaxAge int64 + SubjectType string + IDTokenSignedResponseAlg string + IDTokenEncryptedResponseAlg string + IDTokenEncryptedResponseEnc string + RequestObjectSigningAlg string + RequestObjectEncryptionAlg string + RequestObjectEncryptionEnc string + DefaultACRValues []string + Scope string + SectorIdentifierURI string + InitiateLoginURI string + RequestURIs []string + UserInfoSignedResponseAlg string + UserInfoEncryptedResponseAlg string + UserInfoEncryptedResponseEnc string + TokenEndpointAuthSigningAlg string + AccessTokenSigningAlg string +} + +func (s *Services) CreateAppCredentialsRegistration( + ctx context.Context, + opts CreateAppCredentialsRegistrationOptions, +) (dtos.AppDTO, *exceptions.ServiceError) { + logger := s.buildLogger( + opts.RequestID, + appDynamicRegistrationLocation, + "CreateAppCredentialsRegistration", + ).With( + "accountID", opts.AccountID, + ) + logger.InfoContext(ctx, "Creating app credentials registration...") + + appType, serviceErr := mapAppTypeToDB(opts.ApplicationType) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map application type", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + transport := mapAppDRTransport(appType) + tokenEndpointAuthMethod, serviceErr := mapAppTokenEndpointAuthMethod( + opts.TokenEndpointAuthMethod, + appType, + transport, + ) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map token endpoint auth method", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + accessTokenSigningAlg, serviceErr := mapTokenCryptoSuiteWithDefault(opts.AccessTokenSigningAlg) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map access token signing alg", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + if !validateEncryptionAlgorithmPair(opts.IDTokenEncryptedResponseAlg, opts.IDTokenEncryptedResponseEnc) { + logger.WarnContext(ctx, "id_token encryption algorithm and encoding must both be set or both be unset") + return dtos.AppDTO{}, exceptions.NewValidationError("id_token encryption algorithm and encoding mismatch") + } + if !validateEncryptionAlgorithmPair(opts.UserInfoEncryptedResponseAlg, opts.UserInfoEncryptedResponseEnc) { + logger.WarnContext(ctx, "userinfo encryption algorithm and encoding must both be set or both be unset") + return dtos.AppDTO{}, exceptions.NewValidationError("userinfo encryption algorithm and encoding mismatch") + } + if !validateEncryptionAlgorithmPair(opts.RequestObjectEncryptionAlg, opts.RequestObjectEncryptionEnc) { + logger.WarnContext(ctx, "request_object encryption algorithm and encoding must both be set or both be unset") + return dtos.AppDTO{}, exceptions.NewValidationError("request_object encryption algorithm and encoding mismatch") + } + + parsedClientURI, err := url.Parse(opts.ClientURI) + if err != nil { + logger.WarnContext(ctx, "Failed to parse client URI", "error", err) + return dtos.AppDTO{}, exceptions.NewValidationError("invalid client URI") + } + domain := parsedClientURI.Hostname() + + appDRConfigDTO, serviceErr := s.GetAndCacheAppDynamicRegistrationConfig(ctx, GetAndCacheAppDynamicRegistrationConfigOptions{ + RequestID: opts.RequestID, + AccountID: opts.AccountID, + }) + if serviceErr != nil { + if serviceErr.Code == exceptions.CodeNotFound { + logger.InfoContext(ctx, "App dynamic registration config not found") + return dtos.AppDTO{}, exceptions.NewForbiddenError() + } + + logger.ErrorContext(ctx, "Failed to get app dynamic registration config", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + if slices.Contains(appDRConfigDTO.RequireInitialAccessTokenAppTypes, appType) && + !opts.IsAuthenticated { + logger.WarnContext(ctx, "App dynamic registration configuration requires initial access token") + return dtos.AppDTO{}, exceptions.NewUnauthorizedError() + } + + if !slices.Contains(appDRConfigDTO.AllowedAppTypes, appType) { + logger.WarnContext(ctx, "App type is not allowed", "appType", appType) + return dtos.AppDTO{}, exceptions.NewForbiddenError() + } + + if slices.Contains(appDRConfigDTO.RequireSoftwareStatementAppTypes, appType) && + opts.SoftwareStatement == "" { + logger.WarnContext(ctx, "App dynamic registration configuration requires software statement") + return dtos.AppDTO{}, exceptions.NewUnauthorizedError() + } + + accountDTO, serviceErr := s.GetAccountByID(ctx, GetAccountByIDOptions{ + RequestID: opts.RequestID, + ID: opts.AccountID, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account by ID", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + if opts.AccountVersion != 0 && accountDTO.Version() != opts.AccountVersion { + logger.WarnContext(ctx, "Account version mismatch", + "providedVersion", opts.AccountVersion, + "currentVersion", accountDTO.Version(), + ) + return dtos.AppDTO{}, exceptions.NewUnauthorizedError() + } + + if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ + requestID: opts.RequestID, + accountID: opts.AccountID, + name: opts.ClientName, + softwareID: opts.SoftwareID, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + baseDomain, serviceErr := s.checkClientRegistrationDomain(ctx, checkClientRegistrationDomainOptions{ + requestID: opts.RequestID, + accountPublicID: accountDTO.PublicID, + iatDomain: opts.IATDomain, + usages: appDynamicRegistrationUsages, + domain: domain, + requireVerifiedDomains: slices.Contains(appDRConfigDTO.RequireVerifiedDomainsAppTypes, appType), + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to check domain validity", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + scopesList := strings.Fields(opts.Scope) + defaultScopesList := utils.MapSlice(appDRConfigDTO.DefaultScopes, func(s *database.Scopes) string { + return string(*s) + }) + stdScopes, customScopes, defaultStdScopes, defaultCustomScopes, serviceErr := mapScopesToStandardAndCustomScopes( + scopesList, + defaultScopesList, + ) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map scopes", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + allowedScopesSet := utils.SliceToHashSet(utils.MapSlice(appDRConfigDTO.DefaultAllowedScopes, func(s *database.Scopes) string { + return string(*s) + })) + for _, scope := range stdScopes { + if !allowedScopesSet.Contains(string(scope)) { + logger.WarnContext(ctx, "Scope is not allowed", "scope", scope) + return dtos.AppDTO{}, exceptions.NewValidationError("scope is not allowed: " + string(scope)) + } + } + + allowUserRegistration := appDRConfigDTO.DefaultAllowUserRegistration + usernameColumn := appDRConfigDTO.DefaultUsernameColumn + authProviders := appDRConfigDTO.DefaultAuthProviders + + data := ApplicationRegistrationData{ + RedirectURIs: opts.RedirectURIs, + TokenEndpointAuthMethod: opts.TokenEndpointAuthMethod, + ResponseTypes: opts.ResponseTypes, + GrantTypes: opts.GrantTypes, + ApplicationType: opts.ApplicationType, + ClientName: opts.ClientName, + ClientURI: opts.ClientURI, + LogoURI: opts.LogoURI, + Scope: opts.Scope, + Contacts: opts.Contacts, + TOSURI: opts.TOSURI, + PolicyURI: opts.PolicyURI, + JWKsURI: opts.JWKsURI, + JWKs: opts.JWKs, + SoftwareID: opts.SoftwareID, + SoftwareVersion: opts.SoftwareVersion, + SubjectType: opts.SubjectType, + SectorIdentifierURI: opts.SectorIdentifierURI, + DefaultMaxAge: opts.DefaultMaxAge, + RequireAuthTime: opts.RequireAuthTime, + DefaultACRValues: opts.DefaultACRValues, + InitiateLoginURI: opts.InitiateLoginURI, + RequestURIs: opts.RequestURIs, + IDTokenSignedResponseAlg: opts.IDTokenSignedResponseAlg, + IDTokenEncryptedResponseAlg: opts.IDTokenEncryptedResponseAlg, + IDTokenEncryptedResponseEnc: opts.IDTokenEncryptedResponseEnc, + UserInfoSignedResponseAlg: opts.UserInfoSignedResponseAlg, + UserInfoEncryptedResponseAlg: opts.UserInfoEncryptedResponseAlg, + UserInfoEncryptedResponseEnc: opts.UserInfoEncryptedResponseEnc, + RequestObjectSigningAlg: opts.RequestObjectSigningAlg, + RequestObjectEncryptionAlg: opts.RequestObjectEncryptionAlg, + RequestObjectEncryptionEnc: opts.RequestObjectEncryptionEnc, + TokenEndpointAuthSigningAlg: opts.TokenEndpointAuthSigningAlg, + AccessTokenSigningAlg: opts.AccessTokenSigningAlg, + } + var ssClaimsReference *tokens.SoftwareStatementClaims + if opts.SoftwareStatement != "" { + ssClaims, stdClaims, err := s.jwt.VerifySoftwareStatement(ctx, tokens.VerifySoftwareStatementOptions{ + RequestID: opts.RequestID, + SoftwareStatement: opts.SoftwareStatement, + GetPublicJWK: s.buildDynamicRegistrationSoftwareStatementFunc(ctx, buildDynamicRegistrationSoftwareStatementFuncOptions{ + requestID: opts.RequestID, + accountPublicID: accountDTO.PublicID, + verificationMethods: appDRConfigDTO.SoftwareStatementVerificationMethods, + jwksURI: opts.JWKsURI, + jwks: opts.JWKs, + domain: domain, + baseDomain: baseDomain, + }), + }) + if err != nil { + logger.WarnContext(ctx, "Failed to verify software statement", "error", err) + return dtos.AppDTO{}, exceptions.NewInvalidTokenError("invalid software statement") + } + if serviceErr := s.verifySoftwareStatementSTDClaims(ctx, verifySoftwareStatementSTDClaimsOptions{ + requestID: opts.RequestID, + backendDomain: opts.BackendDomain, + frontendDomain: opts.FrontendDomain, + domain: domain, + baseDomain: baseDomain, + claims: &stdClaims, + }); serviceErr != nil { + logger.WarnContext(ctx, "Failed to verify software statement standard claims", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + if serviceErr := s.validateSoftwareStatementClaims(ctx, validateSoftwareStatementClaimsOptions{ + requestID: opts.RequestID, + claims: &ssClaims, + data: &data, + allowedScopes: utils.SliceToHashSet(allowedAppScopes), + }); serviceErr != nil { + logger.WarnContext(ctx, "Failed to validate software statement claims", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + ssClaimsReference = &ssClaims + } + + params, serviceErr := s.mapAppRegistrationDataToDBParams(ctx, mapAppRegistrationDataToDBParamsOptions{ + appType: appType, + accountPublicID: accountDTO.PublicID, + accountID: opts.AccountID, + domain: domain, + requestID: opts.RequestID, + tokenEndpointAuthMethod: tokenEndpointAuthMethod, + transport: transport, + scopes: stdScopes, + customScopes: customScopes, + defaultScopes: defaultStdScopes, + defaultCustomScopes: defaultCustomScopes, + allowUserRegistration: allowUserRegistration, + usernameColumn: usernameColumn, + authProviders: authProviders, + data: &data, + claims: ssClaimsReference, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to map app registration data to database params", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + if tokenEndpointAuthMethod == database.AuthMethodNone { + app, err := s.database.CreateApp(ctx, params) + if err != nil { + logger.ErrorContext(ctx, "Failed to create app", "error", err) + return dtos.AppDTO{}, exceptions.FromDBError(err) + } + + logger.InfoContext(ctx, "Created app successfully") + return dtos.MapAppToDTO(&app), nil + } + + qrs, txn, err := s.database.BeginTx(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to start transaction", "error", err) + return dtos.AppDTO{}, exceptions.FromDBError(err) + } + defer func() { + logger.DebugContext(ctx, "Finalizing transaction") + s.database.FinalizeTx(ctx, txn, err, serviceErr) + }() + + app, err := s.database.CreateApp(ctx, params) + if err != nil { + logger.ErrorContext(ctx, "Failed to create app", "error", err) + return dtos.AppDTO{}, exceptions.FromDBError(err) + } + + switch tokenEndpointAuthMethod { + case database.AuthMethodPrivateKeyJwt: + var dbPrms database.CreateCredentialsKeyParams + var jwk utils.JWK + dbPrms, jwk, serviceErr = s.clientCredentialsKey(ctx, clientCredentialsKeyOptions{ + requestID: opts.RequestID, + accountID: opts.AccountID, + accountPublicID: accountDTO.PublicID, + expiresIn: s.accountCCExpDays, + usage: database.CredentialsUsageApp, + cryptoSuite: utils.SupportedCryptoSuite(accessTokenSigningAlg), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to generate client credentials key", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + var clientKey database.CredentialsKey + clientKey, err = qrs.CreateCredentialsKey(ctx, dbPrms) + if err != nil { + logger.ErrorContext(ctx, "Failed to create client key", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AppDTO{}, serviceErr + } + + if err = qrs.CreateAppKey(ctx, database.CreateAppKeyParams{ + AccountID: opts.AccountID, + AppID: app.ID, + CredentialsKeyID: clientKey.ID, + }); err != nil { + logger.ErrorContext(ctx, "Failed to create app key", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AppDTO{}, serviceErr + } + + if appType == database.AppTypeBackend || appType == database.AppTypeService { + return dtos.MapBackendAppWithJWKToDTO(&app, jwk, dbPrms.ExpiresAt), nil + } + + return dtos.MapWebAppWithJWKToDTO(&app, jwk, dbPrms.ExpiresAt), nil + case database.AuthMethodClientSecretBasic, database.AuthMethodClientSecretPost, database.AuthMethodClientSecretJwt: + var ccID int32 + var secretID, secret string + var exp time.Time + ccID, secretID, secret, exp, serviceErr = s.clientCredentialsSecret(ctx, qrs, clientCredentialsSecretOptions{ + requestID: opts.RequestID, + accountID: opts.AccountID, + storageMode: mapCCSecretStorageMode(string(tokenEndpointAuthMethod)), + expiresIn: s.appCCExpDays, + usage: database.CredentialsUsageApp, + dekFN: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ + RequestID: opts.RequestID, + AccountID: opts.AccountID, + }), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to create client credentials secret", "serviceError", serviceErr) + return dtos.AppDTO{}, serviceErr + } + + if err = qrs.CreateAppSecret(ctx, database.CreateAppSecretParams{ + AppID: app.ID, + CredentialsSecretID: ccID, + AccountID: opts.AccountID, + }); err != nil { + logger.ErrorContext(ctx, "Failed to create app secret", "error", err) + serviceErr = exceptions.FromDBError(err) + return dtos.AppDTO{}, serviceErr + } + + if appType == database.AppTypeBackend || appType == database.AppTypeService { + return dtos.MapBackendAppWithSecretToDTO(&app, secretID, secret, exp), nil + } + + return dtos.MapWebAppWithSecretToDTO(&app, secretID, secret, exp), nil + default: + logger.ErrorContext(ctx, "Invalid token endpoint auth method", "tokenEndpointAuthMethod", tokenEndpointAuthMethod) + serviceErr = exceptions.NewInternalServerError() + return dtos.AppDTO{}, serviceErr + } +} diff --git a/idp/internal/services/app_dynamic_registration_configs.go b/idp/internal/services/app_dynamic_registration_configs.go new file mode 100644 index 0000000..d55af9d --- /dev/null +++ b/idp/internal/services/app_dynamic_registration_configs.go @@ -0,0 +1,420 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + "fmt" + "time" + + "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/cache" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/services/dtos" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const ( + appDynamicRegistrationConfigsLocation string = "app_dynamic_registration_configs" + + appDynamicRegistrationConfigCacheTTL time.Duration = 24 * time.Hour +) + +func buildAppDynamicRegistrationConfigCacheKey(accountID int32) string { + return fmt.Sprintf("%s:%d", appDynamicRegistrationConfigsLocation, accountID) +} + +func mapAppTypes(appTypes []string) ([]database.AppType, *exceptions.ServiceError) { + appTypesDB := make([]database.AppType, 0, len(appTypes)) + for _, appType := range appTypes { + appTypeDB, serviceErr := mapAppTypeToDB(appType) + if serviceErr != nil { + return nil, serviceErr + } + appTypesDB = append(appTypesDB, appTypeDB) + } + return appTypesDB, nil +} + +func mapScopes(scopes []string) ([]database.Scopes, *exceptions.ServiceError) { + scopesDB := make([]database.Scopes, 0, len(scopes)) + for _, scope := range scopes { + scopeDB, serviceErr := mapScope(scope) + if serviceErr != nil { + return nil, serviceErr + } + scopesDB = append(scopesDB, scopeDB) + } + return scopesDB, nil +} + +func mapGrantTypes(grantTypes []string) ([]database.GrantType, *exceptions.ServiceError) { + grantTypesDB := make([]database.GrantType, 0, len(grantTypes)) + for _, grantType := range grantTypes { + grantTypeDB, serviceErr := mapGrantType(grantType) + if serviceErr != nil { + return nil, serviceErr + } + grantTypesDB = append(grantTypesDB, grantTypeDB) + } + return grantTypesDB, nil +} + +func mapResponseTypes(responseTypes []string) ([]database.ResponseType, *exceptions.ServiceError) { + var responseTypesDB []database.ResponseType + for _, responseType := range responseTypes { + switch utils.Lowered(responseType) { + case ResponseTypeCode: + responseTypesDB = append(responseTypesDB, database.ResponseTypeCode) + case ResponseTypeCodeIdToken: + responseTypesDB = append(responseTypesDB, database.ResponseTypeCodeidToken) + default: + return nil, exceptions.NewValidationError("invalid response type: " + responseType) + } + } + return responseTypesDB, nil +} + +func mapAuthMethods(authMethods []string) ([]database.AuthMethod, *exceptions.ServiceError) { + authMethodsDB := make([]database.AuthMethod, 0, len(authMethods)) + for _, authMethod := range authMethods { + authMethodDB, serviceErr := mapAuthMethod(authMethod) + if serviceErr != nil { + return nil, serviceErr + } + authMethodsDB = append(authMethodsDB, authMethodDB) + } + return authMethodsDB, nil +} + +type SaveAppDynamicRegistrationConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + AllowedAppTypes []string + DefaultAllowUserRegistration bool + DefaultAuthProviders []string + DefaultUsernameColumn string + DefaultAllowedScopes []string + DefaultScopes []string + RequireVerifiedDomainsAppTypes []string + RequireSoftwareStatementAppTypes []string + SoftwareStatementVerificationMethods []string + RequireInitialAccessTokenAppTypes []string + InitialAccessTokenGenerationMethods []string + InitialAccessTokenTtl int32 + InitialAccessTokenMaxUses int32 + AllowedGrantTypes []string + AllowedResponseTypes []string + AllowedTokenEndpointAuthMethods []string + MaxRedirectUris int32 +} + +func (s *Services) SaveAppDynamicRegistrationConfig( + ctx context.Context, + opts SaveAppDynamicRegistrationConfigOptions, +) (dtos.AppDynamicRegistrationConfigDTO, bool, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, appDynamicRegistrationConfigsLocation, "SaveAppDynamicRegistrationConfig").With( + "accountPublicID", opts.AccountPublicID, + "accountVersion", opts.AccountVersion, + ) + logger.InfoContext(ctx, "Saving app dynamic registration config...") + + allowedAppTypes, serviceErr := mapAppTypes(opts.AllowedAppTypes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map allowed app types", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + defaultAuthProviders, serviceErr := mapAuthProviders(opts.DefaultAuthProviders) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map default auth providers", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + defaultUsernameColumn, serviceErr := mapUsernameColumn(opts.DefaultUsernameColumn) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map default username column", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + defaultAllowedScopes, serviceErr := mapScopes(opts.DefaultAllowedScopes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map default allowed scopes", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + defaultScopes, serviceErr := mapScopes(opts.DefaultScopes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map default scopes", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + requireVerifiedDomainsAppTypes, serviceErr := mapAppTypes(opts.RequireVerifiedDomainsAppTypes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map require verified domains app types", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + requireSoftwareStatementAppTypes, serviceErr := mapAppTypes(opts.RequireSoftwareStatementAppTypes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map require software statement app types", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + softwareStatementVerificationMethods, serviceErr := mapSoftwareStatementVerificationMethods(opts.SoftwareStatementVerificationMethods) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map software statement verification methods", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + requireInitialAccessTokenAppTypes, serviceErr := mapAppTypes(opts.RequireInitialAccessTokenAppTypes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map require initial access token app types", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + initialAccessTokenGenerationMethods, serviceErr := mapInitialAccessTokenGenerationMethods(opts.InitialAccessTokenGenerationMethods) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map initial access token generation methods", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + allowedGrantTypes, serviceErr := mapGrantTypes(opts.AllowedGrantTypes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map allowed grant types", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + allowedResponseTypes, serviceErr := mapResponseTypes(opts.AllowedResponseTypes) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map allowed response types", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + allowedTokenEndpointAuthMethods, serviceErr := mapAuthMethods(opts.AllowedTokenEndpointAuthMethods) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to map allowed token endpoint auth methods", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + accountID, serviceErr := s.GetAccountIDByPublicIDAndVersion(ctx, GetAccountIDByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to get account", "serviceError", serviceErr) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + appDynamicRegistrationConfig, err := s.database.FindAppDynamicRegistrationConfigByAccountPublicID(ctx, opts.AccountPublicID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find app dynamic registration config", "error", err) + return dtos.AppDynamicRegistrationConfigDTO{}, false, serviceErr + } + + logger.InfoContext(ctx, "App dynamic registration config not found, creating new one...") + appDynamicRegistrationConfig, err = s.database.CreateAppDynamicRegistrationConfig( + ctx, + database.CreateAppDynamicRegistrationConfigParams{ + AccountID: accountID, + AccountPublicID: opts.AccountPublicID, + AllowedAppTypes: allowedAppTypes, + DefaultAllowUserRegistration: opts.DefaultAllowUserRegistration, + DefaultAuthProviders: defaultAuthProviders, + DefaultUsernameColumn: defaultUsernameColumn, + DefaultAllowedScopes: defaultAllowedScopes, + DefaultScopes: defaultScopes, + RequireVerifiedDomainsAppTypes: requireVerifiedDomainsAppTypes, + RequireSoftwareStatementAppTypes: requireSoftwareStatementAppTypes, + SoftwareStatementVerificationMethods: softwareStatementVerificationMethods, + RequireInitialAccessTokenAppTypes: requireInitialAccessTokenAppTypes, + InitialAccessTokenGenerationMethods: initialAccessTokenGenerationMethods, + InitialAccessTokenTtl: opts.InitialAccessTokenTtl, + InitialAccessTokenMaxUses: opts.InitialAccessTokenMaxUses, + AllowedGrantTypes: allowedGrantTypes, + AllowedResponseTypes: allowedResponseTypes, + AllowedTokenEndpointAuthMethods: allowedTokenEndpointAuthMethods, + MaxRedirectUris: opts.MaxRedirectUris, + }, + ) + if err != nil { + logger.ErrorContext(ctx, "Failed to create app dynamic registration config", "error", err) + return dtos.AppDynamicRegistrationConfigDTO{}, false, exceptions.FromDBError(err) + } + + return dtos.MapAppDynamicRegistrationConfigToDTO(&appDynamicRegistrationConfig), true, nil + } + + appDynamicRegistrationConfig, err = s.database.UpdateAppDynamicRegistrationConfig(ctx, database.UpdateAppDynamicRegistrationConfigParams{ + ID: appDynamicRegistrationConfig.ID, + AllowedAppTypes: allowedAppTypes, + DefaultAllowUserRegistration: opts.DefaultAllowUserRegistration, + DefaultAuthProviders: defaultAuthProviders, + DefaultUsernameColumn: defaultUsernameColumn, + DefaultAllowedScopes: defaultAllowedScopes, + DefaultScopes: defaultScopes, + RequireVerifiedDomainsAppTypes: requireVerifiedDomainsAppTypes, + RequireSoftwareStatementAppTypes: requireSoftwareStatementAppTypes, + SoftwareStatementVerificationMethods: softwareStatementVerificationMethods, + RequireInitialAccessTokenAppTypes: requireInitialAccessTokenAppTypes, + InitialAccessTokenGenerationMethods: initialAccessTokenGenerationMethods, + InitialAccessTokenTtl: opts.InitialAccessTokenTtl, + InitialAccessTokenMaxUses: opts.InitialAccessTokenMaxUses, + AllowedGrantTypes: allowedGrantTypes, + AllowedResponseTypes: allowedResponseTypes, + AllowedTokenEndpointAuthMethods: allowedTokenEndpointAuthMethods, + MaxRedirectUris: opts.MaxRedirectUris, + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to update app dynamic registration config", "error", err) + return dtos.AppDynamicRegistrationConfigDTO{}, false, exceptions.FromDBError(err) + } + + if err := s.cache.DeleteResponse(ctx, cache.DeleteResponseOptions{ + RequestID: opts.RequestID, + Key: buildAppDynamicRegistrationConfigCacheKey(accountID), + }); err != nil { + logger.ErrorContext(ctx, "Failed to delete cached app dynamic registration config", "error", err) + return dtos.AppDynamicRegistrationConfigDTO{}, false, exceptions.NewInternalServerError() + } + + return dtos.MapAppDynamicRegistrationConfigToDTO(&appDynamicRegistrationConfig), false, nil +} + +type GetAppDynamicRegistrationConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID +} + +func (s *Services) GetAppDynamicRegistrationConfig( + ctx context.Context, + opts GetAppDynamicRegistrationConfigOptions, +) (dtos.AppDynamicRegistrationConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, appDynamicRegistrationConfigsLocation, "GetAppDynamicRegistrationConfig").With( + "accountPublicID", opts.AccountPublicID, + ) + logger.InfoContext(ctx, "Retrieving app dynamic registration config...") + + appDynamicRegistrationConfig, err := s.database.FindAppDynamicRegistrationConfigByAccountPublicID(ctx, opts.AccountPublicID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find app dynamic registration config", "error", err) + return dtos.AppDynamicRegistrationConfigDTO{}, serviceErr + } + + logger.InfoContext(ctx, "App dynamic registration config not found", "error", err) + return dtos.AppDynamicRegistrationConfigDTO{}, nil + } + + return dtos.MapAppDynamicRegistrationConfigToDTO(&appDynamicRegistrationConfig), nil +} + +type GetAndCacheAppDynamicRegistrationConfigOptions struct { + RequestID string + AccountID int32 +} + +func (s *Services) GetAndCacheAppDynamicRegistrationConfig( + ctx context.Context, + opts GetAndCacheAppDynamicRegistrationConfigOptions, +) (dtos.AppDynamicRegistrationConfigDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, appDynamicRegistrationConfigsLocation, "GetAndCacheAppDynamicRegistrationConfig").With( + "accountID", opts.AccountID, + ) + logger.InfoContext(ctx, "Getting and caching app dynamic registration config...") + + appDRConfigDTO, found, err := cache.GetResponseWithoutETag(s.cache, ctx, cache.GetResponseOptions[dtos.AppDynamicRegistrationConfigDTO]{ + RequestID: opts.RequestID, + Key: buildAppDynamicRegistrationConfigCacheKey(opts.AccountID), + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to get cached app dynamic registration config", "error", err) + return dtos.AppDynamicRegistrationConfigDTO{}, exceptions.NewInternalServerError() + } + if found { + logger.InfoContext(ctx, "App dynamic registration config found in cache") + return appDRConfigDTO, nil + } + + appDRConfig, err := s.database.FindAppDynamicRegistrationConfigByAccountID(ctx, opts.AccountID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find app dynamic registration config", "error", err) + return dtos.AppDynamicRegistrationConfigDTO{}, serviceErr + } + + logger.InfoContext(ctx, "App dynamic registration config not found, creating new one...") + return dtos.AppDynamicRegistrationConfigDTO{}, exceptions.NewNotFoundError() + } + + appDRConfigDTO = dtos.MapAppDynamicRegistrationConfigToDTO(&appDRConfig) + if err := cache.SaveResponseWithoutETag(s.cache, ctx, cache.SaveResponseOptions[dtos.AppDynamicRegistrationConfigDTO]{ + RequestID: opts.RequestID, + Key: buildAppDynamicRegistrationConfigCacheKey(opts.AccountID), + TTL: appDynamicRegistrationConfigCacheTTL, + Value: appDRConfigDTO, + }); err != nil { + logger.ErrorContext(ctx, "Failed to save app dynamic registration config to cache", "error", err) + return dtos.AppDynamicRegistrationConfigDTO{}, exceptions.NewInternalServerError() + } + + return appDRConfigDTO, nil +} + +type DeleteAppDynamicRegistrationConfigOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 +} + +func (s *Services) DeleteAppDynamicRegistrationConfig( + ctx context.Context, + opts DeleteAppDynamicRegistrationConfigOptions, +) *exceptions.ServiceError { + logger := s.buildLogger(opts.RequestID, appDynamicRegistrationConfigsLocation, "DeleteAppDynamicRegistrationConfig").With( + "accountPublicID", opts.AccountPublicID, + "accountVersion", opts.AccountVersion, + ) + logger.InfoContext(ctx, "Deleting app dynamic registration config...") + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account by public ID and version", "serviceError", serviceErr) + return serviceErr + } + + appDynamicRegistrationConfig, serviceErr := s.GetAppDynamicRegistrationConfig( + ctx, + GetAppDynamicRegistrationConfigOptions{ + RequestID: opts.RequestID, + AccountPublicID: accountDTO.PublicID, + }, + ) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get app dynamic registration config", "serviceError", serviceErr) + return serviceErr + } + + if err := s.database.DeleteAppDynamicRegistrationConfig(ctx, appDynamicRegistrationConfig.ID()); err != nil { + logger.ErrorContext(ctx, "Failed to delete app dynamic registration config", "error", err) + return exceptions.FromDBError(err) + } + + return nil +} diff --git a/idp/internal/services/app_dynamic_registration_iat.go b/idp/internal/services/app_dynamic_registration_iat.go new file mode 100644 index 0000000..c451c58 --- /dev/null +++ b/idp/internal/services/app_dynamic_registration_iat.go @@ -0,0 +1,138 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services + +import ( + "context" + "fmt" + + "github.com/google/uuid" + + "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/crypto" + "github.com/tugascript/devlogs/idp/internal/providers/database" + "github.com/tugascript/devlogs/idp/internal/providers/tokens" + "github.com/tugascript/devlogs/idp/internal/utils" +) + +const appDynamicRegistrationIATLocation = "app_credentials_registration_iat" + +type CreateAppCredentialsRegistrationIATOptions struct { + RequestID string + AccountPublicID uuid.UUID + AccountVersion int32 + Domain string + BackendDomain string +} + +func (s *Services) CreateAppCredentialsRegistrationIAT( + ctx context.Context, + opts CreateAppCredentialsRegistrationIATOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, appDynamicRegistrationIATLocation, "CreateAppCredentialsRegistrationIAT").With( + "accountPublicId", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Creating app credentials registration IAT...") + + if _, serviceErr := s.GetAppCredentialsRegistrationDomain(ctx, GetAppCredentialsRegistrationDomainOptions{ + RequestID: opts.RequestID, + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }); serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get app credentials registration domain", "serviceError", serviceErr) + return "", serviceErr + } + + accountDTO, serviceErr := s.GetAccountByPublicIDAndVersion(ctx, GetAccountByPublicIDAndVersionOptions{ + RequestID: opts.RequestID, + PublicID: opts.AccountPublicID, + Version: opts.AccountVersion, + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to get account", "serviceError", serviceErr) + return "", serviceErr + } + + accountID := accountDTO.ID() + signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ + RequestID: opts.RequestID, + Token: s.jwt.DynamicRegistrationIAT(tokens.DynamicRegistrationIATOptions{ + AccountPublicID: opts.AccountPublicID, + AccountVersion: opts.AccountVersion, + IssuerDomain: fmt.Sprintf("%s.%s", accountDTO.Username, opts.BackendDomain), + Domain: opts.Domain, + ClientID: utils.Base62UUID(), + }), + GetJWKfn: s.BuildGetEncryptedAccountJWKFn(ctx, BuildGetEncryptedAccountJWKFnOptions{ + RequestID: opts.RequestID, + KeyType: database.TokenKeyTypeDynamicRegistration, + AccountID: accountID, + }), + GetDecryptDEKfn: s.BuildGetDecAccountDEKFn(ctx, BuildGetDecAccountDEKFnOptions{ + RequestID: opts.RequestID, + AccountID: accountID, + }), + GetEncryptDEKfn: s.BuildGetEncAccountDEKfn(ctx, BuildGetEncAccountDEKOptions{ + RequestID: opts.RequestID, + AccountID: accountID, + }), + StoreFN: s.BuildUpdateJWKDEKFn(ctx, BuildUpdateJWKDEKFnOptions{ + RequestID: opts.RequestID, + }), + }) + if serviceErr != nil { + logger.ErrorContext(ctx, "Failed to sign app credentials registration IAT", "serviceError", serviceErr) + return "", serviceErr + } + + logger.InfoContext(ctx, "Created app credentials registration IAT successfully") + return signedToken, nil +} + +type ProcessAppCredentialsRegistrationIATAuthOptions struct { + RequestID string + AuthHeader string + AccountUsername string + AccountID int32 + BackendDomain string +} + +func (s *Services) ProcessAppCredentialsRegistrationIATAuth( + ctx context.Context, + opts ProcessAppCredentialsRegistrationIATAuthOptions, +) (string, tokens.AccountClaims, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, appDynamicRegistrationIATLocation, "ProcessAppCredentialsRegistrationIATAuth") + logger.InfoContext(ctx, "Processing app credentials registration IAT auth...") + + token, serviceErr := extractAuthHeaderToken(opts.AuthHeader) + if serviceErr != nil { + logger.WarnContext(ctx, "Failed to extract token from auth header", "serviceError", serviceErr) + return "", tokens.AccountClaims{}, serviceErr + } + + domain, accountClaims, err := s.jwt.VerifyDynamicRegistrationIAT( + ctx, + tokens.VerifyDynamicRegistrationIATOptions{ + RequestID: opts.RequestID, + IAT: token, + IssuerDomain: fmt.Sprintf("%s.%s", opts.AccountUsername, opts.BackendDomain), + GetPublicJWK: s.buildVerifyAccountKeyFn(ctx, logger, buildVerifyAccountKeyFnOptions{ + requestID: opts.RequestID, + accountID: opts.AccountID, + keyType: database.TokenKeyTypeDynamicRegistration, + }), + }, + ) + if err != nil { + logger.WarnContext(ctx, "Failed to verify app credentials registration IAT", "error", err) + return "", tokens.AccountClaims{}, exceptions.NewUnauthorizedError() + } + + logger.InfoContext(ctx, "Processed app credentials registration IAT auth successfully") + return domain, accountClaims, nil +} diff --git a/idp/internal/services/apps.go b/idp/internal/services/apps.go index 7b24315..56ae76a 100644 --- a/idp/internal/services/apps.go +++ b/idp/internal/services/apps.go @@ -589,9 +589,10 @@ func mapAuthProviders(authProviders []string) ([]database.AuthProvider, *excepti } type checkForDuplicateAppsOptions struct { - requestID string - accountID int32 - name string + requestID string + accountID int32 + name string + softwareID string } func (s *Services) checkForDuplicateApps( @@ -604,10 +605,21 @@ func (s *Services) checkForDuplicateApps( ) logger.InfoContext(ctx, "Checking for duplicate apps...") - count, err := s.database.CountAppsByAccountIDAndName(ctx, database.CountAppsByAccountIDAndNameParams{ - AccountID: opts.accountID, - ClientName: opts.name, - }) + var count int64 + var err error + if opts.softwareID != "" { + count, err = s.database.CountAppsByAccountIDAndCliantNameOrSoftwareID(ctx, database.CountAppsByAccountIDAndCliantNameOrSoftwareIDParams{ + AccountID: opts.accountID, + ClientName: opts.name, + SoftwareID: pgtype.Text{String: opts.softwareID, Valid: true}, + }) + } else { + count, err = s.database.CountAppsByAccountIDAndName(ctx, database.CountAppsByAccountIDAndNameParams{ + AccountID: opts.accountID, + ClientName: opts.name, + }) + } + if err != nil { logger.ErrorContext(ctx, "Failed to count apps by name", "error", err) return exceptions.FromDBError(err) @@ -878,9 +890,10 @@ func (s *Services) CreateWebApp( name := strings.TrimSpace(opts.Name) if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: accountID, - name: name, + requestID: opts.RequestID, + accountID: accountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) return dtos.AppDTO{}, serviceErr @@ -1058,9 +1071,10 @@ func (s *Services) CreateSPANativeApp( name := strings.TrimSpace(opts.Name) if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: accountID, - name: name, + requestID: opts.RequestID, + accountID: accountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) return dtos.AppDTO{}, serviceErr @@ -1160,9 +1174,10 @@ func (s *Services) CreateBackendApp( name := strings.TrimSpace(opts.Name) if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: accountID, - name: name, + requestID: opts.RequestID, + accountID: accountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) return dtos.AppDTO{}, serviceErr @@ -1333,9 +1348,10 @@ func (s *Services) CreateDeviceApp( name := strings.TrimSpace(opts.Name) if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: accountID, - name: name, + requestID: opts.RequestID, + accountID: accountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) return dtos.AppDTO{}, serviceErr @@ -1553,9 +1569,10 @@ func (s *Services) CreateServiceApp( name := strings.TrimSpace(opts.Name) if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: accountID, - name: name, + requestID: opts.RequestID, + accountID: accountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) return dtos.AppDTO{}, serviceErr @@ -1810,9 +1827,10 @@ func (s *Services) CreateMCPApp( name := strings.TrimSpace(opts.Name) if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: accountID, - name: name, + requestID: opts.RequestID, + accountID: accountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) return dtos.AppDTO{}, serviceErr @@ -2228,9 +2246,10 @@ func (s *Services) UpdateWebSPANativeApp( name := strings.TrimSpace(opts.Name) if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: opts.AccountID, - name: name, + requestID: opts.RequestID, + accountID: opts.AccountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) } @@ -2301,9 +2320,10 @@ func (s *Services) UpdateBackendApp( name := strings.TrimSpace(opts.Name) if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: opts.AccountID, - name: name, + requestID: opts.RequestID, + accountID: opts.AccountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) } @@ -2383,9 +2403,10 @@ func (s *Services) UpdateDeviceApp( name := strings.TrimSpace(opts.Name) if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: opts.AccountID, - name: name, + requestID: opts.RequestID, + accountID: opts.AccountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) } @@ -2551,9 +2572,10 @@ func (s *Services) UpdateServiceApp( name := strings.TrimSpace(opts.Name) if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: opts.AccountID, - name: name, + requestID: opts.RequestID, + accountID: opts.AccountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) } @@ -2641,9 +2663,10 @@ func (s *Services) UpdateMCPApp( name := strings.TrimSpace(opts.Name) if appDTO.ClientName != name { if serviceErr := s.checkForDuplicateApps(ctx, checkForDuplicateAppsOptions{ - requestID: opts.RequestID, - accountID: opts.AccountID, - name: name, + requestID: opts.RequestID, + accountID: opts.AccountID, + name: name, + softwareID: opts.SoftwareID, }); serviceErr != nil { logger.ErrorContext(ctx, "Duplicate app found", "serviceError", serviceErr) } diff --git a/idp/internal/services/auth.go b/idp/internal/services/auth.go index 93d2efd..90e78fa 100644 --- a/idp/internal/services/auth.go +++ b/idp/internal/services/auth.go @@ -33,41 +33,6 @@ const ( resetMessage string = "Password reset successfully" ) -type processPurposeAuthHeaderOptions struct { - requestID string - authHeader string - tokenPurpose tokens.TokenPurpose - tokenKeyType database.TokenKeyType -} - -func (s *Services) processPurposeAuthHeader( - ctx context.Context, - opts processPurposeAuthHeaderOptions, -) (tokens.AccountClaims, *exceptions.ServiceError) { - logger := s.buildLogger(opts.requestID, authLocation, "processPurposeAuthHeader") - logger.InfoContext(ctx, "Processing purpose auth header...") - - token, serviceErr := extractAuthHeaderToken(opts.authHeader) - if serviceErr != nil { - return tokens.AccountClaims{}, serviceErr - } - - accountClaims, err := s.jwt.VerifyPurposeToken( - token, - opts.tokenPurpose, - s.BuildGetGlobalPublicKeyFn(ctx, BuildGetGlobalVerifyKeyFnOptions{ - RequestID: opts.requestID, - KeyType: opts.tokenKeyType, - }), - ) - if err != nil { - logger.ErrorContext(ctx, "Failed to verify purpose token", "error", err) - return tokens.AccountClaims{}, exceptions.NewUnauthorizedError() - } - - return accountClaims, nil -} - type ProcessAuthHeaderOptions struct { RequestID string AuthHeader string diff --git a/idp/internal/services/dtos/app_dynamic_registration_config.go b/idp/internal/services/dtos/app_dynamic_registration_config.go new file mode 100644 index 0000000..168fbc6 --- /dev/null +++ b/idp/internal/services/dtos/app_dynamic_registration_config.go @@ -0,0 +1,61 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package dtos + +import "github.com/tugascript/devlogs/idp/internal/providers/database" + +type AppDynamicRegistrationConfigDTO struct { + id int32 + + AllowedAppTypes []database.AppType `json:"allowed_app_types"` + WhitelistedDomains []string `json:"whitelisted_domains"` + DefaultAllowUserRegistration bool `json:"default_allow_user_registration"` + DefaultAuthProviders []database.AuthProvider `json:"default_auth_providers"` + DefaultUsernameColumn database.AppUsernameColumn `json:"default_username_column"` + DefaultAllowedScopes []database.Scopes `json:"default_allowed_scopes"` + DefaultScopes []database.Scopes `json:"default_scopes"` + RequireVerifiedDomainsAppTypes []database.AppType `json:"require_verified_domains_app_types"` + RequireSoftwareStatementAppTypes []database.AppType `json:"require_software_statement_app_types"` + SoftwareStatementVerificationMethods []database.SoftwareStatementVerificationMethod `json:"software_statement_verification_methods"` + RequireInitialAccessTokenAppTypes []database.AppType `json:"require_initial_access_token_app_types"` + InitialAccessTokenGenerationMethods []database.InitialAccessTokenGenerationMethod `json:"initial_access_token_generation_methods"` + InitialAccessTokenTtl int32 `json:"initial_access_token_ttl"` + InitialAccessTokenMaxUses int32 `json:"initial_access_token_max_uses"` + AllowedGrantTypes []database.GrantType `json:"allowed_grant_types"` + AllowedResponseTypes []database.ResponseType `json:"allowed_response_types"` + AllowedTokenEndpointAuthMethods []database.AuthMethod `json:"allowed_token_endpoint_auth_methods"` + MaxRedirectUris int32 `json:"max_redirect_uris"` +} + +func (a *AppDynamicRegistrationConfigDTO) ID() int32 { + return a.id +} + +func MapAppDynamicRegistrationConfigToDTO( + config *database.AppDynamicRegistrationConfig, +) AppDynamicRegistrationConfigDTO { + return AppDynamicRegistrationConfigDTO{ + id: config.ID, + AllowedAppTypes: config.AllowedAppTypes, + DefaultAllowUserRegistration: config.DefaultAllowUserRegistration, + DefaultAuthProviders: config.DefaultAuthProviders, + DefaultUsernameColumn: config.DefaultUsernameColumn, + DefaultAllowedScopes: config.DefaultAllowedScopes, + DefaultScopes: config.DefaultScopes, + RequireVerifiedDomainsAppTypes: config.RequireVerifiedDomainsAppTypes, + RequireSoftwareStatementAppTypes: config.RequireSoftwareStatementAppTypes, + SoftwareStatementVerificationMethods: config.SoftwareStatementVerificationMethods, + RequireInitialAccessTokenAppTypes: config.RequireInitialAccessTokenAppTypes, + InitialAccessTokenGenerationMethods: config.InitialAccessTokenGenerationMethods, + InitialAccessTokenTtl: config.InitialAccessTokenTtl, + InitialAccessTokenMaxUses: config.InitialAccessTokenMaxUses, + AllowedGrantTypes: config.AllowedGrantTypes, + AllowedResponseTypes: config.AllowedResponseTypes, + AllowedTokenEndpointAuthMethods: config.AllowedTokenEndpointAuthMethods, + MaxRedirectUris: config.MaxRedirectUris, + } +} diff --git a/idp/internal/services/dynamic_registration_domains.go b/idp/internal/services/dynamic_registration_domains.go index e6d9815..043f5d4 100644 --- a/idp/internal/services/dynamic_registration_domains.go +++ b/idp/internal/services/dynamic_registration_domains.go @@ -12,6 +12,7 @@ import ( "time" "github.com/google/uuid" + "golang.org/x/net/publicsuffix" "github.com/tugascript/devlogs/idp/internal/exceptions" "github.com/tugascript/devlogs/idp/internal/providers/crypto" @@ -218,6 +219,54 @@ func (s *Services) GetAccountCredentialsRegistrationDomain( return dtos.MapAccountCredentialsRegistrationDomainToDTO(&domainDTO), nil } +type GetAppCredentialsRegistrationDomainOptions struct { + RequestID string + AccountPublicID uuid.UUID + Domain string +} + +func (s *Services) GetAppCredentialsRegistrationDomain( + ctx context.Context, + opts GetAppCredentialsRegistrationDomainOptions, +) (dtos.DynamicRegistrationDomainDTO, *exceptions.ServiceError) { + logger := s.buildLogger(opts.RequestID, dynamicRegistrationDomainsLocation, "GetAppCredentialsRegistrationDomain").With( + "accountPublicID", opts.AccountPublicID, + "domain", opts.Domain, + ) + logger.InfoContext(ctx, "Getting app credentials registration domain...") + + domain, err := s.database.FindDynamicRegistrationDomainByAccountPublicIDAndDomain(ctx, database.FindDynamicRegistrationDomainByAccountPublicIDAndDomainParams{ + AccountPublicID: opts.AccountPublicID, + Domain: opts.Domain, + }) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code != exceptions.CodeNotFound { + logger.ErrorContext(ctx, "Failed to find app dynamic registration domain", "error", err) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + logger.WarnContext(ctx, "App dynamic registration domain not found", "domain", opts.Domain) + return dtos.DynamicRegistrationDomainDTO{}, serviceErr + } + + // Verify that the domain has App usage + hasAppUsage := false + for _, usage := range domain.Usages { + if usage == database.DynamicRegistrationUsageApp { + hasAppUsage = true + break + } + } + if !hasAppUsage { + logger.WarnContext(ctx, "Domain does not have app credentials registration usage", "domain", opts.Domain) + return dtos.DynamicRegistrationDomainDTO{}, exceptions.NewNotFoundValidationError("App dynamic registration domain not found") + } + + logger.InfoContext(ctx, "Found app dynamic registration domain", "domain", opts.Domain) + return dtos.MapAccountCredentialsRegistrationDomainToDTO(&domain), nil +} + type ListAccountCredentialsRegistrationDomainsOptions struct { RequestID string AccountPublicID uuid.UUID @@ -795,3 +844,90 @@ func (s *Services) DeleteAccountCredentialsRegistrationDomainCode( logger.InfoContext(ctx, "Deleted account credentials registration domain successfully") return nil } + +type checkClientRegistrationDomainOptions struct { + requestID string + accountPublicID uuid.UUID + usages []database.DynamicRegistrationUsage + domain string + iatDomain string + requireVerifiedDomains bool +} + +func (s *Services) checkClientRegistrationDomain( + ctx context.Context, + opts checkClientRegistrationDomainOptions, +) (string, *exceptions.ServiceError) { + logger := s.buildLogger(opts.requestID, dynamicRegistrationDomainsLocation, "checkClientRegistrationDomain").With( + "domain", opts.domain, + ) + logger.InfoContext(ctx, "Checking client registrationdomain validity") + + baseDomain, err := publicsuffix.EffectiveTLDPlusOne(opts.domain) + if err != nil { + logger.WarnContext(ctx, "Failed to parse base domain", "error", err) + return "", exceptions.NewValidationError("invalid client URI") + } + if opts.iatDomain != "" && opts.domain != opts.iatDomain && baseDomain != opts.iatDomain { + logger.WarnContext(ctx, "Client URI base domain does not match IAT domain", + "baseDomain", baseDomain, + "iatDomain", opts.iatDomain, + ) + return "", exceptions.NewUnauthorizedError() + } + + var count int64 + if baseDomain != opts.domain { + if opts.requireVerifiedDomains { + count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages( + ctx, + database.CountVerifiedDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsagesParams{ + AccountPublicID: opts.accountPublicID, + Usages: opts.usages, + Domains: []string{opts.domain, baseDomain}, + }, + ) + } else { + count, err = s.database.CountDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages( + ctx, + database.CountDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsagesParams{ + AccountPublicID: opts.accountPublicID, + Usages: opts.usages, + Domains: []string{opts.domain, baseDomain}, + }, + ) + } + } else { + if opts.requireVerifiedDomains { + count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomainAccountPublicIDAndUsages( + ctx, + database.CountVerifiedDynamicRegistrationDomainsByDomainAccountPublicIDAndUsagesParams{ + AccountPublicID: opts.accountPublicID, + Domain: opts.domain, + Usages: opts.usages, + }, + ) + } else { + count, err = s.database.CountDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUsages( + ctx, + database.CountDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUsagesParams{ + AccountPublicID: opts.accountPublicID, + Domain: opts.domain, + Usages: appDynamicRegistrationUsages, + }, + ) + } + } + + if err != nil { + logger.ErrorContext(ctx, "Failed to count verified dynamic registration domains", "error", err) + return "", exceptions.FromDBError(err) + } + if count > 0 { + logger.InfoContext(ctx, "Credential domain is whitelisted") + return baseDomain, nil + } + + logger.InfoContext(ctx, "Domain is not whitelisted or verified") + return "", exceptions.NewUnauthorizedError() +} diff --git a/idp/internal/services/oauth.go b/idp/internal/services/oauth.go index fa25070..b3ffcf6 100644 --- a/idp/internal/services/oauth.go +++ b/idp/internal/services/oauth.go @@ -477,12 +477,12 @@ func (s *Services) OAuthLoginAccount( ) } -func (s *Services) GetAccountPublicJWKs( +func (s *Services) GetGlobalPublicJWKs( ctx context.Context, requestID string, ) (string, dtos.JWKsDTO, *exceptions.ServiceError) { - logger := s.buildLogger(requestID, oauthLocation, "GetAccountPublicKeys") - logger.InfoContext(ctx, "Getting account public JWKs...") + logger := s.buildLogger(requestID, oauthLocation, "GetGlobalPublicJWKs") + logger.InfoContext(ctx, "Getting global public JWKs...") etag, jwks, serviceErr := s.GetAndCacheGlobalDistributedJWK(ctx, requestID) if serviceErr != nil { @@ -490,7 +490,7 @@ func (s *Services) GetAccountPublicJWKs( return "", dtos.JWKsDTO{}, serviceErr } - logger.InfoContext(ctx, "Got account public JWKs successfully") + logger.InfoContext(ctx, "Got global public JWKs successfully") return etag, dtos.NewJWKsDTO(jwks), nil } diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index 6bc428c..323a123 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -1047,12 +1047,11 @@ func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( return "", "", serviceErr } - domainDTO, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ + if _, serviceErr := s.GetAccountCredentialsRegistrationDomain(ctx, GetAccountCredentialsRegistrationDomainOptions{ RequestID: opts.RequestID, AccountPublicID: accountDTO.PublicID, Domain: data.Domain, - }) - if serviceErr != nil { + }); serviceErr != nil { if serviceErr.Code != exceptions.CodeNotFound { logger.ErrorContext(ctx, "Failed to get account credentials registration domain", "serviceError", serviceErr) return "", "", serviceErr @@ -1072,10 +1071,6 @@ func (s *Services) OAuthDynamicRegistrationIATVerify2FACode( logger.WarnContext(ctx, "Account credentials registration domain not found") return "", "", exceptions.NewForbiddenError() } - if !domainDTO.Verified { - logger.ErrorContext(ctx, "Account credentials registration domain is not verified") - return "", "", exceptions.NewForbiddenError() - } sessionKey, err := s.cache.CreateAccountCredentialsRegistrationSessionKey( ctx, @@ -1185,7 +1180,7 @@ func (s *Services) VerifyOAuthDynamicRegistrationIATCode( tokenTTL := s.jwt.GetDynamicRegistrationTTL() signedToken, serviceErr := s.crypto.SignToken(ctx, crypto.SignTokenOptions{ RequestID: opts.RequestID, - Token: s.jwt.CreateAccountCredentialsDynamicRegistrationToken(tokens.AccountCredentialsDynamicRegistrationTokenOptions{ + Token: s.jwt.DynamicRegistrationIAT(tokens.DynamicRegistrationIATOptions{ AccountPublicID: accountDTO.PublicID, AccountVersion: accountDTO.Version(), Domain: data.Domain, diff --git a/idp/internal/services/software_statement.go b/idp/internal/services/software_statement.go index 29538f1..a38505f 100644 --- a/idp/internal/services/software_statement.go +++ b/idp/internal/services/software_statement.go @@ -8,18 +8,24 @@ package services import ( "context" + "errors" "fmt" + "net/url" "slices" "strings" "time" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "github.com/tugascript/devlogs/idp/internal/exceptions" + "github.com/tugascript/devlogs/idp/internal/providers/database" "github.com/tugascript/devlogs/idp/internal/providers/tokens" "github.com/tugascript/devlogs/idp/internal/utils" ) +const softwareStatementLocation = "software_statement" + type ApplicationRegistrationData struct { RedirectURIs []string TokenEndpointAuthMethod string @@ -70,7 +76,7 @@ func (s *Services) verifySoftwareStatementSTDClaims( ctx context.Context, opts verifySoftwareStatementSTDClaimsOptions, ) *exceptions.ServiceError { - logger := s.buildLogger(opts.requestID, accountCredentialsRegistrationLocation, "verifySoftwareStatementSTDClaims").With( + logger := s.buildLogger(opts.requestID, softwareStatementLocation, "verifySoftwareStatementSTDClaims").With( "domain", opts.domain, "baseDomain", opts.baseDomain, ) @@ -114,95 +120,6 @@ func (s *Services) verifySoftwareStatementSTDClaims( return nil } -type verifySoftwareStatementClaimsOptions struct { - requestID string - clientName string - clientURI string - logoURI string - tosURI string - policyURI string - contacts []string - softwareID string - softwareVersion string - jwksURI string - jwks []string - claims *tokens.SoftwareStatementClaims -} - -func (s *Services) verifySoftwareStatementClaims( - ctx context.Context, - opts verifySoftwareStatementClaimsOptions, -) error { - logger := s.buildLogger( - opts.requestID, - accountCredentialsRegistrationLocation, - "verifySoftwareStatementClaims", - ) - - if opts.claims.ClientName != opts.clientName { - logger.WarnContext(ctx, "Client name in software statement does not match", - "expected", opts.clientName, "got", opts.claims.ClientName, - ) - return exceptions.NewUnauthorizedError() - } - if opts.claims.ClientURI != opts.clientURI { - logger.WarnContext(ctx, "Client URI in software statement does not match", - "expected", opts.clientURI, "got", opts.claims.ClientURI, - ) - return exceptions.NewUnauthorizedError() - } - - if opts.claims.LogoURI != "" && opts.logoURI != "" && opts.claims.LogoURI != opts.logoURI { - logger.WarnContext(ctx, "Logo URI in software statement does not match", - "expected", opts.logoURI, "got", opts.claims.LogoURI, - ) - return exceptions.NewUnauthorizedError() - } - if opts.claims.TOSURI != "" && opts.tosURI != "" && opts.claims.TOSURI != opts.tosURI { - logger.WarnContext(ctx, "Terms of Service URI in software statement does not match", - "expected", opts.tosURI, "got", opts.claims.TOSURI, - ) - return exceptions.NewUnauthorizedError() - } - if opts.claims.PolicyURI != "" && opts.policyURI != "" && opts.claims.PolicyURI != opts.policyURI { - logger.WarnContext(ctx, "Policy URI in software statement does not match", - "expected", opts.policyURI, "got", opts.claims.PolicyURI, - ) - return exceptions.NewUnauthorizedError() - } - if opts.claims.SoftwareID != "" && opts.softwareID != "" && opts.claims.SoftwareID != opts.softwareID { - logger.WarnContext(ctx, "Software id in software statement does not match", - "expected", opts.softwareID, "got", opts.claims.SoftwareID, - ) - return exceptions.NewUnauthorizedError() - } - if opts.claims.SoftwareVersion != "" && opts.softwareVersion != "" && opts.claims.SoftwareVersion != opts.softwareVersion { - logger.WarnContext(ctx, "Software version in software statement does not match", - "expected", opts.softwareVersion, "got", opts.claims.SoftwareVersion, - ) - return exceptions.NewUnauthorizedError() - } - if opts.claims.JWKsURI != "" && opts.jwksURI != "" && opts.claims.JWKsURI != opts.jwksURI { - logger.WarnContext(ctx, "JWKs URI in software statement does not match", "expected", - opts.jwksURI, "got", opts.claims.JWKsURI, - ) - return exceptions.NewUnauthorizedError() - } - - if len(opts.claims.Contacts) > 0 && len(opts.contacts) > 0 { - claimsSet := utils.SliceToHashSet(opts.claims.Contacts) - for _, c := range opts.contacts { - if !claimsSet.Contains(c) { - logger.WarnContext(ctx, "Contact in registration not present in software statement", "contact", c) - return exceptions.NewUnauthorizedError() - } - } - } - - logger.InfoContext(ctx, "Verified software statement registration claims") - return nil -} - // validateEncryptionAlgorithmPair validates that encryption algorithm and encoding are both set or both unset func validateEncryptionAlgorithmPair(alg, enc string) bool { if enc != "" && alg == "" { @@ -225,7 +142,7 @@ func (s *Services) validateSoftwareStatementClaims( ) *exceptions.ServiceError { logger := s.buildLogger( opts.requestID, - accountCredentialsRegistrationLocation, + softwareStatementLocation, "validateSoftwareStatementClaims", ) logger.InfoContext(ctx, "Validating software statement claims") @@ -431,7 +348,7 @@ func (s *Services) validateSoftwareStatementClaims( logger.WarnContext(ctx, "Default max age mismatch", "expected", opts.data.DefaultMaxAge, "got", opts.claims.DefaultMaxAge) return exceptions.NewValidationError("default max age mismatch") } - if opts.claims.RequireAuthTime != false && opts.data.RequireAuthTime != false && opts.claims.RequireAuthTime != opts.data.RequireAuthTime { + if !opts.claims.RequireAuthTime && opts.claims.RequireAuthTime != opts.data.RequireAuthTime { logger.WarnContext(ctx, "Require auth time mismatch", "expected", opts.data.RequireAuthTime, "got", opts.claims.RequireAuthTime) return exceptions.NewValidationError("require auth time mismatch") } @@ -522,3 +439,166 @@ func (s *Services) validateSoftwareStatementClaims( logger.InfoContext(ctx, "Validated software statement claims") return nil } + +type buildDynamicRegistrationSoftwareStatementFuncOptions struct { + requestID string + accountPublicID uuid.UUID + verificationMethods []database.SoftwareStatementVerificationMethod + jwksURI string + jwks []string + domain string + baseDomain string +} + +func (s *Services) buildDynamicRegistrationSoftwareStatementFunc( + ctx context.Context, + opts buildDynamicRegistrationSoftwareStatementFuncOptions, +) tokens.GetUnknownPublicJWK { + logger := s.buildLogger(opts.requestID, softwareStatementLocation, "buildDynamicRegistrationSoftwareStatementFunc").With( + "accountPublicID", opts.accountPublicID, + ) + logger.InfoContext(ctx, "Checking dynamic registration software statement validity") + + if slices.Contains(opts.verificationMethods, database.SoftwareStatementVerificationMethodJwksUri) && opts.jwksURI != "" { + return func(kid string) (utils.JWK, error) { + parsedURI, err := url.Parse(opts.jwksURI) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse JWKs URI", "error", err) + return nil, errors.New("invalid JWKs URI") + } + if parsedURI.Host != opts.baseDomain || !strings.Contains(parsedURI.Host, "."+opts.baseDomain) { + logger.WarnContext(ctx, "JWKs URI parsedURI does not match client URI parsedURI") + return nil, errors.New("JWKs URI parsedURI does not match client URI parsedURI") + } + + jwks, err := s.jwt.GetPublicJWKs(ctx, tokens.GetPublicJWKsOptions{ + RequestID: opts.requestID, + URL: opts.jwksURI, + }) + if err != nil { + logger.WarnContext(ctx, "Failed to get public JWKs from JWKs URI", "error", err) + return nil, errors.New("failed to get public JWKs from JWKs URI") + } + + jwkIdx := slices.IndexFunc(jwks.Keys, func(jwk utils.JWK) bool { + return jwk.GetKeyID() == kid + }) + if jwkIdx == -1 { + logger.WarnContext(ctx, "No matching JWK found for KID in JWKs URI", "kid", kid) + return nil, errors.New("no matching JWK found for KID in JWKs URI") + } + + return jwks.Keys[jwkIdx], nil + } + } + if slices.Contains(opts.verificationMethods, database.SoftwareStatementVerificationMethodManual) { + if len(opts.jwks) > 0 { + return func(kid string) (utils.JWK, error) { + jwks := make([]utils.JWK, 0, len(opts.jwks)) + for _, rawJWK := range opts.jwks { + jwk, err := utils.JsonToJWK([]byte(rawJWK)) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse manual JWK", "error", err) + return nil, errors.New("failed to parse manual JWK") + } + jwks = append(jwks, jwk) + } + + jwkIdx := slices.IndexFunc(jwks, func(jwk utils.JWK) bool { + return jwk.GetKeyID() == kid + }) + if jwkIdx == -1 { + logger.WarnContext(ctx, "No matching manual JWK found for KID", "kid", kid) + return nil, errors.New("no matching manual JWK found for KID") + } + + sliceJWK := jwks[jwkIdx] + jwkRefEnt, err := s.database.FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicID( + ctx, + database.FindDynamicRegistrationSoftwareStatementKeysByCredentialsKeyKIDAndAccountPublicIDParams{ + CredentialsKeyKid: kid, + AccountPublicID: opts.accountPublicID, + }, + ) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "No database entry found for manual JWK", "kid", kid, "error", err) + return nil, errors.New("no database entry found for manual JWK") + } + + logger.ErrorContext(ctx, "Failed to find database entry for manual JWK", "kid", kid, "error", err) + return nil, errors.New("failed to find database entry for manual JWK") + } + if jwkRefEnt.RootDomain != opts.baseDomain { + logger.WarnContext(ctx, "Manual JWK root domain does not match client URI base domain", + "kid", kid, "jwkRootDomain", jwkRefEnt.RootDomain, "baseDomain", opts.baseDomain, + ) + return nil, errors.New("manual JWK root domain does not match client URI base domain") + } + + jwkEnt, err := s.database.FindCredentialsKeyByID(ctx, jwkRefEnt.CredentialsKeyID) + if err != nil { + serviceErr := exceptions.FromDBError(err) + if serviceErr.Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "No credentials key found for manual JWK", "kid", kid, "error", err) + return nil, errors.New("no credentials key found for manual JWK") + } + + logger.ErrorContext(ctx, "Failed to find credentials key for manual JWK", "kid", kid, "error", err) + return nil, errors.New("failed to find credentials key for manual JWK") + } + + entJWK, err := utils.JsonToJWK(jwkEnt.PublicKey) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse manual JWK", "error", err) + return nil, errors.New("failed to parse manual JWK") + } + if !entJWK.ComparePublicKey(sliceJWK) { + logger.WarnContext(ctx, "Manual JWK does not match database credentials key", "kid", kid) + return nil, errors.New("manual JWK does not match database credentials key") + } + + return sliceJWK, nil + } + } + + return func(kid string) (utils.JWK, error) { + jwkEntity, err := s.database.FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicID( + ctx, + database.FindDynamicRegistrationSoftwareStatementKeysByRootDomainAndAccountPublicIDParams{ + RootDomain: opts.baseDomain, + AccountPublicID: opts.accountPublicID, + }, + ) + if err != nil { + if exceptions.FromDBError(err).Code == exceptions.CodeNotFound { + logger.WarnContext(ctx, "No manual JWKs found for software statement", "error", err) + return nil, errors.New("no manual JWKs found for software statement") + } + + logger.ErrorContext(ctx, "Failed to find manual JWKs for software statement", "error", err) + return nil, errors.New("failed to find manual JWKs for software statement") + } + if jwkEntity.PublicKid != kid { + logger.WarnContext(ctx, "No matching manual JWK found for KID", + "kid", kid, "publicKID", jwkEntity.PublicKid, + ) + return nil, errors.New("no matching manual JWK found for KID") + } + + jwk, err := utils.JsonToJWK(jwkEntity.PublicKey) + if err != nil { + logger.ErrorContext(ctx, "Failed to parse manual JWK for software statement", "error", err) + return nil, errors.New("failed to parse manual JWK for software statement") + } + + return jwk, nil + } + } + + return func(kid string) (utils.JWK, error) { + logger.WarnContext(ctx, "No verification method available for software statement") + return nil, errors.New("no verification method available") + } +} diff --git a/idp/internal/services/users_oauth.go b/idp/internal/services/users_oauth.go new file mode 100644 index 0000000..1df7cfd --- /dev/null +++ b/idp/internal/services/users_oauth.go @@ -0,0 +1,7 @@ +// Copyright (c) 2025 Afonso Barracha +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package services From a0da1ae36d02a7ac2d9abf907290d4bbcde39b78 Mon Sep 17 00:00:00 2001 From: Afonso Barracha Date: Sun, 21 Dec 2025 17:42:22 +1300 Subject: [PATCH 23/23] chore: push current changes --- idp/internal/controllers/middleware.go | 31 ++++-- ...ccount_credentials_dynamic_registration.go | 3 + .../dynamic_registration_domains.sql.go | 46 +++++++- .../queries/dynamic_registration_domains.sql | 20 +++- idp/internal/server/routes/oauth.go | 23 ++-- idp/internal/server/routes/well_known.go | 2 +- .../account_credentials_registration_iat.go | 8 +- .../services/app_dynamic_registration.go | 2 - .../services/app_dynamic_registration_iat.go | 43 -------- .../services/dynamic_registration_domains.go | 13 +++ .../services/oauth_dynamic_registration.go | 102 ++++++++++++------ keygen/main.go | 2 +- project.md | 2 + 13 files changed, 190 insertions(+), 107 deletions(-) diff --git a/idp/internal/controllers/middleware.go b/idp/internal/controllers/middleware.go index 906e942..62c0505 100644 --- a/idp/internal/controllers/middleware.go +++ b/idp/internal/controllers/middleware.go @@ -8,6 +8,7 @@ package controllers import ( "errors" + "fmt" "strings" "github.com/gofiber/fiber/v2" @@ -164,9 +165,21 @@ func (c *Controllers) AppAccessClaimsMiddleware(ctx *fiber.Ctx) error { return ctx.Next() } -func (c *Controllers) AccountCredentialsDRIATMiddleware(ctx *fiber.Ctx) error { +func processIATIssuerDomain(ctx *fiber.Ctx, backendDomain string) (string, *exceptions.ServiceError) { + hasAccountHost, ok := ctx.Locals("hasAccountHost").(bool) + if ok && hasAccountHost { + username, _, serviceErr := getHostAccount(ctx) + if serviceErr != nil { + return "", serviceErr + } + return fmt.Sprintf("%s.%s", username, backendDomain), nil + } + return backendDomain, nil +} + +func (c *Controllers) DynamicRegistrationIATMiddleware(ctx *fiber.Ctx) error { requestID := getRequestID(ctx) - logger := c.buildLogger(requestID, middlewareLocation, "AccountCredentialsDRIATMiddleware") + logger := c.buildLogger(requestID, middlewareLocation, "DynamicRegistrationIATMiddleware") authHeader := ctx.Get("Authorization") if authHeader == "" { @@ -175,17 +188,21 @@ func (c *Controllers) AccountCredentialsDRIATMiddleware(ctx *fiber.Ctx) error { return ctx.Next() } + issDomain, serviceErr := processIATIssuerDomain(ctx, c.backendDomain) + if serviceErr != nil { + return serviceErrorResponse(logger, ctx, serviceErr) + } + domain, accountClaims, serviceErr := c.services.ProcessAccountCredentialsRegistrationIATAuth( ctx.UserContext(), services.ProcessAccountCredentialsRegistrationIATAuthOptions{ - RequestID: requestID, - AuthHeader: authHeader, - BackendDomain: c.backendDomain, + RequestID: requestID, + AuthHeader: authHeader, + IssuerDomain: issDomain, }, ) if serviceErr != nil { - ctx.Set(fiber.HeaderWWWAuthenticate, "Bearer realm=\"accounts\", error=\"invalid_token\"") - return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorInvalidToken) + return oauthErrorResponse(logger, ctx, exceptions.OAuthErrorAccessDenied) } ctx.Locals("account", accountClaims) diff --git a/idp/internal/providers/cache/account_credentials_dynamic_registration.go b/idp/internal/providers/cache/account_credentials_dynamic_registration.go index f776a48..4b8c0b4 100644 --- a/idp/internal/providers/cache/account_credentials_dynamic_registration.go +++ b/idp/internal/providers/cache/account_credentials_dynamic_registration.go @@ -36,6 +36,7 @@ type AccountCredentialsDynamicRegistrationIATAuthData struct { Domain string `json:"domain"` State string `json:"state"` Challenge string `json:"challenge"` + Username string `json:"username,omitempty"` } type SaveAccountCredentialsDynamicRegistrationIATAuthOptions struct { @@ -44,6 +45,7 @@ type SaveAccountCredentialsDynamicRegistrationIATAuthOptions struct { State string RedirectURI string Challenge string + Username string } func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATAuth( @@ -64,6 +66,7 @@ func (c *Cache) SaveAccountCredentialsDynamicRegistrationIATAuth( Domain: opts.Domain, RedirectURI: opts.RedirectURI, Challenge: opts.Challenge, + Username: opts.Username, } dataBytes, err := json.Marshal(data) if err != nil { diff --git a/idp/internal/providers/database/dynamic_registration_domains.sql.go b/idp/internal/providers/database/dynamic_registration_domains.sql.go index cc0d50a..fb0ae56 100644 --- a/idp/internal/providers/database/dynamic_registration_domains.sql.go +++ b/idp/internal/providers/database/dynamic_registration_domains.sql.go @@ -41,7 +41,7 @@ SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 AND - "usages" @> $3 + "usages" = ANY($3) LIMIT 1 ` @@ -58,6 +58,26 @@ func (q *Queries) CountDynamicRegistrationDomainsByDomainAndAccountPublicIDAndUs return count, err } +const countDynamicRegistrationDomainsByDomainAndUsages = `-- name: CountDynamicRegistrationDomainsByDomainAndUsages :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "domain" = $1 AND + "usages" = ANY($2) +LIMIT 1 +` + +type CountDynamicRegistrationDomainsByDomainAndUsagesParams struct { + Domain string + Usages []DynamicRegistrationUsage +} + +func (q *Queries) CountDynamicRegistrationDomainsByDomainAndUsages(ctx context.Context, arg CountDynamicRegistrationDomainsByDomainAndUsagesParams) (int64, error) { + row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomainAndUsages, arg.Domain, arg.Usages) + var count int64 + err := row.Scan(&count) + return count, err +} + const countDynamicRegistrationDomainsByDomains = `-- name: CountDynamicRegistrationDomainsByDomains :one SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "domain" IN ($1) @@ -75,7 +95,7 @@ const countDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsages = `-- nam SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND - "usages" @> $2 AND + "usages" = ANY($2) AND "domain" IN ($3) LIMIT 1 ` @@ -93,6 +113,26 @@ func (q *Queries) CountDynamicRegistrationDomainsByDomainsAccountPublicIDAndUsag return count, err } +const countDynamicRegistrationDomainsByDomainsAndUsages = `-- name: CountDynamicRegistrationDomainsByDomainsAndUsages :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "domain" IN ($2) AND + "usages" = ANY($1) +LIMIT 1 +` + +type CountDynamicRegistrationDomainsByDomainsAndUsagesParams struct { + Usages []DynamicRegistrationUsage + Domains []string +} + +func (q *Queries) CountDynamicRegistrationDomainsByDomainsAndUsages(ctx context.Context, arg CountDynamicRegistrationDomainsByDomainsAndUsagesParams) (int64, error) { + row := q.db.QueryRow(ctx, countDynamicRegistrationDomainsByDomainsAndUsages, arg.Usages, arg.Domains) + var count int64 + err := row.Scan(&count) + return count, err +} + const countFilteredDynamicRegistrationDomainsByAccountPublicID = `-- name: CountFilteredDynamicRegistrationDomainsByAccountPublicID :one SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE @@ -131,7 +171,7 @@ SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 AND - "usages" @> $3 AND + "usages" = ANY($3) AND "verified_at" IS NOT NULL LIMIT 1 ` diff --git a/idp/internal/providers/database/queries/dynamic_registration_domains.sql b/idp/internal/providers/database/queries/dynamic_registration_domains.sql index 64a8c51..e8933aa 100644 --- a/idp/internal/providers/database/queries/dynamic_registration_domains.sql +++ b/idp/internal/providers/database/queries/dynamic_registration_domains.sql @@ -118,7 +118,7 @@ SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 AND - "usages" @> $3 AND + "usages" = ANY($3) AND "verified_at" IS NOT NULL LIMIT 1; @@ -126,7 +126,7 @@ LIMIT 1; SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND - "usages" @> $2 AND + "usages" = ANY($2) AND "domain" IN (sqlc.slice('domains')) LIMIT 1; @@ -135,7 +135,21 @@ SELECT COUNT(*) FROM "dynamic_registration_domains" WHERE "account_public_id" = $1 AND "domain" = $2 AND - "usages" @> $3 + "usages" = ANY($3) +LIMIT 1; + +-- name: CountDynamicRegistrationDomainsByDomainAndUsages :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "domain" = $1 AND + "usages" = ANY($2) +LIMIT 1; + +-- name: CountDynamicRegistrationDomainsByDomainsAndUsages :one +SELECT COUNT(*) FROM "dynamic_registration_domains" +WHERE + "domain" IN (sqlc.slice('domains')) AND + "usages" = ANY($1) LIMIT 1; -- name: DeleteDynamicRegistrationDomain :exec diff --git a/idp/internal/server/routes/oauth.go b/idp/internal/server/routes/oauth.go index 403388f..1f6afe4 100644 --- a/idp/internal/server/routes/oauth.go +++ b/idp/internal/server/routes/oauth.go @@ -28,19 +28,18 @@ func (r *Routes) OAuthRoutes(app *fiber.App) { router.Get(paths.OAuthCallback, r.controllers.AccountOAuthCallback) // Register - router.Post(paths.OAuthRegister, r.controllers.HostMiddleware, HostAwareRoute( - []fiber.Handler{ - r.controllers.AccountCredentialsDRIATMiddleware, - r.controllers.OAuthDynamicRegistration, - }, - []fiber.Handler{ - // TODO: add app claims for DR - r.controllers.OAuthDynamicRegistration, - }, - )) + router.Post( + paths.OAuthRegister, + r.controllers.HostMiddleware, + r.controllers.DynamicRegistrationIATMiddleware, + HostAwareRoute( + []fiber.Handler{r.controllers.OAuthDynamicRegistration}, + []fiber.Handler{r.controllers.OAuthAppDynamicRegistration}, + ), + ) // Initial Access Token (IAT) routes - iatRouter := router.Group(paths.InitialAccessToken) + iatRouter := router.Group(paths.InitialAccessToken, r.controllers.HostMiddleware) // Dynamic Registration IAT Code Exchange flow iatRouter.Get(paths.OAuthAuth, r.controllers.OAuthDynamicRegistrationIATAuth) @@ -57,7 +56,7 @@ func (r *Routes) OAuthRoutes(app *fiber.App) { iatRouter.Post(twoFAAuthRoute, r.controllers.OAuthDynamicRegistrationIAT2FAPost) // Dynamic Registration IAT External Auth flow - const extAuthRoute = paths.InitialAccessTokenSingle + paths.OAuthAuth + paths.InitialAccessTokenAuthEXT + const extAuthRoute = paths.InitialAccessTokenSingle + paths.InitialAccessTokenAuthEXT iatRouter.Get(extAuthRoute+paths.InitialAccessTokenProvider, r.controllers.OAuthDynamicRegistrationIATExtAuthGet) iatRouter.Post(extAuthRoute+paths.OAuthAppleCallback, r.controllers.OAuthDynamicRegistrationIATExtAppleCB) iatRouter.Get(extAuthRoute+paths.OAuthCallback, r.controllers.OAuthDynamicRegistrationIATExtCB) diff --git a/idp/internal/server/routes/well_known.go b/idp/internal/server/routes/well_known.go index d92b753..57c222b 100644 --- a/idp/internal/server/routes/well_known.go +++ b/idp/internal/server/routes/well_known.go @@ -13,7 +13,7 @@ import ( ) func (r *Routes) WellKnownRoutes(app *fiber.App) { - router := V1PathRouter(app).Group(paths.WellKnownBase, r.controllers.HostMiddleware) + router := app.Group(paths.WellKnownBase, r.controllers.HostMiddleware) router.Get(paths.WellKnownJWKs, HostAwareRoute( []fiber.Handler{r.controllers.GlobalOAuthPublicJWKs}, diff --git a/idp/internal/services/account_credentials_registration_iat.go b/idp/internal/services/account_credentials_registration_iat.go index e42bff7..62f2a2c 100644 --- a/idp/internal/services/account_credentials_registration_iat.go +++ b/idp/internal/services/account_credentials_registration_iat.go @@ -90,9 +90,9 @@ func (s *Services) CreateAccountCredentialsRegistrationIAT( } type ProcessAccountCredentialsRegistrationIATAuthOptions struct { - RequestID string - AuthHeader string - BackendDomain string + RequestID string + AuthHeader string + IssuerDomain string } func (s *Services) ProcessAccountCredentialsRegistrationIATAuth( @@ -113,7 +113,7 @@ func (s *Services) ProcessAccountCredentialsRegistrationIATAuth( tokens.VerifyDynamicRegistrationIATOptions{ RequestID: opts.RequestID, IAT: token, - IssuerDomain: opts.BackendDomain, + IssuerDomain: opts.IssuerDomain, GetPublicJWK: s.BuildGetGlobalPublicKeyFn(ctx, BuildGetGlobalVerifyKeyFnOptions{ RequestID: opts.RequestID, KeyType: database.TokenKeyTypeDynamicRegistration, diff --git a/idp/internal/services/app_dynamic_registration.go b/idp/internal/services/app_dynamic_registration.go index 5213ca2..acb2762 100644 --- a/idp/internal/services/app_dynamic_registration.go +++ b/idp/internal/services/app_dynamic_registration.go @@ -79,7 +79,6 @@ func mapAppGrantTypes( func mapAppTokenEndpointAuthMethod( authMethod string, appType database.AppType, - transport database.Transport, ) (database.AuthMethod, *exceptions.ServiceError) { if authMethod == "" { switch appType { @@ -319,7 +318,6 @@ func (s *Services) CreateAppCredentialsRegistration( tokenEndpointAuthMethod, serviceErr := mapAppTokenEndpointAuthMethod( opts.TokenEndpointAuthMethod, appType, - transport, ) if serviceErr != nil { logger.ErrorContext(ctx, "Failed to map token endpoint auth method", "serviceError", serviceErr) diff --git a/idp/internal/services/app_dynamic_registration_iat.go b/idp/internal/services/app_dynamic_registration_iat.go index c451c58..770a9db 100644 --- a/idp/internal/services/app_dynamic_registration_iat.go +++ b/idp/internal/services/app_dynamic_registration_iat.go @@ -93,46 +93,3 @@ func (s *Services) CreateAppCredentialsRegistrationIAT( logger.InfoContext(ctx, "Created app credentials registration IAT successfully") return signedToken, nil } - -type ProcessAppCredentialsRegistrationIATAuthOptions struct { - RequestID string - AuthHeader string - AccountUsername string - AccountID int32 - BackendDomain string -} - -func (s *Services) ProcessAppCredentialsRegistrationIATAuth( - ctx context.Context, - opts ProcessAppCredentialsRegistrationIATAuthOptions, -) (string, tokens.AccountClaims, *exceptions.ServiceError) { - logger := s.buildLogger(opts.RequestID, appDynamicRegistrationIATLocation, "ProcessAppCredentialsRegistrationIATAuth") - logger.InfoContext(ctx, "Processing app credentials registration IAT auth...") - - token, serviceErr := extractAuthHeaderToken(opts.AuthHeader) - if serviceErr != nil { - logger.WarnContext(ctx, "Failed to extract token from auth header", "serviceError", serviceErr) - return "", tokens.AccountClaims{}, serviceErr - } - - domain, accountClaims, err := s.jwt.VerifyDynamicRegistrationIAT( - ctx, - tokens.VerifyDynamicRegistrationIATOptions{ - RequestID: opts.RequestID, - IAT: token, - IssuerDomain: fmt.Sprintf("%s.%s", opts.AccountUsername, opts.BackendDomain), - GetPublicJWK: s.buildVerifyAccountKeyFn(ctx, logger, buildVerifyAccountKeyFnOptions{ - requestID: opts.RequestID, - accountID: opts.AccountID, - keyType: database.TokenKeyTypeDynamicRegistration, - }), - }, - ) - if err != nil { - logger.WarnContext(ctx, "Failed to verify app credentials registration IAT", "error", err) - return "", tokens.AccountClaims{}, exceptions.NewUnauthorizedError() - } - - logger.InfoContext(ctx, "Processed app credentials registration IAT auth successfully") - return domain, accountClaims, nil -} diff --git a/idp/internal/services/dynamic_registration_domains.go b/idp/internal/services/dynamic_registration_domains.go index 043f5d4..94467e6 100644 --- a/idp/internal/services/dynamic_registration_domains.go +++ b/idp/internal/services/dynamic_registration_domains.go @@ -9,6 +9,7 @@ package services import ( "context" "fmt" + "strings" "time" "github.com/google/uuid" @@ -845,6 +846,18 @@ func (s *Services) DeleteAccountCredentialsRegistrationDomainCode( return nil } +func breakDomainIntoAllSubdomains(domain string) []string { + strSlices := strings.Split(domain, ".") + size := len(strSlices) - 1 + subdomains := make([]string, size) + + for i := 0; i < size; i++ { + subdomains[i] = strings.Join(strSlices[i:], ".") + } + + return subdomains +} + type checkClientRegistrationDomainOptions struct { requestID string accountPublicID uuid.UUID diff --git a/idp/internal/services/oauth_dynamic_registration.go b/idp/internal/services/oauth_dynamic_registration.go index 323a123..5306d54 100644 --- a/idp/internal/services/oauth_dynamic_registration.go +++ b/idp/internal/services/oauth_dynamic_registration.go @@ -30,7 +30,7 @@ import ( const oauthDynamicRegistrationLocation string = "oauth_dynamic_registration" const ( - oauthDynamicRegistrationIATPath string = paths.V1 + paths.AccountsBase + paths.CredentialsBase + paths.DynamicRegistrationBase + paths.InitialAccessToken + oauthDynamicRegistrationIATPath string = paths.V1 + paths.AuthBase + paths.OAuthBase + paths.InitialAccessToken oauthDynamicRegistrationIATAuthPath string = oauthDynamicRegistrationIATPath + paths.OAuthAuth ) @@ -120,6 +120,62 @@ func (s *Services) generateOAuthDynamicRegistrationIATCallback( }), nil } +func mapDomainUsageFromHostExistence(host string) database.DynamicRegistrationUsage { + if host != "" { + return database.DynamicRegistrationUsageApp + } + + return database.DynamicRegistrationUsageAccount +} + +type checkDynamicClientRegistrationDomainUsabilityOptions struct { + requestID string + accountUsername string + domain string +} + +func (s *Services) checkDynamicRegistrationDomainUsability( + ctx context.Context, + opts checkDynamicClientRegistrationDomainUsabilityOptions, +) *exceptions.ServiceError { + logger := s.buildLogger(opts.requestID, dynamicRegistrationDomainsLocation, "checkDynamicRegistrationDomainUsability").With( + "accountUsername", opts.accountUsername, + ) + logger.InfoContext(ctx, "Checking dynamic registration domain usability") + + usage := mapDomainUsageFromHostExistence(opts.accountUsername) + domains := breakDomainIntoAllSubdomains(opts.domain) + var count int64 + var err error + if len(domains) > 1 { + count, err = s.database.CountDynamicRegistrationDomainsByDomainsAndUsages( + ctx, + database.CountDynamicRegistrationDomainsByDomainsAndUsagesParams{ + Domains: domains, + Usages: []database.DynamicRegistrationUsage{usage}, + }, + ) + } else { + count, err = s.database.CountDynamicRegistrationDomainsByDomainAndUsages( + ctx, + database.CountDynamicRegistrationDomainsByDomainAndUsagesParams{ + Domain: opts.domain, + Usages: []database.DynamicRegistrationUsage{usage}, + }, + ) + } + if err != nil { + logger.ErrorContext(ctx, "Failed to count dynamic registration domains by domain and usages", "error", err) + return exceptions.FromDBError(err) + } + if count == 0 { + logger.WarnContext(ctx, "Domain not registered for dynamic registration") + return exceptions.NewForbiddenError() + } + + return nil +} + type oauthDynamicRegistrationIATAuthOptions struct { requestID string challenge string @@ -127,6 +183,7 @@ type oauthDynamicRegistrationIATAuthOptions struct { domain string redirectURI string state string + hostUsername string } func (s *Services) oauthDynamicRegistrationIATAuth( @@ -140,39 +197,19 @@ func (s *Services) oauthDynamicRegistrationIATAuth( ).With( "domain", opts.domain, "redirectUri", opts.redirectURI, + "hostUsername", opts.hostUsername, ) logger.InfoContext(ctx, "Handling OAuth dynamic registration IAT auth...") - tldOneDomain, err := publicsuffix.EffectiveTLDPlusOne(opts.domain) - if err != nil { - logger.WarnContext(ctx, "Invalid domain", "error", err) - return "", exceptions.NewValidationError("invalid client_id") - } - - var count int64 - if tldOneDomain != opts.domain { - count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomains( - ctx, - []string{opts.domain, tldOneDomain}, - ) - if err != nil { - logger.ErrorContext(ctx, "Failed to count account dynamic registration domains by domains", "error", err) - return "", exceptions.NewInternalServerError() - } - if count == 0 { - logger.WarnContext(ctx, "Domain not registered for dynamic registration") - return "", exceptions.NewForbiddenError() - } - } else { - count, err = s.database.CountVerifiedDynamicRegistrationDomainsByDomain(ctx, opts.domain) - } - if err != nil { - logger.ErrorContext(ctx, "Failed to count account dynamic registration domains by domains", "error", err) - return "", exceptions.NewInternalServerError() - } - if count == 0 { - logger.WarnContext(ctx, "Domain not registered for dynamic registration") - return "", exceptions.NewForbiddenError() + if serviceErr := s.checkDynamicRegistrationDomainUsability( + ctx, + checkDynamicClientRegistrationDomainUsabilityOptions{ + requestID: opts.requestID, + accountUsername: opts.hostUsername, + }, + ); serviceErr != nil { + logger.InfoContext(ctx, "Dynamic registration domain not usable", "serviceError", serviceErr) + return "", serviceErr } hashedChallenge, serviceErr := hashChallenge(opts.challenge, opts.challengeMethod) @@ -189,6 +226,7 @@ func (s *Services) oauthDynamicRegistrationIATAuth( State: opts.state, RedirectURI: opts.redirectURI, Challenge: hashedChallenge, + Username: opts.hostUsername, }, ) if err != nil { @@ -339,6 +377,7 @@ type InitiateOAuthDynamicRegistrationIATAuthOptions struct { ChallengeMethod string RedirectURI string BackendDomain string + HostUsername string } func (s *Services) InitiateOAuthDynamicRegistrationIATAuth( @@ -364,6 +403,7 @@ func (s *Services) InitiateOAuthDynamicRegistrationIATAuth( domain: opts.Domain, redirectURI: opts.RedirectURI, state: opts.State, + hostUsername: opts.HostUsername, }) } diff --git a/keygen/main.go b/keygen/main.go index 17641c6..8dc383c 100644 --- a/keygen/main.go +++ b/keygen/main.go @@ -134,7 +134,7 @@ func encodeKeyPemToJson(logger *slog.Logger, block *pem.Block) string { func generateSecret(logger *slog.Logger) string { logger.Debug("Generating base64 encoded 32 byte secret") - bytes := make([]byte, 32) + bytes := make([]byte, 64) _, err := rand.Read(bytes) if err != nil { diff --git a/project.md b/project.md index 8c132a0..2f89670 100644 --- a/project.md +++ b/project.md @@ -30,10 +30,12 @@ - Add OAuth Dynamic Registration for: - accounts - apps +- Add support for multiple 2FA types ### IDP Todo - Account key generation +- Add Passkey (WebAuthn) support - Dynamic OIDC configs - User authentication for each app type: - web