Skip to content

Commit 3fbde91

Browse files
committed
Refactor profiles into interfaces
1 parent 07790fb commit 3fbde91

File tree

6 files changed

+179
-129
lines changed

6 files changed

+179
-129
lines changed

cmd/dra-example-kubeletplugin/main.go

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@ import (
2727

2828
"github.com/urfave/cli/v2"
2929

30-
"k8s.io/apimachinery/pkg/runtime"
3130
coreclientset "k8s.io/client-go/kubernetes"
3231
"k8s.io/dynamic-resource-allocation/kubeletplugin"
33-
"k8s.io/dynamic-resource-allocation/resourceslice"
3432
"k8s.io/klog/v2"
3533

34+
"sigs.k8s.io/dra-example-driver/internal/profiles"
3635
"sigs.k8s.io/dra-example-driver/internal/profiles/gpu"
3736
"sigs.k8s.io/dra-example-driver/pkg/flags"
3837
)
@@ -60,16 +59,23 @@ type Config struct {
6059
coreclient coreclientset.Interface
6160
cancelMainCtx func(error)
6261

63-
configScheme *runtime.Scheme // scheme for opaque config types
64-
applyConfigFunc ApplyConfigFunc
65-
cdiClass string
66-
enumerateDevicesFunc func() (resourceslice.DriverResources, error)
62+
profile profiles.Profile
6763
}
6864

69-
var validProfiles = []string{
70-
gpu.ProfileName,
65+
var validProfiles = map[string]func(flags Flags) profiles.Profile{
66+
gpu.ProfileName: func(flags Flags) profiles.Profile {
67+
return gpu.NewProfile(flags.nodeName, flags.numDevices)
68+
},
7169
}
7270

71+
var validProfileNames = func() []string {
72+
var valid []string
73+
for profileName := range validProfiles {
74+
valid = append(valid, profileName)
75+
}
76+
return valid
77+
}()
78+
7379
func (c Config) DriverPluginPath() string {
7480
return filepath.Join(c.flags.kubeletPluginsDirectoryPath, c.flags.driverName)
7581
}
@@ -130,7 +136,7 @@ func newApp() *cli.App {
130136
},
131137
&cli.StringFlag{
132138
Name: "device-profile",
133-
Usage: fmt.Sprintf("Name of the device profile. Valid values are %q.", validProfiles),
139+
Usage: fmt.Sprintf("Name of the device profile. Valid values are %q.", validProfileNames),
134140
Value: gpu.ProfileName,
135141
Destination: &flags.profile,
136142
EnvVars: []string{"DEVICE_PROFILE"},
@@ -168,34 +174,15 @@ func newApp() *cli.App {
168174
flags.driverName = flags.profile + ".example.com"
169175
}
170176

171-
var (
172-
sb runtime.SchemeBuilder
173-
applyConfigFunc ApplyConfigFunc
174-
cdiClass string
175-
enumerateDevicesFunc func() (resourceslice.DriverResources, error)
176-
)
177-
switch flags.profile {
178-
case gpu.ProfileName:
179-
sb = gpu.ConfigSchemeBuilder
180-
applyConfigFunc = gpu.ApplyConfig
181-
cdiClass = gpu.CDIClass
182-
enumerateDevicesFunc = gpu.EnumerateAllPossibleDevices(flags.nodeName, flags.numDevices)
183-
default:
184-
return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfiles)
185-
}
186-
187-
configScheme := runtime.NewScheme()
188-
if err := sb.AddToScheme(configScheme); err != nil {
189-
return fmt.Errorf("create config scheme: %w", err)
177+
newProfile, ok := validProfiles[flags.profile]
178+
if !ok {
179+
return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfileNames)
190180
}
191181

192182
config := &Config{
193-
flags: flags,
194-
coreclient: clientSets.Core,
195-
configScheme: configScheme,
196-
applyConfigFunc: applyConfigFunc,
197-
cdiClass: cdiClass,
198-
enumerateDevicesFunc: enumerateDevicesFunc,
183+
flags: flags,
184+
coreclient: clientSets.Core,
185+
profile: newProfile(*flags),
199186
}
200187

201188
return RunPlugin(ctx, config)

cmd/dra-example-kubeletplugin/state.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ import (
3434
type AllocatableDevices map[string]resourceapi.Device
3535
type PreparedClaims map[string]profiles.PreparedDevices
3636

37-
type ApplyConfigFunc func(cconfig runtime.Object, results []*resourceapi.DeviceRequestAllocationResult) (profiles.PerDeviceCDIContainerEdits, error)
38-
3937
type OpaqueDeviceConfig struct {
4038
Requests []string
4139
Config runtime.Object
@@ -49,16 +47,16 @@ type DeviceState struct {
4947
allocatable AllocatableDevices
5048
checkpointManager checkpointmanager.CheckpointManager
5149
configDecoder runtime.Decoder
52-
applyConfigFunc ApplyConfigFunc
50+
configHandler profiles.ConfigHandler
5351
}
5452

5553
func NewDeviceState(config *Config) (*DeviceState, error) {
56-
driverResources, err := config.enumerateDevicesFunc()
54+
driverResources, err := config.profile.EnumerateDevices()
5755
if err != nil {
5856
return nil, fmt.Errorf("error enumerating all possible devices: %v", err)
5957
}
6058

61-
cdi, err := NewCDIHandler(config.flags.cdiRoot, config.flags.driverName, config.cdiClass)
59+
cdi, err := NewCDIHandler(config.flags.cdiRoot, config.flags.driverName, config.flags.profile)
6260
if err != nil {
6361
return nil, fmt.Errorf("unable to create CDI handler: %v", err)
6462
}
@@ -73,11 +71,18 @@ func NewDeviceState(config *Config) (*DeviceState, error) {
7371
return nil, fmt.Errorf("unable to create checkpoint manager: %v", err)
7472
}
7573

74+
configScheme := runtime.NewScheme()
75+
configHandler := config.profile
76+
sb := configHandler.SchemeBuilder()
77+
if err := sb.AddToScheme(configScheme); err != nil {
78+
return nil, fmt.Errorf("create config scheme: %w", err)
79+
}
80+
7681
// Set up a json serializer to decode our types.
7782
decoder := json.NewSerializerWithOptions(
7883
json.DefaultMetaFactory,
79-
config.configScheme,
80-
config.configScheme,
84+
configScheme,
85+
configScheme,
8186
json.SerializerOptions{
8287
Pretty: true, Strict: true,
8388
},
@@ -97,7 +102,7 @@ func NewDeviceState(config *Config) (*DeviceState, error) {
97102
allocatable: allocatable,
98103
checkpointManager: checkpointManager,
99104
configDecoder: decoder,
100-
applyConfigFunc: config.applyConfigFunc,
105+
configHandler: configHandler,
101106
}
102107

103108
checkpoints, err := state.checkpointManager.ListCheckpoints()
@@ -228,7 +233,7 @@ func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (profiles
228233
perDeviceCDIContainerEdits := make(profiles.PerDeviceCDIContainerEdits)
229234
for config, results := range configResultsMap {
230235
// Apply the config to the list of results associated with it.
231-
containerEdits, err := s.applyConfigFunc(config, results)
236+
containerEdits, err := s.configHandler.ApplyConfig(config, results)
232237
if err != nil {
233238
return nil, fmt.Errorf("error applying config: %w", err)
234239
}

cmd/dra-example-webhook/main.go

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
kjson "k8s.io/apimachinery/pkg/runtime/serializer/json"
3434
"k8s.io/klog/v2"
3535

36+
"sigs.k8s.io/dra-example-driver/internal/profiles"
3637
"sigs.k8s.io/dra-example-driver/internal/profiles/gpu"
3738
"sigs.k8s.io/dra-example-driver/pkg/flags"
3839
)
@@ -47,12 +48,10 @@ type Flags struct {
4748
driverName string
4849
}
4950

50-
var configScheme = runtime.NewScheme()
51-
5251
type validator func(runtime.Object) error
5352

54-
var validProfiles = []string{
55-
gpu.ProfileName,
53+
var validProfiles = map[string]profiles.ConfigHandler{
54+
gpu.ProfileName: gpu.Profile{},
5655
}
5756

5857
func main() {
@@ -114,28 +113,26 @@ func newApp() *cli.App {
114113
return flags.loggingConfig.Apply()
115114
},
116115
Action: func(c *cli.Context) error {
117-
var (
118-
sb runtime.SchemeBuilder
119-
validate validator
120-
)
121-
switch flags.profile {
122-
case gpu.ProfileName:
123-
sb = gpu.ConfigSchemeBuilder
124-
validate = gpu.ValidateConfig
125-
default:
126-
return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, validProfiles)
116+
configHandler, ok := validProfiles[flags.profile]
117+
if !ok {
118+
var valid []string
119+
for profileName := range validProfiles {
120+
valid = append(valid, profileName)
121+
}
122+
return fmt.Errorf("invalid device profile %q, valid profiles are %q", flags.profile, valid)
127123
}
128124

129125
if flags.driverName == "" {
130126
flags.driverName = flags.profile + ".example.com"
131127
}
132128

133-
if err := sb.AddToScheme(configScheme); err != nil {
134-
return fmt.Errorf("create config scheme: %w", err)
129+
mux, err := newMux(configHandler, flags.driverName)
130+
if err != nil {
131+
return fmt.Errorf("create HTTP mux: %w", err)
135132
}
136133

137134
server := &http.Server{
138-
Handler: newMux(newConfigDecoder(), validate, flags.driverName),
135+
Handler: mux,
139136
Addr: fmt.Sprintf(":%d", flags.port),
140137
}
141138
klog.Info("starting webhook server on", server.Addr)
@@ -146,29 +143,33 @@ func newApp() *cli.App {
146143
return app
147144
}
148145

149-
func newConfigDecoder() runtime.Decoder {
150-
// Set up a json serializer to decode our types.
151-
return kjson.NewSerializerWithOptions(
146+
func newMux(configHandler profiles.ConfigHandler, driverName string) (*http.ServeMux, error) {
147+
configScheme := runtime.NewScheme()
148+
sb := configHandler.SchemeBuilder()
149+
if err := sb.AddToScheme(configScheme); err != nil {
150+
return nil, fmt.Errorf("create config scheme: %w", err)
151+
}
152+
configDecoder := kjson.NewSerializerWithOptions(
152153
kjson.DefaultMetaFactory,
153154
configScheme,
154155
configScheme,
155156
kjson.SerializerOptions{
156157
Pretty: true, Strict: true,
157158
},
158159
)
159-
}
160160

161-
func newMux(configDecoder runtime.Decoder, validate validator, driverName string) *http.ServeMux {
162161
mux := http.NewServeMux()
163-
mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder, validate, driverName))
164-
mux.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) {
165-
_, err := w.Write([]byte("ok"))
166-
if err != nil {
167-
http.Error(w, err.Error(), http.StatusInternalServerError)
168-
return
169-
}
170-
})
171-
return mux
162+
mux.HandleFunc("/validate-resource-claim-parameters", serveResourceClaim(configDecoder, configHandler.Validate, driverName))
163+
mux.HandleFunc("/readyz", readyHandler)
164+
return mux, nil
165+
}
166+
167+
func readyHandler(w http.ResponseWriter, req *http.Request) {
168+
_, err := w.Write([]byte("ok"))
169+
if err != nil {
170+
http.Error(w, err.Error(), http.StatusInternalServerError)
171+
return
172+
}
172173
}
173174

174175
func serveResourceClaim(configDecoder runtime.Decoder, validate validator, driverName string) func(http.ResponseWriter, *http.Request) {

cmd/dra-example-webhook/main_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ import (
4242
const driverName = "gpu.example.com"
4343

4444
func TestReadyEndpoint(t *testing.T) {
45-
s := httptest.NewServer(newMux(nil, nil, ""))
45+
s := httptest.NewServer(http.HandlerFunc(readyHandler))
4646
t.Cleanup(s.Close)
4747

48-
res, err := http.Get(s.URL + "/readyz")
48+
res, err := http.Get(s.URL)
4949
assert.NoError(t, err)
5050
assert.Equal(t, http.StatusOK, res.StatusCode)
5151
}
@@ -170,10 +170,11 @@ func TestResourceClaimValidatingWebhook(t *testing.T) {
170170
},
171171
}
172172

173-
sb := gpu.ConfigSchemeBuilder
174-
assert.NoError(t, sb.AddToScheme(configScheme))
173+
configHandler := gpu.Profile{}
174+
mux, err := newMux(configHandler, driverName)
175+
assert.NoError(t, err)
175176

176-
s := httptest.NewServer(newMux(newConfigDecoder(), gpu.ValidateConfig, driverName))
177+
s := httptest.NewServer(mux)
177178
t.Cleanup(s.Close)
178179

179180
for name, test := range tests {

0 commit comments

Comments
 (0)