Skip to content

Commit dede03f

Browse files
Refactor extracting requested devices from the container image
This change consolidates the logic for determining requested devices from the container image. The logic for this has been integrated into the image.CUDA type so that multiple implementations are not required. Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]> Co-authored-by: Evan Lezar <[email protected]>
1 parent fdcd250 commit dede03f

File tree

10 files changed

+652
-437
lines changed

10 files changed

+652
-437
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 19 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ import (
1313
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
1414
)
1515

16-
const (
17-
capSysAdmin = "CAP_SYS_ADMIN"
18-
)
19-
2016
type nvidiaConfig struct {
2117
Devices []string
2218
MigConfigDevices string
@@ -103,9 +99,9 @@ func loadSpec(path string) (spec *Spec) {
10399
return
104100
}
105101

106-
func isPrivileged(s *Spec) bool {
107-
if s.Process.Capabilities == nil {
108-
return false
102+
func (s *Spec) GetCapabilities() []string {
103+
if s == nil || s.Process == nil || s.Process.Capabilities == nil {
104+
return nil
109105
}
110106

111107
var caps []string
@@ -118,67 +114,22 @@ func isPrivileged(s *Spec) bool {
118114
if err != nil {
119115
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
120116
}
121-
for _, c := range caps {
122-
if c == capSysAdmin {
123-
return true
124-
}
125-
}
126-
return false
117+
return caps
127118
}
128119

129120
// Otherwise, parse s.Process.Capabilities as:
130121
// github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54
131-
process := specs.Process{
132-
Env: s.Process.Env,
133-
}
134-
135-
err := json.Unmarshal(*s.Process.Capabilities, &process.Capabilities)
122+
capabilities := specs.LinuxCapabilities{}
123+
err := json.Unmarshal(*s.Process.Capabilities, &capabilities)
136124
if err != nil {
137125
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
138126
}
139127

140-
fullSpec := specs.Spec{
141-
Version: *s.Version,
142-
Process: &process,
143-
}
144-
145-
return image.IsPrivileged(&fullSpec)
128+
return image.OCISpecCapabilities(capabilities).GetCapabilities()
146129
}
147130

148-
func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []string) []string {
149-
// We check if the image has at least one of the Swarm resource envvars defined and use this
150-
// if specified.
151-
for _, envvar := range swarmResourceEnvvars {
152-
if containerImage.HasEnvvar(envvar) {
153-
return containerImage.DevicesFromEnvvars(swarmResourceEnvvars...).List()
154-
}
155-
}
156-
157-
return containerImage.VisibleDevicesFromEnvVar()
158-
}
159-
160-
func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
161-
// If enabled, try and get the device list from volume mounts first
162-
if hookConfig.AcceptDeviceListAsVolumeMounts {
163-
devices := image.VisibleDevicesFromMounts()
164-
if len(devices) > 0 {
165-
return devices
166-
}
167-
}
168-
169-
// Fallback to reading from the environment variable if privileges are correct
170-
devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars())
171-
if len(devices) == 0 {
172-
return nil
173-
}
174-
if privileged || hookConfig.AcceptEnvvarUnprivileged {
175-
return devices
176-
}
177-
178-
configName := hookConfig.getConfigOption("AcceptEnvvarUnprivileged")
179-
log.Printf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES (privileged=%v, %v=%v) ", privileged, configName, hookConfig.AcceptEnvvarUnprivileged)
180-
181-
return nil
131+
func isPrivileged(s *Spec) bool {
132+
return image.IsPrivileged(s)
182133
}
183134

184135
func getMigConfigDevices(i image.CUDA) *string {
@@ -225,7 +176,6 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
225176
// We use the default driver capabilities by default. This is filtered to only include the
226177
// supported capabilities
227178
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)
228-
229179
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
230180

231181
capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities)
@@ -251,7 +201,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
251201
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
252202
legacyImage := image.IsLegacy()
253203

254-
devices := hookConfig.getDevices(image, privileged)
204+
devices := image.VisibleDevices()
255205
if len(devices) == 0 {
256206
// empty devices means this is not a GPU container.
257207
return nil
@@ -306,20 +256,25 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
306256

307257
s := loadSpec(path.Join(b, "config.json"))
308258

309-
image, err := image.New(
259+
privileged := isPrivileged(s)
260+
261+
i, err := image.New(
310262
image.WithEnv(s.Process.Env),
311263
image.WithMounts(s.Mounts),
264+
image.WithPrivileged(privileged),
312265
image.WithDisableRequire(hookConfig.DisableRequire),
266+
image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts),
267+
image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged),
268+
image.WithPreferredVisibleDevicesEnvVars(hookConfig.getSwarmResourceEnvvars()...),
313269
)
314270
if err != nil {
315271
log.Panicln(err)
316272
}
317273

318-
privileged := isPrivileged(s)
319274
return containerConfig{
320275
Pid: h.Pid,
321276
Rootfs: s.Root.Path,
322-
Image: image,
323-
Nvidia: hookConfig.getNvidiaConfig(image, privileged),
277+
Image: i,
278+
Nvidia: hookConfig.getNvidiaConfig(i, privileged),
324279
}
325280
}

0 commit comments

Comments
 (0)