From 59930e0c5561d005ee1ce59951827923cfc7694a Mon Sep 17 00:00:00 2001 From: Jon Huhn Date: Fri, 31 Oct 2025 21:56:15 -0500 Subject: [PATCH 1/9] Extract scheme for decoding opaque configs --- api/example.com/resource/gpu/v1alpha1/api.go | 39 +--- .../resource/gpu/v1alpha1/register.go | 45 +++++ cmd/dra-example-kubeletplugin/main.go | 21 +- cmd/dra-example-kubeletplugin/state.go | 15 +- cmd/dra-example-webhook/main.go | 190 ++++++++++-------- cmd/dra-example-webhook/main_test.go | 7 +- 6 files changed, 191 insertions(+), 126 deletions(-) create mode 100644 api/example.com/resource/gpu/v1alpha1/register.go diff --git a/api/example.com/resource/gpu/v1alpha1/api.go b/api/example.com/resource/gpu/v1alpha1/api.go index 203d220f..bc62b507 100644 --- a/api/example.com/resource/gpu/v1alpha1/api.go +++ b/api/example.com/resource/gpu/v1alpha1/api.go @@ -20,20 +20,9 @@ import ( "fmt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/apimachinery/pkg/runtime/serializer/json" ) -const ( - GroupName = "gpu.resource.example.com" - Version = "v1alpha1" - - GpuConfigKind = "GpuConfig" -) - -// Decoder implements a decoder for objects in this API group. -var Decoder runtime.Decoder +const GpuConfigKind = "GpuConfig" // +genclient // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object @@ -82,29 +71,3 @@ func (c *GpuConfig) Normalize() error { } return nil } - -func init() { - // Create a new scheme and add our types to it. If at some point in the - // future a new version of the configuration API becomes necessary, then - // conversion functions can be generated and registered to continue - // supporting older versions. - scheme := runtime.NewScheme() - schemeGroupVersion := schema.GroupVersion{ - Group: GroupName, - Version: Version, - } - scheme.AddKnownTypes(schemeGroupVersion, - &GpuConfig{}, - ) - metav1.AddToGroupVersion(scheme, schemeGroupVersion) - - // Set up a json serializer to decode our types. - Decoder = json.NewSerializerWithOptions( - json.DefaultMetaFactory, - scheme, - scheme, - json.SerializerOptions{ - Pretty: true, Strict: true, - }, - ) -} diff --git a/api/example.com/resource/gpu/v1alpha1/register.go b/api/example.com/resource/gpu/v1alpha1/register.go new file mode 100644 index 00000000..0c9f0233 --- /dev/null +++ b/api/example.com/resource/gpu/v1alpha1/register.go @@ -0,0 +1,45 @@ +/* + * Copyright The Kubernetes Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +const ( + GroupName = "gpu.resource.example.com" + Version = "v1alpha1" +) + +// SchemeGroupVersion is group version used to register these objects. +var SchemeGroupVersion = schema.GroupVersion{Group: GroupName, Version: Version} + +var ( + SchemeBuilder = runtime.NewSchemeBuilder(addKnownTypes) + AddToScheme = SchemeBuilder.AddToScheme +) + +// Adds the list of known types to the given scheme. +func addKnownTypes(scheme *runtime.Scheme) error { + scheme.AddKnownTypes(SchemeGroupVersion, + &GpuConfig{}, + ) + metav1.AddToGroupVersion(scheme, SchemeGroupVersion) + return nil +} diff --git a/cmd/dra-example-kubeletplugin/main.go b/cmd/dra-example-kubeletplugin/main.go index e0cfaf9c..80638217 100644 --- a/cmd/dra-example-kubeletplugin/main.go +++ b/cmd/dra-example-kubeletplugin/main.go @@ -27,10 +27,12 @@ import ( "github.com/urfave/cli/v2" + "k8s.io/apimachinery/pkg/runtime" coreclientset "k8s.io/client-go/kubernetes" "k8s.io/dynamic-resource-allocation/kubeletplugin" "k8s.io/klog/v2" + configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" "sigs.k8s.io/dra-example-driver/pkg/consts" "sigs.k8s.io/dra-example-driver/pkg/flags" ) @@ -55,6 +57,9 @@ type Config struct { flags *Flags coreclient coreclientset.Interface cancelMainCtx func(error) + + // Config types + configScheme *runtime.Scheme } func (c Config) DriverPluginPath() string { @@ -135,12 +140,22 @@ func newApp() *cli.App { ctx := c.Context clientSets, err := flags.kubeClientConfig.NewClientSets() if err != nil { - return fmt.Errorf("create client: %v", err) + return fmt.Errorf("create client: %w", err) + } + + configScheme := runtime.NewScheme() + sb := runtime.NewSchemeBuilder( + // TODO: only add the API versions that apply to a given profile + configapi.AddToScheme, + ) + if err := sb.AddToScheme(configScheme); err != nil { + return fmt.Errorf("create config scheme: %w", err) } config := &Config{ - flags: flags, - coreclient: clientSets.Core, + flags: flags, + coreclient: clientSets.Core, + configScheme: configScheme, } return RunPlugin(ctx, config) diff --git a/cmd/dra-example-kubeletplugin/state.go b/cmd/dra-example-kubeletplugin/state.go index 7cffe388..74bc4539 100644 --- a/cmd/dra-example-kubeletplugin/state.go +++ b/cmd/dra-example-kubeletplugin/state.go @@ -23,6 +23,7 @@ import ( resourceapi "k8s.io/api/resource/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/serializer/json" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" "k8s.io/kubernetes/pkg/kubelet/checkpointmanager" @@ -61,6 +62,7 @@ type DeviceState struct { cdi *CDIHandler allocatable AllocatableDevices checkpointManager checkpointmanager.CheckpointManager + configDecoder runtime.Decoder } func NewDeviceState(config *Config) (*DeviceState, error) { @@ -84,10 +86,21 @@ func NewDeviceState(config *Config) (*DeviceState, error) { return nil, fmt.Errorf("unable to create checkpoint manager: %v", err) } + // Set up a json serializer to decode our types. + decoder := json.NewSerializerWithOptions( + json.DefaultMetaFactory, + config.configScheme, + config.configScheme, + json.SerializerOptions{ + Pretty: true, Strict: true, + }, + ) + state := &DeviceState{ cdi: cdi, allocatable: allocatable, checkpointManager: checkpointManager, + configDecoder: decoder, } checkpoints, err := state.checkpointManager.ListCheckpoints() @@ -180,7 +193,7 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (Prepared // Retrieve the full set of device configs for the driver. configs, err := GetOpaqueDeviceConfigs( - configapi.Decoder, + s.configDecoder, consts.DriverName, claim.Status.Allocation.Devices.Config, ) diff --git a/cmd/dra-example-webhook/main.go b/cmd/dra-example-webhook/main.go index b7ce2172..69b9191e 100644 --- a/cmd/dra-example-webhook/main.go +++ b/cmd/dra-example-webhook/main.go @@ -30,6 +30,7 @@ import ( resourceapi "k8s.io/api/resource/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + kjson "k8s.io/apimachinery/pkg/runtime/serializer/json" "k8s.io/klog/v2" configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" @@ -45,6 +46,8 @@ type Flags struct { port int } +var configScheme = runtime.NewScheme() + func main() { if err := newApp().Run(os.Args); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) @@ -91,8 +94,15 @@ func newApp() *cli.App { return flags.loggingConfig.Apply() }, Action: func(c *cli.Context) error { + sb := runtime.NewSchemeBuilder( + configapi.AddToScheme, + ) + if err := sb.AddToScheme(configScheme); err != nil { + return fmt.Errorf("create config scheme: %w", err) + } + server := &http.Server{ - Handler: newMux(), + Handler: newMux(newConfigDecoder()), Addr: fmt.Sprintf(":%d", flags.port), } klog.Info("starting webhook server on", server.Addr) @@ -103,9 +113,21 @@ func newApp() *cli.App { return app } -func newMux() *http.ServeMux { +func newConfigDecoder() runtime.Decoder { + // Set up a json serializer to decode our types. + return kjson.NewSerializerWithOptions( + kjson.DefaultMetaFactory, + configScheme, + configScheme, + kjson.SerializerOptions{ + Pretty: true, Strict: true, + }, + ) +} + +func newMux(configDecoder runtime.Decoder) *http.ServeMux { mux := http.NewServeMux() - mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim) + mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder)) mux.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) { _, err := w.Write([]byte("ok")) if err != nil { @@ -116,8 +138,10 @@ func newMux() *http.ServeMux { return mux } -func serveResourceClaim(w http.ResponseWriter, r *http.Request) { - serve(w, r, admitResourceClaimParameters) +func serveResourceClaim(configDecoder runtime.Decoder) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + serve(w, r, admitResourceClaimParameters(configDecoder)) + } } // serve handles the http portion of a request prior to handing to an admit @@ -191,96 +215,98 @@ func readAdmissionReview(data []byte) (*admissionv1.AdmissionReview, error) { // admitResourceClaimParameters accepts both ResourceClaims and ResourceClaimTemplates and validates their // opaque device configuration parameters for this driver. -func admitResourceClaimParameters(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { - klog.V(2).Info("admitting resource claim parameters") - - var deviceConfigs []resourceapi.DeviceClaimConfiguration - var specPath string - - switch ar.Request.Resource { - case resourceClaimResourceV1, resourceClaimResourceV1Beta1, resourceClaimResourceV1Beta2: - claim, err := extractResourceClaim(ar) - if err != nil { - klog.Error(err) - return &admissionv1.AdmissionResponse{ - Result: &metav1.Status{ - Message: err.Error(), - Reason: metav1.StatusReasonBadRequest, - }, +func admitResourceClaimParameters(configDecoder runtime.Decoder) func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { + return func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { + klog.V(2).Info("admitting resource claim parameters") + + var deviceConfigs []resourceapi.DeviceClaimConfiguration + var specPath string + + switch ar.Request.Resource { + case resourceClaimResourceV1, resourceClaimResourceV1Beta1, resourceClaimResourceV1Beta2: + claim, err := extractResourceClaim(ar) + if err != nil { + klog.Error(err) + return &admissionv1.AdmissionResponse{ + Result: &metav1.Status{ + Message: err.Error(), + Reason: metav1.StatusReasonBadRequest, + }, + } } - } - deviceConfigs = claim.Spec.Devices.Config - specPath = "spec" - case resourceClaimTemplateResourceV1, resourceClaimTemplateResourceV1Beta1, resourceClaimTemplateResourceV1Beta2: - claimTemplate, err := extractResourceClaimTemplate(ar) - if err != nil { - klog.Error(err) + deviceConfigs = claim.Spec.Devices.Config + specPath = "spec" + case resourceClaimTemplateResourceV1, resourceClaimTemplateResourceV1Beta1, resourceClaimTemplateResourceV1Beta2: + claimTemplate, err := extractResourceClaimTemplate(ar) + if err != nil { + klog.Error(err) + return &admissionv1.AdmissionResponse{ + Result: &metav1.Status{ + Message: err.Error(), + Reason: metav1.StatusReasonBadRequest, + }, + } + } + deviceConfigs = claimTemplate.Spec.Spec.Devices.Config + specPath = "spec.spec" + default: + msg := fmt.Sprintf( + "expected resource to be one of %v, got %s", + []metav1.GroupVersionResource{ + resourceClaimResourceV1, resourceClaimResourceV1Beta1, resourceClaimResourceV1Beta2, + resourceClaimTemplateResourceV1, resourceClaimTemplateResourceV1Beta1, resourceClaimTemplateResourceV1Beta2, + }, + ar.Request.Resource, + ) + klog.Error(msg) return &admissionv1.AdmissionResponse{ Result: &metav1.Status{ - Message: err.Error(), + Message: msg, Reason: metav1.StatusReasonBadRequest, }, } } - deviceConfigs = claimTemplate.Spec.Spec.Devices.Config - specPath = "spec.spec" - default: - msg := fmt.Sprintf( - "expected resource to be one of %v, got %s", - []metav1.GroupVersionResource{ - resourceClaimResourceV1, resourceClaimResourceV1Beta1, resourceClaimResourceV1Beta2, - resourceClaimTemplateResourceV1, resourceClaimTemplateResourceV1Beta1, resourceClaimTemplateResourceV1Beta2, - }, - ar.Request.Resource, - ) - klog.Error(msg) - return &admissionv1.AdmissionResponse{ - Result: &metav1.Status{ - Message: msg, - Reason: metav1.StatusReasonBadRequest, - }, - } - } - var errs []error - for configIndex, config := range deviceConfigs { - if config.Opaque == nil || config.Opaque.Driver != consts.DriverName { - continue - } + var errs []error + for configIndex, config := range deviceConfigs { + if config.Opaque == nil || config.Opaque.Driver != consts.DriverName { + continue + } - fieldPath := fmt.Sprintf("%s.devices.config[%d].opaque.parameters", specPath, configIndex) - decodedConfig, err := runtime.Decode(configapi.Decoder, config.DeviceConfiguration.Opaque.Parameters.Raw) - if err != nil { - errs = append(errs, fmt.Errorf("error decoding object at %s: %w", fieldPath, err)) - continue - } - gpuConfig, ok := decodedConfig.(*configapi.GpuConfig) - if !ok { - errs = append(errs, fmt.Errorf("expected v1alpha1.GpuConfig at %s but got: %T", fieldPath, decodedConfig)) - continue - } - err = gpuConfig.Validate() - if err != nil { - errs = append(errs, fmt.Errorf("object at %s is invalid: %w", fieldPath, err)) + fieldPath := fmt.Sprintf("%s.devices.config[%d].opaque.parameters", specPath, configIndex) + decodedConfig, err := runtime.Decode(configDecoder, config.DeviceConfiguration.Opaque.Parameters.Raw) + if err != nil { + errs = append(errs, fmt.Errorf("error decoding object at %s: %w", fieldPath, err)) + continue + } + gpuConfig, ok := decodedConfig.(*configapi.GpuConfig) + if !ok { + errs = append(errs, fmt.Errorf("expected v1alpha1.GpuConfig at %s but got: %T", fieldPath, decodedConfig)) + continue + } + err = gpuConfig.Validate() + if err != nil { + errs = append(errs, fmt.Errorf("object at %s is invalid: %w", fieldPath, err)) + } } - } - if len(errs) > 0 { - var errMsgs []string - for _, err := range errs { - errMsgs = append(errMsgs, err.Error()) + if len(errs) > 0 { + var errMsgs []string + for _, err := range errs { + errMsgs = append(errMsgs, err.Error()) + } + msg := fmt.Sprintf("%d configs failed to validate: %s", len(errs), strings.Join(errMsgs, "; ")) + klog.Error(msg) + return &admissionv1.AdmissionResponse{ + Result: &metav1.Status{ + Message: msg, + Reason: metav1.StatusReason(metav1.StatusReasonInvalid), + }, + } } - msg := fmt.Sprintf("%d configs failed to validate: %s", len(errs), strings.Join(errMsgs, "; ")) - klog.Error(msg) + return &admissionv1.AdmissionResponse{ - Result: &metav1.Status{ - Message: msg, - Reason: metav1.StatusReason(metav1.StatusReasonInvalid), - }, + Allowed: true, } } - - return &admissionv1.AdmissionResponse{ - Allowed: true, - } } diff --git a/cmd/dra-example-webhook/main_test.go b/cmd/dra-example-webhook/main_test.go index 249c128a..dbfb1c03 100644 --- a/cmd/dra-example-webhook/main_test.go +++ b/cmd/dra-example-webhook/main_test.go @@ -40,7 +40,7 @@ import ( ) func TestReadyEndpoint(t *testing.T) { - s := httptest.NewServer(newMux()) + s := httptest.NewServer(newMux(nil)) t.Cleanup(s.Close) res, err := http.Get(s.URL + "/readyz") @@ -168,7 +168,10 @@ func TestResourceClaimValidatingWebhook(t *testing.T) { }, } - s := httptest.NewServer(newMux()) + sb := gpu.ConfigSchemeBuilder + assert.NoError(t, sb.AddToScheme(configScheme)) + + s := httptest.NewServer(newMux(newConfigDecoder())) t.Cleanup(s.Close) for name, test := range tests { From 05bd79e0ae27d6a9d873dc0ee120b1aa148eed2a Mon Sep 17 00:00:00 2001 From: Jon Huhn Date: Fri, 31 Oct 2025 22:01:17 -0500 Subject: [PATCH 2/9] Extract device-specific config logic --- cmd/dra-example-kubeletplugin/main.go | 3 + cmd/dra-example-kubeletplugin/state.go | 92 ++++---------------------- internal/profiles/gpu/gpu.go | 92 ++++++++++++++++++++++++++ internal/profiles/profiles.go | 21 ++++++ 4 files changed, 130 insertions(+), 78 deletions(-) create mode 100644 internal/profiles/gpu/gpu.go create mode 100644 internal/profiles/profiles.go diff --git a/cmd/dra-example-kubeletplugin/main.go b/cmd/dra-example-kubeletplugin/main.go index 80638217..029e2ff6 100644 --- a/cmd/dra-example-kubeletplugin/main.go +++ b/cmd/dra-example-kubeletplugin/main.go @@ -33,6 +33,7 @@ import ( "k8s.io/klog/v2" configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" + "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" "sigs.k8s.io/dra-example-driver/pkg/consts" "sigs.k8s.io/dra-example-driver/pkg/flags" ) @@ -60,6 +61,7 @@ type Config struct { // Config types configScheme *runtime.Scheme + applyConfigFunc ApplyConfigFunc } func (c Config) DriverPluginPath() string { @@ -156,6 +158,7 @@ func newApp() *cli.App { flags: flags, coreclient: clientSets.Core, configScheme: configScheme, + applyConfigFunc: gpu.ApplyConfig, // TODO: select an implementation based on the profile } return RunPlugin(ctx, config) diff --git a/cmd/dra-example-kubeletplugin/state.go b/cmd/dra-example-kubeletplugin/state.go index 74bc4539..ae5559f2 100644 --- a/cmd/dra-example-kubeletplugin/state.go +++ b/cmd/dra-example-kubeletplugin/state.go @@ -26,18 +26,17 @@ import ( "k8s.io/apimachinery/pkg/runtime/serializer/json" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" "k8s.io/kubernetes/pkg/kubelet/checkpointmanager" + cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" - configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" + "sigs.k8s.io/dra-example-driver/internal/profiles" "sigs.k8s.io/dra-example-driver/pkg/consts" - - cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" - cdispec "tags.cncf.io/container-device-interface/specs-go" ) type AllocatableDevices map[string]resourceapi.Device type PreparedDevices []*PreparedDevice type PreparedClaims map[string]PreparedDevices -type PerDeviceCDIContainerEdits map[string]*cdiapi.ContainerEdits + +type ApplyConfigFunc func(cconfig runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) type OpaqueDeviceConfig struct { Requests []string @@ -63,6 +62,7 @@ type DeviceState struct { allocatable AllocatableDevices checkpointManager checkpointmanager.CheckpointManager configDecoder runtime.Decoder + applyConfigFunc ApplyConfigFunc } func NewDeviceState(config *Config) (*DeviceState, error) { @@ -101,6 +101,7 @@ func NewDeviceState(config *Config) (*DeviceState, error) { allocatable: allocatable, checkpointManager: checkpointManager, configDecoder: decoder, + applyConfigFunc: config.applyConfigFunc, } checkpoints, err := state.checkpointManager.ListCheckpoints() @@ -204,10 +205,7 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (Prepared // Add the default GPU Config to the front of the config list with the // lowest precedence. This guarantees there will be at least one config in // the list with len(Requests) == 0 for the lookup below. - configs = slices.Insert(configs, 0, &OpaqueDeviceConfig{ - Requests: []string{}, - Config: configapi.DefaultGpuConfig(), - }) + configs = slices.Insert(configs, 0, &OpaqueDeviceConfig{}) // Look through the configs and figure out which one will be applied to // each device allocation result based on their order of precedence. @@ -224,34 +222,15 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (Prepared } } - // Normalize, validate, and apply all configs associated with devices that - // need to be prepared. Track container edits generated from applying the - // config to the set of device allocation results. - perDeviceCDIContainerEdits := make(PerDeviceCDIContainerEdits) - for c, results := range configResultsMap { - // Cast the opaque config to a GpuConfig - var config *configapi.GpuConfig - switch castConfig := c.(type) { - case *configapi.GpuConfig: - config = castConfig - default: - return nil, fmt.Errorf("runtime object is not a regognized configuration") - } - - // Normalize the config to set any implied defaults. - if err := config.Normalize(); err != nil { - return nil, fmt.Errorf("error normalizing GPU config: %w", err) - } - - // Validate the config to ensure its integrity. - if err := config.Validate(); err != nil { - return nil, fmt.Errorf("error validating GPU config: %w", err) - } - + // Apply all configs associated with devices that need to be prepared. + // Track container edits generated from applying the config to the set + // of device allocation results. + perDeviceCDIContainerEdits := make(profiles.PerDeviceCDIContainerEdits) + for config, results := range configResultsMap { // Apply the config to the list of results associated with it. - containerEdits, err := s.applyConfig(config, results) + containerEdits, err := s.applyConfigFunc(config, results) if err != nil { - return nil, fmt.Errorf("error applying GPU config: %w", err) + return nil, fmt.Errorf("error applying config: %w", err) } // Merge any new container edits with the overall per device map. @@ -285,49 +264,6 @@ func (s *DeviceState) unprepareDevices(claimUID string, devices PreparedDevices) return nil } -// applyConfig applies a configuration to a set of device allocation results. -// -// In this example driver there is no actual configuration applied. We simply -// define a set of environment variables to be injected into the containers -// that include a given device. A real driver would likely need to do some sort -// of hardware configuration as well, based on the config passed in. -func (s *DeviceState) applyConfig(config *configapi.GpuConfig, results []*resourceapi.DeviceRequestAllocationResult) (PerDeviceCDIContainerEdits, error) { - perDeviceEdits := make(PerDeviceCDIContainerEdits) - - for _, result := range results { - envs := []string{ - fmt.Sprintf("GPU_DEVICE_%s=%s", result.Device[4:], result.Device), - } - - if config.Sharing != nil { - envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_SHARING_STRATEGY=%s", result.Device[4:], config.Sharing.Strategy)) - } - - switch { - case config.Sharing.IsTimeSlicing(): - tsconfig, err := config.Sharing.GetTimeSlicingConfig() - if err != nil { - return nil, fmt.Errorf("unable to get time slicing config for device %v: %w", result.Device, err) - } - envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_TIMESLICE_INTERVAL=%v", result.Device[4:], tsconfig.Interval)) - case config.Sharing.IsSpacePartitioning(): - spconfig, err := config.Sharing.GetSpacePartitioningConfig() - if err != nil { - return nil, fmt.Errorf("unable to get space partitioning config for device %v: %w", result.Device, err) - } - envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_PARTITION_COUNT=%v", result.Device[4:], spconfig.PartitionCount)) - } - - edits := &cdispec.ContainerEdits{ - Env: envs, - } - - perDeviceEdits[result.Device] = &cdiapi.ContainerEdits{ContainerEdits: edits} - } - - return perDeviceEdits, nil -} - // GetOpaqueDeviceConfigs returns an ordered list of the configs contained in possibleConfigs for this driver. // // Configs can either come from the resource claim itself or from the device diff --git a/internal/profiles/gpu/gpu.go b/internal/profiles/gpu/gpu.go new file mode 100644 index 00000000..c10e2e4f --- /dev/null +++ b/internal/profiles/gpu/gpu.go @@ -0,0 +1,92 @@ +/* + * Copyright The Kubernetes Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gpu + +import ( + "fmt" + + resourceapi "k8s.io/api/resource/v1" + "k8s.io/apimachinery/pkg/runtime" + cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" + cdispec "tags.cncf.io/container-device-interface/specs-go" + + configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" + "sigs.k8s.io/dra-example-driver/internal/profiles" +) + +// ApplyConfig applies a configuration to a set of device allocation results. +// +// In this example driver there is no actual configuration applied. We simply +// define a set of environment variables to be injected into the containers +// that include a given device. A real driver would likely need to do some sort +// of hardware configuration as well, based on the config passed in. +func ApplyConfig(config runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) { + if config == nil { + config = configapi.DefaultGpuConfig() + } + if config, ok := config.(*configapi.GpuConfig); ok { + return applyGpuConfig(config, results) + } + return nil, fmt.Errorf("runtime object is not a recognized configuration") +} + +func applyGpuConfig(config *configapi.GpuConfig, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) { + perDeviceEdits := make(profiles.PerDeviceCDIContainerEdits) + + // Normalize the config to set any implied defaults. + if err := config.Normalize(); err != nil { + return nil, fmt.Errorf("error normalizing GPU config: %w", err) + } + + // Validate the config to ensure its integrity. + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("error validating GPU config: %w", err) + } + + for _, result := range results { + envs := []string{ + fmt.Sprintf("GPU_DEVICE_%s=%s", result.Device[4:], result.Device), + } + + if config.Sharing != nil { + envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_SHARING_STRATEGY=%s", result.Device[4:], config.Sharing.Strategy)) + } + + switch { + case config.Sharing.IsTimeSlicing(): + tsconfig, err := config.Sharing.GetTimeSlicingConfig() + if err != nil { + return nil, fmt.Errorf("unable to get time slicing config for device %v: %w", result.Device, err) + } + envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_TIMESLICE_INTERVAL=%v", result.Device[4:], tsconfig.Interval)) + case config.Sharing.IsSpacePartitioning(): + spconfig, err := config.Sharing.GetSpacePartitioningConfig() + if err != nil { + return nil, fmt.Errorf("unable to get space partitioning config for device %v: %w", result.Device, err) + } + envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_PARTITION_COUNT=%v", result.Device[4:], spconfig.PartitionCount)) + } + + edits := &cdispec.ContainerEdits{ + Env: envs, + } + + perDeviceEdits[result.Device] = &cdiapi.ContainerEdits{ContainerEdits: edits} + } + + return perDeviceEdits, nil +} diff --git a/internal/profiles/profiles.go b/internal/profiles/profiles.go new file mode 100644 index 00000000..e8561bdc --- /dev/null +++ b/internal/profiles/profiles.go @@ -0,0 +1,21 @@ +/* + * Copyright The Kubernetes Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package profiles + +import cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" + +type PerDeviceCDIContainerEdits map[string]*cdiapi.ContainerEdits From 9d7bd5108d412a4dc50bf96ba5638681eee0805e Mon Sep 17 00:00:00 2001 From: Jon Huhn Date: Fri, 14 Nov 2025 19:53:19 -0600 Subject: [PATCH 3/9] Extract device-specific CDI details --- cmd/dra-example-kubeletplugin/cdi.go | 52 +++++++++++++-------- cmd/dra-example-kubeletplugin/main.go | 14 ++++-- cmd/dra-example-kubeletplugin/state.go | 27 +++-------- cmd/dra-example-kubeletplugin/state_test.go | 6 ++- internal/profiles/gpu/gpu.go | 6 +++ internal/profiles/profiles.go | 20 +++++++- 6 files changed, 77 insertions(+), 48 deletions(-) diff --git a/cmd/dra-example-kubeletplugin/cdi.go b/cmd/dra-example-kubeletplugin/cdi.go index 7e701178..42086d93 100644 --- a/cmd/dra-example-kubeletplugin/cdi.go +++ b/cmd/dra-example-kubeletplugin/cdi.go @@ -19,35 +19,38 @@ package main import ( "fmt" "os" - - "sigs.k8s.io/dra-example-driver/pkg/consts" + "regexp" + "strings" cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" cdiparser "tags.cncf.io/container-device-interface/pkg/parser" cdispec "tags.cncf.io/container-device-interface/specs-go" + + "sigs.k8s.io/dra-example-driver/internal/profiles" + "sigs.k8s.io/dra-example-driver/pkg/consts" ) -const ( - cdiVendor = "k8s." + consts.DriverName - cdiClass = "gpu" - cdiKind = cdiVendor + "/" + cdiClass +const cdiCommonDeviceName = "common" - cdiCommonDeviceName = "common" -) +var nonWord = regexp.MustCompile(`[^a-zA-Z0-9]+`) type CDIHandler struct { - cache *cdiapi.Cache + cache *cdiapi.Cache + driverName string + class string } -func NewCDIHandler(config *Config) (*CDIHandler, error) { +func NewCDIHandler(root string, driverName, class string) (*CDIHandler, error) { cache, err := cdiapi.NewCache( - cdiapi.WithSpecDirs(config.flags.cdiRoot), + cdiapi.WithSpecDirs(root), ) if err != nil { return nil, fmt.Errorf("unable to create a new CDI cache: %w", err) } handler := &CDIHandler{ - cache: cache, + cache: cache, + driverName: driverName, + class: class, } return handler, nil @@ -55,7 +58,7 @@ func NewCDIHandler(config *Config) (*CDIHandler, error) { func (cdi *CDIHandler) CreateCommonSpecFile() error { spec := &cdispec.Spec{ - Kind: cdiKind, + Kind: cdi.kind(), Devices: []cdispec.Device{ { Name: cdiCommonDeviceName, @@ -83,19 +86,20 @@ func (cdi *CDIHandler) CreateCommonSpecFile() error { return cdi.cache.WriteSpec(spec, specName) } -func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, devices PreparedDevices) error { - specName := cdiapi.GenerateTransientSpecName(cdiVendor, cdiClass, claimUID) +func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, devices profiles.PreparedDevices) error { + specName := cdiapi.GenerateTransientSpecName(cdi.vendor(), cdi.class, claimUID) spec := &cdispec.Spec{ - Kind: cdiKind, + Kind: cdi.kind(), Devices: []cdispec.Device{}, } for _, device := range devices { + deviceEnvKey := strings.ToUpper(nonWord.ReplaceAllString(device.DeviceName, "_")) claimEdits := cdiapi.ContainerEdits{ ContainerEdits: &cdispec.ContainerEdits{ Env: []string{ - fmt.Sprintf("GPU_DEVICE_%s_RESOURCE_CLAIM=%s", device.DeviceName[4:], claimUID), + fmt.Sprintf("%s_DEVICE_%s_RESOURCE_CLAIM=%s", strings.ToUpper(cdi.class), deviceEnvKey, claimUID), }, }, } @@ -119,19 +123,27 @@ func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, devices PreparedDevi } func (cdi *CDIHandler) DeleteClaimSpecFile(claimUID string) error { - specName := cdiapi.GenerateTransientSpecName(cdiVendor, cdiClass, claimUID) + specName := cdiapi.GenerateTransientSpecName(cdi.vendor(), cdi.class, claimUID) return cdi.cache.RemoveSpec(specName) } func (cdi *CDIHandler) GetClaimDevices(claimUID string, devices []string) []string { cdiDevices := []string{ - cdiparser.QualifiedName(cdiVendor, cdiClass, cdiCommonDeviceName), + cdiparser.QualifiedName(cdi.vendor(), cdi.class, cdiCommonDeviceName), } for _, device := range devices { - cdiDevice := cdiparser.QualifiedName(cdiVendor, cdiClass, fmt.Sprintf("%s-%s", claimUID, device)) + cdiDevice := cdiparser.QualifiedName(cdi.vendor(), cdi.class, fmt.Sprintf("%s-%s", claimUID, device)) cdiDevices = append(cdiDevices, cdiDevice) } return cdiDevices } + +func (cdi *CDIHandler) kind() string { + return cdi.vendor() + "/" + cdi.class +} + +func (cdi *CDIHandler) vendor() string { + return "k8s." + cdi.driverName +} diff --git a/cmd/dra-example-kubeletplugin/main.go b/cmd/dra-example-kubeletplugin/main.go index 029e2ff6..58863958 100644 --- a/cmd/dra-example-kubeletplugin/main.go +++ b/cmd/dra-example-kubeletplugin/main.go @@ -62,6 +62,8 @@ type Config struct { // Config types configScheme *runtime.Scheme applyConfigFunc ApplyConfigFunc + cdiVendor string + cdiClass string } func (c Config) DriverPluginPath() string { @@ -155,10 +157,14 @@ func newApp() *cli.App { } config := &Config{ - flags: flags, - coreclient: clientSets.Core, - configScheme: configScheme, - applyConfigFunc: gpu.ApplyConfig, // TODO: select an implementation based on the profile + flags: flags, + coreclient: clientSets.Core, + configScheme: configScheme, + + // TODO: select an implementation based on the profile + applyConfigFunc: gpu.ApplyConfig, + cdiVendor: gpu.CDIVendor, + cdiClass: gpu.CDIClass, } return RunPlugin(ctx, config) diff --git a/cmd/dra-example-kubeletplugin/state.go b/cmd/dra-example-kubeletplugin/state.go index ae5559f2..3e2f83e2 100644 --- a/cmd/dra-example-kubeletplugin/state.go +++ b/cmd/dra-example-kubeletplugin/state.go @@ -26,15 +26,13 @@ import ( "k8s.io/apimachinery/pkg/runtime/serializer/json" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" "k8s.io/kubernetes/pkg/kubelet/checkpointmanager" - cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" "sigs.k8s.io/dra-example-driver/internal/profiles" "sigs.k8s.io/dra-example-driver/pkg/consts" ) type AllocatableDevices map[string]resourceapi.Device -type PreparedDevices []*PreparedDevice -type PreparedClaims map[string]PreparedDevices +type PreparedClaims map[string]profiles.PreparedDevices type ApplyConfigFunc func(cconfig runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) @@ -43,19 +41,6 @@ type OpaqueDeviceConfig struct { Config runtime.Object } -type PreparedDevice struct { - drapbv1.Device - ContainerEdits *cdiapi.ContainerEdits -} - -func (pds PreparedDevices) GetDevices() []*drapbv1.Device { - var devices []*drapbv1.Device - for _, pd := range pds { - devices = append(devices, &pd.Device) - } - return devices -} - type DeviceState struct { sync.Mutex cdi *CDIHandler @@ -71,7 +56,7 @@ func NewDeviceState(config *Config) (*DeviceState, error) { return nil, fmt.Errorf("error enumerating all possible devices: %v", err) } - cdi, err := NewCDIHandler(config) + cdi, err := NewCDIHandler(config.flags.cdiRoot, consts.DriverName, config.cdiClass) if err != nil { return nil, fmt.Errorf("unable to create CDI handler: %v", err) } @@ -187,7 +172,7 @@ func (s *DeviceState) Unprepare(claimUID string) error { return nil } -func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (PreparedDevices, error) { +func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (profiles.PreparedDevices, error) { if claim.Status.Allocation == nil { return nil, fmt.Errorf("claim not yet allocated") } @@ -241,10 +226,10 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (Prepared // Walk through each config and its associated device allocation results // and construct the list of prepared devices to return. - var preparedDevices PreparedDevices + var preparedDevices profiles.PreparedDevices for _, results := range configResultsMap { for _, result := range results { - device := &PreparedDevice{ + device := &profiles.PreparedDevice{ Device: drapbv1.Device{ RequestNames: []string{result.Request}, PoolName: result.Pool, @@ -260,7 +245,7 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (Prepared return preparedDevices, nil } -func (s *DeviceState) unprepareDevices(claimUID string, devices PreparedDevices) error { +func (s *DeviceState) unprepareDevices(claimUID string, devices profiles.PreparedDevices) error { return nil } diff --git a/cmd/dra-example-kubeletplugin/state_test.go b/cmd/dra-example-kubeletplugin/state_test.go index 8490ebb5..d23b3859 100644 --- a/cmd/dra-example-kubeletplugin/state_test.go +++ b/cmd/dra-example-kubeletplugin/state_test.go @@ -21,12 +21,14 @@ import ( "github.com/stretchr/testify/assert" + "sigs.k8s.io/dra-example-driver/internal/profiles" + drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" ) func TestPreparedDevicesGetDevices(t *testing.T) { tests := map[string]struct { - preparedDevices PreparedDevices + preparedDevices profiles.PreparedDevices expected []*drapbv1.Device }{ "nil PreparedDevices": { @@ -34,7 +36,7 @@ func TestPreparedDevicesGetDevices(t *testing.T) { expected: nil, }, "several PreparedDevices": { - preparedDevices: PreparedDevices{ + preparedDevices: profiles.PreparedDevices{ {Device: drapbv1.Device{DeviceName: "dev1"}}, {Device: drapbv1.Device{DeviceName: "dev2"}}, {Device: drapbv1.Device{DeviceName: "dev3"}}, diff --git a/internal/profiles/gpu/gpu.go b/internal/profiles/gpu/gpu.go index c10e2e4f..fdb6efff 100644 --- a/internal/profiles/gpu/gpu.go +++ b/internal/profiles/gpu/gpu.go @@ -26,6 +26,12 @@ import ( configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" "sigs.k8s.io/dra-example-driver/internal/profiles" + "sigs.k8s.io/dra-example-driver/pkg/consts" +) + +const ( + CDIVendor = "k8s." + consts.DriverName + CDIClass = "gpu" ) // ApplyConfig applies a configuration to a set of device allocation results. diff --git a/internal/profiles/profiles.go b/internal/profiles/profiles.go index e8561bdc..7ad0326a 100644 --- a/internal/profiles/profiles.go +++ b/internal/profiles/profiles.go @@ -16,6 +16,24 @@ package profiles -import cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" +import ( + drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" + cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" +) type PerDeviceCDIContainerEdits map[string]*cdiapi.ContainerEdits + +type PreparedDevice struct { + drapbv1.Device + ContainerEdits *cdiapi.ContainerEdits +} + +type PreparedDevices []*PreparedDevice + +func (pds PreparedDevices) GetDevices() []*drapbv1.Device { + var devices []*drapbv1.Device + for _, pd := range pds { + devices = append(devices, &pd.Device) + } + return devices +} From 0321de723c2b636be021b69f8466b76ca618834d Mon Sep 17 00:00:00 2001 From: Jon Huhn Date: Mon, 17 Nov 2025 17:14:00 -0600 Subject: [PATCH 4/9] Extract device discovery --- cmd/dra-example-kubeletplugin/discovery.go | 84 ---------------------- cmd/dra-example-kubeletplugin/driver.go | 20 +----- cmd/dra-example-kubeletplugin/main.go | 16 +++-- cmd/dra-example-kubeletplugin/state.go | 12 +++- internal/profiles/gpu/gpu.go | 75 +++++++++++++++++++ 5 files changed, 96 insertions(+), 111 deletions(-) delete mode 100644 cmd/dra-example-kubeletplugin/discovery.go diff --git a/cmd/dra-example-kubeletplugin/discovery.go b/cmd/dra-example-kubeletplugin/discovery.go deleted file mode 100644 index 0c45431f..00000000 --- a/cmd/dra-example-kubeletplugin/discovery.go +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright 2023 The Kubernetes Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package main - -import ( - "fmt" - "math/rand" - "os" - - resourceapi "k8s.io/api/resource/v1" - "k8s.io/apimachinery/pkg/api/resource" - "k8s.io/utils/ptr" - - "github.com/google/uuid" -) - -func enumerateAllPossibleDevices(numGPUs int) (AllocatableDevices, error) { - seed := os.Getenv("NODE_NAME") - uuids := generateUUIDs(seed, numGPUs) - - alldevices := make(AllocatableDevices) - for i, uuid := range uuids { - device := resourceapi.Device{ - Name: fmt.Sprintf("gpu-%d", i), - Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ - "index": { - IntValue: ptr.To(int64(i)), - }, - "uuid": { - StringValue: ptr.To(uuid), - }, - "model": { - StringValue: ptr.To("LATEST-GPU-MODEL"), - }, - "driverVersion": { - VersionValue: ptr.To("1.0.0"), - }, - }, - Capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{ - "memory": { - Value: resource.MustParse("80Gi"), - }, - }, - } - alldevices[device.Name] = device - } - return alldevices, nil -} - -func generateUUIDs(seed string, count int) []string { - rand := rand.New(rand.NewSource(hash(seed))) - - uuids := make([]string, count) - for i := 0; i < count; i++ { - charset := make([]byte, 16) - rand.Read(charset) - uuid, _ := uuid.FromBytes(charset) - uuids[i] = "gpu-" + uuid.String() - } - - return uuids -} - -func hash(s string) int64 { - h := int64(0) - for _, c := range s { - h = 31*h + int64(c) - } - return h -} diff --git a/cmd/dra-example-kubeletplugin/driver.go b/cmd/dra-example-kubeletplugin/driver.go index 0f6e224c..c5cf7101 100644 --- a/cmd/dra-example-kubeletplugin/driver.go +++ b/cmd/dra-example-kubeletplugin/driver.go @@ -20,14 +20,12 @@ import ( "context" "errors" "fmt" - "maps" resourceapi "k8s.io/api/resource/v1" "k8s.io/apimachinery/pkg/types" utilruntime "k8s.io/apimachinery/pkg/util/runtime" coreclientset "k8s.io/client-go/kubernetes" "k8s.io/dynamic-resource-allocation/kubeletplugin" - "k8s.io/dynamic-resource-allocation/resourceslice" "k8s.io/klog/v2" "sigs.k8s.io/dra-example-driver/pkg/consts" @@ -67,28 +65,12 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) { } driver.helper = helper - devices := make([]resourceapi.Device, 0, len(state.allocatable)) - for device := range maps.Values(state.allocatable) { - devices = append(devices, device) - } - resources := resourceslice.DriverResources{ - Pools: map[string]resourceslice.Pool{ - config.flags.nodeName: { - Slices: []resourceslice.Slice{ - { - Devices: devices, - }, - }, - }, - }, - } - driver.healthcheck, err = startHealthcheck(ctx, config) if err != nil { return nil, fmt.Errorf("start healthcheck: %w", err) } - if err := helper.PublishResources(ctx, resources); err != nil { + if err := helper.PublishResources(ctx, state.driverResources); err != nil { return nil, err } diff --git a/cmd/dra-example-kubeletplugin/main.go b/cmd/dra-example-kubeletplugin/main.go index 58863958..53bdc274 100644 --- a/cmd/dra-example-kubeletplugin/main.go +++ b/cmd/dra-example-kubeletplugin/main.go @@ -30,6 +30,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" coreclientset "k8s.io/client-go/kubernetes" "k8s.io/dynamic-resource-allocation/kubeletplugin" + "k8s.io/dynamic-resource-allocation/resourceslice" "k8s.io/klog/v2" configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" @@ -60,10 +61,10 @@ type Config struct { cancelMainCtx func(error) // Config types - configScheme *runtime.Scheme - applyConfigFunc ApplyConfigFunc - cdiVendor string - cdiClass string + configScheme *runtime.Scheme + applyConfigFunc ApplyConfigFunc + cdiClass string + enumerateDevicesFunc func() (resourceslice.DriverResources, error) } func (c Config) DriverPluginPath() string { @@ -162,9 +163,10 @@ func newApp() *cli.App { configScheme: configScheme, // TODO: select an implementation based on the profile - applyConfigFunc: gpu.ApplyConfig, - cdiVendor: gpu.CDIVendor, - cdiClass: gpu.CDIClass, + applyConfigFunc: gpu.ApplyConfig, + cdiVendor: gpu.CDIVendor, + cdiClass: gpu.CDIClass, + enumerateDevicesFunc: gpu.EnumerateAllPossibleDevices(flags.nodeName, flags.numDevices), } return RunPlugin(ctx, config) diff --git a/cmd/dra-example-kubeletplugin/state.go b/cmd/dra-example-kubeletplugin/state.go index 3e2f83e2..27a81f34 100644 --- a/cmd/dra-example-kubeletplugin/state.go +++ b/cmd/dra-example-kubeletplugin/state.go @@ -24,6 +24,7 @@ import ( resourceapi "k8s.io/api/resource/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/serializer/json" + "k8s.io/dynamic-resource-allocation/resourceslice" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" "k8s.io/kubernetes/pkg/kubelet/checkpointmanager" @@ -44,6 +45,7 @@ type OpaqueDeviceConfig struct { type DeviceState struct { sync.Mutex cdi *CDIHandler + driverResources resourceslice.DriverResources allocatable AllocatableDevices checkpointManager checkpointmanager.CheckpointManager configDecoder runtime.Decoder @@ -51,7 +53,7 @@ type DeviceState struct { } func NewDeviceState(config *Config) (*DeviceState, error) { - allocatable, err := enumerateAllPossibleDevices(config.flags.numDevices) + driverResources, err := config.enumerateDevicesFunc() if err != nil { return nil, fmt.Errorf("error enumerating all possible devices: %v", err) } @@ -81,8 +83,16 @@ func NewDeviceState(config *Config) (*DeviceState, error) { }, ) + allocatable := make(AllocatableDevices) + for _, slice := range driverResources.Pools[config.flags.nodeName].Slices { + for _, device := range slice.Devices { + allocatable[device.Name] = device + } + } + state := &DeviceState{ cdi: cdi, + driverResources: driverResources, allocatable: allocatable, checkpointManager: checkpointManager, configDecoder: decoder, diff --git a/internal/profiles/gpu/gpu.go b/internal/profiles/gpu/gpu.go index fdb6efff..4c329b23 100644 --- a/internal/profiles/gpu/gpu.go +++ b/internal/profiles/gpu/gpu.go @@ -18,9 +18,14 @@ package gpu import ( "fmt" + "math/rand" + "github.com/google/uuid" resourceapi "k8s.io/api/resource/v1" + "k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/dynamic-resource-allocation/resourceslice" + "k8s.io/utils/ptr" cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" cdispec "tags.cncf.io/container-device-interface/specs-go" @@ -34,6 +39,76 @@ const ( CDIClass = "gpu" ) +func EnumerateAllPossibleDevices(nodeName string, numGPUs int) func() (resourceslice.DriverResources, error) { + return func() (resourceslice.DriverResources, error) { + seed := nodeName + uuids := generateUUIDs(seed, numGPUs) + + var devices []resourceapi.Device + for i, uuid := range uuids { + device := resourceapi.Device{ + Name: fmt.Sprintf("gpu-%d", i), + Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ + "index": { + IntValue: ptr.To(int64(i)), + }, + "uuid": { + StringValue: ptr.To(uuid), + }, + "model": { + StringValue: ptr.To("LATEST-GPU-MODEL"), + }, + "driverVersion": { + VersionValue: ptr.To("1.0.0"), + }, + }, + Capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{ + "memory": { + Value: resource.MustParse("80Gi"), + }, + }, + } + devices = append(devices, device) + } + + resources := resourceslice.DriverResources{ + Pools: map[string]resourceslice.Pool{ + nodeName: { + Slices: []resourceslice.Slice{ + { + Devices: devices, + }, + }, + }, + }, + } + + return resources, nil + } +} + +func generateUUIDs(seed string, count int) []string { + rand := rand.New(rand.NewSource(hash(seed))) + + uuids := make([]string, count) + for i := 0; i < count; i++ { + charset := make([]byte, 16) + rand.Read(charset) + uuid, _ := uuid.FromBytes(charset) + uuids[i] = "gpu-" + uuid.String() + } + + return uuids +} + +func hash(s string) int64 { + h := int64(0) + for _, c := range s { + h = 31*h + int64(c) + } + return h +} + // ApplyConfig applies a configuration to a set of device allocation results. // // In this example driver there is no actual configuration applied. We simply From dc09f71140e2f509c6bb22c4bd10ea6f74845296 Mon Sep 17 00:00:00 2001 From: Jon Huhn Date: Tue, 18 Nov 2025 16:34:05 -0600 Subject: [PATCH 5/9] Extract webhook config validation --- cmd/dra-example-webhook/main.go | 30 ++++++++++++++++------------ cmd/dra-example-webhook/main_test.go | 5 +++-- internal/profiles/gpu/gpu.go | 8 ++++++++ 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/cmd/dra-example-webhook/main.go b/cmd/dra-example-webhook/main.go index 69b9191e..3629ceb4 100644 --- a/cmd/dra-example-webhook/main.go +++ b/cmd/dra-example-webhook/main.go @@ -34,6 +34,7 @@ import ( "k8s.io/klog/v2" configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" + "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" "sigs.k8s.io/dra-example-driver/pkg/consts" "sigs.k8s.io/dra-example-driver/pkg/flags" ) @@ -48,6 +49,8 @@ type Flags struct { var configScheme = runtime.NewScheme() +type validator func(runtime.Object) error + func main() { if err := newApp().Run(os.Args); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) @@ -94,15 +97,21 @@ func newApp() *cli.App { return flags.loggingConfig.Apply() }, Action: func(c *cli.Context) error { - sb := runtime.NewSchemeBuilder( + gpuSchemeBuilder := runtime.NewSchemeBuilder( configapi.AddToScheme, ) + gpuValidator := gpu.ValidateConfig + + // TODO: select based on profile + sb := gpuSchemeBuilder + validate := gpuValidator + if err := sb.AddToScheme(configScheme); err != nil { return fmt.Errorf("create config scheme: %w", err) } server := &http.Server{ - Handler: newMux(newConfigDecoder()), + Handler: newMux(newConfigDecoder(), validate), Addr: fmt.Sprintf(":%d", flags.port), } klog.Info("starting webhook server on", server.Addr) @@ -125,9 +134,9 @@ func newConfigDecoder() runtime.Decoder { ) } -func newMux(configDecoder runtime.Decoder) *http.ServeMux { +func newMux(configDecoder runtime.Decoder, validate validator) *http.ServeMux { mux := http.NewServeMux() - mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder)) + mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder, validate)) mux.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) { _, err := w.Write([]byte("ok")) if err != nil { @@ -138,9 +147,9 @@ func newMux(configDecoder runtime.Decoder) *http.ServeMux { return mux } -func serveResourceClaim(configDecoder runtime.Decoder) func(http.ResponseWriter, *http.Request) { +func serveResourceClaim(configDecoder runtime.Decoder, validate validator) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - serve(w, r, admitResourceClaimParameters(configDecoder)) + serve(w, r, admitResourceClaimParameters(configDecoder, validate)) } } @@ -215,7 +224,7 @@ func readAdmissionReview(data []byte) (*admissionv1.AdmissionReview, error) { // admitResourceClaimParameters accepts both ResourceClaims and ResourceClaimTemplates and validates their // opaque device configuration parameters for this driver. -func admitResourceClaimParameters(configDecoder runtime.Decoder) func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { +func admitResourceClaimParameters(configDecoder runtime.Decoder, validate validator) func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { return func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { klog.V(2).Info("admitting resource claim parameters") @@ -279,12 +288,7 @@ func admitResourceClaimParameters(configDecoder runtime.Decoder) func(ar admissi errs = append(errs, fmt.Errorf("error decoding object at %s: %w", fieldPath, err)) continue } - gpuConfig, ok := decodedConfig.(*configapi.GpuConfig) - if !ok { - errs = append(errs, fmt.Errorf("expected v1alpha1.GpuConfig at %s but got: %T", fieldPath, decodedConfig)) - continue - } - err = gpuConfig.Validate() + err = validate(decodedConfig) if err != nil { errs = append(errs, fmt.Errorf("object at %s is invalid: %w", fieldPath, err)) } diff --git a/cmd/dra-example-webhook/main_test.go b/cmd/dra-example-webhook/main_test.go index dbfb1c03..51cc3a06 100644 --- a/cmd/dra-example-webhook/main_test.go +++ b/cmd/dra-example-webhook/main_test.go @@ -36,11 +36,12 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" + "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" "sigs.k8s.io/dra-example-driver/pkg/consts" ) func TestReadyEndpoint(t *testing.T) { - s := httptest.NewServer(newMux(nil)) + s := httptest.NewServer(newMux(nil, nil)) t.Cleanup(s.Close) res, err := http.Get(s.URL + "/readyz") @@ -171,7 +172,7 @@ func TestResourceClaimValidatingWebhook(t *testing.T) { sb := gpu.ConfigSchemeBuilder assert.NoError(t, sb.AddToScheme(configScheme)) - s := httptest.NewServer(newMux(newConfigDecoder())) + s := httptest.NewServer(newMux(newConfigDecoder(), gpu.ValidateConfig)) t.Cleanup(s.Close) for name, test := range tests { diff --git a/internal/profiles/gpu/gpu.go b/internal/profiles/gpu/gpu.go index 4c329b23..ae1dc990 100644 --- a/internal/profiles/gpu/gpu.go +++ b/internal/profiles/gpu/gpu.go @@ -109,6 +109,14 @@ func hash(s string) int64 { return h } +func ValidateConfig(config runtime.Object) error { + gpuConfig, ok := config.(*configapi.GpuConfig) + if !ok { + return fmt.Errorf("expected v1alpha1.GpuConfig but got: %T", config) + } + return gpuConfig.Validate() +} + // ApplyConfig applies a configuration to a set of device allocation results. // // In this example driver there is no actual configuration applied. We simply From b1ea5873c548e4472edd6616a7f7a1adc709b0d2 Mon Sep 17 00:00:00 2001 From: Jon Huhn Date: Wed, 19 Nov 2025 11:03:33 -0600 Subject: [PATCH 6/9] Add --device-profile CLI flag --- cmd/dra-example-kubeletplugin/main.go | 53 ++++++++++++++++++--------- cmd/dra-example-webhook/main.go | 30 +++++++++++---- internal/profiles/gpu/gpu.go | 6 +++ 3 files changed, 64 insertions(+), 25 deletions(-) diff --git a/cmd/dra-example-kubeletplugin/main.go b/cmd/dra-example-kubeletplugin/main.go index 53bdc274..7e292ae7 100644 --- a/cmd/dra-example-kubeletplugin/main.go +++ b/cmd/dra-example-kubeletplugin/main.go @@ -33,7 +33,6 @@ import ( "k8s.io/dynamic-resource-allocation/resourceslice" "k8s.io/klog/v2" - configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" "sigs.k8s.io/dra-example-driver/pkg/consts" "sigs.k8s.io/dra-example-driver/pkg/flags" @@ -53,6 +52,7 @@ type Flags struct { kubeletRegistrarDirectoryPath string kubeletPluginsDirectoryPath string healthcheckPort int + profile string } type Config struct { @@ -60,13 +60,16 @@ type Config struct { coreclient coreclientset.Interface cancelMainCtx func(error) - // Config types - configScheme *runtime.Scheme + configScheme *runtime.Scheme // scheme for opaque config types applyConfigFunc ApplyConfigFunc cdiClass string enumerateDevicesFunc func() (resourceslice.DriverResources, error) } +var validProfiles = []string{ + gpu.ProfileName, +} + func (c Config) DriverPluginPath() string { return filepath.Join(c.flags.kubeletPluginsDirectoryPath, consts.DriverName) } @@ -99,7 +102,7 @@ func newApp() *cli.App { }, &cli.IntFlag{ Name: "num-devices", - Usage: "The number of devices to be generated.", + Usage: "The number of devices to be generated. Only relevant for the " + gpu.ProfileName + " profile.", Value: 8, Destination: &flags.numDevices, EnvVars: []string{"NUM_DEVICES"}, @@ -125,6 +128,13 @@ func newApp() *cli.App { Destination: &flags.healthcheckPort, EnvVars: []string{"HEALTHCHECK_PORT"}, }, + &cli.StringFlag{ + Name: "device-profile", + Usage: fmt.Sprintf("Name of the device profile. Valid values are %q.", validProfiles), + Value: gpu.ProfileName, + Destination: &flags.profile, + EnvVars: []string{"DEVICE_PROFILE"}, + }, } cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...) cliFlags = append(cliFlags, flags.loggingConfig.Flags()...) @@ -148,25 +158,34 @@ func newApp() *cli.App { return fmt.Errorf("create client: %w", err) } - configScheme := runtime.NewScheme() - sb := runtime.NewSchemeBuilder( - // TODO: only add the API versions that apply to a given profile - configapi.AddToScheme, + var ( + sb runtime.SchemeBuilder + applyConfigFunc ApplyConfigFunc + cdiClass string + enumerateDevicesFunc func() (resourceslice.DriverResources, error) ) + switch flags.profile { + case gpu.ProfileName: + sb = gpu.ConfigSchemeBuilder + applyConfigFunc = gpu.ApplyConfig + cdiClass = gpu.CDIClass + enumerateDevicesFunc = gpu.EnumerateAllPossibleDevices(flags.nodeName, flags.numDevices) + default: + return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfiles) + } + + configScheme := runtime.NewScheme() if err := sb.AddToScheme(configScheme); err != nil { return fmt.Errorf("create config scheme: %w", err) } config := &Config{ - flags: flags, - coreclient: clientSets.Core, - configScheme: configScheme, - - // TODO: select an implementation based on the profile - applyConfigFunc: gpu.ApplyConfig, - cdiVendor: gpu.CDIVendor, - cdiClass: gpu.CDIClass, - enumerateDevicesFunc: gpu.EnumerateAllPossibleDevices(flags.nodeName, flags.numDevices), + flags: flags, + coreclient: clientSets.Core, + configScheme: configScheme, + applyConfigFunc: applyConfigFunc, + cdiClass: cdiClass, + enumerateDevicesFunc: enumerateDevicesFunc, } return RunPlugin(ctx, config) diff --git a/cmd/dra-example-webhook/main.go b/cmd/dra-example-webhook/main.go index 3629ceb4..e4400f4b 100644 --- a/cmd/dra-example-webhook/main.go +++ b/cmd/dra-example-webhook/main.go @@ -33,7 +33,6 @@ import ( kjson "k8s.io/apimachinery/pkg/runtime/serializer/json" "k8s.io/klog/v2" - configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" "sigs.k8s.io/dra-example-driver/pkg/consts" "sigs.k8s.io/dra-example-driver/pkg/flags" @@ -45,12 +44,17 @@ type Flags struct { certFile string keyFile string port int + profile string } var configScheme = runtime.NewScheme() type validator func(runtime.Object) error +var validProfiles = []string{ + gpu.ProfileName, +} + func main() { if err := newApp().Run(os.Args); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) @@ -81,6 +85,13 @@ func newApp() *cli.App { Value: 443, Destination: &flags.port, }, + &cli.StringFlag{ + Name: "device-profile", + Usage: fmt.Sprintf("Name of the device profile. Valid values are %q.", validProfiles), + Value: gpu.ProfileName, + Destination: &flags.profile, + EnvVars: []string{"DEVICE_PROFILE"}, + }, } cliFlags = append(cliFlags, flags.loggingConfig.Flags()...) @@ -97,14 +108,17 @@ func newApp() *cli.App { return flags.loggingConfig.Apply() }, Action: func(c *cli.Context) error { - gpuSchemeBuilder := runtime.NewSchemeBuilder( - configapi.AddToScheme, + var ( + sb runtime.SchemeBuilder + validate validator ) - gpuValidator := gpu.ValidateConfig - - // TODO: select based on profile - sb := gpuSchemeBuilder - validate := gpuValidator + switch flags.profile { + case gpu.ProfileName: + sb = gpu.ConfigSchemeBuilder + validate = gpu.ValidateConfig + default: + return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfiles) + } if err := sb.AddToScheme(configScheme); err != nil { return fmt.Errorf("create config scheme: %w", err) diff --git a/internal/profiles/gpu/gpu.go b/internal/profiles/gpu/gpu.go index ae1dc990..0e093a1b 100644 --- a/internal/profiles/gpu/gpu.go +++ b/internal/profiles/gpu/gpu.go @@ -34,11 +34,17 @@ import ( "sigs.k8s.io/dra-example-driver/pkg/consts" ) +const ProfileName = "gpu" + const ( CDIVendor = "k8s." + consts.DriverName CDIClass = "gpu" ) +var ConfigSchemeBuilder = runtime.NewSchemeBuilder( + configapi.AddToScheme, +) + func EnumerateAllPossibleDevices(nodeName string, numGPUs int) func() (resourceslice.DriverResources, error) { return func() (resourceslice.DriverResources, error) { seed := nodeName From 15653a6c0fc5b11d74310113f1dda30ada047383 Mon Sep 17 00:00:00 2001 From: Jon Huhn Date: Wed, 19 Nov 2025 12:54:11 -0600 Subject: [PATCH 7/9] Add deviceProfile value to chart --- .../templates/kubeletplugin.yaml | 2 ++ .../templates/webhook-deployment.yaml | 1 + .../dra-example-driver/values.schema.json | 11 ++++++++ .../helm/dra-example-driver/values.yaml | 26 +++++-------------- 4 files changed, 21 insertions(+), 19 deletions(-) create mode 100644 deployments/helm/dra-example-driver/values.schema.json diff --git a/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml b/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml index bcf44ffd..1e0e2001 100644 --- a/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml +++ b/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml @@ -59,6 +59,8 @@ spec: periodSeconds: 10 {{- end }} env: + - name: DEVICE_PROFILE + value: {{ .Values.deviceProfile | quote }} - name: CDI_ROOT value: /var/run/cdi - name: KUBELET_REGISTRAR_DIRECTORY_PATH diff --git a/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml b/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml index 1f574da5..f4915ea6 100644 --- a/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml +++ b/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml @@ -43,6 +43,7 @@ spec: - --tls-cert-file=/cert/tls.crt - --tls-private-key-file=/cert/tls.key - --port={{ .Values.webhook.containerPort }} + - --device-profile={{ .Values.deviceProfile }} ports: - name: webhook containerPort: {{ .Values.webhook.containerPort }} diff --git a/deployments/helm/dra-example-driver/values.schema.json b/deployments/helm/dra-example-driver/values.schema.json new file mode 100644 index 00000000..3dff0a3a --- /dev/null +++ b/deployments/helm/dra-example-driver/values.schema.json @@ -0,0 +1,11 @@ +{ + "type": "object", + "properties": { + "deviceProfile": { + "type": "string", + "enum": [ + "gpu" + ] + } + } +} diff --git a/deployments/helm/dra-example-driver/values.yaml b/deployments/helm/dra-example-driver/values.yaml index e58c9d33..a23609a5 100644 --- a/deployments/helm/dra-example-driver/values.yaml +++ b/deployments/helm/dra-example-driver/values.yaml @@ -9,6 +9,11 @@ selectorLabelsOverride: {} allowDefaultNamespace: false +# deviceProfile describes the overall shape of the devices managed by the +# driver. Available profiles are: +# - "gpu": Node-local devices configurable through opaque config +deviceProfile: "gpu" + imagePullSecrets: [] image: repository: registry.k8s.io/dra-example-driver/dra-example-driver @@ -25,26 +30,9 @@ serviceAccount: # If not set and create is true, a name is generated using the fullname template name: "" -controller: - priorityClassName: "system-node-critical" - podAnnotations: {} - podSecurityContext: {} - nodeSelector: - node-role.kubernetes.io/control-plane: "" - tolerations: - - key: node-role.kubernetes.io/master - operator: Exists - effect: NoSchedule - - key: node-role.kubernetes.io/control-plane - operator: Exists - effect: NoSchedule - affinity: {} - containers: - controller: - securityContext: {} - resources: {} - kubeletPlugin: + # numDevices describes how many GPUs to advertise on each node when the "gpu" + # deviceProfile is used. Not relevant for other profiles. numDevices: 8 priorityClassName: "system-node-critical" updateStrategy: From 07790fb335b995a6df8ff86848073c0f34121931 Mon Sep 17 00:00:00 2001 From: Jon Huhn Date: Wed, 19 Nov 2025 16:18:23 -0600 Subject: [PATCH 8/9] Set driver name based on device profile --- cmd/dra-example-kubeletplugin/cdi.go | 3 +- cmd/dra-example-kubeletplugin/driver.go | 4 +-- cmd/dra-example-kubeletplugin/health.go | 4 +-- cmd/dra-example-kubeletplugin/main.go | 14 ++++++-- cmd/dra-example-kubeletplugin/state.go | 13 ++++--- cmd/dra-example-webhook/main.go | 34 ++++++++++++------- cmd/dra-example-webhook/main_test.go | 9 ++--- .../dra-example-driver/templates/_helpers.tpl | 7 ++++ .../templates/deviceclass.yaml | 6 ++-- .../templates/kubeletplugin.yaml | 2 ++ .../templates/webhook-deployment.yaml | 1 + .../helm/dra-example-driver/values.yaml | 4 +++ internal/profiles/gpu/gpu.go | 6 +--- pkg/consts/consts.go | 19 ----------- 14 files changed, 69 insertions(+), 57 deletions(-) delete mode 100644 pkg/consts/consts.go diff --git a/cmd/dra-example-kubeletplugin/cdi.go b/cmd/dra-example-kubeletplugin/cdi.go index 42086d93..e92688e0 100644 --- a/cmd/dra-example-kubeletplugin/cdi.go +++ b/cmd/dra-example-kubeletplugin/cdi.go @@ -27,7 +27,6 @@ import ( cdispec "tags.cncf.io/container-device-interface/specs-go" "sigs.k8s.io/dra-example-driver/internal/profiles" - "sigs.k8s.io/dra-example-driver/pkg/consts" ) const cdiCommonDeviceName = "common" @@ -65,7 +64,7 @@ func (cdi *CDIHandler) CreateCommonSpecFile() error { ContainerEdits: cdispec.ContainerEdits{ Env: []string{ fmt.Sprintf("KUBERNETES_NODE_NAME=%s", os.Getenv("NODE_NAME")), - fmt.Sprintf("DRA_RESOURCE_DRIVER_NAME=%s", consts.DriverName), + fmt.Sprintf("DRA_RESOURCE_DRIVER_NAME=%s", cdi.driverName), }, }, }, diff --git a/cmd/dra-example-kubeletplugin/driver.go b/cmd/dra-example-kubeletplugin/driver.go index c5cf7101..d0ebf77f 100644 --- a/cmd/dra-example-kubeletplugin/driver.go +++ b/cmd/dra-example-kubeletplugin/driver.go @@ -27,8 +27,6 @@ import ( coreclientset "k8s.io/client-go/kubernetes" "k8s.io/dynamic-resource-allocation/kubeletplugin" "k8s.io/klog/v2" - - "sigs.k8s.io/dra-example-driver/pkg/consts" ) type driver struct { @@ -56,7 +54,7 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) { driver, kubeletplugin.KubeClient(config.coreclient), kubeletplugin.NodeName(config.flags.nodeName), - kubeletplugin.DriverName(consts.DriverName), + kubeletplugin.DriverName(config.flags.driverName), kubeletplugin.RegistrarDirectoryPath(config.flags.kubeletRegistrarDirectoryPath), kubeletplugin.PluginDataDirectoryPath(config.DriverPluginPath()), ) diff --git a/cmd/dra-example-kubeletplugin/health.go b/cmd/dra-example-kubeletplugin/health.go index 14d393aa..b5ddf7da 100644 --- a/cmd/dra-example-kubeletplugin/health.go +++ b/cmd/dra-example-kubeletplugin/health.go @@ -33,8 +33,6 @@ import ( "k8s.io/klog/v2" drapb "k8s.io/kubelet/pkg/apis/dra/v1" registerapi "k8s.io/kubelet/pkg/apis/pluginregistration/v1" - - "sigs.k8s.io/dra-example-driver/pkg/consts" ) type healthcheck struct { @@ -65,7 +63,7 @@ func startHealthcheck(ctx context.Context, config *Config) (*healthcheck, error) Scheme: "unix", // TODO: this needs to adapt when seamless upgrades // are enabled and the filename includes a uid. - Path: path.Join(config.flags.kubeletRegistrarDirectoryPath, consts.DriverName+"-reg.sock"), + Path: path.Join(config.flags.kubeletRegistrarDirectoryPath, config.flags.driverName+"-reg.sock"), }).String() log.Info("connecting to registration socket", "path", regSockPath) regConn, err := grpc.NewClient( diff --git a/cmd/dra-example-kubeletplugin/main.go b/cmd/dra-example-kubeletplugin/main.go index 7e292ae7..f68193d5 100644 --- a/cmd/dra-example-kubeletplugin/main.go +++ b/cmd/dra-example-kubeletplugin/main.go @@ -34,7 +34,6 @@ import ( "k8s.io/klog/v2" "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" - "sigs.k8s.io/dra-example-driver/pkg/consts" "sigs.k8s.io/dra-example-driver/pkg/flags" ) @@ -53,6 +52,7 @@ type Flags struct { kubeletPluginsDirectoryPath string healthcheckPort int profile string + driverName string } type Config struct { @@ -71,7 +71,7 @@ var validProfiles = []string{ } func (c Config) DriverPluginPath() string { - return filepath.Join(c.flags.kubeletPluginsDirectoryPath, consts.DriverName) + return filepath.Join(c.flags.kubeletPluginsDirectoryPath, c.flags.driverName) } func main() { @@ -135,6 +135,12 @@ func newApp() *cli.App { Destination: &flags.profile, EnvVars: []string{"DEVICE_PROFILE"}, }, + &cli.StringFlag{ + Name: "driver-name", + Usage: "Name of the DRA driver. Its default is derived from the device profile.", + Destination: &flags.driverName, + EnvVars: []string{"DRIVER_NAME"}, + }, } cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...) cliFlags = append(cliFlags, flags.loggingConfig.Flags()...) @@ -158,6 +164,10 @@ func newApp() *cli.App { return fmt.Errorf("create client: %w", err) } + if flags.driverName == "" { + flags.driverName = flags.profile + ".example.com" + } + var ( sb runtime.SchemeBuilder applyConfigFunc ApplyConfigFunc diff --git a/cmd/dra-example-kubeletplugin/state.go b/cmd/dra-example-kubeletplugin/state.go index 27a81f34..48bf2d54 100644 --- a/cmd/dra-example-kubeletplugin/state.go +++ b/cmd/dra-example-kubeletplugin/state.go @@ -29,7 +29,6 @@ import ( "k8s.io/kubernetes/pkg/kubelet/checkpointmanager" "sigs.k8s.io/dra-example-driver/internal/profiles" - "sigs.k8s.io/dra-example-driver/pkg/consts" ) type AllocatableDevices map[string]resourceapi.Device @@ -44,6 +43,7 @@ type OpaqueDeviceConfig struct { type DeviceState struct { sync.Mutex + driverName string cdi *CDIHandler driverResources resourceslice.DriverResources allocatable AllocatableDevices @@ -58,7 +58,7 @@ func NewDeviceState(config *Config) (*DeviceState, error) { return nil, fmt.Errorf("error enumerating all possible devices: %v", err) } - cdi, err := NewCDIHandler(config.flags.cdiRoot, consts.DriverName, config.cdiClass) + cdi, err := NewCDIHandler(config.flags.cdiRoot, config.flags.driverName, config.cdiClass) if err != nil { return nil, fmt.Errorf("unable to create CDI handler: %v", err) } @@ -91,6 +91,7 @@ func NewDeviceState(config *Config) (*DeviceState, error) { } state := &DeviceState{ + driverName: config.flags.driverName, cdi: cdi, driverResources: driverResources, allocatable: allocatable, @@ -190,7 +191,7 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (profiles // Retrieve the full set of device configs for the driver. configs, err := GetOpaqueDeviceConfigs( s.configDecoder, - consts.DriverName, + s.driverName, claim.Status.Allocation.Devices.Config, ) if err != nil { @@ -206,8 +207,12 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (profiles // each device allocation result based on their order of precedence. configResultsMap := make(map[runtime.Object][]*resourceapi.DeviceRequestAllocationResult) for _, result := range claim.Status.Allocation.Devices.Results { + // The claim may include allocations meant for other drivers. + if result.Driver != s.driverName { + continue + } if _, exists := s.allocatable[result.Device]; !exists { - return nil, fmt.Errorf("requested GPU is not allocatable: %v", result.Device) + return nil, fmt.Errorf("requested device is not allocatable: %v", result.Device) } for _, c := range slices.Backward(configs) { if len(c.Requests) == 0 || slices.Contains(c.Requests, result.Request) { diff --git a/cmd/dra-example-webhook/main.go b/cmd/dra-example-webhook/main.go index e4400f4b..be5fe807 100644 --- a/cmd/dra-example-webhook/main.go +++ b/cmd/dra-example-webhook/main.go @@ -34,17 +34,17 @@ import ( "k8s.io/klog/v2" "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" - "sigs.k8s.io/dra-example-driver/pkg/consts" "sigs.k8s.io/dra-example-driver/pkg/flags" ) type Flags struct { loggingConfig *flags.LoggingConfig - certFile string - keyFile string - port int - profile string + certFile string + keyFile string + port int + profile string + driverName string } var configScheme = runtime.NewScheme() @@ -92,6 +92,12 @@ func newApp() *cli.App { Destination: &flags.profile, EnvVars: []string{"DEVICE_PROFILE"}, }, + &cli.StringFlag{ + Name: "driver-name", + Usage: "Name of the DRA driver. Its default is derived from the device profile.", + Destination: &flags.driverName, + EnvVars: []string{"DRIVER_NAME"}, + }, } cliFlags = append(cliFlags, flags.loggingConfig.Flags()...) @@ -120,12 +126,16 @@ func newApp() *cli.App { return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfiles) } + if flags.driverName == "" { + flags.driverName = flags.profile + ".example.com" + } + if err := sb.AddToScheme(configScheme); err != nil { return fmt.Errorf("create config scheme: %w", err) } server := &http.Server{ - Handler: newMux(newConfigDecoder(), validate), + Handler: newMux(newConfigDecoder(), validate, flags.driverName), Addr: fmt.Sprintf(":%d", flags.port), } klog.Info("starting webhook server on", server.Addr) @@ -148,9 +158,9 @@ func newConfigDecoder() runtime.Decoder { ) } -func newMux(configDecoder runtime.Decoder, validate validator) *http.ServeMux { +func newMux(configDecoder runtime.Decoder, validate validator, driverName string) *http.ServeMux { mux := http.NewServeMux() - mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder, validate)) + mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder, validate, driverName)) mux.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) { _, err := w.Write([]byte("ok")) if err != nil { @@ -161,9 +171,9 @@ func newMux(configDecoder runtime.Decoder, validate validator) *http.ServeMux { return mux } -func serveResourceClaim(configDecoder runtime.Decoder, validate validator) func(http.ResponseWriter, *http.Request) { +func serveResourceClaim(configDecoder runtime.Decoder, validate validator, driverName string) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - serve(w, r, admitResourceClaimParameters(configDecoder, validate)) + serve(w, r, admitResourceClaimParameters(configDecoder, validate, driverName)) } } @@ -238,7 +248,7 @@ func readAdmissionReview(data []byte) (*admissionv1.AdmissionReview, error) { // admitResourceClaimParameters accepts both ResourceClaims and ResourceClaimTemplates and validates their // opaque device configuration parameters for this driver. -func admitResourceClaimParameters(configDecoder runtime.Decoder, validate validator) func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { +func admitResourceClaimParameters(configDecoder runtime.Decoder, validate validator, driverName string) func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { return func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { klog.V(2).Info("admitting resource claim parameters") @@ -292,7 +302,7 @@ func admitResourceClaimParameters(configDecoder runtime.Decoder, validate valida var errs []error for configIndex, config := range deviceConfigs { - if config.Opaque == nil || config.Opaque.Driver != consts.DriverName { + if config.Opaque == nil || config.Opaque.Driver != driverName { continue } diff --git a/cmd/dra-example-webhook/main_test.go b/cmd/dra-example-webhook/main_test.go index 51cc3a06..7fcaf580 100644 --- a/cmd/dra-example-webhook/main_test.go +++ b/cmd/dra-example-webhook/main_test.go @@ -37,11 +37,12 @@ import ( configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" - "sigs.k8s.io/dra-example-driver/pkg/consts" ) +const driverName = "gpu.example.com" + func TestReadyEndpoint(t *testing.T) { - s := httptest.NewServer(newMux(nil, nil)) + s := httptest.NewServer(newMux(nil, nil, "")) t.Cleanup(s.Close) res, err := http.Get(s.URL + "/readyz") @@ -172,7 +173,7 @@ func TestResourceClaimValidatingWebhook(t *testing.T) { sb := gpu.ConfigSchemeBuilder assert.NoError(t, sb.AddToScheme(configScheme)) - s := httptest.NewServer(newMux(newConfigDecoder(), gpu.ValidateConfig)) + s := httptest.NewServer(newMux(newConfigDecoder(), gpu.ValidateConfig, driverName)) t.Cleanup(s.Close) for name, test := range tests { @@ -253,7 +254,7 @@ func resourceClaimSpecWithGpuConfigs(gpuConfigs ...*configapi.GpuConfig) resourc deviceConfig := resourceapi.DeviceClaimConfiguration{ DeviceConfiguration: resourceapi.DeviceConfiguration{ Opaque: &resourceapi.OpaqueDeviceConfiguration{ - Driver: consts.DriverName, + Driver: driverName, Parameters: runtime.RawExtension{ Object: gpuConfig, }, diff --git a/deployments/helm/dra-example-driver/templates/_helpers.tpl b/deployments/helm/dra-example-driver/templates/_helpers.tpl index 63ec53df..6198ba65 100644 --- a/deployments/helm/dra-example-driver/templates/_helpers.tpl +++ b/deployments/helm/dra-example-driver/templates/_helpers.tpl @@ -121,3 +121,10 @@ resource.k8s.io/v1beta1 {{- else -}} {{- end -}} {{- end -}} + +{{/* +The driver name. +*/}} +{{- define "dra-example-driver.driverName" -}} +{{ default (print .Values.deviceProfile ".example.com") .Values.driverName }} +{{- end -}} diff --git a/deployments/helm/dra-example-driver/templates/deviceclass.yaml b/deployments/helm/dra-example-driver/templates/deviceclass.yaml index 8fd7085b..e2bb69cf 100644 --- a/deployments/helm/dra-example-driver/templates/deviceclass.yaml +++ b/deployments/helm/dra-example-driver/templates/deviceclass.yaml @@ -2,8 +2,8 @@ apiVersion: {{ include "dra-example-driver.resourceApiVersion" . }} kind: DeviceClass metadata: - name: gpu.example.com + name: {{ include "dra-example-driver.driverName" . }} spec: selectors: - - cel: - expression: "device.driver == 'gpu.example.com'" + - cel: + expression: "device.driver == '{{ include "dra-example-driver.driverName" . }}'" diff --git a/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml b/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml index 1e0e2001..5785a135 100644 --- a/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml +++ b/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml @@ -59,6 +59,8 @@ spec: periodSeconds: 10 {{- end }} env: + - name: DRIVER_NAME + value: {{ include "dra-example-driver.driverName" . | quote }} - name: DEVICE_PROFILE value: {{ .Values.deviceProfile | quote }} - name: CDI_ROOT diff --git a/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml b/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml index f4915ea6..920e879f 100644 --- a/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml +++ b/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml @@ -44,6 +44,7 @@ spec: - --tls-private-key-file=/cert/tls.key - --port={{ .Values.webhook.containerPort }} - --device-profile={{ .Values.deviceProfile }} + - --driver-name={{ include "dra-example-driver.driverName" . }} ports: - name: webhook containerPort: {{ .Values.webhook.containerPort }} diff --git a/deployments/helm/dra-example-driver/values.yaml b/deployments/helm/dra-example-driver/values.yaml index a23609a5..9ddededa 100644 --- a/deployments/helm/dra-example-driver/values.yaml +++ b/deployments/helm/dra-example-driver/values.yaml @@ -14,6 +14,10 @@ allowDefaultNamespace: false # - "gpu": Node-local devices configurable through opaque config deviceProfile: "gpu" +# driverName uniquely identifies the driver within the cluster. When empty, its +# value is derived from the deviceProfile. +driverName: "" + imagePullSecrets: [] image: repository: registry.k8s.io/dra-example-driver/dra-example-driver diff --git a/internal/profiles/gpu/gpu.go b/internal/profiles/gpu/gpu.go index 0e093a1b..068c72bd 100644 --- a/internal/profiles/gpu/gpu.go +++ b/internal/profiles/gpu/gpu.go @@ -31,15 +31,11 @@ import ( configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" "sigs.k8s.io/dra-example-driver/internal/profiles" - "sigs.k8s.io/dra-example-driver/pkg/consts" ) const ProfileName = "gpu" -const ( - CDIVendor = "k8s." + consts.DriverName - CDIClass = "gpu" -) +const CDIClass = "gpu" var ConfigSchemeBuilder = runtime.NewSchemeBuilder( configapi.AddToScheme, diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go deleted file mode 100644 index 23dacd2e..00000000 --- a/pkg/consts/consts.go +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2025 The Kubernetes Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package consts - -const DriverName = "gpu.example.com" From 3fbde91a200bb901f8cff4612baff92e34770abc Mon Sep 17 00:00:00 2001 From: Jon Huhn Date: Tue, 2 Dec 2025 15:37:47 -0600 Subject: [PATCH 9/9] Refactor profiles into interfaces --- cmd/dra-example-kubeletplugin/main.go | 55 +++++------- cmd/dra-example-kubeletplugin/state.go | 23 +++-- cmd/dra-example-webhook/main.go | 63 +++++++------- cmd/dra-example-webhook/main_test.go | 11 +-- internal/profiles/gpu/gpu.go | 111 ++++++++++++++----------- internal/profiles/profiles.go | 45 ++++++++++ 6 files changed, 179 insertions(+), 129 deletions(-) diff --git a/cmd/dra-example-kubeletplugin/main.go b/cmd/dra-example-kubeletplugin/main.go index f68193d5..757e7fa9 100644 --- a/cmd/dra-example-kubeletplugin/main.go +++ b/cmd/dra-example-kubeletplugin/main.go @@ -27,12 +27,11 @@ import ( "github.com/urfave/cli/v2" - "k8s.io/apimachinery/pkg/runtime" coreclientset "k8s.io/client-go/kubernetes" "k8s.io/dynamic-resource-allocation/kubeletplugin" - "k8s.io/dynamic-resource-allocation/resourceslice" "k8s.io/klog/v2" + "sigs.k8s.io/dra-example-driver/internal/profiles" "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" "sigs.k8s.io/dra-example-driver/pkg/flags" ) @@ -60,16 +59,23 @@ type Config struct { coreclient coreclientset.Interface cancelMainCtx func(error) - configScheme *runtime.Scheme // scheme for opaque config types - applyConfigFunc ApplyConfigFunc - cdiClass string - enumerateDevicesFunc func() (resourceslice.DriverResources, error) + profile profiles.Profile } -var validProfiles = []string{ - gpu.ProfileName, +var validProfiles = map[string]func(flags Flags) profiles.Profile{ + gpu.ProfileName: func(flags Flags) profiles.Profile { + return gpu.NewProfile(flags.nodeName, flags.numDevices) + }, } +var validProfileNames = func() []string { + var valid []string + for profileName := range validProfiles { + valid = append(valid, profileName) + } + return valid +}() + func (c Config) DriverPluginPath() string { return filepath.Join(c.flags.kubeletPluginsDirectoryPath, c.flags.driverName) } @@ -130,7 +136,7 @@ func newApp() *cli.App { }, &cli.StringFlag{ Name: "device-profile", - Usage: fmt.Sprintf("Name of the device profile. Valid values are %q.", validProfiles), + Usage: fmt.Sprintf("Name of the device profile. Valid values are %q.", validProfileNames), Value: gpu.ProfileName, Destination: &flags.profile, EnvVars: []string{"DEVICE_PROFILE"}, @@ -168,34 +174,15 @@ func newApp() *cli.App { flags.driverName = flags.profile + ".example.com" } - var ( - sb runtime.SchemeBuilder - applyConfigFunc ApplyConfigFunc - cdiClass string - enumerateDevicesFunc func() (resourceslice.DriverResources, error) - ) - switch flags.profile { - case gpu.ProfileName: - sb = gpu.ConfigSchemeBuilder - applyConfigFunc = gpu.ApplyConfig - cdiClass = gpu.CDIClass - enumerateDevicesFunc = gpu.EnumerateAllPossibleDevices(flags.nodeName, flags.numDevices) - default: - return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfiles) - } - - configScheme := runtime.NewScheme() - if err := sb.AddToScheme(configScheme); err != nil { - return fmt.Errorf("create config scheme: %w", err) + newProfile, ok := validProfiles[flags.profile] + if !ok { + return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfileNames) } config := &Config{ - flags: flags, - coreclient: clientSets.Core, - configScheme: configScheme, - applyConfigFunc: applyConfigFunc, - cdiClass: cdiClass, - enumerateDevicesFunc: enumerateDevicesFunc, + flags: flags, + coreclient: clientSets.Core, + profile: newProfile(*flags), } return RunPlugin(ctx, config) diff --git a/cmd/dra-example-kubeletplugin/state.go b/cmd/dra-example-kubeletplugin/state.go index 48bf2d54..26ea7cf8 100644 --- a/cmd/dra-example-kubeletplugin/state.go +++ b/cmd/dra-example-kubeletplugin/state.go @@ -34,8 +34,6 @@ import ( type AllocatableDevices map[string]resourceapi.Device type PreparedClaims map[string]profiles.PreparedDevices -type ApplyConfigFunc func(cconfig runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) - type OpaqueDeviceConfig struct { Requests []string Config runtime.Object @@ -49,16 +47,16 @@ type DeviceState struct { allocatable AllocatableDevices checkpointManager checkpointmanager.CheckpointManager configDecoder runtime.Decoder - applyConfigFunc ApplyConfigFunc + configHandler profiles.ConfigHandler } func NewDeviceState(config *Config) (*DeviceState, error) { - driverResources, err := config.enumerateDevicesFunc() + driverResources, err := config.profile.EnumerateDevices() if err != nil { return nil, fmt.Errorf("error enumerating all possible devices: %v", err) } - cdi, err := NewCDIHandler(config.flags.cdiRoot, config.flags.driverName, config.cdiClass) + cdi, err := NewCDIHandler(config.flags.cdiRoot, config.flags.driverName, config.flags.profile) if err != nil { return nil, fmt.Errorf("unable to create CDI handler: %v", err) } @@ -73,11 +71,18 @@ func NewDeviceState(config *Config) (*DeviceState, error) { return nil, fmt.Errorf("unable to create checkpoint manager: %v", err) } + configScheme := runtime.NewScheme() + configHandler := config.profile + sb := configHandler.SchemeBuilder() + if err := sb.AddToScheme(configScheme); err != nil { + return nil, fmt.Errorf("create config scheme: %w", err) + } + // Set up a json serializer to decode our types. decoder := json.NewSerializerWithOptions( json.DefaultMetaFactory, - config.configScheme, - config.configScheme, + configScheme, + configScheme, json.SerializerOptions{ Pretty: true, Strict: true, }, @@ -97,7 +102,7 @@ func NewDeviceState(config *Config) (*DeviceState, error) { allocatable: allocatable, checkpointManager: checkpointManager, configDecoder: decoder, - applyConfigFunc: config.applyConfigFunc, + configHandler: configHandler, } checkpoints, err := state.checkpointManager.ListCheckpoints() @@ -228,7 +233,7 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (profiles perDeviceCDIContainerEdits := make(profiles.PerDeviceCDIContainerEdits) for config, results := range configResultsMap { // Apply the config to the list of results associated with it. - containerEdits, err := s.applyConfigFunc(config, results) + containerEdits, err := s.configHandler.ApplyConfig(config, results) if err != nil { return nil, fmt.Errorf("error applying config: %w", err) } diff --git a/cmd/dra-example-webhook/main.go b/cmd/dra-example-webhook/main.go index be5fe807..9e1383c1 100644 --- a/cmd/dra-example-webhook/main.go +++ b/cmd/dra-example-webhook/main.go @@ -33,6 +33,7 @@ import ( kjson "k8s.io/apimachinery/pkg/runtime/serializer/json" "k8s.io/klog/v2" + "sigs.k8s.io/dra-example-driver/internal/profiles" "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" "sigs.k8s.io/dra-example-driver/pkg/flags" ) @@ -47,12 +48,10 @@ type Flags struct { driverName string } -var configScheme = runtime.NewScheme() - type validator func(runtime.Object) error -var validProfiles = []string{ - gpu.ProfileName, +var validProfiles = map[string]profiles.ConfigHandler{ + gpu.ProfileName: gpu.Profile{}, } func main() { @@ -114,28 +113,26 @@ func newApp() *cli.App { return flags.loggingConfig.Apply() }, Action: func(c *cli.Context) error { - var ( - sb runtime.SchemeBuilder - validate validator - ) - switch flags.profile { - case gpu.ProfileName: - sb = gpu.ConfigSchemeBuilder - validate = gpu.ValidateConfig - default: - return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfiles) + configHandler, ok := validProfiles[flags.profile] + if !ok { + var valid []string + for profileName := range validProfiles { + valid = append(valid, profileName) + } + return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, valid) } if flags.driverName == "" { flags.driverName = flags.profile + ".example.com" } - if err := sb.AddToScheme(configScheme); err != nil { - return fmt.Errorf("create config scheme: %w", err) + mux, err := newMux(configHandler, flags.driverName) + if err != nil { + return fmt.Errorf("create HTTP mux: %w", err) } server := &http.Server{ - Handler: newMux(newConfigDecoder(), validate, flags.driverName), + Handler: mux, Addr: fmt.Sprintf(":%d", flags.port), } klog.Info("starting webhook server on", server.Addr) @@ -146,9 +143,13 @@ func newApp() *cli.App { return app } -func newConfigDecoder() runtime.Decoder { - // Set up a json serializer to decode our types. - return kjson.NewSerializerWithOptions( +func newMux(configHandler profiles.ConfigHandler, driverName string) (*http.ServeMux, error) { + configScheme := runtime.NewScheme() + sb := configHandler.SchemeBuilder() + if err := sb.AddToScheme(configScheme); err != nil { + return nil, fmt.Errorf("create config scheme: %w", err) + } + configDecoder := kjson.NewSerializerWithOptions( kjson.DefaultMetaFactory, configScheme, configScheme, @@ -156,19 +157,19 @@ func newConfigDecoder() runtime.Decoder { Pretty: true, Strict: true, }, ) -} -func newMux(configDecoder runtime.Decoder, validate validator, driverName string) *http.ServeMux { mux := http.NewServeMux() - mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder, validate, driverName)) - mux.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) { - _, err := w.Write([]byte("ok")) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }) - return mux + mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder, configHandler.Validate, driverName)) + mux.HandleFunc("/readyz", readyHandler) + return mux, nil +} + +func readyHandler(w http.ResponseWriter, req *http.Request) { + _, err := w.Write([]byte("ok")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } func serveResourceClaim(configDecoder runtime.Decoder, validate validator, driverName string) func(http.ResponseWriter, *http.Request) { diff --git a/cmd/dra-example-webhook/main_test.go b/cmd/dra-example-webhook/main_test.go index 7fcaf580..e8408fae 100644 --- a/cmd/dra-example-webhook/main_test.go +++ b/cmd/dra-example-webhook/main_test.go @@ -42,10 +42,10 @@ import ( const driverName = "gpu.example.com" func TestReadyEndpoint(t *testing.T) { - s := httptest.NewServer(newMux(nil, nil, "")) + s := httptest.NewServer(http.HandlerFunc(readyHandler)) t.Cleanup(s.Close) - res, err := http.Get(s.URL + "/readyz") + res, err := http.Get(s.URL) assert.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) } @@ -170,10 +170,11 @@ func TestResourceClaimValidatingWebhook(t *testing.T) { }, } - sb := gpu.ConfigSchemeBuilder - assert.NoError(t, sb.AddToScheme(configScheme)) + configHandler := gpu.Profile{} + mux, err := newMux(configHandler, driverName) + assert.NoError(t, err) - s := httptest.NewServer(newMux(newConfigDecoder(), gpu.ValidateConfig, driverName)) + s := httptest.NewServer(mux) t.Cleanup(s.Close) for name, test := range tests { diff --git a/internal/profiles/gpu/gpu.go b/internal/profiles/gpu/gpu.go index 068c72bd..fd2401e7 100644 --- a/internal/profiles/gpu/gpu.go +++ b/internal/profiles/gpu/gpu.go @@ -35,58 +35,62 @@ import ( const ProfileName = "gpu" -const CDIClass = "gpu" +type Profile struct { + nodeName string + numGPUs int +} -var ConfigSchemeBuilder = runtime.NewSchemeBuilder( - configapi.AddToScheme, -) +func NewProfile(nodeName string, numGPUs int) Profile { + return Profile{ + nodeName: nodeName, + numGPUs: numGPUs, + } +} -func EnumerateAllPossibleDevices(nodeName string, numGPUs int) func() (resourceslice.DriverResources, error) { - return func() (resourceslice.DriverResources, error) { - seed := nodeName - uuids := generateUUIDs(seed, numGPUs) - - var devices []resourceapi.Device - for i, uuid := range uuids { - device := resourceapi.Device{ - Name: fmt.Sprintf("gpu-%d", i), - Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ - "index": { - IntValue: ptr.To(int64(i)), - }, - "uuid": { - StringValue: ptr.To(uuid), - }, - "model": { - StringValue: ptr.To("LATEST-GPU-MODEL"), - }, - "driverVersion": { - VersionValue: ptr.To("1.0.0"), - }, +func (p Profile) EnumerateDevices() (resourceslice.DriverResources, error) { + seed := p.nodeName + uuids := generateUUIDs(seed, p.numGPUs) + + var devices []resourceapi.Device + for i, uuid := range uuids { + device := resourceapi.Device{ + Name: fmt.Sprintf("gpu-%d", i), + Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ + "index": { + IntValue: ptr.To(int64(i)), }, - Capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{ - "memory": { - Value: resource.MustParse("80Gi"), - }, + "uuid": { + StringValue: ptr.To(uuid), }, - } - devices = append(devices, device) + "model": { + StringValue: ptr.To("LATEST-GPU-MODEL"), + }, + "driverVersion": { + VersionValue: ptr.To("1.0.0"), + }, + }, + Capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{ + "memory": { + Value: resource.MustParse("80Gi"), + }, + }, } + devices = append(devices, device) + } - resources := resourceslice.DriverResources{ - Pools: map[string]resourceslice.Pool{ - nodeName: { - Slices: []resourceslice.Slice{ - { - Devices: devices, - }, + resources := resourceslice.DriverResources{ + Pools: map[string]resourceslice.Pool{ + p.nodeName: { + Slices: []resourceslice.Slice{ + { + Devices: devices, }, }, }, - } - - return resources, nil + }, } + + return resources, nil } func generateUUIDs(seed string, count int) []string { @@ -111,7 +115,15 @@ func hash(s string) int64 { return h } -func ValidateConfig(config runtime.Object) error { +// SchemeBuilder implements [profiles.ConfigHandler]. +func (p Profile) SchemeBuilder() runtime.SchemeBuilder { + return runtime.NewSchemeBuilder( + configapi.AddToScheme, + ) +} + +// Validate implements [profiles.ConfigHandler]. +func (p Profile) Validate(config runtime.Object) error { gpuConfig, ok := config.(*configapi.GpuConfig) if !ok { return fmt.Errorf("expected v1alpha1.GpuConfig but got: %T", config) @@ -119,13 +131,8 @@ func ValidateConfig(config runtime.Object) error { return gpuConfig.Validate() } -// ApplyConfig applies a configuration to a set of device allocation results. -// -// In this example driver there is no actual configuration applied. We simply -// define a set of environment variables to be injected into the containers -// that include a given device. A real driver would likely need to do some sort -// of hardware configuration as well, based on the config passed in. -func ApplyConfig(config runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) { +// ApplyConfig implements [profiles.ConfigHandler]. +func (p Profile) ApplyConfig(config runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) { if config == nil { config = configapi.DefaultGpuConfig() } @@ -135,6 +142,10 @@ func ApplyConfig(config runtime.Object, results []*resourceapi.DeviceRequestAllo return nil, fmt.Errorf("runtime object is not a recognized configuration") } +// In this example driver there is no actual configuration applied. We simply +// define a set of environment variables to be injected into the containers +// that include a given device. A real driver would likely need to do some sort +// of hardware configuration as well, based on the config passed in. func applyGpuConfig(config *configapi.GpuConfig, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) { perDeviceEdits := make(profiles.PerDeviceCDIContainerEdits) diff --git a/internal/profiles/profiles.go b/internal/profiles/profiles.go index 7ad0326a..25e94004 100644 --- a/internal/profiles/profiles.go +++ b/internal/profiles/profiles.go @@ -17,6 +17,11 @@ package profiles import ( + "errors" + + resourceapi "k8s.io/api/resource/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/dynamic-resource-allocation/resourceslice" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" ) @@ -37,3 +42,43 @@ func (pds PreparedDevices) GetDevices() []*drapbv1.Device { } return devices } + +// Profile describes a kind of device that can be managed by the driver. +type Profile interface { + ConfigHandler + EnumerateDevices() (resourceslice.DriverResources, error) +} + +// ConfigHandler handles opaque configuration set for requests in ResourceClaims. +type ConfigHandler interface { + // SchemeBuilder produces a [runtime.Scheme] for the profile's configuration types. + SchemeBuilder() runtime.SchemeBuilder + // Validate returns nil for valid configuration, or an error explaining why the configuration is invalid. + Validate(config runtime.Object) error + // ApplyConfig applies a configuration to a set of device allocation + // results. When `config` is nil, the profile's default configuration should + // be applied. + ApplyConfig(config runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (PerDeviceCDIContainerEdits, error) +} + +// NoopConfigHandler implements a [ConfigHandler] that does not allow +// configuration. +type NoopConfigHandler struct{} + +// ApplyConfig implements [ConfigHandler]. +func (n NoopConfigHandler) ApplyConfig(config runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (PerDeviceCDIContainerEdits, error) { + if config != nil { + return nil, errors.New("configuration not allowed") + } + return nil, nil +} + +// SchemeBuilder implements [ConfigHandler]. +func (n NoopConfigHandler) SchemeBuilder() runtime.SchemeBuilder { + return runtime.NewSchemeBuilder() +} + +// Validate implements [ConfigHandler]. +func (n NoopConfigHandler) Validate(config runtime.Object) error { + return errors.New("configuration not allowed") +}