@@ -33,6 +33,7 @@ type command struct {
3333}
3434
3535type options struct {
36+ folders cli.StringSlice
3637 driverVersion string
3738 containerSpec string
3839}
@@ -66,6 +67,11 @@ func (m command) build() *cli.Command {
6667 Usage : "Specify the host driver version" ,
6768 Destination : & cfg .driverVersion ,
6869 },
70+ & cli.StringSliceFlag {
71+ Name : "folder" ,
72+ Usage : "Specify the folders that are added to the container. These are used to locate libcuda.so.RM_VERSION if not specified." ,
73+ Destination : & cfg .folders ,
74+ },
6975 & cli.StringFlag {
7076 Name : "container-spec" ,
7177 Usage : "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN" ,
@@ -120,7 +126,22 @@ func (m command) run(c *cli.Context, cfg *options) error {
120126
121127 compatVersion := strings .TrimPrefix (filepath .Base (libs [0 ]), "libcuda.so." )
122128 compatMajor := strings .SplitN (compatVersion , "." , 2 )[0 ]
123- driverMajor := strings .SplitN (cfg .driverVersion , "." , 2 )[0 ]
129+
130+ driverVersion := cfg .driverVersion
131+ if driverVersion == "" {
132+ for _ , folder := range cfg .folders .Value () {
133+ libs , err := root (containerRoot ).glob (filepath .Join (folder , "libcuda.so.*.*" ))
134+ if err != nil || len (libs ) == 0 {
135+ continue
136+ }
137+ if len (libs ) != 1 {
138+ m .logger .Warningf ("Unexpected number of CUDA compat libraries: %v" , libs )
139+ }
140+
141+ driverVersion = strings .TrimPrefix (filepath .Base (libs [0 ]), "libcuda.so." )
142+ }
143+ }
144+ driverMajor := strings .SplitN (driverVersion , "." , 2 )[0 ]
124145
125146 if driverMajor > compatMajor {
126147 return nil
0 commit comments