Skip to content

Commit b5bea48

Browse files
Unify GetDevices logic at internal/config/image
Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]>
1 parent b934c68 commit b5bea48

File tree

9 files changed

+231
-175
lines changed

9 files changed

+231
-175
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 45 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ type HookState struct {
8181
BundlePath string `json:"bundlePath"`
8282
}
8383

84-
func loadSpec(path string) (spec *Spec) {
84+
func loadSpec(path string) *specs.Spec {
85+
var spec Spec
86+
8587
f, err := os.Open(path)
8688
if err != nil {
8789
log.Panicln("could not open OCI spec:", err)
@@ -100,85 +102,57 @@ func loadSpec(path string) (spec *Spec) {
100102
if spec.Root == nil {
101103
log.Panicln("Root is empty in OCI spec")
102104
}
103-
return
104-
}
105105

106-
func isPrivileged(s *Spec) bool {
107-
if s.Process.Capabilities == nil {
108-
return false
106+
process := specs.Process{
107+
Env: spec.Process.Env,
109108
}
110-
111109
var caps []string
112110
// If v1.0.0-rc1 <= OCI version < v1.0.0-rc5 parse s.Process.Capabilities as:
113111
// github.com/opencontainers/runtime-spec/blob/v1.0.0-rc1/specs-go/config.go#L30-L54
114-
rc1cmp := semver.Compare("v"+*s.Version, "v1.0.0-rc1")
115-
rc5cmp := semver.Compare("v"+*s.Version, "v1.0.0-rc5")
112+
rc1cmp := semver.Compare("v"+*spec.Version, "v1.0.0-rc1")
113+
rc5cmp := semver.Compare("v"+*spec.Version, "v1.0.0-rc5")
116114
if (rc1cmp == 1 || rc1cmp == 0) && (rc5cmp == -1) {
117-
err := json.Unmarshal(*s.Process.Capabilities, &caps)
115+
err := json.Unmarshal(*spec.Process.Capabilities, &caps)
118116
if err != nil {
119117
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
120118
}
121119
for _, c := range caps {
122120
if c == capSysAdmin {
123-
return true
121+
process.Capabilities = &specs.LinuxCapabilities{
122+
Bounding: caps,
123+
}
124+
break
124125
}
125126
}
126-
return false
127-
}
128-
129-
// Otherwise, parse s.Process.Capabilities as:
130-
// github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54
131-
process := specs.Process{
132-
Env: s.Process.Env,
133-
}
134-
135-
err := json.Unmarshal(*s.Process.Capabilities, &process.Capabilities)
136-
if err != nil {
137-
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
138-
}
139-
140-
fullSpec := specs.Spec{
141-
Version: *s.Version,
142-
Process: &process,
143-
}
144-
145-
return image.IsPrivileged(&fullSpec)
146-
}
147-
148-
func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []string) []string {
149-
// We check if the image has at least one of the Swarm resource envvars defined and use this
150-
// if specified.
151-
for _, envvar := range swarmResourceEnvvars {
152-
if containerImage.HasEnvvar(envvar) {
153-
return containerImage.DevicesFromEnvvars(swarmResourceEnvvars...).List()
127+
} else {
128+
// Otherwise, parse s.Process.Capabilities as:
129+
// github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54
130+
err := json.Unmarshal(*spec.Process.Capabilities, &process.Capabilities)
131+
if err != nil {
132+
log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
154133
}
155134
}
156135

157-
return containerImage.VisibleDevicesFromEnvVar()
158-
}
136+
root := specs.Root{
137+
Path: spec.Root.Path,
138+
}
159139

160-
func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
161-
// If enabled, try and get the device list from volume mounts first
162-
if hookConfig.AcceptDeviceListAsVolumeMounts {
163-
devices := image.VisibleDevicesFromMounts()
164-
if len(devices) > 0 {
165-
return devices
140+
mounts := make([]specs.Mount, len(spec.Mounts))
141+
for i, m := range spec.Mounts {
142+
mounts[i] = specs.Mount{
143+
Source: m.Source,
144+
Destination: m.Destination,
145+
Type: m.Type,
146+
Options: m.Options,
166147
}
167148
}
168149

169-
// Fallback to reading from the environment variable if privileges are correct
170-
devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars())
171-
if len(devices) == 0 {
172-
return nil
173-
}
174-
if privileged || hookConfig.AcceptEnvvarUnprivileged {
175-
return devices
150+
return &specs.Spec{
151+
Version: *spec.Version,
152+
Process: &process,
153+
Root: &root,
154+
Mounts: mounts,
176155
}
177-
178-
configName := hookConfig.getConfigOption("AcceptEnvvarUnprivileged")
179-
log.Printf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES (privileged=%v, %v=%v) ", privileged, configName, hookConfig.AcceptEnvvarUnprivileged)
180-
181-
return nil
182156
}
183157

184158
func getMigConfigDevices(i image.CUDA) *string {
@@ -225,7 +199,6 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
225199
// We use the default driver capabilities by default. This is filtered to only include the
226200
// supported capabilities
227201
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)
228-
229202
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
230203

231204
capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities)
@@ -251,7 +224,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
251224
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
252225
legacyImage := image.IsLegacy()
253226

254-
devices := hookConfig.getDevices(image, privileged)
227+
devices := image.GetDevices(hookConfig.AcceptDeviceListAsVolumeMounts, hookConfig.AcceptEnvvarUnprivileged)
255228
if len(devices) == 0 {
256229
// empty devices means this is not a GPU container.
257230
return nil
@@ -305,21 +278,26 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
305278
}
306279

307280
s := loadSpec(path.Join(b, "config.json"))
308-
309-
image, err := image.New(
281+
opts := []image.Option{
310282
image.WithEnv(s.Process.Env),
311283
image.WithMounts(s.Mounts),
284+
image.WithSpec(s),
312285
image.WithDisableRequire(hookConfig.DisableRequire),
313-
)
286+
}
287+
288+
if len(hookConfig.getSwarmResourceEnvvars()) > 0 {
289+
opts = append(opts, image.WithSwarmResource(hookConfig.getSwarmResourceEnvvars()...))
290+
}
291+
292+
i, err := image.New(opts...)
314293
if err != nil {
315294
log.Panicln(err)
316295
}
317296

318-
privileged := isPrivileged(s)
319297
return containerConfig{
320298
Pid: h.Pid,
321299
Rootfs: s.Root.Path,
322-
Image: image,
323-
Nvidia: hookConfig.getNvidiaConfig(image, privileged),
300+
Image: i,
301+
Nvidia: hookConfig.getNvidiaConfig(i, i.IsPrivileged()),
324302
}
325303
}

cmd/nvidia-container-runtime-hook/container_config_test.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,18 @@ func TestGetNvidiaConfig(t *testing.T) {
477477
}
478478
for _, tc := range tests {
479479
t.Run(tc.description, func(t *testing.T) {
480-
image, _ := image.New(
480+
opts := []image.Option{
481481
image.WithEnvMap(tc.env),
482-
)
482+
image.WithPrivileged(tc.privileged),
483+
}
484+
485+
if tc.hookConfig != nil {
486+
if tc.hookConfig.SwarmResource != "" {
487+
opts = append(opts, image.WithSwarmResource(tc.hookConfig.SwarmResource))
488+
}
489+
}
490+
image, _ := image.New(opts...)
491+
483492
// Wrap the call to getNvidiaConfig() in a closure.
484493
var cfg *nvidiaConfig
485494
getConfig := func() {
@@ -622,12 +631,13 @@ func TestDeviceListSourcePriority(t *testing.T) {
622631
},
623632
),
624633
image.WithMounts(tc.mountDevices),
634+
image.WithPrivileged(tc.privileged),
625635
)
626636
defaultConfig, _ := config.GetDefault()
627637
cfg := &hookConfig{defaultConfig}
628638
cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
629639
cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
630-
devices = cfg.getDevices(image, tc.privileged)
640+
devices = image.GetDevices(tc.acceptMounts, tc.acceptUnprivileged)
631641
}
632642

633643
// For all other tests, just grab the devices and check the results
@@ -843,10 +853,18 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
843853

844854
for _, tc := range tests {
845855
t.Run(tc.description, func(t *testing.T) {
846-
image, _ := image.New(
856+
opts := []image.Option{
847857
image.WithEnvMap(tc.env),
848-
)
849-
devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars)
858+
image.WithPrivileged(true),
859+
}
860+
861+
if len(tc.swarmResourceEnvvars) > 0 {
862+
opts = append(opts, image.WithSwarmResource(tc.swarmResourceEnvvars...))
863+
}
864+
865+
image, _ := image.New(opts...)
866+
867+
devices := image.GetDevices(false, false)
850868
require.EqualValues(t, tc.expectedDevices, devices)
851869
})
852870
}

cmd/nvidia-container-runtime-hook/hook_test.go

Lines changed: 0 additions & 89 deletions
This file was deleted.

0 commit comments

Comments
 (0)