diff --git a/internal/config/guard_policy.go b/internal/config/guard_policy.go index edebfa5b..cc33b9eb 100644 --- a/internal/config/guard_policy.go +++ b/internal/config/guard_policy.go @@ -3,7 +3,6 @@ package config import ( "encoding/json" "fmt" - "sort" "strings" "github.com/github/gh-aw-mcpg/internal/logger" @@ -201,85 +200,6 @@ func (p AllowOnlyPolicy) MarshalJSON() ([]byte, error) { return json.Marshal(serializedAllowOnly(p)) } -// ValidateGuardPolicy validates AllowOnly or WriteSink policy input. -func ValidateGuardPolicy(policy *GuardPolicy) error { - if policy == nil { - logGuardPolicy.Print("ValidateGuardPolicy: policy is nil") - return fmt.Errorf("policy must include allow-only or write-sink") - } - if policy.WriteSink != nil { - logGuardPolicy.Printf("ValidateGuardPolicy: delegating to write-sink validation, acceptCount=%d", len(policy.WriteSink.Accept)) - return ValidateWriteSinkPolicy(policy.WriteSink) - } - logGuardPolicy.Print("ValidateGuardPolicy: delegating to allow-only normalization") - _, err := NormalizeGuardPolicy(policy) - return err -} - -// ValidateWriteSinkPolicy validates a write-sink policy. -func ValidateWriteSinkPolicy(ws *WriteSinkPolicy) error { - if ws == nil { - return fmt.Errorf("write-sink policy must not be nil") - } - logGuardPolicy.Printf("ValidateWriteSinkPolicy: acceptCount=%d", len(ws.Accept)) - if len(ws.Accept) == 0 { - return fmt.Errorf("write-sink.accept must contain at least one entry") - } - // Special case: ["*"] is a valid wildcard that accepts all writes - if len(ws.Accept) == 1 && strings.TrimSpace(ws.Accept[0]) == "*" { - logGuardPolicy.Print("ValidateWriteSinkPolicy: wildcard accept, policy is valid") - return nil - } - seen := make(map[string]struct{}) - for _, entry := range ws.Accept { - entry = strings.TrimSpace(entry) - if entry == "" { - return fmt.Errorf("write-sink.accept entries must not be empty") - } - if entry == "*" { - return fmt.Errorf("write-sink.accept wildcard \"*\" must be the only entry") - } - if _, exists := seen[entry]; exists { - return fmt.Errorf("write-sink.accept must not contain duplicates") - } - seen[entry] = struct{}{} - if err := validateAcceptEntry(entry); err != nil { - return fmt.Errorf("write-sink.accept entry %q is invalid: %w", entry, err) - } - } - return nil -} - -// validateAcceptEntry validates a single accept entry. -// Accepted formats: -// - "visibility:owner/repo-pattern" (e.g., "private:github/gh-aw*") -// - "visibility:owner" (e.g., "private:myorg" — for owner-wildcard scopes) -// - "owner/repo-pattern" (e.g., "github/gh-aw*" — without visibility prefix) -// - "owner" (e.g., "myorg" — bare owner without visibility prefix) -// -// The accept entries must match the secrecy tags produced by the GitHub guard's -// label_agent. See WriteSinkAcceptRules for the mapping from allow-only repos -// to the required accept values. -func validateAcceptEntry(entry string) error { - scope := entry - if idx := strings.Index(entry, ":"); idx > 0 { - visibility := entry[:idx] - scope = entry[idx+1:] - validVisibility := map[string]bool{ - "private": true, "public": true, "internal": true, - } - if !validVisibility[visibility] { - return fmt.Errorf("visibility prefix must be private, public, or internal; got %q", visibility) - } - } - // Accept either "owner/repo-pattern" or bare "owner" (for owner-wildcard scopes - // where repos=["owner/*"] produces agent secrecy "private:owner") - if !isValidRepoScope(scope) && !isValidRepoOwner(scope) { - return fmt.Errorf("scope %q is invalid; expected owner, owner/*, owner/repo, or owner/re*", scope) - } - return nil -} - // WriteSinkAcceptRules documents the mapping from allow-only repos configuration // to the required write-sink accept values. // @@ -320,478 +240,3 @@ var WriteSinkAcceptRules = "see godoc" // exists for documentation only func (p *GuardPolicy) IsWriteSinkPolicy() bool { return p != nil && p.WriteSink != nil } - -// NormalizeGuardPolicy validates and normalizes an allow-only policy shape. -func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) { - if policy == nil || (policy.AllowOnly == nil && policy.WriteSink == nil) { - return nil, fmt.Errorf("policy must include allow-only or write-sink") - } - if policy.AllowOnly == nil { - // Write-sink policies don't produce a NormalizedGuardPolicy - return nil, fmt.Errorf("policy must include allow-only") - } - - integrity := strings.ToLower(strings.TrimSpace(policy.AllowOnly.MinIntegrity)) - if _, ok := validMinIntegrityValues[integrity]; !ok { - return nil, fmt.Errorf("allow-only.min-integrity must be one of: none, unapproved, approved, merged") - } - - normalized := &NormalizedGuardPolicy{MinIntegrity: integrity} - - logGuardPolicy.Printf("Normalizing guard policy: integrity=%s, reposType=%T", integrity, policy.AllowOnly.Repos) - - // Validate and normalize blocked-users. - // Dedup uses lowercased keys to match Rust guard's case-insensitive comparison. - if len(policy.AllowOnly.BlockedUsers) > 0 { - seen := make(map[string]struct{}, len(policy.AllowOnly.BlockedUsers)) - for _, u := range policy.AllowOnly.BlockedUsers { - u = strings.TrimSpace(u) - if u == "" { - return nil, fmt.Errorf("allow-only.blocked-users entries must not be empty") - } - key := strings.ToLower(u) - if _, exists := seen[key]; !exists { - seen[key] = struct{}{} - normalized.BlockedUsers = append(normalized.BlockedUsers, u) - } - } - } - - // Validate and normalize approval-labels. - // Dedup uses lowercased keys to match Rust guard's case-insensitive comparison. - if len(policy.AllowOnly.ApprovalLabels) > 0 { - seen := make(map[string]struct{}, len(policy.AllowOnly.ApprovalLabels)) - for _, l := range policy.AllowOnly.ApprovalLabels { - l = strings.TrimSpace(l) - if l == "" { - return nil, fmt.Errorf("allow-only.approval-labels entries must not be empty") - } - key := strings.ToLower(l) - if _, exists := seen[key]; !exists { - seen[key] = struct{}{} - normalized.ApprovalLabels = append(normalized.ApprovalLabels, l) - } - } - } - - // Validate and normalize trusted-users. - // Dedup uses lowercased keys to match Rust guard's case-insensitive comparison. - if len(policy.AllowOnly.TrustedUsers) > 0 { - seen := make(map[string]struct{}, len(policy.AllowOnly.TrustedUsers)) - for _, u := range policy.AllowOnly.TrustedUsers { - u = strings.TrimSpace(u) - if u == "" { - return nil, fmt.Errorf("allow-only.trusted-users entries must not be empty") - } - key := strings.ToLower(u) - if _, exists := seen[key]; !exists { - seen[key] = struct{}{} - normalized.TrustedUsers = append(normalized.TrustedUsers, u) - } - } - } - - // Validate and normalize endorsement-reactions. - // Dedup uses uppercased keys to match the GraphQL ReactionContent enum. - if len(policy.AllowOnly.EndorsementReactions) > 0 { - seen := make(map[string]struct{}, len(policy.AllowOnly.EndorsementReactions)) - for _, r := range policy.AllowOnly.EndorsementReactions { - r = strings.TrimSpace(r) - if r == "" { - return nil, fmt.Errorf("allow-only.endorsement-reactions entries must not be empty") - } - key := strings.ToUpper(r) - if _, exists := seen[key]; !exists { - seen[key] = struct{}{} - normalized.EndorsementReactions = append(normalized.EndorsementReactions, key) - } - } - } - - // Validate and normalize disapproval-reactions. - if len(policy.AllowOnly.DisapprovalReactions) > 0 { - seen := make(map[string]struct{}, len(policy.AllowOnly.DisapprovalReactions)) - for _, r := range policy.AllowOnly.DisapprovalReactions { - r = strings.TrimSpace(r) - if r == "" { - return nil, fmt.Errorf("allow-only.disapproval-reactions entries must not be empty") - } - key := strings.ToUpper(r) - if _, exists := seen[key]; !exists { - seen[key] = struct{}{} - normalized.DisapprovalReactions = append(normalized.DisapprovalReactions, key) - } - } - } - - // Validate and normalize disapproval-integrity (optional; empty means feature - // uses Rust-side default of "none" when endorsement/disapproval is evaluated). - if v := strings.ToLower(strings.TrimSpace(policy.AllowOnly.DisapprovalIntegrity)); v != "" { - if _, ok := validMinIntegrityValues[v]; !ok { - return nil, fmt.Errorf("allow-only.disapproval-integrity must be one of: none, unapproved, approved, merged") - } - normalized.DisapprovalIntegrity = v - } - - // Validate and normalize endorser-min-integrity (optional; empty means feature - // uses Rust-side default of "approved" when evaluating reactor eligibility). - if v := strings.ToLower(strings.TrimSpace(policy.AllowOnly.EndorserMinIntegrity)); v != "" { - if _, ok := validMinIntegrityValues[v]; !ok { - return nil, fmt.Errorf("allow-only.endorser-min-integrity must be one of: none, unapproved, approved, merged") - } - normalized.EndorserMinIntegrity = v - } - - switch scope := policy.AllowOnly.Repos.(type) { - case string: - scopeValue := strings.ToLower(strings.TrimSpace(scope)) - if scopeValue != "all" && scopeValue != "public" { - return nil, fmt.Errorf("allow-only.repos string must be 'all' or 'public'") - } - normalized.ScopeKind = scopeValue - logGuardPolicy.Printf("Guard policy normalized: scopeKind=%s, integrity=%s", normalized.ScopeKind, normalized.MinIntegrity) - return normalized, nil - - case []interface{}: - scopes, err := normalizeAndValidateScopeArray(scope) - if err != nil { - return nil, err - } - normalized.ScopeKind = "scoped" - normalized.ScopeValues = scopes - logGuardPolicy.Printf("Guard policy normalized: scopeKind=scoped, scopeCount=%d, integrity=%s", len(scopes), normalized.MinIntegrity) - return normalized, nil - - case []string: - generic := make([]interface{}, len(scope)) - for i := range scope { - generic[i] = scope[i] - } - scopes, err := normalizeAndValidateScopeArray(generic) - if err != nil { - return nil, err - } - normalized.ScopeKind = "scoped" - normalized.ScopeValues = scopes - logGuardPolicy.Printf("Guard policy normalized: scopeKind=scoped, scopeCount=%d, integrity=%s", len(scopes), normalized.MinIntegrity) - return normalized, nil - - default: - return nil, fmt.Errorf("allow-only.repos must be 'all', 'public', or a non-empty array of repo scope strings") - } -} - -func normalizeAndValidateScopeArray(scopes []interface{}) ([]string, error) { - if len(scopes) == 0 { - return nil, fmt.Errorf("allow-only.repos array must contain at least one scope") - } - logGuardPolicy.Printf("normalizeAndValidateScopeArray: validating %d repo scope entries", len(scopes)) - - seen := make(map[string]struct{}, len(scopes)) - normalized := make([]string, 0, len(scopes)) - - for _, scopeValue := range scopes { - scopeString, ok := scopeValue.(string) - if !ok { - return nil, fmt.Errorf("allow-only.repos array values must be strings") - } - - scopeString = strings.TrimSpace(scopeString) - if scopeString == "" { - return nil, fmt.Errorf("allow-only.repos scope entries must not be empty") - } - - if !isValidRepoScope(scopeString) { - return nil, fmt.Errorf("allow-only.repos scope %q is invalid; expected owner/*, owner/repo, or owner/re*", scopeString) - } - - if _, exists := seen[scopeString]; exists { - return nil, fmt.Errorf("allow-only.repos must not contain duplicates") - } - seen[scopeString] = struct{}{} - normalized = append(normalized, scopeString) - } - - sort.Strings(normalized) - return normalized, nil -} - -func isValidRepoScope(scope string) bool { - parts := strings.Split(scope, "/") - if len(parts) != 2 { - return false - } - - owner := parts[0] - repoPart := parts[1] - - if !isValidRepoOwner(owner) { - return false - } - - if repoPart == "*" { - return true - } - - if strings.Count(repoPart, "*") > 1 { - return false - } - - isPrefixWildcard := strings.HasSuffix(repoPart, "*") - if strings.Contains(repoPart, "*") && !isPrefixWildcard { - return false - } - - repoName := repoPart - if isPrefixWildcard { - repoName = strings.TrimSuffix(repoPart, "*") - if repoName == "" { - return false - } - } - - if !isValidRepoName(repoName) { - return false - } - - if isPrefixWildcard && strings.HasSuffix(repoName, ".") { - return false - } - - return true -} - -// isValidTokenString returns true if s is a non-empty string of at most maxLen -// lowercase-alphanumeric, underscore, or hyphen characters. -func isValidTokenString(s string, maxLen int) bool { - if len(s) < 1 || len(s) > maxLen { - return false - } - for i := 0; i < len(s); i++ { - if !isScopeTokenChar(s[i]) { - return false - } - } - return true -} - -func isValidRepoOwner(owner string) bool { - return isValidTokenString(owner, 39) -} - -func isValidRepoName(repo string) bool { - return isValidTokenString(repo, 100) -} - -func isScopeTokenChar(char byte) bool { - return (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '_' || char == '-' -} - -// ParseServerGuardPolicy parses a guard policy from a server-specific raw policy map. -// It handles both the modern allow-only/write-sink format and the legacy repos/min-integrity format. -// The serverID is used to look for a server-keyed nested policy map. -func ParseServerGuardPolicy(serverID string, raw map[string]interface{}) (*GuardPolicy, error) { - logGuardPolicy.Printf("ParseServerGuardPolicy: serverID=%s, keyCount=%d", serverID, len(raw)) - if len(raw) == 0 { - return nil, nil - } - - if policy, err := ParsePolicyMap(raw); err != nil { - return nil, err - } else if policy != nil { - return policy, nil - } - - if nested, ok := raw[serverID]; ok { - nestedMap, ok := nested.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("invalid guard policy for server '%s': expected object", serverID) - } - if policy, err := ParsePolicyMap(nestedMap); err != nil { - return nil, err - } else if policy != nil { - return policy, nil - } - } - - if len(raw) == 1 { - for _, value := range raw { - nestedMap, ok := value.(map[string]interface{}) - if !ok { - continue - } - if policy, err := ParsePolicyMap(nestedMap); err != nil { - return nil, err - } else if policy != nil { - return policy, nil - } - } - } - - return nil, nil -} - -// ParsePolicyMap parses a GuardPolicy from a raw map using either the modern -// allow-only/write-sink format or the legacy repos/min-integrity format. -func ParsePolicyMap(raw map[string]interface{}) (*GuardPolicy, error) { - if len(raw) == 0 { - return nil, nil - } - - hasAllowOnly := false - if _, ok := raw["allow-only"]; ok { - hasAllowOnly = true - } else if _, ok := raw["allowonly"]; ok { // Accept legacy "allowonly" form for backward compatibility - hasAllowOnly = true - } - - hasWriteSink := false - if _, ok := raw["write-sink"]; ok { - hasWriteSink = true - } else if _, ok := raw["writesink"]; ok { - hasWriteSink = true - } - - logGuardPolicy.Printf("ParsePolicyMap: hasAllowOnly=%v, hasWriteSink=%v, keyCount=%d", hasAllowOnly, hasWriteSink, len(raw)) - - if hasAllowOnly || hasWriteSink { - policyBytes, err := json.Marshal(raw) - if err != nil { - return nil, fmt.Errorf("failed to serialize server guard policy: %w", err) - } - policy, err := ParseGuardPolicyJSON(string(policyBytes)) - if err != nil { - return nil, fmt.Errorf("invalid server guard policy: %w", err) - } - return policy, nil - } - - repos, hasRepos := raw["repos"] - if !hasRepos { - return nil, nil - } - - integrityValue, hasIntegrity := raw["min-integrity"] - if !hasIntegrity { - integrityValue, hasIntegrity = raw["integrity"] - } - if !hasIntegrity { - return nil, fmt.Errorf("invalid server guard policy: repos specified without min-integrity") - } - - policy := &GuardPolicy{ - AllowOnly: &AllowOnlyPolicy{ - Repos: repos, - MinIntegrity: fmt.Sprintf("%v", integrityValue), - }, - } - if err := ValidateGuardPolicy(policy); err != nil { - return nil, fmt.Errorf("invalid server guard policy: %w", err) - } - - return policy, nil -} - -// BuildAllowOnlyPolicy constructs an AllowOnly GuardPolicy from the provided parameters. -// Exactly one of public or owner must be set. If repo is set, owner must also be set. -// Returns nil, nil if no scope or integrity is specified (indicating no policy). -func BuildAllowOnlyPolicy(public bool, owner, repo, minIntegrity string) (*GuardPolicy, error) { - logGuardPolicy.Printf("Building AllowOnly policy: public=%v, owner=%q, repo=%q, minIntegrity=%q", public, owner, repo, minIntegrity) - - owner = strings.TrimSpace(owner) - repo = strings.TrimSpace(repo) - integrityInput := strings.TrimSpace(minIntegrity) - integrityKey := strings.ToLower(strings.ReplaceAll(integrityInput, "-", "")) - - integrityByInput := map[string]string{ - "none": IntegrityNone, - "unapproved": IntegrityUnapproved, - "approved": IntegrityApproved, - "merged": IntegrityMerged, - } - integrity, hasIntegrity := integrityByInput[integrityKey] - - scopeCount := 0 - if public { - scopeCount++ - } - if owner != "" { - scopeCount++ - } - if repo != "" && owner == "" { - return nil, fmt.Errorf("allow-only scope repo requires allow-only scope owner") - } - - if scopeCount == 0 && minIntegrity == "" { - logGuardPolicy.Print("No AllowOnly scope or integrity specified, returning nil policy") - return nil, nil - } - if scopeCount != 1 { - return nil, fmt.Errorf("exactly one AllowOnly scope variant must be set (public or owner[/repo])") - } - if integrityInput == "" { - return nil, fmt.Errorf("min-integrity is required") - } - if !hasIntegrity { - return nil, fmt.Errorf("min-integrity must be one of: none, unapproved, approved, merged") - } - - var repos interface{} - if public { - repos = "public" - } else { - scope := owner + "/*" - if repo != "" { - scope = owner + "/" + repo - } - repos = []string{scope} - } - - logGuardPolicy.Printf("AllowOnly policy scope resolved: repos=%v, minIntegrity=%s", repos, integrity) - - policy := &GuardPolicy{ - AllowOnly: &AllowOnlyPolicy{ - Repos: repos, - MinIntegrity: integrity, - }, - } - - if err := ValidateGuardPolicy(policy); err != nil { - return nil, err - } - - logGuardPolicy.Print("AllowOnly policy built and validated successfully") - return policy, nil -} - -// ParseGuardPolicyJSON parses policy JSON and validates it. -func ParseGuardPolicyJSON(policyJSON string) (*GuardPolicy, error) { - logGuardPolicy.Printf("Parsing guard policy JSON: len=%d", len(policyJSON)) - policy := &GuardPolicy{} - if err := json.Unmarshal([]byte(policyJSON), policy); err != nil { - return nil, fmt.Errorf("invalid guard policy JSON: %w", err) - } - if err := ValidateGuardPolicy(policy); err != nil { - return nil, err - } - return policy, nil -} - -// NormalizeScopeKind returns a copy of the policy map with the scope_kind field -// normalized to lowercase trimmed string form. Other fields are preserved as-is. -func NormalizeScopeKind(policy map[string]interface{}) map[string]interface{} { - if policy == nil { - return nil - } - - normalized := make(map[string]interface{}, len(policy)) - for key, value := range policy { - normalized[key] = value - } - - if scopeKind, ok := normalized["scope_kind"].(string); ok { - normalized["scope_kind"] = strings.ToLower(strings.TrimSpace(scopeKind)) - } - - return normalized -} diff --git a/internal/config/guard_policy_parse.go b/internal/config/guard_policy_parse.go new file mode 100644 index 00000000..635a37ea --- /dev/null +++ b/internal/config/guard_policy_parse.go @@ -0,0 +1,216 @@ +package config + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ParseServerGuardPolicy parses a guard policy from a server-specific raw policy map. +// It handles both the modern allow-only/write-sink format and the legacy repos/min-integrity format. +// The serverID is used to look for a server-keyed nested policy map. +func ParseServerGuardPolicy(serverID string, raw map[string]interface{}) (*GuardPolicy, error) { + logGuardPolicy.Printf("ParseServerGuardPolicy: serverID=%s, keyCount=%d", serverID, len(raw)) + if len(raw) == 0 { + return nil, nil + } + + if policy, err := ParsePolicyMap(raw); err != nil { + return nil, err + } else if policy != nil { + return policy, nil + } + + if nested, ok := raw[serverID]; ok { + nestedMap, ok := nested.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid guard policy for server '%s': expected object", serverID) + } + if policy, err := ParsePolicyMap(nestedMap); err != nil { + return nil, err + } else if policy != nil { + return policy, nil + } + } + + if len(raw) == 1 { + for _, value := range raw { + nestedMap, ok := value.(map[string]interface{}) + if !ok { + continue + } + if policy, err := ParsePolicyMap(nestedMap); err != nil { + return nil, err + } else if policy != nil { + return policy, nil + } + } + } + + return nil, nil +} + +// ParsePolicyMap parses a GuardPolicy from a raw map using either the modern +// allow-only/write-sink format or the legacy repos/min-integrity format. +func ParsePolicyMap(raw map[string]interface{}) (*GuardPolicy, error) { + if len(raw) == 0 { + return nil, nil + } + + hasAllowOnly := false + if _, ok := raw["allow-only"]; ok { + hasAllowOnly = true + } else if _, ok := raw["allowonly"]; ok { // Accept legacy "allowonly" form for backward compatibility + hasAllowOnly = true + } + + hasWriteSink := false + if _, ok := raw["write-sink"]; ok { + hasWriteSink = true + } else if _, ok := raw["writesink"]; ok { + hasWriteSink = true + } + + logGuardPolicy.Printf("ParsePolicyMap: hasAllowOnly=%v, hasWriteSink=%v, keyCount=%d", hasAllowOnly, hasWriteSink, len(raw)) + + if hasAllowOnly || hasWriteSink { + policyBytes, err := json.Marshal(raw) + if err != nil { + return nil, fmt.Errorf("failed to serialize server guard policy: %w", err) + } + policy, err := ParseGuardPolicyJSON(string(policyBytes)) + if err != nil { + return nil, fmt.Errorf("invalid server guard policy: %w", err) + } + return policy, nil + } + + repos, hasRepos := raw["repos"] + if !hasRepos { + return nil, nil + } + + integrityValue, hasIntegrity := raw["min-integrity"] + if !hasIntegrity { + integrityValue, hasIntegrity = raw["integrity"] + } + if !hasIntegrity { + return nil, fmt.Errorf("invalid server guard policy: repos specified without min-integrity") + } + + policy := &GuardPolicy{ + AllowOnly: &AllowOnlyPolicy{ + Repos: repos, + MinIntegrity: fmt.Sprintf("%v", integrityValue), + }, + } + if err := ValidateGuardPolicy(policy); err != nil { + return nil, fmt.Errorf("invalid server guard policy: %w", err) + } + + return policy, nil +} + +// BuildAllowOnlyPolicy constructs an AllowOnly GuardPolicy from the provided parameters. +// Exactly one of public or owner must be set. If repo is set, owner must also be set. +// Returns nil, nil if no scope or integrity is specified (indicating no policy). +func BuildAllowOnlyPolicy(public bool, owner, repo, minIntegrity string) (*GuardPolicy, error) { + logGuardPolicy.Printf("Building AllowOnly policy: public=%v, owner=%q, repo=%q, minIntegrity=%q", public, owner, repo, minIntegrity) + + owner = strings.TrimSpace(owner) + repo = strings.TrimSpace(repo) + integrityInput := strings.TrimSpace(minIntegrity) + integrityKey := strings.ToLower(strings.ReplaceAll(integrityInput, "-", "")) + + integrityByInput := map[string]string{ + "none": IntegrityNone, + "unapproved": IntegrityUnapproved, + "approved": IntegrityApproved, + "merged": IntegrityMerged, + } + integrity, hasIntegrity := integrityByInput[integrityKey] + + scopeCount := 0 + if public { + scopeCount++ + } + if owner != "" { + scopeCount++ + } + if repo != "" && owner == "" { + return nil, fmt.Errorf("allow-only scope repo requires allow-only scope owner") + } + + if scopeCount == 0 && minIntegrity == "" { + logGuardPolicy.Print("No AllowOnly scope or integrity specified, returning nil policy") + return nil, nil + } + if scopeCount != 1 { + return nil, fmt.Errorf("exactly one AllowOnly scope variant must be set (public or owner[/repo])") + } + if integrityInput == "" { + return nil, fmt.Errorf("min-integrity is required") + } + if !hasIntegrity { + return nil, fmt.Errorf("min-integrity must be one of: none, unapproved, approved, merged") + } + + var repos interface{} + if public { + repos = "public" + } else { + scope := owner + "/*" + if repo != "" { + scope = owner + "/" + repo + } + repos = []string{scope} + } + + logGuardPolicy.Printf("AllowOnly policy scope resolved: repos=%v, minIntegrity=%s", repos, integrity) + + policy := &GuardPolicy{ + AllowOnly: &AllowOnlyPolicy{ + Repos: repos, + MinIntegrity: integrity, + }, + } + + if err := ValidateGuardPolicy(policy); err != nil { + return nil, err + } + + logGuardPolicy.Print("AllowOnly policy built and validated successfully") + return policy, nil +} + +// ParseGuardPolicyJSON parses policy JSON and validates it. +func ParseGuardPolicyJSON(policyJSON string) (*GuardPolicy, error) { + logGuardPolicy.Printf("Parsing guard policy JSON: len=%d", len(policyJSON)) + policy := &GuardPolicy{} + if err := json.Unmarshal([]byte(policyJSON), policy); err != nil { + return nil, fmt.Errorf("invalid guard policy JSON: %w", err) + } + if err := ValidateGuardPolicy(policy); err != nil { + return nil, err + } + return policy, nil +} + +// NormalizeScopeKind returns a copy of the policy map with the scope_kind field +// normalized to lowercase trimmed string form. Other fields are preserved as-is. +func NormalizeScopeKind(policy map[string]interface{}) map[string]interface{} { + if policy == nil { + return nil + } + + normalized := make(map[string]interface{}, len(policy)) + for key, value := range policy { + normalized[key] = value + } + + if scopeKind, ok := normalized["scope_kind"].(string); ok { + normalized["scope_kind"] = strings.ToLower(strings.TrimSpace(scopeKind)) + } + + return normalized +} diff --git a/internal/config/guard_policy_validation.go b/internal/config/guard_policy_validation.go new file mode 100644 index 00000000..87969af7 --- /dev/null +++ b/internal/config/guard_policy_validation.go @@ -0,0 +1,352 @@ +package config + +import ( + "fmt" + "sort" + "strings" +) + +// ValidateGuardPolicy validates AllowOnly or WriteSink policy input. +func ValidateGuardPolicy(policy *GuardPolicy) error { + if policy == nil { + logGuardPolicy.Print("ValidateGuardPolicy: policy is nil") + return fmt.Errorf("policy must include allow-only or write-sink") + } + if policy.WriteSink != nil { + logGuardPolicy.Printf("ValidateGuardPolicy: delegating to write-sink validation, acceptCount=%d", len(policy.WriteSink.Accept)) + return ValidateWriteSinkPolicy(policy.WriteSink) + } + logGuardPolicy.Print("ValidateGuardPolicy: delegating to allow-only normalization") + _, err := NormalizeGuardPolicy(policy) + return err +} + +// ValidateWriteSinkPolicy validates a write-sink policy. +func ValidateWriteSinkPolicy(ws *WriteSinkPolicy) error { + if ws == nil { + return fmt.Errorf("write-sink policy must not be nil") + } + logGuardPolicy.Printf("ValidateWriteSinkPolicy: acceptCount=%d", len(ws.Accept)) + if len(ws.Accept) == 0 { + return fmt.Errorf("write-sink.accept must contain at least one entry") + } + // Special case: ["*"] is a valid wildcard that accepts all writes + if len(ws.Accept) == 1 && strings.TrimSpace(ws.Accept[0]) == "*" { + logGuardPolicy.Print("ValidateWriteSinkPolicy: wildcard accept, policy is valid") + return nil + } + seen := make(map[string]struct{}) + for _, entry := range ws.Accept { + entry = strings.TrimSpace(entry) + if entry == "" { + return fmt.Errorf("write-sink.accept entries must not be empty") + } + if entry == "*" { + return fmt.Errorf("write-sink.accept wildcard \"*\" must be the only entry") + } + if _, exists := seen[entry]; exists { + return fmt.Errorf("write-sink.accept must not contain duplicates") + } + seen[entry] = struct{}{} + if err := validateAcceptEntry(entry); err != nil { + return fmt.Errorf("write-sink.accept entry %q is invalid: %w", entry, err) + } + } + return nil +} + +// validateAcceptEntry validates a single accept entry. +// Accepted formats: +// - "visibility:owner/repo-pattern" (e.g., "private:github/gh-aw*") +// - "visibility:owner" (e.g., "private:myorg" — for owner-wildcard scopes) +// - "owner/repo-pattern" (e.g., "github/gh-aw*" — without visibility prefix) +// - "owner" (e.g., "myorg" — bare owner without visibility prefix) +// +// The accept entries must match the secrecy tags produced by the GitHub guard's +// label_agent. See WriteSinkAcceptRules for the mapping from allow-only repos +// to the required accept values. +func validateAcceptEntry(entry string) error { + scope := entry + if idx := strings.Index(entry, ":"); idx > 0 { + visibility := entry[:idx] + scope = entry[idx+1:] + validVisibility := map[string]bool{ + "private": true, "public": true, "internal": true, + } + if !validVisibility[visibility] { + return fmt.Errorf("visibility prefix must be private, public, or internal; got %q", visibility) + } + } + // Accept either "owner/repo-pattern" or bare "owner" (for owner-wildcard scopes + // where repos=["owner/*"] produces agent secrecy "private:owner") + if !isValidRepoScope(scope) && !isValidRepoOwner(scope) { + return fmt.Errorf("scope %q is invalid; expected owner, owner/*, owner/repo, or owner/re*", scope) + } + return nil +} + +// NormalizeGuardPolicy validates and normalizes an allow-only policy shape. +func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) { + if policy == nil || (policy.AllowOnly == nil && policy.WriteSink == nil) { + return nil, fmt.Errorf("policy must include allow-only or write-sink") + } + if policy.AllowOnly == nil { + // Write-sink policies don't produce a NormalizedGuardPolicy + return nil, fmt.Errorf("policy must include allow-only") + } + + integrity := strings.ToLower(strings.TrimSpace(policy.AllowOnly.MinIntegrity)) + if _, ok := validMinIntegrityValues[integrity]; !ok { + return nil, fmt.Errorf("allow-only.min-integrity must be one of: none, unapproved, approved, merged") + } + + normalized := &NormalizedGuardPolicy{MinIntegrity: integrity} + + logGuardPolicy.Printf("Normalizing guard policy: integrity=%s, reposType=%T", integrity, policy.AllowOnly.Repos) + + // Validate and normalize blocked-users. + // Dedup uses lowercased keys to match Rust guard's case-insensitive comparison. + if len(policy.AllowOnly.BlockedUsers) > 0 { + seen := make(map[string]struct{}, len(policy.AllowOnly.BlockedUsers)) + for _, u := range policy.AllowOnly.BlockedUsers { + u = strings.TrimSpace(u) + if u == "" { + return nil, fmt.Errorf("allow-only.blocked-users entries must not be empty") + } + key := strings.ToLower(u) + if _, exists := seen[key]; !exists { + seen[key] = struct{}{} + normalized.BlockedUsers = append(normalized.BlockedUsers, u) + } + } + } + + // Validate and normalize approval-labels. + // Dedup uses lowercased keys to match Rust guard's case-insensitive comparison. + if len(policy.AllowOnly.ApprovalLabels) > 0 { + seen := make(map[string]struct{}, len(policy.AllowOnly.ApprovalLabels)) + for _, l := range policy.AllowOnly.ApprovalLabels { + l = strings.TrimSpace(l) + if l == "" { + return nil, fmt.Errorf("allow-only.approval-labels entries must not be empty") + } + key := strings.ToLower(l) + if _, exists := seen[key]; !exists { + seen[key] = struct{}{} + normalized.ApprovalLabels = append(normalized.ApprovalLabels, l) + } + } + } + + // Validate and normalize trusted-users. + // Dedup uses lowercased keys to match Rust guard's case-insensitive comparison. + if len(policy.AllowOnly.TrustedUsers) > 0 { + seen := make(map[string]struct{}, len(policy.AllowOnly.TrustedUsers)) + for _, u := range policy.AllowOnly.TrustedUsers { + u = strings.TrimSpace(u) + if u == "" { + return nil, fmt.Errorf("allow-only.trusted-users entries must not be empty") + } + key := strings.ToLower(u) + if _, exists := seen[key]; !exists { + seen[key] = struct{}{} + normalized.TrustedUsers = append(normalized.TrustedUsers, u) + } + } + } + + // Validate and normalize endorsement-reactions. + // Dedup uses uppercased keys to match the GraphQL ReactionContent enum. + if len(policy.AllowOnly.EndorsementReactions) > 0 { + seen := make(map[string]struct{}, len(policy.AllowOnly.EndorsementReactions)) + for _, r := range policy.AllowOnly.EndorsementReactions { + r = strings.TrimSpace(r) + if r == "" { + return nil, fmt.Errorf("allow-only.endorsement-reactions entries must not be empty") + } + key := strings.ToUpper(r) + if _, exists := seen[key]; !exists { + seen[key] = struct{}{} + normalized.EndorsementReactions = append(normalized.EndorsementReactions, key) + } + } + } + + // Validate and normalize disapproval-reactions. + if len(policy.AllowOnly.DisapprovalReactions) > 0 { + seen := make(map[string]struct{}, len(policy.AllowOnly.DisapprovalReactions)) + for _, r := range policy.AllowOnly.DisapprovalReactions { + r = strings.TrimSpace(r) + if r == "" { + return nil, fmt.Errorf("allow-only.disapproval-reactions entries must not be empty") + } + key := strings.ToUpper(r) + if _, exists := seen[key]; !exists { + seen[key] = struct{}{} + normalized.DisapprovalReactions = append(normalized.DisapprovalReactions, key) + } + } + } + + // Validate and normalize disapproval-integrity (optional; empty means feature + // uses Rust-side default of "none" when endorsement/disapproval is evaluated). + if v := strings.ToLower(strings.TrimSpace(policy.AllowOnly.DisapprovalIntegrity)); v != "" { + if _, ok := validMinIntegrityValues[v]; !ok { + return nil, fmt.Errorf("allow-only.disapproval-integrity must be one of: none, unapproved, approved, merged") + } + normalized.DisapprovalIntegrity = v + } + + // Validate and normalize endorser-min-integrity (optional; empty means feature + // uses Rust-side default of "approved" when evaluating reactor eligibility). + if v := strings.ToLower(strings.TrimSpace(policy.AllowOnly.EndorserMinIntegrity)); v != "" { + if _, ok := validMinIntegrityValues[v]; !ok { + return nil, fmt.Errorf("allow-only.endorser-min-integrity must be one of: none, unapproved, approved, merged") + } + normalized.EndorserMinIntegrity = v + } + + switch scope := policy.AllowOnly.Repos.(type) { + case string: + scopeValue := strings.ToLower(strings.TrimSpace(scope)) + if scopeValue != "all" && scopeValue != "public" { + return nil, fmt.Errorf("allow-only.repos string must be 'all' or 'public'") + } + normalized.ScopeKind = scopeValue + logGuardPolicy.Printf("Guard policy normalized: scopeKind=%s, integrity=%s", normalized.ScopeKind, normalized.MinIntegrity) + return normalized, nil + + case []interface{}: + scopes, err := normalizeAndValidateScopeArray(scope) + if err != nil { + return nil, err + } + normalized.ScopeKind = "scoped" + normalized.ScopeValues = scopes + logGuardPolicy.Printf("Guard policy normalized: scopeKind=scoped, scopeCount=%d, integrity=%s", len(scopes), normalized.MinIntegrity) + return normalized, nil + + case []string: + generic := make([]interface{}, len(scope)) + for i := range scope { + generic[i] = scope[i] + } + scopes, err := normalizeAndValidateScopeArray(generic) + if err != nil { + return nil, err + } + normalized.ScopeKind = "scoped" + normalized.ScopeValues = scopes + logGuardPolicy.Printf("Guard policy normalized: scopeKind=scoped, scopeCount=%d, integrity=%s", len(scopes), normalized.MinIntegrity) + return normalized, nil + + default: + return nil, fmt.Errorf("allow-only.repos must be 'all', 'public', or a non-empty array of repo scope strings") + } +} + +func normalizeAndValidateScopeArray(scopes []interface{}) ([]string, error) { + if len(scopes) == 0 { + return nil, fmt.Errorf("allow-only.repos array must contain at least one scope") + } + logGuardPolicy.Printf("normalizeAndValidateScopeArray: validating %d repo scope entries", len(scopes)) + + seen := make(map[string]struct{}, len(scopes)) + normalized := make([]string, 0, len(scopes)) + + for _, scopeValue := range scopes { + scopeString, ok := scopeValue.(string) + if !ok { + return nil, fmt.Errorf("allow-only.repos array values must be strings") + } + + scopeString = strings.TrimSpace(scopeString) + if scopeString == "" { + return nil, fmt.Errorf("allow-only.repos scope entries must not be empty") + } + + if !isValidRepoScope(scopeString) { + return nil, fmt.Errorf("allow-only.repos scope %q is invalid; expected owner/*, owner/repo, or owner/re*", scopeString) + } + + if _, exists := seen[scopeString]; exists { + return nil, fmt.Errorf("allow-only.repos must not contain duplicates") + } + seen[scopeString] = struct{}{} + normalized = append(normalized, scopeString) + } + + sort.Strings(normalized) + return normalized, nil +} + +func isValidRepoScope(scope string) bool { + parts := strings.Split(scope, "/") + if len(parts) != 2 { + return false + } + + owner := parts[0] + repoPart := parts[1] + + if !isValidRepoOwner(owner) { + return false + } + + if repoPart == "*" { + return true + } + + if strings.Count(repoPart, "*") > 1 { + return false + } + + isPrefixWildcard := strings.HasSuffix(repoPart, "*") + if strings.Contains(repoPart, "*") && !isPrefixWildcard { + return false + } + + repoName := repoPart + if isPrefixWildcard { + repoName = strings.TrimSuffix(repoPart, "*") + if repoName == "" { + return false + } + } + + if !isValidRepoName(repoName) { + return false + } + + if isPrefixWildcard && strings.HasSuffix(repoName, ".") { + return false + } + + return true +} + +// isValidTokenString returns true if s is a non-empty string of at most maxLen +// lowercase-alphanumeric, underscore, or hyphen characters. +func isValidTokenString(s string, maxLen int) bool { + if len(s) < 1 || len(s) > maxLen { + return false + } + for i := 0; i < len(s); i++ { + if !isScopeTokenChar(s[i]) { + return false + } + } + return true +} + +func isValidRepoOwner(owner string) bool { + return isValidTokenString(owner, 39) +} + +func isValidRepoName(repo string) bool { + return isValidTokenString(repo, 100) +} + +func isScopeTokenChar(char byte) bool { + return (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '_' || char == '-' +} diff --git a/internal/guard/wasm.go b/internal/guard/wasm.go index e138ba99..62020384 100644 --- a/internal/guard/wasm.go +++ b/internal/guard/wasm.go @@ -15,7 +15,6 @@ import ( "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" - "github.com/tetratelabs/wazero/sys" ) var logWasm = logger.New("guard:wasm") @@ -333,338 +332,6 @@ func (g *WasmGuard) Name() string { return g.name } -func normalizePolicyPayload(policy interface{}) (interface{}, error) { - if policy == nil { - return nil, fmt.Errorf("policy is required") - } - - if policyString, ok := policy.(string); ok { - trimmed := strings.TrimSpace(policyString) - if trimmed == "" { - return nil, fmt.Errorf("policy string is empty") - } - - var parsed interface{} - if err := json.Unmarshal([]byte(trimmed), &parsed); err != nil { - return nil, fmt.Errorf("policy string is not valid JSON object: %w", err) - } - - switch parsed.(type) { - case map[string]interface{}: - return parsed, nil - default: - return nil, fmt.Errorf("policy JSON must decode to an object") - } - } - - return policy, nil -} - -func buildStrictLabelAgentPayload(policy interface{}) (map[string]interface{}, error) { - if policy == nil { - return nil, fmt.Errorf("invalid guard policy transport shape: expected {\"allow-only\":{\"repos\":...,\"min-integrity\":...}}") - } - - if policyMap, ok := policy.(map[string]interface{}); ok { - if nested, hasPolicy := policyMap["policy"]; hasPolicy { - if nestedMap, nestedOK := nested.(map[string]interface{}); nestedOK { - if _, hasAllowOnly := nestedMap["allow-only"]; hasAllowOnly { - return nil, fmt.Errorf("gateway policy adapter is outdated: remove legacy envelope key policy before calling label_agent") - } - } - } - } - - payload, err := PolicyToMap(policy) - if err != nil { - return nil, fmt.Errorf("failed to decode label_agent policy payload: %w", err) - } - - if _, hasPolicyEnvelope := payload["policy"]; hasPolicyEnvelope { - return nil, fmt.Errorf("gateway policy adapter is outdated: remove legacy envelope key policy before calling label_agent") - } - - allowOnlyRaw, ok := payload["allow-only"] - if !ok { - // Accept legacy "allowonly" form for backward compatibility - allowOnlyRaw, ok = payload["allowonly"] - } - if !ok { - return nil, fmt.Errorf("label_agent policy must use top-level allow-only object (received policy.allow-only)") - } - - // Validate that the only allowed top-level keys are "allow-only" (or legacy "allowonly") - // and the optional "trusted-bots" key. - for k := range payload { - switch k { - case "allow-only", "allowonly", "trusted-bots": - // valid top-level keys - default: - return nil, fmt.Errorf("invalid guard policy transport shape: unexpected key %q", k) - } - } - - allowOnly, ok := allowOnlyRaw.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("invalid guard policy transport shape: expected {\"allow-only\":{\"repos\":...,\"min-integrity\":...}}") - } - - reposRaw, hasRepos := allowOnly["repos"] - integrityRaw, hasIntegrity := allowOnly["min-integrity"] - if !hasIntegrity { - integrityRaw, hasIntegrity = allowOnly["integrity"] - } - if !hasRepos || !hasIntegrity { - return nil, fmt.Errorf("invalid guard policy transport shape: missing required fields repos and/or min-integrity in allow-only object") - } - - // Validate that the allow-only object contains only known keys. - for k := range allowOnly { - switch k { - case "repos", "min-integrity", "integrity", "blocked-users", "approval-labels", "trusted-users", - "endorsement-reactions", "disapproval-reactions", "disapproval-integrity", "endorser-min-integrity": - // valid allow-only keys - default: - return nil, fmt.Errorf("invalid guard policy transport shape: unexpected allow-only key %q", k) - } - } - - if !isValidAllowOnlyRepos(reposRaw) { - return nil, fmt.Errorf("invalid repos value: expected all, public, or non-empty array of scoped strings") - } - - integrity, ok := integrityRaw.(string) - if !ok { - return nil, fmt.Errorf("invalid integrity value: expected one of none|unapproved|approved|merged") - } - - switch strings.ToLower(strings.TrimSpace(integrity)) { - case "none", "unapproved", "approved", "merged": - default: - return nil, fmt.Errorf("invalid integrity value: expected one of none|unapproved|approved|merged") - } - - // Validate blocked-users if present: must be a non-empty array of non-empty strings. - if blockedUsersRaw, ok := allowOnly["blocked-users"]; ok { - arr, ok := blockedUsersRaw.([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid blocked-users value: expected array of strings") - } - for _, entry := range arr { - if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { - return nil, fmt.Errorf("invalid blocked-users value: each entry must be a non-empty string") - } - } - } - - // Validate approval-labels if present: must be a non-empty array of non-empty strings. - if approvalLabelsRaw, ok := allowOnly["approval-labels"]; ok { - arr, ok := approvalLabelsRaw.([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid approval-labels value: expected array of strings") - } - for _, entry := range arr { - if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { - return nil, fmt.Errorf("invalid approval-labels value: each entry must be a non-empty string") - } - } - } - - // Validate trusted-bots if present. - // Per spec §4.1.3.4: trustedBots MUST be a non-empty array of strings when present. - if trustedBotsRaw, hasTrustedBots := payload["trusted-bots"]; hasTrustedBots { - trustedBots, ok := trustedBotsRaw.([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid trusted-bots value: expected non-empty array of strings") - } - if len(trustedBots) == 0 { - return nil, fmt.Errorf("invalid trusted-bots value: must be a non-empty array when present") - } - for _, entry := range trustedBots { - if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { - return nil, fmt.Errorf("invalid trusted-bots value: each entry must be a non-empty string") - } - } - } - - // Validate trusted-users if present inside allow-only. - // Must be a non-empty array of non-empty strings when present. - if trustedUsersRaw, ok := allowOnly["trusted-users"]; ok { - arr, ok := trustedUsersRaw.([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid trusted-users value: expected array of strings") - } - for _, entry := range arr { - if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { - return nil, fmt.Errorf("invalid trusted-users value: each entry must be a non-empty string") - } - } - } - - // Validate endorsement-reactions and disapproval-reactions if present. - for _, reactionKey := range []string{"endorsement-reactions", "disapproval-reactions"} { - if reactionsRaw, ok := allowOnly[reactionKey]; ok { - arr, ok := reactionsRaw.([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid %s value: expected array of strings", reactionKey) - } - for _, entry := range arr { - if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { - return nil, fmt.Errorf("invalid %s value: each entry must be a non-empty string", reactionKey) - } - } - } - } - - // Validate disapproval-integrity if present. - if disIntRaw, ok := allowOnly["disapproval-integrity"]; ok { - disInt, ok := disIntRaw.(string) - if !ok { - return nil, fmt.Errorf("invalid disapproval-integrity value: expected one of none|unapproved|approved|merged") - } - switch strings.ToLower(strings.TrimSpace(disInt)) { - case "none", "unapproved", "approved", "merged": - default: - return nil, fmt.Errorf("invalid disapproval-integrity value: expected one of none|unapproved|approved|merged") - } - } - - // Validate endorser-min-integrity if present. - if endMinRaw, ok := allowOnly["endorser-min-integrity"]; ok { - endMin, ok := endMinRaw.(string) - if !ok { - return nil, fmt.Errorf("invalid endorser-min-integrity value: expected one of none|unapproved|approved|merged") - } - switch strings.ToLower(strings.TrimSpace(endMin)) { - case "none", "unapproved", "approved", "merged": - default: - return nil, fmt.Errorf("invalid endorser-min-integrity value: expected one of none|unapproved|approved|merged") - } - } - - return payload, nil -} - -// BuildLabelAgentPayload constructs the label_agent input payload from the given guard policy -// and optional lists of additional trusted bot usernames and trusted user logins. The trusted -// bots are merged with the guard's built-in list and cannot remove any built-in entries. If -// both trustedBots and trustedUsers are nil or empty, the returned payload contains only the -// allow-only policy. -func BuildLabelAgentPayload(policy interface{}, trustedBots []string, trustedUsers []string) interface{} { - if len(trustedBots) == 0 && len(trustedUsers) == 0 { - return policy - } - - // Convert the policy to a generic map so we can inject the trusted-bots and - // trusted-users keys alongside the allow-only policy without altering the - // policy itself. - payload, err := PolicyToMap(policy) - if err != nil { - // If we can't convert the policy, return it as-is; buildStrictLabelAgentPayload - // will surface the error later. - return policy - } - - if len(trustedBots) > 0 { - // trusted-bots is a top-level key in the label_agent payload. - // Convert []string to []interface{} for JSON compatibility. - bots := make([]interface{}, len(trustedBots)) - for i, b := range trustedBots { - bots[i] = b - } - payload["trusted-bots"] = bots - } - - if len(trustedUsers) > 0 { - // trusted-users is injected inside the allow-only object. - // Convert []string to []interface{} for JSON compatibility. - // If allow-only is absent, the injection is skipped and buildStrictLabelAgentPayload - // will reject the payload when called with the missing allow-only key. - users := make([]interface{}, len(trustedUsers)) - for i, u := range trustedUsers { - users[i] = u - } - // Inject into allow-only object if present - if allowOnly, ok := payload["allow-only"].(map[string]interface{}); ok { - allowOnly["trusted-users"] = users - } - } - - return payload -} - -func isValidAllowOnlyRepos(repos interface{}) bool { - switch value := repos.(type) { - case string: - trimmed := strings.TrimSpace(strings.ToLower(value)) - return trimmed == "all" || trimmed == "public" - case []interface{}: - if len(value) == 0 { - return false - } - for _, entry := range value { - if _, ok := entry.(string); !ok { - return false - } - } - return true - default: - return false - } -} - -// checkBoolFailure returns a non-nil error if the given raw response map -// contains field key set to false, extracting the "error" message if present. -func checkBoolFailure(raw map[string]interface{}, resultJSON []byte, key string) error { - val, ok := raw[key].(bool) - if !ok || val { - return nil // field absent or true — not a failure - } - if message, msgOK := raw["error"].(string); msgOK && strings.TrimSpace(message) != "" { - logWasm.Printf("label_agent response indicated failure: error=%s, response=%s", message, string(resultJSON)) - return fmt.Errorf("label_agent rejected policy: %s", message) - } - logWasm.Printf("label_agent response indicated non-success status: response=%s", string(resultJSON)) - return fmt.Errorf("label_agent returned non-success status") -} - -func parseLabelAgentResponse(resultJSON []byte) (*LabelAgentResult, error) { - var raw map[string]interface{} - if err := json.Unmarshal(resultJSON, &raw); err != nil { - logWasm.Printf("label_agent response parse error (invalid JSON): error=%v, raw=%s", err, string(resultJSON)) - return nil, fmt.Errorf("failed to unmarshal label_agent response: %w", err) - } - - if err := checkBoolFailure(raw, resultJSON, "success"); err != nil { - return nil, err - } - if err := checkBoolFailure(raw, resultJSON, "ok"); err != nil { - return nil, err - } - if message, ok := raw["error"].(string); ok && strings.TrimSpace(message) != "" { - logWasm.Printf("label_agent response contained error field: error=%s, response=%s", message, string(resultJSON)) - return nil, fmt.Errorf("label_agent returned error: %s", message) - } - - var result LabelAgentResult - if err := json.Unmarshal(resultJSON, &result); err != nil { - logWasm.Printf("label_agent response decode error: error=%v, response=%s", err, string(resultJSON)) - return nil, fmt.Errorf("failed to decode label_agent response: %w", err) - } - - if strings.TrimSpace(result.DIFCMode) == "" { - logWasm.Printf("label_agent response missing difc_mode: response=%s", string(resultJSON)) - return nil, fmt.Errorf("label_agent response missing difc_mode") - } - - if _, err := difc.ParseEnforcementMode(result.DIFCMode); err != nil { - logWasm.Printf("label_agent response invalid difc_mode=%q: error=%v, response=%s", result.DIFCMode, err, string(resultJSON)) - return nil, fmt.Errorf("invalid difc_mode from label_agent: %w", err) - } - - return &result, nil -} - // callWasmGuardFunction serialises WASM access, sets the backend reference, marshals // inputData, logs the input, calls the named WASM export, and returns the raw result. // All three public dispatch methods (LabelAgent, LabelResource, LabelResponse) share @@ -822,396 +489,6 @@ func (g *WasmGuard) LabelResponse(ctx context.Context, toolName string, result i return nil, nil } -// parsePathLabeledResponse parses the new path-based labeling format -// This is more efficient as guards don't need to copy data, just return paths and labels -func parsePathLabeledResponse(responseJSON []byte, originalData interface{}) (difc.LabeledData, error) { - pathLabels, err := difc.ParsePathLabels(responseJSON) - if err != nil { - return nil, fmt.Errorf("failed to parse path labels: %w", err) - } - - pld, err := difc.NewPathLabeledData(originalData, pathLabels) - if err != nil { - return nil, fmt.Errorf("failed to apply path labels: %w", err) - } - - // Convert to CollectionLabeledData for compatibility with existing filtering - return pld.ToCollectionLabeledData(), nil -} - -// isWasmTrap reports whether err represents a WASM execution trap that should -// permanently poison the guard. Normal process exits (exit code 0, e.g. TinyGo -// init) are NOT considered traps. A non-zero exit code is treated as a trap. -// As a fallback for wazero execution faults (e.g. Rust panic → unreachable), -// the function also matches on wazero's "wasm error:" message prefix. -func isWasmTrap(err error) bool { - if err == nil { - return false - } - // A normal WASI process exit (exit code 0) is not a trap — don't poison the guard. - var exitErr *sys.ExitError - if errors.As(err, &exitErr) { - return exitErr.ExitCode() != 0 - } - // Fallback for wazero execution traps (e.g. Rust panic → unreachable). - return strings.Contains(err.Error(), "wasm error:") -} - -// callWasmFunction calls an exported function in the WASM module. -// Precondition: g.mu must be held by the caller. All public methods -// (LabelAgent, LabelResource, LabelResponse) hold g.mu for their entire -// duration, satisfying this requirement. -func (g *WasmGuard) callWasmFunction(ctx context.Context, funcName string, inputJSON []byte) ([]byte, error) { - // If the module has already trapped, refuse further calls immediately. - // A WASM trap may corrupt the module's internal state (e.g. the global - // policy context stored by label_agent), so all subsequent calls are - // unsafe until the guard is reloaded. - if g.failed { - return nil, fmt.Errorf("WASM guard '%s' is unavailable after a previous trap: %w", g.name, g.failedErr) - } - - fn := g.module.ExportedFunction(funcName) - if fn == nil { - return nil, fmt.Errorf("function %s not exported from WASM module", funcName) - } - - mem := g.module.Memory() - if mem == nil { - return nil, fmt.Errorf("WASM module has no memory") - } - - // Start with 4MB output buffer, can grow up to 16MB if needed - initialOutputSize := uint32(4 * 1024 * 1024) // 4MB initial - maxOutputSize := uint32(16 * 1024 * 1024) // 16MB maximum - maxInputSize := uint32(8 * 1024 * 1024) // 8MB max input - - if uint32(len(inputJSON)) > maxInputSize { - return nil, fmt.Errorf("input too large: %d bytes (max %d)", len(inputJSON), maxInputSize) - } - - // Adaptive output buffer strategy: - // - // WASM guards communicate buffer-too-small via a return code convention: - // -2 → buffer too small; first 4 bytes of the output buffer MAY contain the - // required size as a little-endian uint32. If present and > 0, we use - // that size for the next attempt; otherwise we double the buffer. - // < 0 → other error (returned as-is to the caller). - // >= 0 → success; value is the number of bytes written to the output buffer. - // - // We retry up to maxRetries times, growing from 4MB toward the 16MB ceiling. - // A WASM trap (e.g. "wasm error: unreachable" from a Rust panic) permanently - // marks the guard as failed because the module's internal state may be corrupt. - outputSize := initialOutputSize - const maxRetries = 3 - - for attempt := 0; attempt < maxRetries; attempt++ { - result, requiredSize, err := g.tryCallWasmFunction(ctx, fn, mem, inputJSON, outputSize) - if err != nil { - if isWasmTrap(err) { - // A WASM trap (e.g. unreachable from a Rust panic) leaves the - // module in an undefined state. Log it prominently and mark the - // guard as permanently failed so callers get a clear error. - logger.LogError("backend", "WASM guard trap: guard=%s, func=%s, error=%v", g.name, funcName, err) - g.failed = true - g.failedErr = err - } - return nil, err - } - - // If we got a result, return it - if result != nil { - return result, nil - } - - // Buffer was too small, check if we can grow - if requiredSize == 0 { - // Guard didn't tell us the required size, double the buffer - requiredSize = outputSize * 2 - } - - if requiredSize > maxOutputSize { - return nil, fmt.Errorf("guard requires buffer of %d bytes which exceeds maximum of %d bytes", requiredSize, maxOutputSize) - } - - logWasm.Printf("Buffer too small (%d bytes), retrying with %d bytes", outputSize, requiredSize) - outputSize = requiredSize - } - - return nil, fmt.Errorf("failed after %d attempts, buffer size %d still insufficient", maxRetries, outputSize) -} - -// tryCallWasmFunction attempts to call the WASM function with the given buffer size -// Returns (result, 0, nil) on success -// Returns (nil, requiredSize, nil) if buffer was too small -// Returns (nil, 0, error) on actual error -func (g *WasmGuard) tryCallWasmFunction(ctx context.Context, fn api.Function, mem api.Memory, inputJSON []byte, outputSize uint32) ([]byte, uint32, error) { - inputSize := uint32(len(inputJSON)) - - // Preferred path: use guard allocator if exported to avoid overlapping - // host-managed buffers with guard heap allocations. - allocFn := g.module.ExportedFunction("alloc") - deallocFn := g.module.ExportedFunction("dealloc") - if allocFn != nil { - // Use a non-cancelable context for cleanup to avoid leaking WASM heap - // allocations if the request context is canceled or times out. - cleanupCtx := context.WithoutCancel(ctx) - - inputPtr, err := g.wasmAlloc(ctx, allocFn, inputSize) - if err != nil { - return nil, 0, fmt.Errorf("failed to allocate WASM input buffer: %w", err) - } - defer g.wasmDealloc(cleanupCtx, deallocFn, inputPtr, inputSize) - - outputPtr, err := g.wasmAlloc(ctx, allocFn, outputSize) - if err != nil { - return nil, 0, fmt.Errorf("failed to allocate WASM output buffer: %w", err) - } - defer g.wasmDealloc(cleanupCtx, deallocFn, outputPtr, outputSize) - - if !mem.Write(inputPtr, inputJSON) { - return nil, 0, fmt.Errorf("failed to write input to WASM memory") - } - - results, err := fn.Call(ctx, - uint64(inputPtr), - uint64(inputSize), - uint64(outputPtr), - uint64(outputSize)) - if err != nil { - return nil, 0, fmt.Errorf("WASM function call failed: %w", err) - } - - resultLen := int32(results[0]) - if resultLen == -2 { - if requiredSize, ok := mem.ReadUint32Le(outputPtr); ok && requiredSize > 0 { - return nil, requiredSize, nil - } - return nil, 0, nil - } - - if resultLen < 0 { - return nil, 0, fmt.Errorf("WASM function returned error code: %d", resultLen) - } - - if resultLen == 0 { - return []byte{}, 0, nil - } - - outputJSON, ok := mem.Read(outputPtr, uint32(resultLen)) - if !ok { - return nil, 0, fmt.Errorf("failed to read output from WASM memory (len=%d)", resultLen) - } - - // Copy out of WASM linear memory before deferred dealloc runs. - resultCopy := append([]byte(nil), outputJSON...) - return resultCopy, 0, nil - } - - // Ensure memory is large enough for our buffers - // Layout: [...guard memory...][input buffer][output buffer] - requiredMemory := inputSize + outputSize + uint32(64*1024) // Extra 64KB for safety margin - - memSize := mem.Size() - if memSize < requiredMemory { - pages := (requiredMemory - memSize + 65535) / 65536 // Round up to pages - _, success := mem.Grow(pages) - if !success { - return nil, 0, fmt.Errorf("failed to grow WASM memory from %d to %d bytes", memSize, requiredMemory) - } - memSize = mem.Size() - } - - // Place buffers at end of memory - outputPtr := memSize - outputSize - inputPtr := outputPtr - inputSize - - // Write input to WASM memory - if !mem.Write(inputPtr, inputJSON) { - return nil, 0, fmt.Errorf("failed to write input to WASM memory") - } - - // Call the WASM function - results, err := fn.Call(ctx, - uint64(inputPtr), - uint64(inputSize), - uint64(outputPtr), - uint64(outputSize)) - if err != nil { - return nil, 0, fmt.Errorf("WASM function call failed: %w", err) - } - - // Check result - resultLen := int32(results[0]) - - // Error code -2 means "buffer too small" - // The guard can optionally return the required size in the output buffer as a uint32 - if resultLen == -2 { - // Try to read the required size from the output buffer (first 4 bytes as uint32) - if requiredSize, ok := mem.ReadUint32Le(outputPtr); ok && requiredSize > 0 { - return nil, requiredSize, nil - } - // Guard didn't specify size, return 0 to trigger doubling - return nil, 0, nil - } - - // Other negative values are errors - if resultLen < 0 { - return nil, 0, fmt.Errorf("WASM function returned error code: %d", resultLen) - } - - if resultLen == 0 { - return []byte{}, 0, nil - } - - // Read output from WASM memory - outputJSON, ok := mem.Read(outputPtr, uint32(resultLen)) - if !ok { - return nil, 0, fmt.Errorf("failed to read output from WASM memory (len=%d)", resultLen) - } - - // Copy out of WASM linear memory to avoid aliasing with future calls. - resultCopy := append([]byte(nil), outputJSON...) - return resultCopy, 0, nil -} - -func (g *WasmGuard) wasmAlloc(ctx context.Context, allocFn api.Function, size uint32) (uint32, error) { - results, err := allocFn.Call(ctx, uint64(size)) - if err != nil { - return 0, err - } - if len(results) == 0 { - return 0, fmt.Errorf("alloc returned no result") - } - ptr := uint32(results[0]) - if ptr == 0 { - return 0, fmt.Errorf("alloc returned null pointer") - } - return ptr, nil -} - -func (g *WasmGuard) wasmDealloc(ctx context.Context, deallocFn api.Function, ptr, size uint32) { - if deallocFn == nil || ptr == 0 || size == 0 { - return - } - if _, err := deallocFn.Call(ctx, uint64(ptr), uint64(size)); err != nil { - logWasm.Printf("WASM dealloc failed: ptr=%d size=%d err=%v", ptr, size, err) - } -} - -// parseResourceResponse converts guard response to LabeledResource -func parseResourceResponse(response map[string]interface{}) (*difc.LabeledResource, difc.OperationType, error) { - resourceData, ok := response["resource"].(map[string]interface{}) - if !ok { - return nil, difc.OperationWrite, fmt.Errorf("invalid resource format in guard response") - } - - resource := &difc.LabeledResource{} - - if desc, ok := resourceData["description"].(string); ok { - resource.Description = desc - } - - // Parse secrecy tags - if secrecy, ok := resourceData["secrecy"].([]interface{}); ok { - tags := make([]difc.Tag, 0, len(secrecy)) - for _, t := range secrecy { - if tagStr, ok := t.(string); ok { - tags = append(tags, difc.Tag(tagStr)) - } - } - resource.Secrecy = *difc.NewSecrecyLabelWithTags(tags) - } else { - resource.Secrecy = *difc.NewSecrecyLabel() - } - - // Parse integrity tags - if integrity, ok := resourceData["integrity"].([]interface{}); ok { - tags := make([]difc.Tag, 0, len(integrity)) - for _, t := range integrity { - if tagStr, ok := t.(string); ok { - tags = append(tags, difc.Tag(tagStr)) - } - } - resource.Integrity = *difc.NewIntegrityLabelWithTags(tags) - } else { - resource.Integrity = *difc.NewIntegrityLabel() - } - - // Parse operation type - operation := difc.OperationWrite // default to most restrictive - if opStr, ok := response["operation"].(string); ok { - switch opStr { - case "read": - operation = difc.OperationRead - case "write": - operation = difc.OperationWrite - case "read-write": - operation = difc.OperationReadWrite - } - } - - return resource, operation, nil -} - -// parseCollectionLabeledData converts an array of items to CollectionLabeledData -func parseCollectionLabeledData(items []interface{}) (*difc.CollectionLabeledData, error) { - collection := &difc.CollectionLabeledData{ - Items: make([]difc.LabeledItem, 0, len(items)), - } - - for _, item := range items { - itemMap, ok := item.(map[string]interface{}) - if !ok { - continue - } - - labeledItem := difc.LabeledItem{ - Data: itemMap["data"], - } - - // Parse labels - if labelsData, ok := itemMap["labels"].(map[string]interface{}); ok { - labels := &difc.LabeledResource{} - - if desc, ok := labelsData["description"].(string); ok { - labels.Description = desc - } - - // Parse secrecy tags - if secrecy, ok := labelsData["secrecy"].([]interface{}); ok { - tags := make([]difc.Tag, 0, len(secrecy)) - for _, t := range secrecy { - if tagStr, ok := t.(string); ok { - tags = append(tags, difc.Tag(tagStr)) - } - } - labels.Secrecy = *difc.NewSecrecyLabelWithTags(tags) - } else { - labels.Secrecy = *difc.NewSecrecyLabel() - } - - // Parse integrity tags - if integrity, ok := labelsData["integrity"].([]interface{}); ok { - tags := make([]difc.Tag, 0, len(integrity)) - for _, t := range integrity { - if tagStr, ok := t.(string); ok { - tags = append(tags, difc.Tag(tagStr)) - } - } - labels.Integrity = *difc.NewIntegrityLabelWithTags(tags) - } else { - labels.Integrity = *difc.NewIntegrityLabel() - } - - labeledItem.Labels = labels - } - - collection.Items = append(collection.Items, labeledItem) - } - - return collection, nil -} - // Close releases WASM runtime resources func (g *WasmGuard) Close(ctx context.Context) error { var moduleErr, runtimeErr error diff --git a/internal/guard/wasm_parse.go b/internal/guard/wasm_parse.go new file mode 100644 index 00000000..2f71fcdd --- /dev/null +++ b/internal/guard/wasm_parse.go @@ -0,0 +1,445 @@ +package guard + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/github/gh-aw-mcpg/internal/difc" + "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/sys" +) + +// parseLabelAgentResponse validates and decodes the raw JSON returned by the +// WASM label_agent function into a LabelAgentResult. +func parseLabelAgentResponse(resultJSON []byte) (*LabelAgentResult, error) { + var raw map[string]interface{} + if err := json.Unmarshal(resultJSON, &raw); err != nil { + logWasm.Printf("label_agent response parse error (invalid JSON): error=%v, raw=%s", err, string(resultJSON)) + return nil, fmt.Errorf("failed to unmarshal label_agent response: %w", err) + } + + if err := checkBoolFailure(raw, resultJSON, "success"); err != nil { + return nil, err + } + if err := checkBoolFailure(raw, resultJSON, "ok"); err != nil { + return nil, err + } + if message, ok := raw["error"].(string); ok && strings.TrimSpace(message) != "" { + logWasm.Printf("label_agent response contained error field: error=%s, response=%s", message, string(resultJSON)) + return nil, fmt.Errorf("label_agent returned error: %s", message) + } + + var result LabelAgentResult + if err := json.Unmarshal(resultJSON, &result); err != nil { + logWasm.Printf("label_agent response decode error: error=%v, response=%s", err, string(resultJSON)) + return nil, fmt.Errorf("failed to decode label_agent response: %w", err) + } + + if strings.TrimSpace(result.DIFCMode) == "" { + logWasm.Printf("label_agent response missing difc_mode: response=%s", string(resultJSON)) + return nil, fmt.Errorf("label_agent response missing difc_mode") + } + + if _, err := difc.ParseEnforcementMode(result.DIFCMode); err != nil { + logWasm.Printf("label_agent response invalid difc_mode=%q: error=%v, response=%s", result.DIFCMode, err, string(resultJSON)) + return nil, fmt.Errorf("invalid difc_mode from label_agent: %w", err) + } + + return &result, nil +} + +// parsePathLabeledResponse parses the path-based labeling format. +// This is more efficient as guards don't need to copy data, just return paths and labels. +func parsePathLabeledResponse(responseJSON []byte, originalData interface{}) (difc.LabeledData, error) { + pathLabels, err := difc.ParsePathLabels(responseJSON) + if err != nil { + return nil, fmt.Errorf("failed to parse path labels: %w", err) + } + + pld, err := difc.NewPathLabeledData(originalData, pathLabels) + if err != nil { + return nil, fmt.Errorf("failed to apply path labels: %w", err) + } + + // Convert to CollectionLabeledData for compatibility with existing filtering + return pld.ToCollectionLabeledData(), nil +} + +// isWasmTrap reports whether err represents a WASM execution trap that should +// permanently poison the guard. Normal process exits (exit code 0, e.g. TinyGo +// init) are NOT considered traps. A non-zero exit code is treated as a trap. +// As a fallback for wazero execution faults (e.g. Rust panic → unreachable), +// the function also matches on wazero's "wasm error:" message prefix. +func isWasmTrap(err error) bool { + if err == nil { + return false + } + // A normal WASI process exit (exit code 0) is not a trap — don't poison the guard. + var exitErr *sys.ExitError + if errors.As(err, &exitErr) { + return exitErr.ExitCode() != 0 + } + // Fallback for wazero execution traps (e.g. Rust panic → unreachable). + return strings.Contains(err.Error(), "wasm error:") +} + +// callWasmFunction calls an exported function in the WASM module. +// Precondition: g.mu must be held by the caller. All public methods +// (LabelAgent, LabelResource, LabelResponse) hold g.mu for their entire +// duration, satisfying this requirement. +func (g *WasmGuard) callWasmFunction(ctx context.Context, funcName string, inputJSON []byte) ([]byte, error) { + // If the module has already trapped, refuse further calls immediately. + // A WASM trap may corrupt the module's internal state (e.g. the global + // policy context stored by label_agent), so all subsequent calls are + // unsafe until the guard is reloaded. + if g.failed { + return nil, fmt.Errorf("WASM guard '%s' is unavailable after a previous trap: %w", g.name, g.failedErr) + } + + fn := g.module.ExportedFunction(funcName) + if fn == nil { + return nil, fmt.Errorf("function %s not exported from WASM module", funcName) + } + + mem := g.module.Memory() + if mem == nil { + return nil, fmt.Errorf("WASM module has no memory") + } + + // Start with 4MB output buffer, can grow up to 16MB if needed + initialOutputSize := uint32(4 * 1024 * 1024) // 4MB initial + maxOutputSize := uint32(16 * 1024 * 1024) // 16MB maximum + maxInputSize := uint32(8 * 1024 * 1024) // 8MB max input + + if uint32(len(inputJSON)) > maxInputSize { + return nil, fmt.Errorf("input too large: %d bytes (max %d)", len(inputJSON), maxInputSize) + } + + // Adaptive output buffer strategy: + // + // WASM guards communicate buffer-too-small via a return code convention: + // -2 → buffer too small; first 4 bytes of the output buffer MAY contain the + // required size as a little-endian uint32. If present and > 0, we use + // that size for the next attempt; otherwise we double the buffer. + // < 0 → other error (returned as-is to the caller). + // >= 0 → success; value is the number of bytes written to the output buffer. + // + // We retry up to maxRetries times, growing from 4MB toward the 16MB ceiling. + // A WASM trap (e.g. "wasm error: unreachable" from a Rust panic) permanently + // marks the guard as failed because the module's internal state may be corrupt. + outputSize := initialOutputSize + const maxRetries = 3 + + for attempt := 0; attempt < maxRetries; attempt++ { + result, requiredSize, err := g.tryCallWasmFunction(ctx, fn, mem, inputJSON, outputSize) + if err != nil { + if isWasmTrap(err) { + // A WASM trap (e.g. unreachable from a Rust panic) leaves the + // module in an undefined state. Log it prominently and mark the + // guard as permanently failed so callers get a clear error. + logger.LogError("backend", "WASM guard trap: guard=%s, func=%s, error=%v", g.name, funcName, err) + g.failed = true + g.failedErr = err + } + return nil, err + } + + // If we got a result, return it + if result != nil { + return result, nil + } + + // Buffer was too small, check if we can grow + if requiredSize == 0 { + // Guard didn't tell us the required size, double the buffer + requiredSize = outputSize * 2 + } + + if requiredSize > maxOutputSize { + return nil, fmt.Errorf("guard requires buffer of %d bytes which exceeds maximum of %d bytes", requiredSize, maxOutputSize) + } + + logWasm.Printf("Buffer too small (%d bytes), retrying with %d bytes", outputSize, requiredSize) + outputSize = requiredSize + } + + return nil, fmt.Errorf("failed after %d attempts, buffer size %d still insufficient", maxRetries, outputSize) +} + +// tryCallWasmFunction attempts to call the WASM function with the given buffer size. +// Returns (result, 0, nil) on success. +// Returns (nil, requiredSize, nil) if buffer was too small. +// Returns (nil, 0, error) on actual error. +func (g *WasmGuard) tryCallWasmFunction(ctx context.Context, fn api.Function, mem api.Memory, inputJSON []byte, outputSize uint32) ([]byte, uint32, error) { + inputSize := uint32(len(inputJSON)) + + // Preferred path: use guard allocator if exported to avoid overlapping + // host-managed buffers with guard heap allocations. + allocFn := g.module.ExportedFunction("alloc") + deallocFn := g.module.ExportedFunction("dealloc") + if allocFn != nil { + // Use a non-cancelable context for cleanup to avoid leaking WASM heap + // allocations if the request context is canceled or times out. + cleanupCtx := context.WithoutCancel(ctx) + + inputPtr, err := g.wasmAlloc(ctx, allocFn, inputSize) + if err != nil { + return nil, 0, fmt.Errorf("failed to allocate WASM input buffer: %w", err) + } + defer g.wasmDealloc(cleanupCtx, deallocFn, inputPtr, inputSize) + + outputPtr, err := g.wasmAlloc(ctx, allocFn, outputSize) + if err != nil { + return nil, 0, fmt.Errorf("failed to allocate WASM output buffer: %w", err) + } + defer g.wasmDealloc(cleanupCtx, deallocFn, outputPtr, outputSize) + + if !mem.Write(inputPtr, inputJSON) { + return nil, 0, fmt.Errorf("failed to write input to WASM memory") + } + + results, err := fn.Call(ctx, + uint64(inputPtr), + uint64(inputSize), + uint64(outputPtr), + uint64(outputSize)) + if err != nil { + return nil, 0, fmt.Errorf("WASM function call failed: %w", err) + } + + resultLen := int32(results[0]) + if resultLen == -2 { + if requiredSize, ok := mem.ReadUint32Le(outputPtr); ok && requiredSize > 0 { + return nil, requiredSize, nil + } + return nil, 0, nil + } + + if resultLen < 0 { + return nil, 0, fmt.Errorf("WASM function returned error code: %d", resultLen) + } + + if resultLen == 0 { + return []byte{}, 0, nil + } + + outputJSON, ok := mem.Read(outputPtr, uint32(resultLen)) + if !ok { + return nil, 0, fmt.Errorf("failed to read output from WASM memory (len=%d)", resultLen) + } + + // Copy out of WASM linear memory before deferred dealloc runs. + resultCopy := append([]byte(nil), outputJSON...) + return resultCopy, 0, nil + } + + // Ensure memory is large enough for our buffers + // Layout: [...guard memory...][input buffer][output buffer] + requiredMemory := inputSize + outputSize + uint32(64*1024) // Extra 64KB for safety margin + + memSize := mem.Size() + if memSize < requiredMemory { + pages := (requiredMemory - memSize + 65535) / 65536 // Round up to pages + _, success := mem.Grow(pages) + if !success { + return nil, 0, fmt.Errorf("failed to grow WASM memory from %d to %d bytes", memSize, requiredMemory) + } + memSize = mem.Size() + } + + // Place buffers at end of memory + outputPtr := memSize - outputSize + inputPtr := outputPtr - inputSize + + // Write input to WASM memory + if !mem.Write(inputPtr, inputJSON) { + return nil, 0, fmt.Errorf("failed to write input to WASM memory") + } + + // Call the WASM function + results, err := fn.Call(ctx, + uint64(inputPtr), + uint64(inputSize), + uint64(outputPtr), + uint64(outputSize)) + if err != nil { + return nil, 0, fmt.Errorf("WASM function call failed: %w", err) + } + + // Check result + resultLen := int32(results[0]) + + // Error code -2 means "buffer too small" + // The guard can optionally return the required size in the output buffer as a uint32 + if resultLen == -2 { + // Try to read the required size from the output buffer (first 4 bytes as uint32) + if requiredSize, ok := mem.ReadUint32Le(outputPtr); ok && requiredSize > 0 { + return nil, requiredSize, nil + } + // Guard didn't specify size, return 0 to trigger doubling + return nil, 0, nil + } + + // Other negative values are errors + if resultLen < 0 { + return nil, 0, fmt.Errorf("WASM function returned error code: %d", resultLen) + } + + if resultLen == 0 { + return []byte{}, 0, nil + } + + // Read output from WASM memory + outputJSON, ok := mem.Read(outputPtr, uint32(resultLen)) + if !ok { + return nil, 0, fmt.Errorf("failed to read output from WASM memory (len=%d)", resultLen) + } + + // Copy out of WASM linear memory to avoid aliasing with future calls. + resultCopy := append([]byte(nil), outputJSON...) + return resultCopy, 0, nil +} + +// wasmAlloc allocates a buffer in WASM linear memory using the guard's exported alloc function. +func (g *WasmGuard) wasmAlloc(ctx context.Context, allocFn api.Function, size uint32) (uint32, error) { + results, err := allocFn.Call(ctx, uint64(size)) + if err != nil { + return 0, err + } + if len(results) == 0 { + return 0, fmt.Errorf("alloc returned no result") + } + ptr := uint32(results[0]) + if ptr == 0 { + return 0, fmt.Errorf("alloc returned null pointer") + } + return ptr, nil +} + +// wasmDealloc frees a WASM linear-memory allocation via the guard's exported dealloc function. +func (g *WasmGuard) wasmDealloc(ctx context.Context, deallocFn api.Function, ptr, size uint32) { + if deallocFn == nil || ptr == 0 || size == 0 { + return + } + if _, err := deallocFn.Call(ctx, uint64(ptr), uint64(size)); err != nil { + logWasm.Printf("WASM dealloc failed: ptr=%d size=%d err=%v", ptr, size, err) + } +} + +// parseResourceResponse converts the guard label_resource response to a LabeledResource. +func parseResourceResponse(response map[string]interface{}) (*difc.LabeledResource, difc.OperationType, error) { + resourceData, ok := response["resource"].(map[string]interface{}) + if !ok { + return nil, difc.OperationWrite, fmt.Errorf("invalid resource format in guard response") + } + + resource := &difc.LabeledResource{} + + if desc, ok := resourceData["description"].(string); ok { + resource.Description = desc + } + + // Parse secrecy tags + if secrecy, ok := resourceData["secrecy"].([]interface{}); ok { + tags := make([]difc.Tag, 0, len(secrecy)) + for _, t := range secrecy { + if tagStr, ok := t.(string); ok { + tags = append(tags, difc.Tag(tagStr)) + } + } + resource.Secrecy = *difc.NewSecrecyLabelWithTags(tags) + } else { + resource.Secrecy = *difc.NewSecrecyLabel() + } + + // Parse integrity tags + if integrity, ok := resourceData["integrity"].([]interface{}); ok { + tags := make([]difc.Tag, 0, len(integrity)) + for _, t := range integrity { + if tagStr, ok := t.(string); ok { + tags = append(tags, difc.Tag(tagStr)) + } + } + resource.Integrity = *difc.NewIntegrityLabelWithTags(tags) + } else { + resource.Integrity = *difc.NewIntegrityLabel() + } + + // Parse operation type + operation := difc.OperationWrite // default to most restrictive + if opStr, ok := response["operation"].(string); ok { + switch opStr { + case "read": + operation = difc.OperationRead + case "write": + operation = difc.OperationWrite + case "read-write": + operation = difc.OperationReadWrite + } + } + + return resource, operation, nil +} + +// parseCollectionLabeledData converts an array of items to CollectionLabeledData. +func parseCollectionLabeledData(items []interface{}) (*difc.CollectionLabeledData, error) { + collection := &difc.CollectionLabeledData{ + Items: make([]difc.LabeledItem, 0, len(items)), + } + + for _, item := range items { + itemMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + + labeledItem := difc.LabeledItem{ + Data: itemMap["data"], + } + + // Parse labels + if labelsData, ok := itemMap["labels"].(map[string]interface{}); ok { + labels := &difc.LabeledResource{} + + if desc, ok := labelsData["description"].(string); ok { + labels.Description = desc + } + + // Parse secrecy tags + if secrecy, ok := labelsData["secrecy"].([]interface{}); ok { + tags := make([]difc.Tag, 0, len(secrecy)) + for _, t := range secrecy { + if tagStr, ok := t.(string); ok { + tags = append(tags, difc.Tag(tagStr)) + } + } + labels.Secrecy = *difc.NewSecrecyLabelWithTags(tags) + } else { + labels.Secrecy = *difc.NewSecrecyLabel() + } + + // Parse integrity tags + if integrity, ok := labelsData["integrity"].([]interface{}); ok { + tags := make([]difc.Tag, 0, len(integrity)) + for _, t := range integrity { + if tagStr, ok := t.(string); ok { + tags = append(tags, difc.Tag(tagStr)) + } + } + labels.Integrity = *difc.NewIntegrityLabelWithTags(tags) + } else { + labels.Integrity = *difc.NewIntegrityLabel() + } + + labeledItem.Labels = labels + } + + collection.Items = append(collection.Items, labeledItem) + } + + return collection, nil +} diff --git a/internal/guard/wasm_payload.go b/internal/guard/wasm_payload.go new file mode 100644 index 00000000..ac417f29 --- /dev/null +++ b/internal/guard/wasm_payload.go @@ -0,0 +1,308 @@ +package guard + +import ( + "encoding/json" + "fmt" + "strings" +) + +// normalizePolicyPayload coerces a policy value to a map[string]interface{}. +// String inputs are JSON-parsed; non-object JSON values are rejected. +func normalizePolicyPayload(policy interface{}) (interface{}, error) { + if policy == nil { + return nil, fmt.Errorf("policy is required") + } + + if policyString, ok := policy.(string); ok { + trimmed := strings.TrimSpace(policyString) + if trimmed == "" { + return nil, fmt.Errorf("policy string is empty") + } + + var parsed interface{} + if err := json.Unmarshal([]byte(trimmed), &parsed); err != nil { + return nil, fmt.Errorf("policy string is not valid JSON object: %w", err) + } + + switch parsed.(type) { + case map[string]interface{}: + return parsed, nil + default: + return nil, fmt.Errorf("policy JSON must decode to an object") + } + } + + return policy, nil +} + +// buildStrictLabelAgentPayload validates the normalised policy and returns a +// map ready to be serialised as the label_agent input payload. +func buildStrictLabelAgentPayload(policy interface{}) (map[string]interface{}, error) { + if policy == nil { + return nil, fmt.Errorf("invalid guard policy transport shape: expected {\"allow-only\":{\"repos\":...,\"min-integrity\":...}}") + } + + if policyMap, ok := policy.(map[string]interface{}); ok { + if nested, hasPolicy := policyMap["policy"]; hasPolicy { + if nestedMap, nestedOK := nested.(map[string]interface{}); nestedOK { + if _, hasAllowOnly := nestedMap["allow-only"]; hasAllowOnly { + return nil, fmt.Errorf("gateway policy adapter is outdated: remove legacy envelope key policy before calling label_agent") + } + } + } + } + + payload, err := PolicyToMap(policy) + if err != nil { + return nil, fmt.Errorf("failed to decode label_agent policy payload: %w", err) + } + + if _, hasPolicyEnvelope := payload["policy"]; hasPolicyEnvelope { + return nil, fmt.Errorf("gateway policy adapter is outdated: remove legacy envelope key policy before calling label_agent") + } + + allowOnlyRaw, ok := payload["allow-only"] + if !ok { + // Accept legacy "allowonly" form for backward compatibility + allowOnlyRaw, ok = payload["allowonly"] + } + if !ok { + return nil, fmt.Errorf("label_agent policy must use top-level allow-only object (received policy.allow-only)") + } + + // Validate that the only allowed top-level keys are "allow-only" (or legacy "allowonly") + // and the optional "trusted-bots" key. + for k := range payload { + switch k { + case "allow-only", "allowonly", "trusted-bots": + // valid top-level keys + default: + return nil, fmt.Errorf("invalid guard policy transport shape: unexpected key %q", k) + } + } + + allowOnly, ok := allowOnlyRaw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid guard policy transport shape: expected {\"allow-only\":{\"repos\":...,\"min-integrity\":...}}") + } + + reposRaw, hasRepos := allowOnly["repos"] + integrityRaw, hasIntegrity := allowOnly["min-integrity"] + if !hasIntegrity { + integrityRaw, hasIntegrity = allowOnly["integrity"] + } + if !hasRepos || !hasIntegrity { + return nil, fmt.Errorf("invalid guard policy transport shape: missing required fields repos and/or min-integrity in allow-only object") + } + + // Validate that the allow-only object contains only known keys. + for k := range allowOnly { + switch k { + case "repos", "min-integrity", "integrity", "blocked-users", "approval-labels", "trusted-users", + "endorsement-reactions", "disapproval-reactions", "disapproval-integrity", "endorser-min-integrity": + // valid allow-only keys + default: + return nil, fmt.Errorf("invalid guard policy transport shape: unexpected allow-only key %q", k) + } + } + + if !isValidAllowOnlyRepos(reposRaw) { + return nil, fmt.Errorf("invalid repos value: expected all, public, or non-empty array of scoped strings") + } + + integrity, ok := integrityRaw.(string) + if !ok { + return nil, fmt.Errorf("invalid integrity value: expected one of none|unapproved|approved|merged") + } + + switch strings.ToLower(strings.TrimSpace(integrity)) { + case "none", "unapproved", "approved", "merged": + default: + return nil, fmt.Errorf("invalid integrity value: expected one of none|unapproved|approved|merged") + } + + // Validate blocked-users if present: must be a non-empty array of non-empty strings. + if blockedUsersRaw, ok := allowOnly["blocked-users"]; ok { + arr, ok := blockedUsersRaw.([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid blocked-users value: expected array of strings") + } + for _, entry := range arr { + if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { + return nil, fmt.Errorf("invalid blocked-users value: each entry must be a non-empty string") + } + } + } + + // Validate approval-labels if present: must be a non-empty array of non-empty strings. + if approvalLabelsRaw, ok := allowOnly["approval-labels"]; ok { + arr, ok := approvalLabelsRaw.([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid approval-labels value: expected array of strings") + } + for _, entry := range arr { + if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { + return nil, fmt.Errorf("invalid approval-labels value: each entry must be a non-empty string") + } + } + } + + // Validate trusted-bots if present. + // Per spec §4.1.3.4: trustedBots MUST be a non-empty array of strings when present. + if trustedBotsRaw, hasTrustedBots := payload["trusted-bots"]; hasTrustedBots { + trustedBots, ok := trustedBotsRaw.([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid trusted-bots value: expected non-empty array of strings") + } + if len(trustedBots) == 0 { + return nil, fmt.Errorf("invalid trusted-bots value: must be a non-empty array when present") + } + for _, entry := range trustedBots { + if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { + return nil, fmt.Errorf("invalid trusted-bots value: each entry must be a non-empty string") + } + } + } + + // Validate trusted-users if present inside allow-only. + // Must be a non-empty array of non-empty strings when present. + if trustedUsersRaw, ok := allowOnly["trusted-users"]; ok { + arr, ok := trustedUsersRaw.([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid trusted-users value: expected array of strings") + } + for _, entry := range arr { + if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { + return nil, fmt.Errorf("invalid trusted-users value: each entry must be a non-empty string") + } + } + } + + // Validate endorsement-reactions and disapproval-reactions if present. + for _, reactionKey := range []string{"endorsement-reactions", "disapproval-reactions"} { + if reactionsRaw, ok := allowOnly[reactionKey]; ok { + arr, ok := reactionsRaw.([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid %s value: expected array of strings", reactionKey) + } + for _, entry := range arr { + if s, ok := entry.(string); !ok || strings.TrimSpace(s) == "" { + return nil, fmt.Errorf("invalid %s value: each entry must be a non-empty string", reactionKey) + } + } + } + } + + // Validate disapproval-integrity if present. + if disIntRaw, ok := allowOnly["disapproval-integrity"]; ok { + disInt, ok := disIntRaw.(string) + if !ok { + return nil, fmt.Errorf("invalid disapproval-integrity value: expected one of none|unapproved|approved|merged") + } + switch strings.ToLower(strings.TrimSpace(disInt)) { + case "none", "unapproved", "approved", "merged": + default: + return nil, fmt.Errorf("invalid disapproval-integrity value: expected one of none|unapproved|approved|merged") + } + } + + // Validate endorser-min-integrity if present. + if endMinRaw, ok := allowOnly["endorser-min-integrity"]; ok { + endMin, ok := endMinRaw.(string) + if !ok { + return nil, fmt.Errorf("invalid endorser-min-integrity value: expected one of none|unapproved|approved|merged") + } + switch strings.ToLower(strings.TrimSpace(endMin)) { + case "none", "unapproved", "approved", "merged": + default: + return nil, fmt.Errorf("invalid endorser-min-integrity value: expected one of none|unapproved|approved|merged") + } + } + + return payload, nil +} + +// BuildLabelAgentPayload constructs the label_agent input payload from the given guard policy +// and optional lists of additional trusted bot usernames and trusted user logins. The trusted +// bots are merged with the guard's built-in list and cannot remove any built-in entries. If +// both trustedBots and trustedUsers are nil or empty, the returned payload contains only the +// allow-only policy. +func BuildLabelAgentPayload(policy interface{}, trustedBots []string, trustedUsers []string) interface{} { + if len(trustedBots) == 0 && len(trustedUsers) == 0 { + return policy + } + + // Convert the policy to a generic map so we can inject the trusted-bots and + // trusted-users keys alongside the allow-only policy without altering the + // policy itself. + payload, err := PolicyToMap(policy) + if err != nil { + // If we can't convert the policy, return it as-is; buildStrictLabelAgentPayload + // will surface the error later. + return policy + } + + if len(trustedBots) > 0 { + // trusted-bots is a top-level key in the label_agent payload. + // Convert []string to []interface{} for JSON compatibility. + bots := make([]interface{}, len(trustedBots)) + for i, b := range trustedBots { + bots[i] = b + } + payload["trusted-bots"] = bots + } + + if len(trustedUsers) > 0 { + // trusted-users is injected inside the allow-only object. + // Convert []string to []interface{} for JSON compatibility. + // If allow-only is absent, the injection is skipped and buildStrictLabelAgentPayload + // will reject the payload when called with the missing allow-only key. + users := make([]interface{}, len(trustedUsers)) + for i, u := range trustedUsers { + users[i] = u + } + // Inject into allow-only object if present + if allowOnly, ok := payload["allow-only"].(map[string]interface{}); ok { + allowOnly["trusted-users"] = users + } + } + + return payload +} + +// isValidAllowOnlyRepos returns true if repos is either a recognised string +// shorthand ("all" or "public") or a non-empty array of strings. +func isValidAllowOnlyRepos(repos interface{}) bool { + switch value := repos.(type) { + case string: + trimmed := strings.TrimSpace(strings.ToLower(value)) + return trimmed == "all" || trimmed == "public" + case []interface{}: + if len(value) == 0 { + return false + } + for _, entry := range value { + if _, ok := entry.(string); !ok { + return false + } + } + return true + default: + return false + } +} + +// checkBoolFailure returns a non-nil error if the given raw response map +// contains field key set to false, extracting the "error" message if present. +func checkBoolFailure(raw map[string]interface{}, resultJSON []byte, key string) error { + val, ok := raw[key].(bool) + if !ok || val { + return nil // field absent or true — not a failure + } + if message, msgOK := raw["error"].(string); msgOK && strings.TrimSpace(message) != "" { + logWasm.Printf("label_agent response indicated failure: error=%s, response=%s", message, string(resultJSON)) + return fmt.Errorf("label_agent rejected policy: %s", message) + } + logWasm.Printf("label_agent response indicated non-success status: response=%s", string(resultJSON)) + return fmt.Errorf("label_agent returned non-success status") +}