@@ -25,7 +25,6 @@ import (
2525 "github.com/opencontainers/runtime-spec/specs-go"
2626
2727 containerd "github.com/containerd/containerd/v2/client"
28- "github.com/containerd/containerd/v2/contrib/nvidia"
2928 "github.com/containerd/containerd/v2/core/containers"
3029 "github.com/containerd/containerd/v2/pkg/oci"
3130 "github.com/containerd/log"
@@ -99,7 +98,7 @@ func setPlatformOptions(ctx context.Context, client *containerd.Client, id, uts
9998 if options .Sysctl != nil {
10099 opts = append (opts , WithSysctls (strutil .ConvertKVStringsToMap (options .Sysctl )))
101100 }
102- gpuOpt , err := parseGPUOpts (options .GPUs )
101+ gpuOpt , err := parseGPUOpts (options .GOptions . CDISpecDirs , options . GPUs )
103102 if err != nil {
104103 return nil , err
105104 }
@@ -262,60 +261,36 @@ func withOOMScoreAdj(score int) oci.SpecOpts {
262261 }
263262}
264263
265- func parseGPUOpts (value []string ) (res []oci.SpecOpts , _ error ) {
264+ func parseGPUOpts (cdiSpecDirs [] string , value []string ) (res []oci.SpecOpts , _ error ) {
266265 for _ , gpu := range value {
267- gpuOpt , err := parseGPUOpt (gpu )
266+ req , err := ParseGPUOptCSV (gpu )
268267 if err != nil {
269268 return nil , err
270269 }
271- res = append (res , gpuOpt )
270+ res = append (res , withCDIDevices ( cdiSpecDirs , req . toCDIDeviceIDS () ... ) )
272271 }
273272 return res , nil
274273}
275274
276- func parseGPUOpt ( value string ) (oci. SpecOpts , error ) {
277- req , err := ParseGPUOptCSV ( value )
278- if err != nil {
279- return nil , err
275+ func ( req * GPUReq ) toCDIDeviceIDS () [] string {
276+ var cdiDeviceIDs [] string
277+ for _ , id := range req . normalizeDeviceIDs () {
278+ cdiDeviceIDs = append ( cdiDeviceIDs , "nvidia.com/gpu=" + id )
280279 }
280+ return cdiDeviceIDs
281+ }
281282
282- var gpuOpts []nvidia.Opts
283-
283+ func (req * GPUReq ) normalizeDeviceIDs () []string {
284284 if len (req .DeviceIDs ) > 0 {
285- gpuOpts = append (gpuOpts , nvidia .WithDeviceUUIDs (req .DeviceIDs ... ))
286- } else if req .Count > 0 {
287- var devices []int
288- for i := 0 ; i < req .Count ; i ++ {
289- devices = append (devices , i )
290- }
291- gpuOpts = append (gpuOpts , nvidia .WithDevices (devices ... ))
292- } else if req .Count < 0 {
293- gpuOpts = append (gpuOpts , nvidia .WithAllDevices )
285+ return req .DeviceIDs
294286 }
295-
296- str2cap := make (map [string ]nvidia.Capability )
297- for _ , c := range nvidia .AllCaps () {
298- str2cap [string (c )] = c
299- }
300- var nvidiaCaps []nvidia.Capability
301- for _ , c := range req .Capabilities {
302- if cp , isNvidiaCap := str2cap [c ]; isNvidiaCap {
303- nvidiaCaps = append (nvidiaCaps , cp )
304- }
287+ if req .Count < 0 {
288+ return []string {"all" }
305289 }
306- if len (nvidiaCaps ) != 0 {
307- gpuOpts = append (gpuOpts , nvidia .WithCapabilities (nvidiaCaps ... ))
308- } else {
309- // Add "utility", "compute" capability if unset.
310- // Please see also: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#driver-capabilities
311- gpuOpts = append (gpuOpts , nvidia .WithCapabilities (nvidia .Utility , nvidia .Compute ))
312- }
313-
314- if rootlessutil .IsRootless () {
315- // "--no-cgroups" option is needed to nvidia-container-cli in rootless environment
316- // Please see also: https://github.com/moby/moby/issues/38729#issuecomment-463493866
317- gpuOpts = append (gpuOpts , nvidia .WithNoCgroups )
290+ var ids []string
291+ for i := 0 ; i < req .Count ; i ++ {
292+ ids = append (ids , fmt .Sprintf ("%d" , i ))
318293 }
319294
320- return nvidia . WithGPUs ( gpuOpts ... ), nil
295+ return ids
321296}
0 commit comments