Skip to content

Commit 8513316

Browse files
authored
Merge pull request #1267 from elezar/add-imex-channel-to-jit-cdi
Add imex channel to jit cdi
2 parents 798376b + 55c6859 commit 8513316

File tree

6 files changed

+88
-39
lines changed

6 files changed

+88
-39
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -148,30 +148,6 @@ func getMigDevices(image image.CUDA, envvar string) *string {
148148
return &devices
149149
}
150150

151-
func (hookConfig *hookConfig) getImexChannels(image image.CUDA, privileged bool) []string {
152-
if hookConfig.Features.IgnoreImexChannelRequests.IsEnabled() {
153-
return nil
154-
}
155-
156-
// If enabled, try and get the device list from volume mounts first
157-
if hookConfig.AcceptDeviceListAsVolumeMounts {
158-
devices := image.ImexChannelsFromMounts()
159-
if len(devices) > 0 {
160-
return devices
161-
}
162-
}
163-
devices := image.ImexChannelsFromEnvVar()
164-
if len(devices) == 0 {
165-
return nil
166-
}
167-
168-
if privileged || hookConfig.AcceptEnvvarUnprivileged {
169-
return devices
170-
}
171-
172-
return nil
173-
}
174-
175151
func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities {
176152
// We use the default driver capabilities by default. This is filtered to only include the
177153
// supported capabilities
@@ -223,8 +199,6 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool)
223199
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
224200
}
225201

226-
imexChannels := hookConfig.getImexChannels(image, privileged)
227-
228202
driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String()
229203

230204
requirements, err := image.GetRequirements()
@@ -236,7 +210,7 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool)
236210
Devices: devices,
237211
MigConfigDevices: migConfigDevices,
238212
MigMonitorDevices: migMonitorDevices,
239-
ImexChannels: imexChannels,
213+
ImexChannels: image.ImexChannelRequests(),
240214
DriverCapabilities: driverCapabilities,
241215
Requirements: requirements,
242216
}
@@ -273,6 +247,7 @@ func (hookConfig *hookConfig) getContainerConfig() (config *containerConfig) {
273247
image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts),
274248
image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged),
275249
image.WithPreferredVisibleDevicesEnvVars(hookConfig.getSwarmResourceEnvvars()...),
250+
image.WithIgnoreImexChannelRequests(hookConfig.Features.IgnoreImexChannelRequests.IsEnabled()),
276251
)
277252
if err != nil {
278253
log.Panicln(err)

internal/config/image/builder.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,15 @@ func WithEnvMap(env map[string]string) Option {
127127
}
128128
}
129129

130+
// WithIgnoreImexChannelRequests sets whether per-container IMEX channel
131+
// requests are supported.
132+
func WithIgnoreImexChannelRequests(ignoreImexChannelRequests bool) Option {
133+
return func(b *builder) error {
134+
b.ignoreImexChannelRequests = ignoreImexChannelRequests
135+
return nil
136+
}
137+
}
138+
130139
// WithLogger sets the logger to use when creating the CUDA image.
131140
func WithLogger(logger logger.Interface) Option {
132141
return func(b *builder) error {

internal/config/image/cuda_image.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ type CUDA struct {
5151
annotationsPrefixes []string
5252
acceptDeviceListAsVolumeMounts bool
5353
acceptEnvvarUnprivileged bool
54+
ignoreImexChannelRequests bool
5455
preferredVisibleDeviceEnvVars []string
5556
}
5657

@@ -412,17 +413,51 @@ func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
412413
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
413414
}
414415

415-
// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
416-
func (i CUDA) ImexChannelsFromEnvVar() []string {
416+
func (i CUDA) ImexChannelRequests() []string {
417+
if i.ignoreImexChannelRequests {
418+
return nil
419+
}
420+
421+
// If enabled, try and get the device list from volume mounts first
422+
if i.acceptDeviceListAsVolumeMounts {
423+
volumeMountDeviceRequests := i.imexChannelsFromMounts()
424+
if len(volumeMountDeviceRequests) > 0 {
425+
return volumeMountDeviceRequests
426+
}
427+
}
428+
429+
// Get the Fallback to reading from the environment variable if privileges are correct
430+
envVarDeviceRequests := i.imexChannelsFromEnvVar()
431+
if len(envVarDeviceRequests) == 0 {
432+
return nil
433+
}
434+
435+
// If the container is privileged, or environment variable requests are
436+
// allowed for unprivileged containers, these devices are returned.
437+
if i.isPrivileged || i.acceptEnvvarUnprivileged {
438+
return envVarDeviceRequests
439+
}
440+
441+
// We log a warning if we are ignoring the environment variable requests.
442+
envVars := []string{EnvVarNvidiaImexChannels}
443+
if len(envVars) > 0 {
444+
i.logger.Warningf("Ignoring request by environment variable(s) in unprivileged container: %v", envVars)
445+
}
446+
447+
return nil
448+
}
449+
450+
// imexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
451+
func (i CUDA) imexChannelsFromEnvVar() []string {
417452
imexChannels := i.devicesFromEnvvars(EnvVarNvidiaImexChannels)
418453
if len(imexChannels) == 1 && imexChannels[0] == "all" {
419454
return nil
420455
}
421456
return imexChannels
422457
}
423458

424-
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
425-
func (i CUDA) ImexChannelsFromMounts() []string {
459+
// imexChannelsFromMounts returns the list of IMEX channels requested for the image.
460+
func (i CUDA) imexChannelsFromMounts() []string {
426461
var channels []string
427462
for _, mountDevice := range i.requestsFromMounts() {
428463
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {

internal/config/image/cuda_image_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ func TestImexChannelsFromEnvVar(t *testing.T) {
716716
i, err := newCUDAImageFromEnv(append(baseEnvvars, tc.env...))
717717
require.NoError(t, err)
718718

719-
channels := i.ImexChannelsFromEnvVar()
719+
channels := i.imexChannelsFromEnvVar()
720720
require.EqualValues(t, tc.expected, channels)
721721
})
722722
}

internal/modifier/cdi.go

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD
6262
return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices")
6363
}
6464
if len(automaticDevices) > 0 {
65-
automaticDevices = append(automaticDevices, gatedDevices(image).DeviceRequests()...)
65+
automaticDevices = append(automaticDevices, withUniqueDevices(gatedDevices(image)).DeviceRequests()...)
66+
automaticDevices = append(automaticDevices, withUniqueDevices(imexDevices(image)).DeviceRequests()...)
67+
6668
automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, automaticDevices)
6769
if err == nil {
6870
return automaticModifier, nil
@@ -135,6 +137,17 @@ func (g gatedDevices) DeviceRequests() []string {
135137
return devices
136138
}
137139

140+
type imexDevices image.CUDA
141+
142+
func (d imexDevices) DeviceRequests() []string {
143+
var devices []string
144+
i := (image.CUDA)(d)
145+
for _, channelID := range i.ImexChannelRequests() {
146+
devices = append(devices, "mode=imex,id="+channelID)
147+
}
148+
return devices
149+
}
150+
138151
// filterAutomaticDevices searches for "automatic" device names in the input slice.
139152
// "Automatic" devices are a well-defined list of CDI device names which, when requested,
140153
// trigger the generation of a CDI spec at runtime. This removes the need to generate a
@@ -155,17 +168,21 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
155168

156169
perModeIdentifiers := make(map[string][]string)
157170
perModeDeviceClass := map[string]string{"auto": automaticDeviceClass}
158-
modes := []string{"auto"}
171+
uniqueModes := []string{"auto"}
172+
seen := make(map[string]bool)
159173
for _, device := range devices {
160-
if strings.HasPrefix(device, "mode=") {
161-
modes = append(modes, strings.TrimPrefix(device, "mode="))
162-
continue
174+
mode, id := getModeIdentifier(device)
175+
if !seen[mode] {
176+
uniqueModes = append(uniqueModes, mode)
177+
seen[mode] = true
178+
}
179+
if id != "" {
180+
perModeIdentifiers[id] = append(perModeIdentifiers[id], id)
163181
}
164-
perModeIdentifiers["auto"] = append(perModeIdentifiers["auto"], strings.TrimPrefix(device, automaticDevicePrefix))
165182
}
166183

167184
var modifiers oci.SpecModifiers
168-
for _, mode := range modes {
185+
for _, mode := range uniqueModes {
169186
cdilib, err := nvcdi.New(
170187
nvcdi.WithLogger(logger),
171188
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
@@ -197,6 +214,18 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
197214
return modifiers, nil
198215
}
199216

217+
func getModeIdentifier(device string) (string, string) {
218+
if !strings.HasPrefix(device, "mode=") {
219+
return "auto", strings.TrimPrefix(device, automaticDevicePrefix)
220+
}
221+
parts := strings.SplitN(device, ",", 2)
222+
mode := strings.TrimPrefix(parts[0], "mode=")
223+
if len(parts) == 2 {
224+
return mode, strings.TrimPrefix(parts[1], "id=")
225+
}
226+
return mode, ""
227+
}
228+
200229
type deduplicatedDeviceRequestor struct {
201230
deviceRequestor
202231
}

internal/runtime/runtime_factory.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpe
132132
image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts),
133133
image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged),
134134
image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes),
135+
image.WithIgnoreImexChannelRequests(cfg.Features.IgnoreImexChannelRequests.IsEnabled()),
135136
)
136137
if err != nil {
137138
return "", nil, err

0 commit comments

Comments
 (0)