Skip to content

Commit 2968ef1

Browse files
Unify GetDevices logic at internal/config/image
Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]> Signed-off-by: Evan Lezar <[email protected]>
1 parent d59fd3d commit 2968ef1

File tree

10 files changed

+298
-107
lines changed

10 files changed

+298
-107
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 22 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ import (
1313
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
1414
)
1515

16-
const (
17-
capSysAdmin = "CAP_SYS_ADMIN"
18-
)
19-
2016
type nvidiaConfig struct {
2117
Devices []string
2218
MigConfigDevices string
@@ -103,9 +99,9 @@ func loadSpec(path string) (spec *Spec) {
10399
return
104100
}
105101

106-
func isPrivileged(s *Spec) bool {
107-
if s.Process.Capabilities == nil {
108-
return false
102+
func (s *Spec) GetCapabilities() []string {
103+
if s == nil || s.Process == nil || s.Process.Capabilities == nil {
104+
return nil
109105
}
110106

111107
var caps []string
@@ -118,67 +114,22 @@ func isPrivileged(s *Spec) bool {
118114
if err != nil {
119115
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
120116
}
121-
for _, c := range caps {
122-
if c == capSysAdmin {
123-
return true
124-
}
125-
}
126-
return false
117+
return caps
127118
}
128119

129120
// Otherwise, parse s.Process.Capabilities as:
130121
// github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54
131-
process := specs.Process{
132-
Env: s.Process.Env,
133-
}
134-
135-
err := json.Unmarshal(*s.Process.Capabilities, &process.Capabilities)
122+
capabilities := specs.LinuxCapabilities{}
123+
err := json.Unmarshal(*s.Process.Capabilities, &capabilities)
136124
if err != nil {
137125
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
138126
}
139127

140-
fullSpec := specs.Spec{
141-
Version: *s.Version,
142-
Process: &process,
143-
}
144-
145-
return image.IsPrivileged(&fullSpec)
128+
return image.OCISpecCapabilities(capabilities).GetCapabilities()
146129
}
147130

148-
func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []string) []string {
149-
// We check if the image has at least one of the Swarm resource envvars defined and use this
150-
// if specified.
151-
for _, envvar := range swarmResourceEnvvars {
152-
if containerImage.HasEnvvar(envvar) {
153-
return containerImage.DevicesFromEnvvars(swarmResourceEnvvars...).List()
154-
}
155-
}
156-
157-
return containerImage.VisibleDevicesFromEnvVar()
158-
}
159-
160-
func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
161-
// If enabled, try and get the device list from volume mounts first
162-
if hookConfig.AcceptDeviceListAsVolumeMounts {
163-
devices := image.VisibleDevicesFromMounts()
164-
if len(devices) > 0 {
165-
return devices
166-
}
167-
}
168-
169-
// Fallback to reading from the environment variable if privileges are correct
170-
devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars())
171-
if len(devices) == 0 {
172-
return nil
173-
}
174-
if privileged || hookConfig.AcceptEnvvarUnprivileged {
175-
return devices
176-
}
177-
178-
configName := hookConfig.getConfigOption("AcceptEnvvarUnprivileged")
179-
log.Printf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES (privileged=%v, %v=%v) ", privileged, configName, hookConfig.AcceptEnvvarUnprivileged)
180-
181-
return nil
131+
func isPrivileged(s *Spec) bool {
132+
return image.IsPrivileged(s)
182133
}
183134

184135
func getMigConfigDevices(i image.CUDA) *string {
@@ -225,7 +176,6 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
225176
// We use the default driver capabilities by default. This is filtered to only include the
226177
// supported capabilities
227178
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)
228-
229179
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
230180

231181
capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities)
@@ -251,7 +201,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
251201
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
252202
legacyImage := image.IsLegacy()
253203

254-
devices := hookConfig.getDevices(image, privileged)
204+
devices := image.VisibleDevices()
255205
if len(devices) == 0 {
256206
// empty devices means this is not a GPU container.
257207
return nil
@@ -306,20 +256,27 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
306256

307257
s := loadSpec(path.Join(b, "config.json"))
308258

309-
image, err := image.New(
259+
privileged := isPrivileged(s)
260+
261+
opts := []image.Option{
310262
image.WithEnv(s.Process.Env),
311263
image.WithMounts(s.Mounts),
264+
image.WithPrivileged(privileged),
312265
image.WithDisableRequire(hookConfig.DisableRequire),
313-
)
266+
image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts),
267+
image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged),
268+
image.WithAdditionalVisibleDevicesEnvVars(hookConfig.getSwarmResourceEnvvars()...),
269+
}
270+
271+
i, err := image.New(opts...)
314272
if err != nil {
315273
log.Panicln(err)
316274
}
317275

318-
privileged := isPrivileged(s)
319276
return containerConfig{
320277
Pid: h.Pid,
321278
Rootfs: s.Root.Path,
322-
Image: image,
323-
Nvidia: hookConfig.getNvidiaConfig(image, privileged),
279+
Image: i,
280+
Nvidia: hookConfig.getNvidiaConfig(i, privileged),
324281
}
325282
}

cmd/nvidia-container-runtime-hook/container_config_test.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,19 @@ func TestGetNvidiaConfig(t *testing.T) {
477477
}
478478
for _, tc := range tests {
479479
t.Run(tc.description, func(t *testing.T) {
480-
image, _ := image.New(
480+
opts := []image.Option{
481481
image.WithEnvMap(tc.env),
482-
)
482+
image.WithPrivileged(tc.privileged),
483+
image.WithAcceptEnvvarUnprivileged(true),
484+
}
485+
486+
if tc.hookConfig != nil {
487+
if tc.hookConfig.SwarmResource != "" {
488+
opts = append(opts, image.WithAdditionalVisibleDevicesEnvVars(tc.hookConfig.SwarmResource))
489+
}
490+
}
491+
image, _ := image.New(opts...)
492+
483493
// Wrap the call to getNvidiaConfig() in a closure.
484494
var cfg *nvidiaConfig
485495
getConfig := func() {
@@ -622,12 +632,11 @@ func TestDeviceListSourcePriority(t *testing.T) {
622632
},
623633
),
624634
image.WithMounts(tc.mountDevices),
635+
image.WithPrivileged(tc.privileged),
636+
image.WithAcceptDeviceListAsVolumeMounts(tc.acceptMounts),
637+
image.WithAcceptEnvvarUnprivileged(tc.acceptUnprivileged),
625638
)
626-
defaultConfig, _ := config.GetDefault()
627-
cfg := &hookConfig{defaultConfig}
628-
cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
629-
cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
630-
devices = cfg.getDevices(image, tc.privileged)
639+
devices = image.VisibleDevices()
631640
}
632641

633642
// For all other tests, just grab the devices and check the results
@@ -843,10 +852,17 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
843852

844853
for _, tc := range tests {
845854
t.Run(tc.description, func(t *testing.T) {
846-
image, _ := image.New(
855+
opts := []image.Option{
847856
image.WithEnvMap(tc.env),
848-
)
849-
devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars)
857+
image.WithPrivileged(true),
858+
image.WithAcceptDeviceListAsVolumeMounts(false),
859+
image.WithAcceptEnvvarUnprivileged(false),
860+
image.WithAdditionalVisibleDevicesEnvVars(tc.swarmResourceEnvvars...),
861+
}
862+
863+
image, _ := image.New(opts...)
864+
865+
devices := image.VisibleDevices()
850866
require.EqualValues(t, tc.expectedDevices, devices)
851867
})
852868
}

internal/config/image/builder.go

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ import (
2424
)
2525

2626
type builder struct {
27-
env map[string]string
28-
mounts []specs.Mount
27+
CUDA
28+
2929
disableRequire bool
3030
}
3131

32+
// Option is a functional option for creating a CUDA image.
33+
type Option func(*builder) error
34+
3235
// New creates a new CUDA image from the input options.
3336
func New(opt ...Option) (CUDA, error) {
3437
b := &builder{}
@@ -50,15 +53,22 @@ func (b builder) build() (CUDA, error) {
5053
b.env[EnvVarNvidiaDisableRequire] = "true"
5154
}
5255

53-
c := CUDA{
54-
env: b.env,
55-
mounts: b.mounts,
56+
return b.CUDA, nil
57+
}
58+
59+
func WithAcceptDeviceListAsVolumeMounts(acceptDeviceListAsVolumeMounts bool) Option {
60+
return func(b *builder) error {
61+
b.acceptDeviceListAsVolumeMounts = acceptDeviceListAsVolumeMounts
62+
return nil
5663
}
57-
return c, nil
5864
}
5965

60-
// Option is a functional option for creating a CUDA image.
61-
type Option func(*builder) error
66+
func WithAcceptEnvvarUnprivileged(acceptEnvvarUnprivileged bool) Option {
67+
return func(b *builder) error {
68+
b.acceptEnvvarUnprivileged = acceptEnvvarUnprivileged
69+
return nil
70+
}
71+
}
6272

6373
// WithDisableRequire sets the disable require option.
6474
func WithDisableRequire(disableRequire bool) Option {
@@ -100,3 +110,32 @@ func WithMounts(mounts []specs.Mount) Option {
100110
return nil
101111
}
102112
}
113+
114+
// WithPrivileged sets whether an image is privileged or not.
115+
func WithPrivileged(isPrivileged bool) Option {
116+
return func(b *builder) error {
117+
b.isPrivileged = isPrivileged
118+
return nil
119+
}
120+
}
121+
122+
// WithAdditionalVisibleDevicesEnvVars sets the visible devices environment variables.
123+
// For each envvar passed:
124+
// - Split on commas
125+
// - Trim spaces and ignore empty strings
126+
// - Concatenate all of this
127+
func WithAdditionalVisibleDevicesEnvVars(visibleDevicesEnvVars ...string) Option {
128+
return func(b *builder) error {
129+
var result []string
130+
for _, v := range visibleDevicesEnvVars {
131+
for _, c := range strings.Split(v, ",") {
132+
trimmed := strings.TrimSpace(c)
133+
if trimmed != "" {
134+
result = append(result, trimmed)
135+
}
136+
}
137+
}
138+
b.additionalVisibleDevicesEnvVars = result
139+
return nil
140+
}
141+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package image
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestWithAdditionalVisibleDevicesEnvVars(t *testing.T) {
10+
testCases := []struct {
11+
desc string
12+
input []string
13+
expect []string
14+
}{
15+
{
16+
desc: "single value, no comma",
17+
input: []string{"FOO"},
18+
expect: []string{"FOO"},
19+
},
20+
{
21+
desc: "single value, with spaces",
22+
input: []string{" BAR "},
23+
expect: []string{"BAR"},
24+
},
25+
{
26+
desc: "single string, comma separated",
27+
input: []string{"A,B,C"},
28+
expect: []string{"A", "B", "C"},
29+
},
30+
{
31+
desc: "single string, comma separated with spaces",
32+
input: []string{" A , B, C , "},
33+
expect: []string{"A", "B", "C"},
34+
},
35+
{
36+
desc: "multiple values, no commas",
37+
input: []string{"A", "B", "C"},
38+
expect: []string{"A", "B", "C"},
39+
},
40+
{
41+
desc: "multiple values, some empty",
42+
input: []string{"A", "", "C"},
43+
expect: []string{"A", "C"},
44+
},
45+
{
46+
desc: "multiple values, one is comma separated",
47+
input: []string{"A,B", "C"},
48+
expect: []string{"A", "B", "C"},
49+
},
50+
{
51+
desc: "single string, only spaces and commas",
52+
input: []string{" , , , "},
53+
expect: nil,
54+
},
55+
}
56+
57+
for _, tc := range testCases {
58+
t.Run(tc.desc, func(t *testing.T) {
59+
b := &builder{}
60+
err := WithAdditionalVisibleDevicesEnvVars(tc.input...)(b)
61+
require.NoError(t, err)
62+
if tc.expect == nil {
63+
require.Nil(t, b.additionalVisibleDevicesEnvVars)
64+
} else {
65+
require.Equal(t, tc.expect, b.additionalVisibleDevicesEnvVars)
66+
}
67+
})
68+
}
69+
}

0 commit comments

Comments
 (0)