From 7ef316435cff5981801e31781058a6c56ec5e712 Mon Sep 17 00:00:00 2001 From: Tommy Lam Date: Tue, 7 Oct 2025 09:08:44 -0700 Subject: [PATCH 1/8] feat: implement deployment strategies with compartment-based batching --- chart/templates/skyhook-crd.yaml | 34 +++ .../api/v1alpha1/deployment_policy_types.go | 155 +++++++++++ operator/api/v1alpha1/skyhook_types.go | 3 + .../api/v1alpha1/zz_generated.deepcopy.go | 27 ++ .../bases/skyhook.nvidia.com_skyhooks.yaml | 34 +++ .../internal/controller/cluster_state_v2.go | 102 +++++++- operator/internal/wrapper/compartment.go | 181 ++++++++++++- operator/internal/wrapper/compartment_test.go | 240 ++++++++++++++++++ 8 files changed, 768 insertions(+), 8 deletions(-) create mode 100644 operator/internal/wrapper/compartment_test.go diff --git a/chart/templates/skyhook-crd.yaml b/chart/templates/skyhook-crd.yaml index 6147596c..776f1f4f 100644 --- a/chart/templates/skyhook-crd.yaml +++ b/chart/templates/skyhook-crd.yaml @@ -498,6 +498,40 @@ spec: status: description: SkyhookStatus defines the observed state of Skyhook properties: + compartmentBatchStates: + additionalProperties: + description: BatchProcessingState tracks the current state of batch + processing for a compartment + properties: + consecutiveFailures: + description: Number of consecutive failures + type: integer + currentBatch: + description: Current batch number (starts at 1) + type: integer + currentBatchNodes: + description: Names of nodes in the current batch (persisted + across reconciles) + items: + type: string + type: array + failedInBatch: + description: Number of failed nodes in current batch + type: integer + processedNodes: + description: Total number of nodes processed so far + type: integer + shouldStop: + description: Whether the strategy should stop processing due + to failures + type: boolean + successfulInBatch: + description: Number of successful nodes in current batch + type: integer + type: object + description: CompartmentBatchStates tracks batch processing state + per compartment + type: object completeNodes: default: 0/0 description: |- diff --git a/operator/api/v1alpha1/deployment_policy_types.go b/operator/api/v1alpha1/deployment_policy_types.go index a2bb8f5d..756fefc9 100644 --- a/operator/api/v1alpha1/deployment_policy_types.go +++ b/operator/api/v1alpha1/deployment_policy_types.go @@ -232,6 +232,161 @@ func (s *DeploymentStrategy) Validate() error { return nil } +// BatchProcessingState tracks the current state of batch processing for a compartment +type BatchProcessingState struct { + // Current batch number (starts at 1) + CurrentBatch int `json:"currentBatch,omitempty"` + // Number of consecutive failures + ConsecutiveFailures int `json:"consecutiveFailures,omitempty"` + // Total number of nodes processed so far + ProcessedNodes int `json:"processedNodes,omitempty"` + // Number of successful nodes in current batch + SuccessfulInBatch int `json:"successfulInBatch,omitempty"` + // Number of failed nodes in current batch + FailedInBatch int `json:"failedInBatch,omitempty"` + // Whether the strategy should stop processing due to failures + ShouldStop bool `json:"shouldStop,omitempty"` + // Names of nodes in the current batch (persisted across reconciles) + CurrentBatchNodes []string `json:"currentBatchNodes,omitempty"` + // Last successful batch size (for slowdown calculations) + LastBatchSize int `json:"lastBatchSize,omitempty"` + // Whether the last batch failed (for slowdown logic) + LastBatchFailed bool `json:"lastBatchFailed,omitempty"` +} + +// CalculateBatchSize calculates the next batch size based on the strategy +func (s *DeploymentStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { + switch { + case s.Fixed != nil: + return s.Fixed.CalculateBatchSize(totalNodes, state) + case s.Linear != nil: + return s.Linear.CalculateBatchSize(totalNodes, state) + case s.Exponential != nil: + return s.Exponential.CalculateBatchSize(totalNodes, state) + default: + return 1 // fallback + } +} + +// EvaluateBatchResult evaluates the result of a batch and updates state +func (s *DeploymentStrategy) EvaluateBatchResult(state *BatchProcessingState, batchSize int, successCount int, failureCount int, totalNodes int) { + state.SuccessfulInBatch = successCount + state.FailedInBatch = failureCount + state.ProcessedNodes += batchSize + + successPercentage := (successCount * 100) / batchSize + progressPercent := (state.ProcessedNodes * 100) / totalNodes + + if successPercentage >= s.getBatchThreshold() { + state.ConsecutiveFailures = 0 + state.LastBatchFailed = false + } else { + state.ConsecutiveFailures++ + state.LastBatchFailed = true + if progressPercent < s.getSafetyLimit() && state.ConsecutiveFailures >= s.getFailureThreshold() { + state.ShouldStop = true + } + } + + state.LastBatchSize = batchSize + state.CurrentBatch++ +} + +// getBatchThreshold returns the batch threshold from the active strategy +func (s *DeploymentStrategy) getBatchThreshold() int { + switch { + case s.Fixed != nil: + return *s.Fixed.BatchThreshold + case s.Linear != nil: + return *s.Linear.BatchThreshold + case s.Exponential != nil: + return *s.Exponential.BatchThreshold + default: + return 100 + } +} + +// getSafetyLimit returns the safety limit from the active strategy +func (s *DeploymentStrategy) getSafetyLimit() int { + switch { + case s.Fixed != nil: + return *s.Fixed.SafetyLimit + case s.Linear != nil: + return *s.Linear.SafetyLimit + case s.Exponential != nil: + return *s.Exponential.SafetyLimit + default: + return 50 + } +} + +// getFailureThreshold returns the failure threshold from the active strategy +func (s *DeploymentStrategy) getFailureThreshold() int { + switch { + case s.Fixed != nil: + return *s.Fixed.FailureThreshold + case s.Linear != nil: + return *s.Linear.FailureThreshold + case s.Exponential != nil: + return *s.Exponential.FailureThreshold + default: + return 3 + } +} + +func (s *FixedStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { + // Fixed strategy doesn't change batch size, but respects remaining nodes + batchSize := *s.InitialBatch + remaining := totalNodes - state.ProcessedNodes + if batchSize > remaining { + batchSize = remaining + } + return max(1, batchSize) +} + +func (s *LinearStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { + var batchSize int + + // Check if we should slow down due to last batch failure + progressPercent := (state.ProcessedNodes * 100) / totalNodes + if state.LastBatchFailed && progressPercent < *s.SafetyLimit && state.LastBatchSize > 0 { + // Slow down: reduce by delta from last batch size + batchSize = max(1, state.LastBatchSize-*s.Delta) + } else { + // Normal growth: initialBatch + (currentBatch - 1) * delta + batchSize = *s.InitialBatch + (state.CurrentBatch-1)*(*s.Delta) + } + + remaining := totalNodes - state.ProcessedNodes + if batchSize > remaining { + batchSize = remaining + } + return max(1, batchSize) +} + +func (s *ExponentialStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { + var batchSize int + + // Check if we should slow down due to last batch failure + progressPercent := (state.ProcessedNodes * 100) / totalNodes + if state.LastBatchFailed && progressPercent < *s.SafetyLimit && state.LastBatchSize > 0 { + // Slow down: divide last batch size by growth factor + batchSize = max(1, state.LastBatchSize / *s.GrowthFactor) + } else { + // Normal growth: initialBatch * (growthFactor ^ (currentBatch - 1)) + batchSize = *s.InitialBatch + for i := 1; i < state.CurrentBatch; i++ { + batchSize *= *s.GrowthFactor + } + } + + remaining := totalNodes - state.ProcessedNodes + if batchSize > remaining { + batchSize = remaining + } + return max(1, batchSize) +} + // Validate validates the Compartment func (c *Compartment) Validate() error { // Validate compartment budget diff --git a/operator/api/v1alpha1/skyhook_types.go b/operator/api/v1alpha1/skyhook_types.go index 9cb4489d..ec4d0f77 100644 --- a/operator/api/v1alpha1/skyhook_types.go +++ b/operator/api/v1alpha1/skyhook_types.go @@ -316,6 +316,9 @@ type SkyhookStatus struct { // ConfigUpdates tracks config updates ConfigUpdates map[string][]string `json:"configUpdates,omitempty"` + // CompartmentBatchStates tracks batch processing state per compartment + CompartmentBatchStates map[string]BatchProcessingState `json:"compartmentBatchStates,omitempty"` + // +kubebuilder:example=3 // +kubebuilder:default=0 // NodesInProgress displays the number of nodes that are currently in progress and is diff --git a/operator/api/v1alpha1/zz_generated.deepcopy.go b/operator/api/v1alpha1/zz_generated.deepcopy.go index 258999f5..3d5de27e 100644 --- a/operator/api/v1alpha1/zz_generated.deepcopy.go +++ b/operator/api/v1alpha1/zz_generated.deepcopy.go @@ -28,6 +28,26 @@ import ( "k8s.io/apimachinery/pkg/runtime" ) +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *BatchProcessingState) DeepCopyInto(out *BatchProcessingState) { + *out = *in + if in.CurrentBatchNodes != nil { + in, out := &in.CurrentBatchNodes, &out.CurrentBatchNodes + *out = make([]string, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BatchProcessingState. +func (in *BatchProcessingState) DeepCopy() *BatchProcessingState { + if in == nil { + return nil + } + out := new(BatchProcessingState) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Compartment) DeepCopyInto(out *Compartment) { *out = *in @@ -688,6 +708,13 @@ func (in *SkyhookStatus) DeepCopyInto(out *SkyhookStatus) { (*out)[key] = outVal } } + if in.CompartmentBatchStates != nil { + in, out := &in.CompartmentBatchStates, &out.CompartmentBatchStates + *out = make(map[string]BatchProcessingState, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SkyhookStatus. diff --git a/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml b/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml index 43df1e9f..9dd23ea4 100644 --- a/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml +++ b/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml @@ -499,6 +499,40 @@ spec: status: description: SkyhookStatus defines the observed state of Skyhook properties: + compartmentBatchStates: + additionalProperties: + description: BatchProcessingState tracks the current state of batch + processing for a compartment + properties: + consecutiveFailures: + description: Number of consecutive failures + type: integer + currentBatch: + description: Current batch number (starts at 1) + type: integer + currentBatchNodes: + description: Names of nodes in the current batch (persisted + across reconciles) + items: + type: string + type: array + failedInBatch: + description: Number of failed nodes in current batch + type: integer + processedNodes: + description: Total number of nodes processed so far + type: integer + shouldStop: + description: Whether the strategy should stop processing due + to failures + type: boolean + successfulInBatch: + description: Number of successful nodes in current batch + type: integer + type: object + description: CompartmentBatchStates tracks batch processing state + per compartment + type: object completeNodes: default: 0/0 description: |- diff --git a/operator/internal/controller/cluster_state_v2.go b/operator/internal/controller/cluster_state_v2.go index 134a6f1a..9c360f6f 100644 --- a/operator/internal/controller/cluster_state_v2.go +++ b/operator/internal/controller/cluster_state_v2.go @@ -111,14 +111,27 @@ func BuildState(skyhooks *v1alpha1.SkyhookList, nodes *corev1.NodeList, deployme for _, deploymentPolicy := range deploymentPolicies.Items { if deploymentPolicy.Name == skyhook.Spec.DeploymentPolicy { for _, compartment := range deploymentPolicy.Spec.Compartments { - ret.skyhooks[idx].AddCompartment(compartment.Name, wrapper.NewCompartmentWrapper(&compartment)) + // Load persisted batch state if it exists + var batchState *v1alpha1.BatchProcessingState + if skyhook.Status.CompartmentBatchStates != nil { + if state, exists := skyhook.Status.CompartmentBatchStates[compartment.Name]; exists { + batchState = &state + } + } + ret.skyhooks[idx].AddCompartment(compartment.Name, wrapper.NewCompartmentWrapper(&compartment, batchState)) } // use policy default + var defaultBatchState *v1alpha1.BatchProcessingState + if skyhook.Status.CompartmentBatchStates != nil { + if state, exists := skyhook.Status.CompartmentBatchStates[v1alpha1.DefaultCompartmentName]; exists { + defaultBatchState = &state + } + } ret.skyhooks[idx].AddCompartment(v1alpha1.DefaultCompartmentName, wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ Name: v1alpha1.DefaultCompartmentName, Budget: deploymentPolicy.Spec.Default.Budget, Strategy: deploymentPolicy.Spec.Default.Strategy, - })) + }, defaultBatchState)) } } } @@ -402,8 +415,6 @@ func (np *NodePicker) SelectNodes(s SkyhookNodes) []wrapper.SkyhookNode { np.primeAndPruneNodes(s) - nodes := make([]wrapper.SkyhookNode, 0) - // Straight from skyhook_controller CreatePodForPackage tolerations := append([]corev1.Toleration{ // tolerate all cordon { @@ -417,6 +428,47 @@ func (np *NodePicker) SelectNodes(s SkyhookNodes) []wrapper.SkyhookNode { tolerations = append(tolerations, np.runtimeRequiredToleration) } + // Check if this skyhook uses deployment policies with compartments + compartments := s.GetCompartments() + if len(compartments) > 0 { + return np.selectNodesWithCompartments(s, compartments, tolerations) + } + + // Fallback to original logic for skyhooks without deployment policies + return np.selectNodesLegacy(s, tolerations) +} + +// selectNodesWithCompartments selects nodes using compartment-based batch processing +func (np *NodePicker) selectNodesWithCompartments(s SkyhookNodes, compartments map[string]*wrapper.Compartment, tolerations []corev1.Toleration) []wrapper.SkyhookNode { + selectedNodes := make([]wrapper.SkyhookNode, 0) + nodesWithTaintTolerationIssue := make([]string, 0) + + // Process each compartment according to its strategy + for _, compartment := range compartments { + batchNodes := compartment.GetNodesForNextBatch() + + for _, node := range batchNodes { + // Check taint toleration + if CheckTaintToleration(tolerations, node.GetNode().Spec.Taints) { + selectedNodes = append(selectedNodes, node) + np.upsertPick(node.GetNode().GetName(), s.GetSkyhook()) + } else { + nodesWithTaintTolerationIssue = append(nodesWithTaintTolerationIssue, node.GetNode().Name) + node.SetStatus(v1alpha1.StatusBlocked) + } + } + } + + // Add condition about taint toleration issues + np.updateTaintToleranceCondition(s, nodesWithTaintTolerationIssue) + + return selectedNodes +} + +// selectNodesLegacy implements the original node selection logic for backward compatibility +func (np *NodePicker) selectNodesLegacy(s SkyhookNodes, tolerations []corev1.Toleration) []wrapper.SkyhookNode { + nodes := make([]wrapper.SkyhookNode, 0) + var nodeCount int if s.GetSkyhook().Spec.InterruptionBudget.Percent != nil { limit := float64(*s.GetSkyhook().Spec.InterruptionBudget.Percent) / 100 @@ -479,6 +531,13 @@ func (np *NodePicker) SelectNodes(s SkyhookNodes) []wrapper.SkyhookNode { } // if we have nodes that are not tolerable, we need to add a condition to the skyhook + np.updateTaintToleranceCondition(s, nodesWithTaintTolerationIssue) + + return final_nodes +} + +// updateTaintToleranceCondition updates the taint tolerance condition on the skyhook +func (np *NodePicker) updateTaintToleranceCondition(s SkyhookNodes, nodesWithTaintTolerationIssue []string) { if len(nodesWithTaintTolerationIssue) > 0 { s.GetSkyhook().AddCondition(metav1.Condition{ Type: fmt.Sprintf("%s/TaintNotTolerable", v1alpha1.METADATA_PREFIX), @@ -496,8 +555,6 @@ func (np *NodePicker) SelectNodes(s SkyhookNodes) []wrapper.SkyhookNode { LastTransitionTime: metav1.Now(), }) } - - return final_nodes } // for node/package source of true, its on the node (we true to reflect this on the skyhook status) @@ -536,6 +593,11 @@ func IntrospectSkyhook(skyhook SkyhookNodes, allSkyhooks []SkyhookNodes) bool { } } + // Evaluate completed batches for compartments with deployment policies + if evaluateCompletedBatches(skyhook) { + change = true + } + skyhook.UpdateCondition() if skyhook.GetSkyhook().Updated { change = true @@ -543,6 +605,34 @@ func IntrospectSkyhook(skyhook SkyhookNodes, allSkyhooks []SkyhookNodes) bool { return change } +// evaluateCompletedBatches checks if any compartment batches are complete and evaluates them +func evaluateCompletedBatches(skyhook SkyhookNodes) bool { + compartments := skyhook.GetCompartments() + if len(compartments) == 0 { + return false // No compartments to evaluate + } + + changed := false + for _, compartment := range compartments { + if isComplete, successCount, failureCount := compartment.EvaluateCurrentBatch(); isComplete { + batchSize := successCount + failureCount + + // Update the compartment's batch state using strategy logic + compartment.EvaluateAndUpdateBatchState(batchSize, successCount, failureCount) + + // Persist the updated batch state to the skyhook status + if skyhook.GetSkyhook().Status.CompartmentBatchStates == nil { + skyhook.GetSkyhook().Status.CompartmentBatchStates = make(map[string]v1alpha1.BatchProcessingState) + } + skyhook.GetSkyhook().Status.CompartmentBatchStates[compartment.GetName()] = compartment.GetBatchState() + skyhook.GetSkyhook().Updated = true + changed = true + } + } + + return changed +} + func IntrospectNode(node wrapper.SkyhookNode, skyhook SkyhookNodes) bool { skyhookStatus := skyhook.Status() diff --git a/operator/internal/wrapper/compartment.go b/operator/internal/wrapper/compartment.go index 3b97e14d..8630836c 100644 --- a/operator/internal/wrapper/compartment.go +++ b/operator/internal/wrapper/compartment.go @@ -26,15 +26,27 @@ import ( "k8s.io/apimachinery/pkg/labels" ) -func NewCompartmentWrapper(c *v1alpha1.Compartment) *Compartment { - return &Compartment{ +func NewCompartmentWrapper(c *v1alpha1.Compartment, batchState *v1alpha1.BatchProcessingState) *Compartment { + comp := &Compartment{ Compartment: *c, } + + if batchState != nil { + comp.BatchState = *batchState + } else { + comp.BatchState = v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + } + } + + return comp } type Compartment struct { v1alpha1.Compartment Nodes []SkyhookNode + // BatchState tracks the persistent batch processing state + BatchState v1alpha1.BatchProcessingState } func (c *Compartment) GetName() string { @@ -58,6 +70,171 @@ func (c *Compartment) AddNode(node SkyhookNode) { c.Nodes = append(c.Nodes, node) } +func (c *Compartment) calculateCeiling() int { + if c.Budget.Count != nil { + return *c.Budget.Count + } + if c.Budget.Percent != nil { + matched := len(c.Nodes) + if matched == 0 { + return 0 + } + limit := float64(*c.Budget.Percent) / 100 + return max(1, int(float64(matched)*limit)) + } + return 0 +} + +func (c *Compartment) getInProgressCount() int { + inProgress := 0 + for _, node := range c.Nodes { + if node.Status() == v1alpha1.StatusInProgress { + inProgress++ + } + } + return inProgress +} + +func (c *Compartment) GetNodesForNextBatch() []SkyhookNode { + if c.Strategy != nil && c.BatchState.ShouldStop { + return nil + } + + if len(c.BatchState.CurrentBatchNodes) > 0 { + return c.getCurrentBatchNodes() + } + + return c.createNewBatch() +} + +func (c *Compartment) getCurrentBatchNodes() []SkyhookNode { + currentBatchNodes := make([]SkyhookNode, 0) + for _, nodeName := range c.BatchState.CurrentBatchNodes { + for _, node := range c.Nodes { + if node.GetNode().Name == nodeName { + currentBatchNodes = append(currentBatchNodes, node) + break + } + } + } + return currentBatchNodes +} + +func (c *Compartment) createNewBatch() []SkyhookNode { + var batchSize int + if c.Strategy != nil { + batchSize = c.Strategy.CalculateBatchSize(len(c.Nodes), &c.BatchState) + } else { + ceiling := c.calculateCeiling() + availableCapacity := ceiling - c.getInProgressCount() + batchSize = max(0, availableCapacity) + } + + if batchSize <= 0 { + return nil + } + + selectedNodes := make([]SkyhookNode, 0) + priority := []v1alpha1.Status{v1alpha1.StatusInProgress, v1alpha1.StatusUnknown, v1alpha1.StatusBlocked, v1alpha1.StatusErroring} + + for _, status := range priority { + for _, node := range c.Nodes { + if len(selectedNodes) >= batchSize { + break + } + if node.Status() != status { + continue + } + if !node.IsComplete() { + selectedNodes = append(selectedNodes, node) + } + } + if len(selectedNodes) >= batchSize { + break + } + } + + nodeNames := make([]string, len(selectedNodes)) + for i, node := range selectedNodes { + nodeNames[i] = node.GetNode().Name + } + c.BatchState.CurrentBatchNodes = nodeNames + + return selectedNodes +} + +// IsBatchComplete checks if the current batch has reached terminal states +func (c *Compartment) IsBatchComplete() bool { + if len(c.BatchState.CurrentBatchNodes) == 0 { + return true // No batch in progress + } + + // Check if all batch nodes have reached terminal states + for _, nodeName := range c.BatchState.CurrentBatchNodes { + for _, node := range c.Nodes { + if node.GetNode().Name == nodeName { + if node.Status() == v1alpha1.StatusInProgress { + return false // Still processing + } + break + } + } + } + return true // All nodes are Complete or Erroring +} + +// EvaluateCurrentBatch evaluates the current batch result if it's complete +func (c *Compartment) EvaluateCurrentBatch() (bool, int, int) { + if !c.IsBatchComplete() { + return false, 0, 0 // Batch not complete yet + } + + if len(c.BatchState.CurrentBatchNodes) == 0 { + return false, 0, 0 // No batch to evaluate + } + + successCount := 0 + failureCount := 0 + + // Count successes and failures from the batch nodes + for _, nodeName := range c.BatchState.CurrentBatchNodes { + for _, node := range c.Nodes { + if node.GetNode().Name == nodeName { + if node.IsComplete() { + successCount++ + } else if node.Status() == v1alpha1.StatusErroring { + failureCount++ + } + break + } + } + } + + // Clear the current batch since we're evaluating it + c.BatchState.CurrentBatchNodes = nil + + return true, successCount, failureCount +} + +// EvaluateAndUpdateBatchState evaluates a completed batch and updates the persistent state +func (c *Compartment) EvaluateAndUpdateBatchState(batchSize int, successCount int, failureCount int) { + if c.Strategy != nil { + // Use strategy-specific evaluation + c.Strategy.EvaluateBatchResult(&c.BatchState, batchSize, successCount, failureCount, len(c.Nodes)) + } else { + // No strategy: just update basic counters + c.BatchState.ProcessedNodes += batchSize + c.BatchState.SuccessfulInBatch = successCount + c.BatchState.FailedInBatch = failureCount + c.BatchState.CurrentBatch++ + } +} + +// GetBatchState returns the current batch processing state +func (c *Compartment) GetBatchState() v1alpha1.BatchProcessingState { + return c.BatchState +} + // AssignNodeToCompartment assigns a single node to the appropriate compartment func AssignNodeToCompartment(node SkyhookNode, compartments map[string]*Compartment) (string, error) { nodeLabels := labels.Set(node.GetNode().Labels) diff --git a/operator/internal/wrapper/compartment_test.go b/operator/internal/wrapper/compartment_test.go new file mode 100644 index 00000000..33e3cb9d --- /dev/null +++ b/operator/internal/wrapper/compartment_test.go @@ -0,0 +1,240 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wrapper + +import ( + "github.com/NVIDIA/skyhook/operator/api/v1alpha1" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "k8s.io/utils/ptr" +) + +var _ = Describe("Compartment", func() { + Context("calculateCeiling", func() { + It("should calculate ceiling for count budget", func() { + compartment := &Compartment{ + Compartment: v1alpha1.Compartment{ + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(3)}, + }, + } + + // Add 10 mock nodes (just need count for ceiling calculation) + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + ceiling := compartment.calculateCeiling() + Expect(ceiling).To(Equal(3)) + }) + + It("should calculate ceiling for percent budget", func() { + compartment := &Compartment{ + Compartment: v1alpha1.Compartment{ + Budget: v1alpha1.DeploymentBudget{Percent: ptr.To(30)}, + }, + } + + // Add 10 mock nodes - 30% should be 3 + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + ceiling := compartment.calculateCeiling() + Expect(ceiling).To(Equal(3)) // max(1, int(10 * 0.3)) = 3 + }) + + It("should handle small percent budgets with minimum 1", func() { + compartment := &Compartment{ + Compartment: v1alpha1.Compartment{ + Budget: v1alpha1.DeploymentBudget{Percent: ptr.To(30)}, + }, + } + + // Add 2 mock nodes - 30% of 2 = 0.6, should round to 1 + for i := 0; i < 2; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + ceiling := compartment.calculateCeiling() + Expect(ceiling).To(Equal(1)) // max(1, int(2 * 0.3)) = max(1, 0) = 1 + }) + + It("should return 0 for no nodes", func() { + compartment := &Compartment{ + Compartment: v1alpha1.Compartment{ + Budget: v1alpha1.DeploymentBudget{Percent: ptr.To(50)}, + }, + } + + ceiling := compartment.calculateCeiling() + Expect(ceiling).To(Equal(0)) + }) + }) + + Context("NewCompartmentWrapperWithState", func() { + It("should create compartment with provided batch state", func() { + batchState := &v1alpha1.BatchProcessingState{ + CurrentBatch: 3, + ConsecutiveFailures: 1, + ProcessedNodes: 5, + } + + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(5)}, + }, batchState) + + state := compartment.GetBatchState() + Expect(state.CurrentBatch).To(Equal(3)) + Expect(state.ConsecutiveFailures).To(Equal(1)) + Expect(state.ProcessedNodes).To(Equal(5)) + }) + + It("should create compartment with default batch state when nil", func() { + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(5)}, + }, nil) + + state := compartment.GetBatchState() + Expect(state.CurrentBatch).To(Equal(1)) + Expect(state.ConsecutiveFailures).To(Equal(0)) + Expect(state.ProcessedNodes).To(Equal(0)) + }) + }) + + Context("EvaluateAndUpdateBatchState", func() { + It("should update basic state without strategy", func() { + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test-compartment", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(10)}, + }, &v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + ProcessedNodes: 0, + }) + + compartment.EvaluateAndUpdateBatchState(3, 2, 1) + + state := compartment.GetBatchState() + Expect(state.ProcessedNodes).To(Equal(3)) + Expect(state.CurrentBatch).To(Equal(2)) + Expect(state.SuccessfulInBatch).To(Equal(2)) + Expect(state.FailedInBatch).To(Equal(1)) + }) + + It("should reset consecutive failures on successful batch", func() { + strategy := &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{ + InitialBatch: ptr.To(3), + BatchThreshold: ptr.To(80), + FailureThreshold: ptr.To(2), + SafetyLimit: ptr.To(50), + }, + } + + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test-compartment", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(10)}, + Strategy: strategy, + }, &v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + ProcessedNodes: 0, + ConsecutiveFailures: 1, // Should reset on success + }) + + // Add 10 mock nodes for totalNodes calculation + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + // 80% success (4 out of 5) + compartment.EvaluateAndUpdateBatchState(5, 4, 1) + + state := compartment.GetBatchState() + Expect(state.ConsecutiveFailures).To(Equal(0)) // Should reset + Expect(state.ShouldStop).To(BeFalse()) + }) + + It("should increment consecutive failures and trigger stop when below safety limit", func() { + strategy := &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{ + InitialBatch: ptr.To(3), + BatchThreshold: ptr.To(80), + FailureThreshold: ptr.To(2), + SafetyLimit: ptr.To(50), + }, + } + + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test-compartment", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(10)}, + Strategy: strategy, + }, &v1alpha1.BatchProcessingState{ + CurrentBatch: 2, + ProcessedNodes: 1, // After adding 3 more: (1+3)/10 = 40% (below 50% safety limit) + ConsecutiveFailures: 1, // Will increment to 2 (threshold) + }) + + // Add 10 mock nodes for totalNodes calculation + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + // 33% success (1 out of 3) - below 80% threshold, progress will be (1+3)/10 = 40% (below safety limit) + compartment.EvaluateAndUpdateBatchState(3, 1, 2) + + state := compartment.GetBatchState() + Expect(state.ConsecutiveFailures).To(Equal(2)) // Should increment + Expect(state.ShouldStop).To(BeTrue()) // Should trigger stop (below safety limit) + }) + + It("should not trigger stop when above safety limit", func() { + strategy := &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{ + InitialBatch: ptr.To(3), + BatchThreshold: ptr.To(80), + FailureThreshold: ptr.To(2), + SafetyLimit: ptr.To(50), + }, + } + + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test-compartment", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(10)}, + Strategy: strategy, + }, &v1alpha1.BatchProcessingState{ + CurrentBatch: 3, + ProcessedNodes: 6, // 60% progress (above 50% safety limit) + ConsecutiveFailures: 1, + }) + + // Add 10 mock nodes for totalNodes calculation + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + // 40% success (2 out of 5) - below 80% threshold, but above safety limit + compartment.EvaluateAndUpdateBatchState(5, 2, 3) + + state := compartment.GetBatchState() + Expect(state.ConsecutiveFailures).To(Equal(2)) // Should increment + Expect(state.ShouldStop).To(BeFalse()) // Should NOT stop (above safety limit) + }) + }) +}) From eb0312ae95f91e936d3859432718a332ed6b3a60 Mon Sep 17 00:00:00 2001 From: Tommy Lam Date: Tue, 7 Oct 2025 13:30:54 -0700 Subject: [PATCH 2/8] fix batch slowdown and persistence --- chart/templates/skyhook-crd.yaml | 6 ++++ .../api/v1alpha1/deployment_policy_types.go | 35 ++++++++++++++++--- .../api/v1alpha1/zz_generated.deepcopy.go | 4 +-- .../bases/skyhook.nvidia.com_skyhooks.yaml | 6 ++++ .../internal/controller/cluster_state_v2.go | 28 +++++++++++++++ .../internal/controller/skyhook_controller.go | 4 +++ operator/internal/wrapper/compartment.go | 2 +- 7 files changed, 77 insertions(+), 8 deletions(-) diff --git a/chart/templates/skyhook-crd.yaml b/chart/templates/skyhook-crd.yaml index 776f1f4f..0b287bef 100644 --- a/chart/templates/skyhook-crd.yaml +++ b/chart/templates/skyhook-crd.yaml @@ -518,6 +518,12 @@ spec: failedInBatch: description: Number of failed nodes in current batch type: integer + lastBatchFailed: + description: Whether the last batch failed (for slowdown logic) + type: boolean + lastBatchSize: + description: Last successful batch size (for slowdown calculations) + type: integer processedNodes: description: Total number of nodes processed so far type: integer diff --git a/operator/api/v1alpha1/deployment_policy_types.go b/operator/api/v1alpha1/deployment_policy_types.go index 756fefc9..9f249d0c 100644 --- a/operator/api/v1alpha1/deployment_policy_types.go +++ b/operator/api/v1alpha1/deployment_policy_types.go @@ -25,6 +25,7 @@ package v1alpha1 import ( "fmt" + "math" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/utils/ptr" @@ -274,8 +275,16 @@ func (s *DeploymentStrategy) EvaluateBatchResult(state *BatchProcessingState, ba state.FailedInBatch = failureCount state.ProcessedNodes += batchSize + // Avoid divide by zero + if batchSize == 0 { + return + } + successPercentage := (successCount * 100) / batchSize - progressPercent := (state.ProcessedNodes * 100) / totalNodes + var progressPercent int + if totalNodes > 0 { + progressPercent = (state.ProcessedNodes * 100) / totalNodes + } if successPercentage >= s.getBatchThreshold() { state.ConsecutiveFailures = 0 @@ -347,6 +356,11 @@ func (s *FixedStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessin func (s *LinearStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { var batchSize int + // Avoid divide by zero + if totalNodes == 0 { + return 0 + } + // Check if we should slow down due to last batch failure progressPercent := (state.ProcessedNodes * 100) / totalNodes if state.LastBatchFailed && progressPercent < *s.SafetyLimit && state.LastBatchSize > 0 { @@ -367,16 +381,27 @@ func (s *LinearStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessi func (s *ExponentialStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { var batchSize int + // Avoid divide by zero + if totalNodes == 0 { + return 0 + } + // Check if we should slow down due to last batch failure progressPercent := (state.ProcessedNodes * 100) / totalNodes - if state.LastBatchFailed && progressPercent < *s.SafetyLimit && state.LastBatchSize > 0 { + if state.LastBatchFailed && progressPercent < *s.SafetyLimit && state.LastBatchSize > 0 && *s.GrowthFactor > 0 { // Slow down: divide last batch size by growth factor batchSize = max(1, state.LastBatchSize / *s.GrowthFactor) } else { // Normal growth: initialBatch * (growthFactor ^ (currentBatch - 1)) - batchSize = *s.InitialBatch - for i := 1; i < state.CurrentBatch; i++ { - batchSize *= *s.GrowthFactor + // Use math.Pow for efficiency and to avoid overflow issues with large batch numbers + exponent := state.CurrentBatch - 1 + growthMultiplier := math.Pow(float64(*s.GrowthFactor), float64(exponent)) + batchSize = int(float64(*s.InitialBatch) * growthMultiplier) + + // Cap at remaining nodes to prevent unreasonably large batch sizes + // and potential integer overflow + if batchSize > totalNodes { + batchSize = totalNodes } } diff --git a/operator/api/v1alpha1/zz_generated.deepcopy.go b/operator/api/v1alpha1/zz_generated.deepcopy.go index 3d5de27e..d08ccd46 100644 --- a/operator/api/v1alpha1/zz_generated.deepcopy.go +++ b/operator/api/v1alpha1/zz_generated.deepcopy.go @@ -1,5 +1,3 @@ -//go:build !ignore_autogenerated - /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 @@ -18,6 +16,8 @@ * limitations under the License. */ +//go:build !ignore_autogenerated + // Code generated by controller-gen. DO NOT EDIT. package v1alpha1 diff --git a/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml b/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml index 9dd23ea4..194e8703 100644 --- a/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml +++ b/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml @@ -519,6 +519,12 @@ spec: failedInBatch: description: Number of failed nodes in current batch type: integer + lastBatchFailed: + description: Whether the last batch failed (for slowdown logic) + type: boolean + lastBatchSize: + description: Last successful batch size (for slowdown calculations) + type: integer processedNodes: description: Total number of nodes processed so far type: integer diff --git a/operator/internal/controller/cluster_state_v2.go b/operator/internal/controller/cluster_state_v2.go index 9c360f6f..ffbbcefd 100644 --- a/operator/internal/controller/cluster_state_v2.go +++ b/operator/internal/controller/cluster_state_v2.go @@ -465,6 +465,34 @@ func (np *NodePicker) selectNodesWithCompartments(s SkyhookNodes, compartments m return selectedNodes } +// PersistCompartmentBatchStates saves the current batch state for all compartments to the Skyhook status +func PersistCompartmentBatchStates(skyhook SkyhookNodes) bool { + compartments := skyhook.GetCompartments() + if len(compartments) == 0 { + return false // No compartments, nothing to persist + } + + // Initialize the batch states map if needed + if skyhook.GetSkyhook().Status.CompartmentBatchStates == nil { + skyhook.GetSkyhook().Status.CompartmentBatchStates = make(map[string]v1alpha1.BatchProcessingState) + } + + changed := false + for _, compartment := range compartments { + // Only persist if there are nodes in the current batch + if len(compartment.GetBatchState().CurrentBatchNodes) > 0 { + skyhook.GetSkyhook().Status.CompartmentBatchStates[compartment.GetName()] = compartment.GetBatchState() + changed = true + } + } + + if changed { + skyhook.GetSkyhook().Updated = true + } + + return changed +} + // selectNodesLegacy implements the original node selection logic for backward compatibility func (np *NodePicker) selectNodesLegacy(s SkyhookNodes, tolerations []corev1.Toleration) []wrapper.SkyhookNode { nodes := make([]wrapper.SkyhookNode, 0) diff --git a/operator/internal/controller/skyhook_controller.go b/operator/internal/controller/skyhook_controller.go index ae4adc49..102adc44 100644 --- a/operator/internal/controller/skyhook_controller.go +++ b/operator/internal/controller/skyhook_controller.go @@ -560,6 +560,10 @@ func (r *SkyhookReconciler) RunSkyhookPackages(ctx context.Context, clusterState } selectedNode := nodePicker.SelectNodes(skyhook) + + // Persist compartment batch states after node selection + PersistCompartmentBatchStates(skyhook) + for _, node := range selectedNode { if node.IsComplete() && !node.Changed() { diff --git a/operator/internal/wrapper/compartment.go b/operator/internal/wrapper/compartment.go index 8630836c..c68bd01e 100644 --- a/operator/internal/wrapper/compartment.go +++ b/operator/internal/wrapper/compartment.go @@ -135,7 +135,7 @@ func (c *Compartment) createNewBatch() []SkyhookNode { } selectedNodes := make([]SkyhookNode, 0) - priority := []v1alpha1.Status{v1alpha1.StatusInProgress, v1alpha1.StatusUnknown, v1alpha1.StatusBlocked, v1alpha1.StatusErroring} + priority := []v1alpha1.Status{v1alpha1.StatusInProgress, v1alpha1.StatusUnknown, v1alpha1.StatusErroring} for _, status := range priority { for _, node := range c.Nodes { From f383f36e6dee2b204d108c0383822b010e4e3979 Mon Sep 17 00:00:00 2001 From: Tommy Lam Date: Thu, 9 Oct 2025 10:26:57 -0700 Subject: [PATCH 3/8] only persist key info --- .../api/v1alpha1/deployment_policy_types.go | 38 ++++---- .../api/v1alpha1/zz_generated.deepcopy.go | 7 +- .../bases/skyhook.nvidia.com_skyhooks.yaml | 23 ++--- .../internal/controller/cluster_state_v2.go | 8 +- operator/internal/wrapper/compartment.go | 95 ++++++++----------- operator/internal/wrapper/compartment_test.go | 29 +++--- 6 files changed, 92 insertions(+), 108 deletions(-) diff --git a/operator/api/v1alpha1/deployment_policy_types.go b/operator/api/v1alpha1/deployment_policy_types.go index 9f249d0c..aec34f3c 100644 --- a/operator/api/v1alpha1/deployment_policy_types.go +++ b/operator/api/v1alpha1/deployment_policy_types.go @@ -239,17 +239,13 @@ type BatchProcessingState struct { CurrentBatch int `json:"currentBatch,omitempty"` // Number of consecutive failures ConsecutiveFailures int `json:"consecutiveFailures,omitempty"` - // Total number of nodes processed so far - ProcessedNodes int `json:"processedNodes,omitempty"` - // Number of successful nodes in current batch - SuccessfulInBatch int `json:"successfulInBatch,omitempty"` - // Number of failed nodes in current batch - FailedInBatch int `json:"failedInBatch,omitempty"` + // Total number of nodes that have completed successfully (cumulative across all batches) + CompletedNodes int `json:"completedNodes,omitempty"` + // Total number of nodes that have failed (cumulative across all batches) + FailedNodes int `json:"failedNodes,omitempty"` // Whether the strategy should stop processing due to failures ShouldStop bool `json:"shouldStop,omitempty"` - // Names of nodes in the current batch (persisted across reconciles) - CurrentBatchNodes []string `json:"currentBatchNodes,omitempty"` - // Last successful batch size (for slowdown calculations) + // Last batch size (for slowdown calculations) LastBatchSize int `json:"lastBatchSize,omitempty"` // Whether the last batch failed (for slowdown logic) LastBatchFailed bool `json:"lastBatchFailed,omitempty"` @@ -271,19 +267,22 @@ func (s *DeploymentStrategy) CalculateBatchSize(totalNodes int, state *BatchProc // EvaluateBatchResult evaluates the result of a batch and updates state func (s *DeploymentStrategy) EvaluateBatchResult(state *BatchProcessingState, batchSize int, successCount int, failureCount int, totalNodes int) { - state.SuccessfulInBatch = successCount - state.FailedInBatch = failureCount - state.ProcessedNodes += batchSize + // Note: successCount and failureCount are deltas from the current batch + // CompletedNodes and FailedNodes are already updated in EvaluateCurrentBatch before this is called // Avoid divide by zero if batchSize == 0 { return } + // Calculate success percentage for this batch successPercentage := (successCount * 100) / batchSize + + // Calculate overall progress percentage + processedNodes := state.CompletedNodes + state.FailedNodes var progressPercent int if totalNodes > 0 { - progressPercent = (state.ProcessedNodes * 100) / totalNodes + progressPercent = (processedNodes * 100) / totalNodes } if successPercentage >= s.getBatchThreshold() { @@ -346,7 +345,8 @@ func (s *DeploymentStrategy) getFailureThreshold() int { func (s *FixedStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { // Fixed strategy doesn't change batch size, but respects remaining nodes batchSize := *s.InitialBatch - remaining := totalNodes - state.ProcessedNodes + processedNodes := state.CompletedNodes + state.FailedNodes + remaining := totalNodes - processedNodes if batchSize > remaining { batchSize = remaining } @@ -362,7 +362,8 @@ func (s *LinearStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessi } // Check if we should slow down due to last batch failure - progressPercent := (state.ProcessedNodes * 100) / totalNodes + processedNodes := state.CompletedNodes + state.FailedNodes + progressPercent := (processedNodes * 100) / totalNodes if state.LastBatchFailed && progressPercent < *s.SafetyLimit && state.LastBatchSize > 0 { // Slow down: reduce by delta from last batch size batchSize = max(1, state.LastBatchSize-*s.Delta) @@ -371,7 +372,7 @@ func (s *LinearStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessi batchSize = *s.InitialBatch + (state.CurrentBatch-1)*(*s.Delta) } - remaining := totalNodes - state.ProcessedNodes + remaining := totalNodes - processedNodes if batchSize > remaining { batchSize = remaining } @@ -387,7 +388,8 @@ func (s *ExponentialStrategy) CalculateBatchSize(totalNodes int, state *BatchPro } // Check if we should slow down due to last batch failure - progressPercent := (state.ProcessedNodes * 100) / totalNodes + processedNodes := state.CompletedNodes + state.FailedNodes + progressPercent := (processedNodes * 100) / totalNodes if state.LastBatchFailed && progressPercent < *s.SafetyLimit && state.LastBatchSize > 0 && *s.GrowthFactor > 0 { // Slow down: divide last batch size by growth factor batchSize = max(1, state.LastBatchSize / *s.GrowthFactor) @@ -405,7 +407,7 @@ func (s *ExponentialStrategy) CalculateBatchSize(totalNodes int, state *BatchPro } } - remaining := totalNodes - state.ProcessedNodes + remaining := totalNodes - processedNodes if batchSize > remaining { batchSize = remaining } diff --git a/operator/api/v1alpha1/zz_generated.deepcopy.go b/operator/api/v1alpha1/zz_generated.deepcopy.go index d08ccd46..e403467d 100644 --- a/operator/api/v1alpha1/zz_generated.deepcopy.go +++ b/operator/api/v1alpha1/zz_generated.deepcopy.go @@ -31,11 +31,6 @@ import ( // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *BatchProcessingState) DeepCopyInto(out *BatchProcessingState) { *out = *in - if in.CurrentBatchNodes != nil { - in, out := &in.CurrentBatchNodes, &out.CurrentBatchNodes - *out = make([]string, len(*in)) - copy(*out, *in) - } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BatchProcessingState. @@ -712,7 +707,7 @@ func (in *SkyhookStatus) DeepCopyInto(out *SkyhookStatus) { in, out := &in.CompartmentBatchStates, &out.CompartmentBatchStates *out = make(map[string]BatchProcessingState, len(*in)) for key, val := range *in { - (*out)[key] = *val.DeepCopy() + (*out)[key] = val } } } diff --git a/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml b/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml index 194e8703..54689a82 100644 --- a/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml +++ b/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml @@ -504,37 +504,30 @@ spec: description: BatchProcessingState tracks the current state of batch processing for a compartment properties: + completedNodes: + description: Total number of nodes that have completed successfully + (cumulative across all batches) + type: integer consecutiveFailures: description: Number of consecutive failures type: integer currentBatch: description: Current batch number (starts at 1) type: integer - currentBatchNodes: - description: Names of nodes in the current batch (persisted - across reconciles) - items: - type: string - type: array - failedInBatch: - description: Number of failed nodes in current batch + failedNodes: + description: Total number of nodes that have failed (cumulative + across all batches) type: integer lastBatchFailed: description: Whether the last batch failed (for slowdown logic) type: boolean lastBatchSize: - description: Last successful batch size (for slowdown calculations) - type: integer - processedNodes: - description: Total number of nodes processed so far + description: Last batch size (for slowdown calculations) type: integer shouldStop: description: Whether the strategy should stop processing due to failures type: boolean - successfulInBatch: - description: Number of successful nodes in current batch - type: integer type: object description: CompartmentBatchStates tracks batch processing state per compartment diff --git a/operator/internal/controller/cluster_state_v2.go b/operator/internal/controller/cluster_state_v2.go index ffbbcefd..36057a77 100644 --- a/operator/internal/controller/cluster_state_v2.go +++ b/operator/internal/controller/cluster_state_v2.go @@ -479,9 +479,11 @@ func PersistCompartmentBatchStates(skyhook SkyhookNodes) bool { changed := false for _, compartment := range compartments { - // Only persist if there are nodes in the current batch - if len(compartment.GetBatchState().CurrentBatchNodes) > 0 { - skyhook.GetSkyhook().Status.CompartmentBatchStates[compartment.GetName()] = compartment.GetBatchState() + // Always persist batch state to maintain cumulative counters + batchState := compartment.GetBatchState() + // Only persist if there's meaningful state (batch has started or there are nodes) + if batchState.CurrentBatch > 0 || len(compartment.GetNodes()) > 0 { + skyhook.GetSkyhook().Status.CompartmentBatchStates[compartment.GetName()] = batchState changed = true } } diff --git a/operator/internal/wrapper/compartment.go b/operator/internal/wrapper/compartment.go index c68bd01e..a5a65bc0 100644 --- a/operator/internal/wrapper/compartment.go +++ b/operator/internal/wrapper/compartment.go @@ -100,24 +100,23 @@ func (c *Compartment) GetNodesForNextBatch() []SkyhookNode { return nil } - if len(c.BatchState.CurrentBatchNodes) > 0 { - return c.getCurrentBatchNodes() + // If there's a batch in progress (nodes are InProgress), don't start a new one + if c.getInProgressCount() > 0 { + return c.getInProgressNodes() } + // No batch in progress, create a new one return c.createNewBatch() } -func (c *Compartment) getCurrentBatchNodes() []SkyhookNode { - currentBatchNodes := make([]SkyhookNode, 0) - for _, nodeName := range c.BatchState.CurrentBatchNodes { - for _, node := range c.Nodes { - if node.GetNode().Name == nodeName { - currentBatchNodes = append(currentBatchNodes, node) - break - } +func (c *Compartment) getInProgressNodes() []SkyhookNode { + inProgressNodes := make([]SkyhookNode, 0) + for _, node := range c.Nodes { + if node.Status() == v1alpha1.StatusInProgress { + inProgressNodes = append(inProgressNodes, node) } } - return currentBatchNodes + return inProgressNodes } func (c *Compartment) createNewBatch() []SkyhookNode { @@ -154,66 +153,54 @@ func (c *Compartment) createNewBatch() []SkyhookNode { } } - nodeNames := make([]string, len(selectedNodes)) - for i, node := range selectedNodes { - nodeNames[i] = node.GetNode().Name - } - c.BatchState.CurrentBatchNodes = nodeNames - return selectedNodes } // IsBatchComplete checks if the current batch has reached terminal states +// A batch is complete when there are no nodes in InProgress status func (c *Compartment) IsBatchComplete() bool { - if len(c.BatchState.CurrentBatchNodes) == 0 { - return true // No batch in progress - } - - // Check if all batch nodes have reached terminal states - for _, nodeName := range c.BatchState.CurrentBatchNodes { - for _, node := range c.Nodes { - if node.GetNode().Name == nodeName { - if node.Status() == v1alpha1.StatusInProgress { - return false // Still processing - } - break - } - } - } - return true // All nodes are Complete or Erroring + return c.getInProgressCount() == 0 } // EvaluateCurrentBatch evaluates the current batch result if it's complete +// Uses delta-based tracking: compares current state to last checkpoint func (c *Compartment) EvaluateCurrentBatch() (bool, int, int) { if !c.IsBatchComplete() { return false, 0, 0 // Batch not complete yet } - if len(c.BatchState.CurrentBatchNodes) == 0 { - return false, 0, 0 // No batch to evaluate + // If this is the first batch (nothing has been processed yet), skip evaluation + // The batch will be started in the next reconcile + if c.BatchState.CurrentBatch == 0 { + c.BatchState.CurrentBatch = 1 + return false, 0, 0 } - successCount := 0 - failureCount := 0 - - // Count successes and failures from the batch nodes - for _, nodeName := range c.BatchState.CurrentBatchNodes { - for _, node := range c.Nodes { - if node.GetNode().Name == nodeName { - if node.IsComplete() { - successCount++ - } else if node.Status() == v1alpha1.StatusErroring { - failureCount++ - } - break - } + // Count current state in the compartment + currentCompleted := 0 + currentFailed := 0 + for _, node := range c.Nodes { + if node.IsComplete() { + currentCompleted++ + } else if node.Status() == v1alpha1.StatusErroring { + currentFailed++ } } - // Clear the current batch since we're evaluating it - c.BatchState.CurrentBatchNodes = nil + // Calculate delta from last checkpoint + deltaCompleted := currentCompleted - c.BatchState.CompletedNodes + deltaFailed := currentFailed - c.BatchState.FailedNodes + + // Only evaluate if there's actually a change (batch was processed) + if deltaCompleted == 0 && deltaFailed == 0 { + return false, 0, 0 + } + + // Update checkpoints + c.BatchState.CompletedNodes = currentCompleted + c.BatchState.FailedNodes = currentFailed - return true, successCount, failureCount + return true, deltaCompleted, deltaFailed } // EvaluateAndUpdateBatchState evaluates a completed batch and updates the persistent state @@ -223,10 +210,8 @@ func (c *Compartment) EvaluateAndUpdateBatchState(batchSize int, successCount in c.Strategy.EvaluateBatchResult(&c.BatchState, batchSize, successCount, failureCount, len(c.Nodes)) } else { // No strategy: just update basic counters - c.BatchState.ProcessedNodes += batchSize - c.BatchState.SuccessfulInBatch = successCount - c.BatchState.FailedInBatch = failureCount c.BatchState.CurrentBatch++ + c.BatchState.LastBatchSize = batchSize } } diff --git a/operator/internal/wrapper/compartment_test.go b/operator/internal/wrapper/compartment_test.go index 33e3cb9d..df2869ec 100644 --- a/operator/internal/wrapper/compartment_test.go +++ b/operator/internal/wrapper/compartment_test.go @@ -92,7 +92,8 @@ var _ = Describe("Compartment", func() { batchState := &v1alpha1.BatchProcessingState{ CurrentBatch: 3, ConsecutiveFailures: 1, - ProcessedNodes: 5, + CompletedNodes: 4, + FailedNodes: 1, } compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ @@ -103,7 +104,8 @@ var _ = Describe("Compartment", func() { state := compartment.GetBatchState() Expect(state.CurrentBatch).To(Equal(3)) Expect(state.ConsecutiveFailures).To(Equal(1)) - Expect(state.ProcessedNodes).To(Equal(5)) + Expect(state.CompletedNodes).To(Equal(4)) + Expect(state.FailedNodes).To(Equal(1)) }) It("should create compartment with default batch state when nil", func() { @@ -115,7 +117,8 @@ var _ = Describe("Compartment", func() { state := compartment.GetBatchState() Expect(state.CurrentBatch).To(Equal(1)) Expect(state.ConsecutiveFailures).To(Equal(0)) - Expect(state.ProcessedNodes).To(Equal(0)) + Expect(state.CompletedNodes).To(Equal(0)) + Expect(state.FailedNodes).To(Equal(0)) }) }) @@ -126,16 +129,15 @@ var _ = Describe("Compartment", func() { Budget: v1alpha1.DeploymentBudget{Count: ptr.To(10)}, }, &v1alpha1.BatchProcessingState{ CurrentBatch: 1, - ProcessedNodes: 0, + CompletedNodes: 0, + FailedNodes: 0, }) compartment.EvaluateAndUpdateBatchState(3, 2, 1) state := compartment.GetBatchState() - Expect(state.ProcessedNodes).To(Equal(3)) Expect(state.CurrentBatch).To(Equal(2)) - Expect(state.SuccessfulInBatch).To(Equal(2)) - Expect(state.FailedInBatch).To(Equal(1)) + Expect(state.LastBatchSize).To(Equal(3)) }) It("should reset consecutive failures on successful batch", func() { @@ -154,7 +156,8 @@ var _ = Describe("Compartment", func() { Strategy: strategy, }, &v1alpha1.BatchProcessingState{ CurrentBatch: 1, - ProcessedNodes: 0, + CompletedNodes: 4, // Simulating cumulative state after batch evaluation + FailedNodes: 1, ConsecutiveFailures: 1, // Should reset on success }) @@ -163,7 +166,7 @@ var _ = Describe("Compartment", func() { compartment.Nodes = append(compartment.Nodes, nil) } - // 80% success (4 out of 5) + // 80% success (4 out of 5) - using delta values compartment.EvaluateAndUpdateBatchState(5, 4, 1) state := compartment.GetBatchState() @@ -187,7 +190,8 @@ var _ = Describe("Compartment", func() { Strategy: strategy, }, &v1alpha1.BatchProcessingState{ CurrentBatch: 2, - ProcessedNodes: 1, // After adding 3 more: (1+3)/10 = 40% (below 50% safety limit) + CompletedNodes: 1, // After this batch: (1+3)/10 = 40% (below 50% safety limit) + FailedNodes: 0, // Will add 2 more ConsecutiveFailures: 1, // Will increment to 2 (threshold) }) @@ -220,7 +224,8 @@ var _ = Describe("Compartment", func() { Strategy: strategy, }, &v1alpha1.BatchProcessingState{ CurrentBatch: 3, - ProcessedNodes: 6, // 60% progress (above 50% safety limit) + CompletedNodes: 4, // After this batch: (4+2+3)/10 = 90% but we use cumulative + FailedNodes: 2, // Total 6 processed, 60% (above 50% safety limit) ConsecutiveFailures: 1, }) @@ -230,6 +235,8 @@ var _ = Describe("Compartment", func() { } // 40% success (2 out of 5) - below 80% threshold, but above safety limit + // After evaluation: CompletedNodes would be 6, FailedNodes would be 5, total 11 processed + // For this test, we assume deltas add to existing: 4+2=6 complete, 2+3=5 failed = 11/10 compartment.EvaluateAndUpdateBatchState(5, 2, 3) state := compartment.GetBatchState() From 93e3bae908068bc865484edb865756a047839e01 Mon Sep 17 00:00:00 2001 From: Tommy Lam Date: Thu, 9 Oct 2025 10:28:31 -0700 Subject: [PATCH 4/8] update crd --- chart/templates/skyhook-crd.yaml | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/chart/templates/skyhook-crd.yaml b/chart/templates/skyhook-crd.yaml index 0b287bef..f9d5c854 100644 --- a/chart/templates/skyhook-crd.yaml +++ b/chart/templates/skyhook-crd.yaml @@ -503,37 +503,30 @@ spec: description: BatchProcessingState tracks the current state of batch processing for a compartment properties: + completedNodes: + description: Total number of nodes that have completed successfully + (cumulative across all batches) + type: integer consecutiveFailures: description: Number of consecutive failures type: integer currentBatch: description: Current batch number (starts at 1) type: integer - currentBatchNodes: - description: Names of nodes in the current batch (persisted - across reconciles) - items: - type: string - type: array - failedInBatch: - description: Number of failed nodes in current batch + failedNodes: + description: Total number of nodes that have failed (cumulative + across all batches) type: integer lastBatchFailed: description: Whether the last batch failed (for slowdown logic) type: boolean lastBatchSize: - description: Last successful batch size (for slowdown calculations) - type: integer - processedNodes: - description: Total number of nodes processed so far + description: Last batch size (for slowdown calculations) type: integer shouldStop: description: Whether the strategy should stop processing due to failures type: boolean - successfulInBatch: - description: Number of successful nodes in current batch - type: integer type: object description: CompartmentBatchStates tracks batch processing state per compartment From 003163df6230644a49f6be949506b188c12688cb Mon Sep 17 00:00:00 2001 From: Tommy Lam Date: Thu, 9 Oct 2025 11:05:37 -0700 Subject: [PATCH 5/8] fix growth logic --- .../api/v1alpha1/deployment_policy_types.go | 80 +++++++++++-------- 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/operator/api/v1alpha1/deployment_policy_types.go b/operator/api/v1alpha1/deployment_policy_types.go index aec34f3c..0d61b0f8 100644 --- a/operator/api/v1alpha1/deployment_policy_types.go +++ b/operator/api/v1alpha1/deployment_policy_types.go @@ -25,7 +25,6 @@ package v1alpha1 import ( "fmt" - "math" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/utils/ptr" @@ -265,7 +264,7 @@ func (s *DeploymentStrategy) CalculateBatchSize(totalNodes int, state *BatchProc } } -// EvaluateBatchResult evaluates the result of a batch and updates state +// EvaluateBatchResult evaluates the result of a batch and records the outcome func (s *DeploymentStrategy) EvaluateBatchResult(state *BatchProcessingState, batchSize int, successCount int, failureCount int, totalNodes int) { // Note: successCount and failureCount are deltas from the current batch // CompletedNodes and FailedNodes are already updated in EvaluateCurrentBatch before this is called @@ -285,18 +284,21 @@ func (s *DeploymentStrategy) EvaluateBatchResult(state *BatchProcessingState, ba progressPercent = (processedNodes * 100) / totalNodes } - if successPercentage >= s.getBatchThreshold() { - state.ConsecutiveFailures = 0 - state.LastBatchFailed = false - } else { + // Record the batch outcome + batchFailed := successPercentage < s.getBatchThreshold() + state.LastBatchSize = batchSize + state.LastBatchFailed = batchFailed + + if batchFailed { state.ConsecutiveFailures++ - state.LastBatchFailed = true + // Check if we should stop processing if progressPercent < s.getSafetyLimit() && state.ConsecutiveFailures >= s.getFailureThreshold() { state.ShouldStop = true } + } else { + state.ConsecutiveFailures = 0 } - state.LastBatchSize = batchSize state.CurrentBatch++ } @@ -354,24 +356,30 @@ func (s *FixedStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessin } func (s *LinearStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { - var batchSize int - // Avoid divide by zero if totalNodes == 0 { return 0 } - // Check if we should slow down due to last batch failure - processedNodes := state.CompletedNodes + state.FailedNodes - progressPercent := (processedNodes * 100) / totalNodes - if state.LastBatchFailed && progressPercent < *s.SafetyLimit && state.LastBatchSize > 0 { - // Slow down: reduce by delta from last batch size - batchSize = max(1, state.LastBatchSize-*s.Delta) + var batchSize int + if state.LastBatchSize > 0 { + // Calculate next size based on last batch outcome + processedNodes := state.CompletedNodes + state.FailedNodes + progressPercent := (processedNodes * 100) / totalNodes + + if state.LastBatchFailed && progressPercent < *s.SafetyLimit { + // Slow down: reduce by delta + batchSize = max(1, state.LastBatchSize-*s.Delta) + } else { + // Normal growth: grow by delta + batchSize = state.LastBatchSize + *s.Delta + } } else { - // Normal growth: initialBatch + (currentBatch - 1) * delta - batchSize = *s.InitialBatch + (state.CurrentBatch-1)*(*s.Delta) + // First batch: use initial batch size + batchSize = *s.InitialBatch } + processedNodes := state.CompletedNodes + state.FailedNodes remaining := totalNodes - processedNodes if batchSize > remaining { batchSize = remaining @@ -380,33 +388,35 @@ func (s *LinearStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessi } func (s *ExponentialStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { - var batchSize int - // Avoid divide by zero if totalNodes == 0 { return 0 } - // Check if we should slow down due to last batch failure - processedNodes := state.CompletedNodes + state.FailedNodes - progressPercent := (processedNodes * 100) / totalNodes - if state.LastBatchFailed && progressPercent < *s.SafetyLimit && state.LastBatchSize > 0 && *s.GrowthFactor > 0 { - // Slow down: divide last batch size by growth factor - batchSize = max(1, state.LastBatchSize / *s.GrowthFactor) - } else { - // Normal growth: initialBatch * (growthFactor ^ (currentBatch - 1)) - // Use math.Pow for efficiency and to avoid overflow issues with large batch numbers - exponent := state.CurrentBatch - 1 - growthMultiplier := math.Pow(float64(*s.GrowthFactor), float64(exponent)) - batchSize = int(float64(*s.InitialBatch) * growthMultiplier) - - // Cap at remaining nodes to prevent unreasonably large batch sizes - // and potential integer overflow + var batchSize int + if state.LastBatchSize > 0 && *s.GrowthFactor > 0 { + // Calculate next size based on last batch outcome + processedNodes := state.CompletedNodes + state.FailedNodes + progressPercent := (processedNodes * 100) / totalNodes + + if state.LastBatchFailed && progressPercent < *s.SafetyLimit { + // Slow down: divide by growth factor + batchSize = max(1, state.LastBatchSize / *s.GrowthFactor) + } else { + // Normal growth: multiply by growth factor + batchSize = state.LastBatchSize * *s.GrowthFactor + } + + // Cap at total nodes to prevent unreasonably large batch sizes if batchSize > totalNodes { batchSize = totalNodes } + } else { + // First batch: use initial batch size + batchSize = *s.InitialBatch } + processedNodes := state.CompletedNodes + state.FailedNodes remaining := totalNodes - processedNodes if batchSize > remaining { batchSize = remaining From aae89c4fd0fd070adc519136ad7f2cbb19184461 Mon Sep 17 00:00:00 2001 From: Tommy Lam Date: Mon, 13 Oct 2025 10:28:38 -0700 Subject: [PATCH 6/8] move persist func to a method off of skyhooknodes --- .../api/v1alpha1/zz_generated.deepcopy.go | 4 +- .../internal/controller/cluster_state_v2.go | 13 +- .../controller/cluster_state_v2_test.go | 131 ++++++++++++++++++ .../internal/controller/mock/SkyhookNodes.go | 44 ++++++ .../internal/controller/skyhook_controller.go | 2 +- 5 files changed, 185 insertions(+), 9 deletions(-) diff --git a/operator/api/v1alpha1/zz_generated.deepcopy.go b/operator/api/v1alpha1/zz_generated.deepcopy.go index e403467d..0619a4ee 100644 --- a/operator/api/v1alpha1/zz_generated.deepcopy.go +++ b/operator/api/v1alpha1/zz_generated.deepcopy.go @@ -1,3 +1,5 @@ +//go:build !ignore_autogenerated + /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 @@ -16,8 +18,6 @@ * limitations under the License. */ -//go:build !ignore_autogenerated - // Code generated by controller-gen. DO NOT EDIT. package v1alpha1 diff --git a/operator/internal/controller/cluster_state_v2.go b/operator/internal/controller/cluster_state_v2.go index 36057a77..5bebe4c3 100644 --- a/operator/internal/controller/cluster_state_v2.go +++ b/operator/internal/controller/cluster_state_v2.go @@ -193,6 +193,7 @@ type SkyhookNodes interface { GetCompartments() map[string]*wrapper.Compartment AddCompartment(name string, compartment *wrapper.Compartment) AddCompartmentNode(name string, node wrapper.SkyhookNode) + PersistCompartmentBatchStates() bool } var _ SkyhookNodes = &skyhookNodes{} @@ -466,15 +467,15 @@ func (np *NodePicker) selectNodesWithCompartments(s SkyhookNodes, compartments m } // PersistCompartmentBatchStates saves the current batch state for all compartments to the Skyhook status -func PersistCompartmentBatchStates(skyhook SkyhookNodes) bool { - compartments := skyhook.GetCompartments() +func (s *skyhookNodes) PersistCompartmentBatchStates() bool { + compartments := s.GetCompartments() if len(compartments) == 0 { return false // No compartments, nothing to persist } // Initialize the batch states map if needed - if skyhook.GetSkyhook().Status.CompartmentBatchStates == nil { - skyhook.GetSkyhook().Status.CompartmentBatchStates = make(map[string]v1alpha1.BatchProcessingState) + if s.skyhook.Status.CompartmentBatchStates == nil { + s.skyhook.Status.CompartmentBatchStates = make(map[string]v1alpha1.BatchProcessingState) } changed := false @@ -483,13 +484,13 @@ func PersistCompartmentBatchStates(skyhook SkyhookNodes) bool { batchState := compartment.GetBatchState() // Only persist if there's meaningful state (batch has started or there are nodes) if batchState.CurrentBatch > 0 || len(compartment.GetNodes()) > 0 { - skyhook.GetSkyhook().Status.CompartmentBatchStates[compartment.GetName()] = batchState + s.skyhook.Status.CompartmentBatchStates[compartment.GetName()] = batchState changed = true } } if changed { - skyhook.GetSkyhook().Updated = true + s.skyhook.Updated = true } return changed diff --git a/operator/internal/controller/cluster_state_v2_test.go b/operator/internal/controller/cluster_state_v2_test.go index 634bd921..8fc26761 100644 --- a/operator/internal/controller/cluster_state_v2_test.go +++ b/operator/internal/controller/cluster_state_v2_test.go @@ -27,6 +27,7 @@ import ( "github.com/NVIDIA/skyhook/operator/internal/wrapper" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + kptr "k8s.io/utils/ptr" ) var _ = Describe("cluster state v2 tests", func() { @@ -466,4 +467,134 @@ var _ = Describe("CleanupRemovedNodes", func() { Expect(result).To(BeFalse()) }) }) + + Describe("PersistCompartmentBatchStates", func() { + var skyhook *wrapper.Skyhook + var sn *skyhookNodes + + BeforeEach(func() { + skyhook = &wrapper.Skyhook{ + Skyhook: &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-skyhook", + }, + Status: v1alpha1.SkyhookStatus{}, + }, + } + + sn = &skyhookNodes{ + skyhook: skyhook, + nodes: []wrapper.SkyhookNode{}, + compartments: make(map[string]*wrapper.Compartment), + } + }) + + It("should return false when there are no compartments", func() { + result := sn.PersistCompartmentBatchStates() + Expect(result).To(BeFalse()) + Expect(skyhook.Updated).To(BeFalse()) + }) + + It("should persist batch state when compartment has CurrentBatch > 0", func() { + // Create a compartment with batch state + batchState := &v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + CompletedNodes: 4, + FailedNodes: 1, + } + compartment := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "compartment1", + Budget: v1alpha1.DeploymentBudget{ + Count: kptr.To(10), + }, + Strategy: &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{InitialBatch: kptr.To(5)}, + }, + }, batchState) + + sn.AddCompartment("compartment1", compartment) + + result := sn.PersistCompartmentBatchStates() + + Expect(result).To(BeTrue()) + Expect(skyhook.Updated).To(BeTrue()) + Expect(skyhook.Status.CompartmentBatchStates).ToNot(BeNil()) + Expect(skyhook.Status.CompartmentBatchStates).To(HaveKey("compartment1")) + Expect(skyhook.Status.CompartmentBatchStates["compartment1"].CurrentBatch).To(Equal(1)) + Expect(skyhook.Status.CompartmentBatchStates["compartment1"].CompletedNodes).To(Equal(4)) + }) + + It("should persist batch state when compartment has nodes", func() { + // Create a compartment with nodes but no batch started yet + compartment := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "compartment1", + Budget: v1alpha1.DeploymentBudget{ + Count: kptr.To(10), + }, + Strategy: &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{InitialBatch: kptr.To(5)}, + }, + }, nil) + + // Add a node to the compartment + node := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node1"}} + skyhookNode, err := wrapper.NewSkyhookNode(node, skyhook.Skyhook) + Expect(err).NotTo(HaveOccurred()) + compartment.AddNode(skyhookNode) + + sn.AddCompartment("compartment1", compartment) + + result := sn.PersistCompartmentBatchStates() + + Expect(result).To(BeTrue()) + Expect(skyhook.Updated).To(BeTrue()) + Expect(skyhook.Status.CompartmentBatchStates).ToNot(BeNil()) + Expect(skyhook.Status.CompartmentBatchStates).To(HaveKey("compartment1")) + }) + + It("should persist multiple compartments with meaningful state", func() { + // Create multiple compartments + batchState1 := &v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + CompletedNodes: 5, + FailedNodes: 0, + } + compartment1 := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "compartment1", + Budget: v1alpha1.DeploymentBudget{ + Count: kptr.To(10), + }, + Strategy: &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{InitialBatch: kptr.To(5)}, + }, + }, batchState1) + + batchState2 := &v1alpha1.BatchProcessingState{ + CurrentBatch: 2, + CompletedNodes: 8, + FailedNodes: 2, + } + compartment2 := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "compartment2", + Budget: v1alpha1.DeploymentBudget{ + Count: kptr.To(5), + }, + Strategy: &v1alpha1.DeploymentStrategy{ + Linear: &v1alpha1.LinearStrategy{}, + }, + }, batchState2) + + sn.AddCompartment("compartment1", compartment1) + sn.AddCompartment("compartment2", compartment2) + + result := sn.PersistCompartmentBatchStates() + + Expect(result).To(BeTrue()) + Expect(skyhook.Updated).To(BeTrue()) + Expect(skyhook.Status.CompartmentBatchStates).ToNot(BeNil()) + Expect(skyhook.Status.CompartmentBatchStates).To(HaveLen(2)) + Expect(skyhook.Status.CompartmentBatchStates["compartment1"].CurrentBatch).To(Equal(1)) + Expect(skyhook.Status.CompartmentBatchStates["compartment2"].CurrentBatch).To(Equal(2)) + }) + }) }) diff --git a/operator/internal/controller/mock/SkyhookNodes.go b/operator/internal/controller/mock/SkyhookNodes.go index d4d1e4c3..053e9c6f 100644 --- a/operator/internal/controller/mock/SkyhookNodes.go +++ b/operator/internal/controller/mock/SkyhookNodes.go @@ -703,6 +703,50 @@ func (_c *MockSkyhookNodes_NodeCount_Call) RunAndReturn(run func() int) *MockSky return _c } +// PersistCompartmentBatchStates provides a mock function for the type MockSkyhookNodes +func (_mock *MockSkyhookNodes) PersistCompartmentBatchStates() bool { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for PersistCompartmentBatchStates") + } + + var r0 bool + if returnFunc, ok := ret.Get(0).(func() bool); ok { + r0 = returnFunc() + } else { + r0 = ret.Get(0).(bool) + } + return r0 +} + +// MockSkyhookNodes_PersistCompartmentBatchStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PersistCompartmentBatchStates' +type MockSkyhookNodes_PersistCompartmentBatchStates_Call struct { + *mock.Call +} + +// PersistCompartmentBatchStates is a helper method to define mock.On call +func (_e *MockSkyhookNodes_Expecter) PersistCompartmentBatchStates() *MockSkyhookNodes_PersistCompartmentBatchStates_Call { + return &MockSkyhookNodes_PersistCompartmentBatchStates_Call{Call: _e.mock.On("PersistCompartmentBatchStates")} +} + +func (_c *MockSkyhookNodes_PersistCompartmentBatchStates_Call) Run(run func()) *MockSkyhookNodes_PersistCompartmentBatchStates_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSkyhookNodes_PersistCompartmentBatchStates_Call) Return(b bool) *MockSkyhookNodes_PersistCompartmentBatchStates_Call { + _c.Call.Return(b) + return _c +} + +func (_c *MockSkyhookNodes_PersistCompartmentBatchStates_Call) RunAndReturn(run func() bool) *MockSkyhookNodes_PersistCompartmentBatchStates_Call { + _c.Call.Return(run) + return _c +} + // ReportState provides a mock function for the type MockSkyhookNodes func (_mock *MockSkyhookNodes) ReportState() { _mock.Called() diff --git a/operator/internal/controller/skyhook_controller.go b/operator/internal/controller/skyhook_controller.go index 102adc44..88efccac 100644 --- a/operator/internal/controller/skyhook_controller.go +++ b/operator/internal/controller/skyhook_controller.go @@ -562,7 +562,7 @@ func (r *SkyhookReconciler) RunSkyhookPackages(ctx context.Context, clusterState selectedNode := nodePicker.SelectNodes(skyhook) // Persist compartment batch states after node selection - PersistCompartmentBatchStates(skyhook) + skyhook.PersistCompartmentBatchStates() for _, node := range selectedNode { From a0e563fdbc26ef6584f0bb791e6e6bf29a07dd4f Mon Sep 17 00:00:00 2001 From: Tommy Lam Date: Mon, 13 Oct 2025 16:07:48 -0700 Subject: [PATCH 7/8] add introspect skyhook uts --- .../controller/cluster_state_v2_test.go | 208 ++++++++++++++++++ 1 file changed, 208 insertions(+) diff --git a/operator/internal/controller/cluster_state_v2_test.go b/operator/internal/controller/cluster_state_v2_test.go index 8fc26761..3ba5d799 100644 --- a/operator/internal/controller/cluster_state_v2_test.go +++ b/operator/internal/controller/cluster_state_v2_test.go @@ -597,4 +597,212 @@ var _ = Describe("CleanupRemovedNodes", func() { Expect(skyhook.Status.CompartmentBatchStates["compartment2"].CurrentBatch).To(Equal(2)) }) }) + + Describe("IntrospectSkyhook", func() { + var testSkyhook *v1alpha1.Skyhook + var testNode *corev1.Node + + BeforeEach(func() { + testSkyhook = &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-skyhook", + Annotations: map[string]string{}, + }, + Spec: v1alpha1.SkyhookSpec{ + Packages: map[string]v1alpha1.Package{ + "test-package": { + PackageRef: v1alpha1.PackageRef{Name: "test-package", Version: "1.0.0"}, + Image: "test-image", + }, + }, + }, + Status: v1alpha1.SkyhookStatus{ + Status: v1alpha1.StatusInProgress, + }, + } + + testNode = &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-node", + }, + } + }) + + It("should set status to disabled when skyhook is disabled", func() { + // Set up the skyhook as disabled + testSkyhook.Annotations["skyhook.nvidia.com/disable"] = "true" + + skyhookNode, err := wrapper.NewSkyhookNode(testNode, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(testSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode}, + } + + // Call the function + changed := IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result + Expect(changed).To(BeTrue()) + Expect(skyhookNodes.Status()).To(Equal(v1alpha1.StatusDisabled)) + }) + + It("should set status to paused when skyhook is paused", func() { + // Set up the skyhook as paused + testSkyhook.Annotations["skyhook.nvidia.com/pause"] = "true" + + skyhookNode, err := wrapper.NewSkyhookNode(testNode, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(testSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode}, + } + + // Call the function + changed := IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result + Expect(changed).To(BeTrue()) + Expect(skyhookNodes.Status()).To(Equal(v1alpha1.StatusPaused)) + }) + + It("should set status to waiting when another skyhook has higher priority", func() { + // Create higher priority skyhook (priority 1) + higherPrioritySkyhook := &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{Name: "skyhook-1"}, + Spec: v1alpha1.SkyhookSpec{ + Priority: 1, + Packages: map[string]v1alpha1.Package{ + "test-package-1": { + PackageRef: v1alpha1.PackageRef{Name: "test-package-1", Version: "1.0.0"}, + Image: "test-image-1", + }, + }, + }, + Status: v1alpha1.SkyhookStatus{Status: v1alpha1.StatusInProgress}, + } + + // Create lower priority skyhook (priority 2) + lowerPrioritySkyhook := &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{Name: "skyhook-2"}, + Spec: v1alpha1.SkyhookSpec{ + Priority: 2, + Packages: map[string]v1alpha1.Package{ + "test-package-2": { + PackageRef: v1alpha1.PackageRef{Name: "test-package-2", Version: "1.0.0"}, + Image: "test-image-2", + }, + }, + }, + Status: v1alpha1.SkyhookStatus{Status: v1alpha1.StatusInProgress}, + } + + node1 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-1"}} + node2 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-2"}} + + skyhookNode1, err := wrapper.NewSkyhookNode(node1, higherPrioritySkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNode2, err := wrapper.NewSkyhookNode(node2, lowerPrioritySkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNodes1 := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(higherPrioritySkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode1}, + } + + skyhookNodes2 := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(lowerPrioritySkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode2}, + } + + allSkyhooks := []SkyhookNodes{skyhookNodes1, skyhookNodes2} + + // Call the function - skyhook2 should be waiting because skyhook1 has higher priority + changed := IntrospectSkyhook(skyhookNodes2, allSkyhooks) + + // Verify the result + Expect(changed).To(BeTrue()) + Expect(skyhookNodes2.Status()).To(Equal(v1alpha1.StatusWaiting)) + }) + + It("should not change status when skyhook is complete", func() { + // Create a complete skyhook with no packages + completeSkyhook := &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{Name: "test-skyhook"}, + Status: v1alpha1.SkyhookStatus{Status: v1alpha1.StatusComplete}, + } + + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "test-node"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + } + + skyhookNode, err := wrapper.NewSkyhookNode(node, completeSkyhook) + Expect(err).NotTo(HaveOccurred()) + skyhookNode.SetStatus(v1alpha1.StatusComplete) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(completeSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode}, + } + + // Call the function + _ = IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result - status should stay complete + Expect(skyhookNodes.Status()).To(Equal(v1alpha1.StatusComplete)) + }) + + It("should return true when node status changes", func() { + skyhookNode, err := wrapper.NewSkyhookNode(testNode, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + skyhookNode.SetStatus(v1alpha1.StatusUnknown) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(testSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode}, + } + + // Call the function + changed := IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result + Expect(changed).To(BeTrue()) + }) + + It("should handle multiple nodes correctly when disabled", func() { + // Set up the skyhook as disabled + testSkyhook.Annotations["skyhook.nvidia.com/disable"] = "true" + + node1 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-1"}} + node2 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-2"}} + + skyhookNode1, err := wrapper.NewSkyhookNode(node1, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNode2, err := wrapper.NewSkyhookNode(node2, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(testSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode1, skyhookNode2}, + } + + // Call the function + changed := IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result + Expect(changed).To(BeTrue()) + Expect(skyhookNodes.Status()).To(Equal(v1alpha1.StatusDisabled)) + Expect(skyhookNode1.Status()).To(Equal(v1alpha1.StatusDisabled)) + Expect(skyhookNode2.Status()).To(Equal(v1alpha1.StatusDisabled)) + }) + }) }) From 86c91039d29fe856528c3c0ed5ba35807d0e26ca Mon Sep 17 00:00:00 2001 From: Tommy Lam Date: Mon, 13 Oct 2025 16:13:17 -0700 Subject: [PATCH 8/8] make true a const to satisfy linter --- .../internal/controller/cluster_state_v2_test.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/operator/internal/controller/cluster_state_v2_test.go b/operator/internal/controller/cluster_state_v2_test.go index 3ba5d799..80a2c3a6 100644 --- a/operator/internal/controller/cluster_state_v2_test.go +++ b/operator/internal/controller/cluster_state_v2_test.go @@ -30,6 +30,10 @@ import ( kptr "k8s.io/utils/ptr" ) +const ( + annotationTrueValue = "true" +) + var _ = Describe("cluster state v2 tests", func() { It("should check taint toleration", func() { @@ -412,7 +416,7 @@ var _ = Describe("CleanupRemovedNodes", func() { It("should update status to paused when skyhook is paused and status is not already paused", func() { // Set up the skyhook as paused - mockSkyhook.Annotations[v1alpha1.METADATA_PREFIX+"/pause"] = "true" + mockSkyhook.Annotations[v1alpha1.METADATA_PREFIX+"/pause"] = annotationTrueValue // Set up mock expectations mockSkyhookNodes.EXPECT().IsPaused().Return(true) @@ -429,7 +433,7 @@ var _ = Describe("CleanupRemovedNodes", func() { It("should not change status when skyhook is paused but status is already paused", func() { // Set up the skyhook as paused with paused status - mockSkyhook.Annotations[v1alpha1.METADATA_PREFIX+"/pause"] = "true" + mockSkyhook.Annotations[v1alpha1.METADATA_PREFIX+"/pause"] = annotationTrueValue // Set up mock expectations mockSkyhookNodes.EXPECT().IsPaused().Return(true) @@ -630,7 +634,7 @@ var _ = Describe("CleanupRemovedNodes", func() { It("should set status to disabled when skyhook is disabled", func() { // Set up the skyhook as disabled - testSkyhook.Annotations["skyhook.nvidia.com/disable"] = "true" + testSkyhook.Annotations["skyhook.nvidia.com/disable"] = annotationTrueValue skyhookNode, err := wrapper.NewSkyhookNode(testNode, testSkyhook) Expect(err).NotTo(HaveOccurred()) @@ -650,7 +654,7 @@ var _ = Describe("CleanupRemovedNodes", func() { It("should set status to paused when skyhook is paused", func() { // Set up the skyhook as paused - testSkyhook.Annotations["skyhook.nvidia.com/pause"] = "true" + testSkyhook.Annotations["skyhook.nvidia.com/pause"] = annotationTrueValue skyhookNode, err := wrapper.NewSkyhookNode(testNode, testSkyhook) Expect(err).NotTo(HaveOccurred()) @@ -779,7 +783,7 @@ var _ = Describe("CleanupRemovedNodes", func() { It("should handle multiple nodes correctly when disabled", func() { // Set up the skyhook as disabled - testSkyhook.Annotations["skyhook.nvidia.com/disable"] = "true" + testSkyhook.Annotations["skyhook.nvidia.com/disable"] = annotationTrueValue node1 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-1"}} node2 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-2"}}