Skip to content

Commit 00f745f

Browse files
committed
Determine driver version from folders
Signed-off-by: Evan Lezar <[email protected]>
1 parent b23e42d commit 00f745f

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

cmd/nvidia-cdi-hook/compat-libs/compat-libs.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type command struct {
3333
}
3434

3535
type options struct {
36+
folders cli.StringSlice
3637
driverVersion string
3738
containerSpec string
3839
}
@@ -66,6 +67,11 @@ func (m command) build() *cli.Command {
6667
Usage: "Specify the host driver version",
6768
Destination: &cfg.driverVersion,
6869
},
70+
&cli.StringSliceFlag{
71+
Name: "folder",
72+
Usage: "Specify the folders that are added to the container. These are used to locate libcuda.so.RM_VERSION if not specified.",
73+
Destination: &cfg.folders,
74+
},
6975
&cli.StringFlag{
7076
Name: "container-spec",
7177
Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN",
@@ -120,7 +126,22 @@ func (m command) run(c *cli.Context, cfg *options) error {
120126

121127
compatVersion := strings.TrimPrefix(filepath.Base(libs[0]), "libcuda.so.")
122128
compatMajor := strings.SplitN(compatVersion, ".", 2)[0]
123-
driverMajor := strings.SplitN(cfg.driverVersion, ".", 2)[0]
129+
130+
driverVersion := cfg.driverVersion
131+
if driverVersion == "" {
132+
for _, folder := range cfg.folders.Value() {
133+
libs, err := root(containerRoot).glob(filepath.Join(folder, "libcuda.so.*.*"))
134+
if err != nil || len(libs) == 0 {
135+
continue
136+
}
137+
if len(libs) != 1 {
138+
m.logger.Warningf("Unexpected number of CUDA compat libraries: %v", libs)
139+
}
140+
141+
driverVersion = strings.TrimPrefix(filepath.Base(libs[0]), "libcuda.so.")
142+
}
143+
}
144+
driverMajor := strings.SplitN(driverVersion, ".", 2)[0]
124145

125146
if driverMajor > compatMajor {
126147
return nil

0 commit comments

Comments
 (0)