Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package jwt
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
)
Expand Down Expand Up @@ -214,7 +215,7 @@ type HeaderValidator func(alg string, headerDecoded []byte) (Alg, PublicKey, Inj
// its serialized format.
func compareHeader(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc, error) {
if n := len(headerDecoded); n < 25 /* 28 but allow custom short algs*/ {
if n == 15 { // header without "typ": "JWT".
if n == 15 { // header without "typ": "JWT" and no additional fields.
expectedHeader := createHeaderWithoutTyp(alg)
if bytes.Equal(expectedHeader, headerDecoded) {
return nil, nil, nil, nil
Expand All @@ -239,7 +240,29 @@ func compareHeader(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc

expectedHeader := createHeaderRaw(alg)
if !bytes.Equal(expectedHeader, headerDecoded) {
return nil, nil, nil, ErrTokenAlg
// The fast path failed. Try JSON parsing to handle:
// 1. Headers with additional fields (like "kid")
// 2. Headers missing the "typ" field
// 3. Non-standard field ordering
var header map[string]interface{}
if err := json.Unmarshal(headerDecoded, &header); err != nil {
return nil, nil, nil, ErrTokenAlg
}

// Validate that the algorithm matches
if headerAlg, ok := header["alg"].(string); !ok || headerAlg != alg {
return nil, nil, nil, ErrTokenAlg
}

// If typ field is present, it must be "JWT"
if typ, exists := header["typ"]; exists {
if typStr, ok := typ.(string); !ok || typStr != "JWT" {
return nil, nil, nil, ErrTokenAlg
}
}

// Header is valid - algorithm matches and typ is either missing or "JWT"
return nil, nil, nil, nil
}

return nil, nil, nil, nil
Expand Down
95 changes: 95 additions & 0 deletions token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,98 @@ func compareMap(m1, m2 map[string]any) bool {

return true
}

// This test verifies the fix for GitHub issue #13
func TestHeaderValidationWithAdditionalFields(t *testing.T) {
tests := []struct {
name string
header string
alg string
valid bool
}{
{
name: "Header with kid field and no typ (Cloudflare Zero Trust case)",
header: `{"kid":"test-key-id","alg":"ES384"}`,
alg: "ES384",
valid: true,
},
{
name: "Header with kid field and typ=JWT",
header: `{"kid":"test-key-id","alg":"ES384","typ":"JWT"}`,
alg: "ES384",
valid: true,
},
{
name: "Header with multiple additional fields",
header: `{"kid":"key1","iss":"example.com","alg":"RS256","jku":"https://example.com/jwks"}`,
alg: "RS256",
valid: true,
},
{
name: "Header with wrong algorithm",
header: `{"kid":"test-key","alg":"HS256"}`,
alg: "ES384",
valid: false,
},
{
name: "Header with invalid typ field",
header: `{"kid":"test-key","alg":"ES384","typ":"INVALID"}`,
alg: "ES384",
valid: false,
},
{
name: "Standard header (should still work)",
header: `{"alg":"ES384","typ":"JWT"}`,
alg: "ES384",
valid: true,
},
{
name: "Minimal header without typ (should still work)",
header: `{"alg":"ES384"}`,
alg: "ES384",
valid: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
headerBytes := []byte(tt.header)
_, _, _, err := compareHeader(tt.alg, headerBytes)

if tt.valid && err != nil {
t.Errorf("Expected header to be valid, but got error: %v", err)
}
if !tt.valid && err == nil {
t.Errorf("Expected header to be invalid, but validation passed")
}
})
}
}

func TestDecodeTokenWithKidField(t *testing.T) {
// Create a token with kid field in header (base64url encoded)
// Header: {"kid":"test-key-id","alg":"ES384"}
// Payload: {"sub":"123456","name":"John Doe","admin":true}
token := "eyJhbGciOiJFUzM4NCIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJzdWIiOiIxMjM0NTYiLCJuYW1lIjoiSm9obiBEb2UiLCJhZG1pbiI6dHJ1ZSwiaWF0IjoxNzUwMjA0MTEzfQ.QAfDHhbQgrejIol0fd4mmBMLU3i1Zn0yqN7ar41wRkuod7K5MbB0BjLxQHxB9PhuER9n7QuGSg8p45GDph4bjz17Z91MLwqlgMt0ws38O1MqxJ-gN9g0AyYzR86hTab5"

// This should not fail with "unexpected token algorithm" error
unverifiedToken, err := Decode([]byte(token))
if err != nil {
t.Fatalf("Should be able to decode token with kid field: %v", err)
}

if unverifiedToken == nil {
t.Fatal("Decoded token should not be nil")
}

// Verify we can extract claims
var claims map[string]interface{}
err = unverifiedToken.Claims(&claims)
if err != nil {
t.Fatalf("Should be able to extract claims: %v", err)
}

if claims["sub"] != "123456" {
t.Errorf("Expected sub claim to be '123456', got %v", claims["sub"])
}
}