diff --git a/constraint/pkg/client/drivers/rego/builtin.go b/constraint/pkg/client/drivers/rego/builtin.go index 7ea46eb6f..e0ff62d5f 100644 --- a/constraint/pkg/client/drivers/rego/builtin.go +++ b/constraint/pkg/client/drivers/rego/builtin.go @@ -1,9 +1,15 @@ package rego import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" + "fmt" "net/http" + "net/url" "time" + "github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned" "github.com/open-policy-agent/frameworks/constraint/pkg/externaldata" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" @@ -12,6 +18,9 @@ import ( const ( providerResponseAPIVersion = "externaldata.gatekeeper.sh/v1beta1" providerResponseKind = "ProviderResponse" + HTTPSScheme = "https" + idleConnTimeout = 90 * time.Second + maxIdleConnsPerHost = 100 ) func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *ast.Term) (*ast.Term, error) { @@ -31,6 +40,12 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest * return externaldata.HandleError(http.StatusBadRequest, err) } + client, err := getClient(&provider, clientCert) + if err != nil { + return externaldata.HandleError(http.StatusInternalServerError, + fmt.Errorf("failed to get HTTP client: %w", err)) + } + // check provider response cache var providerRequestKeys []string var providerResponseStatusCode int @@ -71,7 +86,7 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest * } if len(providerRequestKeys) > 0 { - externaldataResponse, statusCode, err := d.sendRequestToProvider(bctx.Context, &provider, providerRequestKeys, clientCert) + externaldataResponse, statusCode, err := d.sendRequestToProvider(bctx.Context, &provider, providerRequestKeys, client) if err != nil { return externaldata.HandleError(statusCode, err) } @@ -115,3 +130,49 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest * return externaldata.PrepareRegoResponse(regoResponse) } } + +// getClient returns a new HTTP client, and set up its TLS configuration. +func getClient(provider *unversioned.Provider, clientCert *tls.Certificate) (*http.Client, error) { + u, err := url.Parse(provider.Spec.URL) + if err != nil { + return nil, fmt.Errorf("failed to parse provider URL %s: %w", provider.Spec.URL, err) + } + + if u.Scheme != HTTPSScheme { + return nil, fmt.Errorf("only HTTPS scheme is supported") + } + + client := &http.Client{ + Timeout: time.Duration(provider.Spec.Timeout) * time.Second, + } + + tlsConfig := &tls.Config{MinVersion: tls.VersionTLS13} + + // present our client cert to the server + // in case provider wants to verify it + if clientCert != nil { + tlsConfig.Certificates = []tls.Certificate{*clientCert} + } + + // if the provider presents its own CA bundle, + // we will use it to verify the server's certificate + caBundleData, err := base64.StdEncoding.DecodeString(provider.Spec.CABundle) + if err != nil { + return nil, fmt.Errorf("failed to decode CA bundle: %w", err) + } + + providerCertPool := x509.NewCertPool() + if ok := providerCertPool.AppendCertsFromPEM(caBundleData); !ok { + return nil, fmt.Errorf("failed to append provider's CA bundle to certificate pool") + } + + tlsConfig.RootCAs = providerCertPool + + client.Transport = &http.Transport{ + TLSClientConfig: tlsConfig, + IdleConnTimeout: idleConnTimeout, + MaxIdleConnsPerHost: maxIdleConnsPerHost, + } + + return client, nil +} diff --git a/constraint/pkg/client/drivers/rego/builtin_test.go b/constraint/pkg/client/drivers/rego/builtin_test.go new file mode 100644 index 000000000..ca155f2b2 --- /dev/null +++ b/constraint/pkg/client/drivers/rego/builtin_test.go @@ -0,0 +1,85 @@ +package rego + +import ( + "crypto/tls" + "testing" + + "github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned" +) + +const ( + validCABundle = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUIwekNDQVgyZ0F3SUJBZ0lKQUkvTTdCWWp3Qit1TUEwR0NTcUdTSWIzRFFFQkJRVUFNRVV4Q3pBSkJnTlYKQkFZVEFrRlZNUk13RVFZRFZRUUlEQXBUYjIxbExWTjBZWFJsTVNFd0h3WURWUVFLREJoSmJuUmxjbTVsZENCWAphV1JuYVhSeklGQjBlU0JNZEdRd0hoY05NVEl3T1RFeU1qRTFNakF5V2hjTk1UVXdPVEV5TWpFMU1qQXlXakJGCk1Rc3dDUVlEVlFRR0V3SkJWVEVUTUJFR0ExVUVDQXdLVTI5dFpTMVRkR0YwWlRFaE1COEdBMVVFQ2d3WVNXNTAKWlhKdVpYUWdWMmxrWjJsMGN5QlFkSGtnVEhSa01Gd3dEUVlKS29aSWh2Y05BUUVCQlFBRFN3QXdTQUpCQU5MSgpoUEhoSVRxUWJQa2xHM2liQ1Z4d0dNUmZwL3Y0WHFoZmRRSGRjVmZIYXA2TlE1V29rLzR4SUErdWkzNS9NbU5hCnJ0TnVDK0JkWjF0TXVWQ1BGWmNDQXdFQUFhTlFNRTR3SFFZRFZSME9CQllFRkp2S3M4UmZKYVhUSDA4VytTR3YKelF5S24wSDhNQjhHQTFVZEl3UVlNQmFBRkp2S3M4UmZKYVhUSDA4VytTR3Z6UXlLbjBIOE1Bd0dBMVVkRXdRRgpNQU1CQWY4d0RRWUpLb1pJaHZjTkFRRUZCUUFEUVFCSmxmZkpIeWJqREd4Uk1xYVJtRGhYMCs2djAyVFVLWnNXCnI1UXVWYnBRaEg2dSswVWdjVzBqcDlRd3B4b1BUTFRXR1hFV0JCQnVyeEZ3aUNCaGtRK1YKLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=" + badCABundle = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCmhlbGxvCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K" +) + +func Test_getClient(t *testing.T) { + type args struct { + provider *unversioned.Provider + clientCert *tls.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "invalid http url", + args: args{ + provider: &unversioned.Provider{ + Spec: unversioned.ProviderSpec{ + URL: "http://foo", + }, + }, + clientCert: nil, + }, + wantErr: true, + }, + { + name: "no CA bundle", + args: args{ + provider: &unversioned.Provider{ + Spec: unversioned.ProviderSpec{ + URL: "https://foo", + }, + }, + clientCert: nil, + }, + wantErr: true, + }, + { + name: "invalid CA bundle", + args: args{ + provider: &unversioned.Provider{ + Spec: unversioned.ProviderSpec{ + URL: "https://foo", + CABundle: badCABundle, + }, + }, + clientCert: nil, + }, + wantErr: true, + }, + { + name: "valid CA bundle", + args: args{ + provider: &unversioned.Provider{ + Spec: unversioned.ProviderSpec{ + URL: "https://foo", + CABundle: validCABundle, + }, + }, + clientCert: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getClient(tt.args.provider, tt.args.clientCert) + if (err != nil) != tt.wantErr { + t.Errorf("getClient() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/constraint/pkg/client/drivers/rego/driver_unit_test.go b/constraint/pkg/client/drivers/rego/driver_unit_test.go index c17d00c40..5e0f10207 100644 --- a/constraint/pkg/client/drivers/rego/driver_unit_test.go +++ b/constraint/pkg/client/drivers/rego/driver_unit_test.go @@ -2,7 +2,6 @@ package rego import ( "context" - "crypto/tls" "errors" "fmt" "net/http" @@ -676,7 +675,7 @@ func TestDriver_ExternalData(t *testing.T) { }, clientCertContent: clientCert, clientKeyContent: clientKey, - sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *tls.Certificate) (*externaldata.ProviderResponse, int, error) { + sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *http.Client) (*externaldata.ProviderResponse, int, error) { return nil, http.StatusBadRequest, errors.New("error from SendRequestToProvider") }, errorExpected: true, @@ -695,7 +694,7 @@ func TestDriver_ExternalData(t *testing.T) { }, clientCertContent: clientCert, clientKeyContent: clientKey, - sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *tls.Certificate) (*externaldata.ProviderResponse, int, error) { + sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *http.Client) (*externaldata.ProviderResponse, int, error) { return &externaldata.ProviderResponse{ APIVersion: "v1beta1", Kind: "Provider", diff --git a/constraint/pkg/externaldata/cache.go b/constraint/pkg/externaldata/cache.go index c6c996073..b1cedfe04 100644 --- a/constraint/pkg/externaldata/cache.go +++ b/constraint/pkg/externaldata/cache.go @@ -14,6 +14,10 @@ import ( "k8s.io/apimachinery/pkg/util/wait" ) +const ( + HTTPSScheme = "https" +) + type ProviderCache struct { cache map[string]unversioned.Provider mux sync.RWMutex diff --git a/constraint/pkg/externaldata/request.go b/constraint/pkg/externaldata/request.go index f30f9967b..4306b3204 100644 --- a/constraint/pkg/externaldata/request.go +++ b/constraint/pkg/externaldata/request.go @@ -3,24 +3,15 @@ package externaldata import ( "bytes" "context" - "crypto/tls" - "crypto/x509" - "encoding/base64" "encoding/json" "fmt" "io" "net/http" - "net/url" "time" "github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned" ) -const ( - HTTPScheme = "http" - HTTPSScheme = "https" -) - // RegoRequest is the request for external_data rego function. type RegoRequest struct { // ProviderName is the name of the external data provider. @@ -57,17 +48,16 @@ func NewProviderRequest(keys []string) *ProviderRequest { } // SendRequestToProvider is a function that sends a request to the external data provider. -type SendRequestToProvider func(ctx context.Context, provider *unversioned.Provider, keys []string, clientCert *tls.Certificate) (*ProviderResponse, int, error) +type SendRequestToProvider func(ctx context.Context, provider *unversioned.Provider, keys []string, client *http.Client) (*ProviderResponse, int, error) // DefaultSendRequestToProvider is the default function to send the request to the external data provider. -func DefaultSendRequestToProvider(ctx context.Context, provider *unversioned.Provider, keys []string, clientCert *tls.Certificate) (*ProviderResponse, int, error) { +func DefaultSendRequestToProvider(ctx context.Context, provider *unversioned.Provider, keys []string, client *http.Client) (*ProviderResponse, int, error) { externaldataRequest := NewProviderRequest(keys) body, err := json.Marshal(externaldataRequest) if err != nil { return nil, http.StatusInternalServerError, fmt.Errorf("failed to marshal external data request: %w", err) } - client, err := getClient(provider, clientCert) if err != nil { return nil, http.StatusInternalServerError, fmt.Errorf("failed to get HTTP client: %w", err) } @@ -100,50 +90,6 @@ func DefaultSendRequestToProvider(ctx context.Context, provider *unversioned.Pro return &externaldataResponse, resp.StatusCode, nil } -// getClient returns a new HTTP client, and set up its TLS configuration. -func getClient(provider *unversioned.Provider, clientCert *tls.Certificate) (*http.Client, error) { - u, err := url.Parse(provider.Spec.URL) - if err != nil { - return nil, fmt.Errorf("failed to parse provider URL %s: %w", provider.Spec.URL, err) - } - - if u.Scheme != HTTPSScheme { - return nil, fmt.Errorf("only HTTPS scheme is supported") - } - - client := &http.Client{ - Timeout: time.Duration(provider.Spec.Timeout) * time.Second, - } - - tlsConfig := &tls.Config{MinVersion: tls.VersionTLS13} - - // present our client cert to the server - // in case provider wants to verify it - if clientCert != nil { - tlsConfig.Certificates = []tls.Certificate{*clientCert} - } - - // if the provider presents its own CA bundle, - // we will use it to verify the server's certificate - caBundleData, err := base64.StdEncoding.DecodeString(provider.Spec.CABundle) - if err != nil { - return nil, fmt.Errorf("failed to decode CA bundle: %w", err) - } - - providerCertPool := x509.NewCertPool() - if ok := providerCertPool.AppendCertsFromPEM(caBundleData); !ok { - return nil, fmt.Errorf("failed to append provider's CA bundle to certificate pool") - } - - tlsConfig.RootCAs = providerCertPool - - client.Transport = &http.Transport{ - TLSClientConfig: tlsConfig, - } - - return client, nil -} - // ProviderKind strings are special string constants for Providers. // +kubebuilder:validation:Enum=ProviderRequestKind;ProviderResponseKind type ProviderKind string diff --git a/constraint/pkg/externaldata/request_test.go b/constraint/pkg/externaldata/request_test.go index 6c9656eba..f267fbd55 100644 --- a/constraint/pkg/externaldata/request_test.go +++ b/constraint/pkg/externaldata/request_test.go @@ -1,11 +1,8 @@ package externaldata import ( - "crypto/tls" "reflect" "testing" - - "github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned" ) func TestNewProviderRequest(t *testing.T) { @@ -65,75 +62,3 @@ func TestNewProviderRequest(t *testing.T) { }) } } - -func Test_getClient(t *testing.T) { - type args struct { - provider *unversioned.Provider - clientCert *tls.Certificate - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "invalid http url", - args: args{ - provider: &unversioned.Provider{ - Spec: unversioned.ProviderSpec{ - URL: "http://foo", - }, - }, - clientCert: nil, - }, - wantErr: true, - }, - { - name: "no CA bundle", - args: args{ - provider: &unversioned.Provider{ - Spec: unversioned.ProviderSpec{ - URL: "https://foo", - }, - }, - clientCert: nil, - }, - wantErr: true, - }, - { - name: "invalid CA bundle", - args: args{ - provider: &unversioned.Provider{ - Spec: unversioned.ProviderSpec{ - URL: "https://foo", - CABundle: badCABundle, - }, - }, - clientCert: nil, - }, - wantErr: true, - }, - { - name: "valid CA bundle", - args: args{ - provider: &unversioned.Provider{ - Spec: unversioned.ProviderSpec{ - URL: "https://foo", - CABundle: validCABundle, - }, - }, - clientCert: nil, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := getClient(tt.args.provider, tt.args.clientCert) - if (err != nil) != tt.wantErr { - t.Errorf("getClient() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -}