Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions cmd/nvidia-dra-controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/urfave/cli/v2"

"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/component-base/metrics/legacyregistry"
"k8s.io/klog/v2"

Expand All @@ -49,6 +50,8 @@ type Flags struct {
httpEndpoint string
metricsPath string
profilePath string

deviceClasses sets.Set[string]
}

type Config struct {
Expand Down Expand Up @@ -105,6 +108,12 @@ func newApp() *cli.App {
Destination: &flags.profilePath,
EnvVars: []string{"PPROF_PATH"},
},
&cli.StringSliceFlag{
Name: "device-classes",
Usage: "The supported set of DRA device classes",
Value: cli.NewStringSlice(GpuDeviceType, MigDeviceType, ImexChannelType),
EnvVars: []string{"DEVICE_CLASSES"},
},
}

cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
Expand All @@ -125,6 +134,7 @@ func newApp() *cli.App {
Action: func(c *cli.Context) error {
ctx := c.Context
mux := http.NewServeMux()
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)

clientSets, err := flags.kubeClientConfig.NewClientSets()
if err != nil {
Expand All @@ -144,9 +154,11 @@ func newApp() *cli.App {
}
}

err = StartIMEXManager(ctx, config)
if err != nil {
return fmt.Errorf("start IMEX manager: %w", err)
if flags.deviceClasses.Has(ImexChannelType) {
err = StartIMEXManager(ctx, config)
if err != nil {
return fmt.Errorf("start IMEX manager: %w", err)
}
}

<-ctx.Done()
Expand Down
24 changes: 24 additions & 0 deletions cmd/nvidia-dra-controller/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package main

const (
GpuDeviceType = "gpu"
MigDeviceType = "mig"
ImexChannelType = "imex"
UnknownDeviceType = "unknown"
)
26 changes: 12 additions & 14 deletions cmd/nvidia-dra-plugin/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ func (cdi *CDIHandler) CreateStandardDeviceSpecFile(allocatable AllocatableDevic
return fmt.Errorf("failed to get common CDI spec edits: %w", err)
}

// Make sure that NVIDIA_VISIBLE_DEVICES is set to void to avoid the
// nvidia-container-runtime honoring it in addition to the underlying
// runtime honoring CDI.
commonEdits.ContainerEdits.Env = append(
commonEdits.ContainerEdits.Env,
"NVIDIA_VISIBLE_DEVICES=void")

// Generate device specs for all full GPUs and MIG devices.
var deviceSpecs []cdispec.Device
for _, device := range allocatable {
Expand Down Expand Up @@ -223,25 +230,16 @@ func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, preparedDevices Prep
// Generate claim specific specs for each device.
var deviceSpecs []cdispec.Device
for _, group := range preparedDevices {
// Include this per-device, rather than as a top-level edit so that
// each device spec is never empty and the spec file gets created
// without error.
claimDeviceEdits := cdiapi.ContainerEdits{
ContainerEdits: &cdispec.ContainerEdits{
Env: []string{
"NVIDIA_VISIBLE_DEVICES=void",
},
},
// If there are no edits passed back as prt of the device config state, skip it
if group.ConfigState.containerEdits == nil {
continue
}

// Apply any edits passed back as part of the device config state.
claimDeviceEdits.Append(group.ConfigState.containerEdits)

// Apply edits to all devices.
// Apply any edits passed back as part of the device config state to all devices
for _, device := range group.Devices {
deviceSpec := cdispec.Device{
Name: fmt.Sprintf("%s-%s", claimUID, device.CanonicalName()),
ContainerEdits: *claimDeviceEdits.ContainerEdits,
ContainerEdits: *group.ConfigState.containerEdits.ContainerEdits,
}

deviceSpecs = append(deviceSpecs, deviceSpec)
Expand Down
2 changes: 1 addition & 1 deletion cmd/nvidia-dra-plugin/device_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func NewDeviceState(ctx context.Context, config *Config) (*DeviceState, error) {
return nil, fmt.Errorf("failed to create device library: %w", err)
}

allocatable, err := nvdevlib.enumerateAllPossibleDevices()
allocatable, err := nvdevlib.enumerateAllPossibleDevices(config)
if err != nil {
return nil, fmt.Errorf("error enumerating all possible devices: %w", err)
}
Expand Down
9 changes: 7 additions & 2 deletions cmd/nvidia-dra-plugin/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (

type driver struct {
sync.Mutex
doneCh chan struct{}
client coreclientset.Interface
plugin kubeletplugin.DRAPlugin
state *DeviceState
Expand Down Expand Up @@ -61,6 +60,12 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
}
driver.plugin = plugin

// If not responsible for advertising GPUs or MIG devices, we are done
if !(config.flags.deviceClasses.Has(GpuDeviceType) || config.flags.deviceClasses.Has(MigDeviceType)) {
return driver, nil
}

// Otherwise, enumerate the set of GPU and MIG devices and publish them
var resources kubeletplugin.Resources
for _, device := range state.allocatable {
// Explicitly exclude IMEX channels from being advertised here. They
Expand All @@ -79,7 +84,7 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
}

func (d *driver) Shutdown(ctx context.Context) error {
close(d.doneCh)
d.plugin.Stop()
return nil
}

Expand Down
10 changes: 10 additions & 0 deletions cmd/nvidia-dra-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/urfave/cli/v2"

"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/klog/v2"

"github.com/NVIDIA/k8s-dra-driver/internal/info"
Expand All @@ -50,6 +51,7 @@ type Flags struct {
containerDriverRoot string
hostDriverRoot string
nvidiaCTKPath string
deviceClasses sets.Set[string]
}

type Config struct {
Expand Down Expand Up @@ -112,6 +114,12 @@ func newApp() *cli.App {
Destination: &flags.nvidiaCTKPath,
EnvVars: []string{"NVIDIA_CTK_PATH"},
},
&cli.StringSliceFlag{
Name: "device-classes",
Usage: "The supported set of DRA device classes",
Value: cli.NewStringSlice(GpuDeviceType, MigDeviceType, ImexChannelType),
EnvVars: []string{"DEVICE_CLASSES"},
},
}
cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
cliFlags = append(cliFlags, flags.loggingConfig.Flags()...)
Expand All @@ -130,6 +138,8 @@ func newApp() *cli.App {
},
Action: func(c *cli.Context) error {
ctx := c.Context
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)

clientSets, err := flags.kubeClientConfig.NewClientSets()
if err != nil {
return fmt.Errorf("create client: %w", err)
Expand Down
68 changes: 52 additions & 16 deletions cmd/nvidia-dra-plugin/nvlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,36 +108,66 @@ func (l deviceLib) alwaysShutdown() {
}
}

func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
func (l deviceLib) enumerateAllPossibleDevices(config *Config) (AllocatableDevices, error) {
alldevices := make(AllocatableDevices)
deviceClasses := config.flags.deviceClasses

if deviceClasses.Has(GpuDeviceType) || deviceClasses.Has(MigDeviceType) {
gms, err := l.enumerateGpusAndMigDevices(config)
if err != nil {
return nil, fmt.Errorf("error enumerating IMEX devices: %w", err)
}
for k, v := range gms {
alldevices[k] = v
}
}

if deviceClasses.Has(ImexChannelType) {
imex, err := l.enumerateImexChannels(config)
if err != nil {
return nil, fmt.Errorf("error enumerating IMEX devices: %w", err)
}
for k, v := range imex {
alldevices[k] = v
}
}

return alldevices, nil
}

func (l deviceLib) enumerateGpusAndMigDevices(config *Config) (AllocatableDevices, error) {
if err := l.Init(); err != nil {
return nil, err
}
defer l.alwaysShutdown()

alldevices := make(AllocatableDevices)
devices := make(AllocatableDevices)
deviceClasses := config.flags.deviceClasses
err := l.VisitDevices(func(i int, d nvdev.Device) error {
gpuInfo, err := l.getGpuInfo(i, d)
if err != nil {
return fmt.Errorf("error getting info for GPU %d: %w", i, err)
}

migs, err := l.getMigDevices(gpuInfo)
if err != nil {
return fmt.Errorf("error getting MIG devices for GPU %d: %w", i, err)
}

for _, migDeviceInfo := range migs {
if deviceClasses.Has(GpuDeviceType) && !gpuInfo.migEnabled {
deviceInfo := &AllocatableDevice{
Mig: migDeviceInfo,
Gpu: gpuInfo,
}
alldevices[migDeviceInfo.CanonicalName()] = deviceInfo
devices[gpuInfo.CanonicalName()] = deviceInfo
}

if !gpuInfo.migEnabled && len(migs) == 0 {
deviceInfo := &AllocatableDevice{
Gpu: gpuInfo,
if deviceClasses.Has(MigDeviceType) {
migs, err := l.getMigDevices(gpuInfo)
if err != nil {
return fmt.Errorf("error getting MIG devices for GPU %d: %w", i, err)
}

for _, migDeviceInfo := range migs {
deviceInfo := &AllocatableDevice{
Mig: migDeviceInfo,
}
devices[migDeviceInfo.CanonicalName()] = deviceInfo
}
alldevices[gpuInfo.CanonicalName()] = deviceInfo
}

return nil
Expand All @@ -146,6 +176,12 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
return nil, fmt.Errorf("error visiting devices: %w", err)
}

return devices, nil
}

func (l deviceLib) enumerateImexChannels(config *Config) (AllocatableDevices, error) {
devices := make(AllocatableDevices)

imexChannelCount, err := l.getImexChannelCount()
if err != nil {
return nil, fmt.Errorf("error getting IMEX channel count: %w", err)
Expand All @@ -157,10 +193,10 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
deviceInfo := &AllocatableDevice{
ImexChannel: imexChannelInfo,
}
alldevices[imexChannelInfo.CanonicalName()] = deviceInfo
devices[imexChannelInfo.CanonicalName()] = deviceInfo
}

return alldevices, nil
return devices, nil
}

func (l deviceLib) getGpuInfo(index int, device nvdev.Device) (*GpuInfo, error) {
Expand Down
2 changes: 1 addition & 1 deletion cmd/nvidia-dra-plugin/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package main
const (
GpuDeviceType = "gpu"
MigDeviceType = "mig"
ImexChannelType = "imex-channel"
ImexChannelType = "imex"
UnknownDeviceType = "unknown"
)

Expand Down
9 changes: 4 additions & 5 deletions demo/clusters/kind/install-dra-driver.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ set -o pipefail

source "${CURRENT_DIR}/scripts/common.sh"

kubectl label node -l node-role.x-k8s.io/worker --overwrite nvidia.com/dra.kubelet-plugin=true
kubectl label node -l node-role.x-k8s.io/control-plane --overwrite nvidia.com/dra.controller=true

helm upgrade -i --create-namespace --namespace nvidia-dra-driver nvidia ${PROJECT_DIR}/deployments/helm/k8s-dra-driver \
deviceClasses=${1:-"gpu,mig,imex"}
helm upgrade -i --create-namespace --namespace nvidia nvidia-dra-driver ${PROJECT_DIR}/deployments/helm/k8s-dra-driver \
--set deviceClasses="{${deviceClasses}}" \
${NVIDIA_DRIVER_ROOT:+--set nvidiaDriverRoot=${NVIDIA_DRIVER_ROOT}} \
--wait

set +x
printf '\033[0;32m'
echo "Driver installation complete:"
kubectl get pod -n nvidia-dra-driver
kubectl get pod -n nvidia
printf '\033[0m'
32 changes: 32 additions & 0 deletions deployments/helm/k8s-dra-driver/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,35 @@ Create the name of the service account to use
{{- default "default" .Values.serviceAccount.name }}
{{- end }}
{{- end }}

{{/*
Check for the existence of an element in a list
*/}}
{{- define "k8s-dra-driver.listHas" -}}
{{- $listToCheck := index . 0 }}
{{- $valueToCheck := index . 1 }}

{{- $found := "" -}}
{{- range $listToCheck}}
{{- if eq . $valueToCheck }}
{{- $found = "true" -}}
{{- end }}
{{- end }}
{{- $found -}}
{{- end }}

{{/*
Filter a list by a set of valid values
*/}}
{{- define "k8s-dra-driver.filterList" -}}
{{- $listToFilter := index . 0 }}
{{- $validValues := index . 1 }}

{{- $result := list -}}
{{- range $validValues}}
{{- if include "k8s-dra-driver.listHas" (list $listToFilter .) }}
{{- $result = append $result . }}
{{- end }}
{{- end }}
{{- $result -}}
{{- end -}}
Loading