Skip to content

Commit 88c37fa

Browse files
committed
Handle --gpus flag using CDI
This change switches to using CDI to handle the --gpus flag. This removes the custom implementation that invoked the nvidia-container-cli directly. This mechanism does not align with existing implementations. Signed-off-by: Evan Lezar <[email protected]>
1 parent 924e283 commit 88c37fa

File tree

1 file changed

+18
-43
lines changed

1 file changed

+18
-43
lines changed

pkg/cmd/container/run_linux.go

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"github.com/opencontainers/runtime-spec/specs-go"
2626

2727
containerd "github.com/containerd/containerd/v2/client"
28-
"github.com/containerd/containerd/v2/contrib/nvidia"
2928
"github.com/containerd/containerd/v2/core/containers"
3029
"github.com/containerd/containerd/v2/pkg/oci"
3130
"github.com/containerd/log"
@@ -99,7 +98,7 @@ func setPlatformOptions(ctx context.Context, client *containerd.Client, id, uts
9998
if options.Sysctl != nil {
10099
opts = append(opts, WithSysctls(strutil.ConvertKVStringsToMap(options.Sysctl)))
101100
}
102-
gpuOpt, err := parseGPUOpts(options.GPUs)
101+
gpuOpt, err := parseGPUOpts(options.GOptions.CDISpecDirs, options.GPUs)
103102
if err != nil {
104103
return nil, err
105104
}
@@ -262,60 +261,36 @@ func withOOMScoreAdj(score int) oci.SpecOpts {
262261
}
263262
}
264263

265-
func parseGPUOpts(value []string) (res []oci.SpecOpts, _ error) {
264+
func parseGPUOpts(cdiSpecDirs []string, value []string) (res []oci.SpecOpts, _ error) {
266265
for _, gpu := range value {
267-
gpuOpt, err := parseGPUOpt(gpu)
266+
req, err := ParseGPUOptCSV(gpu)
268267
if err != nil {
269268
return nil, err
270269
}
271-
res = append(res, gpuOpt)
270+
res = append(res, withCDIDevices(cdiSpecDirs, req.toCDIDeviceIDS()...))
272271
}
273272
return res, nil
274273
}
275274

276-
func parseGPUOpt(value string) (oci.SpecOpts, error) {
277-
req, err := ParseGPUOptCSV(value)
278-
if err != nil {
279-
return nil, err
275+
func (req *GPUReq) toCDIDeviceIDS() []string {
276+
var cdiDeviceIDs []string
277+
for _, id := range req.normalizeDeviceIDs() {
278+
cdiDeviceIDs = append(cdiDeviceIDs, "nvidia.com/gpu="+id)
280279
}
280+
return cdiDeviceIDs
281+
}
281282

282-
var gpuOpts []nvidia.Opts
283-
283+
func (req *GPUReq) normalizeDeviceIDs() []string {
284284
if len(req.DeviceIDs) > 0 {
285-
gpuOpts = append(gpuOpts, nvidia.WithDeviceUUIDs(req.DeviceIDs...))
286-
} else if req.Count > 0 {
287-
var devices []int
288-
for i := 0; i < req.Count; i++ {
289-
devices = append(devices, i)
290-
}
291-
gpuOpts = append(gpuOpts, nvidia.WithDevices(devices...))
292-
} else if req.Count < 0 {
293-
gpuOpts = append(gpuOpts, nvidia.WithAllDevices)
285+
return req.DeviceIDs
294286
}
295-
296-
str2cap := make(map[string]nvidia.Capability)
297-
for _, c := range nvidia.AllCaps() {
298-
str2cap[string(c)] = c
299-
}
300-
var nvidiaCaps []nvidia.Capability
301-
for _, c := range req.Capabilities {
302-
if cp, isNvidiaCap := str2cap[c]; isNvidiaCap {
303-
nvidiaCaps = append(nvidiaCaps, cp)
304-
}
287+
if req.Count < 0 {
288+
return []string{"all"}
305289
}
306-
if len(nvidiaCaps) != 0 {
307-
gpuOpts = append(gpuOpts, nvidia.WithCapabilities(nvidiaCaps...))
308-
} else {
309-
// Add "utility", "compute" capability if unset.
310-
// Please see also: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#driver-capabilities
311-
gpuOpts = append(gpuOpts, nvidia.WithCapabilities(nvidia.Utility, nvidia.Compute))
312-
}
313-
314-
if rootlessutil.IsRootless() {
315-
// "--no-cgroups" option is needed to nvidia-container-cli in rootless environment
316-
// Please see also: https://github.com/moby/moby/issues/38729#issuecomment-463493866
317-
gpuOpts = append(gpuOpts, nvidia.WithNoCgroups)
290+
var ids []string
291+
for i := 0; i < req.Count; i++ {
292+
ids = append(ids, fmt.Sprintf("%d", i))
318293
}
319294

320-
return nvidia.WithGPUs(gpuOpts...), nil
295+
return ids
321296
}

0 commit comments

Comments
 (0)