Skip to content

Commit 322a3a7

Browse files
authored
Merge pull request #185 from klueska/rename-to-gpu-plugin
Add support to selectively decide which device classes to support via helm
2 parents 215a49a + e8e3e43 commit 322a3a7

File tree

18 files changed

+274
-64
lines changed

18 files changed

+274
-64
lines changed

cmd/nvidia-dra-controller/main.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/prometheus/client_golang/prometheus/promhttp"
2929
"github.com/urfave/cli/v2"
3030

31+
"k8s.io/apimachinery/pkg/util/sets"
3132
"k8s.io/component-base/metrics/legacyregistry"
3233
"k8s.io/klog/v2"
3334

@@ -49,6 +50,8 @@ type Flags struct {
4950
httpEndpoint string
5051
metricsPath string
5152
profilePath string
53+
54+
deviceClasses sets.Set[string]
5255
}
5356

5457
type Config struct {
@@ -105,6 +108,12 @@ func newApp() *cli.App {
105108
Destination: &flags.profilePath,
106109
EnvVars: []string{"PPROF_PATH"},
107110
},
111+
&cli.StringSliceFlag{
112+
Name: "device-classes",
113+
Usage: "The supported set of DRA device classes",
114+
Value: cli.NewStringSlice(GpuDeviceType, MigDeviceType, ImexChannelType),
115+
EnvVars: []string{"DEVICE_CLASSES"},
116+
},
108117
}
109118

110119
cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
@@ -125,6 +134,7 @@ func newApp() *cli.App {
125134
Action: func(c *cli.Context) error {
126135
ctx := c.Context
127136
mux := http.NewServeMux()
137+
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)
128138

129139
clientSets, err := flags.kubeClientConfig.NewClientSets()
130140
if err != nil {
@@ -144,9 +154,11 @@ func newApp() *cli.App {
144154
}
145155
}
146156

147-
err = StartIMEXManager(ctx, config)
148-
if err != nil {
149-
return fmt.Errorf("start IMEX manager: %w", err)
157+
if flags.deviceClasses.Has(ImexChannelType) {
158+
err = StartIMEXManager(ctx, config)
159+
if err != nil {
160+
return fmt.Errorf("start IMEX manager: %w", err)
161+
}
150162
}
151163

152164
<-ctx.Done()

cmd/nvidia-dra-controller/types.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package main
18+
19+
const (
20+
GpuDeviceType = "gpu"
21+
MigDeviceType = "mig"
22+
ImexChannelType = "imex"
23+
UnknownDeviceType = "unknown"
24+
)

cmd/nvidia-dra-plugin/cdi.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,13 @@ func (cdi *CDIHandler) CreateStandardDeviceSpecFile(allocatable AllocatableDevic
172172
return fmt.Errorf("failed to get common CDI spec edits: %w", err)
173173
}
174174

175+
// Make sure that NVIDIA_VISIBLE_DEVICES is set to void to avoid the
176+
// nvidia-container-runtime honoring it in addition to the underlying
177+
// runtime honoring CDI.
178+
commonEdits.ContainerEdits.Env = append(
179+
commonEdits.ContainerEdits.Env,
180+
"NVIDIA_VISIBLE_DEVICES=void")
181+
175182
// Generate device specs for all full GPUs and MIG devices.
176183
var deviceSpecs []cdispec.Device
177184
for _, device := range allocatable {
@@ -223,25 +230,16 @@ func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, preparedDevices Prep
223230
// Generate claim specific specs for each device.
224231
var deviceSpecs []cdispec.Device
225232
for _, group := range preparedDevices {
226-
// Include this per-device, rather than as a top-level edit so that
227-
// each device spec is never empty and the spec file gets created
228-
// without error.
229-
claimDeviceEdits := cdiapi.ContainerEdits{
230-
ContainerEdits: &cdispec.ContainerEdits{
231-
Env: []string{
232-
"NVIDIA_VISIBLE_DEVICES=void",
233-
},
234-
},
233+
// If there are no edits passed back as prt of the device config state, skip it
234+
if group.ConfigState.containerEdits == nil {
235+
continue
235236
}
236237

237-
// Apply any edits passed back as part of the device config state.
238-
claimDeviceEdits.Append(group.ConfigState.containerEdits)
239-
240-
// Apply edits to all devices.
238+
// Apply any edits passed back as part of the device config state to all devices
241239
for _, device := range group.Devices {
242240
deviceSpec := cdispec.Device{
243241
Name: fmt.Sprintf("%s-%s", claimUID, device.CanonicalName()),
244-
ContainerEdits: *claimDeviceEdits.ContainerEdits,
242+
ContainerEdits: *group.ConfigState.containerEdits.ContainerEdits,
245243
}
246244

247245
deviceSpecs = append(deviceSpecs, deviceSpec)

cmd/nvidia-dra-plugin/device_state.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func NewDeviceState(ctx context.Context, config *Config) (*DeviceState, error) {
6161
return nil, fmt.Errorf("failed to create device library: %w", err)
6262
}
6363

64-
allocatable, err := nvdevlib.enumerateAllPossibleDevices()
64+
allocatable, err := nvdevlib.enumerateAllPossibleDevices(config)
6565
if err != nil {
6666
return nil, fmt.Errorf("error enumerating all possible devices: %w", err)
6767
}

cmd/nvidia-dra-plugin/driver.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import (
3030

3131
type driver struct {
3232
sync.Mutex
33-
doneCh chan struct{}
3433
client coreclientset.Interface
3534
plugin kubeletplugin.DRAPlugin
3635
state *DeviceState
@@ -61,6 +60,12 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
6160
}
6261
driver.plugin = plugin
6362

63+
// If not responsible for advertising GPUs or MIG devices, we are done
64+
if !(config.flags.deviceClasses.Has(GpuDeviceType) || config.flags.deviceClasses.Has(MigDeviceType)) {
65+
return driver, nil
66+
}
67+
68+
// Otherwise, enumerate the set of GPU and MIG devices and publish them
6469
var resources kubeletplugin.Resources
6570
for _, device := range state.allocatable {
6671
// Explicitly exclude IMEX channels from being advertised here. They
@@ -79,7 +84,7 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
7984
}
8085

8186
func (d *driver) Shutdown(ctx context.Context) error {
82-
close(d.doneCh)
87+
d.plugin.Stop()
8388
return nil
8489
}
8590

cmd/nvidia-dra-plugin/main.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
"github.com/urfave/cli/v2"
2727

28+
"k8s.io/apimachinery/pkg/util/sets"
2829
"k8s.io/klog/v2"
2930

3031
"github.com/NVIDIA/k8s-dra-driver/internal/info"
@@ -50,6 +51,7 @@ type Flags struct {
5051
containerDriverRoot string
5152
hostDriverRoot string
5253
nvidiaCTKPath string
54+
deviceClasses sets.Set[string]
5355
}
5456

5557
type Config struct {
@@ -112,6 +114,12 @@ func newApp() *cli.App {
112114
Destination: &flags.nvidiaCTKPath,
113115
EnvVars: []string{"NVIDIA_CTK_PATH"},
114116
},
117+
&cli.StringSliceFlag{
118+
Name: "device-classes",
119+
Usage: "The supported set of DRA device classes",
120+
Value: cli.NewStringSlice(GpuDeviceType, MigDeviceType, ImexChannelType),
121+
EnvVars: []string{"DEVICE_CLASSES"},
122+
},
115123
}
116124
cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
117125
cliFlags = append(cliFlags, flags.loggingConfig.Flags()...)
@@ -130,6 +138,8 @@ func newApp() *cli.App {
130138
},
131139
Action: func(c *cli.Context) error {
132140
ctx := c.Context
141+
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)
142+
133143
clientSets, err := flags.kubeClientConfig.NewClientSets()
134144
if err != nil {
135145
return fmt.Errorf("create client: %w", err)

cmd/nvidia-dra-plugin/nvlib.go

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,36 +108,66 @@ func (l deviceLib) alwaysShutdown() {
108108
}
109109
}
110110

111-
func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
111+
func (l deviceLib) enumerateAllPossibleDevices(config *Config) (AllocatableDevices, error) {
112+
alldevices := make(AllocatableDevices)
113+
deviceClasses := config.flags.deviceClasses
114+
115+
if deviceClasses.Has(GpuDeviceType) || deviceClasses.Has(MigDeviceType) {
116+
gms, err := l.enumerateGpusAndMigDevices(config)
117+
if err != nil {
118+
return nil, fmt.Errorf("error enumerating IMEX devices: %w", err)
119+
}
120+
for k, v := range gms {
121+
alldevices[k] = v
122+
}
123+
}
124+
125+
if deviceClasses.Has(ImexChannelType) {
126+
imex, err := l.enumerateImexChannels(config)
127+
if err != nil {
128+
return nil, fmt.Errorf("error enumerating IMEX devices: %w", err)
129+
}
130+
for k, v := range imex {
131+
alldevices[k] = v
132+
}
133+
}
134+
135+
return alldevices, nil
136+
}
137+
138+
func (l deviceLib) enumerateGpusAndMigDevices(config *Config) (AllocatableDevices, error) {
112139
if err := l.Init(); err != nil {
113140
return nil, err
114141
}
115142
defer l.alwaysShutdown()
116143

117-
alldevices := make(AllocatableDevices)
144+
devices := make(AllocatableDevices)
145+
deviceClasses := config.flags.deviceClasses
118146
err := l.VisitDevices(func(i int, d nvdev.Device) error {
119147
gpuInfo, err := l.getGpuInfo(i, d)
120148
if err != nil {
121149
return fmt.Errorf("error getting info for GPU %d: %w", i, err)
122150
}
123151

124-
migs, err := l.getMigDevices(gpuInfo)
125-
if err != nil {
126-
return fmt.Errorf("error getting MIG devices for GPU %d: %w", i, err)
127-
}
128-
129-
for _, migDeviceInfo := range migs {
152+
if deviceClasses.Has(GpuDeviceType) && !gpuInfo.migEnabled {
130153
deviceInfo := &AllocatableDevice{
131-
Mig: migDeviceInfo,
154+
Gpu: gpuInfo,
132155
}
133-
alldevices[migDeviceInfo.CanonicalName()] = deviceInfo
156+
devices[gpuInfo.CanonicalName()] = deviceInfo
134157
}
135158

136-
if !gpuInfo.migEnabled && len(migs) == 0 {
137-
deviceInfo := &AllocatableDevice{
138-
Gpu: gpuInfo,
159+
if deviceClasses.Has(MigDeviceType) {
160+
migs, err := l.getMigDevices(gpuInfo)
161+
if err != nil {
162+
return fmt.Errorf("error getting MIG devices for GPU %d: %w", i, err)
163+
}
164+
165+
for _, migDeviceInfo := range migs {
166+
deviceInfo := &AllocatableDevice{
167+
Mig: migDeviceInfo,
168+
}
169+
devices[migDeviceInfo.CanonicalName()] = deviceInfo
139170
}
140-
alldevices[gpuInfo.CanonicalName()] = deviceInfo
141171
}
142172

143173
return nil
@@ -146,6 +176,12 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
146176
return nil, fmt.Errorf("error visiting devices: %w", err)
147177
}
148178

179+
return devices, nil
180+
}
181+
182+
func (l deviceLib) enumerateImexChannels(config *Config) (AllocatableDevices, error) {
183+
devices := make(AllocatableDevices)
184+
149185
imexChannelCount, err := l.getImexChannelCount()
150186
if err != nil {
151187
return nil, fmt.Errorf("error getting IMEX channel count: %w", err)
@@ -157,10 +193,10 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
157193
deviceInfo := &AllocatableDevice{
158194
ImexChannel: imexChannelInfo,
159195
}
160-
alldevices[imexChannelInfo.CanonicalName()] = deviceInfo
196+
devices[imexChannelInfo.CanonicalName()] = deviceInfo
161197
}
162198

163-
return alldevices, nil
199+
return devices, nil
164200
}
165201

166202
func (l deviceLib) getGpuInfo(index int, device nvdev.Device) (*GpuInfo, error) {

cmd/nvidia-dra-plugin/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package main
1919
const (
2020
GpuDeviceType = "gpu"
2121
MigDeviceType = "mig"
22-
ImexChannelType = "imex-channel"
22+
ImexChannelType = "imex"
2323
UnknownDeviceType = "unknown"
2424
)
2525

demo/clusters/kind/install-dra-driver.sh

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@ set -o pipefail
2222

2323
source "${CURRENT_DIR}/scripts/common.sh"
2424

25-
kubectl label node -l node-role.x-k8s.io/worker --overwrite nvidia.com/dra.kubelet-plugin=true
26-
kubectl label node -l node-role.x-k8s.io/control-plane --overwrite nvidia.com/dra.controller=true
27-
28-
helm upgrade -i --create-namespace --namespace nvidia-dra-driver nvidia ${PROJECT_DIR}/deployments/helm/k8s-dra-driver \
25+
deviceClasses=${1:-"gpu,mig,imex"}
26+
helm upgrade -i --create-namespace --namespace nvidia nvidia-dra-driver ${PROJECT_DIR}/deployments/helm/k8s-dra-driver \
27+
--set deviceClasses="{${deviceClasses}}" \
2928
${NVIDIA_DRIVER_ROOT:+--set nvidiaDriverRoot=${NVIDIA_DRIVER_ROOT}} \
3029
--wait
3130

3231
set +x
3332
printf '\033[0;32m'
3433
echo "Driver installation complete:"
35-
kubectl get pod -n nvidia-dra-driver
34+
kubectl get pod -n nvidia
3635
printf '\033[0m'

deployments/helm/k8s-dra-driver/templates/_helpers.tpl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,35 @@ Create the name of the service account to use
9595
{{- default "default" .Values.serviceAccount.name }}
9696
{{- end }}
9797
{{- end }}
98+
99+
{{/*
100+
Check for the existence of an element in a list
101+
*/}}
102+
{{- define "k8s-dra-driver.listHas" -}}
103+
{{- $listToCheck := index . 0 }}
104+
{{- $valueToCheck := index . 1 }}
105+
106+
{{- $found := "" -}}
107+
{{- range $listToCheck}}
108+
{{- if eq . $valueToCheck }}
109+
{{- $found = "true" -}}
110+
{{- end }}
111+
{{- end }}
112+
{{- $found -}}
113+
{{- end }}
114+
115+
{{/*
116+
Filter a list by a set of valid values
117+
*/}}
118+
{{- define "k8s-dra-driver.filterList" -}}
119+
{{- $listToFilter := index . 0 }}
120+
{{- $validValues := index . 1 }}
121+
122+
{{- $result := list -}}
123+
{{- range $validValues}}
124+
{{- if include "k8s-dra-driver.listHas" (list $listToFilter .) }}
125+
{{- $result = append $result . }}
126+
{{- end }}
127+
{{- end }}
128+
{{- $result -}}
129+
{{- end -}}

0 commit comments

Comments
 (0)