Skip to content

Commit 40dd3f8

Browse files
Unify GetDevices logic at internal/config/image
Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]> Signed-off-by: Evan Lezar <[email protected]>
1 parent d59fd3d commit 40dd3f8

File tree

9 files changed

+230
-107
lines changed

9 files changed

+230
-107
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 22 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ import (
1313
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
1414
)
1515

16-
const (
17-
capSysAdmin = "CAP_SYS_ADMIN"
18-
)
19-
2016
type nvidiaConfig struct {
2117
Devices []string
2218
MigConfigDevices string
@@ -103,9 +99,9 @@ func loadSpec(path string) (spec *Spec) {
10399
return
104100
}
105101

106-
func isPrivileged(s *Spec) bool {
107-
if s.Process.Capabilities == nil {
108-
return false
102+
func (s *Spec) GetCapabilities() []string {
103+
if s == nil || s.Process == nil || s.Process.Capabilities == nil {
104+
return nil
109105
}
110106

111107
var caps []string
@@ -118,67 +114,22 @@ func isPrivileged(s *Spec) bool {
118114
if err != nil {
119115
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
120116
}
121-
for _, c := range caps {
122-
if c == capSysAdmin {
123-
return true
124-
}
125-
}
126-
return false
117+
return caps
127118
}
128119

129120
// Otherwise, parse s.Process.Capabilities as:
130121
// github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54
131-
process := specs.Process{
132-
Env: s.Process.Env,
133-
}
134-
135-
err := json.Unmarshal(*s.Process.Capabilities, &process.Capabilities)
122+
capabilities := specs.LinuxCapabilities{}
123+
err := json.Unmarshal(*s.Process.Capabilities, &capabilities)
136124
if err != nil {
137125
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
138126
}
139127

140-
fullSpec := specs.Spec{
141-
Version: *s.Version,
142-
Process: &process,
143-
}
144-
145-
return image.IsPrivileged(&fullSpec)
128+
return image.OCISpecCapabilities(capabilities).GetCapabilities()
146129
}
147130

148-
func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []string) []string {
149-
// We check if the image has at least one of the Swarm resource envvars defined and use this
150-
// if specified.
151-
for _, envvar := range swarmResourceEnvvars {
152-
if containerImage.HasEnvvar(envvar) {
153-
return containerImage.DevicesFromEnvvars(swarmResourceEnvvars...).List()
154-
}
155-
}
156-
157-
return containerImage.VisibleDevicesFromEnvVar()
158-
}
159-
160-
func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
161-
// If enabled, try and get the device list from volume mounts first
162-
if hookConfig.AcceptDeviceListAsVolumeMounts {
163-
devices := image.VisibleDevicesFromMounts()
164-
if len(devices) > 0 {
165-
return devices
166-
}
167-
}
168-
169-
// Fallback to reading from the environment variable if privileges are correct
170-
devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars())
171-
if len(devices) == 0 {
172-
return nil
173-
}
174-
if privileged || hookConfig.AcceptEnvvarUnprivileged {
175-
return devices
176-
}
177-
178-
configName := hookConfig.getConfigOption("AcceptEnvvarUnprivileged")
179-
log.Printf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES (privileged=%v, %v=%v) ", privileged, configName, hookConfig.AcceptEnvvarUnprivileged)
180-
181-
return nil
131+
func isPrivileged(s *Spec) bool {
132+
return image.IsPrivileged(s)
182133
}
183134

184135
func getMigConfigDevices(i image.CUDA) *string {
@@ -225,7 +176,6 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
225176
// We use the default driver capabilities by default. This is filtered to only include the
226177
// supported capabilities
227178
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)
228-
229179
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
230180

231181
capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities)
@@ -251,7 +201,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
251201
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
252202
legacyImage := image.IsLegacy()
253203

254-
devices := hookConfig.getDevices(image, privileged)
204+
devices := image.VisibleDevices()
255205
if len(devices) == 0 {
256206
// empty devices means this is not a GPU container.
257207
return nil
@@ -306,20 +256,27 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
306256

307257
s := loadSpec(path.Join(b, "config.json"))
308258

309-
image, err := image.New(
259+
privileged := isPrivileged(s)
260+
261+
opts := []image.Option{
310262
image.WithEnv(s.Process.Env),
311263
image.WithMounts(s.Mounts),
264+
image.WithPrivileged(privileged),
312265
image.WithDisableRequire(hookConfig.DisableRequire),
313-
)
266+
image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts),
267+
image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged),
268+
image.WithAdditionalVisibleDevicesEnvVars(hookConfig.getSwarmResourceEnvvars()...),
269+
}
270+
271+
i, err := image.New(opts...)
314272
if err != nil {
315273
log.Panicln(err)
316274
}
317275

318-
privileged := isPrivileged(s)
319276
return containerConfig{
320277
Pid: h.Pid,
321278
Rootfs: s.Root.Path,
322-
Image: image,
323-
Nvidia: hookConfig.getNvidiaConfig(image, privileged),
279+
Image: i,
280+
Nvidia: hookConfig.getNvidiaConfig(i, privileged),
324281
}
325282
}

cmd/nvidia-container-runtime-hook/container_config_test.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,19 @@ func TestGetNvidiaConfig(t *testing.T) {
477477
}
478478
for _, tc := range tests {
479479
t.Run(tc.description, func(t *testing.T) {
480-
image, _ := image.New(
480+
opts := []image.Option{
481481
image.WithEnvMap(tc.env),
482-
)
482+
image.WithPrivileged(tc.privileged),
483+
image.WithAcceptEnvvarUnprivileged(true),
484+
}
485+
486+
if tc.hookConfig != nil {
487+
if tc.hookConfig.SwarmResource != "" {
488+
opts = append(opts, image.WithAdditionalVisibleDevicesEnvVars(tc.hookConfig.SwarmResource))
489+
}
490+
}
491+
image, _ := image.New(opts...)
492+
483493
// Wrap the call to getNvidiaConfig() in a closure.
484494
var cfg *nvidiaConfig
485495
getConfig := func() {
@@ -622,12 +632,11 @@ func TestDeviceListSourcePriority(t *testing.T) {
622632
},
623633
),
624634
image.WithMounts(tc.mountDevices),
635+
image.WithPrivileged(tc.privileged),
636+
image.WithAcceptDeviceListAsVolumeMounts(tc.acceptMounts),
637+
image.WithAcceptEnvvarUnprivileged(tc.acceptUnprivileged),
625638
)
626-
defaultConfig, _ := config.GetDefault()
627-
cfg := &hookConfig{defaultConfig}
628-
cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
629-
cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
630-
devices = cfg.getDevices(image, tc.privileged)
639+
devices = image.VisibleDevices()
631640
}
632641

633642
// For all other tests, just grab the devices and check the results
@@ -843,10 +852,17 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
843852

844853
for _, tc := range tests {
845854
t.Run(tc.description, func(t *testing.T) {
846-
image, _ := image.New(
855+
opts := []image.Option{
847856
image.WithEnvMap(tc.env),
848-
)
849-
devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars)
857+
image.WithPrivileged(true),
858+
image.WithAcceptDeviceListAsVolumeMounts(false),
859+
image.WithAcceptEnvvarUnprivileged(false),
860+
image.WithAdditionalVisibleDevicesEnvVars(tc.swarmResourceEnvvars...),
861+
}
862+
863+
image, _ := image.New(opts...)
864+
865+
devices := image.VisibleDevices()
850866
require.EqualValues(t, tc.expectedDevices, devices)
851867
})
852868
}

internal/config/image/builder.go

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ import (
2424
)
2525

2626
type builder struct {
27-
env map[string]string
28-
mounts []specs.Mount
27+
CUDA
28+
2929
disableRequire bool
3030
}
3131

32+
// Option is a functional option for creating a CUDA image.
33+
type Option func(*builder) error
34+
3235
// New creates a new CUDA image from the input options.
3336
func New(opt ...Option) (CUDA, error) {
3437
b := &builder{}
@@ -50,15 +53,22 @@ func (b builder) build() (CUDA, error) {
5053
b.env[EnvVarNvidiaDisableRequire] = "true"
5154
}
5255

53-
c := CUDA{
54-
env: b.env,
55-
mounts: b.mounts,
56+
return b.CUDA, nil
57+
}
58+
59+
func WithAcceptDeviceListAsVolumeMounts(acceptDeviceListAsVolumeMounts bool) Option {
60+
return func(b *builder) error {
61+
b.acceptDeviceListAsVolumeMounts = acceptDeviceListAsVolumeMounts
62+
return nil
5663
}
57-
return c, nil
5864
}
5965

60-
// Option is a functional option for creating a CUDA image.
61-
type Option func(*builder) error
66+
func WithAcceptEnvvarUnprivileged(acceptEnvvarUnprivileged bool) Option {
67+
return func(b *builder) error {
68+
b.acceptEnvvarUnprivileged = acceptEnvvarUnprivileged
69+
return nil
70+
}
71+
}
6272

6373
// WithDisableRequire sets the disable require option.
6474
func WithDisableRequire(disableRequire bool) Option {
@@ -100,3 +110,33 @@ func WithMounts(mounts []specs.Mount) Option {
100110
return nil
101111
}
102112
}
113+
114+
// WithPrivileged sets whether an image is privileged or not.
115+
func WithPrivileged(isPrivileged bool) Option {
116+
return func(b *builder) error {
117+
b.isPrivileged = isPrivileged
118+
return nil
119+
}
120+
}
121+
122+
// WithAdditionalVisibleDevicesEnvVars sets the visible devices environment variables.
123+
// If the resource is a single string, it is split by comma and the resulting
124+
// strings are added to the visible devices environment variables.
125+
func WithAdditionalVisibleDevicesEnvVars(visibleDevicesEnvVars ...string) Option {
126+
return func(b *builder) error {
127+
// if resource is a single string, split it by comma
128+
if len(visibleDevicesEnvVars) == 1 && strings.Contains(visibleDevicesEnvVars[0], ",") {
129+
candidates := strings.Split(visibleDevicesEnvVars[0], ",")
130+
for _, c := range candidates {
131+
trimmed := strings.TrimSpace(c)
132+
if len(trimmed) > 0 {
133+
b.additionalVisibleDevicesEnvVars = append(b.additionalVisibleDevicesEnvVars, trimmed)
134+
}
135+
}
136+
return nil
137+
}
138+
139+
b.additionalVisibleDevicesEnvVars = append(b.additionalVisibleDevicesEnvVars, visibleDevicesEnvVars...)
140+
return nil
141+
}
142+
}

internal/config/image/cuda_image.go

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,15 @@ const (
3838
// a map of environment variable to values that can be used to perform lookups
3939
// such as requirements.
4040
type CUDA struct {
41-
env map[string]string
41+
additionalVisibleDevicesEnvVars []string
42+
43+
env map[string]string
44+
4245
mounts []specs.Mount
46+
47+
acceptDeviceListAsVolumeMounts bool
48+
acceptEnvvarUnprivileged bool
49+
isPrivileged bool
4350
}
4451

4552
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
@@ -53,12 +60,13 @@ func NewCUDAImageFromSpec(spec *specs.Spec) (CUDA, error) {
5360
return New(
5461
WithEnv(env),
5562
WithMounts(spec.Mounts),
63+
WithPrivileged(IsPrivileged((*OCISpec)(spec))),
5664
)
5765
}
5866

59-
// NewCUDAImageFromEnv creates a CUDA image from the input environment. The environment
67+
// newCUDAImageFromEnv creates a CUDA image from the input environment. The environment
6068
// is a list of strings of the form ENVAR=VALUE.
61-
func NewCUDAImageFromEnv(env []string) (CUDA, error) {
69+
func newCUDAImageFromEnv(env []string) (CUDA, error) {
6270
return New(WithEnv(env))
6371
}
6472

@@ -155,7 +163,7 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
155163

156164
// GetDriverCapabilities returns the requested driver capabilities.
157165
func (i CUDA) GetDriverCapabilities() DriverCapabilities {
158-
env := i.env[EnvVarNvidiaDriverCapabilities]
166+
env := i.Getenv(EnvVarNvidiaDriverCapabilities)
159167

160168
capabilities := make(DriverCapabilities)
161169
for _, c := range strings.Split(env, ",") {
@@ -166,7 +174,7 @@ func (i CUDA) GetDriverCapabilities() DriverCapabilities {
166174
}
167175

168176
func (i CUDA) legacyVersion() (string, error) {
169-
cudaVersion := i.env[EnvVarCudaVersion]
177+
cudaVersion := i.Getenv(EnvVarCudaVersion)
170178
majorMinor, err := parseMajorMinorVersion(cudaVersion)
171179
if err != nil {
172180
return "", fmt.Errorf("invalid CUDA version %v: %v", cudaVersion, err)
@@ -217,13 +225,19 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
217225
}
218226

219227
// VisibleDevicesFromEnvVar returns the set of visible devices requested through
220-
// the NVIDIA_VISIBLE_DEVICES environment variable.
228+
// the NVIDIA_VISIBLE_DEVICES environment variable or any variables specified
229+
// in visibleDevicesEnvVars.
221230
func (i CUDA) VisibleDevicesFromEnvVar() []string {
231+
for _, envVar := range i.additionalVisibleDevicesEnvVars {
232+
if i.HasEnvvar(envVar) {
233+
return i.DevicesFromEnvvars(i.additionalVisibleDevicesEnvVars...).List()
234+
}
235+
}
222236
return i.DevicesFromEnvvars(EnvVarNvidiaVisibleDevices).List()
223237
}
224238

225-
// VisibleDevicesFromMounts returns the set of visible devices requested as mounts.
226-
func (i CUDA) VisibleDevicesFromMounts() []string {
239+
// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
240+
func (i CUDA) visibleDevicesFromMounts() []string {
227241
var devices []string
228242
for _, device := range i.DevicesFromMounts() {
229243
switch {
@@ -238,7 +252,6 @@ func (i CUDA) VisibleDevicesFromMounts() []string {
238252
}
239253

240254
// DevicesFromMounts returns a list of device specified as mounts.
241-
// TODO: This should be merged with getDevicesFromMounts used in the NVIDIA Container Runtime
242255
func (i CUDA) DevicesFromMounts() []string {
243256
root := filepath.Clean(DeviceListAsVolumeMountsRoot)
244257
seen := make(map[string]bool)
@@ -271,6 +284,28 @@ func (i CUDA) DevicesFromMounts() []string {
271284
return devices
272285
}
273286

287+
func (i CUDA) VisibleDevices() []string {
288+
// If enabled, try and get the device list from volume mounts first
289+
if i.acceptDeviceListAsVolumeMounts {
290+
devices := i.visibleDevicesFromMounts()
291+
if len(devices) > 0 {
292+
return devices
293+
}
294+
}
295+
296+
// Fallback to reading from the environment variable if privileges are correct
297+
devices := i.VisibleDevicesFromEnvVar()
298+
if len(devices) == 0 {
299+
return nil
300+
}
301+
302+
if i.isPrivileged || i.acceptEnvvarUnprivileged {
303+
return devices
304+
}
305+
306+
return nil
307+
}
308+
274309
// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
275310
func (i CUDA) CDIDevicesFromMounts() []string {
276311
var devices []string

0 commit comments

Comments
 (0)