Skip to content

Commit c11dc37

Browse files
committed
Recursively unmount /proc/driver/nvidia if it is mounted
Signed-off-by: Kevin Klues <[email protected]>
1 parent 883d613 commit c11dc37

File tree

1 file changed

+53
-0
lines changed
  • cmd/compute-domain-kubelet-plugin

1 file changed

+53
-0
lines changed

cmd/compute-domain-kubelet-plugin/nvlib.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ import (
2020
"bufio"
2121
"fmt"
2222
"os"
23+
"os/exec"
2324
"path/filepath"
2425
"strconv"
2526
"strings"
2627

2728
"golang.org/x/sys/unix"
2829

30+
"k8s.io/mount-utils"
31+
2932
nvdev "github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
3033
"github.com/NVIDIA/go-nvml/pkg/nvml"
3134
)
3235

3336
const (
3437
procDevicesPath = "/proc/devices"
38+
procDriverNvidiaPath = "/proc/driver/nvidia"
3539
nvidiaCapsDeviceName = "nvidia-caps"
3640
nvidiaCapsImexChannelsDeviceName = "nvidia-caps-imex-channels"
3741
nvidiaCapFabricImexMgmtPath = "/proc/driver/nvidia/capabilities/fabric-imex-mgmt"
@@ -76,6 +80,11 @@ func newDeviceLib(driverRoot root) (*deviceLib, error) {
7680
devRoot: driverRoot.getDevRoot(),
7781
nvidiaSMIPath: nvidiaSMIPath,
7882
}
83+
84+
if err := d.unmountRecursively(procDriverNvidiaPath); err != nil {
85+
return nil, fmt.Errorf("error recursively unmounting %s: %w", procDriverNvidiaPath, err)
86+
}
87+
7988
return &d, nil
8089
}
8190

@@ -285,3 +294,47 @@ func (l deviceLib) createNvCapDevice(nvcapFilePath string) error {
285294

286295
return nil
287296
}
297+
298+
func (l deviceLib) unmountRecursively(root string) error {
299+
// Get a reference to the mount executable.
300+
mountExecutable, err := exec.LookPath("mount")
301+
if err != nil {
302+
return fmt.Errorf("error looking up mpunt executable: %w", err)
303+
}
304+
mounter := mount.New(mountExecutable)
305+
306+
// Build a recursive helper function to unmount depth-first.
307+
var helper func(path string) error
308+
helper = func(path string) error {
309+
// Read the directory contents of path.
310+
entries, err := os.ReadDir(path)
311+
if err != nil {
312+
return fmt.Errorf("failed to read directory %s: %w", path, err)
313+
}
314+
315+
// Process each entry, recursively.
316+
for _, entry := range entries {
317+
subPath := filepath.Join(path, entry.Name())
318+
if entry.IsDir() {
319+
if err := helper(subPath); err != nil {
320+
return err
321+
}
322+
}
323+
}
324+
325+
// After processing all children, unmount the current directory if it's a mount point.
326+
mounted, err := mounter.IsMountPoint(path)
327+
if err != nil {
328+
return fmt.Errorf("failed to check mount point %s: %w", path, err)
329+
}
330+
if mounted {
331+
if err := mounter.Unmount(path); err != nil {
332+
return fmt.Errorf("failed to unmount %s: %w", path, err)
333+
}
334+
}
335+
336+
return nil
337+
}
338+
339+
return helper(root)
340+
}

0 commit comments

Comments
 (0)