Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/discover/graphics.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver
if err != nil {
return nil, fmt.Errorf("failed to get driver version: %w", err)
}
cudaLibRoot, err := driver.GetLibcudaParentDir()
cudaLibRoot, err := driver.GetDriverLibDirectory()
if err != nil {
return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err)
}
Expand Down
12 changes: 9 additions & 3 deletions internal/lookup/root/cuda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ func TestLocate(t *testing.T) {
{
description: "no-ldcache searches /usr/lib64",
libcudaPath: "/usr/lib64/libcuda.so.123.34",
expected: "/usr/lib64/libcuda.so.123.34",
expected: "/usr/lib64",
expectedError: nil,
},
{
description: "no-ldcache searches /usr/lib64 for libnvidia-ml.so.",
libcudaPath: "/usr/lib64/libnvidia-ml.so.123.34",
expected: "/usr/lib64",
expectedError: nil,
},
}
Expand All @@ -62,11 +68,11 @@ func TestLocate(t *testing.T) {
WithDriverRoot(driverRoot),
)

libcudasoPath, err := l.GetLibcudasoPath()
driverLibraryPath, err := l.GetDriverLibDirectory()
require.ErrorIs(t, err, tc.expectedError)

// NOTE: We need to strip `/private` on MacOs due to symlink resolution
stripped := strings.TrimPrefix(libcudasoPath, "/private")
stripped := strings.TrimPrefix(driverLibraryPath, "/private")

require.Equal(t, tc.expected, stripped)
})
Expand Down
73 changes: 42 additions & 31 deletions internal/lookup/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package root

import (
"errors"
"fmt"
"os"
"path/filepath"
Expand All @@ -40,8 +41,8 @@ type Driver struct {

// version caches the driver version.
version string
// libcudasoPath caches the path to libcuda.so.VERSION.
libcudasoPath string
// driverLibDirectory caches the path to parent of the driver libraries
driverLibDirectory string
}

// New creates a new Driver root using the specified options.
Expand Down Expand Up @@ -69,7 +70,7 @@ func New(opts ...Option) *Driver {
librarySearchPaths: o.librarySearchPaths,
configSearchPaths: o.configSearchPaths,
version: driverVersion,
libcudasoPath: "",
driverLibDirectory: "",
}

return d
Expand All @@ -90,31 +91,24 @@ func (r *Driver) Version() (string, error) {
return r.version, nil
}

// GetLibcudaParentDir returns the cached libcuda.so path if possible.
// GetDriverLibDirectory returns the cached directory where the driver libs are
// found if possible.
// If this has not yet been initialized, the path is first detected and then returned.
func (r *Driver) GetLibcudasoPath() (string, error) {
func (r *Driver) GetDriverLibDirectory() (string, error) {
r.Lock()
defer r.Unlock()

if r.libcudasoPath == "" {
if r.driverLibDirectory == "" {
if err := r.updateInfo(); err != nil {
return "", err
}
}

return r.libcudasoPath, nil
}

func (r *Driver) GetLibcudaParentDir() (string, error) {
libcudasoPath, err := r.GetLibcudasoPath()
if err != nil {
return "", err
}
return filepath.Dir(libcudasoPath), nil
return r.driverLibDirectory, nil
}

func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator, error) {
libcudasoParentDirPath, err := r.GetLibcudaParentDir()
libcudasoParentDirPath, err := r.GetDriverLibDirectory()
if err != nil {
return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err)
}
Expand All @@ -140,30 +134,47 @@ func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator,
}

func (r *Driver) updateInfo() error {
versionSuffix := r.version
if versionSuffix == "" {
versionSuffix = "*.*"
}

libCudaPaths, err := r.Libraries().Locate("libcuda.so." + versionSuffix)
driverLibPath, version, err := r.inferVersion()
if err != nil {
return fmt.Errorf("failed to locate libcuda.so: %w", err)
}
libcudaPath := libCudaPaths[0]

version := strings.TrimPrefix(filepath.Base(libcudaPath), "libcuda.so.")
if version == "" {
return fmt.Errorf("failed to extract version from path %v", libcudaPath)
return err
}

if r.version != "" && r.version != version {
return fmt.Errorf("unexpected version detected: %v != %v", r.version, version)
}

r.version = version
r.libcudasoPath = r.RelativeToRoot(libcudaPath)
r.driverLibDirectory = r.RelativeToRoot(filepath.Dir(driverLibPath))

return nil
}

// inferVersion attempts to infer the driver version from the libcuda.so or
// libnvidia-ml.so driver library suffixes.
func (r *Driver) inferVersion() (string, string, error) {
versionSuffix := r.version
if versionSuffix == "" {
versionSuffix = "*.*"
}

var errs error
for _, driverLib := range []string{"libcuda.so.", "libnvidia-ml.so."} {
driverLibPaths, err := r.Libraries().Locate(driverLib + versionSuffix)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("failed to locate libcuda.so: %w", err))
continue
}
driverLibPath := driverLibPaths[0]
version := strings.TrimPrefix(filepath.Base(driverLibPath), driverLib)
if version == "" {
errs = errors.Join(errs, fmt.Errorf("failed to extract version from path %v", driverLibPath))
continue
}
return driverLibPath, version, nil
}

return "", "", errs
}

// RelativeToRoot returns the specified path relative to the driver root.
func (r *Driver) RelativeToRoot(path string) string {
if r.Root == "" || r.Root == "/" {
Expand Down
6 changes: 3 additions & 3 deletions pkg/nvcdi/driver-nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (l *nvcdilib) newDriverVersionDiscoverer() (discover.Discover, error) {
return nil, fmt.Errorf("failed to determine driver version (%q): %w", version, err)
}

libcudasoParentDirPath, err := l.driver.GetLibcudaParentDir()
libcudasoParentDirPath, err := l.driver.GetDriverLibDirectory()
if err != nil {
return nil, fmt.Errorf("failed to get libcuda.so parent path: %w", err)
}
Expand Down Expand Up @@ -110,13 +110,13 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string, libcudaSoParentDir
disableDeviceNodeModification := l.hookCreator.Create(DisableDeviceNodeModificationHook)
discoverers = append(discoverers, disableDeviceNodeModification)

libCudaSoParentDirectoryPath, err := l.driver.GetLibcudaParentDir()
driverLibDirectory, err := l.driver.GetDriverLibDirectory()
if err != nil {
return nil, fmt.Errorf("failed to get libcuda.so parent directory path: %w", err)
}
environmentVariable := &discover.EnvVar{
Name: "NVIDIA_CTK_LIBCUDA_DIR",
Value: libCudaSoParentDirectoryPath,
Value: driverLibDirectory,
}
discoverers = append(discoverers, environmentVariable)

Expand Down