Skip to content

Commit 4e3ebfd

Browse files
davidLifsingh1203
andauthored
Add leaderWorkerSet support - v0.6 (#309)
* Added LWS plugin for LeaderWorkerSet integration (#267) * feat: (GH:#124)add LWS plugin for LeaderWorkerSet support Signed-off-by: Saurabh Kumar Singh <[email protected]> Co-authored-by: davidLif <[email protected]> * Add new type of workload to podgrouper exceptions - DistributedInferenceWorkload (#303) --------- Signed-off-by: Saurabh Kumar Singh <[email protected]> Co-authored-by: Saurabh Singh <[email protected]>
1 parent c50fb19 commit 4e3ebfd

File tree

12 files changed

+516
-8
lines changed

12 files changed

+516
-8
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
66

77
## [Unreleased]
88

9+
## [v0.6.7] - 2025-07-07
10+
### Added
11+
- Added LeaderWorkerSet support in the podGrouper. Each replica will be given a separate podGroup.
12+
913
## [v0.6.6] - 2025-07-06
1014

1115
### Fixes

deployments/kai-scheduler/templates/rbac/podgrouper.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,22 @@ rules:
166166
- create
167167
- patch
168168
- update
169+
- apiGroups:
170+
- leaderworkerset.x-k8s.io
171+
resources:
172+
- leaderworkersets
173+
verbs:
174+
- get
175+
- list
176+
- watch
177+
- apiGroups:
178+
- leaderworkerset.x-k8s.io
179+
resources:
180+
- leaderworkersets/finalizers
181+
verbs:
182+
- create
183+
- patch
184+
- update
169185
- apiGroups:
170186
- machinelearning.seldon.io
171187
resources:
@@ -205,6 +221,7 @@ rules:
205221
- apiGroups:
206222
- run.ai
207223
resources:
224+
- distributedinferenceworkloads
208225
- distributedworkloads
209226
- inferenceworkloads
210227
- interactiveworkloads

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ require (
6363
knative.dev/serving v0.44.0
6464
sigs.k8s.io/controller-runtime v0.20.0
6565
sigs.k8s.io/karpenter v1.2.0
66+
sigs.k8s.io/lws v0.5.1
6667
)
6768

6869
require (

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,8 @@ sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 h1:gBQPwqORJ8d8/YNZWEjoZs7np
602602
sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg=
603603
sigs.k8s.io/karpenter v1.2.0 h1:y1zyFGzDLiT5OHpG8Jvj4JKKe/iXmJDYeejs8k8OznM=
604604
sigs.k8s.io/karpenter v1.2.0/go.mod h1:646txj32arNTy+K4gySCqWSljYrEdemAdYoBMQmkS7o=
605+
sigs.k8s.io/lws v0.5.1 h1:eaeMNkP0manRluQZLN32atoULaGrzP611gSLdFaHZs4=
606+
sigs.k8s.io/lws v0.5.1/go.mod h1:qprXSTTFnfmPZY3V3sUfk6ZPmAodsdoKS8XVElJ9kN0=
605607
sigs.k8s.io/structured-merge-diff/v4 v4.5.0 h1:nbCitCK2hfnhyiKo6uf2HxUPTCodY6Qaf85SbDIaMBk=
606608
sigs.k8s.io/structured-merge-diff/v4 v4.5.0/go.mod h1:N8f93tFZh9U6vpxwRArLiikrE5/2tiu1w1AGfACIGE4=
607609
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=

hack/run-e2e-kind.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ if [ "$TEST_THIRD_PARTY_INTEGRATIONS" = "true" ]; then
5555
${REPO_ROOT}/hack/third_party_integrations/deploy_ray.sh
5656
${REPO_ROOT}/hack/third_party_integrations/deploy_kubeflow.sh
5757
${REPO_ROOT}/hack/third_party_integrations/deploy_knative.sh
58+
${REPO_ROOT}/hack/third_party_integrations/deploy_lws.sh
5859
fi
5960

6061
if [ "$LOCAL_IMAGES_BUILD" = "true" ]; then
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
# Copyright 2025 NVIDIA CORPORATION
3+
# SPDX-License-Identifier: Apache-2.0
4+
set -e
5+
6+
CHART_VERSION=0.6.1
7+
helm install lws oci://registry.k8s.io/lws/charts/lws --version=$CHART_VERSION --namespace lws-system --create-namespace --wait --timeout 300s

pkg/podgrouper/podgrouper/hub/hub.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
pytorchplugin "github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/kubeflow/pytorch"
2222
tensorflowlugin "github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/kubeflow/tensorflow"
2323
xgboostplugin "github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/kubeflow/xgboost"
24+
leader_worker_set "github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/leaderworkerset"
2425
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/podjob"
2526
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/ray"
2627
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/runaijob"
@@ -30,12 +31,13 @@ import (
3031
)
3132

3233
const (
33-
apiGroupArgo = "argoproj.io"
34-
apiGroupRunai = "run.ai"
35-
kindTrainingWorkload = "TrainingWorkload"
36-
kindInteractiveWorkload = "InteractiveWorkload"
37-
kindDistributedWorkload = "DistributedWorkload"
38-
kindInferenceWorkload = "InferenceWorkload"
34+
apiGroupArgo = "argoproj.io"
35+
apiGroupRunai = "run.ai"
36+
kindTrainingWorkload = "TrainingWorkload"
37+
kindInteractiveWorkload = "InteractiveWorkload"
38+
kindDistributedWorkload = "DistributedWorkload"
39+
kindInferenceWorkload = "InferenceWorkload"
40+
kindDistributedInferenceWorkload = "DistributedInferenceWorkload"
3941
)
4042

4143
// +kubebuilder:rbac:groups=apps,resources=replicasets;statefulsets,verbs=get;list;watch
@@ -50,7 +52,7 @@ const (
5052
// +kubebuilder:rbac:groups=argoproj.io,resources=workflows/finalizers,verbs=patch;update;create
5153
// +kubebuilder:rbac:groups=tekton.dev,resources=pipelineruns;taskruns,verbs=get;list;watch
5254
// +kubebuilder:rbac:groups=tekton.dev,resources=pipelineruns/finalizers;taskruns/finalizers,verbs=patch;update;create
53-
// +kubebuilder:rbac:groups=run.ai,resources=trainingworkloads;interactiveworkloads;distributedworkloads;inferenceworkloads,verbs=get;list;watch
55+
// +kubebuilder:rbac:groups=run.ai,resources=trainingworkloads;interactiveworkloads;distributedworkloads;inferenceworkloads;distributedinferenceworkloads,verbs=get;list;watch
5456

5557
type PluginsHub struct {
5658
defaultPlugin *defaultgrouper.DefaultGrouper
@@ -238,6 +240,11 @@ func NewPluginsHub(kubeClient client.Client, searchForLegacyPodGroups,
238240
Version: "v1",
239241
Kind: "SPOTRequest",
240242
}: spotrequest.NewSpotRequestGrouper(defaultGrouper),
243+
{
244+
Group: "leaderworkerset.x-k8s.io",
245+
Version: "v1",
246+
Kind: "LeaderWorkerSet",
247+
}: leader_worker_set.NewLwsGrouper(defaultGrouper),
241248
}
242249

243250
skipTopOwnerGrouper := skiptopowner.NewSkipTopOwnerGrouper(kubeClient, defaultGrouper, table)
@@ -247,7 +254,13 @@ func NewPluginsHub(kubeClient client.Client, searchForLegacyPodGroups,
247254
Kind: "Workflow",
248255
}] = skipTopOwnerGrouper
249256

250-
for _, kind := range []string{kindInferenceWorkload, kindTrainingWorkload, kindDistributedWorkload, kindInteractiveWorkload} {
257+
for _, kind := range []string{
258+
kindInferenceWorkload,
259+
kindTrainingWorkload,
260+
kindDistributedWorkload,
261+
kindInteractiveWorkload,
262+
kindDistributedInferenceWorkload,
263+
} {
251264
table[metav1.GroupVersionKind{
252265
Group: apiGroupRunai,
253266
Version: "*",
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Copyright 2025 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package leader_worker_set
5+
6+
import (
7+
"fmt"
8+
"strconv"
9+
10+
v1 "k8s.io/api/core/v1"
11+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
12+
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
13+
14+
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgroup"
15+
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/defaultgrouper"
16+
)
17+
18+
const (
19+
startupPolicyLeaderReady = "LeaderReady"
20+
startupPolicyLeaderCreated = "LeaderCreated"
21+
22+
// LWS annotation and label keys
23+
lwsSizeAnnotation = "leaderworkerset.sigs.k8s.io/size"
24+
lwsGroupIndexLabel = "leaderworkerset.sigs.k8s.io/group-index"
25+
lwsWorkerIndexLabel = "leaderworkerset.sigs.k8s.io/worker-index"
26+
)
27+
28+
type LwsGrouper struct {
29+
*defaultgrouper.DefaultGrouper
30+
}
31+
32+
func NewLwsGrouper(defaultGrouper *defaultgrouper.DefaultGrouper) *LwsGrouper {
33+
return &LwsGrouper{
34+
DefaultGrouper: defaultGrouper,
35+
}
36+
}
37+
38+
func (lwsGrouper *LwsGrouper) Name() string {
39+
return "LWS Grouper"
40+
}
41+
42+
// +kubebuilder:rbac:groups=leaderworkerset.x-k8s.io,resources=leaderworkersets,verbs=get;list;watch
43+
// +kubebuilder:rbac:groups=leaderworkerset.x-k8s.io,resources=leaderworkersets/finalizers,verbs=patch;update;create
44+
45+
func (lwsGrouper *LwsGrouper) GetPodGroupMetadata(
46+
lwsJob *unstructured.Unstructured, pod *v1.Pod, _ ...*metav1.PartialObjectMetadata,
47+
) (*podgroup.Metadata, error) {
48+
podGroupMetadata, err := lwsGrouper.DefaultGrouper.GetPodGroupMetadata(lwsJob, pod)
49+
if err != nil {
50+
return nil, err
51+
}
52+
53+
groupSize, err := lwsGrouper.getLwsGroupSize(lwsJob)
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
startupPolicy, err := lwsGrouper.getStartupPolicy(lwsJob)
59+
if err != nil {
60+
return nil, err
61+
}
62+
63+
// Initialize podGroupMetadata with the group size
64+
switch startupPolicy {
65+
case startupPolicyLeaderReady:
66+
if err := handleLeaderReadyPolicy(pod, podGroupMetadata, groupSize); err != nil {
67+
return nil, fmt.Errorf("error handling leader ready policy: %w", err)
68+
}
69+
case startupPolicyLeaderCreated:
70+
podGroupMetadata.MinAvailable = groupSize
71+
default:
72+
return nil, fmt.Errorf("unknown startupPolicy: %s", startupPolicy)
73+
}
74+
75+
if groupIndexStr, ok := pod.Labels[lwsGroupIndexLabel]; ok {
76+
if groupIndex, err := strconv.Atoi(groupIndexStr); err == nil {
77+
podGroupMetadata.Name = fmt.Sprintf("%s-group-%d", podGroupMetadata.Name, groupIndex)
78+
}
79+
}
80+
81+
return podGroupMetadata, nil
82+
}
83+
84+
func (lwsGrouper *LwsGrouper) getLwsGroupSize(lwsJob *unstructured.Unstructured) (int32, error) {
85+
size, found, err := unstructured.NestedInt64(lwsJob.Object, "spec", "leaderWorkerTemplate", "size")
86+
if err != nil {
87+
return 0, fmt.Errorf("failed to get leaderWorkerTemplate.size from LWS %s/%s with error: %w",
88+
lwsJob.GetNamespace(), lwsJob.GetName(), err)
89+
}
90+
if !found {
91+
return 0, fmt.Errorf("leaderWorkerTemplate.size not found in LWS %s/%s", lwsJob.GetNamespace(), lwsJob.GetName())
92+
}
93+
if size <= 0 {
94+
return 0, fmt.Errorf("invalid leaderWorkerTemplate.size %d in LWS %s/%s", size, lwsJob.GetNamespace(), lwsJob.GetName())
95+
}
96+
return int32(size), nil
97+
}
98+
99+
// getStartupPolicy extracts the startup policy from the LWS object
100+
func (lwsGrouper *LwsGrouper) getStartupPolicy(lwsJob *unstructured.Unstructured) (string, error) {
101+
policy, found, err := unstructured.NestedString(lwsJob.Object, "spec", "startupPolicy")
102+
if err != nil {
103+
return "", fmt.Errorf("failed to get startupPolicy from LWS %s/%s: %w",
104+
lwsJob.GetNamespace(), lwsJob.GetName(), err)
105+
}
106+
if !found {
107+
// Default to LeaderCreated if not specified
108+
return startupPolicyLeaderCreated, nil
109+
}
110+
return policy, nil
111+
}
112+
113+
func handleLeaderReadyPolicy(pod *v1.Pod, podGroupMetadata *podgroup.Metadata, fallbackSize int32) error {
114+
groupSize := fallbackSize
115+
116+
// Check for the size annotation on the pod
117+
if sizeStr, ok := pod.Annotations[lwsSizeAnnotation]; ok {
118+
if parsed, err := strconv.Atoi(sizeStr); err == nil {
119+
groupSize = int32(parsed)
120+
}
121+
}
122+
123+
workerIndex, hasWorkerIndex := pod.Labels[lwsWorkerIndexLabel]
124+
isLeader := hasWorkerIndex && workerIndex == "0"
125+
isScheduled := pod.Spec.NodeName != ""
126+
127+
if isLeader && !isScheduled {
128+
// Leader pod not yet scheduled, only need leader to be available
129+
podGroupMetadata.MinAvailable = 1
130+
} else {
131+
// Either worker pod or leader is already scheduled
132+
podGroupMetadata.MinAvailable = groupSize
133+
}
134+
135+
return nil
136+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright 2025 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package leader_worker_set
5+
6+
import (
7+
"testing"
8+
9+
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/defaultgrouper"
10+
"github.com/stretchr/testify/assert"
11+
v1 "k8s.io/api/core/v1"
12+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
13+
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
14+
)
15+
16+
func baseOwner(name string, startupPolicy string, replicas int64) *unstructured.Unstructured {
17+
return &unstructured.Unstructured{
18+
Object: map[string]interface{}{
19+
"kind": "LeaderWorkerSet",
20+
"apiVersion": "leaderworkerset.x-k8s.io/v1",
21+
"metadata": map[string]interface{}{
22+
"name": name,
23+
"namespace": "default",
24+
"uid": name + "-uid",
25+
},
26+
"spec": map[string]interface{}{
27+
"startupPolicy": startupPolicy,
28+
"leaderWorkerTemplate": map[string]interface{}{
29+
"size": replicas,
30+
},
31+
},
32+
},
33+
}
34+
}
35+
36+
func TestGetPodGroupMetadata_LeaderCreated(t *testing.T) {
37+
owner := baseOwner("lws-test", "LeaderCreated", 3)
38+
39+
pod := &v1.Pod{
40+
ObjectMeta: metav1.ObjectMeta{
41+
Labels: map[string]string{},
42+
},
43+
}
44+
45+
lwsGrouper := NewLwsGrouper(defaultgrouper.NewDefaultGrouper("", ""))
46+
podGroupMetadata, err := lwsGrouper.GetPodGroupMetadata(owner, pod)
47+
48+
assert.Nil(t, err)
49+
assert.Equal(t, int32(3), podGroupMetadata.MinAvailable)
50+
assert.Equal(t, "LeaderWorkerSet", podGroupMetadata.Owner.Kind)
51+
assert.Equal(t, "leaderworkerset.x-k8s.io/v1", podGroupMetadata.Owner.APIVersion)
52+
assert.Equal(t, "lws-test", podGroupMetadata.Owner.Name)
53+
assert.Equal(t, "lws-test-uid", string(podGroupMetadata.Owner.UID))
54+
}
55+
56+
func TestGetPodGroupMetadata_LeaderReady_LeaderPod(t *testing.T) {
57+
owner := baseOwner("lws-ready", "LeaderReady", 5)
58+
59+
pod := &v1.Pod{
60+
ObjectMeta: metav1.ObjectMeta{
61+
Annotations: map[string]string{},
62+
Labels: map[string]string{},
63+
},
64+
Spec: v1.PodSpec{
65+
NodeName: "", // not scheduled => simulate leader
66+
},
67+
}
68+
69+
lwsGrouper := NewLwsGrouper(defaultgrouper.NewDefaultGrouper("", ""))
70+
podGroupMetadata, err := lwsGrouper.GetPodGroupMetadata(owner, pod)
71+
72+
assert.Nil(t, err)
73+
assert.Equal(t, int32(5), podGroupMetadata.MinAvailable)
74+
}
75+
76+
func TestGetPodGroupMetadata_LeaderReady_WorkerPod(t *testing.T) {
77+
owner := baseOwner("lws-ready", "LeaderReady", 5)
78+
79+
pod := &v1.Pod{
80+
ObjectMeta: metav1.ObjectMeta{
81+
Annotations: map[string]string{
82+
"leaderworkerset.sigs.k8s.io/size": "5",
83+
},
84+
Labels: map[string]string{
85+
"leaderworkerset.sigs.k8s.io/group-index": "0",
86+
},
87+
},
88+
Spec: v1.PodSpec{
89+
NodeName: "worker-node", // scheduled => simulate worker
90+
},
91+
}
92+
93+
lwsGrouper := NewLwsGrouper(defaultgrouper.NewDefaultGrouper("", ""))
94+
podGroupMetadata, err := lwsGrouper.GetPodGroupMetadata(owner, pod)
95+
96+
assert.Nil(t, err)
97+
assert.Equal(t, int32(5), podGroupMetadata.MinAvailable)
98+
}
99+
100+
func TestGetPodGroupMetadata_GroupIndex_Label(t *testing.T) {
101+
owner := baseOwner("lws-grouped", "LeaderCreated", 2)
102+
103+
pod := &v1.Pod{
104+
ObjectMeta: metav1.ObjectMeta{
105+
Labels: map[string]string{
106+
"leaderworkerset.sigs.k8s.io/group-index": "1",
107+
},
108+
},
109+
}
110+
111+
lwsGrouper := NewLwsGrouper(defaultgrouper.NewDefaultGrouper("", ""))
112+
podGroupMetadata, err := lwsGrouper.GetPodGroupMetadata(owner, pod)
113+
114+
assert.Nil(t, err)
115+
assert.Contains(t, podGroupMetadata.Name, "-group-1")
116+
}

0 commit comments

Comments
 (0)