Skip to content

Commit a88b676

Browse files
feat: health check robustness and auto-recovery
Add buffered channels (64), non-blocking writes, graceful shutdown, stats collection, and automatic device recovery detection (30s). Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]>
1 parent 5736d9b commit a88b676

File tree

22 files changed

+32259
-3
lines changed

22 files changed

+32259
-3
lines changed

internal/plugin/server.go

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ type nvidiaDevicePlugin struct {
7171
health chan *rm.Device
7272
stop chan interface{}
7373

74+
// deviceListUpdate is used to trigger ListAndWatch to send updated device
75+
// list to kubelet (e.g., when devices recover from unhealthy state)
76+
deviceListUpdate chan struct{}
77+
7478
imexChannels imex.Channels
7579

7680
mps mpsOptions
@@ -117,13 +121,18 @@ func (plugin *nvidiaDevicePlugin) initialize() {
117121
plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
118122
plugin.health = make(chan *rm.Device, healthChannelBufferSize)
119123
plugin.stop = make(chan interface{})
124+
plugin.deviceListUpdate = make(chan struct{}, 1)
120125
}
121126

122127
func (plugin *nvidiaDevicePlugin) cleanup() {
123128
close(plugin.stop)
129+
if plugin.deviceListUpdate != nil {
130+
close(plugin.deviceListUpdate)
131+
}
124132
plugin.server = nil
125133
plugin.health = nil
126134
plugin.stop = nil
135+
plugin.deviceListUpdate = nil
127136
}
128137

129138
// Devices returns the full set of devices associated with the plugin.
@@ -163,6 +172,9 @@ func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error {
163172
}
164173
}()
165174

175+
// Start recovery worker to detect when unhealthy devices become healthy
176+
go plugin.runRecoveryWorker()
177+
166178
return nil
167179
}
168180

@@ -270,7 +282,9 @@ func (plugin *nvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *plugi
270282
return options, nil
271283
}
272284

273-
// ListAndWatch lists devices and update that list according to the health status
285+
// ListAndWatch lists devices and update that list according to the health
286+
// status. This now supports device recovery: when devices that were marked
287+
// unhealthy recover, they are automatically re-advertised to kubelet.
274288
func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
275289
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
276290
return err
@@ -281,9 +295,17 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D
281295
case <-plugin.stop:
282296
return nil
283297
case d := <-plugin.health:
284-
// FIXME: there is no way to recover from the Unhealthy state.
298+
// Device marked unhealthy by health check
285299
d.Health = pluginapi.Unhealthy
286-
klog.Infof("'%s' device marked unhealthy: %s", plugin.rm.Resource(), d.ID)
300+
klog.Infof("'%s' device marked unhealthy: %s (reason: %s)",
301+
plugin.rm.Resource(), d.ID, d.UnhealthyReason)
302+
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
303+
return nil
304+
}
305+
case <-plugin.deviceListUpdate:
306+
// Device recovery or other device list change
307+
klog.Infof("'%s' device list updated, notifying kubelet",
308+
plugin.rm.Resource())
287309
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
288310
return nil
289311
}
@@ -519,6 +541,80 @@ func (plugin *nvidiaDevicePlugin) updateResponseForDeviceMounts(response *plugin
519541
}
520542
}
521543

544+
// runRecoveryWorker periodically checks if unhealthy devices have recovered
545+
// and notifies kubelet when they do.
546+
func (plugin *nvidiaDevicePlugin) runRecoveryWorker() {
547+
const recoveryInterval = 30 * time.Second
548+
549+
ticker := time.NewTicker(recoveryInterval)
550+
defer ticker.Stop()
551+
552+
klog.V(2).Infof("Recovery worker started for '%s' (interval=%v)",
553+
plugin.rm.Resource(), recoveryInterval)
554+
555+
for {
556+
select {
557+
case <-plugin.stop:
558+
klog.V(2).Info("Recovery worker stopped")
559+
return
560+
case <-ticker.C:
561+
plugin.checkForRecoveredDevices()
562+
}
563+
}
564+
}
565+
566+
// checkForRecoveredDevices checks all unhealthy devices to see if they have
567+
// recovered. If any have recovered, triggers a device list update to
568+
// kubelet.
569+
func (plugin *nvidiaDevicePlugin) checkForRecoveredDevices() {
570+
recoveredDevices := []*rm.Device{}
571+
572+
for _, d := range plugin.rm.Devices() {
573+
if !d.IsUnhealthy() {
574+
continue
575+
}
576+
577+
// Increment recovery attempts
578+
d.RecoveryAttempts++
579+
580+
// Check if device has recovered
581+
healthy, err := plugin.rm.CheckDeviceHealth(d)
582+
if err != nil {
583+
klog.V(4).Infof("Device %s recovery check failed (attempt %d): %v",
584+
d.ID, d.RecoveryAttempts, err)
585+
continue
586+
}
587+
588+
if healthy {
589+
klog.Infof("Device %s has RECOVERED! Was unhealthy for %v (reason: %s)",
590+
d.ID, d.UnhealthyDuration(), d.UnhealthyReason)
591+
d.MarkHealthy()
592+
recoveredDevices = append(recoveredDevices, d)
593+
} else {
594+
klog.V(3).Infof("Device %s still unhealthy (attempt %d, duration %v)",
595+
d.ID, d.RecoveryAttempts, d.UnhealthyDuration())
596+
}
597+
}
598+
599+
// If any devices recovered, notify ListAndWatch
600+
if len(recoveredDevices) > 0 {
601+
klog.Infof("Total recovered devices: %d", len(recoveredDevices))
602+
plugin.triggerDeviceListUpdate()
603+
}
604+
}
605+
606+
// triggerDeviceListUpdate sends a signal to ListAndWatch to send an updated
607+
// device list to kubelet. Uses a buffered channel with non-blocking send to
608+
// avoid blocking the recovery worker.
609+
func (plugin *nvidiaDevicePlugin) triggerDeviceListUpdate() {
610+
select {
611+
case plugin.deviceListUpdate <- struct{}{}:
612+
klog.V(3).Info("Device list update triggered")
613+
default:
614+
klog.V(4).Info("Device list update already pending, skipping")
615+
}
616+
}
617+
522618
func (plugin *nvidiaDevicePlugin) apiDeviceSpecs(devRoot string, ids []string) []*pluginapi.DeviceSpec {
523619
optional := map[string]bool{
524620
"/dev/nvidiactl": true,

internal/plugin/server_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ package plugin
1818

1919
import (
2020
"context"
21+
"fmt"
2122
"testing"
23+
"time"
2224

2325
"github.com/stretchr/testify/require"
2426
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
@@ -254,3 +256,96 @@ func TestCDIAllocateResponse(t *testing.T) {
254256
func ptr[T any](x T) *T {
255257
return &x
256258
}
259+
260+
func TestTriggerDeviceListUpdate_Phase2(t *testing.T) {
261+
plugin := &nvidiaDevicePlugin{
262+
deviceListUpdate: make(chan struct{}, 1),
263+
}
264+
265+
// First trigger should send signal
266+
plugin.triggerDeviceListUpdate()
267+
select {
268+
case <-plugin.deviceListUpdate:
269+
t.Log("✓ Device list update signal sent")
270+
case <-time.After(100 * time.Millisecond):
271+
t.Fatal("Signal not sent")
272+
}
273+
274+
// Second trigger with pending signal should not block
275+
plugin.triggerDeviceListUpdate()
276+
plugin.triggerDeviceListUpdate() // Should not block
277+
t.Log("✓ triggerDeviceListUpdate doesn't block when signal pending")
278+
}
279+
280+
func TestCheckForRecoveredDevices_Phase2(t *testing.T) {
281+
// Create persistent device map
282+
devices := rm.Devices{
283+
"GPU-0": &rm.Device{
284+
Device: pluginapi.Device{
285+
ID: "GPU-0",
286+
Health: pluginapi.Unhealthy,
287+
},
288+
UnhealthyReason: "XID-79",
289+
},
290+
"GPU-1": &rm.Device{
291+
Device: pluginapi.Device{
292+
ID: "GPU-1",
293+
Health: pluginapi.Unhealthy,
294+
},
295+
UnhealthyReason: "XID-48",
296+
},
297+
"GPU-2": &rm.Device{
298+
Device: pluginapi.Device{
299+
ID: "GPU-2",
300+
Health: pluginapi.Healthy,
301+
},
302+
},
303+
}
304+
305+
// Create mock resource manager with persistent devices
306+
mockRM := &rm.ResourceManagerMock{
307+
DevicesFunc: func() rm.Devices {
308+
return devices
309+
},
310+
CheckDeviceHealthFunc: func(d *rm.Device) (bool, error) {
311+
// GPU-0 recovers, GPU-1 stays unhealthy
312+
if d.ID == "GPU-0" {
313+
return true, nil
314+
}
315+
return false, fmt.Errorf("still unhealthy")
316+
},
317+
}
318+
319+
plugin := &nvidiaDevicePlugin{
320+
rm: mockRM,
321+
deviceListUpdate: make(chan struct{}, 1),
322+
}
323+
324+
plugin.checkForRecoveredDevices()
325+
326+
// Verify GPU-0 recovered
327+
gpu0 := devices["GPU-0"]
328+
require.Equal(t, pluginapi.Healthy, gpu0.Health, "GPU-0 should be healthy")
329+
require.Equal(t, "", gpu0.UnhealthyReason)
330+
t.Logf("✓ GPU-0 recovered: Health=%s, Reason=%s", gpu0.Health, gpu0.UnhealthyReason)
331+
332+
// Verify GPU-1 still unhealthy
333+
gpu1 := devices["GPU-1"]
334+
require.Equal(t, pluginapi.Unhealthy, gpu1.Health, "GPU-1 should still be unhealthy")
335+
require.Equal(t, 1, gpu1.RecoveryAttempts, "GPU-1 recovery attempts should increment")
336+
t.Logf("✓ GPU-1 still unhealthy: attempts=%d", gpu1.RecoveryAttempts)
337+
338+
// Verify GPU-2 unchanged
339+
gpu2 := devices["GPU-2"]
340+
require.Equal(t, pluginapi.Healthy, gpu2.Health)
341+
require.Equal(t, 0, gpu2.RecoveryAttempts, "Healthy device shouldn't be probed")
342+
t.Log("✓ GPU-2 unchanged (was already healthy)")
343+
344+
// Verify deviceListUpdate was triggered
345+
select {
346+
case <-plugin.deviceListUpdate:
347+
t.Log("✓ Device list update triggered for recovery")
348+
case <-time.After(100 * time.Millisecond):
349+
t.Fatal("Device list update not triggered")
350+
}
351+
}

internal/rm/devices.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"fmt"
2121
"strconv"
2222
"strings"
23+
"time"
2324

2425
"k8s.io/klog/v2"
2526
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
@@ -35,6 +36,12 @@ type Device struct {
3536
// Replicas stores the total number of times this device is replicated.
3637
// If this is 0 or 1 then the device is not shared.
3738
Replicas int
39+
40+
// Health tracking fields for recovery detection
41+
LastHealthyTime time.Time // Last time device was confirmed healthy
42+
LastUnhealthyTime time.Time // When device became unhealthy
43+
UnhealthyReason string // Human-readable reason (e.g., "XID-79")
44+
RecoveryAttempts int // Number of recovery probes attempted
3845
}
3946

4047
// deviceInfo defines the information the required to construct a Device
@@ -238,6 +245,40 @@ func (d Device) GetUUID() string {
238245
return AnnotatedID(d.ID).GetID()
239246
}
240247

248+
// MarkUnhealthy marks the device as unhealthy and records the reason and
249+
// timestamp. This should be called when a health check detects a device
250+
// failure (e.g., XID error).
251+
func (d *Device) MarkUnhealthy(reason string) {
252+
d.Health = pluginapi.Unhealthy
253+
d.LastUnhealthyTime = time.Now()
254+
d.UnhealthyReason = reason
255+
d.RecoveryAttempts = 0
256+
}
257+
258+
// MarkHealthy marks the device as healthy and clears unhealthy state. This
259+
// should be called when recovery detection confirms the device is working
260+
// again.
261+
func (d *Device) MarkHealthy() {
262+
d.Health = pluginapi.Healthy
263+
d.LastHealthyTime = time.Now()
264+
d.UnhealthyReason = ""
265+
d.RecoveryAttempts = 0
266+
}
267+
268+
// IsUnhealthy returns true if the device is currently marked as unhealthy.
269+
func (d *Device) IsUnhealthy() bool {
270+
return d.Health == pluginapi.Unhealthy
271+
}
272+
273+
// UnhealthyDuration returns how long the device has been unhealthy. Returns
274+
// zero duration if the device is healthy.
275+
func (d *Device) UnhealthyDuration() time.Duration {
276+
if !d.IsUnhealthy() {
277+
return 0
278+
}
279+
return time.Since(d.LastUnhealthyTime)
280+
}
281+
241282
// NewAnnotatedID creates a new AnnotatedID from an ID and a replica number.
242283
func NewAnnotatedID(id string, replica int) AnnotatedID {
243284
return AnnotatedID(fmt.Sprintf("%s::%d", id, replica))

internal/rm/health.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
370370

371371
klog.Infof("XidCriticalError: Xid=%d on Device=%s; marking device as unhealthy.", e.EventData, d.ID)
372372
stats.recordUnhealthy()
373+
d.MarkUnhealthy(fmt.Sprintf("XID-%d", e.EventData))
373374
sendUnhealthyDevice(unhealthy, d)
374375
}
375376
}

0 commit comments

Comments
 (0)