Skip to content

Commit 02a486e

Browse files
Unify GetDevices logic at internal/config/image
Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]> Signed-off-by: Evan Lezar <[email protected]>
1 parent bb3a54f commit 02a486e

File tree

10 files changed

+267
-203
lines changed

10 files changed

+267
-203
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 25 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ import (
1313
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
1414
)
1515

16-
const (
17-
capSysAdmin = "CAP_SYS_ADMIN"
18-
)
19-
2016
type nvidiaConfig struct {
2117
Devices []string
2218
MigConfigDevices string
@@ -103,9 +99,9 @@ func loadSpec(path string) (spec *Spec) {
10399
return
104100
}
105101

106-
func isPrivileged(s *Spec) bool {
107-
if s.Process.Capabilities == nil {
108-
return false
102+
func (s *Spec) GetCapabilities() []string {
103+
if s == nil || s.Process == nil || s.Process.Capabilities == nil {
104+
return nil
109105
}
110106

111107
var caps []string
@@ -118,67 +114,22 @@ func isPrivileged(s *Spec) bool {
118114
if err != nil {
119115
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
120116
}
121-
for _, c := range caps {
122-
if c == capSysAdmin {
123-
return true
124-
}
125-
}
126-
return false
117+
return caps
127118
}
128119

129120
// Otherwise, parse s.Process.Capabilities as:
130121
// github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54
131-
process := specs.Process{
132-
Env: s.Process.Env,
133-
}
134-
135-
err := json.Unmarshal(*s.Process.Capabilities, &process.Capabilities)
122+
capabilities := specs.LinuxCapabilities{}
123+
err := json.Unmarshal(*s.Process.Capabilities, &capabilities)
136124
if err != nil {
137125
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
138126
}
139127

140-
fullSpec := specs.Spec{
141-
Version: *s.Version,
142-
Process: &process,
143-
}
144-
145-
return image.IsPrivileged(&fullSpec)
146-
}
147-
148-
func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []string) []string {
149-
// We check if the image has at least one of the Swarm resource envvars defined and use this
150-
// if specified.
151-
for _, envvar := range swarmResourceEnvvars {
152-
if containerImage.HasEnvvar(envvar) {
153-
return containerImage.DevicesFromEnvvars(swarmResourceEnvvars...).List()
154-
}
155-
}
156-
157-
return containerImage.VisibleDevicesFromEnvVar()
128+
return image.OCISpecCapabilities(capabilities).GetCapabilities()
158129
}
159130

160-
func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
161-
// If enabled, try and get the device list from volume mounts first
162-
if hookConfig.AcceptDeviceListAsVolumeMounts {
163-
devices := image.VisibleDevicesFromMounts()
164-
if len(devices) > 0 {
165-
return devices
166-
}
167-
}
168-
169-
// Fallback to reading from the environment variable if privileges are correct
170-
devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars())
171-
if len(devices) == 0 {
172-
return nil
173-
}
174-
if privileged || hookConfig.AcceptEnvvarUnprivileged {
175-
return devices
176-
}
177-
178-
configName := hookConfig.getConfigOption("AcceptEnvvarUnprivileged")
179-
log.Printf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES (privileged=%v, %v=%v) ", privileged, configName, hookConfig.AcceptEnvvarUnprivileged)
180-
181-
return nil
131+
func isPrivileged(s *Spec) bool {
132+
return image.IsPrivileged(s)
182133
}
183134

184135
func getMigConfigDevices(i image.CUDA) *string {
@@ -225,7 +176,6 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
225176
// We use the default driver capabilities by default. This is filtered to only include the
226177
// supported capabilities
227178
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)
228-
229179
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
230180

231181
capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities)
@@ -251,7 +201,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
251201
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
252202
legacyImage := image.IsLegacy()
253203

254-
devices := hookConfig.getDevices(image, privileged)
204+
devices := image.VisibleDevices()
255205
if len(devices) == 0 {
256206
// empty devices means this is not a GPU container.
257207
return nil
@@ -306,20 +256,30 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
306256

307257
s := loadSpec(path.Join(b, "config.json"))
308258

309-
image, err := image.New(
259+
privileged := isPrivileged(s)
260+
261+
opts := []image.Option{
310262
image.WithEnv(s.Process.Env),
311263
image.WithMounts(s.Mounts),
264+
image.WithPrivileged(privileged),
312265
image.WithDisableRequire(hookConfig.DisableRequire),
313-
)
266+
image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts),
267+
image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged),
268+
}
269+
270+
if len(hookConfig.getSwarmResourceEnvvars()) > 0 {
271+
opts = append(opts, image.WithVisibleDevicesEnvVars(hookConfig.getSwarmResourceEnvvars()...))
272+
}
273+
274+
i, err := image.New(opts...)
314275
if err != nil {
315276
log.Panicln(err)
316277
}
317278

318-
privileged := isPrivileged(s)
319279
return containerConfig{
320280
Pid: h.Pid,
321281
Rootfs: s.Root.Path,
322-
Image: image,
323-
Nvidia: hookConfig.getNvidiaConfig(image, privileged),
282+
Image: i,
283+
Nvidia: hookConfig.getNvidiaConfig(i, privileged),
324284
}
325285
}

cmd/nvidia-container-runtime-hook/container_config_test.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,19 @@ func TestGetNvidiaConfig(t *testing.T) {
477477
}
478478
for _, tc := range tests {
479479
t.Run(tc.description, func(t *testing.T) {
480-
image, _ := image.New(
480+
opts := []image.Option{
481481
image.WithEnvMap(tc.env),
482-
)
482+
image.WithPrivileged(tc.privileged),
483+
image.WithAcceptEnvvarUnprivileged(true),
484+
}
485+
486+
if tc.hookConfig != nil {
487+
if tc.hookConfig.SwarmResource != "" {
488+
opts = append(opts, image.WithVisibleDevicesEnvVars(tc.hookConfig.SwarmResource))
489+
}
490+
}
491+
image, _ := image.New(opts...)
492+
483493
// Wrap the call to getNvidiaConfig() in a closure.
484494
var cfg *nvidiaConfig
485495
getConfig := func() {
@@ -622,12 +632,15 @@ func TestDeviceListSourcePriority(t *testing.T) {
622632
},
623633
),
624634
image.WithMounts(tc.mountDevices),
635+
image.WithPrivileged(tc.privileged),
636+
image.WithAcceptDeviceListAsVolumeMounts(tc.acceptMounts),
637+
image.WithAcceptEnvvarUnprivileged(tc.acceptUnprivileged),
625638
)
626639
defaultConfig, _ := config.GetDefault()
627640
cfg := &hookConfig{defaultConfig}
628641
cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
629642
cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
630-
devices = cfg.getDevices(image, tc.privileged)
643+
devices = image.VisibleDevices()
631644
}
632645

633646
// For all other tests, just grab the devices and check the results
@@ -843,10 +856,20 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
843856

844857
for _, tc := range tests {
845858
t.Run(tc.description, func(t *testing.T) {
846-
image, _ := image.New(
859+
opts := []image.Option{
847860
image.WithEnvMap(tc.env),
848-
)
849-
devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars)
861+
image.WithPrivileged(true),
862+
image.WithAcceptDeviceListAsVolumeMounts(false),
863+
image.WithAcceptEnvvarUnprivileged(false),
864+
}
865+
866+
if len(tc.swarmResourceEnvvars) > 0 {
867+
opts = append(opts, image.WithVisibleDevicesEnvVars(tc.swarmResourceEnvvars...))
868+
}
869+
870+
image, _ := image.New(opts...)
871+
872+
devices := image.VisibleDevices()
850873
require.EqualValues(t, tc.expectedDevices, devices)
851874
})
852875
}

cmd/nvidia-container-runtime-hook/hook_test.go

Lines changed: 0 additions & 89 deletions
This file was deleted.

internal/config/image/builder.go

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ import (
2424
)
2525

2626
type builder struct {
27-
env map[string]string
28-
mounts []specs.Mount
27+
CUDA
28+
2929
disableRequire bool
3030
}
3131

32+
// Option is a functional option for creating a CUDA image.
33+
type Option func(*builder) error
34+
3235
// New creates a new CUDA image from the input options.
3336
func New(opt ...Option) (CUDA, error) {
3437
b := &builder{}
@@ -50,15 +53,36 @@ func (b builder) build() (CUDA, error) {
5053
b.env[EnvVarNvidiaDisableRequire] = "true"
5154
}
5255

53-
c := CUDA{
54-
env: b.env,
55-
mounts: b.mounts,
56+
return b.CUDA, nil
57+
}
58+
59+
func WithAnnotationPrefixes(annotationPrefixes []string) Option {
60+
return func(b *builder) error {
61+
b.annotationPrefixes = annotationPrefixes
62+
return nil
5663
}
57-
return c, nil
5864
}
5965

60-
// Option is a functional option for creating a CUDA image.
61-
type Option func(*builder) error
66+
func WithAnnotations(annotations map[string]string) Option {
67+
return func(b *builder) error {
68+
b.annotations = annotations
69+
return nil
70+
}
71+
}
72+
73+
func WithAcceptDeviceListAsVolumeMounts(acceptDeviceListAsVolumeMounts bool) Option {
74+
return func(b *builder) error {
75+
b.acceptDeviceListAsVolumeMounts = acceptDeviceListAsVolumeMounts
76+
return nil
77+
}
78+
}
79+
80+
func WithAcceptEnvvarUnprivileged(acceptEnvvarUnprivileged bool) Option {
81+
return func(b *builder) error {
82+
b.acceptEnvvarUnprivileged = acceptEnvvarUnprivileged
83+
return nil
84+
}
85+
}
6286

6387
// WithDisableRequire sets the disable require option.
6488
func WithDisableRequire(disableRequire bool) Option {
@@ -100,3 +124,35 @@ func WithMounts(mounts []specs.Mount) Option {
100124
return nil
101125
}
102126
}
127+
128+
// WithPrivileged sets whether an image is privileged or not.
129+
func WithPrivileged(isPrivileged bool) Option {
130+
return func(b *builder) error {
131+
b.isPrivileged = isPrivileged
132+
return nil
133+
}
134+
}
135+
136+
// WithVisibleDevicesEnvVars sets the swarm resource for the CUDA image.
137+
func WithVisibleDevicesEnvVars(visibleDevicesEnvVars ...string) Option {
138+
return func(b *builder) error {
139+
if len(visibleDevicesEnvVars) == 0 {
140+
return fmt.Errorf("visible devices env vars cannot be empty")
141+
}
142+
b.visibleDevicesEnvVars = []string{}
143+
// if resource is a single string, split it by comma
144+
if len(visibleDevicesEnvVars) == 1 && strings.Contains(visibleDevicesEnvVars[0], ",") {
145+
candidates := strings.Split(visibleDevicesEnvVars[0], ",")
146+
for _, c := range candidates {
147+
trimmed := strings.TrimSpace(c)
148+
if len(trimmed) > 0 {
149+
b.visibleDevicesEnvVars = append(b.visibleDevicesEnvVars, trimmed)
150+
}
151+
}
152+
return nil
153+
}
154+
155+
b.visibleDevicesEnvVars = append(b.visibleDevicesEnvVars, visibleDevicesEnvVars...)
156+
return nil
157+
}
158+
}

0 commit comments

Comments
 (0)