@@ -32,6 +32,12 @@ import (
3232 "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
3333)
3434
35+ type driverVersionDiscoverer struct {
36+ discover.Discover
37+ nvidiaCDIHookPath string
38+ version string
39+ }
40+
3541// NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation.
3642// The supplied NVML Library is used to query the expected driver version.
3743func NewDriverDiscoverer (logger logger.Interface , driver * root.Driver , nvidiaCDIHookPath string , ldconfigPath string , nvmllib nvml.Interface ) (discover.Discover , error ) {
@@ -100,7 +106,11 @@ func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nv
100106 hooks , _ := discover .NewLDCacheUpdateHook (logger , libraries , nvidiaCDIHookPath , ldconfigPath )
101107
102108 d := discover .Merge (
103- libraries ,
109+ & driverVersionDiscoverer {
110+ Discover : libraries ,
111+ nvidiaCDIHookPath : nvidiaCDIHookPath ,
112+ version : version ,
113+ },
104114 hooks ,
105115 )
106116
@@ -220,3 +230,37 @@ func getVersionLibs(logger logger.Interface, driver *root.Driver, version string
220230
221231 return relative , nil
222232}
233+
234+ func (d driverVersionDiscoverer ) Hooks () ([]discover.Hook , error ) {
235+ mounts , err := d .Discover .Mounts ()
236+ if err != nil {
237+ return nil , fmt .Errorf ("failed to get library mounts: %v" , err )
238+ }
239+
240+ var links []string
241+ for _ , mount := range mounts {
242+ dir , filename := filepath .Split (mount .Path )
243+ // TODO: We should include the other libraries as is done here:
244+ // https://github.com/NVIDIA/nvidia-container-toolkit/blob/79c59aeb7f59dd612793ac80a8d7022c554634bb/internal/platform-support/tegra/symlinks.go#L84-L97
245+ if d .isDriverLibrary (filename , "libcuda.so" ) {
246+ // create libcuda.so -> libcuda.so.RM_VERSION symlink
247+ links = append (links , fmt .Sprintf ("%s::%s" , filename , filepath .Join (dir , "libcuda.so" )))
248+ }
249+ }
250+
251+ if len (links ) == 0 {
252+ return nil , nil
253+ }
254+
255+ hooks := discover .CreateCreateSymlinkHook (d .nvidiaCDIHookPath , links )
256+
257+ return hooks .Hooks ()
258+
259+ }
260+
261+ // isDriverLibrary checks whether the specified filename is a specific driver library.
262+ func (d driverVersionDiscoverer ) isDriverLibrary (filename string , libraryName string ) bool {
263+ pattern := strings .TrimSuffix (libraryName , "." ) + d .version
264+ match , _ := filepath .Match (pattern , filename )
265+ return match
266+ }
0 commit comments