diff --git a/client/go.mod b/client/go.mod index 205e08decc..a0d45ffc1c 100644 --- a/client/go.mod +++ b/client/go.mod @@ -16,6 +16,7 @@ require ( github.com/stretchr/testify v1.9.0 go.uber.org/goleak v1.1.11 go.uber.org/zap v1.24.0 + golang.org/x/sync v0.19.0 google.golang.org/grpc v1.75.1 google.golang.org/grpc/examples v0.0.0-20231221225426-4f03f3ff32c9 ) diff --git a/client/go.sum b/client/go.sum index c9c66d687d..1d1f2c955c 100644 --- a/client/go.sum +++ b/client/go.sum @@ -130,6 +130,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/client/servicediscovery/service_discovery.go b/client/servicediscovery/service_discovery.go index b6d2e16c25..9e463c8c5b 100644 --- a/client/servicediscovery/service_discovery.go +++ b/client/servicediscovery/service_discovery.go @@ -26,6 +26,7 @@ import ( "time" "go.uber.org/zap" + "golang.org/x/sync/singleflight" "google.golang.org/grpc" "google.golang.org/grpc/codes" healthpb "google.golang.org/grpc/health/grpc_health_v1" @@ -443,6 +444,8 @@ type serviceDiscovery struct { tlsCfg *tls.Config // Client option. option *opt.Option + + flight singleflight.Group } // NewDefaultServiceDiscovery returns a new default service discovery-based client. @@ -474,6 +477,7 @@ func NewServiceDiscovery( keyspaceID: keyspaceID, tlsCfg: tlsCfg, option: option, + flight: singleflight.Group{}, } pdsd.callbacks.setServiceModeUpdateCallback(serviceModeUpdateCb) urls = tlsutil.AddrsToURLs(urls, tlsCfg) @@ -919,12 +923,16 @@ func (c *serviceDiscovery) getClusterInfo(ctx context.Context, url string, timeo } start := time.Now() defer func() { metrics.InternalCmdDurationGetClusterInfo.Observe(time.Since(start).Seconds()) }() - clusterInfo, err := pdpb.NewPDClient(cc).GetClusterInfo(ctx, &pdpb.GetClusterInfoRequest{}) + key := "GetClusterInfo-" + url + res, err, _ := c.flight.Do(key, func() (any, error) { + return pdpb.NewPDClient(cc).GetClusterInfo(ctx, &pdpb.GetClusterInfoRequest{}) + }) if err != nil { metrics.InternalCmdFailedDurationGetClusterInfo.Observe(time.Since(start).Seconds()) attachErr := errors.Errorf("error:%s target:%s status:%s", err, cc.Target(), cc.GetState().String()) return nil, errs.ErrClientGetClusterInfo.Wrap(attachErr).GenWithStackByCause() } + clusterInfo := res.(*pdpb.GetClusterInfoResponse) if clusterInfo.GetHeader().GetError() != nil { metrics.InternalCmdFailedDurationGetClusterInfo.Observe(time.Since(start).Seconds()) attachErr := errors.Errorf("error:%s target:%s status:%s", clusterInfo.GetHeader().GetError().String(), cc.Target(), cc.GetState().String()) @@ -942,12 +950,16 @@ func (c *serviceDiscovery) getMembers(ctx context.Context, url string, timeout t } start := time.Now() defer func() { metrics.InternalCmdDurationGetMembers.Observe(time.Since(start).Seconds()) }() - members, err := pdpb.NewPDClient(cc).GetMembers(ctx, &pdpb.GetMembersRequest{}) + key := "GetMembers-" + url + res, err, _ := c.flight.Do(key, func() (any, error) { + return pdpb.NewPDClient(cc).GetMembers(ctx, &pdpb.GetMembersRequest{}) + }) if err != nil { metrics.InternalCmdFailedDurationGetMembers.Observe(time.Since(start).Seconds()) attachErr := errors.Errorf("error:%s target:%s status:%s", err, cc.Target(), cc.GetState().String()) return nil, errs.ErrClientGetMember.Wrap(attachErr).GenWithStackByCause() } + members := res.(*pdpb.GetMembersResponse) if members.GetHeader().GetError() != nil { metrics.InternalCmdFailedDurationGetMembers.Observe(time.Since(start).Seconds()) attachErr := errors.Errorf("error:%s target:%s status:%s", members.GetHeader().GetError().String(), cc.Target(), cc.GetState().String())