Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 240 additions & 3 deletions cmd/dra-example-kubeletplugin/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,55 @@ import (
"errors"
"fmt"
"maps"
"sync"
"time"

"google.golang.org/grpc"
resourceapi "k8s.io/api/resource/v1"
"k8s.io/apimachinery/pkg/types"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
coreclientset "k8s.io/client-go/kubernetes"
"k8s.io/dynamic-resource-allocation/kubeletplugin"
"k8s.io/dynamic-resource-allocation/resourceslice"
"k8s.io/klog/v2"
drahealthv1alpha1 "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1"

"sigs.k8s.io/dra-example-driver/pkg/consts"
)

// DeviceHealthStatus represents the health state of a device
type DeviceHealthStatus struct {
Health drahealthv1alpha1.HealthStatus
Message string
}

type driver struct {
drahealthv1alpha1.UnimplementedDRAResourceHealthServer

client coreclientset.Interface
helper *kubeletplugin.Helper
state *DeviceState
healthcheck *healthcheck
cancelCtx func(error)

// Health monitoring
config *Config
simulator *HealthSimulator
healthMu sync.RWMutex
deviceHealth map[string]map[string]*DeviceHealthStatus // poolName -> deviceName -> health
healthClients []chan *drahealthv1alpha1.NodeWatchResourcesResponse
clientsMu sync.RWMutex
stopHealthCh chan struct{}
healthWg sync.WaitGroup
}

func NewDriver(ctx context.Context, config *Config) (*driver, error) {
driver := &driver{
client: config.coreclient,
cancelCtx: config.cancelMainCtx,
client: config.coreclient,
cancelCtx: config.cancelMainCtx,
config: config,
deviceHealth: make(map[string]map[string]*DeviceHealthStatus),
stopHealthCh: make(chan struct{}),
}

state, err := NewDeviceState(config)
Expand All @@ -53,6 +78,22 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
}
driver.state = state

// Initialize health monitoring if enabled
if config.flags.enableHealthReporting {
poolName := config.flags.nodeName
driver.deviceHealth[poolName] = make(map[string]*DeviceHealthStatus)

for deviceName := range state.allocatable {
driver.deviceHealth[poolName][deviceName] = &DeviceHealthStatus{
Health: drahealthv1alpha1.HealthStatus_HEALTHY,
Message: fmt.Sprintf("Device %s initialized successfully", deviceName),
}
}

driver.simulator = NewHealthSimulator(config.flags.numDevices)
klog.Infof("Device health reporting enabled for %d devices", config.flags.numDevices)
}

helper, err := kubeletplugin.Start(
ctx,
driver,
Expand Down Expand Up @@ -92,13 +133,36 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
return nil, err
}

// Start health monitoring loop if enabled
if config.flags.enableHealthReporting {
driver.healthWg.Add(1)
go driver.healthMonitoringLoop(ctx)
}

return driver, nil
}

func (d *driver) Shutdown(logger klog.Logger) error {
if d.healthcheck != nil {
d.healthcheck.Stop(logger)
}

// Stop health monitoring
if d.config.flags.enableHealthReporting {
logger.Info("Stopping device health monitoring")
close(d.stopHealthCh)

// Close all client channels
d.clientsMu.Lock()
for _, clientCh := range d.healthClients {
close(clientCh)
}
d.healthClients = nil
d.clientsMu.Unlock()

d.healthWg.Wait()
}

d.helper.Stop()
return nil
}
Expand Down Expand Up @@ -127,7 +191,7 @@ func (d *driver) prepareResourceClaim(_ context.Context, claim *resourceapi.Reso
Requests: preparedPB.GetRequestNames(),
PoolName: preparedPB.GetPoolName(),
DeviceName: preparedPB.GetDeviceName(),
CDIDeviceIDs: preparedPB.GetCDIDeviceIDs(),
CDIDeviceIDs: preparedPB.GetCdiDeviceIds(),
})
}

Expand Down Expand Up @@ -160,3 +224,176 @@ func (d *driver) HandleError(ctx context.Context, err error, msg string) {
d.cancelCtx(fmt.Errorf("fatal background error: %w", err))
}
}

// Health monitoring methods

// healthMonitoringLoop periodically checks device health and updates status
func (d *driver) healthMonitoringLoop(ctx context.Context) {
defer d.healthWg.Done()

logger := klog.FromContext(ctx)
ticker := time.NewTicker(30 * time.Second) // Check every 30 seconds
defer ticker.Stop()

logger.Info("Starting device health monitoring loop")

// Perform initial health check
d.performHealthCheck(logger)

for {
select {
case <-d.stopHealthCh:
logger.Info("Health monitoring loop stopped")
return
case <-ctx.Done():
logger.Info("Context cancelled, stopping health monitoring")
return
case <-ticker.C:
d.performHealthCheck(logger)
}
}
}

// performHealthCheck simulates health checks and updates device status
func (d *driver) performHealthCheck(logger klog.Logger) {
d.healthMu.Lock()
defer d.healthMu.Unlock()

poolName := d.config.flags.nodeName
deviceHealthMap := d.deviceHealth[poolName]

hasChanges := false
for deviceName, currentHealth := range deviceHealthMap {
// Get simulated health status from the health simulator
newHealth, newMessage := d.simulator.GetDeviceHealth(deviceName)

// Check if health status or message changed
if currentHealth.Health != newHealth || currentHealth.Message != newMessage {
currentHealth.Health = newHealth
currentHealth.Message = newMessage
hasChanges = true
logger.Info("Device health changed",
"device", deviceName,
"health", newHealth.String(),
"message", newMessage)
}
}

// If there are changes, notify all streaming clients
if hasChanges {
d.notifyClients()
}
}

// notifyClients sends health updates to all connected streaming clients
func (d *driver) notifyClients() {
poolName := d.config.flags.nodeName
deviceHealthMap := d.deviceHealth[poolName]

// Build the response with all current device health statuses
var devices []*drahealthv1alpha1.DeviceHealth
for deviceName, health := range deviceHealthMap {
devices = append(devices, &drahealthv1alpha1.DeviceHealth{
Device: &drahealthv1alpha1.DeviceIdentifier{
PoolName: poolName,
DeviceName: deviceName,
},
Health: health.Health,
LastUpdatedTime: time.Now().Unix(),
Message: health.Message,
HealthCheckTimeoutSeconds: 60, // 60 second timeout
})
}

response := &drahealthv1alpha1.NodeWatchResourcesResponse{
Devices: devices,
}

// Send to all connected clients
d.clientsMu.RLock()
defer d.clientsMu.RUnlock()

for _, clientCh := range d.healthClients {
select {
case clientCh <- response:
// Successfully sent
default:
// Client channel is full or closed, skip
}
}
}

// NodeWatchResources implements the streaming RPC for health updates
func (d *driver) NodeWatchResources(
req *drahealthv1alpha1.NodeWatchResourcesRequest,
stream grpc.ServerStreamingServer[drahealthv1alpha1.NodeWatchResourcesResponse],
) error {
logger := klog.FromContext(stream.Context())
logger.Info("New health monitoring client connected")

// Create a channel for this client
clientCh := make(chan *drahealthv1alpha1.NodeWatchResourcesResponse, 10)

// Register the client
d.clientsMu.Lock()
d.healthClients = append(d.healthClients, clientCh)
d.clientsMu.Unlock()

// Cleanup on exit
defer func() {
d.clientsMu.Lock()
for i, ch := range d.healthClients {
if ch == clientCh {
d.healthClients = append(d.healthClients[:i], d.healthClients[i+1:]...)
break
}
}
d.clientsMu.Unlock()
close(clientCh)
logger.Info("Health monitoring client disconnected")
}()

// Send initial health status immediately
d.healthMu.RLock()
poolName := d.config.flags.nodeName
deviceHealthMap := d.deviceHealth[poolName]

var initialDevices []*drahealthv1alpha1.DeviceHealth
for deviceName, health := range deviceHealthMap {
initialDevices = append(initialDevices, &drahealthv1alpha1.DeviceHealth{
Device: &drahealthv1alpha1.DeviceIdentifier{
PoolName: poolName,
DeviceName: deviceName,
},
Health: health.Health,
LastUpdatedTime: time.Now().Unix(),
Message: health.Message,
HealthCheckTimeoutSeconds: 60,
})
}
d.healthMu.RUnlock()

initialResponse := &drahealthv1alpha1.NodeWatchResourcesResponse{
Devices: initialDevices,
}

if err := stream.Send(initialResponse); err != nil {
return fmt.Errorf("failed to send initial health status: %w", err)
}

// Stream updates
for {
select {
case <-stream.Context().Done():
return stream.Context().Err()
case response, ok := <-clientCh:
if !ok {
// Channel closed, exit
return nil
}
if err := stream.Send(response); err != nil {
return fmt.Errorf("failed to send health update: %w", err)
}
}
}
}
Loading