Skip to content

Commit 1824234

Browse files
authored
Adjusted grove podGrouper plugin to use subgroups (#377)
* Adjusted grove podGrouper plugin to use subgroups * Moved SubGroupLabelKey to common consts
1 parent fc5dd11 commit 1824234

File tree

8 files changed

+321
-28
lines changed

8 files changed

+321
-28
lines changed

pkg/common/constants/constants.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ const (
3131
MultiGpuGroupLabelPrefix = GPUGroup + "/"
3232
MigStrategyLabel = "nvidia.com/mig.strategy"
3333
GpuCountLabel = "nvidia.com/gpu.count"
34+
SubGroupLabelKey = "kai.scheduler/subgroup-name"
3435
)

pkg/podgrouper/podgroup/handler.go

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ package podgroup
55

66
import (
77
"context"
8+
"fmt"
89

10+
v1 "k8s.io/api/core/v1"
911
"k8s.io/apimachinery/pkg/api/errors"
1012
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1113
"k8s.io/apimachinery/pkg/types"
1214
"sigs.k8s.io/controller-runtime/pkg/client"
1315

1416
schedulingv2alpha2 "github.com/NVIDIA/KAI-scheduler/pkg/apis/scheduling/v2alpha2"
17+
commonconstants "github.com/NVIDIA/KAI-scheduler/pkg/common/constants"
1518
)
1619

1720
type Handler struct {
@@ -31,7 +34,11 @@ func NewHandler(client client.Client, nodePoolKey string, queueLabelKey string)
3134
func (h *Handler) ApplyToCluster(ctx context.Context, pgMetadata Metadata) error {
3235
newPodGroup := h.createPodGroupForMetadata(pgMetadata)
3336

34-
var err error
37+
err := h.assignPodsToSubGroup(ctx, pgMetadata.SubGroups)
38+
if err != nil {
39+
return fmt.Errorf("error assigning pods to subgroup: %v", err)
40+
}
41+
3542
oldPodGroup := &schedulingv2alpha2.PodGroup{}
3643
key := types.NamespacedName{
3744
Namespace: pgMetadata.Namespace,
@@ -101,9 +108,18 @@ func (h *Handler) createPodGroupForMetadata(podGroupMetadata Metadata) *scheduli
101108
MinMember: podGroupMetadata.MinAvailable,
102109
Queue: podGroupMetadata.Queue,
103110
PriorityClassName: podGroupMetadata.PriorityClassName,
111+
SubGroups: []schedulingv2alpha2.SubGroup{},
104112
},
105113
}
106114

115+
for _, subGroup := range podGroupMetadata.SubGroups {
116+
pg.Spec.SubGroups = append(pg.Spec.SubGroups,
117+
schedulingv2alpha2.SubGroup{
118+
Name: subGroup.Name,
119+
MinMember: subGroup.MinAvailable,
120+
})
121+
}
122+
107123
pg.Spec.TopologyConstraint = schedulingv2alpha2.TopologyConstraint{
108124
PreferredTopologyLevel: podGroupMetadata.PreferredTopologyLevel,
109125
RequiredTopologyLevel: podGroupMetadata.RequiredTopologyLevel,
@@ -112,3 +128,24 @@ func (h *Handler) createPodGroupForMetadata(podGroupMetadata Metadata) *scheduli
112128

113129
return pg
114130
}
131+
132+
func (h *Handler) assignPodsToSubGroup(ctx context.Context, subGroups []*SubGroupMetadata) error {
133+
for _, subGroup := range subGroups {
134+
for _, podRef := range subGroup.PodsReferences {
135+
pod := &v1.Pod{}
136+
err := h.client.Get(ctx, *podRef, pod)
137+
if err != nil {
138+
return err
139+
}
140+
141+
labeledPod := pod.DeepCopy()
142+
labeledPod.Labels[commonconstants.SubGroupLabelKey] = subGroup.Name
143+
144+
err = h.client.Patch(ctx, labeledPod, client.MergeFrom(pod))
145+
if err != nil {
146+
return err
147+
}
148+
}
149+
}
150+
return nil
151+
}

pkg/podgrouper/podgroup/metadata.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33

44
package podgroup
55

6-
import metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
6+
import (
7+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
8+
"k8s.io/apimachinery/pkg/types"
9+
)
10+
11+
type SubGroupMetadata struct {
12+
Name string
13+
MinAvailable int32
14+
PodsReferences []*types.NamespacedName
15+
}
716

817
type Metadata struct {
918
Annotations map[string]string
@@ -14,6 +23,7 @@ type Metadata struct {
1423
Name string
1524
MinAvailable int32
1625
Owner metav1.OwnerReference
26+
SubGroups []*SubGroupMetadata
1727

1828
PreferredTopologyLevel string
1929
RequiredTopologyLevel string

pkg/podgrouper/podgrouper/plugins/grove/grove_grouper.go

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1212
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
1313
"k8s.io/apimachinery/pkg/runtime/schema"
14+
"k8s.io/apimachinery/pkg/types"
1415
"sigs.k8s.io/controller-runtime/pkg/client"
1516

1617
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgroup"
@@ -67,19 +68,19 @@ func (gg *GroveGrouper) GetPodGroupMetadata(
6768
Name: podGangName,
6869
}, podGang)
6970
if err != nil {
70-
return nil, fmt.Errorf("failed to get PodGang %s/%s : %w",
71+
return nil, fmt.Errorf("failed to get PodGang %s/%s. Err: %w",
7172
pod.Namespace, podGangName, err)
7273
}
7374

7475
metadata, err := gg.DefaultGrouper.GetPodGroupMetadata(podGang, pod)
7576
if err != nil {
76-
return nil, fmt.Errorf("failed to get DefaultGrouper metadata for PodGang %s/%s : %w",
77+
return nil, fmt.Errorf("failed to get DefaultGrouper metadata for PodGang %s/%s. Err: %w",
7778
pod.Namespace, podGangName, err)
7879
}
7980

8081
priorityClassName, found, err := unstructured.NestedString(podGang.Object, "spec", "priorityClassName")
8182
if err != nil {
82-
return nil, fmt.Errorf("failed to get spec.priorityClassName from PodGang %s/%s : %w",
83+
return nil, fmt.Errorf("failed to get spec.priorityClassName from PodGang %s/%s. Err: %w",
8384
pod.Namespace, podGangName, err)
8485
}
8586
if found {
@@ -89,27 +90,99 @@ func (gg *GroveGrouper) GetPodGroupMetadata(
8990
var minAvailable int32
9091
pgSlice, found, err := unstructured.NestedSlice(podGang.Object, "spec", "podgroups")
9192
if err != nil {
92-
return nil, fmt.Errorf("failed to get spec.podgroups from PodGang %s/%s : %w",
93+
return nil, fmt.Errorf("failed to get spec.podgroups from PodGang %s/%s. Err: %w",
9394
pod.Namespace, podGangName, err)
9495
}
95-
for idx, v := range pgSlice {
96+
for pgIndex, v := range pgSlice {
9697
pgr, ok := v.(map[string]interface{})
9798
if !ok {
9899
return nil, fmt.Errorf("invalid structure of spec.podgroup[%v] in PodGang %s/%s",
99-
idx, pod.Namespace, podGangName)
100+
pgIndex, pod.Namespace, podGangName)
100101
}
101-
podSlice, found, err := unstructured.NestedSlice(pgr, "podReferences")
102+
subGroup, err := parseGroveSubGroup(pgr, pgIndex, pod.Namespace, podGangName)
102103
if err != nil {
103-
return nil, fmt.Errorf("failed to get podReferences from spec.podgroup[%v] of PodGang %s/%s : %w",
104-
idx, pod.Namespace, podGangName, err)
104+
return nil, fmt.Errorf("failed to parse spec.podgroups[%d] from PodGang %s/%s. Err: %w",
105+
pgIndex, pod.Namespace, podGangName, err)
105106
}
106-
if !found {
107-
return nil, fmt.Errorf("missing podReferences in spec.podgroup[%v] of PodGang %s/%s",
108-
idx, pod.Namespace, podGangName)
109-
}
110-
minAvailable += int32(len(podSlice))
107+
metadata.SubGroups = append(metadata.SubGroups, subGroup)
108+
109+
minAvailable += subGroup.MinAvailable
111110
}
112111
metadata.MinAvailable = minAvailable
113112

114113
return metadata, nil
115114
}
115+
116+
func parseGroveSubGroup(
117+
pg map[string]interface{}, pgIndex int, namespace, podGangName string,
118+
) (*podgroup.SubGroupMetadata, error) {
119+
// Name
120+
name, found, err := unstructured.NestedString(pg, "name")
121+
if err != nil {
122+
return nil, fmt.Errorf("failed to parse 'name' field. Err: %v", err)
123+
}
124+
if !found {
125+
return nil, fmt.Errorf("missing required 'name' field")
126+
}
127+
128+
// MinReplicas
129+
minAvailable, found, err := unstructured.NestedInt64(pg, "minReplicas")
130+
if err != nil {
131+
return nil, fmt.Errorf("failed to parse 'minReplicas' field. Err: %v", err)
132+
}
133+
if !found {
134+
return nil, fmt.Errorf("missing required 'minReplicas' field")
135+
}
136+
if minAvailable <= 0 {
137+
return nil, fmt.Errorf("invalid 'minReplicas' field. Must be greater than 0")
138+
}
139+
140+
// PodReferences
141+
podReferences, found, err := unstructured.NestedSlice(pg, "podReferences")
142+
if err != nil {
143+
return nil, fmt.Errorf("failed to parse 'podReferences' field. Err: %w", err)
144+
}
145+
if !found {
146+
return nil, fmt.Errorf("missing required 'podReferences' field")
147+
}
148+
var pods []*types.NamespacedName
149+
for podIndex, podRef := range podReferences {
150+
reference, ok := podRef.(map[string]interface{})
151+
if !ok {
152+
return nil, fmt.Errorf("invalid spec.podgroup[%d].podReferences[%d] in PodGang %s/%s",
153+
pgIndex, podIndex, namespace, podGangName)
154+
}
155+
namespacedName, err := parsePodReference(reference)
156+
if err != nil {
157+
return nil, fmt.Errorf("failed to parse spec.podgroups[%d].podreferences[%d] from PodGang %s/%s. Err: %w",
158+
pgIndex, podIndex, namespace, podGangName, err)
159+
}
160+
pods = append(pods, namespacedName)
161+
}
162+
163+
return &podgroup.SubGroupMetadata{
164+
Name: name,
165+
MinAvailable: int32(minAvailable),
166+
PodsReferences: pods,
167+
}, nil
168+
}
169+
170+
func parsePodReference(podRef map[string]interface{}) (*types.NamespacedName, error) {
171+
podNamespace, found, err := unstructured.NestedString(podRef, "namespace")
172+
if err != nil {
173+
return nil, fmt.Errorf("failed to parse 'namespace' field. Err: %v", err)
174+
}
175+
if !found {
176+
return nil, fmt.Errorf("missing required 'namespace' field")
177+
}
178+
179+
podName, found, err := unstructured.NestedString(podRef, "name")
180+
if err != nil {
181+
return nil, fmt.Errorf("failed to parse 'name' field. Err: %v", err)
182+
}
183+
if !found {
184+
return nil, fmt.Errorf("missing required 'name' field")
185+
}
186+
187+
return &types.NamespacedName{Namespace: podNamespace, Name: podName}, nil
188+
}

0 commit comments

Comments
 (0)