Skip to content

Commit cc78124

Browse files
authored
Merge pull request #1143 from elezar/add-device-ids-to-getspec
Add device IDs to nvcdi.GetSpec API
2 parents bdcdcb7 + 2ccf67c commit cc78124

File tree

11 files changed

+25
-28
lines changed

11 files changed

+25
-28
lines changed

internal/modifier/cdi.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,28 +144,13 @@ func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devic
144144
return nil, fmt.Errorf("failed to construct CDI library: %w", err)
145145
}
146146

147-
identifiers := []string{}
147+
var identifiers []string
148148
for _, device := range devices {
149149
_, _, id := parser.ParseDevice(device)
150150
identifiers = append(identifiers, id)
151151
}
152152

153-
deviceSpecs, err := cdilib.GetDeviceSpecsByID(identifiers...)
154-
if err != nil {
155-
return nil, fmt.Errorf("failed to get CDI device specs: %w", err)
156-
}
157-
158-
commonEdits, err := cdilib.GetCommonEdits()
159-
if err != nil {
160-
return nil, fmt.Errorf("failed to get common CDI spec edits: %w", err)
161-
}
162-
163-
return spec.New(
164-
spec.WithDeviceSpecs(deviceSpecs),
165-
spec.WithEdits(*commonEdits.ContainerEdits),
166-
spec.WithVendor("runtime.nvidia.com"),
167-
spec.WithClass("gpu"),
168-
)
153+
return cdilib.GetSpec(identifiers...)
169154
}
170155

171156
type deduplicatedDeviceRequestor struct {

pkg/nvcdi/api.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import (
2727

2828
// Interface defines the API for the nvcdi package
2929
type Interface interface {
30-
GetSpec() (spec.Interface, error)
30+
GetSpec(...string) (spec.Interface, error)
3131
GetCommonEdits() (*cdi.ContainerEdits, error)
3232
GetAllDeviceSpecs() ([]specs.Device, error)
3333
GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error)

pkg/nvcdi/gds.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (l *gdslib) GetCommonEdits() (*cdi.ContainerEdits, error) {
5858

5959
// GetSpec is unsppported for the gdslib specs.
6060
// gdslib is typically wrapped by a spec that implements GetSpec.
61-
func (l *gdslib) GetSpec() (spec.Interface, error) {
61+
func (l *gdslib) GetSpec(...string) (spec.Interface, error) {
6262
return nil, fmt.Errorf("GetSpec is not supported")
6363
}
6464

pkg/nvcdi/lib-csv.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ type csvlib nvcdilib
3434
var _ Interface = (*csvlib)(nil)
3535

3636
// GetSpec should not be called for wsllib
37-
func (l *csvlib) GetSpec() (spec.Interface, error) {
37+
func (l *csvlib) GetSpec(...string) (spec.Interface, error) {
3838
return nil, fmt.Errorf("unexpected call to csvlib.GetSpec()")
3939
}
4040

pkg/nvcdi/lib-imex.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ const (
4141
)
4242

4343
// GetSpec should not be called for imexlib.
44-
func (l *imexlib) GetSpec() (spec.Interface, error) {
44+
func (l *imexlib) GetSpec(...string) (spec.Interface, error) {
4545
return nil, fmt.Errorf("unexpected call to imexlib.GetSpec()")
4646
}
4747

pkg/nvcdi/lib-nvml.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ type nvmllib nvcdilib
3636
var _ Interface = (*nvmllib)(nil)
3737

3838
// GetSpec should not be called for nvmllib
39-
func (l *nvmllib) GetSpec() (spec.Interface, error) {
39+
func (l *nvmllib) GetSpec(...string) (spec.Interface, error) {
4040
return nil, fmt.Errorf("unexpected call to nvmllib.GetSpec()")
4141
}
4242

pkg/nvcdi/lib-wsl.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type wsllib nvcdilib
3232
var _ Interface = (*wsllib)(nil)
3333

3434
// GetSpec should not be called for wsllib
35-
func (l *wsllib) GetSpec() (spec.Interface, error) {
35+
func (l *wsllib) GetSpec(...string) (spec.Interface, error) {
3636
return nil, fmt.Errorf("unexpected call to wsllib.GetSpec()")
3737
}
3838

pkg/nvcdi/management.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func (m managementDiscoverer) nodeIsBlocked(path string) bool {
180180

181181
// GetSpec is unsppported for the managementlib specs.
182182
// managementlib is typically wrapped by a spec that implements GetSpec.
183-
func (m *managementlib) GetSpec() (spec.Interface, error) {
183+
func (m *managementlib) GetSpec(...string) (spec.Interface, error) {
184184
return nil, fmt.Errorf("GetSpec is not supported")
185185
}
186186

pkg/nvcdi/mofed.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (l *mofedlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
5858

5959
// GetSpec is unsppported for the mofedlib specs.
6060
// mofedlib is typically wrapped by a spec that implements GetSpec.
61-
func (l *mofedlib) GetSpec() (spec.Interface, error) {
61+
func (l *mofedlib) GetSpec(...string) (spec.Interface, error) {
6262
return nil, fmt.Errorf("GetSpec is not supported")
6363
}
6464

pkg/nvcdi/wrapper.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ type wrapper struct {
3535
}
3636

3737
// GetSpec combines the device specs and common edits from the wrapped Interface to a single spec.Interface.
38-
func (l *wrapper) GetSpec() (spec.Interface, error) {
39-
deviceSpecs, err := l.GetAllDeviceSpecs()
38+
func (l *wrapper) GetSpec(devices ...string) (spec.Interface, error) {
39+
if len(devices) == 0 {
40+
devices = append(devices, "all")
41+
}
42+
deviceSpecs, err := l.GetDeviceSpecsByID(devices...)
4043
if err != nil {
4144
return nil, err
4245
}
@@ -55,6 +58,16 @@ func (l *wrapper) GetSpec() (spec.Interface, error) {
5558
)
5659
}
5760

61+
func (l *wrapper) GetDeviceSpecsByID(devices ...string) ([]specs.Device, error) {
62+
for _, device := range devices {
63+
if device != "all" {
64+
continue
65+
}
66+
return l.GetAllDeviceSpecs()
67+
}
68+
return l.Interface.GetDeviceSpecsByID(devices...)
69+
}
70+
5871
// GetAllDeviceSpecs returns the device specs for all available devices.
5972
func (l *wrapper) GetAllDeviceSpecs() ([]specs.Device, error) {
6073
return l.Interface.GetAllDeviceSpecs()

0 commit comments

Comments
 (0)