Skip to content

Commit 8637f29

Browse files
authored
Merge pull request #591 from danielvegamyhre/automated-cherry-pick-of-#590-upstream-release-0.5
Automated cherry pick of #590: propagate job pod template updates to suspended jobs when
2 parents 0436215 + d809200 commit 8637f29

File tree

4 files changed

+220
-25
lines changed

4 files changed

+220
-25
lines changed

pkg/controllers/jobset_controller.go

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,10 @@ func (r *JobSetReconciler) suspendJobs(ctx context.Context, js *jobset.JobSet, a
388388
// resumeJobsIfNecessary iterates through each replicatedJob, resuming any suspended jobs if the JobSet
389389
// is not suspended.
390390
func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset.JobSet, activeJobs []*batchv1.Job, replicatedJobStatuses []jobset.ReplicatedJobStatus, updateStatusOpts *statusUpdateOpts) error {
391-
// Store node selector for each replicatedJob template.
392-
nodeAffinities := map[string]map[string]string{}
391+
// Store pod template for each replicatedJob.
392+
replicatedJobTemplateMap := map[string]corev1.PodTemplateSpec{}
393393
for _, replicatedJob := range js.Spec.ReplicatedJobs {
394-
nodeAffinities[replicatedJob.Name] = replicatedJob.Template.Spec.Template.Spec.NodeSelector
394+
replicatedJobTemplateMap[replicatedJob.Name] = replicatedJob.Template.Spec.Template
395395
}
396396

397397
// Map each replicatedJob to a list of its active jobs.
@@ -415,7 +415,7 @@ func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset
415415
if !jobSuspended(job) {
416416
continue
417417
}
418-
if err := r.resumeJob(ctx, job, nodeAffinities); err != nil {
418+
if err := r.resumeJob(ctx, job, replicatedJobTemplateMap); err != nil {
419419
return err
420420
}
421421
}
@@ -433,7 +433,7 @@ func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset
433433
return nil
434434
}
435435

436-
func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, nodeAffinities map[string]map[string]string) error {
436+
func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, replicatedJobTemplateMap map[string]corev1.PodTemplateSpec) error {
437437
log := ctrl.LoggerFrom(ctx)
438438
// Kubernetes validates that a job template is immutable
439439
// so if the job has started i.e., startTime != nil), we must set it to nil first.
@@ -443,10 +443,33 @@ func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, node
443443
return err
444444
}
445445
}
446+
447+
// Get name of parent replicated job and use it to look up the pod template.
448+
replicatedJobName := job.Labels[jobset.ReplicatedJobNameKey]
449+
replicatedJobPodTemplate := replicatedJobTemplateMap[replicatedJobName]
446450
if job.Labels != nil && job.Labels[jobset.ReplicatedJobNameKey] != "" {
447-
// When resuming a job, its nodeSelectors should match that of the replicatedJob template
448-
// that it was created from, which may have been updated while it was suspended.
449-
job.Spec.Template.Spec.NodeSelector = nodeAffinities[job.Labels[jobset.ReplicatedJobNameKey]]
451+
// Certain fields on the Job pod template may be mutated while a JobSet is suspended,
452+
// for integration with Kueue. Ensure these updates are propagated to the child Jobs
453+
// when the JobSet is resumed.
454+
// Merge values rather than overwriting them, since a different controller
455+
// (e.g., the Job controller) may have added labels/annotations/etc to the
456+
// Job that do not exist in the ReplicatedJob pod template.
457+
job.Spec.Template.Labels = collections.MergeMaps(
458+
job.Spec.Template.Labels,
459+
replicatedJobPodTemplate.Labels,
460+
)
461+
job.Spec.Template.Annotations = collections.MergeMaps(
462+
job.Spec.Template.Annotations,
463+
replicatedJobPodTemplate.Annotations,
464+
)
465+
job.Spec.Template.Spec.NodeSelector = collections.MergeMaps(
466+
job.Spec.Template.Spec.NodeSelector,
467+
replicatedJobPodTemplate.Spec.NodeSelector,
468+
)
469+
job.Spec.Template.Spec.Tolerations = collections.MergeSlices(
470+
job.Spec.Template.Spec.Tolerations,
471+
replicatedJobPodTemplate.Spec.Tolerations,
472+
)
450473
} else {
451474
log.Error(nil, "job missing ReplicatedJobName label")
452475
}

pkg/util/collections/collections.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,41 @@ func IndexOf[T comparable](slice []T, item T) int {
4747
}
4848
return -1
4949
}
50+
51+
// MergeMaps will merge the `old` and `new` maps and return the
52+
// merged map. If a key appears in both maps, the key-value pair
53+
// in the `new` map will overwrite the value in the `old` map.
54+
func MergeMaps[K comparable, V any](old, new map[K]V) map[K]V {
55+
merged := make(map[K]V)
56+
for k, v := range old {
57+
merged[k] = v
58+
}
59+
for k, v := range new {
60+
merged[k] = v // Overwrite if duplicate
61+
}
62+
return merged
63+
}
64+
65+
func MergeSlices[T comparable](s1, s2 []T) []T {
66+
mergedSet := make(map[T]bool)
67+
68+
// Add elements from s1 to the set
69+
for _, item := range s1 {
70+
mergedSet[item] = true
71+
}
72+
73+
// Add elements from s2, only if they are not already in the set
74+
for _, item := range s2 {
75+
if _, exists := mergedSet[item]; !exists {
76+
mergedSet[item] = true
77+
}
78+
}
79+
80+
// Convert the set back into a slice
81+
mergedSlice := make([]T, 0, len(mergedSet))
82+
for item := range mergedSet {
83+
mergedSlice = append(mergedSlice, item)
84+
}
85+
86+
return mergedSlice
87+
}

pkg/util/collections/collections_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"testing"
2121

2222
"github.com/google/go-cmp/cmp"
23+
"golang.org/x/exp/slices"
2324
)
2425

2526
func TestConcat(t *testing.T) {
@@ -151,3 +152,90 @@ func TestContains(t *testing.T) {
151152
})
152153
}
153154
}
155+
156+
func TestMergeMaps(t *testing.T) {
157+
testCases := []struct {
158+
name string
159+
m1 map[string]int
160+
m2 map[string]int
161+
expected map[string]int
162+
}{
163+
{
164+
name: "Basic merge",
165+
m1: map[string]int{"a": 1, "b": 2},
166+
m2: map[string]int{"c": 3, "d": 4},
167+
expected: map[string]int{"a": 1, "b": 2, "c": 3, "d": 4},
168+
},
169+
{
170+
name: "Overlapping keys",
171+
m1: map[string]int{"a": 1, "b": 2},
172+
m2: map[string]int{"b": 3, "c": 4},
173+
expected: map[string]int{"a": 1, "b": 3, "c": 4}, // m2 value for 'b' overwrites
174+
},
175+
{
176+
name: "Empty maps",
177+
m1: map[string]int{},
178+
m2: map[string]int{},
179+
expected: map[string]int{},
180+
},
181+
{
182+
name: "One empty map",
183+
m1: map[string]int{"a": 1, "b": 2},
184+
m2: map[string]int{},
185+
expected: map[string]int{"a": 1, "b": 2},
186+
},
187+
}
188+
189+
for _, tc := range testCases {
190+
t.Run(tc.name, func(t *testing.T) {
191+
merged := MergeMaps(tc.m1, tc.m2)
192+
193+
if !reflect.DeepEqual(merged, tc.expected) {
194+
t.Errorf("expected %v, got %v", tc.expected, merged)
195+
}
196+
})
197+
}
198+
}
199+
200+
func TestMergeSlices(t *testing.T) {
201+
testCases := []struct {
202+
name string
203+
s1 []int
204+
s2 []int
205+
expected []int
206+
}{
207+
{
208+
name: "merge with overlapping elements should not result in duplicates",
209+
s1: []int{1, 2, 3},
210+
s2: []int{3, 4, 5},
211+
expected: []int{1, 2, 3, 4, 5},
212+
},
213+
{
214+
name: "empty slices",
215+
s1: []int{},
216+
s2: []int{},
217+
expected: []int{},
218+
},
219+
{
220+
name: "one empty slice",
221+
s1: []int{1, 2},
222+
s2: []int{},
223+
expected: []int{1, 2},
224+
},
225+
}
226+
227+
for _, tc := range testCases {
228+
t.Run(tc.name, func(t *testing.T) {
229+
merged := MergeSlices(tc.s1, tc.s2)
230+
231+
// Sort before comparison so slices with the same elements
232+
// should be the same.
233+
slices.Sort(merged)
234+
slices.Sort(tc.expected)
235+
236+
if !reflect.DeepEqual(merged, tc.expected) {
237+
t.Errorf("Expected %v, got %v", tc.expected, merged)
238+
}
239+
})
240+
}
241+
}

test/integration/controller/jobset_controller_test.go

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,16 @@ var _ = ginkgo.Describe("JobSet controller", func() {
132132
updates []*update
133133
}
134134

135-
nodeSelectors := map[string]map[string]string{
136-
"replicated-job-a": {"node-selector-test-a": "node-selector-test-a"},
137-
"replicated-job-b": {"node-selector-test-b": "node-selector-test-b"},
135+
var podTemplateUpdates = &updatePodTemplateOpts{
136+
labels: map[string]string{"label": "value"},
137+
annotations: map[string]string{"annotation": "value"},
138+
nodeSelector: map[string]string{"node-selector-test-a": "node-selector-test-a"},
139+
tolerations: []corev1.Toleration{
140+
{
141+
Key: "key",
142+
Operator: corev1.TolerationOpExists,
143+
},
144+
},
138145
}
139146

140147
ginkgo.DescribeTable("jobset is created and its jobs go through a series of updates",
@@ -514,7 +521,7 @@ var _ = ginkgo.Describe("JobSet controller", func() {
514521
},
515522
{
516523
jobSetUpdateFn: func(js *jobset.JobSet) {
517-
updateJobSetNodeSelectors(js, nodeSelectors)
524+
updatePodTemplates(js, podTemplateUpdates)
518525
},
519526
checkJobSetState: func(js *jobset.JobSet) {
520527
ginkgo.By("Check ReplicatedJobStatus for suspend")
@@ -542,7 +549,7 @@ var _ = ginkgo.Describe("JobSet controller", func() {
542549
{
543550
checkJobSetState: func(js *jobset.JobSet) {
544551
ginkgo.By("checking jobs have expected node selectors")
545-
gomega.Eventually(matchJobsNodeSelectors, timeout, interval).WithArguments(js, nodeSelectors).Should(gomega.Equal(true))
552+
gomega.Eventually(checkPodTemplateUpdates, timeout, interval).WithArguments(js, podTemplateUpdates).Should(gomega.Equal(true))
546553
},
547554
jobUpdateFn: completeAllJobs,
548555
checkJobSetCondition: testutil.JobSetCompleted,
@@ -1464,15 +1471,35 @@ func suspendJobSet(js *jobset.JobSet, suspend bool) {
14641471
}, timeout, interval).Should(gomega.Succeed())
14651472
}
14661473

1467-
func updateJobSetNodeSelectors(js *jobset.JobSet, nodeSelectors map[string]map[string]string) {
1474+
// updatePodTemplateOpts contains pod template values
1475+
// which can be mutated on a ReplicatedJob template
1476+
// while a JobSet is suspended.
1477+
type updatePodTemplateOpts struct {
1478+
labels map[string]string
1479+
annotations map[string]string
1480+
nodeSelector map[string]string
1481+
tolerations []corev1.Toleration
1482+
}
1483+
1484+
func updatePodTemplates(js *jobset.JobSet, opts *updatePodTemplateOpts) {
14681485
gomega.Eventually(func() error {
14691486
var jsGet jobset.JobSet
14701487
if err := k8sClient.Get(ctx, types.NamespacedName{Name: js.Name, Namespace: js.Namespace}, &jsGet); err != nil {
14711488
return err
14721489
}
14731490
for index := range jsGet.Spec.ReplicatedJobs {
1474-
jsGet.Spec.ReplicatedJobs[index].
1475-
Template.Spec.Template.Spec.NodeSelector = nodeSelectors[jsGet.Spec.ReplicatedJobs[index].Name]
1491+
podTemplate := &jsGet.Spec.ReplicatedJobs[index].Template.Spec.Template
1492+
// Update labels.
1493+
podTemplate.Labels = opts.labels
1494+
1495+
// Update annotations.
1496+
podTemplate.Annotations = opts.annotations
1497+
1498+
// Update node selector.
1499+
podTemplate.Spec.NodeSelector = opts.nodeSelector
1500+
1501+
// Update tolerations.
1502+
podTemplate.Spec.Tolerations = opts.tolerations
14761503
}
14771504
return k8sClient.Update(ctx, &jsGet)
14781505
}, timeout, interval).Should(gomega.Succeed())
@@ -1496,29 +1523,48 @@ func matchJobsSuspendState(js *jobset.JobSet, suspend bool) (bool, error) {
14961523
return true, nil
14971524
}
14981525

1499-
func matchJobsNodeSelectors(js *jobset.JobSet, nodeSelectors map[string]map[string]string) (bool, error) {
1526+
func checkPodTemplateUpdates(js *jobset.JobSet, podTemplateUpdates *updatePodTemplateOpts) (bool, error) {
15001527
var jobList batchv1.JobList
15011528
if err := k8sClient.List(ctx, &jobList, client.InNamespace(js.Namespace)); err != nil {
15021529
return false, err
15031530
}
15041531
// Count number of updated jobs
15051532
jobsUpdated := 0
15061533
for _, job := range jobList.Items {
1507-
rjobName, ok := job.Labels[jobset.ReplicatedJobNameKey]
1508-
if !ok {
1509-
return false, fmt.Errorf(fmt.Sprintf("%s job missing ReplicatedJobName label", job.Name))
1534+
// Check label was added.
1535+
for label, value := range podTemplateUpdates.labels {
1536+
if job.Spec.Template.Labels[label] != value {
1537+
return false, fmt.Errorf("%s != %s", job.Spec.Template.Labels[label], value)
1538+
}
15101539
}
1511-
if !apiequality.Semantic.DeepEqual(job.Spec.Template.Spec.NodeSelector, nodeSelectors[rjobName]) {
1512-
return false, nil
1540+
1541+
// Check annotation was added.
1542+
for annotation, value := range podTemplateUpdates.annotations {
1543+
if job.Spec.Template.Annotations[annotation] != value {
1544+
return false, fmt.Errorf("%s != %s", job.Spec.Template.Labels[annotation], value)
1545+
}
15131546
}
1547+
1548+
// Check nodeSelector was updated.
1549+
for label, value := range podTemplateUpdates.nodeSelector {
1550+
if job.Spec.Template.Spec.NodeSelector[label] != value {
1551+
return false, fmt.Errorf("%s != %s", job.Spec.Template.Spec.NodeSelector[label], value)
1552+
}
1553+
}
1554+
1555+
// Check tolerations were updated.
1556+
for _, toleration := range podTemplateUpdates.tolerations {
1557+
if !collections.Contains(job.Spec.Template.Spec.Tolerations, toleration) {
1558+
return false, fmt.Errorf("missing toleration %v", toleration)
1559+
}
1560+
}
1561+
15141562
jobsUpdated++
15151563
}
15161564
// Calculate expected number of updated jobs
15171565
wantJobsUpdated := 0
15181566
for _, rjob := range js.Spec.ReplicatedJobs {
1519-
if _, exists := nodeSelectors[rjob.Name]; exists {
1520-
wantJobsUpdated += int(rjob.Replicas)
1521-
}
1567+
wantJobsUpdated += int(rjob.Replicas)
15221568
}
15231569
return wantJobsUpdated == jobsUpdated, nil
15241570
}

0 commit comments

Comments
 (0)