@@ -23,11 +23,13 @@ import (
2323 "sync"
2424 "time"
2525
26+ resourceapi "k8s.io/api/resource/v1beta1"
2627 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
28+ "k8s.io/apimachinery/pkg/types"
2729 coreclientset "k8s.io/client-go/kubernetes"
2830 "k8s.io/dynamic-resource-allocation/kubeletplugin"
31+ "k8s.io/dynamic-resource-allocation/resourceslice"
2932 "k8s.io/klog/v2"
30- drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1"
3133
3234 "github.com/NVIDIA/k8s-dra-driver-gpu/pkg/workqueue"
3335)
@@ -38,63 +40,65 @@ const ErrorRetryMaxTimeout = 45 * time.Second
3840
3941// permanentError defines an error indicating that it is permanent.
4042// By default, every error will be retried up to ErrorRetryMaxTimeout.
41- // Errors marked as permament will not be retried.
43+ // Errors marked as permanent will not be retried.
4244type permanentError struct { error }
4345
4446func isPermanentError (err error ) bool {
4547 return errors .As (err , & permanentError {})
4648}
4749
48- var _ drapbv1.DRAPluginServer = & driver {}
49-
5050type driver struct {
5151 sync.Mutex
52- client coreclientset.Interface
53- plugin kubeletplugin.DRAPlugin
54- state * DeviceState
52+ client coreclientset.Interface
53+ pluginhelper * kubeletplugin.Helper
54+ state * DeviceState
5555}
5656
5757func NewDriver (ctx context.Context , config * Config ) (* driver , error ) {
58- driver := & driver {
59- client : config .clientsets .Core ,
60- }
61-
6258 state , err := NewDeviceState (ctx , config )
6359 if err != nil {
6460 return nil , err
6561 }
66- driver .state = state
6762
68- plugin , err := kubeletplugin .Start (
63+ driver := & driver {
64+ client : config .clientsets .Core ,
65+ state : state ,
66+ }
67+
68+ helper , err := kubeletplugin .Start (
6969 ctx ,
70- [] any { driver } ,
70+ driver ,
7171 kubeletplugin .KubeClient (driver .client ),
7272 kubeletplugin .NodeName (config .flags .nodeName ),
7373 kubeletplugin .DriverName (DriverName ),
74- kubeletplugin .RegistrarSocketPath (PluginRegistrationPath ),
75- kubeletplugin .PluginSocketPath (DriverPluginSocketPath ),
76- kubeletplugin .KubeletPluginSocketPath (DriverPluginSocketPath ))
74+ )
7775 if err != nil {
7876 return nil , err
7977 }
80- driver .plugin = plugin
78+ driver .pluginhelper = helper
8179
8280 // Enumerate the set of ComputeDomain daemon devices and publish them
83- var resources kubeletplugin. Resources
81+ var resourceSlice resourceslice. Slice
8482 for _ , device := range state .allocatable {
8583 // Explicitly exclude ComputeDomain channels from being advertised here. They
8684 // are instead advertised in as a network resource from the control plane.
8785 if device .Type () == ComputeDomainChannelType && device .Channel .ID != 0 {
8886 continue
8987 }
90- resources .Devices = append (resources .Devices , device .GetDevice ())
88+ resourceSlice .Devices = append (resourceSlice .Devices , device .GetDevice ())
89+ }
90+
91+ resources := resourceslice.DriverResources {
92+ Pools : map [string ]resourceslice.Pool {
93+ config .flags .nodeName : {Slices : []resourceslice.Slice {resourceSlice }},
94+ },
9195 }
9296
9397 if err := state .computeDomainManager .Start (ctx ); err != nil {
9498 return nil , err
9599 }
96100
97- if err := plugin .PublishResources (ctx , resources ); err != nil {
101+ if err := driver . pluginhelper .PublishResources (ctx , resources ); err != nil {
98102 return nil , err
99103 }
100104
@@ -108,28 +112,28 @@ func (d *driver) Shutdown() error {
108112 if err := d .state .computeDomainManager .Stop (); err != nil {
109113 return fmt .Errorf ("error stopping ComputeDomainManager: %w" , err )
110114 }
111- d .plugin .Stop ()
115+ d .pluginhelper .Stop ()
112116 return nil
113117}
114118
115- func (d * driver ) NodePrepareResources (ctx context.Context , req * drapbv1.NodePrepareResourcesRequest ) (* drapbv1.NodePrepareResourcesResponse , error ) {
116- klog .Infof ("NodePrepareResource is called: number of claims: %d" , len (req .Claims ))
117- preparedResources := & drapbv1.NodePrepareResourcesResponse {Claims : map [string ]* drapbv1.NodePrepareResourceResponse {}}
119+ func (d * driver ) PrepareResourceClaims (ctx context.Context , claims []* resourceapi.ResourceClaim ) (map [types.UID ]kubeletplugin.PrepareResult , error ) {
120+ klog .Infof ("PrepareResourceClaims called with %d claim(s)" , len (claims ))
118121
119122 var wg sync.WaitGroup
120123 ctx , cancel := context .WithTimeout (ctx , ErrorRetryMaxTimeout )
121124 workQueue := workqueue .New (workqueue .DefaultControllerRateLimiter ())
125+ results := make (map [types.UID ]kubeletplugin.PrepareResult )
122126
123- for _ , claim := range req . Claims {
127+ for _ , claim := range claims {
124128 wg .Add (1 )
125129 workQueue .EnqueueRaw (claim , func (ctx context.Context , obj any ) error {
126- done , prepared := d .nodePrepareResource (ctx , claim )
130+ done , res := d .nodePrepareResource (ctx , claim )
127131 if done {
128- preparedResources . Claims [claim .UID ] = prepared
132+ results [claim .UID ] = res
129133 wg .Done ()
130134 return nil
131135 }
132- return fmt .Errorf ("%s " , prepared . Error )
136+ return fmt .Errorf ("%w " , res . Err )
133137 })
134138 }
135139
@@ -139,28 +143,27 @@ func (d *driver) NodePrepareResources(ctx context.Context, req *drapbv1.NodePrep
139143 }()
140144
141145 workQueue .Run (ctx )
142-
143- return preparedResources , nil
146+ return results , nil
144147}
145148
146- func (d * driver ) NodeUnprepareResources (ctx context.Context , req * drapbv1.NodeUnprepareResourcesRequest ) (* drapbv1.NodeUnprepareResourcesResponse , error ) {
147- klog .Infof ("NodeUnprepareResource is called: number of claims: %d" , len (req .Claims ))
148- unpreparedResources := & drapbv1.NodeUnprepareResourcesResponse {Claims : map [string ]* drapbv1.NodeUnprepareResourceResponse {}}
149+ func (d * driver ) UnprepareResourceClaims (ctx context.Context , claims []kubeletplugin.NamespacedObject ) (map [types.UID ]error , error ) {
150+ klog .Infof ("UnprepareResourceClaims called with %d claim(s)" , len (claims ))
149151
150152 var wg sync.WaitGroup
151153 ctx , cancel := context .WithTimeout (ctx , ErrorRetryMaxTimeout )
152154 workQueue := workqueue .New (workqueue .DefaultControllerRateLimiter ())
155+ results := make (map [types.UID ]error )
153156
154- for _ , claim := range req . Claims {
157+ for _ , claim := range claims {
155158 wg .Add (1 )
156159 workQueue .EnqueueRaw (claim , func (ctx context.Context , obj any ) error {
157- done , unprepared := d .nodeUnprepareResource (ctx , claim )
160+ done , err := d .nodeUnprepareResource (ctx , claim )
158161 if done {
159- unpreparedResources . Claims [claim .UID ] = unprepared
162+ results [claim .UID ] = err
160163 wg .Done ()
161164 return nil
162165 }
163- return fmt .Errorf ("%s " , unprepared . Error )
166+ return fmt .Errorf ("%w " , err )
164167 })
165168 }
166169
@@ -171,73 +174,76 @@ func (d *driver) NodeUnprepareResources(ctx context.Context, req *drapbv1.NodeUn
171174
172175 workQueue .Run (ctx )
173176
174- return unpreparedResources , nil
177+ return results , nil
175178}
176179
177- func (d * driver ) nodePrepareResource (ctx context.Context , claim * drapbv1. Claim ) (bool , * drapbv1. NodePrepareResourceResponse ) {
180+ func (d * driver ) nodePrepareResource (ctx context.Context , claim * resourceapi. ResourceClaim ) (bool , kubeletplugin. PrepareResult ) {
178181 d .Lock ()
179182 defer d .Unlock ()
180183
181- resourceClaim , err := d .client .ResourceV1beta1 ().ResourceClaims (claim .Namespace ).Get (
182- ctx ,
183- claim .Name ,
184- metav1.GetOptions {})
185- if err != nil {
186- ret := & drapbv1.NodePrepareResourceResponse {
187- Error : fmt .Sprintf ("failed to fetch ResourceClaim %s in namespace %s" , claim .Name , claim .Namespace ),
188- }
189- return isPermanentError (err ), ret
190- }
191-
192- if resourceClaim .Status .Allocation == nil {
193- ret := & drapbv1.NodePrepareResourceResponse {
194- Error : fmt .Sprintf ("no allocation set in ResourceClaim %s in namespace %s" , claim .Name , claim .Namespace ),
184+ if claim .Status .Allocation == nil {
185+ res := kubeletplugin.PrepareResult {
186+ Err : fmt .Errorf ("no allocation set in ResourceClaim %s in namespace %s" , claim .Name , claim .Namespace ),
195187 }
196- return true , ret
188+ return true , res
197189 }
198190
199- prepared , err := d .state .Prepare (ctx , resourceClaim )
191+ devs , err := d .state .Prepare (ctx , claim )
200192 if err != nil {
201- ret := & drapbv1. NodePrepareResourceResponse {
202- Error : fmt .Sprintf ("error preparing devices for claim %v: %v " , claim .UID , err ),
193+ res := kubeletplugin. PrepareResult {
194+ Err : fmt .Errorf ("error preparing devices for claim %v: %w " , claim .UID , err ),
203195 }
204- return isPermanentError (err ), ret
196+ return isPermanentError (err ), res
197+ }
198+
199+ // Translate type: drapbv1.Device -> kubeletplugin.Device
200+ // Maybe instead change return type of state.Prepare()
201+ var prepDevs []kubeletplugin.Device
202+ for _ , d := range devs {
203+ device := kubeletplugin.Device {
204+ Requests : d .RequestNames ,
205+ PoolName : d .PoolName ,
206+ DeviceName : d .DeviceName ,
207+ CDIDeviceIDs : d .CDIDeviceIDs ,
208+ }
209+ prepDevs = append (prepDevs , device )
205210 }
206211
207- klog .Infof ("Returning newly prepared devices for claim '%v': %v" , claim .UID , prepared )
208- return true , & drapbv1. NodePrepareResourceResponse {Devices : prepared }
212+ klog .Infof ("Returning newly prepared devices for claim '%v': %v" , claim .UID , prepDevs )
213+ return true , kubeletplugin. PrepareResult {Devices : prepDevs }
209214}
210215
211- func (d * driver ) nodeUnprepareResource (ctx context.Context , claim * drapbv1. Claim ) (bool , * drapbv1. NodeUnprepareResourceResponse ) {
216+ func (d * driver ) nodeUnprepareResource (ctx context.Context , claimNs kubeletplugin. NamespacedObject ) (bool , error ) {
212217 d .Lock ()
213218 defer d .Unlock ()
214219
215- resourceClaim , err := d .client .ResourceV1beta1 ().ResourceClaims (claim .Namespace ).Get (
220+ // Fetching the resource claim should not be needed (and not be done) in the
221+ // unprepare code path. Any state required during unprepare can be stored
222+ // via checkpointing.
223+ claim , err := d .client .ResourceV1beta1 ().ResourceClaims (claimNs .Namespace ).Get (
216224 ctx ,
217- claim .Name ,
225+ claimNs .Name ,
218226 metav1.GetOptions {})
227+
219228 if err != nil {
220- ret := & drapbv1.NodeUnprepareResourceResponse {
221- Error : fmt .Sprintf ("failed to fetch ResourceClaim %s in namespace %s" , claim .Name , claim .Namespace ),
222- }
223- return isPermanentError (err ), ret
229+ return isPermanentError (err ), fmt .Errorf (
230+ "failed to fetch ResourceClaim %s in namespace %s: %w" ,
231+ claimNs .Name ,
232+ claimNs .Namespace ,
233+ err ,
234+ )
224235 }
225236
226- if resourceClaim .Status .Allocation == nil {
227- ret := & drapbv1.NodeUnprepareResourceResponse {
228- Error : fmt .Sprintf ("no allocation set in ResourceClaim %s in namespace %s" , claim .Name , claim .Namespace ),
229- }
230- return true , ret
237+ if claim .Status .Allocation == nil {
238+ return true , fmt .Errorf ("no allocation set in ResourceClaim %s in namespace %s" , claim .Name , claim .Namespace )
231239 }
232240
233- if err := d .state .Unprepare (ctx , resourceClaim ); err != nil {
234- ret := & drapbv1.NodeUnprepareResourceResponse {
235- Error : fmt .Sprintf ("error unpreparing devices for claim %v: %v" , claim .UID , err ),
236- }
237- return isPermanentError (err ), ret
241+ if err := d .state .Unprepare (ctx , claim ); err != nil {
242+ return isPermanentError (err ), fmt .Errorf ("error unpreparing devices for claim '%v': %w" , claim .UID , err )
238243 }
239244
240- return true , & drapbv1.NodeUnprepareResourceResponse {}
245+ klog .Infof ("unprepared devices for claim '%v'" , claim .UID )
246+ return true , nil
241247}
242248
243249// TODO: implement loop to remove CDI files from the CDI path for claimUIDs
0 commit comments