Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,69 @@ package gocbcoreps

import (
"context"
"crypto/tls"
"encoding/base64"

"google.golang.org/grpc/credentials"
"sync/atomic"
)

type GrpcBasicAuth struct {
EncodedData string
type Authenticator interface {
isAuthenticator()
}

type BasicAuthenticator struct {
encodedData atomic.Pointer[string]
}

// NewJWTAccessFromKey creates PerRPCCredentials from the given jsonKey.
func NewGrpcBasicAuth(username, password string) (credentials.PerRPCCredentials, error) {
// NewBasicAuthenticator creates PerRPCCredentials from the given username and password.
func NewBasicAuthenticator(username, password string) *BasicAuthenticator {
basicAuth := username + ":" + password
authValue := base64.StdEncoding.EncodeToString([]byte(basicAuth))
return GrpcBasicAuth{authValue}, nil

auth := &BasicAuthenticator{}

auth.encodedData.Store(&authValue)

return auth
}

func (j GrpcBasicAuth) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
func (j *BasicAuthenticator) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
encodedData := j.encodedData.Load()
return map[string]string{
"authorization": "Basic " + j.EncodedData,
"authorization": "Basic " + *encodedData,
}, nil
}

func (j GrpcBasicAuth) RequireTransportSecurity() bool {
func (j *BasicAuthenticator) RequireTransportSecurity() bool {
return false
}

func (j *BasicAuthenticator) UpdateCredentials(username, password string) {
basicAuth := username + ":" + password
authValue := base64.StdEncoding.EncodeToString([]byte(basicAuth))

j.encodedData.Store(&authValue)
}

func (j *BasicAuthenticator) isAuthenticator() {}

type CertificateAuthenticator struct {
certificate atomic.Pointer[tls.Certificate]
}

func NewCertificateAuthenticator(cert *tls.Certificate) *CertificateAuthenticator {
auth := &CertificateAuthenticator{}
auth.certificate.Store(cert)
return auth
}

func (j *CertificateAuthenticator) GetClientCertificate(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert := j.certificate.Load()

return cert, nil
}

func (j *CertificateAuthenticator) UpdateCertificate(cert *tls.Certificate) {
j.certificate.Store(cert)
}

func (j *CertificateAuthenticator) isAuthenticator() {}
11 changes: 11 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package gocbcoreps

import "errors"

var (
// ErrAuthenticatorMismatch is returned when there is a mismatch between authenticators.
ErrAuthenticatorMismatch = errors.New("authenticator mismatch")

// ErrAuthenticatorUnsupported is returned when a unsupported authenticator is specified.
ErrAuthenticatorUnsupported = errors.New("authenticator unsupported")
)
44 changes: 37 additions & 7 deletions routingclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gocbcoreps

import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"sync"
Expand Down Expand Up @@ -39,16 +38,15 @@ type RoutingClient struct {
lock sync.Mutex
buckets map[string]*routingClient_Bucket
logger *zap.Logger
auth Authenticator
}

// Verify that RoutingClient implements Conn
var _ Conn = (*RoutingClient)(nil)

type DialOptions struct {
RootCAs *x509.CertPool
Certificate *tls.Certificate
Username string
Password string
Authenticator Authenticator
Logger *zap.Logger
InsecureSkipVerify bool
PoolSize uint32
Expand Down Expand Up @@ -88,9 +86,7 @@ func DialContext(ctx context.Context, target string, opts *DialOptions) (*Routin
for i := uint32(0); i < poolSize; i++ {
conn, err := dialRoutingConn(ctx, target, &routingConnOptions{
RootCAs: opts.RootCAs,
Certificate: opts.Certificate,
Username: opts.Username,
Password: opts.Password,
Authenticator: opts.Authenticator,
InsecureSkipVerify: opts.InsecureSkipVerify,
TracerProvider: opts.TracerProvider,
MeterProvider: opts.MeterProvider,
Expand All @@ -111,6 +107,7 @@ func DialContext(ctx context.Context, target string, opts *DialOptions) (*Routin
routing: routing,
buckets: make(map[string]*routingClient_Bucket),
logger: logger,
auth: opts.Authenticator,
}, nil
}

Expand Down Expand Up @@ -166,6 +163,39 @@ func (c *RoutingClient) CloseBucket(bucketName string) {
c.lock.Unlock()
}

type ReconfigureAuthenticatorOptions struct {
Authenticator Authenticator
}

func (c *RoutingClient) ReconfigureAuthenticator(opts ReconfigureAuthenticatorOptions) error {
auth := opts.Authenticator
c.lock.Lock()
defer c.lock.Unlock()

switch a := c.auth.(type) {
case *BasicAuthenticator:
switch na := auth.(type) {
case *BasicAuthenticator:
data := na.encodedData.Load()
a.encodedData.Store(data)
default:
return ErrAuthenticatorMismatch
}
case *CertificateAuthenticator:
switch na := auth.(type) {
case *CertificateAuthenticator:
cert := na.certificate.Load()
a.certificate.Store(cert)
default:
return ErrAuthenticatorMismatch
}
default:
return ErrAuthenticatorUnsupported
}

return nil
}

func (c *RoutingClient) ConnectionState() ConnState {
r := c.routing.Load()

Expand Down
36 changes: 10 additions & 26 deletions routingconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"

"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -34,9 +32,7 @@ import (
type routingConnOptions struct {
InsecureSkipVerify bool // used for enabling TLS, but skipping verification
RootCAs *x509.CertPool
Certificate *tls.Certificate
Username string
Password string
Authenticator Authenticator
TracerProvider trace.TracerProvider
MeterProvider metric.MeterProvider
}
Expand All @@ -62,25 +58,13 @@ const maxMsgSize = 26214400 // 25MiB

func dialRoutingConn(ctx context.Context, address string, opts *routingConnOptions) (*routingConn, error) {
var perRpcDialOpt grpc.DialOption
var getClientCertificate func(info *tls.CertificateRequestInfo) (*tls.Certificate, error)

// setup basic auth.
if opts.Username != "" && opts.Password != "" {
basicAuthCreds, err := NewGrpcBasicAuth(opts.Username, opts.Password)
if err != nil {
return nil, err
}
perRpcDialOpt = grpc.WithPerRPCCredentials(basicAuthCreds)
} else {
perRpcDialOpt = nil
}

var certificates []tls.Certificate
if opts.Certificate != nil {
if perRpcDialOpt != nil {
return nil, errors.New("cannot use basic credentials and client cert auth at the same time")
}

certificates = append(certificates, *opts.Certificate)
switch a := opts.Authenticator.(type) {
case *BasicAuthenticator:
perRpcDialOpt = grpc.WithPerRPCCredentials(a)
case *CertificateAuthenticator:
getClientCertificate = a.GetClientCertificate
}

pool, err := x509.SystemCertPool()
Expand All @@ -94,9 +78,9 @@ func dialRoutingConn(ctx context.Context, address string, opts *routingConnOptio

dialOpts := []grpc.DialOption{grpc.WithTransportCredentials(credentials.NewTLS(
&tls.Config{
InsecureSkipVerify: opts.InsecureSkipVerify,
RootCAs: pool,
Certificates: certificates,
InsecureSkipVerify: opts.InsecureSkipVerify,
RootCAs: pool,
GetClientCertificate: getClientCertificate,
},
))}
if perRpcDialOpt != nil {
Expand Down
Loading