diff --git a/api/example.com/resource/gpu/v1alpha1/api.go b/api/example.com/resource/gpu/v1alpha1/api.go index 203d220f..bc62b507 100644 --- a/api/example.com/resource/gpu/v1alpha1/api.go +++ b/api/example.com/resource/gpu/v1alpha1/api.go @@ -20,20 +20,9 @@ import ( "fmt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/apimachinery/pkg/runtime/serializer/json" ) -const ( - GroupName = "gpu.resource.example.com" - Version = "v1alpha1" - - GpuConfigKind = "GpuConfig" -) - -// Decoder implements a decoder for objects in this API group. -var Decoder runtime.Decoder +const GpuConfigKind = "GpuConfig" // +genclient // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object @@ -82,29 +71,3 @@ func (c *GpuConfig) Normalize() error { } return nil } - -func init() { - // Create a new scheme and add our types to it. If at some point in the - // future a new version of the configuration API becomes necessary, then - // conversion functions can be generated and registered to continue - // supporting older versions. - scheme := runtime.NewScheme() - schemeGroupVersion := schema.GroupVersion{ - Group: GroupName, - Version: Version, - } - scheme.AddKnownTypes(schemeGroupVersion, - &GpuConfig{}, - ) - metav1.AddToGroupVersion(scheme, schemeGroupVersion) - - // Set up a json serializer to decode our types. - Decoder = json.NewSerializerWithOptions( - json.DefaultMetaFactory, - scheme, - scheme, - json.SerializerOptions{ - Pretty: true, Strict: true, - }, - ) -} diff --git a/api/example.com/resource/gpu/v1alpha1/register.go b/api/example.com/resource/gpu/v1alpha1/register.go new file mode 100644 index 00000000..0c9f0233 --- /dev/null +++ b/api/example.com/resource/gpu/v1alpha1/register.go @@ -0,0 +1,45 @@ +/* + * Copyright The Kubernetes Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +const ( + GroupName = "gpu.resource.example.com" + Version = "v1alpha1" +) + +// SchemeGroupVersion is group version used to register these objects. +var SchemeGroupVersion = schema.GroupVersion{Group: GroupName, Version: Version} + +var ( + SchemeBuilder = runtime.NewSchemeBuilder(addKnownTypes) + AddToScheme = SchemeBuilder.AddToScheme +) + +// Adds the list of known types to the given scheme. +func addKnownTypes(scheme *runtime.Scheme) error { + scheme.AddKnownTypes(SchemeGroupVersion, + &GpuConfig{}, + ) + metav1.AddToGroupVersion(scheme, SchemeGroupVersion) + return nil +} diff --git a/cmd/dra-example-kubeletplugin/cdi.go b/cmd/dra-example-kubeletplugin/cdi.go index 7e701178..e92688e0 100644 --- a/cmd/dra-example-kubeletplugin/cdi.go +++ b/cmd/dra-example-kubeletplugin/cdi.go @@ -19,35 +19,37 @@ package main import ( "fmt" "os" - - "sigs.k8s.io/dra-example-driver/pkg/consts" + "regexp" + "strings" cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" cdiparser "tags.cncf.io/container-device-interface/pkg/parser" cdispec "tags.cncf.io/container-device-interface/specs-go" + + "sigs.k8s.io/dra-example-driver/internal/profiles" ) -const ( - cdiVendor = "k8s." + consts.DriverName - cdiClass = "gpu" - cdiKind = cdiVendor + "/" + cdiClass +const cdiCommonDeviceName = "common" - cdiCommonDeviceName = "common" -) +var nonWord = regexp.MustCompile(`[^a-zA-Z0-9]+`) type CDIHandler struct { - cache *cdiapi.Cache + cache *cdiapi.Cache + driverName string + class string } -func NewCDIHandler(config *Config) (*CDIHandler, error) { +func NewCDIHandler(root string, driverName, class string) (*CDIHandler, error) { cache, err := cdiapi.NewCache( - cdiapi.WithSpecDirs(config.flags.cdiRoot), + cdiapi.WithSpecDirs(root), ) if err != nil { return nil, fmt.Errorf("unable to create a new CDI cache: %w", err) } handler := &CDIHandler{ - cache: cache, + cache: cache, + driverName: driverName, + class: class, } return handler, nil @@ -55,14 +57,14 @@ func NewCDIHandler(config *Config) (*CDIHandler, error) { func (cdi *CDIHandler) CreateCommonSpecFile() error { spec := &cdispec.Spec{ - Kind: cdiKind, + Kind: cdi.kind(), Devices: []cdispec.Device{ { Name: cdiCommonDeviceName, ContainerEdits: cdispec.ContainerEdits{ Env: []string{ fmt.Sprintf("KUBERNETES_NODE_NAME=%s", os.Getenv("NODE_NAME")), - fmt.Sprintf("DRA_RESOURCE_DRIVER_NAME=%s", consts.DriverName), + fmt.Sprintf("DRA_RESOURCE_DRIVER_NAME=%s", cdi.driverName), }, }, }, @@ -83,19 +85,20 @@ func (cdi *CDIHandler) CreateCommonSpecFile() error { return cdi.cache.WriteSpec(spec, specName) } -func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, devices PreparedDevices) error { - specName := cdiapi.GenerateTransientSpecName(cdiVendor, cdiClass, claimUID) +func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, devices profiles.PreparedDevices) error { + specName := cdiapi.GenerateTransientSpecName(cdi.vendor(), cdi.class, claimUID) spec := &cdispec.Spec{ - Kind: cdiKind, + Kind: cdi.kind(), Devices: []cdispec.Device{}, } for _, device := range devices { + deviceEnvKey := strings.ToUpper(nonWord.ReplaceAllString(device.DeviceName, "_")) claimEdits := cdiapi.ContainerEdits{ ContainerEdits: &cdispec.ContainerEdits{ Env: []string{ - fmt.Sprintf("GPU_DEVICE_%s_RESOURCE_CLAIM=%s", device.DeviceName[4:], claimUID), + fmt.Sprintf("%s_DEVICE_%s_RESOURCE_CLAIM=%s", strings.ToUpper(cdi.class), deviceEnvKey, claimUID), }, }, } @@ -119,19 +122,27 @@ func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, devices PreparedDevi } func (cdi *CDIHandler) DeleteClaimSpecFile(claimUID string) error { - specName := cdiapi.GenerateTransientSpecName(cdiVendor, cdiClass, claimUID) + specName := cdiapi.GenerateTransientSpecName(cdi.vendor(), cdi.class, claimUID) return cdi.cache.RemoveSpec(specName) } func (cdi *CDIHandler) GetClaimDevices(claimUID string, devices []string) []string { cdiDevices := []string{ - cdiparser.QualifiedName(cdiVendor, cdiClass, cdiCommonDeviceName), + cdiparser.QualifiedName(cdi.vendor(), cdi.class, cdiCommonDeviceName), } for _, device := range devices { - cdiDevice := cdiparser.QualifiedName(cdiVendor, cdiClass, fmt.Sprintf("%s-%s", claimUID, device)) + cdiDevice := cdiparser.QualifiedName(cdi.vendor(), cdi.class, fmt.Sprintf("%s-%s", claimUID, device)) cdiDevices = append(cdiDevices, cdiDevice) } return cdiDevices } + +func (cdi *CDIHandler) kind() string { + return cdi.vendor() + "/" + cdi.class +} + +func (cdi *CDIHandler) vendor() string { + return "k8s." + cdi.driverName +} diff --git a/cmd/dra-example-kubeletplugin/discovery.go b/cmd/dra-example-kubeletplugin/discovery.go deleted file mode 100644 index 0c45431f..00000000 --- a/cmd/dra-example-kubeletplugin/discovery.go +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright 2023 The Kubernetes Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package main - -import ( - "fmt" - "math/rand" - "os" - - resourceapi "k8s.io/api/resource/v1" - "k8s.io/apimachinery/pkg/api/resource" - "k8s.io/utils/ptr" - - "github.com/google/uuid" -) - -func enumerateAllPossibleDevices(numGPUs int) (AllocatableDevices, error) { - seed := os.Getenv("NODE_NAME") - uuids := generateUUIDs(seed, numGPUs) - - alldevices := make(AllocatableDevices) - for i, uuid := range uuids { - device := resourceapi.Device{ - Name: fmt.Sprintf("gpu-%d", i), - Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ - "index": { - IntValue: ptr.To(int64(i)), - }, - "uuid": { - StringValue: ptr.To(uuid), - }, - "model": { - StringValue: ptr.To("LATEST-GPU-MODEL"), - }, - "driverVersion": { - VersionValue: ptr.To("1.0.0"), - }, - }, - Capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{ - "memory": { - Value: resource.MustParse("80Gi"), - }, - }, - } - alldevices[device.Name] = device - } - return alldevices, nil -} - -func generateUUIDs(seed string, count int) []string { - rand := rand.New(rand.NewSource(hash(seed))) - - uuids := make([]string, count) - for i := 0; i < count; i++ { - charset := make([]byte, 16) - rand.Read(charset) - uuid, _ := uuid.FromBytes(charset) - uuids[i] = "gpu-" + uuid.String() - } - - return uuids -} - -func hash(s string) int64 { - h := int64(0) - for _, c := range s { - h = 31*h + int64(c) - } - return h -} diff --git a/cmd/dra-example-kubeletplugin/driver.go b/cmd/dra-example-kubeletplugin/driver.go index 0f6e224c..d0ebf77f 100644 --- a/cmd/dra-example-kubeletplugin/driver.go +++ b/cmd/dra-example-kubeletplugin/driver.go @@ -20,17 +20,13 @@ import ( "context" "errors" "fmt" - "maps" resourceapi "k8s.io/api/resource/v1" "k8s.io/apimachinery/pkg/types" utilruntime "k8s.io/apimachinery/pkg/util/runtime" coreclientset "k8s.io/client-go/kubernetes" "k8s.io/dynamic-resource-allocation/kubeletplugin" - "k8s.io/dynamic-resource-allocation/resourceslice" "k8s.io/klog/v2" - - "sigs.k8s.io/dra-example-driver/pkg/consts" ) type driver struct { @@ -58,7 +54,7 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) { driver, kubeletplugin.KubeClient(config.coreclient), kubeletplugin.NodeName(config.flags.nodeName), - kubeletplugin.DriverName(consts.DriverName), + kubeletplugin.DriverName(config.flags.driverName), kubeletplugin.RegistrarDirectoryPath(config.flags.kubeletRegistrarDirectoryPath), kubeletplugin.PluginDataDirectoryPath(config.DriverPluginPath()), ) @@ -67,28 +63,12 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) { } driver.helper = helper - devices := make([]resourceapi.Device, 0, len(state.allocatable)) - for device := range maps.Values(state.allocatable) { - devices = append(devices, device) - } - resources := resourceslice.DriverResources{ - Pools: map[string]resourceslice.Pool{ - config.flags.nodeName: { - Slices: []resourceslice.Slice{ - { - Devices: devices, - }, - }, - }, - }, - } - driver.healthcheck, err = startHealthcheck(ctx, config) if err != nil { return nil, fmt.Errorf("start healthcheck: %w", err) } - if err := helper.PublishResources(ctx, resources); err != nil { + if err := helper.PublishResources(ctx, state.driverResources); err != nil { return nil, err } diff --git a/cmd/dra-example-kubeletplugin/health.go b/cmd/dra-example-kubeletplugin/health.go index 14d393aa..b5ddf7da 100644 --- a/cmd/dra-example-kubeletplugin/health.go +++ b/cmd/dra-example-kubeletplugin/health.go @@ -33,8 +33,6 @@ import ( "k8s.io/klog/v2" drapb "k8s.io/kubelet/pkg/apis/dra/v1" registerapi "k8s.io/kubelet/pkg/apis/pluginregistration/v1" - - "sigs.k8s.io/dra-example-driver/pkg/consts" ) type healthcheck struct { @@ -65,7 +63,7 @@ func startHealthcheck(ctx context.Context, config *Config) (*healthcheck, error) Scheme: "unix", // TODO: this needs to adapt when seamless upgrades // are enabled and the filename includes a uid. - Path: path.Join(config.flags.kubeletRegistrarDirectoryPath, consts.DriverName+"-reg.sock"), + Path: path.Join(config.flags.kubeletRegistrarDirectoryPath, config.flags.driverName+"-reg.sock"), }).String() log.Info("connecting to registration socket", "path", regSockPath) regConn, err := grpc.NewClient( diff --git a/cmd/dra-example-kubeletplugin/main.go b/cmd/dra-example-kubeletplugin/main.go index e0cfaf9c..757e7fa9 100644 --- a/cmd/dra-example-kubeletplugin/main.go +++ b/cmd/dra-example-kubeletplugin/main.go @@ -31,7 +31,8 @@ import ( "k8s.io/dynamic-resource-allocation/kubeletplugin" "k8s.io/klog/v2" - "sigs.k8s.io/dra-example-driver/pkg/consts" + "sigs.k8s.io/dra-example-driver/internal/profiles" + "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" "sigs.k8s.io/dra-example-driver/pkg/flags" ) @@ -49,16 +50,34 @@ type Flags struct { kubeletRegistrarDirectoryPath string kubeletPluginsDirectoryPath string healthcheckPort int + profile string + driverName string } type Config struct { flags *Flags coreclient coreclientset.Interface cancelMainCtx func(error) + + profile profiles.Profile +} + +var validProfiles = map[string]func(flags Flags) profiles.Profile{ + gpu.ProfileName: func(flags Flags) profiles.Profile { + return gpu.NewProfile(flags.nodeName, flags.numDevices) + }, } +var validProfileNames = func() []string { + var valid []string + for profileName := range validProfiles { + valid = append(valid, profileName) + } + return valid +}() + func (c Config) DriverPluginPath() string { - return filepath.Join(c.flags.kubeletPluginsDirectoryPath, consts.DriverName) + return filepath.Join(c.flags.kubeletPluginsDirectoryPath, c.flags.driverName) } func main() { @@ -89,7 +108,7 @@ func newApp() *cli.App { }, &cli.IntFlag{ Name: "num-devices", - Usage: "The number of devices to be generated.", + Usage: "The number of devices to be generated. Only relevant for the " + gpu.ProfileName + " profile.", Value: 8, Destination: &flags.numDevices, EnvVars: []string{"NUM_DEVICES"}, @@ -115,6 +134,19 @@ func newApp() *cli.App { Destination: &flags.healthcheckPort, EnvVars: []string{"HEALTHCHECK_PORT"}, }, + &cli.StringFlag{ + Name: "device-profile", + Usage: fmt.Sprintf("Name of the device profile. Valid values are %q.", validProfileNames), + Value: gpu.ProfileName, + Destination: &flags.profile, + EnvVars: []string{"DEVICE_PROFILE"}, + }, + &cli.StringFlag{ + Name: "driver-name", + Usage: "Name of the DRA driver. Its default is derived from the device profile.", + Destination: &flags.driverName, + EnvVars: []string{"DRIVER_NAME"}, + }, } cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...) cliFlags = append(cliFlags, flags.loggingConfig.Flags()...) @@ -135,12 +167,22 @@ func newApp() *cli.App { ctx := c.Context clientSets, err := flags.kubeClientConfig.NewClientSets() if err != nil { - return fmt.Errorf("create client: %v", err) + return fmt.Errorf("create client: %w", err) + } + + if flags.driverName == "" { + flags.driverName = flags.profile + ".example.com" + } + + newProfile, ok := validProfiles[flags.profile] + if !ok { + return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfileNames) } config := &Config{ flags: flags, coreclient: clientSets.Core, + profile: newProfile(*flags), } return RunPlugin(ctx, config) diff --git a/cmd/dra-example-kubeletplugin/state.go b/cmd/dra-example-kubeletplugin/state.go index 7cffe388..26ea7cf8 100644 --- a/cmd/dra-example-kubeletplugin/state.go +++ b/cmd/dra-example-kubeletplugin/state.go @@ -23,53 +23,40 @@ import ( resourceapi "k8s.io/api/resource/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/serializer/json" + "k8s.io/dynamic-resource-allocation/resourceslice" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" "k8s.io/kubernetes/pkg/kubelet/checkpointmanager" - configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" - "sigs.k8s.io/dra-example-driver/pkg/consts" - - cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" - cdispec "tags.cncf.io/container-device-interface/specs-go" + "sigs.k8s.io/dra-example-driver/internal/profiles" ) type AllocatableDevices map[string]resourceapi.Device -type PreparedDevices []*PreparedDevice -type PreparedClaims map[string]PreparedDevices -type PerDeviceCDIContainerEdits map[string]*cdiapi.ContainerEdits +type PreparedClaims map[string]profiles.PreparedDevices type OpaqueDeviceConfig struct { Requests []string Config runtime.Object } -type PreparedDevice struct { - drapbv1.Device - ContainerEdits *cdiapi.ContainerEdits -} - -func (pds PreparedDevices) GetDevices() []*drapbv1.Device { - var devices []*drapbv1.Device - for _, pd := range pds { - devices = append(devices, &pd.Device) - } - return devices -} - type DeviceState struct { sync.Mutex + driverName string cdi *CDIHandler + driverResources resourceslice.DriverResources allocatable AllocatableDevices checkpointManager checkpointmanager.CheckpointManager + configDecoder runtime.Decoder + configHandler profiles.ConfigHandler } func NewDeviceState(config *Config) (*DeviceState, error) { - allocatable, err := enumerateAllPossibleDevices(config.flags.numDevices) + driverResources, err := config.profile.EnumerateDevices() if err != nil { return nil, fmt.Errorf("error enumerating all possible devices: %v", err) } - cdi, err := NewCDIHandler(config) + cdi, err := NewCDIHandler(config.flags.cdiRoot, config.flags.driverName, config.flags.profile) if err != nil { return nil, fmt.Errorf("unable to create CDI handler: %v", err) } @@ -84,10 +71,38 @@ func NewDeviceState(config *Config) (*DeviceState, error) { return nil, fmt.Errorf("unable to create checkpoint manager: %v", err) } + configScheme := runtime.NewScheme() + configHandler := config.profile + sb := configHandler.SchemeBuilder() + if err := sb.AddToScheme(configScheme); err != nil { + return nil, fmt.Errorf("create config scheme: %w", err) + } + + // Set up a json serializer to decode our types. + decoder := json.NewSerializerWithOptions( + json.DefaultMetaFactory, + configScheme, + configScheme, + json.SerializerOptions{ + Pretty: true, Strict: true, + }, + ) + + allocatable := make(AllocatableDevices) + for _, slice := range driverResources.Pools[config.flags.nodeName].Slices { + for _, device := range slice.Devices { + allocatable[device.Name] = device + } + } + state := &DeviceState{ + driverName: config.flags.driverName, cdi: cdi, + driverResources: driverResources, allocatable: allocatable, checkpointManager: checkpointManager, + configDecoder: decoder, + configHandler: configHandler, } checkpoints, err := state.checkpointManager.ListCheckpoints() @@ -173,15 +188,15 @@ func (s *DeviceState) Unprepare(claimUID string) error { return nil } -func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (PreparedDevices, error) { +func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (profiles.PreparedDevices, error) { if claim.Status.Allocation == nil { return nil, fmt.Errorf("claim not yet allocated") } // Retrieve the full set of device configs for the driver. configs, err := GetOpaqueDeviceConfigs( - configapi.Decoder, - consts.DriverName, + s.configDecoder, + s.driverName, claim.Status.Allocation.Devices.Config, ) if err != nil { @@ -191,17 +206,18 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (Prepared // Add the default GPU Config to the front of the config list with the // lowest precedence. This guarantees there will be at least one config in // the list with len(Requests) == 0 for the lookup below. - configs = slices.Insert(configs, 0, &OpaqueDeviceConfig{ - Requests: []string{}, - Config: configapi.DefaultGpuConfig(), - }) + configs = slices.Insert(configs, 0, &OpaqueDeviceConfig{}) // Look through the configs and figure out which one will be applied to // each device allocation result based on their order of precedence. configResultsMap := make(map[runtime.Object][]*resourceapi.DeviceRequestAllocationResult) for _, result := range claim.Status.Allocation.Devices.Results { + // The claim may include allocations meant for other drivers. + if result.Driver != s.driverName { + continue + } if _, exists := s.allocatable[result.Device]; !exists { - return nil, fmt.Errorf("requested GPU is not allocatable: %v", result.Device) + return nil, fmt.Errorf("requested device is not allocatable: %v", result.Device) } for _, c := range slices.Backward(configs) { if len(c.Requests) == 0 || slices.Contains(c.Requests, result.Request) { @@ -211,34 +227,15 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (Prepared } } - // Normalize, validate, and apply all configs associated with devices that - // need to be prepared. Track container edits generated from applying the - // config to the set of device allocation results. - perDeviceCDIContainerEdits := make(PerDeviceCDIContainerEdits) - for c, results := range configResultsMap { - // Cast the opaque config to a GpuConfig - var config *configapi.GpuConfig - switch castConfig := c.(type) { - case *configapi.GpuConfig: - config = castConfig - default: - return nil, fmt.Errorf("runtime object is not a regognized configuration") - } - - // Normalize the config to set any implied defaults. - if err := config.Normalize(); err != nil { - return nil, fmt.Errorf("error normalizing GPU config: %w", err) - } - - // Validate the config to ensure its integrity. - if err := config.Validate(); err != nil { - return nil, fmt.Errorf("error validating GPU config: %w", err) - } - + // Apply all configs associated with devices that need to be prepared. + // Track container edits generated from applying the config to the set + // of device allocation results. + perDeviceCDIContainerEdits := make(profiles.PerDeviceCDIContainerEdits) + for config, results := range configResultsMap { // Apply the config to the list of results associated with it. - containerEdits, err := s.applyConfig(config, results) + containerEdits, err := s.configHandler.ApplyConfig(config, results) if err != nil { - return nil, fmt.Errorf("error applying GPU config: %w", err) + return nil, fmt.Errorf("error applying config: %w", err) } // Merge any new container edits with the overall per device map. @@ -249,10 +246,10 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (Prepared // Walk through each config and its associated device allocation results // and construct the list of prepared devices to return. - var preparedDevices PreparedDevices + var preparedDevices profiles.PreparedDevices for _, results := range configResultsMap { for _, result := range results { - device := &PreparedDevice{ + device := &profiles.PreparedDevice{ Device: drapbv1.Device{ RequestNames: []string{result.Request}, PoolName: result.Pool, @@ -268,53 +265,10 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (Prepared return preparedDevices, nil } -func (s *DeviceState) unprepareDevices(claimUID string, devices PreparedDevices) error { +func (s *DeviceState) unprepareDevices(claimUID string, devices profiles.PreparedDevices) error { return nil } -// applyConfig applies a configuration to a set of device allocation results. -// -// In this example driver there is no actual configuration applied. We simply -// define a set of environment variables to be injected into the containers -// that include a given device. A real driver would likely need to do some sort -// of hardware configuration as well, based on the config passed in. -func (s *DeviceState) applyConfig(config *configapi.GpuConfig, results []*resourceapi.DeviceRequestAllocationResult) (PerDeviceCDIContainerEdits, error) { - perDeviceEdits := make(PerDeviceCDIContainerEdits) - - for _, result := range results { - envs := []string{ - fmt.Sprintf("GPU_DEVICE_%s=%s", result.Device[4:], result.Device), - } - - if config.Sharing != nil { - envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_SHARING_STRATEGY=%s", result.Device[4:], config.Sharing.Strategy)) - } - - switch { - case config.Sharing.IsTimeSlicing(): - tsconfig, err := config.Sharing.GetTimeSlicingConfig() - if err != nil { - return nil, fmt.Errorf("unable to get time slicing config for device %v: %w", result.Device, err) - } - envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_TIMESLICE_INTERVAL=%v", result.Device[4:], tsconfig.Interval)) - case config.Sharing.IsSpacePartitioning(): - spconfig, err := config.Sharing.GetSpacePartitioningConfig() - if err != nil { - return nil, fmt.Errorf("unable to get space partitioning config for device %v: %w", result.Device, err) - } - envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_PARTITION_COUNT=%v", result.Device[4:], spconfig.PartitionCount)) - } - - edits := &cdispec.ContainerEdits{ - Env: envs, - } - - perDeviceEdits[result.Device] = &cdiapi.ContainerEdits{ContainerEdits: edits} - } - - return perDeviceEdits, nil -} - // GetOpaqueDeviceConfigs returns an ordered list of the configs contained in possibleConfigs for this driver. // // Configs can either come from the resource claim itself or from the device diff --git a/cmd/dra-example-kubeletplugin/state_test.go b/cmd/dra-example-kubeletplugin/state_test.go index 8490ebb5..d23b3859 100644 --- a/cmd/dra-example-kubeletplugin/state_test.go +++ b/cmd/dra-example-kubeletplugin/state_test.go @@ -21,12 +21,14 @@ import ( "github.com/stretchr/testify/assert" + "sigs.k8s.io/dra-example-driver/internal/profiles" + drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" ) func TestPreparedDevicesGetDevices(t *testing.T) { tests := map[string]struct { - preparedDevices PreparedDevices + preparedDevices profiles.PreparedDevices expected []*drapbv1.Device }{ "nil PreparedDevices": { @@ -34,7 +36,7 @@ func TestPreparedDevicesGetDevices(t *testing.T) { expected: nil, }, "several PreparedDevices": { - preparedDevices: PreparedDevices{ + preparedDevices: profiles.PreparedDevices{ {Device: drapbv1.Device{DeviceName: "dev1"}}, {Device: drapbv1.Device{DeviceName: "dev2"}}, {Device: drapbv1.Device{DeviceName: "dev3"}}, diff --git a/cmd/dra-example-webhook/main.go b/cmd/dra-example-webhook/main.go index b7ce2172..9e1383c1 100644 --- a/cmd/dra-example-webhook/main.go +++ b/cmd/dra-example-webhook/main.go @@ -30,19 +30,28 @@ import ( resourceapi "k8s.io/api/resource/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + kjson "k8s.io/apimachinery/pkg/runtime/serializer/json" "k8s.io/klog/v2" - configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" - "sigs.k8s.io/dra-example-driver/pkg/consts" + "sigs.k8s.io/dra-example-driver/internal/profiles" + "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" "sigs.k8s.io/dra-example-driver/pkg/flags" ) type Flags struct { loggingConfig *flags.LoggingConfig - certFile string - keyFile string - port int + certFile string + keyFile string + port int + profile string + driverName string +} + +type validator func(runtime.Object) error + +var validProfiles = map[string]profiles.ConfigHandler{ + gpu.ProfileName: gpu.Profile{}, } func main() { @@ -75,6 +84,19 @@ func newApp() *cli.App { Value: 443, Destination: &flags.port, }, + &cli.StringFlag{ + Name: "device-profile", + Usage: fmt.Sprintf("Name of the device profile. Valid values are %q.", validProfiles), + Value: gpu.ProfileName, + Destination: &flags.profile, + EnvVars: []string{"DEVICE_PROFILE"}, + }, + &cli.StringFlag{ + Name: "driver-name", + Usage: "Name of the DRA driver. Its default is derived from the device profile.", + Destination: &flags.driverName, + EnvVars: []string{"DRIVER_NAME"}, + }, } cliFlags = append(cliFlags, flags.loggingConfig.Flags()...) @@ -91,8 +113,26 @@ func newApp() *cli.App { return flags.loggingConfig.Apply() }, Action: func(c *cli.Context) error { + configHandler, ok := validProfiles[flags.profile] + if !ok { + var valid []string + for profileName := range validProfiles { + valid = append(valid, profileName) + } + return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, valid) + } + + if flags.driverName == "" { + flags.driverName = flags.profile + ".example.com" + } + + mux, err := newMux(configHandler, flags.driverName) + if err != nil { + return fmt.Errorf("create HTTP mux: %w", err) + } + server := &http.Server{ - Handler: newMux(), + Handler: mux, Addr: fmt.Sprintf(":%d", flags.port), } klog.Info("starting webhook server on", server.Addr) @@ -103,21 +143,39 @@ func newApp() *cli.App { return app } -func newMux() *http.ServeMux { +func newMux(configHandler profiles.ConfigHandler, driverName string) (*http.ServeMux, error) { + configScheme := runtime.NewScheme() + sb := configHandler.SchemeBuilder() + if err := sb.AddToScheme(configScheme); err != nil { + return nil, fmt.Errorf("create config scheme: %w", err) + } + configDecoder := kjson.NewSerializerWithOptions( + kjson.DefaultMetaFactory, + configScheme, + configScheme, + kjson.SerializerOptions{ + Pretty: true, Strict: true, + }, + ) + mux := http.NewServeMux() - mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim) - mux.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) { - _, err := w.Write([]byte("ok")) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }) - return mux + mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder, configHandler.Validate, driverName)) + mux.HandleFunc("/readyz", readyHandler) + return mux, nil } -func serveResourceClaim(w http.ResponseWriter, r *http.Request) { - serve(w, r, admitResourceClaimParameters) +func readyHandler(w http.ResponseWriter, req *http.Request) { + _, err := w.Write([]byte("ok")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +func serveResourceClaim(configDecoder runtime.Decoder, validate validator, driverName string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + serve(w, r, admitResourceClaimParameters(configDecoder, validate, driverName)) + } } // serve handles the http portion of a request prior to handing to an admit @@ -191,96 +249,93 @@ func readAdmissionReview(data []byte) (*admissionv1.AdmissionReview, error) { // admitResourceClaimParameters accepts both ResourceClaims and ResourceClaimTemplates and validates their // opaque device configuration parameters for this driver. -func admitResourceClaimParameters(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { - klog.V(2).Info("admitting resource claim parameters") - - var deviceConfigs []resourceapi.DeviceClaimConfiguration - var specPath string - - switch ar.Request.Resource { - case resourceClaimResourceV1, resourceClaimResourceV1Beta1, resourceClaimResourceV1Beta2: - claim, err := extractResourceClaim(ar) - if err != nil { - klog.Error(err) - return &admissionv1.AdmissionResponse{ - Result: &metav1.Status{ - Message: err.Error(), - Reason: metav1.StatusReasonBadRequest, - }, +func admitResourceClaimParameters(configDecoder runtime.Decoder, validate validator, driverName string) func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { + return func(ar admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { + klog.V(2).Info("admitting resource claim parameters") + + var deviceConfigs []resourceapi.DeviceClaimConfiguration + var specPath string + + switch ar.Request.Resource { + case resourceClaimResourceV1, resourceClaimResourceV1Beta1, resourceClaimResourceV1Beta2: + claim, err := extractResourceClaim(ar) + if err != nil { + klog.Error(err) + return &admissionv1.AdmissionResponse{ + Result: &metav1.Status{ + Message: err.Error(), + Reason: metav1.StatusReasonBadRequest, + }, + } } - } - deviceConfigs = claim.Spec.Devices.Config - specPath = "spec" - case resourceClaimTemplateResourceV1, resourceClaimTemplateResourceV1Beta1, resourceClaimTemplateResourceV1Beta2: - claimTemplate, err := extractResourceClaimTemplate(ar) - if err != nil { - klog.Error(err) + deviceConfigs = claim.Spec.Devices.Config + specPath = "spec" + case resourceClaimTemplateResourceV1, resourceClaimTemplateResourceV1Beta1, resourceClaimTemplateResourceV1Beta2: + claimTemplate, err := extractResourceClaimTemplate(ar) + if err != nil { + klog.Error(err) + return &admissionv1.AdmissionResponse{ + Result: &metav1.Status{ + Message: err.Error(), + Reason: metav1.StatusReasonBadRequest, + }, + } + } + deviceConfigs = claimTemplate.Spec.Spec.Devices.Config + specPath = "spec.spec" + default: + msg := fmt.Sprintf( + "expected resource to be one of %v, got %s", + []metav1.GroupVersionResource{ + resourceClaimResourceV1, resourceClaimResourceV1Beta1, resourceClaimResourceV1Beta2, + resourceClaimTemplateResourceV1, resourceClaimTemplateResourceV1Beta1, resourceClaimTemplateResourceV1Beta2, + }, + ar.Request.Resource, + ) + klog.Error(msg) return &admissionv1.AdmissionResponse{ Result: &metav1.Status{ - Message: err.Error(), + Message: msg, Reason: metav1.StatusReasonBadRequest, }, } } - deviceConfigs = claimTemplate.Spec.Spec.Devices.Config - specPath = "spec.spec" - default: - msg := fmt.Sprintf( - "expected resource to be one of %v, got %s", - []metav1.GroupVersionResource{ - resourceClaimResourceV1, resourceClaimResourceV1Beta1, resourceClaimResourceV1Beta2, - resourceClaimTemplateResourceV1, resourceClaimTemplateResourceV1Beta1, resourceClaimTemplateResourceV1Beta2, - }, - ar.Request.Resource, - ) - klog.Error(msg) - return &admissionv1.AdmissionResponse{ - Result: &metav1.Status{ - Message: msg, - Reason: metav1.StatusReasonBadRequest, - }, - } - } - var errs []error - for configIndex, config := range deviceConfigs { - if config.Opaque == nil || config.Opaque.Driver != consts.DriverName { - continue - } + var errs []error + for configIndex, config := range deviceConfigs { + if config.Opaque == nil || config.Opaque.Driver != driverName { + continue + } - fieldPath := fmt.Sprintf("%s.devices.config[%d].opaque.parameters", specPath, configIndex) - decodedConfig, err := runtime.Decode(configapi.Decoder, config.DeviceConfiguration.Opaque.Parameters.Raw) - if err != nil { - errs = append(errs, fmt.Errorf("error decoding object at %s: %w", fieldPath, err)) - continue - } - gpuConfig, ok := decodedConfig.(*configapi.GpuConfig) - if !ok { - errs = append(errs, fmt.Errorf("expected v1alpha1.GpuConfig at %s but got: %T", fieldPath, decodedConfig)) - continue - } - err = gpuConfig.Validate() - if err != nil { - errs = append(errs, fmt.Errorf("object at %s is invalid: %w", fieldPath, err)) + fieldPath := fmt.Sprintf("%s.devices.config[%d].opaque.parameters", specPath, configIndex) + decodedConfig, err := runtime.Decode(configDecoder, config.DeviceConfiguration.Opaque.Parameters.Raw) + if err != nil { + errs = append(errs, fmt.Errorf("error decoding object at %s: %w", fieldPath, err)) + continue + } + err = validate(decodedConfig) + if err != nil { + errs = append(errs, fmt.Errorf("object at %s is invalid: %w", fieldPath, err)) + } } - } - if len(errs) > 0 { - var errMsgs []string - for _, err := range errs { - errMsgs = append(errMsgs, err.Error()) + if len(errs) > 0 { + var errMsgs []string + for _, err := range errs { + errMsgs = append(errMsgs, err.Error()) + } + msg := fmt.Sprintf("%d configs failed to validate: %s", len(errs), strings.Join(errMsgs, "; ")) + klog.Error(msg) + return &admissionv1.AdmissionResponse{ + Result: &metav1.Status{ + Message: msg, + Reason: metav1.StatusReason(metav1.StatusReasonInvalid), + }, + } } - msg := fmt.Sprintf("%d configs failed to validate: %s", len(errs), strings.Join(errMsgs, "; ")) - klog.Error(msg) + return &admissionv1.AdmissionResponse{ - Result: &metav1.Status{ - Message: msg, - Reason: metav1.StatusReason(metav1.StatusReasonInvalid), - }, + Allowed: true, } } - - return &admissionv1.AdmissionResponse{ - Allowed: true, - } } diff --git a/cmd/dra-example-webhook/main_test.go b/cmd/dra-example-webhook/main_test.go index 249c128a..e8408fae 100644 --- a/cmd/dra-example-webhook/main_test.go +++ b/cmd/dra-example-webhook/main_test.go @@ -36,14 +36,16 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" - "sigs.k8s.io/dra-example-driver/pkg/consts" + "sigs.k8s.io/dra-example-driver/internal/profiles/gpu" ) +const driverName = "gpu.example.com" + func TestReadyEndpoint(t *testing.T) { - s := httptest.NewServer(newMux()) + s := httptest.NewServer(http.HandlerFunc(readyHandler)) t.Cleanup(s.Close) - res, err := http.Get(s.URL + "/readyz") + res, err := http.Get(s.URL) assert.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) } @@ -168,7 +170,11 @@ func TestResourceClaimValidatingWebhook(t *testing.T) { }, } - s := httptest.NewServer(newMux()) + configHandler := gpu.Profile{} + mux, err := newMux(configHandler, driverName) + assert.NoError(t, err) + + s := httptest.NewServer(mux) t.Cleanup(s.Close) for name, test := range tests { @@ -249,7 +255,7 @@ func resourceClaimSpecWithGpuConfigs(gpuConfigs ...*configapi.GpuConfig) resourc deviceConfig := resourceapi.DeviceClaimConfiguration{ DeviceConfiguration: resourceapi.DeviceConfiguration{ Opaque: &resourceapi.OpaqueDeviceConfiguration{ - Driver: consts.DriverName, + Driver: driverName, Parameters: runtime.RawExtension{ Object: gpuConfig, }, diff --git a/deployments/helm/dra-example-driver/templates/_helpers.tpl b/deployments/helm/dra-example-driver/templates/_helpers.tpl index 63ec53df..6198ba65 100644 --- a/deployments/helm/dra-example-driver/templates/_helpers.tpl +++ b/deployments/helm/dra-example-driver/templates/_helpers.tpl @@ -121,3 +121,10 @@ resource.k8s.io/v1beta1 {{- else -}} {{- end -}} {{- end -}} + +{{/* +The driver name. +*/}} +{{- define "dra-example-driver.driverName" -}} +{{ default (print .Values.deviceProfile ".example.com") .Values.driverName }} +{{- end -}} diff --git a/deployments/helm/dra-example-driver/templates/deviceclass.yaml b/deployments/helm/dra-example-driver/templates/deviceclass.yaml index 8fd7085b..e2bb69cf 100644 --- a/deployments/helm/dra-example-driver/templates/deviceclass.yaml +++ b/deployments/helm/dra-example-driver/templates/deviceclass.yaml @@ -2,8 +2,8 @@ apiVersion: {{ include "dra-example-driver.resourceApiVersion" . }} kind: DeviceClass metadata: - name: gpu.example.com + name: {{ include "dra-example-driver.driverName" . }} spec: selectors: - - cel: - expression: "device.driver == 'gpu.example.com'" + - cel: + expression: "device.driver == '{{ include "dra-example-driver.driverName" . }}'" diff --git a/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml b/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml index bcf44ffd..5785a135 100644 --- a/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml +++ b/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml @@ -59,6 +59,10 @@ spec: periodSeconds: 10 {{- end }} env: + - name: DRIVER_NAME + value: {{ include "dra-example-driver.driverName" . | quote }} + - name: DEVICE_PROFILE + value: {{ .Values.deviceProfile | quote }} - name: CDI_ROOT value: /var/run/cdi - name: KUBELET_REGISTRAR_DIRECTORY_PATH diff --git a/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml b/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml index 1f574da5..920e879f 100644 --- a/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml +++ b/deployments/helm/dra-example-driver/templates/webhook-deployment.yaml @@ -43,6 +43,8 @@ spec: - --tls-cert-file=/cert/tls.crt - --tls-private-key-file=/cert/tls.key - --port={{ .Values.webhook.containerPort }} + - --device-profile={{ .Values.deviceProfile }} + - --driver-name={{ include "dra-example-driver.driverName" . }} ports: - name: webhook containerPort: {{ .Values.webhook.containerPort }} diff --git a/deployments/helm/dra-example-driver/values.schema.json b/deployments/helm/dra-example-driver/values.schema.json new file mode 100644 index 00000000..3dff0a3a --- /dev/null +++ b/deployments/helm/dra-example-driver/values.schema.json @@ -0,0 +1,11 @@ +{ + "type": "object", + "properties": { + "deviceProfile": { + "type": "string", + "enum": [ + "gpu" + ] + } + } +} diff --git a/deployments/helm/dra-example-driver/values.yaml b/deployments/helm/dra-example-driver/values.yaml index e58c9d33..9ddededa 100644 --- a/deployments/helm/dra-example-driver/values.yaml +++ b/deployments/helm/dra-example-driver/values.yaml @@ -9,6 +9,15 @@ selectorLabelsOverride: {} allowDefaultNamespace: false +# deviceProfile describes the overall shape of the devices managed by the +# driver. Available profiles are: +# - "gpu": Node-local devices configurable through opaque config +deviceProfile: "gpu" + +# driverName uniquely identifies the driver within the cluster. When empty, its +# value is derived from the deviceProfile. +driverName: "" + imagePullSecrets: [] image: repository: registry.k8s.io/dra-example-driver/dra-example-driver @@ -25,26 +34,9 @@ serviceAccount: # If not set and create is true, a name is generated using the fullname template name: "" -controller: - priorityClassName: "system-node-critical" - podAnnotations: {} - podSecurityContext: {} - nodeSelector: - node-role.kubernetes.io/control-plane: "" - tolerations: - - key: node-role.kubernetes.io/master - operator: Exists - effect: NoSchedule - - key: node-role.kubernetes.io/control-plane - operator: Exists - effect: NoSchedule - affinity: {} - containers: - controller: - securityContext: {} - resources: {} - kubeletPlugin: + # numDevices describes how many GPUs to advertise on each node when the "gpu" + # deviceProfile is used. Not relevant for other profiles. numDevices: 8 priorityClassName: "system-node-critical" updateStrategy: diff --git a/internal/profiles/gpu/gpu.go b/internal/profiles/gpu/gpu.go new file mode 100644 index 00000000..fd2401e7 --- /dev/null +++ b/internal/profiles/gpu/gpu.go @@ -0,0 +1,194 @@ +/* + * Copyright The Kubernetes Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gpu + +import ( + "fmt" + "math/rand" + + "github.com/google/uuid" + resourceapi "k8s.io/api/resource/v1" + "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/dynamic-resource-allocation/resourceslice" + "k8s.io/utils/ptr" + cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" + cdispec "tags.cncf.io/container-device-interface/specs-go" + + configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1" + "sigs.k8s.io/dra-example-driver/internal/profiles" +) + +const ProfileName = "gpu" + +type Profile struct { + nodeName string + numGPUs int +} + +func NewProfile(nodeName string, numGPUs int) Profile { + return Profile{ + nodeName: nodeName, + numGPUs: numGPUs, + } +} + +func (p Profile) EnumerateDevices() (resourceslice.DriverResources, error) { + seed := p.nodeName + uuids := generateUUIDs(seed, p.numGPUs) + + var devices []resourceapi.Device + for i, uuid := range uuids { + device := resourceapi.Device{ + Name: fmt.Sprintf("gpu-%d", i), + Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ + "index": { + IntValue: ptr.To(int64(i)), + }, + "uuid": { + StringValue: ptr.To(uuid), + }, + "model": { + StringValue: ptr.To("LATEST-GPU-MODEL"), + }, + "driverVersion": { + VersionValue: ptr.To("1.0.0"), + }, + }, + Capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{ + "memory": { + Value: resource.MustParse("80Gi"), + }, + }, + } + devices = append(devices, device) + } + + resources := resourceslice.DriverResources{ + Pools: map[string]resourceslice.Pool{ + p.nodeName: { + Slices: []resourceslice.Slice{ + { + Devices: devices, + }, + }, + }, + }, + } + + return resources, nil +} + +func generateUUIDs(seed string, count int) []string { + rand := rand.New(rand.NewSource(hash(seed))) + + uuids := make([]string, count) + for i := 0; i < count; i++ { + charset := make([]byte, 16) + rand.Read(charset) + uuid, _ := uuid.FromBytes(charset) + uuids[i] = "gpu-" + uuid.String() + } + + return uuids +} + +func hash(s string) int64 { + h := int64(0) + for _, c := range s { + h = 31*h + int64(c) + } + return h +} + +// SchemeBuilder implements [profiles.ConfigHandler]. +func (p Profile) SchemeBuilder() runtime.SchemeBuilder { + return runtime.NewSchemeBuilder( + configapi.AddToScheme, + ) +} + +// Validate implements [profiles.ConfigHandler]. +func (p Profile) Validate(config runtime.Object) error { + gpuConfig, ok := config.(*configapi.GpuConfig) + if !ok { + return fmt.Errorf("expected v1alpha1.GpuConfig but got: %T", config) + } + return gpuConfig.Validate() +} + +// ApplyConfig implements [profiles.ConfigHandler]. +func (p Profile) ApplyConfig(config runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) { + if config == nil { + config = configapi.DefaultGpuConfig() + } + if config, ok := config.(*configapi.GpuConfig); ok { + return applyGpuConfig(config, results) + } + return nil, fmt.Errorf("runtime object is not a recognized configuration") +} + +// In this example driver there is no actual configuration applied. We simply +// define a set of environment variables to be injected into the containers +// that include a given device. A real driver would likely need to do some sort +// of hardware configuration as well, based on the config passed in. +func applyGpuConfig(config *configapi.GpuConfig, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error) { + perDeviceEdits := make(profiles.PerDeviceCDIContainerEdits) + + // Normalize the config to set any implied defaults. + if err := config.Normalize(); err != nil { + return nil, fmt.Errorf("error normalizing GPU config: %w", err) + } + + // Validate the config to ensure its integrity. + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("error validating GPU config: %w", err) + } + + for _, result := range results { + envs := []string{ + fmt.Sprintf("GPU_DEVICE_%s=%s", result.Device[4:], result.Device), + } + + if config.Sharing != nil { + envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_SHARING_STRATEGY=%s", result.Device[4:], config.Sharing.Strategy)) + } + + switch { + case config.Sharing.IsTimeSlicing(): + tsconfig, err := config.Sharing.GetTimeSlicingConfig() + if err != nil { + return nil, fmt.Errorf("unable to get time slicing config for device %v: %w", result.Device, err) + } + envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_TIMESLICE_INTERVAL=%v", result.Device[4:], tsconfig.Interval)) + case config.Sharing.IsSpacePartitioning(): + spconfig, err := config.Sharing.GetSpacePartitioningConfig() + if err != nil { + return nil, fmt.Errorf("unable to get space partitioning config for device %v: %w", result.Device, err) + } + envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_PARTITION_COUNT=%v", result.Device[4:], spconfig.PartitionCount)) + } + + edits := &cdispec.ContainerEdits{ + Env: envs, + } + + perDeviceEdits[result.Device] = &cdiapi.ContainerEdits{ContainerEdits: edits} + } + + return perDeviceEdits, nil +} diff --git a/internal/profiles/profiles.go b/internal/profiles/profiles.go new file mode 100644 index 00000000..25e94004 --- /dev/null +++ b/internal/profiles/profiles.go @@ -0,0 +1,84 @@ +/* + * Copyright The Kubernetes Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package profiles + +import ( + "errors" + + resourceapi "k8s.io/api/resource/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/dynamic-resource-allocation/resourceslice" + drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" + cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" +) + +type PerDeviceCDIContainerEdits map[string]*cdiapi.ContainerEdits + +type PreparedDevice struct { + drapbv1.Device + ContainerEdits *cdiapi.ContainerEdits +} + +type PreparedDevices []*PreparedDevice + +func (pds PreparedDevices) GetDevices() []*drapbv1.Device { + var devices []*drapbv1.Device + for _, pd := range pds { + devices = append(devices, &pd.Device) + } + return devices +} + +// Profile describes a kind of device that can be managed by the driver. +type Profile interface { + ConfigHandler + EnumerateDevices() (resourceslice.DriverResources, error) +} + +// ConfigHandler handles opaque configuration set for requests in ResourceClaims. +type ConfigHandler interface { + // SchemeBuilder produces a [runtime.Scheme] for the profile's configuration types. + SchemeBuilder() runtime.SchemeBuilder + // Validate returns nil for valid configuration, or an error explaining why the configuration is invalid. + Validate(config runtime.Object) error + // ApplyConfig applies a configuration to a set of device allocation + // results. When `config` is nil, the profile's default configuration should + // be applied. + ApplyConfig(config runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (PerDeviceCDIContainerEdits, error) +} + +// NoopConfigHandler implements a [ConfigHandler] that does not allow +// configuration. +type NoopConfigHandler struct{} + +// ApplyConfig implements [ConfigHandler]. +func (n NoopConfigHandler) ApplyConfig(config runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (PerDeviceCDIContainerEdits, error) { + if config != nil { + return nil, errors.New("configuration not allowed") + } + return nil, nil +} + +// SchemeBuilder implements [ConfigHandler]. +func (n NoopConfigHandler) SchemeBuilder() runtime.SchemeBuilder { + return runtime.NewSchemeBuilder() +} + +// Validate implements [ConfigHandler]. +func (n NoopConfigHandler) Validate(config runtime.Object) error { + return errors.New("configuration not allowed") +} diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go deleted file mode 100644 index 23dacd2e..00000000 --- a/pkg/consts/consts.go +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2025 The Kubernetes Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package consts - -const DriverName = "gpu.example.com"