diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index c7cd8c945..abdd88424 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -166,38 +166,24 @@ func filterAutomaticDevices(devices []string) []string { func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) { logger.Debugf("Generating in-memory CDI specs for devices %v", devices) - perModeIdentifiers := make(map[string][]string) - perModeDeviceClass := map[string]string{"auto": automaticDeviceClass} - uniqueModes := []string{"auto"} - seen := make(map[string]bool) - for _, device := range devices { - mode, id := getModeIdentifier(device) - logger.Debugf("Mapped %v to %v: %v", device, mode, id) - if !seen[mode] { - uniqueModes = append(uniqueModes, mode) - seen[mode] = true - } - if id != "" { - perModeIdentifiers[mode] = append(perModeIdentifiers[mode], id) - } - } + cdiModeIdentifiers := cdiModeIdentfiersFromDevices(devices...) - logger.Debugf("Per-mode identifiers: %v", perModeIdentifiers) + logger.Debugf("Per-mode identifiers: %v", cdiModeIdentifiers) var modifiers oci.SpecModifiers - for _, mode := range uniqueModes { + for _, mode := range cdiModeIdentifiers.modes { cdilib, err := nvcdi.New( nvcdi.WithLogger(logger), nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path), nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), nvcdi.WithVendor(automaticDeviceVendor), - nvcdi.WithClass(perModeDeviceClass[mode]), + nvcdi.WithClass(cdiModeIdentifiers.deviceClassByMode[mode]), nvcdi.WithMode(mode), ) if err != nil { return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err) } - spec, err := cdilib.GetSpec(perModeIdentifiers[mode]...) + spec, err := cdilib.GetSpec(cdiModeIdentifiers.idsByMode[mode]...) if err != nil { return nil, fmt.Errorf("failed to generate CDI spec for mode %q: %w", mode, err) } @@ -216,6 +202,35 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de return modifiers, nil } +type cdiModeIdentifiers struct { + modes []string + idsByMode map[string][]string + deviceClassByMode map[string]string +} + +func cdiModeIdentfiersFromDevices(devices ...string) *cdiModeIdentifiers { + perModeIdentifiers := make(map[string][]string) + perModeDeviceClass := map[string]string{"auto": automaticDeviceClass} + var uniqueModes []string + seen := make(map[string]bool) + for _, device := range devices { + mode, id := getModeIdentifier(device) + if !seen[mode] { + uniqueModes = append(uniqueModes, mode) + seen[mode] = true + } + if id != "" { + perModeIdentifiers[mode] = append(perModeIdentifiers[mode], id) + } + } + + return &cdiModeIdentifiers{ + modes: uniqueModes, + idsByMode: perModeIdentifiers, + deviceClassByMode: perModeDeviceClass, + } +} + func getModeIdentifier(device string) (string, string) { if !strings.HasPrefix(device, "mode=") { return "auto", strings.TrimPrefix(device, automaticDevicePrefix) diff --git a/internal/modifier/cdi_test.go b/internal/modifier/cdi_test.go index 0163e4ce0..65d2ca37e 100644 --- a/internal/modifier/cdi_test.go +++ b/internal/modifier/cdi_test.go @@ -170,3 +170,86 @@ func TestDeviceRequests(t *testing.T) { }) } } + +func Test_cdiModeIdentfiersFromDevices(t *testing.T) { + testCases := []struct { + description string + devices []string + expected *cdiModeIdentifiers + }{ + { + description: "empty device list", + devices: []string{}, + expected: &cdiModeIdentifiers{ + modes: nil, + idsByMode: map[string][]string{}, + deviceClassByMode: map[string]string{"auto": "gpu"}, + }, + }, + { + description: "single automatic device", + devices: []string{"0"}, + expected: &cdiModeIdentifiers{ + modes: []string{"auto"}, + idsByMode: map[string][]string{"auto": {"0"}}, + deviceClassByMode: map[string]string{"auto": "gpu"}, + }, + }, + { + description: "multiple automatic devices", + devices: []string{"0", "1"}, + expected: &cdiModeIdentifiers{ + modes: []string{"auto"}, + idsByMode: map[string][]string{"auto": {"0", "1"}}, + deviceClassByMode: map[string]string{"auto": "gpu"}, + }, + }, + { + description: "device with explicit mode", + devices: []string{"mode=gds,id=foo"}, + expected: &cdiModeIdentifiers{ + modes: []string{"gds"}, + idsByMode: map[string][]string{"gds": {"foo"}}, + deviceClassByMode: map[string]string{"auto": "gpu"}, + }, + }, + { + description: "mixed auto and explicit", + devices: []string{"0", "mode=gds,id=foo", "mode=gdrcopy,id=bar"}, + expected: &cdiModeIdentifiers{ + modes: []string{"auto", "gds", "gdrcopy"}, + idsByMode: map[string][]string{ + "auto": {"0"}, + "gds": {"foo"}, + "gdrcopy": {"bar"}, + }, + deviceClassByMode: map[string]string{"auto": "gpu"}, + }, + }, + { + description: "device with only mode, no id", + devices: []string{"mode=nvswitch"}, + expected: &cdiModeIdentifiers{ + modes: []string{"nvswitch"}, + idsByMode: map[string][]string{}, + deviceClassByMode: map[string]string{"auto": "gpu"}, + }, + }, + { + description: "duplicate modes", + devices: []string{"mode=gds,id=x", "mode=gds,id=y", "mode=gds"}, + expected: &cdiModeIdentifiers{ + modes: []string{"gds"}, + idsByMode: map[string][]string{"gds": {"x", "y"}}, + deviceClassByMode: map[string]string{"auto": "gpu"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + result := cdiModeIdentfiersFromDevices(tc.devices...) + require.EqualValues(t, tc.expected, result) + }) + } +} diff --git a/tests/e2e/nvidia-container-toolkit_test.go b/tests/e2e/nvidia-container-toolkit_test.go index b2487c436..80ff3663c 100644 --- a/tests/e2e/nvidia-container-toolkit_test.go +++ b/tests/e2e/nvidia-container-toolkit_test.go @@ -198,6 +198,16 @@ var _ = Describe("docker", Ordered, ContinueOnFailure, func() { Expect(err).ToNot(HaveOccurred()) Expect(ldconfigOut).To(ContainSubstring("/usr/local/cuda-12.9/compat/")) }) + + It("should create a single ld.so.conf.d config file", func(ctx context.Context) { + lsout, _, err := runner.Run("docker run --rm -i -e NVIDIA_DISABLE_REQUIRE=true --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=runtime.nvidia.com/gpu=all nvcr.io/nvidia/cuda:12.9.0-base-ubi8 bash -c \"ls -l /etc/ld.so.conf.d/00-compat-*.conf\"") + Expect(err).ToNot(HaveOccurred()) + Expect(lsout).To(WithTransform( + func(s string) []string { + return strings.Split(strings.TrimSpace(s), "\n") + }, HaveLen(1), + )) + }) }) When("Disabling device node creation", Ordered, func() { diff --git a/third_party/libnvidia-container b/third_party/libnvidia-container index 0964f8171..889a3bb54 160000 --- a/third_party/libnvidia-container +++ b/third_party/libnvidia-container @@ -1 +1 @@ -Subproject commit 0964f81717e96ac903e39700908677dcdf72ed5f +Subproject commit 889a3bb5408c195ed7897ba2cb8341c7d249672f