Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions domain/auth_code.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package domain

import (
"time"

"github.com/uptrace/bun"
)

// AuthCode tracks a consumed authorization code for single-use enforcement
// per RFC 6749 §4.1.2. Auth codes are stateless HS256 JWTs; this record is
// created on first exchange and used to detect replays.
type AuthCode struct {
bun.BaseModel `bun:"table:auth_codes,alias:ac"`

JTI string `bun:"jti,pk,type:varchar(255)" json:"jti"`
ClientID string `bun:"client_id,type:varchar(255)" json:"client_id"`
AccountID string `bun:"account_id,type:varchar(255)" json:"account_id"`
ProjectID string `bun:"project_id,type:varchar(255)" json:"project_id"`
CredentialJTI *string `bun:"credential_jti,type:varchar(255)" json:"credential_jti,omitempty"`
RefreshFamilyID *string `bun:"refresh_family_id,type:uuid" json:"refresh_family_id,omitempty"`
ConsumedAt time.Time `bun:"consumed_at,nullzero,notnull,default:current_timestamp" json:"consumed_at"`
ExpiresAt time.Time `bun:"expires_at,notnull" json:"expires_at"`
}
4 changes: 0 additions & 4 deletions internal/handler/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,3 @@ func respondWithError(w http.ResponseWriter, status int, internalCode, message s
respondWithJSON(w, status, errResp)
}

// respondNotImplemented returns a 501 stub response.
func respondNotImplemented(w http.ResponseWriter) {
respondWithError(w, http.StatusNotImplemented, domain.ErrCodeNotImplemented, "not yet implemented")
}
15 changes: 15 additions & 0 deletions internal/service/authcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@ package service
import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"strings"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
)

// AuthCodeClaims holds the decoded claims from an authorization code JWT.
type AuthCodeClaims struct {
JTI string // JWT ID (jti claim or derived hash)
ExpiresAt time.Time // Token expiration
ClientID string // "cid" — Client application ID
CodeChallenge string // "cc" — PKCE code challenge (S256)
RedirectURI string // "ruri" — OAuth redirect URI
Expand Down Expand Up @@ -42,7 +46,18 @@ func decodeAuthCodeJWT(code, hmacSecret, expectedIssuer string) (*AuthCodeClaims
return nil, fmt.Errorf("auth code has invalid subject: %s", token.Subject())
}

// Use the JWT's jti claim if present; otherwise derive a deterministic
// identifier from the SHA-256 hash of the raw code string. This ensures
// replay protection works even for auth codes issued without a jti.
jti := token.JwtID()
if jti == "" {
h := sha256.Sum256([]byte(code))
jti = "derived:" + hex.EncodeToString(h[:])
}

claims := &AuthCodeClaims{
JTI: jti,
ExpiresAt: token.Expiration(),
ClientID: getStringClaim(token, "cid"),
CodeChallenge: getStringClaim(token, "cc"),
RedirectURI: getStringClaim(token, "ruri"),
Expand Down
75 changes: 75 additions & 0 deletions internal/service/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type OAuthService struct {
identitySvc *IdentityService
oauthClientSvc *OAuthClientService
apiKeyRepo *postgres.APIKeyRepository
authCodeRepo *postgres.AuthCodeRepository
jwksSvc *signing.JWKSService
refreshTokenSvc *RefreshTokenService
issuer string
Expand Down Expand Up @@ -86,6 +87,7 @@ func NewOAuthService(
identitySvc *IdentityService,
oauthClientSvc *OAuthClientService,
apiKeyRepo *postgres.APIKeyRepository,
authCodeRepo *postgres.AuthCodeRepository,
jwksSvc *signing.JWKSService,
refreshTokenSvc *RefreshTokenService,
cfg OAuthServiceConfig,
Expand All @@ -95,6 +97,7 @@ func NewOAuthService(
identitySvc: identitySvc,
oauthClientSvc: oauthClientSvc,
apiKeyRepo: apiKeyRepo,
authCodeRepo: authCodeRepo,
jwksSvc: jwksSvc,
refreshTokenSvc: refreshTokenSvc,
issuer: cfg.Issuer,
Expand Down Expand Up @@ -622,6 +625,10 @@ func (s *OAuthService) apiKeyGrant(ctx context.Context, req TokenRequest) (*doma
// Token behaviour is derived from the client's registered grant_types:
// - Clients with "refresh_token" grant: short-lived (1h) access token + rotating refresh token.
// - Clients without: long-lived (90-day) access token, no refresh token.
//
// Each auth code is single-use per RFC 6749 §4.1.2. On first exchange, the code
// is atomically marked as consumed. Replays are rejected with invalid_grant and
// all tokens issued from the original exchange are revoked.
func (s *OAuthService) authorizationCode(ctx context.Context, req TokenRequest) (*domain.AccessToken, error) {
if req.Code == "" || req.CodeVerifier == "" || req.ClientID == "" || req.RedirectURI == "" {
return nil, oauthBadRequest("invalid_request", "code, code_verifier, client_id, and redirect_uri are required")
Expand Down Expand Up @@ -670,6 +677,32 @@ func (s *OAuthService) authorizationCode(ctx context.Context, req TokenRequest)
return nil, oauthBadRequest("invalid_grant", "PKCE verification failed")
}

// ── Single-use enforcement (RFC 6749 §4.1.2) ────────────────────────
// Placed after all validation (client, redirect_uri, PKCE) so an
// attacker who intercepts a code but doesn't know the verifier cannot
// burn it by sending a request with a wrong verifier.
//
// Consume → IssueCredential → IssueRefreshToken → UpdateTokenInfo is
// not transactional. A replay arriving between Consume and
// UpdateTokenInfo is still rejected (the critical correctness
// property), but revokeAuthCodeTokens may find CredentialJTI unset and
// leave the in-flight exchange's tokens valid. RFC 6749 §4.1.2 says
// "SHOULD revoke" — best-effort revocation is acceptable here.
consumed, err := s.authCodeRepo.Consume(ctx, &domain.AuthCode{
JTI: authCode.JTI,
ClientID: authCode.ClientID,
AccountID: authCode.AccountID,
ProjectID: authCode.ProjectID,
ExpiresAt: authCode.ExpiresAt,
})
if err != nil {
return nil, oauthServerError("failed to check authorization code usage", err)
}
if !consumed {
s.revokeAuthCodeTokens(ctx, authCode.JTI)
Comment thread
rsharath marked this conversation as resolved.
Comment thread
rsharath marked this conversation as resolved.
return nil, oauthBadRequest("invalid_grant", "authorization code has already been used")
}

// Determine access token TTL.
// Priority: per-client config > grant-type-based default > server default.
hasRefreshGrant := false
Expand Down Expand Up @@ -713,6 +746,8 @@ func (s *OAuthService) authorizationCode(ctx context.Context, req TokenRequest)
accessToken.ProjectID = authCode.ProjectID
accessToken.UserID = authCode.UserID

var refreshFamilyID string

// Issue refresh token when the client is registered for the refresh_token grant.
if hasRefreshGrant && s.refreshTokenSvc != nil {
rtResult, rtErr := s.refreshTokenSvc.IssueRefreshToken(ctx, &RefreshTokenParams{
Expand All @@ -727,12 +762,52 @@ func (s *OAuthService) authorizationCode(ctx context.Context, req TokenRequest)
log.Error().Err(rtErr).Msg("Failed to issue refresh token — returning access token only")
} else {
accessToken.RefreshToken = rtResult.RawToken
refreshFamilyID = rtResult.FamilyID
}
}

// Store token info so replay detection can revoke these tokens later.
if updateErr := s.authCodeRepo.UpdateTokenInfo(ctx, authCode.JTI, accessToken.JTI, refreshFamilyID); updateErr != nil {
log.Error().Err(updateErr).Str("auth_code_jti", authCode.JTI).Msg("Failed to store auth code token info for replay revocation")
}
Comment thread
rsharath marked this conversation as resolved.

return accessToken, nil
}

// revokeAuthCodeTokens revokes the access token and refresh token family that
// were issued when the auth code was first exchanged. Per RFC 6749 §4.1.2:
// "the authorization server [...] SHOULD revoke all tokens previously issued
// based on that authorization code."
func (s *OAuthService) revokeAuthCodeTokens(ctx context.Context, codeJTI string) {
record, err := s.authCodeRepo.GetByJTI(ctx, codeJTI)
if err != nil {
log.Warn().Err(err).Str("auth_code_jti", codeJTI).Msg("Auth code replay: could not look up original exchange for revocation")
return
}

if record.CredentialJTI != nil && *record.CredentialJTI != "" {
cred, _, introspectErr := s.credentialSvc.IntrospectToken(ctx, *record.CredentialJTI)
if introspectErr != nil {
log.Error().Err(introspectErr).Str("credential_jti", *record.CredentialJTI).Msg("Auth code replay: failed to introspect access token for revocation")
} else if cred != nil {
if revokeErr := s.credentialSvc.RevokeCredential(ctx, cred.ID, cred.AccountID, cred.ProjectID, "auth_code_replay"); revokeErr != nil {
log.Error().Err(revokeErr).Str("credential_jti", *record.CredentialJTI).Msg("Auth code replay: failed to revoke access token")
} else {
log.Warn().Str("credential_jti", *record.CredentialJTI).Msg("Auth code replay: revoked access token from original exchange")
}
}
}

if record.RefreshFamilyID != nil && *record.RefreshFamilyID != "" && s.refreshTokenSvc != nil {
count, revokeErr := s.refreshTokenSvc.RevokeFamily(ctx, *record.RefreshFamilyID)
if revokeErr != nil {
log.Error().Err(revokeErr).Str("family_id", *record.RefreshFamilyID).Msg("Auth code replay: failed to revoke refresh token family")
} else if count > 0 {
log.Warn().Str("family_id", *record.RefreshFamilyID).Int64("count", count).Msg("Auth code replay: revoked refresh token family from original exchange")
}
}
}

// refreshToken handles the refresh_token grant (RFC 6749 section 6).
// Implements single-use rotation with family-based reuse detection.
func (s *OAuthService) refreshToken(ctx context.Context, req TokenRequest) (*domain.AccessToken, error) {
Expand Down
6 changes: 6 additions & 0 deletions internal/service/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ func (s *RefreshTokenService) RotateRefreshToken(ctx context.Context, rawToken s
return existing, result, nil
}

// RevokeFamily revokes all active tokens in a refresh token family.
// Used during auth code replay detection per RFC 6749 §4.1.2.
func (s *RefreshTokenService) RevokeFamily(ctx context.Context, familyID string) (int64, error) {
return s.repo.RevokeFamily(ctx, familyID)
}

// generateRefreshToken creates a cryptographically random refresh token.
// Format: zid_rt_<base64url(32 random bytes)>
func generateRefreshToken() (string, error) {
Expand Down
75 changes: 75 additions & 0 deletions internal/store/postgres/auth_code.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package postgres

import (
"context"
"fmt"

"github.com/uptrace/bun"

"github.com/highflame-ai/zeroid/domain"
)

// AuthCodeRepository handles persistence for consumed authorization codes.
type AuthCodeRepository struct {
db *bun.DB
}

// NewAuthCodeRepository creates a new AuthCodeRepository.
func NewAuthCodeRepository(db *bun.DB) *AuthCodeRepository {
return &AuthCodeRepository{db: db}
}

// Consume atomically records an auth code as consumed.
// Returns true if this is the first consumption (INSERT succeeded).
// Returns false if the code was already consumed (conflict on PK).
func (r *AuthCodeRepository) Consume(ctx context.Context, code *domain.AuthCode) (bool, error) {
res, err := r.db.NewInsert().
Model(code).
On("CONFLICT (jti) DO NOTHING").
Exec(ctx)
if err != nil {
return false, fmt.Errorf("failed to consume auth code: %w", err)
}

rows, err := res.RowsAffected()
if err != nil {
return false, fmt.Errorf("failed to check auth code insert result: %w", err)
}

return rows > 0, nil
}

// GetByJTI retrieves a consumed auth code record by its JTI.
// Used during replay detection to find the credential and refresh token
// family that need to be revoked.
func (r *AuthCodeRepository) GetByJTI(ctx context.Context, jti string) (*domain.AuthCode, error) {
code := &domain.AuthCode{}
err := r.db.NewSelect().Model(code).
Where("jti = ?", jti).
Scan(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get auth code by jti: %w", err)
}
return code, nil
}

// UpdateTokenInfo stores the credential JTI and refresh token family ID
// after successful token issuance. These are needed to revoke the tokens
// if a replay is detected later.
func (r *AuthCodeRepository) UpdateTokenInfo(ctx context.Context, jti, credentialJTI, refreshFamilyID string) error {
q := r.db.NewUpdate().
Model((*domain.AuthCode)(nil)).
Set("credential_jti = ?", credentialJTI).
Where("jti = ?", jti)

if refreshFamilyID != "" {
q = q.Set("refresh_family_id = ?::uuid", refreshFamilyID)
}

_, err := q.Exec(ctx)
if err != nil {
return fmt.Errorf("failed to update auth code token info: %w", err)
}
return nil
}

11 changes: 11 additions & 0 deletions internal/worker/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,15 @@ func (w *CleanupWorker) runOnce(ctx context.Context) {
} else if n, err := proofRes.RowsAffected(); err == nil && n > 0 {
log.Info().Int64("count", n).Msg("Cleanup: deleted expired proof tokens")
}

// Delete consumed auth codes past their expiry (single-use enforcement records).
authCodeRes, err := w.db.NewDelete().
TableExpr("auth_codes").
Where("expires_at < ?", now).
Exec(ctx)
if err != nil {
log.Error().Err(err).Msg("Cleanup: failed to delete expired auth codes")
} else if n, err := authCodeRes.RowsAffected(); err == nil && n > 0 {
log.Info().Int64("count", n).Msg("Cleanup: deleted expired auth codes")
}
}
2 changes: 1 addition & 1 deletion migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
func Migrate(databaseURL string) error {
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(databaseURL)))
db := bun.NewDB(sqldb, pgdialect.New())
defer db.Close()
defer func() { _ = db.Close() }()

if err := database.RunMigrations(db); err != nil {
return fmt.Errorf("zeroid migration failed: %w", err)
Expand Down
2 changes: 2 additions & 0 deletions migrations/008_auth_codes.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- 008_auth_codes.down.sql
DROP TABLE IF EXISTS auth_codes;
20 changes: 20 additions & 0 deletions migrations/008_auth_codes.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- 008_auth_codes.up.sql
-- Tracks consumed authorization codes to enforce single-use per RFC 6749 §4.1.2.
-- Auth codes are stateless HS256 JWTs; this table records consumption so replays
-- are rejected. The credential_jti and refresh_family_id columns enable token
-- revocation when a replay is detected (RFC 6749: "SHOULD revoke all tokens
-- previously issued based on that authorization code").

CREATE TABLE IF NOT EXISTS auth_codes (
jti VARCHAR(255) PRIMARY KEY,
client_id VARCHAR(255) NOT NULL,
account_id VARCHAR(255) NOT NULL,
project_id VARCHAR(255) NOT NULL DEFAULT '',
credential_jti VARCHAR(255),
refresh_family_id UUID,
consumed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL
);

CREATE INDEX IF NOT EXISTS idx_auth_codes_expires_at
ON auth_codes (expires_at);
3 changes: 2 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func NewServer(cfg Config) (*Server, error) {
credentialPolicyRepo := postgres.NewCredentialPolicyRepository(db)
apiKeyRepo := postgres.NewAPIKeyRepository(db)
refreshTokenRepo := postgres.NewRefreshTokenRepository(db)
authCodeRepo := postgres.NewAuthCodeRepository(db)

// Initialize services.
identitySvc := service.NewIdentityService(identityRepo, cfg.WIMSEDomain)
Expand All @@ -158,7 +159,7 @@ func NewServer(cfg Config) (*Server, error) {
if authCodeIssuer == "" {
authCodeIssuer = cfg.Token.Issuer
}
oauthSvc := service.NewOAuthService(credentialSvc, identitySvc, oauthClientSvc, apiKeyRepo, jwksSvc, refreshTokenSvc, service.OAuthServiceConfig{
oauthSvc := service.NewOAuthService(credentialSvc, identitySvc, oauthClientSvc, apiKeyRepo, authCodeRepo, jwksSvc, refreshTokenSvc, service.OAuthServiceConfig{
Issuer: cfg.Token.Issuer,
WIMSEDomain: cfg.WIMSEDomain,
HMACSecret: cfg.Token.HMACSecret,
Expand Down
Loading
Loading