Skip to content

Commit eb02298

Browse files
authored
feat: copy webhooks to seperate files (#331)
* feat: copy webhooks to seperate files * fix: apply CR comments, add tests boilerplate * test: add unit tests * test: fix tests * fix: remove RunSpecs multiple calls
1 parent daf2d15 commit eb02298

File tree

7 files changed

+737
-0
lines changed

7 files changed

+737
-0
lines changed

pkg/admission/plugins/plugins.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright 2025 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package plugins
5+
6+
import (
7+
"context"
8+
9+
v1 "k8s.io/api/core/v1"
10+
"sigs.k8s.io/controller-runtime/pkg/log"
11+
)
12+
13+
type Plugin interface {
14+
Name() string
15+
Validate(*v1.Pod) error
16+
Mutate(*v1.Pod) error
17+
}
18+
19+
type KaiAdmissionPlugins struct {
20+
plugins []Plugin
21+
}
22+
23+
func New() *KaiAdmissionPlugins {
24+
return &KaiAdmissionPlugins{
25+
plugins: []Plugin{},
26+
}
27+
}
28+
29+
func (bp *KaiAdmissionPlugins) RegisterPlugin(plugin Plugin) {
30+
bp.plugins = append(bp.plugins, plugin)
31+
}
32+
33+
func (bp *KaiAdmissionPlugins) Validate(pod *v1.Pod) error {
34+
for _, p := range bp.plugins {
35+
err := p.Validate(pod)
36+
if err != nil {
37+
logger := log.FromContext(context.Background())
38+
logger.Error(err, "pod validation failed for pod",
39+
"namespace", pod.Namespace, "name", pod.Name, "plugin", p.Name())
40+
return err
41+
}
42+
}
43+
return nil
44+
}
45+
46+
func (bp *KaiAdmissionPlugins) Mutate(pod *v1.Pod) error {
47+
for _, p := range bp.plugins {
48+
err := p.Mutate(pod)
49+
if err != nil {
50+
logger := log.FromContext(context.Background())
51+
logger.Error(err, "pod mutation failed for pod",
52+
"namespace", pod.Namespace, "name", pod.Name, "plugin", p.Name())
53+
return err
54+
}
55+
}
56+
return nil
57+
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// Copyright 2025 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package gpusharing
5+
6+
import (
7+
"context"
8+
"fmt"
9+
"strings"
10+
11+
"golang.org/x/exp/slices"
12+
v1 "k8s.io/api/core/v1"
13+
"sigs.k8s.io/controller-runtime/pkg/client"
14+
15+
"github.com/NVIDIA/KAI-scheduler/pkg/apis/scheduling/v1alpha2"
16+
"github.com/NVIDIA/KAI-scheduler/pkg/binder/common/gpusharingconfigmap"
17+
"github.com/NVIDIA/KAI-scheduler/pkg/common/resources"
18+
19+
"github.com/NVIDIA/KAI-scheduler/pkg/binder/common"
20+
gpurequesthandler "github.com/NVIDIA/KAI-scheduler/pkg/binder/plugins/gpusharing/gpu-request"
21+
"github.com/NVIDIA/KAI-scheduler/pkg/binder/plugins/state"
22+
)
23+
24+
const (
25+
fractionContainerIndex = 0
26+
CdiDeviceNameBase = "k8s.device-plugin.nvidia.com/gpu=%s"
27+
)
28+
29+
type GPUSharing struct {
30+
kubeClient client.Client
31+
gpuDevicePluginUsesCdi bool
32+
gpuSharingEnabled bool
33+
}
34+
35+
func New(kubeClient client.Client, gpuDevicePluginUsesCdi bool, gpuSharingEnabled bool) *GPUSharing {
36+
return &GPUSharing{
37+
kubeClient: kubeClient,
38+
gpuDevicePluginUsesCdi: gpuDevicePluginUsesCdi,
39+
gpuSharingEnabled: gpuSharingEnabled,
40+
}
41+
}
42+
43+
func (p *GPUSharing) Name() string {
44+
return "gpusharing"
45+
}
46+
47+
func (p *GPUSharing) Validate(pod *v1.Pod) error {
48+
if !p.gpuSharingEnabled && resources.RequestsGPUFraction(pod) {
49+
return fmt.Errorf(
50+
"attempting to create a pod %s/%s with gpu sharing request, while GPU sharing is disabled",
51+
pod.Namespace, pod.Name,
52+
)
53+
}
54+
return gpurequesthandler.ValidateGpuRequests(pod)
55+
}
56+
57+
func (p *GPUSharing) Mutate(pod *v1.Pod) error {
58+
if len(pod.Spec.Containers) == 0 {
59+
return nil
60+
}
61+
62+
if !resources.RequestsGPUFraction(pod) {
63+
return nil
64+
}
65+
66+
containerRef := &gpusharingconfigmap.PodContainerRef{
67+
Container: &pod.Spec.Containers[fractionContainerIndex],
68+
Index: fractionContainerIndex,
69+
Type: gpusharingconfigmap.RegularContainer,
70+
}
71+
capabilitiesConfigMapName := gpusharingconfigmap.SetGpuCapabilitiesConfigMapName(pod, containerRef)
72+
directEnvVarsMapName, err := gpusharingconfigmap.ExtractDirectEnvVarsConfigMapName(pod, containerRef)
73+
if err != nil {
74+
return err
75+
}
76+
77+
common.AddGPUSharingEnvVars(containerRef.Container, capabilitiesConfigMapName)
78+
common.SetConfigMapVolume(pod, capabilitiesConfigMapName)
79+
common.AddDirectEnvVarsConfigMapSource(containerRef.Container, directEnvVarsMapName)
80+
81+
return nil
82+
}
83+
84+
func (p *GPUSharing) PreBind(
85+
ctx context.Context, pod *v1.Pod, _ *v1.Node, bindRequest *v1alpha2.BindRequest, state *state.BindingState,
86+
) error {
87+
if !common.IsSharedGPUAllocation(bindRequest) {
88+
return nil
89+
}
90+
91+
reservedGPUIds := slices.Clone(state.ReservedGPUIds)
92+
if p.gpuDevicePluginUsesCdi {
93+
for index, gpuIndex := range reservedGPUIds {
94+
reservedGPUIds[index] = fmt.Sprintf(CdiDeviceNameBase, gpuIndex)
95+
}
96+
}
97+
98+
containerRef := &gpusharingconfigmap.PodContainerRef{
99+
Container: &pod.Spec.Containers[fractionContainerIndex],
100+
Index: fractionContainerIndex,
101+
Type: gpusharingconfigmap.RegularContainer,
102+
}
103+
err := p.createCapabilitiesConfigMapIfMissing(ctx, pod, containerRef)
104+
if err != nil {
105+
return fmt.Errorf("failed to create capabilities configmap: %w", err)
106+
}
107+
108+
err = p.createDirectEnvMapIfMissing(ctx, pod, containerRef)
109+
if err != nil {
110+
return fmt.Errorf("failed to create env configmap: %w", err)
111+
}
112+
113+
nVisibleDevicesStr := strings.Join(reservedGPUIds, ",")
114+
err = common.SetNvidiaVisibleDevices(ctx, p.kubeClient, pod, containerRef, nVisibleDevicesStr)
115+
if err != nil {
116+
return err
117+
}
118+
119+
return common.SetGPUPortion(ctx, p.kubeClient, pod, containerRef, bindRequest.Spec.ReceivedGPU.Portion)
120+
}
121+
122+
func (p *GPUSharing) createCapabilitiesConfigMapIfMissing(ctx context.Context, pod *v1.Pod,
123+
containerRef *gpusharingconfigmap.PodContainerRef) error {
124+
capabilitiesConfigMapName, err := gpusharingconfigmap.ExtractCapabilitiesConfigMapName(pod, containerRef)
125+
if err != nil {
126+
return fmt.Errorf("failed to get capabilities configmap name: %w", err)
127+
}
128+
err = gpusharingconfigmap.UpsertJobConfigMap(ctx, p.kubeClient, pod, capabilitiesConfigMapName, map[string]string{})
129+
return err
130+
}
131+
132+
func (p *GPUSharing) createDirectEnvMapIfMissing(ctx context.Context, pod *v1.Pod,
133+
containerRef *gpusharingconfigmap.PodContainerRef) error {
134+
directEnvVarsMapName, err := gpusharingconfigmap.ExtractDirectEnvVarsConfigMapName(pod, containerRef)
135+
if err != nil {
136+
return err
137+
}
138+
directEnvVars := make(map[string]string)
139+
return gpusharingconfigmap.UpsertJobConfigMap(ctx, p.kubeClient, pod, directEnvVarsMapName, directEnvVars)
140+
}
141+
142+
func (p *GPUSharing) PostBind(
143+
context.Context, *v1.Pod, *v1.Node, *v1alpha2.BindRequest, *state.BindingState,
144+
) {
145+
}
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
// Copyright 2025 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package gpusharing
5+
6+
import (
7+
"fmt"
8+
"testing"
9+
10+
v1 "k8s.io/api/core/v1"
11+
"k8s.io/apimachinery/pkg/api/resource"
12+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
13+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
14+
15+
"github.com/NVIDIA/KAI-scheduler/pkg/common/constants"
16+
)
17+
18+
func TestValidate(t *testing.T) {
19+
tests := []struct {
20+
name string
21+
pod *v1.Pod
22+
GPUSharingEnabled bool
23+
error error
24+
}{
25+
{
26+
name: "GPU sharing disabled, whole GPU pod",
27+
pod: &v1.Pod{
28+
ObjectMeta: metav1.ObjectMeta{
29+
Name: "test-pod",
30+
Namespace: "test-namespace",
31+
},
32+
Spec: v1.PodSpec{
33+
Containers: []v1.Container{
34+
{
35+
Resources: v1.ResourceRequirements{
36+
Limits: v1.ResourceList{
37+
constants.GpuResource: resource.MustParse("1"),
38+
},
39+
},
40+
},
41+
},
42+
},
43+
},
44+
GPUSharingEnabled: false,
45+
error: nil,
46+
},
47+
{
48+
name: "GPU sharing enabled, whole GPU pod",
49+
pod: &v1.Pod{
50+
ObjectMeta: metav1.ObjectMeta{
51+
Name: "test-pod",
52+
Namespace: "test-namespace",
53+
},
54+
Spec: v1.PodSpec{
55+
Containers: []v1.Container{
56+
{
57+
Resources: v1.ResourceRequirements{
58+
Limits: v1.ResourceList{
59+
constants.GpuResource: resource.MustParse("1"),
60+
},
61+
},
62+
},
63+
},
64+
},
65+
},
66+
GPUSharingEnabled: true,
67+
error: nil,
68+
},
69+
{
70+
name: "GPU sharing disabled, GPU sharing pod - fraction",
71+
pod: &v1.Pod{
72+
ObjectMeta: metav1.ObjectMeta{
73+
Name: "test-pod",
74+
Namespace: "test-namespace",
75+
Annotations: map[string]string{
76+
constants.GpuFraction: "0.5",
77+
},
78+
},
79+
Spec: v1.PodSpec{
80+
Containers: []v1.Container{
81+
{
82+
Resources: v1.ResourceRequirements{
83+
Limits: v1.ResourceList{},
84+
},
85+
},
86+
},
87+
},
88+
},
89+
GPUSharingEnabled: false,
90+
error: fmt.Errorf("attempting to create a pod test-namespace/test-pod with gpu " +
91+
"sharing request, while GPU sharing is disabled"),
92+
},
93+
{
94+
name: "GPU sharing enabled, GPU sharing pod - fraction",
95+
pod: &v1.Pod{
96+
ObjectMeta: metav1.ObjectMeta{
97+
Name: "test-pod",
98+
Namespace: "test-namespace",
99+
Annotations: map[string]string{
100+
constants.GpuFraction: "0.5",
101+
},
102+
},
103+
Spec: v1.PodSpec{
104+
Containers: []v1.Container{
105+
{
106+
Resources: v1.ResourceRequirements{
107+
Limits: v1.ResourceList{},
108+
},
109+
},
110+
},
111+
},
112+
},
113+
GPUSharingEnabled: true,
114+
error: nil,
115+
},
116+
{
117+
name: "GPU sharing disabled, GPU sharing pod - memory",
118+
pod: &v1.Pod{
119+
ObjectMeta: metav1.ObjectMeta{
120+
Name: "test-pod",
121+
Namespace: "test-namespace",
122+
Annotations: map[string]string{
123+
constants.GpuMemory: "1024",
124+
},
125+
},
126+
Spec: v1.PodSpec{
127+
Containers: []v1.Container{
128+
{
129+
Resources: v1.ResourceRequirements{
130+
Limits: v1.ResourceList{},
131+
},
132+
},
133+
},
134+
},
135+
},
136+
GPUSharingEnabled: false,
137+
error: fmt.Errorf("attempting to create a pod test-namespace/test-pod with gpu " +
138+
"sharing request, while GPU sharing is disabled"),
139+
},
140+
{
141+
name: "GPU sharing enabled, GPU sharing pod - memory",
142+
pod: &v1.Pod{
143+
ObjectMeta: metav1.ObjectMeta{
144+
Name: "test-pod",
145+
Namespace: "test-namespace",
146+
Annotations: map[string]string{
147+
constants.GpuMemory: "1024",
148+
},
149+
},
150+
Spec: v1.PodSpec{
151+
Containers: []v1.Container{
152+
{
153+
Resources: v1.ResourceRequirements{
154+
Limits: v1.ResourceList{},
155+
},
156+
},
157+
},
158+
},
159+
},
160+
GPUSharingEnabled: true,
161+
error: nil,
162+
},
163+
}
164+
for _, tt := range tests {
165+
t.Run(tt.name, func(t *testing.T) {
166+
kubeClient := fake.NewClientBuilder().WithRuntimeObjects(tt.pod).Build()
167+
gpuSharingPlugin := New(kubeClient, false, tt.GPUSharingEnabled)
168+
err := gpuSharingPlugin.Validate(tt.pod)
169+
if err == nil && tt.error != nil {
170+
t.Errorf("Validate() expected and error but actual is nil")
171+
return
172+
}
173+
if err != nil && tt.error == nil {
174+
t.Errorf("Validate() actual is nil but didn't expect and error. Error: %v", err)
175+
return
176+
}
177+
if tt.error != nil && err.Error() != tt.error.Error() {
178+
t.Errorf("Validate()\nactual: %v\nexpected: %v\n", err, tt.error)
179+
return
180+
}
181+
})
182+
}
183+
}

0 commit comments

Comments
 (0)