Skip to content

Commit 9aa88ea

Browse files
Add --disable-hook flag to cdi generate command
When running the nvidia-ctk cdi generate command, a user should be able to opt out of specific hooks. We propose to add a flag --disable-hook that will take a comma-separated list of hooks that will be skipped when creating the CDI spec. Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]>
1 parent aa696ef commit 9aa88ea

File tree

7 files changed

+237
-49
lines changed

7 files changed

+237
-49
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ type options struct {
5757

5858
configSearchPaths cli.StringSlice
5959
librarySearchPaths cli.StringSlice
60+
disabledHooks cli.StringSlice
6061

6162
csv struct {
6263
files cli.StringSlice
@@ -176,6 +177,13 @@ func (m command) build() *cli.Command {
176177
Usage: "Specify a pattern the CSV mount specifications.",
177178
Destination: &opts.csv.ignorePatterns,
178179
},
180+
&cli.StringSliceFlag{
181+
Name: "disable-hook",
182+
Aliases: []string{"disable-hooks"},
183+
Usage: "Hook to skip when generating the CDI specification. Can be specified multiple times. Can be a comma-separated list of hooks or a single hook name.",
184+
Value: cli.NewStringSlice(),
185+
Destination: &opts.disabledHooks,
186+
},
179187
}
180188

181189
return &c
@@ -262,7 +270,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
262270
deviceNamers = append(deviceNamers, deviceNamer)
263271
}
264272

265-
cdilib, err := nvcdi.New(
273+
initOpts := []nvcdi.Option{
266274
nvcdi.WithLogger(m.logger),
267275
nvcdi.WithDriverRoot(opts.driverRoot),
268276
nvcdi.WithDevRoot(opts.devRoot),
@@ -276,7 +284,13 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
276284
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns.Value()),
277285
// We set the following to allow for dependency injection:
278286
nvcdi.WithNvmlLib(opts.nvmllib),
279-
)
287+
}
288+
289+
for _, hook := range opts.disabledHooks.Value() {
290+
initOpts = append(initOpts, nvcdi.WithDisabledHook(hook))
291+
}
292+
293+
cdilib, err := nvcdi.New(initOpts...)
280294
if err != nil {
281295
return nil, fmt.Errorf("failed to create CDI library: %v", err)
282296
}

cmd/nvidia-ctk/cdi/generate/generate_test.go

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100"
2727
testlog "github.com/sirupsen/logrus/hooks/test"
2828
"github.com/stretchr/testify/require"
29+
"github.com/urfave/cli/v2"
2930

3031
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
3132
)
@@ -36,6 +37,9 @@ func TestGenerateSpec(t *testing.T) {
3637
require.NoError(t, err)
3738

3839
driverRoot := filepath.Join(moduleRoot, "testdata", "lookup", "rootfs-1")
40+
disableHook1 := cli.NewStringSlice("enable-cuda-compat")
41+
disableHook2 := cli.NewStringSlice("enable-cuda-compat", "update-ldcache")
42+
disableHook3 := cli.NewStringSlice("all")
3943

4044
logger, _ := testlog.NewNullLogger()
4145
testCases := []struct {
@@ -113,6 +117,179 @@ containerEdits:
113117
- nodev
114118
- rbind
115119
- rprivate
120+
`,
121+
},
122+
{
123+
description: "disableHooks1",
124+
options: options{
125+
format: "yaml",
126+
mode: "nvml",
127+
vendor: "example.com",
128+
class: "device",
129+
driverRoot: driverRoot,
130+
disabledHooks: *disableHook1,
131+
},
132+
expectedOptions: options{
133+
format: "yaml",
134+
mode: "nvml",
135+
vendor: "example.com",
136+
class: "device",
137+
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
138+
driverRoot: driverRoot,
139+
disabledHooks: *disableHook1,
140+
},
141+
expectedSpec: `---
142+
cdiVersion: 0.5.0
143+
kind: example.com/device
144+
devices:
145+
- name: "0"
146+
containerEdits:
147+
deviceNodes:
148+
- path: /dev/nvidia0
149+
hostPath: {{ .driverRoot }}/dev/nvidia0
150+
- name: all
151+
containerEdits:
152+
deviceNodes:
153+
- path: /dev/nvidia0
154+
hostPath: {{ .driverRoot }}/dev/nvidia0
155+
containerEdits:
156+
env:
157+
- NVIDIA_VISIBLE_DEVICES=void
158+
deviceNodes:
159+
- path: /dev/nvidiactl
160+
hostPath: {{ .driverRoot }}/dev/nvidiactl
161+
hooks:
162+
- hookName: createContainer
163+
path: /usr/bin/nvidia-cdi-hook
164+
args:
165+
- nvidia-cdi-hook
166+
- create-symlinks
167+
- --link
168+
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
169+
- hookName: createContainer
170+
path: /usr/bin/nvidia-cdi-hook
171+
args:
172+
- nvidia-cdi-hook
173+
- update-ldcache
174+
- --folder
175+
- /lib/x86_64-linux-gnu
176+
mounts:
177+
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
178+
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
179+
options:
180+
- ro
181+
- nosuid
182+
- nodev
183+
- rbind
184+
- rprivate
185+
`,
186+
},
187+
{
188+
description: "disableHooks2",
189+
options: options{
190+
format: "yaml",
191+
mode: "nvml",
192+
vendor: "example.com",
193+
class: "device",
194+
driverRoot: driverRoot,
195+
disabledHooks: *disableHook2,
196+
},
197+
expectedOptions: options{
198+
format: "yaml",
199+
mode: "nvml",
200+
vendor: "example.com",
201+
class: "device",
202+
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
203+
driverRoot: driverRoot,
204+
disabledHooks: *disableHook2,
205+
},
206+
expectedSpec: `---
207+
cdiVersion: 0.5.0
208+
kind: example.com/device
209+
devices:
210+
- name: "0"
211+
containerEdits:
212+
deviceNodes:
213+
- path: /dev/nvidia0
214+
hostPath: {{ .driverRoot }}/dev/nvidia0
215+
- name: all
216+
containerEdits:
217+
deviceNodes:
218+
- path: /dev/nvidia0
219+
hostPath: {{ .driverRoot }}/dev/nvidia0
220+
containerEdits:
221+
env:
222+
- NVIDIA_VISIBLE_DEVICES=void
223+
deviceNodes:
224+
- path: /dev/nvidiactl
225+
hostPath: {{ .driverRoot }}/dev/nvidiactl
226+
hooks:
227+
- hookName: createContainer
228+
path: /usr/bin/nvidia-cdi-hook
229+
args:
230+
- nvidia-cdi-hook
231+
- create-symlinks
232+
- --link
233+
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
234+
mounts:
235+
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
236+
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
237+
options:
238+
- ro
239+
- nosuid
240+
- nodev
241+
- rbind
242+
- rprivate
243+
`,
244+
},
245+
{
246+
description: "disableHooksAll",
247+
options: options{
248+
format: "yaml",
249+
mode: "nvml",
250+
vendor: "example.com",
251+
class: "device",
252+
driverRoot: driverRoot,
253+
disabledHooks: *disableHook3,
254+
},
255+
expectedOptions: options{
256+
format: "yaml",
257+
mode: "nvml",
258+
vendor: "example.com",
259+
class: "device",
260+
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
261+
driverRoot: driverRoot,
262+
disabledHooks: *disableHook3,
263+
},
264+
expectedSpec: `---
265+
cdiVersion: 0.5.0
266+
kind: example.com/device
267+
devices:
268+
- name: "0"
269+
containerEdits:
270+
deviceNodes:
271+
- path: /dev/nvidia0
272+
hostPath: {{ .driverRoot }}/dev/nvidia0
273+
- name: all
274+
containerEdits:
275+
deviceNodes:
276+
- path: /dev/nvidia0
277+
hostPath: {{ .driverRoot }}/dev/nvidia0
278+
containerEdits:
279+
env:
280+
- NVIDIA_VISIBLE_DEVICES=void
281+
deviceNodes:
282+
- path: /dev/nvidiactl
283+
hostPath: {{ .driverRoot }}/dev/nvidiactl
284+
mounts:
285+
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
286+
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
287+
options:
288+
- ro
289+
- nosuid
290+
- nodev
291+
- rbind
292+
- rprivate
116293
`,
117294
},
118295
}

internal/discover/hooks.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,16 @@ func (h *Hook) Hooks() ([]Hook, error) {
4646

4747
type HookName string
4848

49-
// DisabledHooks allows individual hooks to be disabled.
50-
type DisabledHooks map[HookName]bool
49+
// disabledHooks allows individual hooks to be disabled.
50+
type disabledHooks map[HookName]bool
51+
52+
// isDisabled checks if a hook is disabled.
53+
func (d disabledHooks) isDisabled(h HookName) bool {
54+
if d["all"] {
55+
return true
56+
}
57+
return d[h]
58+
}
5159

5260
const (
5361
// HookEnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility.
@@ -67,26 +75,43 @@ var AllHooks = []HookName{
6775
HookUpdateLDCache,
6876
}
6977

70-
// Option is a function that configures the nvcdilib
7178
type Option func(*CDIHook)
7279

7380
type CDIHook struct {
7481
nvidiaCDIHookPath string
82+
disabledHooks disabledHooks
7583
}
7684

7785
type HookCreator interface {
7886
Create(HookName, ...string) *Hook
7987
}
8088

81-
func NewHookCreator(nvidiaCDIHookPath string) HookCreator {
89+
func WithDisabledHooks(hooks ...HookName) Option {
90+
return func(c *CDIHook) {
91+
for _, hook := range hooks {
92+
c.disabledHooks[hook] = true
93+
}
94+
}
95+
}
96+
97+
func NewHookCreator(nvidiaCDIHookPath string, opts ...Option) HookCreator {
8298
CDIHook := &CDIHook{
8399
nvidiaCDIHookPath: nvidiaCDIHookPath,
100+
disabledHooks: disabledHooks{},
101+
}
102+
103+
for _, opt := range opts {
104+
opt(CDIHook)
84105
}
85106

86107
return CDIHook
87108
}
88109

89110
func (c CDIHook) Create(name HookName, args ...string) *Hook {
111+
if c.disabledHooks[name] {
112+
return nil
113+
}
114+
90115
if name == "create-symlinks" {
91116
if len(args) == 0 {
92117
return nil

pkg/nvcdi/driver-nvml.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,9 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover
106106
)
107107
discoverers = append(discoverers, driverDotSoSymlinksDiscoverer)
108108

109-
if l.HookIsSupported(HookEnableCudaCompat) {
110-
// TODO: The following should use the version directly.
111-
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver)
112-
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
113-
}
109+
// TODO: The following should use the version directly.
110+
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver)
111+
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
114112

115113
updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath)
116114
discoverers = append(discoverers, updateLDCache)

pkg/nvcdi/hooks.go

Lines changed: 0 additions & 27 deletions
This file was deleted.

pkg/nvcdi/lib.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,13 @@ type nvcdilib struct {
5656

5757
mergedDeviceOptions []transform.MergedDeviceOption
5858

59-
disabledHooks discover.DisabledHooks
59+
disabledHooks []discover.HookName
6060
hookCreator discover.HookCreator
6161
}
6262

6363
// New creates a new nvcdi library
6464
func New(opts ...Option) (Interface, error) {
65-
l := &nvcdilib{
66-
disabledHooks: make(discover.DisabledHooks),
67-
}
65+
l := &nvcdilib{}
6866
for _, opt := range opts {
6967
opt(l)
7068
}
@@ -81,8 +79,6 @@ func New(opts ...Option) (Interface, error) {
8179
if l.nvidiaCDIHookPath == "" {
8280
l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
8381
}
84-
// create hookCreator
85-
l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath)
8682

8783
if l.driverRoot == "" {
8884
l.driverRoot = "/"
@@ -150,7 +146,7 @@ func New(opts ...Option) (Interface, error) {
150146
l.vendor = "management.nvidia.com"
151147
}
152148
// Management containers in general do not require CUDA Forward compatibility.
153-
l.disabledHooks[HookEnableCudaCompat] = true
149+
l.disabledHooks = append(l.disabledHooks, HookEnableCudaCompat)
154150
lib = (*managementlib)(l)
155151
case ModeNvml:
156152
lib = (*nvmllib)(l)
@@ -175,6 +171,9 @@ func New(opts ...Option) (Interface, error) {
175171
return nil, fmt.Errorf("unknown mode %q", l.mode)
176172
}
177173

174+
// create hookCreator
175+
l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath, discover.WithDisabledHooks(l.disabledHooks...))
176+
178177
w := wrapper{
179178
Interface: lib,
180179
vendor: l.vendor,

0 commit comments

Comments
 (0)