diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index 74f759ff3..14c316fed 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -120,7 +120,7 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver if err != nil { return nil, fmt.Errorf("failed to get driver version: %w", err) } - cudaLibRoot, err := driver.GetLibcudaParentDir() + cudaLibRoot, err := driver.GetDriverLibDirectory() if err != nil { return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err) } diff --git a/internal/lookup/root/cuda_test.go b/internal/lookup/root/cuda_test.go index 03db71ac1..e26361c7c 100644 --- a/internal/lookup/root/cuda_test.go +++ b/internal/lookup/root/cuda_test.go @@ -47,7 +47,13 @@ func TestLocate(t *testing.T) { { description: "no-ldcache searches /usr/lib64", libcudaPath: "/usr/lib64/libcuda.so.123.34", - expected: "/usr/lib64/libcuda.so.123.34", + expected: "/usr/lib64", + expectedError: nil, + }, + { + description: "no-ldcache searches /usr/lib64 for libnvidia-ml.so.", + libcudaPath: "/usr/lib64/libnvidia-ml.so.123.34", + expected: "/usr/lib64", expectedError: nil, }, } @@ -62,11 +68,11 @@ func TestLocate(t *testing.T) { WithDriverRoot(driverRoot), ) - libcudasoPath, err := l.GetLibcudasoPath() + driverLibraryPath, err := l.GetDriverLibDirectory() require.ErrorIs(t, err, tc.expectedError) // NOTE: We need to strip `/private` on MacOs due to symlink resolution - stripped := strings.TrimPrefix(libcudasoPath, "/private") + stripped := strings.TrimPrefix(driverLibraryPath, "/private") require.Equal(t, tc.expected, stripped) }) diff --git a/internal/lookup/root/root.go b/internal/lookup/root/root.go index 284bafd52..0f15e2e0f 100644 --- a/internal/lookup/root/root.go +++ b/internal/lookup/root/root.go @@ -17,6 +17,7 @@ package root import ( + "errors" "fmt" "os" "path/filepath" @@ -40,8 +41,8 @@ type Driver struct { // version caches the driver version. version string - // libcudasoPath caches the path to libcuda.so.VERSION. - libcudasoPath string + // driverLibDirectory caches the path to parent of the driver libraries + driverLibDirectory string } // New creates a new Driver root using the specified options. @@ -69,7 +70,7 @@ func New(opts ...Option) *Driver { librarySearchPaths: o.librarySearchPaths, configSearchPaths: o.configSearchPaths, version: driverVersion, - libcudasoPath: "", + driverLibDirectory: "", } return d @@ -90,31 +91,24 @@ func (r *Driver) Version() (string, error) { return r.version, nil } -// GetLibcudaParentDir returns the cached libcuda.so path if possible. +// GetDriverLibDirectory returns the cached directory where the driver libs are +// found if possible. // If this has not yet been initialized, the path is first detected and then returned. -func (r *Driver) GetLibcudasoPath() (string, error) { +func (r *Driver) GetDriverLibDirectory() (string, error) { r.Lock() defer r.Unlock() - if r.libcudasoPath == "" { + if r.driverLibDirectory == "" { if err := r.updateInfo(); err != nil { return "", err } } - return r.libcudasoPath, nil -} - -func (r *Driver) GetLibcudaParentDir() (string, error) { - libcudasoPath, err := r.GetLibcudasoPath() - if err != nil { - return "", err - } - return filepath.Dir(libcudasoPath), nil + return r.driverLibDirectory, nil } func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator, error) { - libcudasoParentDirPath, err := r.GetLibcudaParentDir() + libcudasoParentDirPath, err := r.GetDriverLibDirectory() if err != nil { return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err) } @@ -140,30 +134,47 @@ func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator, } func (r *Driver) updateInfo() error { - versionSuffix := r.version - if versionSuffix == "" { - versionSuffix = "*.*" - } - - libCudaPaths, err := r.Libraries().Locate("libcuda.so." + versionSuffix) + driverLibPath, version, err := r.inferVersion() if err != nil { - return fmt.Errorf("failed to locate libcuda.so: %w", err) - } - libcudaPath := libCudaPaths[0] - - version := strings.TrimPrefix(filepath.Base(libcudaPath), "libcuda.so.") - if version == "" { - return fmt.Errorf("failed to extract version from path %v", libcudaPath) + return err } - if r.version != "" && r.version != version { return fmt.Errorf("unexpected version detected: %v != %v", r.version, version) } + r.version = version - r.libcudasoPath = r.RelativeToRoot(libcudaPath) + r.driverLibDirectory = r.RelativeToRoot(filepath.Dir(driverLibPath)) + return nil } +// inferVersion attempts to infer the driver version from the libcuda.so or +// libnvidia-ml.so driver library suffixes. +func (r *Driver) inferVersion() (string, string, error) { + versionSuffix := r.version + if versionSuffix == "" { + versionSuffix = "*.*" + } + + var errs error + for _, driverLib := range []string{"libcuda.so.", "libnvidia-ml.so."} { + driverLibPaths, err := r.Libraries().Locate(driverLib + versionSuffix) + if err != nil { + errs = errors.Join(errs, fmt.Errorf("failed to locate libcuda.so: %w", err)) + continue + } + driverLibPath := driverLibPaths[0] + version := strings.TrimPrefix(filepath.Base(driverLibPath), driverLib) + if version == "" { + errs = errors.Join(errs, fmt.Errorf("failed to extract version from path %v", driverLibPath)) + continue + } + return driverLibPath, version, nil + } + + return "", "", errs +} + // RelativeToRoot returns the specified path relative to the driver root. func (r *Driver) RelativeToRoot(path string) string { if r.Root == "" || r.Root == "/" { diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index e145a2d6f..191cc6a91 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -42,7 +42,7 @@ func (l *nvcdilib) newDriverVersionDiscoverer() (discover.Discover, error) { return nil, fmt.Errorf("failed to determine driver version (%q): %w", version, err) } - libcudasoParentDirPath, err := l.driver.GetLibcudaParentDir() + libcudasoParentDirPath, err := l.driver.GetDriverLibDirectory() if err != nil { return nil, fmt.Errorf("failed to get libcuda.so parent path: %w", err) } @@ -110,13 +110,13 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string, libcudaSoParentDir disableDeviceNodeModification := l.hookCreator.Create(DisableDeviceNodeModificationHook) discoverers = append(discoverers, disableDeviceNodeModification) - libCudaSoParentDirectoryPath, err := l.driver.GetLibcudaParentDir() + driverLibDirectory, err := l.driver.GetDriverLibDirectory() if err != nil { return nil, fmt.Errorf("failed to get libcuda.so parent directory path: %w", err) } environmentVariable := &discover.EnvVar{ Name: "NVIDIA_CTK_LIBCUDA_DIR", - Value: libCudaSoParentDirectoryPath, + Value: driverLibDirectory, } discoverers = append(discoverers, environmentVariable)