Skip to content

Commit 0cdb30e

Browse files
authored
Merge pull request #1451 from elezar/relax-driver-version
Also consider libnvidia-ml.so for extracting driver version
2 parents eda32d5 + d087e91 commit 0cdb30e

File tree

4 files changed

+55
-38
lines changed

4 files changed

+55
-38
lines changed

internal/discover/graphics.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver
120120
if err != nil {
121121
return nil, fmt.Errorf("failed to get driver version: %w", err)
122122
}
123-
cudaLibRoot, err := driver.GetLibcudaParentDir()
123+
cudaLibRoot, err := driver.GetDriverLibDirectory()
124124
if err != nil {
125125
return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err)
126126
}

internal/lookup/root/cuda_test.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@ func TestLocate(t *testing.T) {
4747
{
4848
description: "no-ldcache searches /usr/lib64",
4949
libcudaPath: "/usr/lib64/libcuda.so.123.34",
50-
expected: "/usr/lib64/libcuda.so.123.34",
50+
expected: "/usr/lib64",
51+
expectedError: nil,
52+
},
53+
{
54+
description: "no-ldcache searches /usr/lib64 for libnvidia-ml.so.",
55+
libcudaPath: "/usr/lib64/libnvidia-ml.so.123.34",
56+
expected: "/usr/lib64",
5157
expectedError: nil,
5258
},
5359
}
@@ -62,11 +68,11 @@ func TestLocate(t *testing.T) {
6268
WithDriverRoot(driverRoot),
6369
)
6470

65-
libcudasoPath, err := l.GetLibcudasoPath()
71+
driverLibraryPath, err := l.GetDriverLibDirectory()
6672
require.ErrorIs(t, err, tc.expectedError)
6773

6874
// NOTE: We need to strip `/private` on MacOs due to symlink resolution
69-
stripped := strings.TrimPrefix(libcudasoPath, "/private")
75+
stripped := strings.TrimPrefix(driverLibraryPath, "/private")
7076

7177
require.Equal(t, tc.expected, stripped)
7278
})

internal/lookup/root/root.go

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package root
1818

1919
import (
20+
"errors"
2021
"fmt"
2122
"os"
2223
"path/filepath"
@@ -40,8 +41,8 @@ type Driver struct {
4041

4142
// version caches the driver version.
4243
version string
43-
// libcudasoPath caches the path to libcuda.so.VERSION.
44-
libcudasoPath string
44+
// driverLibDirectory caches the path to parent of the driver libraries
45+
driverLibDirectory string
4546
}
4647

4748
// New creates a new Driver root using the specified options.
@@ -69,7 +70,7 @@ func New(opts ...Option) *Driver {
6970
librarySearchPaths: o.librarySearchPaths,
7071
configSearchPaths: o.configSearchPaths,
7172
version: driverVersion,
72-
libcudasoPath: "",
73+
driverLibDirectory: "",
7374
}
7475

7576
return d
@@ -90,31 +91,24 @@ func (r *Driver) Version() (string, error) {
9091
return r.version, nil
9192
}
9293

93-
// GetLibcudaParentDir returns the cached libcuda.so path if possible.
94+
// GetDriverLibDirectory returns the cached directory where the driver libs are
95+
// found if possible.
9496
// If this has not yet been initialized, the path is first detected and then returned.
95-
func (r *Driver) GetLibcudasoPath() (string, error) {
97+
func (r *Driver) GetDriverLibDirectory() (string, error) {
9698
r.Lock()
9799
defer r.Unlock()
98100

99-
if r.libcudasoPath == "" {
101+
if r.driverLibDirectory == "" {
100102
if err := r.updateInfo(); err != nil {
101103
return "", err
102104
}
103105
}
104106

105-
return r.libcudasoPath, nil
106-
}
107-
108-
func (r *Driver) GetLibcudaParentDir() (string, error) {
109-
libcudasoPath, err := r.GetLibcudasoPath()
110-
if err != nil {
111-
return "", err
112-
}
113-
return filepath.Dir(libcudasoPath), nil
107+
return r.driverLibDirectory, nil
114108
}
115109

116110
func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator, error) {
117-
libcudasoParentDirPath, err := r.GetLibcudaParentDir()
111+
libcudasoParentDirPath, err := r.GetDriverLibDirectory()
118112
if err != nil {
119113
return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err)
120114
}
@@ -140,30 +134,47 @@ func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator,
140134
}
141135

142136
func (r *Driver) updateInfo() error {
143-
versionSuffix := r.version
144-
if versionSuffix == "" {
145-
versionSuffix = "*.*"
146-
}
147-
148-
libCudaPaths, err := r.Libraries().Locate("libcuda.so." + versionSuffix)
137+
driverLibPath, version, err := r.inferVersion()
149138
if err != nil {
150-
return fmt.Errorf("failed to locate libcuda.so: %w", err)
151-
}
152-
libcudaPath := libCudaPaths[0]
153-
154-
version := strings.TrimPrefix(filepath.Base(libcudaPath), "libcuda.so.")
155-
if version == "" {
156-
return fmt.Errorf("failed to extract version from path %v", libcudaPath)
139+
return err
157140
}
158-
159141
if r.version != "" && r.version != version {
160142
return fmt.Errorf("unexpected version detected: %v != %v", r.version, version)
161143
}
144+
162145
r.version = version
163-
r.libcudasoPath = r.RelativeToRoot(libcudaPath)
146+
r.driverLibDirectory = r.RelativeToRoot(filepath.Dir(driverLibPath))
147+
164148
return nil
165149
}
166150

151+
// inferVersion attempts to infer the driver version from the libcuda.so or
152+
// libnvidia-ml.so driver library suffixes.
153+
func (r *Driver) inferVersion() (string, string, error) {
154+
versionSuffix := r.version
155+
if versionSuffix == "" {
156+
versionSuffix = "*.*"
157+
}
158+
159+
var errs error
160+
for _, driverLib := range []string{"libcuda.so.", "libnvidia-ml.so."} {
161+
driverLibPaths, err := r.Libraries().Locate(driverLib + versionSuffix)
162+
if err != nil {
163+
errs = errors.Join(errs, fmt.Errorf("failed to locate libcuda.so: %w", err))
164+
continue
165+
}
166+
driverLibPath := driverLibPaths[0]
167+
version := strings.TrimPrefix(filepath.Base(driverLibPath), driverLib)
168+
if version == "" {
169+
errs = errors.Join(errs, fmt.Errorf("failed to extract version from path %v", driverLibPath))
170+
continue
171+
}
172+
return driverLibPath, version, nil
173+
}
174+
175+
return "", "", errs
176+
}
177+
167178
// RelativeToRoot returns the specified path relative to the driver root.
168179
func (r *Driver) RelativeToRoot(path string) string {
169180
if r.Root == "" || r.Root == "/" {

pkg/nvcdi/driver-nvml.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (l *nvcdilib) newDriverVersionDiscoverer() (discover.Discover, error) {
4242
return nil, fmt.Errorf("failed to determine driver version (%q): %w", version, err)
4343
}
4444

45-
libcudasoParentDirPath, err := l.driver.GetLibcudaParentDir()
45+
libcudasoParentDirPath, err := l.driver.GetDriverLibDirectory()
4646
if err != nil {
4747
return nil, fmt.Errorf("failed to get libcuda.so parent path: %w", err)
4848
}
@@ -110,13 +110,13 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string, libcudaSoParentDir
110110
disableDeviceNodeModification := l.hookCreator.Create(DisableDeviceNodeModificationHook)
111111
discoverers = append(discoverers, disableDeviceNodeModification)
112112

113-
libCudaSoParentDirectoryPath, err := l.driver.GetLibcudaParentDir()
113+
driverLibDirectory, err := l.driver.GetDriverLibDirectory()
114114
if err != nil {
115115
return nil, fmt.Errorf("failed to get libcuda.so parent directory path: %w", err)
116116
}
117117
environmentVariable := &discover.EnvVar{
118118
Name: "NVIDIA_CTK_LIBCUDA_DIR",
119-
Value: libCudaSoParentDirectoryPath,
119+
Value: driverLibDirectory,
120120
}
121121
discoverers = append(discoverers, environmentVariable)
122122

0 commit comments

Comments
 (0)