@@ -24,7 +24,8 @@ import (
2424 "strings"
2525 "syscall"
2626
27- "github.com/NVIDIA/nvidia-container-toolkit/internal/system"
27+ "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
28+ "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules"
2829 "github.com/fsnotify/fsnotify"
2930 "github.com/sirupsen/logrus"
3031 "github.com/urfave/cli/v2"
@@ -216,6 +217,7 @@ type linkCreator struct {
216217 logger * logrus.Logger
217218 lister nodeLister
218219 driverRoot string
220+ devRoot string
219221 devCharPath string
220222 dryRun bool
221223 createAll bool
@@ -243,6 +245,9 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) {
243245 if c .driverRoot == "" {
244246 c .driverRoot = "/"
245247 }
248+ if c .devRoot == "" {
249+ c .devRoot = "/"
250+ }
246251 if c .devCharPath == "" {
247252 c .devCharPath = defaultDevCharPath
248253 }
@@ -252,13 +257,13 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) {
252257 }
253258
254259 if c .createAll {
255- lister , err := newAllPossible (c .logger , c .driverRoot )
260+ lister , err := newAllPossible (c .logger , c .devRoot )
256261 if err != nil {
257262 return nil , fmt .Errorf ("failed to create all possible device lister: %v" , err )
258263 }
259264 c .lister = lister
260265 } else {
261- c .lister = existing {c .logger , c .driverRoot }
266+ c .lister = existing {c .logger , c .devRoot }
262267 }
263268 return c , nil
264269}
@@ -268,36 +273,48 @@ func (m linkCreator) setup() error {
268273 return nil
269274 }
270275
271- s , err := system .New (
272- system .WithLogger (m .logger ),
273- system .WithDryRun (m .dryRun ),
274- )
275- if err != nil {
276- return err
277- }
278-
279276 if m .loadKernelModules {
280- if err := s .LoadNVIDIAKernelModules (); err != nil {
277+ modules := nvmodules .New (
278+ nvmodules .WithLogger (m .logger ),
279+ nvmodules .WithDryRun (m .dryRun ),
280+ nvmodules .WithRoot (m .driverRoot ),
281+ )
282+ if err := modules .LoadAll (); err != nil {
281283 return fmt .Errorf ("failed to load NVIDIA kernel modules: %v" , err )
282284 }
283285 }
284286
285287 if m .createDeviceNodes {
286- if err := s .CreateNVIDIAControlDeviceNodesAt (m .driverRoot ); err != nil {
288+ devices , err := nvdevices .New (
289+ nvdevices .WithLogger (m .logger ),
290+ nvdevices .WithDryRun (m .dryRun ),
291+ nvdevices .WithDevRoot (m .devRoot ),
292+ )
293+ if err != nil {
294+ return err
295+ }
296+ if err := devices .CreateNVIDIAControlDevices (); err != nil {
287297 return fmt .Errorf ("failed to create NVIDIA device nodes: %v" , err )
288298 }
289299 }
290-
291300 return nil
292301}
293302
294303// WithDriverRoot sets the driver root path.
304+ // This is the path in which kernel modules must be loaded.
295305func WithDriverRoot (root string ) Option {
296306 return func (c * linkCreator ) {
297307 c .driverRoot = root
298308 }
299309}
300310
311+ // WithDevRoot sets the root path for the /dev directory.
312+ func WithDevRoot (root string ) Option {
313+ return func (c * linkCreator ) {
314+ c .devRoot = root
315+ }
316+ }
317+
301318// WithDevCharPath sets the path at which the symlinks will be created.
302319func WithDevCharPath (path string ) Option {
303320 return func (c * linkCreator ) {
0 commit comments