Skip to content

Commit df90001

Browse files
committed
Determine cliqueID from NVML not node label
Signed-off-by: Kevin Klues <[email protected]>
1 parent 222df11 commit df90001

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

cmd/compute-domain-kubelet-plugin/device_state.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"sync"
2424

2525
resourceapi "k8s.io/api/resource/v1beta1"
26-
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2726
"k8s.io/apimachinery/pkg/runtime"
2827
"k8s.io/klog/v2"
2928
drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1"
@@ -89,12 +88,11 @@ func NewDeviceState(ctx context.Context, config *Config) (*DeviceState, error) {
8988
return nil, fmt.Errorf("unable to create CDI handler: %w", err)
9089
}
9190

92-
node, err := config.clientsets.Core.CoreV1().Nodes().Get(ctx, config.flags.nodeName, metav1.GetOptions{})
91+
cliqueID, err := nvdevlib.getCliqueID()
9392
if err != nil {
94-
return nil, fmt.Errorf("error getting Node: %w", err)
93+
return nil, fmt.Errorf("error getting cliqueID: %w", err)
9594
}
9695

97-
cliqueID := node.Labels[CliqueIDLabelKey]
9896
computeDomainManager := NewComputeDomainManager(config, ComputeDomainDaemonSettingsRoot, cliqueID)
9997

10098
if err := cdi.CreateStandardDeviceSpecFile(allocatable); err != nil {

cmd/compute-domain-kubelet-plugin/nvlib.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ import (
2727

2828
"golang.org/x/sys/unix"
2929

30+
"k8s.io/klog/v2"
3031
"k8s.io/mount-utils"
3132

33+
"github.com/google/uuid"
34+
3235
nvdev "github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
3336
"github.com/NVIDIA/go-nvml/pkg/nvml"
3437
)
@@ -88,6 +91,21 @@ func newDeviceLib(driverRoot root) (*deviceLib, error) {
8891
return &d, nil
8992
}
9093

94+
func (l deviceLib) init() error {
95+
ret := l.nvmllib.Init()
96+
if ret != nvml.SUCCESS {
97+
return fmt.Errorf("error initializing NVML: %v", ret)
98+
}
99+
return nil
100+
}
101+
102+
func (l deviceLib) alwaysShutdown() {
103+
ret := l.nvmllib.Shutdown()
104+
if ret != nvml.SUCCESS {
105+
klog.Warningf("error shutting down NVML: %v", ret)
106+
}
107+
}
108+
91109
func (l deviceLib) enumerateAllPossibleDevices(config *Config) (AllocatableDevices, error) {
92110
alldevices := make(AllocatableDevices)
93111

@@ -142,6 +160,66 @@ func (l deviceLib) enumerateComputeDomainDaemons(config *Config) (AllocatableDev
142160
return devices, nil
143161
}
144162

163+
func (l deviceLib) getCliqueID() (string, error) {
164+
if err := l.init(); err != nil {
165+
return "", fmt.Errorf("error initializing deviceLib: %w", err)
166+
}
167+
defer l.alwaysShutdown()
168+
169+
uniqueClusterUUIDs := make(map[string]struct{})
170+
uniqueCliqueIDs := make(map[string]struct{})
171+
172+
err := l.VisitDevices(func(i int, d nvdev.Device) error {
173+
isFabricAttached, err := d.IsFabricAttached()
174+
if err != nil {
175+
return fmt.Errorf("error checking if device is fabric attached: %w", err)
176+
}
177+
if !isFabricAttached {
178+
return nil
179+
}
180+
181+
info, ret := d.GetGpuFabricInfo()
182+
if ret != nvml.SUCCESS {
183+
return fmt.Errorf("failed to get GPU fabric info: %w", ret)
184+
}
185+
186+
clusterUUID, err := uuid.FromBytes(info.ClusterUuid[:])
187+
if err != nil {
188+
return fmt.Errorf("invalid cluster UUID: %w", err)
189+
}
190+
191+
cliqueID := fmt.Sprintf("%d", info.CliqueId)
192+
193+
uniqueClusterUUIDs[clusterUUID.String()] = struct{}{}
194+
uniqueCliqueIDs[cliqueID] = struct{}{}
195+
196+
return nil
197+
})
198+
if err != nil {
199+
return "", fmt.Errorf("error getting fabric information from one or more devices: %w", err)
200+
}
201+
202+
if len(uniqueClusterUUIDs) == 0 && len(uniqueCliqueIDs) == 0 {
203+
return "", nil
204+
}
205+
206+
if len(uniqueClusterUUIDs) != 1 {
207+
return "", fmt.Errorf("unexpected number of unique ClusterUUIDs found on devices")
208+
}
209+
210+
if len(uniqueCliqueIDs) != 1 {
211+
return "", fmt.Errorf("unexpected number of unique CliqueIDs found on devices")
212+
}
213+
214+
for clusterUUID := range uniqueClusterUUIDs {
215+
for cliqueID := range uniqueCliqueIDs {
216+
return fmt.Sprintf("%s.%s", clusterUUID, cliqueID), nil
217+
}
218+
}
219+
220+
return "", fmt.Errorf("unexpected return")
221+
}
222+
145223
func (l deviceLib) getImexChannelCount() (int, error) {
146224
// TODO: Pull this value from /proc/driver/nvidia/params
147225
return 2048, nil

0 commit comments

Comments
 (0)