@@ -19,7 +19,6 @@ package main
1919import (
2020 "fmt"
2121 "io"
22- "strconv"
2322 "strings"
2423
2524 "github.com/sirupsen/logrus"
@@ -29,6 +28,7 @@ import (
2928 "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
3029 "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
3130 transformroot "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform/root"
31+ "k8s.io/klog/v2"
3232 cdiapi "tags.cncf.io/container-device-interface/pkg/cdi"
3333 cdiparser "tags.cncf.io/container-device-interface/pkg/parser"
3434 cdispec "tags.cncf.io/container-device-interface/specs-go"
@@ -141,38 +141,33 @@ func NewCDIHandler(opts ...cdiOption) (*CDIHandler, error) {
141141}
142142
143143func (cdi * CDIHandler ) CreateStandardDeviceSpecFile (allocatable AllocatableDevices ) error {
144+ // Initialize NVML in order to get the device edits.
145+ if r := cdi .nvml .Init (); r != nvml .SUCCESS {
146+ return fmt .Errorf ("failed to initialize NVML: %v" , r )
147+ }
148+ defer func () {
149+ if r := cdi .nvml .Shutdown (); r != nvml .SUCCESS {
150+ klog .Warningf ("failed to shutdown NVML: %v" , r )
151+ }
152+ }()
153+
144154 // Generate the set of common edits.
145155 commonEdits , err := cdi .nvcdiDevice .GetCommonEdits ()
146156 if err != nil {
147157 return fmt .Errorf ("failed to get common CDI spec edits: %w" , err )
148158 }
149159
150- // Generate device specs for all full GPUs.
151- var indices []string
160+ // Generate device specs for all full GPUs and MIG devices .
161+ var deviceSpecs []cdispec. Device
152162 for _ , device := range allocatable {
153- indices = append (indices , strconv .Itoa (device .GpuInfo .index ))
154- }
155- deviceSpecs , err := cdi .nvcdiDevice .GetDeviceSpecsByID (indices ... )
156- if err != nil {
157- return fmt .Errorf ("unable to get CDI spec edits for full GPUs: %w" , err )
158- }
159- for i := range deviceSpecs {
160- deviceSpecs [i ].Name = fmt .Sprintf ("gpu-%s" , deviceSpecs [i ].Name )
163+ dspecs , err := cdi .nvcdiDevice .GetDeviceSpecsByID (device .CanonicalIndex ())
164+ if err != nil {
165+ return fmt .Errorf ("unable to get device spec for %s: %w" , device .CanonicalName (), err )
166+ }
167+ dspecs [0 ].Name = device .CanonicalName ()
168+ deviceSpecs = append (deviceSpecs , dspecs [0 ])
161169 }
162170
163- // TODO: MIG is not yet supported with structured parameters.
164- // Refactor this to generate devices specs for all MIG devices.
165- //
166- // var indices []string
167- // for _, device := range devices.Mig.Devices {
168- // index := fmt.Sprintf("%d:%d", device.Info.parent.index, device.Info.index)
169- // indices = append(indices, index)
170- // }
171- // migSpecs, err := cdi.nvcdiClaim.GetDeviceSpecsByID(indices...)
172- // if err != nil {
173- // return fmt.Errorf("unable to get CDI spec edits for MIG Devices: %w", err)
174- // }
175-
176171 // Generate base spec from commonEdits and deviceEdits.
177172 spec , err := spec .New (
178173 spec .WithVendor (cdiVendor ),
@@ -211,28 +206,35 @@ func (cdi *CDIHandler) CreateStandardDeviceSpecFile(allocatable AllocatableDevic
211206 return cdi .cache .WriteSpec (spec .Raw (), specName )
212207}
213208
214- func (cdi * CDIHandler ) CreateClaimSpecFile (claimUID string , devices * PreparedDevices ) error {
215- // Gather all claim specific container edits together.
216- // Include at least one edit so that this file always gets created without error.
217- claimEdits := cdiapi.ContainerEdits {
218- ContainerEdits : & cdispec.ContainerEdits {
219- Env : []string {
220- fmt .Sprintf ("NVIDIA_VISIBLE_DEVICES=%s" , strings .Join (devices .UUIDs (), "," )),
209+ func (cdi * CDIHandler ) CreateClaimSpecFile (claimUID string , preparedDevices PreparedDevices ) error {
210+ // Generate claim specific specs for each device.
211+ var deviceSpecs []cdispec.Device
212+ for _ , group := range preparedDevices {
213+ // Include this per-device, rather than as a top-level edit so that
214+ // each device spec is never empty and the spec file gets created
215+ // without error.
216+ claimDeviceEdits := cdiapi.ContainerEdits {
217+ ContainerEdits : & cdispec.ContainerEdits {
218+ Env : []string {
219+ fmt .Sprintf ("NVIDIA_VISIBLE_DEVICES=%s" , strings .Join (preparedDevices .UUIDs (), "," )),
220+ },
221221 },
222- },
223- }
222+ }
224223
225- // Generate devices for the MPS control daemon if configured.
226- if devices .MpsControlDaemon != nil {
227- claimEdits .Append (devices .MpsControlDaemon .GetCDIContainerEdits ())
228- }
224+ // Generate edits for the MPS control daemon if configured for the group.
225+ if group .MpsControlDaemon != nil {
226+ claimDeviceEdits .Append (group .MpsControlDaemon .GetCDIContainerEdits ())
227+ }
228+
229+ // Apply edits to all devices.
230+ for _ , device := range group .Devices {
231+ deviceSpec := cdispec.Device {
232+ Name : fmt .Sprintf ("%s-%s" , claimUID , device .CanonicalName ()),
233+ ContainerEdits : * claimDeviceEdits .ContainerEdits ,
234+ }
229235
230- // Create a single device spec for all of the edits associated with this claim.
231- deviceSpecs := []cdispec.Device {
232- {
233- Name : claimUID ,
234- ContainerEdits : * claimEdits .ContainerEdits ,
235- },
236+ deviceSpecs = append (deviceSpecs , deviceSpec )
237+ }
236238 }
237239
238240 // Generate the claim specific device spec for this driver.
@@ -286,6 +288,6 @@ func (cdi *CDIHandler) GetStandardDevices(devices []string) []string {
286288 return cdiDevices
287289}
288290
289- func (cdi * CDIHandler ) GetClaimDevice (claimUID string ) string {
290- return cdiparser .QualifiedName (cdiVendor , cdiClaimClass , claimUID )
291+ func (cdi * CDIHandler ) GetClaimDevice (claimUID string , device string ) string {
292+ return cdiparser .QualifiedName (cdiVendor , cdiClaimClass , fmt . Sprintf ( "%s-%s" , claimUID , device ) )
291293}
0 commit comments