@@ -25,6 +25,7 @@ import (
2525 "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
2626 "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
2727 "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
28+ "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
2829 "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
2930 "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
3031 "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
@@ -34,7 +35,7 @@ import (
3435// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
3536// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
3637// used to select the devices to include.
37- func NewCDIModifier (logger logger.Interface , cfg * config.Config , ociSpec oci.Spec ) (oci.SpecModifier , error ) {
38+ func NewCDIModifier (logger logger.Interface , cfg * config.Config , driver * root. Driver , ociSpec oci.Spec ) (oci.SpecModifier , error ) {
3839 devices , err := getDevicesFromSpec (logger , ociSpec , cfg )
3940 if err != nil {
4041 return nil , fmt .Errorf ("failed to get required devices from OCI specification: %v" , err )
@@ -50,7 +51,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
5051 return nil , fmt .Errorf ("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices" )
5152 }
5253 if len (automaticDevices ) > 0 {
53- automaticModifier , err := newAutomaticCDISpecModifier (logger , cfg , automaticDevices )
54+ automaticModifier , err := newAutomaticCDISpecModifier (logger , cfg , driver , automaticDevices )
5455 if err == nil {
5556 return automaticModifier , nil
5657 }
@@ -163,9 +164,9 @@ func filterAutomaticDevices(devices []string) []string {
163164 return automatic
164165}
165166
166- func newAutomaticCDISpecModifier (logger logger.Interface , cfg * config.Config , devices []string ) (oci.SpecModifier , error ) {
167+ func newAutomaticCDISpecModifier (logger logger.Interface , cfg * config.Config , driver * root. Driver , devices []string ) (oci.SpecModifier , error ) {
167168 logger .Debugf ("Generating in-memory CDI specs for devices %v" , devices )
168- spec , err := generateAutomaticCDISpec (logger , cfg , devices )
169+ spec , err := generateAutomaticCDISpec (logger , cfg , driver , devices )
169170 if err != nil {
170171 return nil , fmt .Errorf ("failed to generate CDI spec: %w" , err )
171172 }
@@ -180,7 +181,7 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
180181 return cdiModifier , nil
181182}
182183
183- func generateAutomaticCDISpec (logger logger.Interface , cfg * config.Config , devices []string ) (spec.Interface , error ) {
184+ func generateAutomaticCDISpec (logger logger.Interface , cfg * config.Config , driver * root. Driver , devices []string ) (spec.Interface , error ) {
184185 cdilib , err := nvcdi .New (
185186 nvcdi .WithLogger (logger ),
186187 nvcdi .WithNVIDIACDIHookPath (cfg .NVIDIACTKConfig .Path ),
@@ -192,6 +193,11 @@ func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devic
192193 return nil , fmt .Errorf ("failed to construct CDI library: %w" , err )
193194 }
194195
196+ // TODO: Consider moving this into the nvcdi API.
197+ if err := driver .LoadKernelModules (cfg .NVIDIAContainerRuntimeConfig .Modes .JitCDI .LoadKernelModules ... ); err != nil {
198+ logger .Warningf ("Ignoring error(s) loading kernel modules: %v" , err )
199+ }
200+
195201 identifiers := []string {}
196202 for _ , device := range devices {
197203 _ , _ , id := parser .ParseDevice (device )
0 commit comments