1717package root
1818
1919import (
20+ "errors"
2021 "fmt"
2122 "os"
2223 "path/filepath"
@@ -40,8 +41,8 @@ type Driver struct {
4041
4142 // version caches the driver version.
4243 version string
43- // libcudasoPath caches the path to libcuda.so.VERSION.
44- libcudasoPath string
44+ // driverLibDirectory caches the path to parent of the driver libraries
45+ driverLibDirectory string
4546}
4647
4748// New creates a new Driver root using the specified options.
@@ -69,7 +70,7 @@ func New(opts ...Option) *Driver {
6970 librarySearchPaths : o .librarySearchPaths ,
7071 configSearchPaths : o .configSearchPaths ,
7172 version : driverVersion ,
72- libcudasoPath : "" ,
73+ driverLibDirectory : "" ,
7374 }
7475
7576 return d
@@ -90,31 +91,24 @@ func (r *Driver) Version() (string, error) {
9091 return r .version , nil
9192}
9293
93- // GetLibcudaParentDir returns the cached libcuda.so path if possible.
94+ // GetDriverLibDirectory returns the cached directory where the driver libs are
95+ // found if possible.
9496// If this has not yet been initialized, the path is first detected and then returned.
95- func (r * Driver ) GetLibcudasoPath () (string , error ) {
97+ func (r * Driver ) GetDriverLibDirectory () (string , error ) {
9698 r .Lock ()
9799 defer r .Unlock ()
98100
99- if r .libcudasoPath == "" {
101+ if r .driverLibDirectory == "" {
100102 if err := r .updateInfo (); err != nil {
101103 return "" , err
102104 }
103105 }
104106
105- return r .libcudasoPath , nil
106- }
107-
108- func (r * Driver ) GetLibcudaParentDir () (string , error ) {
109- libcudasoPath , err := r .GetLibcudasoPath ()
110- if err != nil {
111- return "" , err
112- }
113- return filepath .Dir (libcudasoPath ), nil
107+ return r .driverLibDirectory , nil
114108}
115109
116110func (r * Driver ) DriverLibraryLocator (additionalDirs ... string ) (lookup.Locator , error ) {
117- libcudasoParentDirPath , err := r .GetLibcudaParentDir ()
111+ libcudasoParentDirPath , err := r .GetDriverLibDirectory ()
118112 if err != nil {
119113 return nil , fmt .Errorf ("failed to get libcuda.so parent directory: %w" , err )
120114 }
@@ -140,30 +134,47 @@ func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator,
140134}
141135
142136func (r * Driver ) updateInfo () error {
143- versionSuffix := r .version
144- if versionSuffix == "" {
145- versionSuffix = "*.*"
146- }
147-
148- libCudaPaths , err := r .Libraries ().Locate ("libcuda.so." + versionSuffix )
137+ driverLibPath , version , err := r .inferVersion ()
149138 if err != nil {
150- return fmt .Errorf ("failed to locate libcuda.so: %w" , err )
151- }
152- libcudaPath := libCudaPaths [0 ]
153-
154- version := strings .TrimPrefix (filepath .Base (libcudaPath ), "libcuda.so." )
155- if version == "" {
156- return fmt .Errorf ("failed to extract version from path %v" , libcudaPath )
139+ return err
157140 }
158-
159141 if r .version != "" && r .version != version {
160142 return fmt .Errorf ("unexpected version detected: %v != %v" , r .version , version )
161143 }
144+
162145 r .version = version
163- r .libcudasoPath = r .RelativeToRoot (libcudaPath )
146+ r .driverLibDirectory = r .RelativeToRoot (filepath .Dir (driverLibPath ))
147+
164148 return nil
165149}
166150
151+ // inferVersion attempts to infer the driver version from the libcuda.so or
152+ // libnvidia-ml.so driver library suffixes.
153+ func (r * Driver ) inferVersion () (string , string , error ) {
154+ versionSuffix := r .version
155+ if versionSuffix == "" {
156+ versionSuffix = "*.*"
157+ }
158+
159+ var errs error
160+ for _ , driverLib := range []string {"libcuda.so." , "libnvidia-ml.so." } {
161+ driverLibPaths , err := r .Libraries ().Locate (driverLib + versionSuffix )
162+ if err != nil {
163+ errs = errors .Join (errs , fmt .Errorf ("failed to locate libcuda.so: %w" , err ))
164+ continue
165+ }
166+ driverLibPath := driverLibPaths [0 ]
167+ version := strings .TrimPrefix (filepath .Base (driverLibPath ), driverLib )
168+ if version == "" {
169+ errs = errors .Join (errs , fmt .Errorf ("failed to extract version from path %v" , driverLibPath ))
170+ continue
171+ }
172+ return driverLibPath , version , nil
173+ }
174+
175+ return "" , "" , errs
176+ }
177+
167178// RelativeToRoot returns the specified path relative to the driver root.
168179func (r * Driver ) RelativeToRoot (path string ) string {
169180 if r .Root == "" || r .Root == "/" {
0 commit comments