diff --git a/domain/auth_code.go b/domain/auth_code.go new file mode 100644 index 0000000..eb4ef7f --- /dev/null +++ b/domain/auth_code.go @@ -0,0 +1,23 @@ +package domain + +import ( + "time" + + "github.com/uptrace/bun" +) + +// AuthCode tracks a consumed authorization code for single-use enforcement +// per RFC 6749 §4.1.2. Auth codes are stateless HS256 JWTs; this record is +// created on first exchange and used to detect replays. +type AuthCode struct { + bun.BaseModel `bun:"table:auth_codes,alias:ac"` + + JTI string `bun:"jti,pk,type:varchar(255)" json:"jti"` + ClientID string `bun:"client_id,type:varchar(255)" json:"client_id"` + AccountID string `bun:"account_id,type:varchar(255)" json:"account_id"` + ProjectID string `bun:"project_id,type:varchar(255)" json:"project_id"` + CredentialJTI *string `bun:"credential_jti,type:varchar(255)" json:"credential_jti,omitempty"` + RefreshFamilyID *string `bun:"refresh_family_id,type:uuid" json:"refresh_family_id,omitempty"` + ConsumedAt time.Time `bun:"consumed_at,nullzero,notnull,default:current_timestamp" json:"consumed_at"` + ExpiresAt time.Time `bun:"expires_at,notnull" json:"expires_at"` +} diff --git a/internal/handler/helper.go b/internal/handler/helper.go index 8de3b60..b990a82 100644 --- a/internal/handler/helper.go +++ b/internal/handler/helper.go @@ -25,7 +25,3 @@ func respondWithError(w http.ResponseWriter, status int, internalCode, message s respondWithJSON(w, status, errResp) } -// respondNotImplemented returns a 501 stub response. -func respondNotImplemented(w http.ResponseWriter) { - respondWithError(w, http.StatusNotImplemented, domain.ErrCodeNotImplemented, "not yet implemented") -} diff --git a/internal/service/authcode.go b/internal/service/authcode.go index 7ddfeed..0bcce0e 100644 --- a/internal/service/authcode.go +++ b/internal/service/authcode.go @@ -3,8 +3,10 @@ package service import ( "crypto/sha256" "encoding/base64" + "encoding/hex" "fmt" "strings" + "time" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" @@ -12,6 +14,8 @@ import ( // AuthCodeClaims holds the decoded claims from an authorization code JWT. type AuthCodeClaims struct { + JTI string // JWT ID (jti claim or derived hash) + ExpiresAt time.Time // Token expiration ClientID string // "cid" — Client application ID CodeChallenge string // "cc" — PKCE code challenge (S256) RedirectURI string // "ruri" — OAuth redirect URI @@ -42,7 +46,18 @@ func decodeAuthCodeJWT(code, hmacSecret, expectedIssuer string) (*AuthCodeClaims return nil, fmt.Errorf("auth code has invalid subject: %s", token.Subject()) } + // Use the JWT's jti claim if present; otherwise derive a deterministic + // identifier from the SHA-256 hash of the raw code string. This ensures + // replay protection works even for auth codes issued without a jti. + jti := token.JwtID() + if jti == "" { + h := sha256.Sum256([]byte(code)) + jti = "derived:" + hex.EncodeToString(h[:]) + } + claims := &AuthCodeClaims{ + JTI: jti, + ExpiresAt: token.Expiration(), ClientID: getStringClaim(token, "cid"), CodeChallenge: getStringClaim(token, "cc"), RedirectURI: getStringClaim(token, "ruri"), diff --git a/internal/service/oauth.go b/internal/service/oauth.go index fc65fab..7da4e5c 100644 --- a/internal/service/oauth.go +++ b/internal/service/oauth.go @@ -28,6 +28,7 @@ type OAuthService struct { identitySvc *IdentityService oauthClientSvc *OAuthClientService apiKeyRepo *postgres.APIKeyRepository + authCodeRepo *postgres.AuthCodeRepository jwksSvc *signing.JWKSService refreshTokenSvc *RefreshTokenService issuer string @@ -86,6 +87,7 @@ func NewOAuthService( identitySvc *IdentityService, oauthClientSvc *OAuthClientService, apiKeyRepo *postgres.APIKeyRepository, + authCodeRepo *postgres.AuthCodeRepository, jwksSvc *signing.JWKSService, refreshTokenSvc *RefreshTokenService, cfg OAuthServiceConfig, @@ -95,6 +97,7 @@ func NewOAuthService( identitySvc: identitySvc, oauthClientSvc: oauthClientSvc, apiKeyRepo: apiKeyRepo, + authCodeRepo: authCodeRepo, jwksSvc: jwksSvc, refreshTokenSvc: refreshTokenSvc, issuer: cfg.Issuer, @@ -622,6 +625,10 @@ func (s *OAuthService) apiKeyGrant(ctx context.Context, req TokenRequest) (*doma // Token behaviour is derived from the client's registered grant_types: // - Clients with "refresh_token" grant: short-lived (1h) access token + rotating refresh token. // - Clients without: long-lived (90-day) access token, no refresh token. +// +// Each auth code is single-use per RFC 6749 §4.1.2. On first exchange, the code +// is atomically marked as consumed. Replays are rejected with invalid_grant and +// all tokens issued from the original exchange are revoked. func (s *OAuthService) authorizationCode(ctx context.Context, req TokenRequest) (*domain.AccessToken, error) { if req.Code == "" || req.CodeVerifier == "" || req.ClientID == "" || req.RedirectURI == "" { return nil, oauthBadRequest("invalid_request", "code, code_verifier, client_id, and redirect_uri are required") @@ -670,6 +677,32 @@ func (s *OAuthService) authorizationCode(ctx context.Context, req TokenRequest) return nil, oauthBadRequest("invalid_grant", "PKCE verification failed") } + // ── Single-use enforcement (RFC 6749 §4.1.2) ──────────────────────── + // Placed after all validation (client, redirect_uri, PKCE) so an + // attacker who intercepts a code but doesn't know the verifier cannot + // burn it by sending a request with a wrong verifier. + // + // Consume → IssueCredential → IssueRefreshToken → UpdateTokenInfo is + // not transactional. A replay arriving between Consume and + // UpdateTokenInfo is still rejected (the critical correctness + // property), but revokeAuthCodeTokens may find CredentialJTI unset and + // leave the in-flight exchange's tokens valid. RFC 6749 §4.1.2 says + // "SHOULD revoke" — best-effort revocation is acceptable here. + consumed, err := s.authCodeRepo.Consume(ctx, &domain.AuthCode{ + JTI: authCode.JTI, + ClientID: authCode.ClientID, + AccountID: authCode.AccountID, + ProjectID: authCode.ProjectID, + ExpiresAt: authCode.ExpiresAt, + }) + if err != nil { + return nil, oauthServerError("failed to check authorization code usage", err) + } + if !consumed { + s.revokeAuthCodeTokens(ctx, authCode.JTI) + return nil, oauthBadRequest("invalid_grant", "authorization code has already been used") + } + // Determine access token TTL. // Priority: per-client config > grant-type-based default > server default. hasRefreshGrant := false @@ -713,6 +746,8 @@ func (s *OAuthService) authorizationCode(ctx context.Context, req TokenRequest) accessToken.ProjectID = authCode.ProjectID accessToken.UserID = authCode.UserID + var refreshFamilyID string + // Issue refresh token when the client is registered for the refresh_token grant. if hasRefreshGrant && s.refreshTokenSvc != nil { rtResult, rtErr := s.refreshTokenSvc.IssueRefreshToken(ctx, &RefreshTokenParams{ @@ -727,12 +762,52 @@ func (s *OAuthService) authorizationCode(ctx context.Context, req TokenRequest) log.Error().Err(rtErr).Msg("Failed to issue refresh token — returning access token only") } else { accessToken.RefreshToken = rtResult.RawToken + refreshFamilyID = rtResult.FamilyID } } + // Store token info so replay detection can revoke these tokens later. + if updateErr := s.authCodeRepo.UpdateTokenInfo(ctx, authCode.JTI, accessToken.JTI, refreshFamilyID); updateErr != nil { + log.Error().Err(updateErr).Str("auth_code_jti", authCode.JTI).Msg("Failed to store auth code token info for replay revocation") + } + return accessToken, nil } +// revokeAuthCodeTokens revokes the access token and refresh token family that +// were issued when the auth code was first exchanged. Per RFC 6749 §4.1.2: +// "the authorization server [...] SHOULD revoke all tokens previously issued +// based on that authorization code." +func (s *OAuthService) revokeAuthCodeTokens(ctx context.Context, codeJTI string) { + record, err := s.authCodeRepo.GetByJTI(ctx, codeJTI) + if err != nil { + log.Warn().Err(err).Str("auth_code_jti", codeJTI).Msg("Auth code replay: could not look up original exchange for revocation") + return + } + + if record.CredentialJTI != nil && *record.CredentialJTI != "" { + cred, _, introspectErr := s.credentialSvc.IntrospectToken(ctx, *record.CredentialJTI) + if introspectErr != nil { + log.Error().Err(introspectErr).Str("credential_jti", *record.CredentialJTI).Msg("Auth code replay: failed to introspect access token for revocation") + } else if cred != nil { + if revokeErr := s.credentialSvc.RevokeCredential(ctx, cred.ID, cred.AccountID, cred.ProjectID, "auth_code_replay"); revokeErr != nil { + log.Error().Err(revokeErr).Str("credential_jti", *record.CredentialJTI).Msg("Auth code replay: failed to revoke access token") + } else { + log.Warn().Str("credential_jti", *record.CredentialJTI).Msg("Auth code replay: revoked access token from original exchange") + } + } + } + + if record.RefreshFamilyID != nil && *record.RefreshFamilyID != "" && s.refreshTokenSvc != nil { + count, revokeErr := s.refreshTokenSvc.RevokeFamily(ctx, *record.RefreshFamilyID) + if revokeErr != nil { + log.Error().Err(revokeErr).Str("family_id", *record.RefreshFamilyID).Msg("Auth code replay: failed to revoke refresh token family") + } else if count > 0 { + log.Warn().Str("family_id", *record.RefreshFamilyID).Int64("count", count).Msg("Auth code replay: revoked refresh token family from original exchange") + } + } +} + // refreshToken handles the refresh_token grant (RFC 6749 section 6). // Implements single-use rotation with family-based reuse detection. func (s *OAuthService) refreshToken(ctx context.Context, req TokenRequest) (*domain.AccessToken, error) { diff --git a/internal/service/refresh_token.go b/internal/service/refresh_token.go index 4884dbb..424ef6a 100644 --- a/internal/service/refresh_token.go +++ b/internal/service/refresh_token.go @@ -184,6 +184,12 @@ func (s *RefreshTokenService) RotateRefreshToken(ctx context.Context, rawToken s return existing, result, nil } +// RevokeFamily revokes all active tokens in a refresh token family. +// Used during auth code replay detection per RFC 6749 §4.1.2. +func (s *RefreshTokenService) RevokeFamily(ctx context.Context, familyID string) (int64, error) { + return s.repo.RevokeFamily(ctx, familyID) +} + // generateRefreshToken creates a cryptographically random refresh token. // Format: zid_rt_ func generateRefreshToken() (string, error) { diff --git a/internal/store/postgres/auth_code.go b/internal/store/postgres/auth_code.go new file mode 100644 index 0000000..d6b8d3f --- /dev/null +++ b/internal/store/postgres/auth_code.go @@ -0,0 +1,75 @@ +package postgres + +import ( + "context" + "fmt" + + "github.com/uptrace/bun" + + "github.com/highflame-ai/zeroid/domain" +) + +// AuthCodeRepository handles persistence for consumed authorization codes. +type AuthCodeRepository struct { + db *bun.DB +} + +// NewAuthCodeRepository creates a new AuthCodeRepository. +func NewAuthCodeRepository(db *bun.DB) *AuthCodeRepository { + return &AuthCodeRepository{db: db} +} + +// Consume atomically records an auth code as consumed. +// Returns true if this is the first consumption (INSERT succeeded). +// Returns false if the code was already consumed (conflict on PK). +func (r *AuthCodeRepository) Consume(ctx context.Context, code *domain.AuthCode) (bool, error) { + res, err := r.db.NewInsert(). + Model(code). + On("CONFLICT (jti) DO NOTHING"). + Exec(ctx) + if err != nil { + return false, fmt.Errorf("failed to consume auth code: %w", err) + } + + rows, err := res.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to check auth code insert result: %w", err) + } + + return rows > 0, nil +} + +// GetByJTI retrieves a consumed auth code record by its JTI. +// Used during replay detection to find the credential and refresh token +// family that need to be revoked. +func (r *AuthCodeRepository) GetByJTI(ctx context.Context, jti string) (*domain.AuthCode, error) { + code := &domain.AuthCode{} + err := r.db.NewSelect().Model(code). + Where("jti = ?", jti). + Scan(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get auth code by jti: %w", err) + } + return code, nil +} + +// UpdateTokenInfo stores the credential JTI and refresh token family ID +// after successful token issuance. These are needed to revoke the tokens +// if a replay is detected later. +func (r *AuthCodeRepository) UpdateTokenInfo(ctx context.Context, jti, credentialJTI, refreshFamilyID string) error { + q := r.db.NewUpdate(). + Model((*domain.AuthCode)(nil)). + Set("credential_jti = ?", credentialJTI). + Where("jti = ?", jti) + + if refreshFamilyID != "" { + q = q.Set("refresh_family_id = ?::uuid", refreshFamilyID) + } + + _, err := q.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to update auth code token info: %w", err) + } + return nil +} + diff --git a/internal/worker/cleanup.go b/internal/worker/cleanup.go index 1ed37b9..2bcc27c 100644 --- a/internal/worker/cleanup.go +++ b/internal/worker/cleanup.go @@ -64,4 +64,15 @@ func (w *CleanupWorker) runOnce(ctx context.Context) { } else if n, err := proofRes.RowsAffected(); err == nil && n > 0 { log.Info().Int64("count", n).Msg("Cleanup: deleted expired proof tokens") } + + // Delete consumed auth codes past their expiry (single-use enforcement records). + authCodeRes, err := w.db.NewDelete(). + TableExpr("auth_codes"). + Where("expires_at < ?", now). + Exec(ctx) + if err != nil { + log.Error().Err(err).Msg("Cleanup: failed to delete expired auth codes") + } else if n, err := authCodeRes.RowsAffected(); err == nil && n > 0 { + log.Info().Int64("count", n).Msg("Cleanup: deleted expired auth codes") + } } diff --git a/migrate.go b/migrate.go index 1627772..d68d25c 100644 --- a/migrate.go +++ b/migrate.go @@ -22,7 +22,7 @@ import ( func Migrate(databaseURL string) error { sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(databaseURL))) db := bun.NewDB(sqldb, pgdialect.New()) - defer db.Close() + defer func() { _ = db.Close() }() if err := database.RunMigrations(db); err != nil { return fmt.Errorf("zeroid migration failed: %w", err) diff --git a/migrations/008_auth_codes.down.sql b/migrations/008_auth_codes.down.sql new file mode 100644 index 0000000..ede3866 --- /dev/null +++ b/migrations/008_auth_codes.down.sql @@ -0,0 +1,2 @@ +-- 008_auth_codes.down.sql +DROP TABLE IF EXISTS auth_codes; diff --git a/migrations/008_auth_codes.up.sql b/migrations/008_auth_codes.up.sql new file mode 100644 index 0000000..563cc2b --- /dev/null +++ b/migrations/008_auth_codes.up.sql @@ -0,0 +1,20 @@ +-- 008_auth_codes.up.sql +-- Tracks consumed authorization codes to enforce single-use per RFC 6749 §4.1.2. +-- Auth codes are stateless HS256 JWTs; this table records consumption so replays +-- are rejected. The credential_jti and refresh_family_id columns enable token +-- revocation when a replay is detected (RFC 6749: "SHOULD revoke all tokens +-- previously issued based on that authorization code"). + +CREATE TABLE IF NOT EXISTS auth_codes ( + jti VARCHAR(255) PRIMARY KEY, + client_id VARCHAR(255) NOT NULL, + account_id VARCHAR(255) NOT NULL, + project_id VARCHAR(255) NOT NULL DEFAULT '', + credential_jti VARCHAR(255), + refresh_family_id UUID, + consumed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_auth_codes_expires_at + ON auth_codes (expires_at); diff --git a/server.go b/server.go index 28bb580..fb152d6 100644 --- a/server.go +++ b/server.go @@ -144,6 +144,7 @@ func NewServer(cfg Config) (*Server, error) { credentialPolicyRepo := postgres.NewCredentialPolicyRepository(db) apiKeyRepo := postgres.NewAPIKeyRepository(db) refreshTokenRepo := postgres.NewRefreshTokenRepository(db) + authCodeRepo := postgres.NewAuthCodeRepository(db) // Initialize services. identitySvc := service.NewIdentityService(identityRepo, cfg.WIMSEDomain) @@ -158,7 +159,7 @@ func NewServer(cfg Config) (*Server, error) { if authCodeIssuer == "" { authCodeIssuer = cfg.Token.Issuer } - oauthSvc := service.NewOAuthService(credentialSvc, identitySvc, oauthClientSvc, apiKeyRepo, jwksSvc, refreshTokenSvc, service.OAuthServiceConfig{ + oauthSvc := service.NewOAuthService(credentialSvc, identitySvc, oauthClientSvc, apiKeyRepo, authCodeRepo, jwksSvc, refreshTokenSvc, service.OAuthServiceConfig{ Issuer: cfg.Token.Issuer, WIMSEDomain: cfg.WIMSEDomain, HMACSecret: cfg.Token.HMACSecret, diff --git a/tests/integration/auth_verify_test.go b/tests/integration/auth_verify_test.go index 3d39d88..dfc0dfc 100644 --- a/tests/integration/auth_verify_test.go +++ b/tests/integration/auth_verify_test.go @@ -29,7 +29,7 @@ func issueAPIKeyToken(t *testing.T, externalID string) string { // challenge. func TestAuthVerify_MissingAuthorizationHeader(t *testing.T) { resp := get(t, "/oauth2/token/verify", nil) - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, `Bearer error="missing_token"`, resp.Header.Get("WWW-Authenticate")) @@ -41,7 +41,7 @@ func TestAuthVerify_WrongScheme(t *testing.T) { resp := get(t, "/oauth2/token/verify", map[string]string{ "Authorization": "Basic dXNlcjpwYXNz", }) - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, `Bearer error="invalid_request"`, resp.Header.Get("WWW-Authenticate")) @@ -53,7 +53,7 @@ func TestAuthVerify_EmptyBearerToken(t *testing.T) { resp := get(t, "/oauth2/token/verify", map[string]string{ "Authorization": "Bearer ", }) - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, `Bearer error="invalid_request"`, resp.Header.Get("WWW-Authenticate")) @@ -65,7 +65,7 @@ func TestAuthVerify_InvalidToken(t *testing.T) { resp := get(t, "/oauth2/token/verify", map[string]string{ "Authorization": "Bearer not.a.valid.jwt", }) - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, `Bearer error="invalid_token"`, resp.Header.Get("WWW-Authenticate")) @@ -80,7 +80,7 @@ func TestAuthVerify_ValidToken(t *testing.T) { resp := get(t, "/oauth2/token/verify", map[string]string{ "Authorization": "Bearer " + token, }) - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) @@ -101,13 +101,13 @@ func TestAuthVerify_RevokedToken(t *testing.T) { // Revoke the token. revokeResp := post(t, "/oauth2/token/revoke", map[string]string{"token": token}, nil) require.Equal(t, http.StatusOK, revokeResp.StatusCode) - revokeResp.Body.Close() + _ = revokeResp.Body.Close() // Verify must now reject it. resp := get(t, "/oauth2/token/verify", map[string]string{ "Authorization": "Bearer " + token, }) - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, `Bearer error="invalid_token"`, resp.Header.Get("WWW-Authenticate")) @@ -135,7 +135,7 @@ func authVerifyHeaders(t *testing.T, token string) http.Header { resp := get(t, "/oauth2/token/verify", map[string]string{ "Authorization": "Bearer " + token, }) - resp.Body.Close() + _ = resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) return resp.Header } diff --git a/tests/integration/authorization_code_test.go b/tests/integration/authorization_code_test.go index 5be9e95..ccd1064 100644 --- a/tests/integration/authorization_code_test.go +++ b/tests/integration/authorization_code_test.go @@ -368,3 +368,75 @@ func TestEnsureClientUpdatesConfig(t *testing.T) { grantTypes := found["grant_types"].([]any) assert.Equal(t, 2, len(grantTypes)) } + +// TestAuthorizationCodeReplayRejected verifies that an auth code can only be +// exchanged once (RFC 6749 §4.1.2). The second exchange must return invalid_grant. +func TestAuthorizationCodeReplayRejected(t *testing.T) { + verifier, challenge := buildPKCEPair(t) + code := buildAuthCode(t, testCLIClientID, "user-replay-001", testRedirectURI, challenge, []string{"data:read"}) + + payload := map[string]any{ + "grant_type": "authorization_code", + "client_id": testCLIClientID, + "code": code, + "code_verifier": verifier, + "redirect_uri": testRedirectURI, + } + + // First exchange — must succeed. + resp := post(t, "/oauth2/token", payload, nil) + require.Equal(t, http.StatusOK, resp.StatusCode) + firstToken := decode(t, resp) + assert.NotEmpty(t, firstToken["access_token"]) + + // Second exchange with the same code — must be rejected. + resp = post(t, "/oauth2/token", payload, nil) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + body := decode(t, resp) + assert.Equal(t, "invalid_grant", body["error"]) +} + +// TestAuthorizationCodeReplayRevokesTokens verifies that when an auth code is +// replayed, the tokens issued during the first exchange are revoked per +// RFC 6749 §4.1.2. +func TestAuthorizationCodeReplayRevokesTokens(t *testing.T) { + verifier, challenge := buildPKCEPair(t) + code := buildAuthCode(t, testMCPClientID, "user-replay-revoke-001", testRedirectURI, challenge, []string{"data:read"}) + + payload := map[string]any{ + "grant_type": "authorization_code", + "client_id": testMCPClientID, + "code": code, + "code_verifier": verifier, + "redirect_uri": testRedirectURI, + } + + // First exchange — succeeds, issues access + refresh token. + resp := post(t, "/oauth2/token", payload, nil) + require.Equal(t, http.StatusOK, resp.StatusCode) + firstToken := decode(t, resp) + accessToken := firstToken["access_token"].(string) + refreshToken := firstToken["refresh_token"].(string) + assert.NotEmpty(t, accessToken) + assert.NotEmpty(t, refreshToken) + + // Access token must be active before replay. + result := introspect(t, accessToken) + assert.True(t, result["active"].(bool), "access token should be active before replay") + + // Replay the auth code — triggers revocation of the first exchange's tokens. + resp = post(t, "/oauth2/token", payload, nil) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + // Access token from the first exchange must now be revoked. + result = introspect(t, accessToken) + assert.False(t, result["active"].(bool), "access token should be revoked after auth code replay") + + // Refresh token from the first exchange must also be revoked. + resp = post(t, "/oauth2/token", map[string]any{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": testMCPClientID, + }, nil) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode, "refresh token should be revoked after auth code replay") +} diff --git a/tests/integration/cae_test.go b/tests/integration/cae_test.go index b83d296..3e21f37 100644 --- a/tests/integration/cae_test.go +++ b/tests/integration/cae_test.go @@ -45,7 +45,7 @@ func TestCAECriticalSignalRevokesCredential(t *testing.T) { "payload": map[string]any{"reason": "test revocation"}, }, adminHeaders()) require.Equal(t, http.StatusCreated, signalResp.StatusCode) - signalResp.Body.Close() + _ = signalResp.Body.Close() // Give the in-process revocation goroutine a moment to complete. time.Sleep(100 * time.Millisecond) @@ -82,7 +82,7 @@ func TestCAEHighSignalRevokesCredential(t *testing.T) { "payload": map[string]any{}, }, adminHeaders()) require.Equal(t, http.StatusCreated, signalResp.StatusCode) - signalResp.Body.Close() + _ = signalResp.Body.Close() time.Sleep(100 * time.Millisecond) @@ -117,7 +117,7 @@ func TestCAELowSignalDoesNotRevokeCredential(t *testing.T) { "payload": map[string]any{}, }, adminHeaders()) require.Equal(t, http.StatusCreated, signalResp.StatusCode) - signalResp.Body.Close() + _ = signalResp.Body.Close() time.Sleep(100 * time.Millisecond) @@ -164,7 +164,7 @@ func TestCAESignalRevokesAllActiveCredentials(t *testing.T) { "payload": map[string]any{}, }, adminHeaders()) require.Equal(t, http.StatusCreated, signalResp.StatusCode) - signalResp.Body.Close() + _ = signalResp.Body.Close() time.Sleep(100 * time.Millisecond) @@ -241,7 +241,7 @@ func TestCAESignalCascadesRevocationToChildren(t *testing.T) { "payload": map[string]any{"reason": "cascade revocation test"}, }, adminHeaders()) require.Equal(t, http.StatusCreated, signalResp.StatusCode) - signalResp.Body.Close() + _ = signalResp.Body.Close() time.Sleep(100 * time.Millisecond) @@ -281,7 +281,7 @@ func TestSignalListEndpoint(t *testing.T) { "payload": map[string]any{"old_ip": "1.2.3.4", "new_ip": "5.6.7.8"}, }, adminHeaders()) require.Equal(t, http.StatusCreated, signalResp.StatusCode) - signalResp.Body.Close() + _ = signalResp.Body.Close() listResp := get(t, adminPath("/signals"), adminHeaders()) require.Equal(t, http.StatusOK, listResp.StatusCode) diff --git a/tests/integration/helpers_test.go b/tests/integration/helpers_test.go index 67cefc8..32ebcb6 100644 --- a/tests/integration/helpers_test.go +++ b/tests/integration/helpers_test.go @@ -199,28 +199,28 @@ func writeKeyFiles(privKey *ecdsa.PrivateKey) (privPath, pubPath string, cleanup return "", "", nil, err } if err := pem.Encode(privFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privDER}); err != nil { - privFile.Close() - os.Remove(privFile.Name()) + _ = privFile.Close() + _ = os.Remove(privFile.Name()) return "", "", nil, err } - privFile.Close() + _ = privFile.Close() pubFile, err := os.CreateTemp("", "zeroid-test-pub-*.pem") if err != nil { - os.Remove(privFile.Name()) + _ = os.Remove(privFile.Name()) return "", "", nil, err } if err := pem.Encode(pubFile, &pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}); err != nil { - pubFile.Close() - os.Remove(privFile.Name()) - os.Remove(pubFile.Name()) + _ = pubFile.Close() + _ = os.Remove(privFile.Name()) + _ = os.Remove(pubFile.Name()) return "", "", nil, err } - pubFile.Close() + _ = pubFile.Close() return privFile.Name(), pubFile.Name(), func() { - os.Remove(privFile.Name()) - os.Remove(pubFile.Name()) + _ = os.Remove(privFile.Name()) + _ = os.Remove(pubFile.Name()) }, nil } @@ -260,7 +260,7 @@ func doRequest(t *testing.T, method, path string, body any, headers map[string]s // decode reads and JSON-decodes a response body, closing it after. func decode(t *testing.T, resp *http.Response) map[string]any { t.Helper() - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() var m map[string]any require.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) return m @@ -403,28 +403,28 @@ func writeRSAKeyFiles(privKey *rsa.PrivateKey) (privPath, pubPath string, cleanu return "", "", nil, err } if err := pem.Encode(privFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privDER}); err != nil { - privFile.Close() - os.Remove(privFile.Name()) + _ = privFile.Close() + _ = os.Remove(privFile.Name()) return "", "", nil, err } - privFile.Close() + _ = privFile.Close() pubFile, err := os.CreateTemp("", "zeroid-test-rsa-pub-*.pem") if err != nil { - os.Remove(privFile.Name()) + _ = os.Remove(privFile.Name()) return "", "", nil, err } if err := pem.Encode(pubFile, &pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}); err != nil { - pubFile.Close() - os.Remove(privFile.Name()) - os.Remove(pubFile.Name()) + _ = pubFile.Close() + _ = os.Remove(privFile.Name()) + _ = os.Remove(pubFile.Name()) return "", "", nil, err } - pubFile.Close() + _ = pubFile.Close() return privFile.Name(), pubFile.Name(), func() { - os.Remove(privFile.Name()) - os.Remove(pubFile.Name()) + _ = os.Remove(privFile.Name()) + _ = os.Remove(pubFile.Name()) }, nil } diff --git a/tests/integration/identity_test.go b/tests/integration/identity_test.go index 4017551..3295e24 100644 --- a/tests/integration/identity_test.go +++ b/tests/integration/identity_test.go @@ -50,7 +50,7 @@ func TestRegisterIdentityDuplicateReturns409(t *testing.T) { "allowed_scopes": []string{"billing:read"}, }, adminHeaders()) assert.Equal(t, http.StatusConflict, resp.StatusCode) - resp.Body.Close() + _ = resp.Body.Close() } // TestRegisterIdentityMissingExternalID verifies that omitting external_id returns 400/422. @@ -64,7 +64,7 @@ func TestRegisterIdentityMissingExternalID(t *testing.T) { resp.StatusCode == http.StatusBadRequest || resp.StatusCode == http.StatusUnprocessableEntity, "expected 400 or 422 for missing external_id, got %d", resp.StatusCode, ) - resp.Body.Close() + _ = resp.Body.Close() } // TestGetIdentity verifies that GET /api/v1/identities/{id} returns the identity. @@ -85,7 +85,7 @@ func TestGetIdentity(t *testing.T) { func TestGetIdentityNotFound(t *testing.T) { resp := get(t, adminPath("/identities/00000000-0000-0000-0000-000000000000"), adminHeaders()) assert.Equal(t, http.StatusNotFound, resp.StatusCode) - resp.Body.Close() + _ = resp.Body.Close() } // TestListIdentities verifies that the list endpoint returns identities scoped to the tenant. @@ -183,7 +183,7 @@ func TestDeleteIdentity(t *testing.T) { resp, err := doRaw(t, http.MethodDelete, adminPath("/identities/"+identity.ID), nil, adminHeaders()) require.NoError(t, err) assert.Equal(t, http.StatusNoContent, resp.StatusCode) - resp.Body.Close() + _ = resp.Body.Close() } func TestServerGetIdentity(t *testing.T) { @@ -299,7 +299,7 @@ func TestListAgentsFilterByIsActive(t *testing.T) { deactivateResp, err := doRaw(t, http.MethodPost, adminPath("/agents/registry/"+agentID+"/deactivate"), nil, adminHeaders()) require.NoError(t, err) require.Equal(t, http.StatusOK, deactivateResp.StatusCode) - deactivateResp.Body.Close() + _ = deactivateResp.Body.Close() // is_active=true should exclude deactivated. resp = get(t, adminPath("/agents/registry?is_active=true&label=test:active-filter"), adminHeaders()) diff --git a/tests/integration/oauth_test.go b/tests/integration/oauth_test.go index e1dd4bf..2aa6cf0 100644 --- a/tests/integration/oauth_test.go +++ b/tests/integration/oauth_test.go @@ -49,7 +49,7 @@ func TestClientCredentialsFlow(t *testing.T) { // Agent: revoke the token (RFC 7009 — must return 200 regardless). revokeResp := post(t, "/oauth2/token/revoke", map[string]string{"token": accessToken}, nil) require.Equal(t, http.StatusOK, revokeResp.StatusCode) - revokeResp.Body.Close() + _ = revokeResp.Body.Close() // Introspect again: token must now be inactive. result = introspect(t, accessToken) @@ -91,7 +91,7 @@ func TestClientCredentialsWrongSecret(t *testing.T) { "client_secret": "wrong-secret", }, nil) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - resp.Body.Close() + _ = resp.Body.Close() } // TestJWTBearerFlow exercises the full RFC 7523 jwt_bearer flow: @@ -147,7 +147,7 @@ func TestJWTBearerWrongKey(t *testing.T) { "scope": "data:read", }, nil) assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - resp.Body.Close() + _ = resp.Body.Close() } // TestTokenExchangeFlow exercises the full RFC 8693 token_exchange / agent delegation flow: @@ -248,7 +248,7 @@ func TestTokenExchangeScopeEnforcement(t *testing.T) { assert.NotContains(t, scope, "data:write", "data:write must not be granted when orchestrator lacks it") } else { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - resp.Body.Close() + _ = resp.Body.Close() } } @@ -273,7 +273,7 @@ func TestRevokedSubjectTokenCannotDelegate(t *testing.T) { // Revoke the orchestrator token. revokeResp := post(t, "/oauth2/token/revoke", map[string]string{"token": orchToken}, nil) require.Equal(t, http.StatusOK, revokeResp.StatusCode) - revokeResp.Body.Close() + _ = revokeResp.Body.Close() // Attempt token_exchange with the now-revoked token. subKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -289,7 +289,7 @@ func TestRevokedSubjectTokenCannotDelegate(t *testing.T) { "scope": "data:read", }, nil) assert.Equal(t, http.StatusBadRequest, resp.StatusCode, "revoked subject_token must be rejected") - resp.Body.Close() + _ = resp.Body.Close() } // TestRevokeMissingTokenReturns200 verifies RFC 7009 §2.2: revoke always returns 200 @@ -297,7 +297,7 @@ func TestRevokedSubjectTokenCannotDelegate(t *testing.T) { func TestRevokeMissingTokenReturns200(t *testing.T) { resp := post(t, "/oauth2/token/revoke", map[string]string{"token": "not-a-real-token"}, nil) assert.Equal(t, http.StatusOK, resp.StatusCode) - resp.Body.Close() + _ = resp.Body.Close() } // TestIntrospectUnknownToken verifies that introspecting an unknown token returns active:false. @@ -335,7 +335,7 @@ func TestAPIKeyGrant(t *testing.T) { // Revoke and confirm inactive. revokeResp := post(t, "/oauth2/token/revoke", map[string]string{"token": accessToken}, nil) require.Equal(t, http.StatusOK, revokeResp.StatusCode) - revokeResp.Body.Close() + _ = revokeResp.Body.Close() result = introspect(t, accessToken) assert.False(t, result["active"].(bool)) @@ -560,7 +560,7 @@ func TestRevokeTokenCascadesToChildren(t *testing.T) { // ── Revoke only the orchestrator token ─────────────────────────────────── revokeResp := post(t, "/oauth2/token/revoke", map[string]string{"token": orchToken}, nil) require.Equal(t, http.StatusOK, revokeResp.StatusCode) - revokeResp.Body.Close() + _ = revokeResp.Body.Close() // ── All three must now be inactive ─────────────────────────────────────── assert.False(t, introspect(t, orchToken)["active"].(bool), @@ -627,7 +627,7 @@ func TestRevokeMidChainDoesNotRevokeParent(t *testing.T) { // ── Revoke depth-1 only ────────────────────────────────────────────────── revokeResp := post(t, "/oauth2/token/revoke", map[string]string{"token": depth1Token}, nil) require.Equal(t, http.StatusOK, revokeResp.StatusCode) - revokeResp.Body.Close() + _ = revokeResp.Body.Close() // ── Parent (depth=0) must remain active ────────────────────────────────── assert.True(t, introspect(t, orchToken)["active"].(bool), @@ -680,7 +680,7 @@ func TestRevokeCascadesFanOut(t *testing.T) { // ── Revoke orchestrator ─────────────────────────────────────────────────── revokeResp := post(t, "/oauth2/token/revoke", map[string]string{"token": orchToken}, nil) require.Equal(t, http.StatusOK, revokeResp.StatusCode) - revokeResp.Body.Close() + _ = revokeResp.Body.Close() // ── All children must be inactive ──────────────────────────────────────── assert.False(t, introspect(t, orchToken)["active"].(bool), "orchestrator must be inactive") @@ -736,7 +736,7 @@ func TestRevokeDoesNotAffectSiblingChains(t *testing.T) { // Revoke chain A only. revokeResp := post(t, "/oauth2/token/revoke", map[string]string{"token": orchA}, nil) require.Equal(t, http.StatusOK, revokeResp.StatusCode) - revokeResp.Body.Close() + _ = revokeResp.Body.Close() // Chain A is gone. assert.False(t, introspect(t, orchA)["active"].(bool), "orch-A must be inactive") @@ -789,7 +789,7 @@ func TestRevokeDeepChain(t *testing.T) { // ── Revoke the root ─────────────────────────────────────────────────────── revokeResp := post(t, "/oauth2/token/revoke", map[string]string{"token": tokens[0]}, nil) require.Equal(t, http.StatusOK, revokeResp.StatusCode) - revokeResp.Body.Close() + _ = revokeResp.Body.Close() // ── All four must be inactive ───────────────────────────────────────────── for i, tok := range tokens { @@ -819,14 +819,14 @@ func TestRevokeIsIdempotent(t *testing.T) { // First revocation. r1 := post(t, "/oauth2/token/revoke", map[string]string{"token": token}, nil) assert.Equal(t, http.StatusOK, r1.StatusCode) - r1.Body.Close() + _ = r1.Body.Close() assert.False(t, introspect(t, token)["active"].(bool), "token must be inactive after first revocation") // Second revocation of the same token — must still return 200. r2 := post(t, "/oauth2/token/revoke", map[string]string{"token": token}, nil) assert.Equal(t, http.StatusOK, r2.StatusCode, "second revocation must return 200 per RFC 7009") - r2.Body.Close() + _ = r2.Body.Close() assert.False(t, introspect(t, token)["active"].(bool), "token must remain inactive after second revocation") } diff --git a/tests/integration/wellknown_test.go b/tests/integration/wellknown_test.go index d7dc82a..fd4dffb 100644 --- a/tests/integration/wellknown_test.go +++ b/tests/integration/wellknown_test.go @@ -68,12 +68,12 @@ func TestOAuthServerMetadata(t *testing.T) { func TestHealthEndpoint(t *testing.T) { resp := get(t, "/health", nil) assert.Equal(t, http.StatusOK, resp.StatusCode) - resp.Body.Close() + _ = resp.Body.Close() } // TestReadyEndpoint verifies that /ready returns 200 when the database is reachable. func TestReadyEndpoint(t *testing.T) { resp := get(t, "/ready", nil) assert.Equal(t, http.StatusOK, resp.StatusCode) - resp.Body.Close() + _ = resp.Body.Close() }