diff --git a/auth.go b/auth.go index 0b3cc71..e2f0300 100644 --- a/auth.go +++ b/auth.go @@ -2,28 +2,69 @@ package gocbcoreps import ( "context" + "crypto/tls" "encoding/base64" - - "google.golang.org/grpc/credentials" + "sync/atomic" ) -type GrpcBasicAuth struct { - EncodedData string +type Authenticator interface { + isAuthenticator() +} + +type BasicAuthenticator struct { + encodedData atomic.Pointer[string] } -// NewJWTAccessFromKey creates PerRPCCredentials from the given jsonKey. -func NewGrpcBasicAuth(username, password string) (credentials.PerRPCCredentials, error) { +// NewBasicAuthenticator creates PerRPCCredentials from the given username and password. +func NewBasicAuthenticator(username, password string) *BasicAuthenticator { basicAuth := username + ":" + password authValue := base64.StdEncoding.EncodeToString([]byte(basicAuth)) - return GrpcBasicAuth{authValue}, nil + + auth := &BasicAuthenticator{} + + auth.encodedData.Store(&authValue) + + return auth } -func (j GrpcBasicAuth) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { +func (j *BasicAuthenticator) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + encodedData := j.encodedData.Load() return map[string]string{ - "authorization": "Basic " + j.EncodedData, + "authorization": "Basic " + *encodedData, }, nil } -func (j GrpcBasicAuth) RequireTransportSecurity() bool { +func (j *BasicAuthenticator) RequireTransportSecurity() bool { return false } + +func (j *BasicAuthenticator) UpdateCredentials(username, password string) { + basicAuth := username + ":" + password + authValue := base64.StdEncoding.EncodeToString([]byte(basicAuth)) + + j.encodedData.Store(&authValue) +} + +func (j *BasicAuthenticator) isAuthenticator() {} + +type CertificateAuthenticator struct { + certificate atomic.Pointer[tls.Certificate] +} + +func NewCertificateAuthenticator(cert *tls.Certificate) *CertificateAuthenticator { + auth := &CertificateAuthenticator{} + auth.certificate.Store(cert) + return auth +} + +func (j *CertificateAuthenticator) GetClientCertificate(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { + cert := j.certificate.Load() + + return cert, nil +} + +func (j *CertificateAuthenticator) UpdateCertificate(cert *tls.Certificate) { + j.certificate.Store(cert) +} + +func (j *CertificateAuthenticator) isAuthenticator() {} diff --git a/error.go b/error.go new file mode 100644 index 0000000..2deff5e --- /dev/null +++ b/error.go @@ -0,0 +1,11 @@ +package gocbcoreps + +import "errors" + +var ( + // ErrAuthenticatorMismatch is returned when there is a mismatch between authenticators. + ErrAuthenticatorMismatch = errors.New("authenticator mismatch") + + // ErrAuthenticatorUnsupported is returned when a unsupported authenticator is specified. + ErrAuthenticatorUnsupported = errors.New("authenticator unsupported") +) diff --git a/routingclient.go b/routingclient.go index 58090fb..3e92194 100644 --- a/routingclient.go +++ b/routingclient.go @@ -2,7 +2,6 @@ package gocbcoreps import ( "context" - "crypto/tls" "crypto/x509" "net" "sync" @@ -39,6 +38,7 @@ type RoutingClient struct { lock sync.Mutex buckets map[string]*routingClient_Bucket logger *zap.Logger + auth Authenticator } // Verify that RoutingClient implements Conn @@ -46,9 +46,7 @@ var _ Conn = (*RoutingClient)(nil) type DialOptions struct { RootCAs *x509.CertPool - Certificate *tls.Certificate - Username string - Password string + Authenticator Authenticator Logger *zap.Logger InsecureSkipVerify bool PoolSize uint32 @@ -88,9 +86,7 @@ func DialContext(ctx context.Context, target string, opts *DialOptions) (*Routin for i := uint32(0); i < poolSize; i++ { conn, err := dialRoutingConn(ctx, target, &routingConnOptions{ RootCAs: opts.RootCAs, - Certificate: opts.Certificate, - Username: opts.Username, - Password: opts.Password, + Authenticator: opts.Authenticator, InsecureSkipVerify: opts.InsecureSkipVerify, TracerProvider: opts.TracerProvider, MeterProvider: opts.MeterProvider, @@ -111,6 +107,7 @@ func DialContext(ctx context.Context, target string, opts *DialOptions) (*Routin routing: routing, buckets: make(map[string]*routingClient_Bucket), logger: logger, + auth: opts.Authenticator, }, nil } @@ -166,6 +163,39 @@ func (c *RoutingClient) CloseBucket(bucketName string) { c.lock.Unlock() } +type ReconfigureAuthenticatorOptions struct { + Authenticator Authenticator +} + +func (c *RoutingClient) ReconfigureAuthenticator(opts ReconfigureAuthenticatorOptions) error { + auth := opts.Authenticator + c.lock.Lock() + defer c.lock.Unlock() + + switch a := c.auth.(type) { + case *BasicAuthenticator: + switch na := auth.(type) { + case *BasicAuthenticator: + data := na.encodedData.Load() + a.encodedData.Store(data) + default: + return ErrAuthenticatorMismatch + } + case *CertificateAuthenticator: + switch na := auth.(type) { + case *CertificateAuthenticator: + cert := na.certificate.Load() + a.certificate.Store(cert) + default: + return ErrAuthenticatorMismatch + } + default: + return ErrAuthenticatorUnsupported + } + + return nil +} + func (c *RoutingClient) ConnectionState() ConnState { r := c.routing.Load() diff --git a/routingconn.go b/routingconn.go index 237b9d7..2d16a3c 100644 --- a/routingconn.go +++ b/routingconn.go @@ -4,8 +4,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" - "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" @@ -34,9 +32,7 @@ import ( type routingConnOptions struct { InsecureSkipVerify bool // used for enabling TLS, but skipping verification RootCAs *x509.CertPool - Certificate *tls.Certificate - Username string - Password string + Authenticator Authenticator TracerProvider trace.TracerProvider MeterProvider metric.MeterProvider } @@ -62,25 +58,13 @@ const maxMsgSize = 26214400 // 25MiB func dialRoutingConn(ctx context.Context, address string, opts *routingConnOptions) (*routingConn, error) { var perRpcDialOpt grpc.DialOption + var getClientCertificate func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) - // setup basic auth. - if opts.Username != "" && opts.Password != "" { - basicAuthCreds, err := NewGrpcBasicAuth(opts.Username, opts.Password) - if err != nil { - return nil, err - } - perRpcDialOpt = grpc.WithPerRPCCredentials(basicAuthCreds) - } else { - perRpcDialOpt = nil - } - - var certificates []tls.Certificate - if opts.Certificate != nil { - if perRpcDialOpt != nil { - return nil, errors.New("cannot use basic credentials and client cert auth at the same time") - } - - certificates = append(certificates, *opts.Certificate) + switch a := opts.Authenticator.(type) { + case *BasicAuthenticator: + perRpcDialOpt = grpc.WithPerRPCCredentials(a) + case *CertificateAuthenticator: + getClientCertificate = a.GetClientCertificate } pool, err := x509.SystemCertPool() @@ -94,9 +78,9 @@ func dialRoutingConn(ctx context.Context, address string, opts *routingConnOptio dialOpts := []grpc.DialOption{grpc.WithTransportCredentials(credentials.NewTLS( &tls.Config{ - InsecureSkipVerify: opts.InsecureSkipVerify, - RootCAs: pool, - Certificates: certificates, + InsecureSkipVerify: opts.InsecureSkipVerify, + RootCAs: pool, + GetClientCertificate: getClientCertificate, }, ))} if perRpcDialOpt != nil {