@@ -31,11 +31,22 @@ import (
3131 "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
3232)
3333
34+ const (
35+ automaticDeviceVendor = "runtime.nvidia.com"
36+ automaticDeviceClass = "gpu"
37+ automaticDeviceKind = automaticDeviceVendor + "/" + automaticDeviceClass
38+ )
39+
3440// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
3541// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
3642// used to select the devices to include.
37- func NewCDIModifier (logger logger.Interface , cfg * config.Config , ociSpec oci.Spec ) (oci.SpecModifier , error ) {
38- devices , err := getDevicesFromSpec (logger , ociSpec , cfg )
43+ func NewCDIModifier (logger logger.Interface , cfg * config.Config , ociSpec oci.Spec , isJitCDI bool ) (oci.SpecModifier , error ) {
44+ defaultKind := cfg .NVIDIAContainerRuntimeConfig .Modes .CDI .DefaultKind
45+ if isJitCDI {
46+ defaultKind = automaticDeviceKind
47+ }
48+
49+ devices , err := getDevicesFromSpec (logger , ociSpec , cfg , defaultKind )
3950 if err != nil {
4051 return nil , fmt .Errorf ("failed to get required devices from OCI specification: %v" , err )
4152 }
@@ -65,7 +76,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
6576 )
6677}
6778
68- func getDevicesFromSpec (logger logger.Interface , ociSpec oci.Spec , cfg * config.Config ) ([]string , error ) {
79+ func getDevicesFromSpec (logger logger.Interface , ociSpec oci.Spec , cfg * config.Config , defaultKind string ) ([]string , error ) {
6980 rawSpec , err := ociSpec .Load ()
7081 if err != nil {
7182 return nil , fmt .Errorf ("failed to load OCI spec: %v" , err )
@@ -83,26 +94,16 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
8394 if err != nil {
8495 return nil , err
8596 }
86- if cfg .AcceptDeviceListAsVolumeMounts {
87- mountDevices := container .CDIDevicesFromMounts ()
88- if len (mountDevices ) > 0 {
89- return mountDevices , nil
90- }
91- }
9297
9398 var devices []string
94- seen := make (map [string ]bool )
95- for _ , name := range container .VisibleDevicesFromEnvVar () {
96- if ! parser .IsQualifiedName (name ) {
97- name = fmt .Sprintf ("%s=%s" , cfg .NVIDIAContainerRuntimeConfig .Modes .CDI .DefaultKind , name )
98- }
99- if seen [name ] {
100- logger .Debugf ("Ignoring duplicate device %q" , name )
101- continue
99+ if cfg .AcceptDeviceListAsVolumeMounts {
100+ devices = normalizeDeviceList (logger , defaultKind , append (container .DevicesFromMounts (), container .CDIDevicesFromMounts ()... )... )
101+ if len (devices ) > 0 {
102+ return devices , nil
102103 }
103- devices = append (devices , name )
104104 }
105105
106+ devices = normalizeDeviceList (logger , defaultKind , container .VisibleDevicesFromEnvVar ()... )
106107 if len (devices ) == 0 {
107108 return nil , nil
108109 }
@@ -116,6 +117,24 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
116117 return nil , nil
117118}
118119
120+ func normalizeDeviceList (logger logger.Interface , defaultKind string , devices ... string ) []string {
121+ seen := make (map [string ]bool )
122+ var normalized []string
123+ for _ , name := range devices {
124+ if ! parser .IsQualifiedName (name ) {
125+ name = fmt .Sprintf ("%s=%s" , defaultKind , name )
126+ }
127+ if seen [name ] {
128+ logger .Debugf ("Ignoring duplicate device %q" , name )
129+ continue
130+ }
131+ normalized = append (normalized , fmt .Sprintf ("%s=%s" , defaultKind , name ))
132+ seen [name ] = true
133+ }
134+
135+ return normalized
136+ }
137+
119138// getAnnotationDevices returns a list of devices specified in the annotations.
120139// Keys starting with the specified prefixes are considered and expected to contain a comma-separated list of
121140// fully-qualified CDI devices names. If any device name is not fully-quality an error is returned.
@@ -156,7 +175,7 @@ func filterAutomaticDevices(devices []string) []string {
156175 var automatic []string
157176 for _ , device := range devices {
158177 vendor , class , _ := parser .ParseDevice (device )
159- if vendor == "runtime.nvidia.com" && class == "gpu" {
178+ if vendor == automaticDeviceVendor && class == automaticDeviceClass {
160179 automatic = append (automatic , device )
161180 }
162181 }
@@ -165,6 +184,8 @@ func filterAutomaticDevices(devices []string) []string {
165184
166185func newAutomaticCDISpecModifier (logger logger.Interface , cfg * config.Config , devices []string ) (oci.SpecModifier , error ) {
167186 logger .Debugf ("Generating in-memory CDI specs for devices %v" , devices )
187+ // TODO: We should try to load the kernel modules and create the device nodes here.
188+ // Failures should raise a warning and not error out.
168189 spec , err := generateAutomaticCDISpec (logger , cfg , devices )
169190 if err != nil {
170191 return nil , fmt .Errorf ("failed to generate CDI spec: %w" , err )
0 commit comments