Skip to content

Commit e8e3e43

Browse files
committed
Allow selection of device classes managed by driver
Signed-off-by: Kevin Klues <[email protected]>
1 parent 9b16df1 commit e8e3e43

File tree

17 files changed

+235
-38
lines changed

17 files changed

+235
-38
lines changed

cmd/nvidia-dra-controller/main.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/prometheus/client_golang/prometheus/promhttp"
2929
"github.com/urfave/cli/v2"
3030

31+
"k8s.io/apimachinery/pkg/util/sets"
3132
"k8s.io/component-base/metrics/legacyregistry"
3233
"k8s.io/klog/v2"
3334

@@ -49,6 +50,8 @@ type Flags struct {
4950
httpEndpoint string
5051
metricsPath string
5152
profilePath string
53+
54+
deviceClasses sets.Set[string]
5255
}
5356

5457
type Config struct {
@@ -105,6 +108,12 @@ func newApp() *cli.App {
105108
Destination: &flags.profilePath,
106109
EnvVars: []string{"PPROF_PATH"},
107110
},
111+
&cli.StringSliceFlag{
112+
Name: "device-classes",
113+
Usage: "The supported set of DRA device classes",
114+
Value: cli.NewStringSlice(GpuDeviceType, MigDeviceType, ImexChannelType),
115+
EnvVars: []string{"DEVICE_CLASSES"},
116+
},
108117
}
109118

110119
cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
@@ -125,6 +134,7 @@ func newApp() *cli.App {
125134
Action: func(c *cli.Context) error {
126135
ctx := c.Context
127136
mux := http.NewServeMux()
137+
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)
128138

129139
clientSets, err := flags.kubeClientConfig.NewClientSets()
130140
if err != nil {
@@ -144,9 +154,11 @@ func newApp() *cli.App {
144154
}
145155
}
146156

147-
err = StartIMEXManager(ctx, config)
148-
if err != nil {
149-
return fmt.Errorf("start IMEX manager: %w", err)
157+
if flags.deviceClasses.Has(ImexChannelType) {
158+
err = StartIMEXManager(ctx, config)
159+
if err != nil {
160+
return fmt.Errorf("start IMEX manager: %w", err)
161+
}
150162
}
151163

152164
<-ctx.Done()

cmd/nvidia-dra-controller/types.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package main
18+
19+
const (
20+
GpuDeviceType = "gpu"
21+
MigDeviceType = "mig"
22+
ImexChannelType = "imex"
23+
UnknownDeviceType = "unknown"
24+
)

cmd/nvidia-dra-plugin/device_state.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func NewDeviceState(ctx context.Context, config *Config) (*DeviceState, error) {
6161
return nil, fmt.Errorf("failed to create device library: %w", err)
6262
}
6363

64-
allocatable, err := nvdevlib.enumerateAllPossibleDevices()
64+
allocatable, err := nvdevlib.enumerateAllPossibleDevices(config)
6565
if err != nil {
6666
return nil, fmt.Errorf("error enumerating all possible devices: %w", err)
6767
}

cmd/nvidia-dra-plugin/driver.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
6060
}
6161
driver.plugin = plugin
6262

63+
// If not responsible for advertising GPUs or MIG devices, we are done
64+
if !(config.flags.deviceClasses.Has(GpuDeviceType) || config.flags.deviceClasses.Has(MigDeviceType)) {
65+
return driver, nil
66+
}
67+
68+
// Otherwise, enumerate the set of GPU and MIG devices and publish them
6369
var resources kubeletplugin.Resources
6470
for _, device := range state.allocatable {
6571
// Explicitly exclude IMEX channels from being advertised here. They

cmd/nvidia-dra-plugin/main.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
"github.com/urfave/cli/v2"
2727

28+
"k8s.io/apimachinery/pkg/util/sets"
2829
"k8s.io/klog/v2"
2930

3031
"github.com/NVIDIA/k8s-dra-driver/internal/info"
@@ -50,6 +51,7 @@ type Flags struct {
5051
containerDriverRoot string
5152
hostDriverRoot string
5253
nvidiaCTKPath string
54+
deviceClasses sets.Set[string]
5355
}
5456

5557
type Config struct {
@@ -112,6 +114,12 @@ func newApp() *cli.App {
112114
Destination: &flags.nvidiaCTKPath,
113115
EnvVars: []string{"NVIDIA_CTK_PATH"},
114116
},
117+
&cli.StringSliceFlag{
118+
Name: "device-classes",
119+
Usage: "The supported set of DRA device classes",
120+
Value: cli.NewStringSlice(GpuDeviceType, MigDeviceType, ImexChannelType),
121+
EnvVars: []string{"DEVICE_CLASSES"},
122+
},
115123
}
116124
cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
117125
cliFlags = append(cliFlags, flags.loggingConfig.Flags()...)
@@ -130,6 +138,8 @@ func newApp() *cli.App {
130138
},
131139
Action: func(c *cli.Context) error {
132140
ctx := c.Context
141+
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)
142+
133143
clientSets, err := flags.kubeClientConfig.NewClientSets()
134144
if err != nil {
135145
return fmt.Errorf("create client: %w", err)

cmd/nvidia-dra-plugin/nvlib.go

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,36 +108,66 @@ func (l deviceLib) alwaysShutdown() {
108108
}
109109
}
110110

111-
func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
111+
func (l deviceLib) enumerateAllPossibleDevices(config *Config) (AllocatableDevices, error) {
112+
alldevices := make(AllocatableDevices)
113+
deviceClasses := config.flags.deviceClasses
114+
115+
if deviceClasses.Has(GpuDeviceType) || deviceClasses.Has(MigDeviceType) {
116+
gms, err := l.enumerateGpusAndMigDevices(config)
117+
if err != nil {
118+
return nil, fmt.Errorf("error enumerating IMEX devices: %w", err)
119+
}
120+
for k, v := range gms {
121+
alldevices[k] = v
122+
}
123+
}
124+
125+
if deviceClasses.Has(ImexChannelType) {
126+
imex, err := l.enumerateImexChannels(config)
127+
if err != nil {
128+
return nil, fmt.Errorf("error enumerating IMEX devices: %w", err)
129+
}
130+
for k, v := range imex {
131+
alldevices[k] = v
132+
}
133+
}
134+
135+
return alldevices, nil
136+
}
137+
138+
func (l deviceLib) enumerateGpusAndMigDevices(config *Config) (AllocatableDevices, error) {
112139
if err := l.Init(); err != nil {
113140
return nil, err
114141
}
115142
defer l.alwaysShutdown()
116143

117-
alldevices := make(AllocatableDevices)
144+
devices := make(AllocatableDevices)
145+
deviceClasses := config.flags.deviceClasses
118146
err := l.VisitDevices(func(i int, d nvdev.Device) error {
119147
gpuInfo, err := l.getGpuInfo(i, d)
120148
if err != nil {
121149
return fmt.Errorf("error getting info for GPU %d: %w", i, err)
122150
}
123151

124-
migs, err := l.getMigDevices(gpuInfo)
125-
if err != nil {
126-
return fmt.Errorf("error getting MIG devices for GPU %d: %w", i, err)
127-
}
128-
129-
for _, migDeviceInfo := range migs {
152+
if deviceClasses.Has(GpuDeviceType) && !gpuInfo.migEnabled {
130153
deviceInfo := &AllocatableDevice{
131-
Mig: migDeviceInfo,
154+
Gpu: gpuInfo,
132155
}
133-
alldevices[migDeviceInfo.CanonicalName()] = deviceInfo
156+
devices[gpuInfo.CanonicalName()] = deviceInfo
134157
}
135158

136-
if !gpuInfo.migEnabled && len(migs) == 0 {
137-
deviceInfo := &AllocatableDevice{
138-
Gpu: gpuInfo,
159+
if deviceClasses.Has(MigDeviceType) {
160+
migs, err := l.getMigDevices(gpuInfo)
161+
if err != nil {
162+
return fmt.Errorf("error getting MIG devices for GPU %d: %w", i, err)
163+
}
164+
165+
for _, migDeviceInfo := range migs {
166+
deviceInfo := &AllocatableDevice{
167+
Mig: migDeviceInfo,
168+
}
169+
devices[migDeviceInfo.CanonicalName()] = deviceInfo
139170
}
140-
alldevices[gpuInfo.CanonicalName()] = deviceInfo
141171
}
142172

143173
return nil
@@ -146,6 +176,12 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
146176
return nil, fmt.Errorf("error visiting devices: %w", err)
147177
}
148178

179+
return devices, nil
180+
}
181+
182+
func (l deviceLib) enumerateImexChannels(config *Config) (AllocatableDevices, error) {
183+
devices := make(AllocatableDevices)
184+
149185
imexChannelCount, err := l.getImexChannelCount()
150186
if err != nil {
151187
return nil, fmt.Errorf("error getting IMEX channel count: %w", err)
@@ -157,10 +193,10 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
157193
deviceInfo := &AllocatableDevice{
158194
ImexChannel: imexChannelInfo,
159195
}
160-
alldevices[imexChannelInfo.CanonicalName()] = deviceInfo
196+
devices[imexChannelInfo.CanonicalName()] = deviceInfo
161197
}
162198

163-
return alldevices, nil
199+
return devices, nil
164200
}
165201

166202
func (l deviceLib) getGpuInfo(index int, device nvdev.Device) (*GpuInfo, error) {

cmd/nvidia-dra-plugin/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package main
1919
const (
2020
GpuDeviceType = "gpu"
2121
MigDeviceType = "mig"
22-
ImexChannelType = "imex-channel"
22+
ImexChannelType = "imex"
2323
UnknownDeviceType = "unknown"
2424
)
2525

demo/clusters/kind/install-dra-driver.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ set -o pipefail
2222

2323
source "${CURRENT_DIR}/scripts/common.sh"
2424

25+
deviceClasses=${1:-"gpu,mig,imex"}
2526
helm upgrade -i --create-namespace --namespace nvidia nvidia-dra-driver ${PROJECT_DIR}/deployments/helm/k8s-dra-driver \
27+
--set deviceClasses="{${deviceClasses}}" \
2628
${NVIDIA_DRIVER_ROOT:+--set nvidiaDriverRoot=${NVIDIA_DRIVER_ROOT}} \
2729
--wait
2830

deployments/helm/k8s-dra-driver/templates/_helpers.tpl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,35 @@ Create the name of the service account to use
9595
{{- default "default" .Values.serviceAccount.name }}
9696
{{- end }}
9797
{{- end }}
98+
99+
{{/*
100+
Check for the existence of an element in a list
101+
*/}}
102+
{{- define "k8s-dra-driver.listHas" -}}
103+
{{- $listToCheck := index . 0 }}
104+
{{- $valueToCheck := index . 1 }}
105+
106+
{{- $found := "" -}}
107+
{{- range $listToCheck}}
108+
{{- if eq . $valueToCheck }}
109+
{{- $found = "true" -}}
110+
{{- end }}
111+
{{- end }}
112+
{{- $found -}}
113+
{{- end }}
114+
115+
{{/*
116+
Filter a list by a set of valid values
117+
*/}}
118+
{{- define "k8s-dra-driver.filterList" -}}
119+
{{- $listToFilter := index . 0 }}
120+
{{- $validValues := index . 1 }}
121+
122+
{{- $result := list -}}
123+
{{- range $validValues}}
124+
{{- if include "k8s-dra-driver.listHas" (list $listToFilter .) }}
125+
{{- $result = append $result . }}
126+
{{- end }}
127+
{{- end }}
128+
{{- $result -}}
129+
{{- end -}}

deployments/helm/k8s-dra-driver/templates/controller.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright 2024 NVIDIA CORPORATION
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
{{- if (include "k8s-dra-driver.listHas" (list $.Values.deviceClasses "imex")) }}
16+
{{- $deviceClasses := include "k8s-dra-driver.filterList" (list $.Values.deviceClasses (list "imex")) }}
117
---
218
apiVersion: apps/v1
319
kind: Deployment
@@ -40,6 +56,8 @@ spec:
4056
resources:
4157
{{- toYaml .Values.controller.containers.controller.resources | nindent 10 }}
4258
env:
59+
- name: DEVICE_CLASSES
60+
value: {{ .Values.deviceClasses | join "," }}
4361
- name: POD_NAME
4462
valueFrom:
4563
fieldRef:
@@ -60,3 +78,4 @@ spec:
6078
tolerations:
6179
{{- toYaml . | nindent 8 }}
6280
{{- end }}
81+
{{- end }}

0 commit comments

Comments
 (0)