Skip to content

Commit 68adf6a

Browse files
Add shouldSkipUninstall to avoid GPU driver teardown on restart
Signed-off-by: Karthik Vetrivel <[email protected]>
1 parent d349b73 commit 68adf6a

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

cmd/driver-manager/main.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ type config struct {
7878
gpuDirectRDMAEnabled bool
7979
useHostMofed bool
8080
kubeconfig string
81+
driverVersion string
82+
forceReinstall bool
8183
}
8284

8385
// ComponentState tracks the deployment state of GPU operator components
@@ -209,6 +211,20 @@ func main() {
209211
EnvVars: []string{"KUBECONFIG"},
210212
Value: "",
211213
},
214+
&cli.StringFlag{
215+
Name: "driver-version",
216+
Usage: "Desired NVIDIA driver version",
217+
Destination: &cfg.driverVersion,
218+
EnvVars: []string{"DRIVER_VERSION"},
219+
Value: "",
220+
},
221+
&cli.BoolFlag{
222+
Name: "force-reinstall",
223+
Usage: "Force driver reinstall regardless of current state",
224+
Destination: &cfg.forceReinstall,
225+
EnvVars: []string{"FORCE_REINSTALL"},
226+
Value: false,
227+
},
212228
}
213229

214230
app.Commands = []*cli.Command{
@@ -272,6 +288,11 @@ func (dm *DriverManager) uninstallDriver() error {
272288
return fmt.Errorf("driver is pre-installed on host")
273289
}
274290

291+
if skip, reason := dm.shouldSkipUninstall(); skip {
292+
dm.log.Infof("Skipping driver uninstall: %s", reason)
293+
return nil
294+
}
295+
275296
// Fetch current component states
276297
if err := dm.fetchCurrentLabels(); err != nil {
277298
return fmt.Errorf("failed to fetch current labels: %w", err)
@@ -623,6 +644,68 @@ func (dm *DriverManager) isDriverLoaded() bool {
623644
return err == nil
624645
}
625646

647+
func (dm *DriverManager) shouldSkipUninstall() (bool, string) {
648+
if dm.config.forceReinstall {
649+
dm.log.Info("Force reinstall is enabled, skipping driver uninstall")
650+
return false, ""
651+
}
652+
653+
if !dm.isDriverLoaded() {
654+
return true, "no NVIDIA driver modules detected"
655+
}
656+
657+
if dm.config.driverVersion == "" {
658+
return false, ""
659+
}
660+
661+
version, err := dm.detectCurrentDriverVersion()
662+
if err != nil {
663+
dm.log.Warnf("Unable to determine installed driver version: %v", err)
664+
// If driver is loaded but we can't detect version, skip uninstall to avoid disruption
665+
return true, "driver is loaded but version cannot be determined - skipping to avoid disruption"
666+
}
667+
668+
if version == dm.config.driverVersion {
669+
return true, "desired version already present"
670+
}
671+
672+
dm.log.Infof("Installed driver version %s does not match desired %s", version, dm.config.driverVersion)
673+
return false, ""
674+
}
675+
676+
func (dm *DriverManager) detectCurrentDriverVersion() (string, error) {
677+
baseCtx := dm.ctx
678+
if baseCtx == nil {
679+
baseCtx = context.Background()
680+
}
681+
682+
ctx, cancel := context.WithTimeout(baseCtx, 10*time.Second)
683+
defer cancel()
684+
685+
// Try chroot to /run/nvidia/driver for containerized driver
686+
cmd := exec.CommandContext(ctx, "chroot", "/run/nvidia/driver", "modinfo", "-F", "version", "nvidia")
687+
cmd.Env = append(os.Environ(), "LC_ALL=C")
688+
cmdOutput, chrootErr := cmd.Output()
689+
if chrootErr == nil {
690+
version := strings.TrimSpace(string(cmdOutput))
691+
if version != "" {
692+
dm.log.Infof("Driver version detected via chroot: %s", version)
693+
return version, nil
694+
}
695+
}
696+
697+
// Second try to read from /sys/module/nvidia/version if available
698+
if versionData, err := os.ReadFile("/sys/module/nvidia/version"); err == nil {
699+
version := strings.TrimSpace(string(versionData))
700+
if version != "" {
701+
dm.log.Infof("Driver version detected from /sys/module/nvidia/version: %s", version)
702+
return version, nil
703+
}
704+
}
705+
706+
return "", fmt.Errorf("all version detection methods failed: chroot: %v", chrootErr)
707+
}
708+
626709
func (dm *DriverManager) isNouveauLoaded() bool {
627710
_, err := os.Stat("/sys/module/nouveau/refcnt")
628711
return err == nil

0 commit comments

Comments
 (0)