Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions cmd/nvidia-dra-controller/imex.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package main
import (
"context"
"fmt"
"sync"
"time"

v1 "k8s.io/api/core/v1"
Expand All @@ -40,25 +41,30 @@ const (
ImexChannelLimit = 128
)

type ImexManager struct {
waitGroup sync.WaitGroup
clientset kubernetes.Interface
}

type DriverResources resourceslice.DriverResources

func StartIMEXManager(ctx context.Context, config *Config) error {
func StartIMEXManager(ctx context.Context, config *Config) (*ImexManager, error) {
// Build a client set config
csconfig, err := config.flags.kubeClientConfig.NewClientSetConfig()
if err != nil {
return fmt.Errorf("error creating client set config: %w", err)
return nil, fmt.Errorf("error creating client set config: %w", err)
}

// Create a new clientset
clientset, err := kubernetes.NewForConfig(csconfig)
if err != nil {
return fmt.Errorf("error creating dynamic client: %w", err)
return nil, fmt.Errorf("error creating dynamic client: %w", err)
}

// Fetch the current Pod object
pod, err := clientset.CoreV1().Pods(config.flags.namespace).Get(ctx, config.flags.podName, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("error fetching pod: %w", err)
return nil, fmt.Errorf("error fetching pod: %w", err)
}

// Set the owner of the ResourceSlices we will create
Expand All @@ -69,32 +75,39 @@ func StartIMEXManager(ctx context.Context, config *Config) error {
UID: pod.UID,
}

// Create the manager itself
m := &ImexManager{
clientset: clientset,
}

// Stream added/removed IMEX domains from nodes over time
klog.Info("Start streaming IMEX domains from nodes...")
addedDomainsCh, removedDomainsCh, err := streamImexDomains(ctx, clientset)
addedDomainsCh, removedDomainsCh, err := m.streamImexDomains(ctx)
if err != nil {
return fmt.Errorf("error streaming IMEX domains: %w", err)
return nil, fmt.Errorf("error streaming IMEX domains: %w", err)
}

// Add/Remove resource slices from IMEX domains as they come and go
klog.Info("Start publishing IMEX channels to ResourceSlices...")
err = manageResourceSlices(ctx, clientset, owner, addedDomainsCh, removedDomainsCh)
err = m.manageResourceSlices(ctx, owner, addedDomainsCh, removedDomainsCh)
if err != nil {
return fmt.Errorf("error managing resource slices: %w", err)
return nil, fmt.Errorf("error managing resource slices: %w", err)
}

return nil
return m, nil
}

// manageResourceSlices reacts to added and removed IMEX domains and triggers the creation / removal of resource slices accordingly.
func manageResourceSlices(ctx context.Context, clientset kubernetes.Interface, owner resourceslice.Owner, addedDomainsCh <-chan string, removedDomainsCh <-chan string) error {
func (m *ImexManager) manageResourceSlices(ctx context.Context, owner resourceslice.Owner, addedDomainsCh <-chan string, removedDomainsCh <-chan string) error {
driverResources := resourceslice.DriverResources{}
controller, err := resourceslice.StartController(ctx, clientset, DriverName, owner, &driverResources)
controller, err := resourceslice.StartController(ctx, m.clientset, DriverName, owner, &driverResources)
if err != nil {
return fmt.Errorf("error starting resource slice controller: %w", err)
}

m.waitGroup.Add(1)
go func() {
defer m.waitGroup.Done()
for {
select {
case addedDomain := <-addedDomainsCh:
Expand All @@ -118,6 +131,21 @@ func manageResourceSlices(ctx context.Context, clientset kubernetes.Interface, o
return nil
}

// Stop stops a running ImexManager.
func (m *ImexManager) Stop() error {
if m == nil {
return nil
}

m.waitGroup.Wait()
klog.Info("Cleaning up all resourceSlices")
if err := m.cleanupResourceSlices(); err != nil {
return fmt.Errorf("error cleaning up resource slices: %w", err)
}

return nil
}

// DeepCopy will perform a deep copy of the provided DriverResources.
func (d DriverResources) DeepCopy() resourceslice.DriverResources {
driverResources := resourceslice.DriverResources{
Expand All @@ -130,7 +158,7 @@ func (d DriverResources) DeepCopy() resourceslice.DriverResources {
}

// streamImexDomains returns two channels that streams imexDomans that are added and removed from nodes over time.
func streamImexDomains(ctx context.Context, clientset kubernetes.Interface) (<-chan string, <-chan string, error) {
func (m *ImexManager) streamImexDomains(ctx context.Context) (<-chan string, <-chan string, error) {
// Create channels to stream IMEX domain ids that are added / removed
addedDomainCh := make(chan string)
removedDomainCh := make(chan string)
Expand All @@ -147,7 +175,7 @@ func streamImexDomains(ctx context.Context, clientset kubernetes.Interface) (<-c

// Create a shared informer factory for nodes
informerFactory := informers.NewSharedInformerFactoryWithOptions(
clientset,
m.clientset,
time.Minute*10, // Resync period
informers.WithTweakListOptions(func(options *metav1.ListOptions) {
options.LabelSelector = labelSelector
Expand Down Expand Up @@ -206,7 +234,11 @@ func streamImexDomains(ctx context.Context, clientset kubernetes.Interface) (<-c
}

// Start the informer and wait for it to sync
go informerFactory.Start(ctx.Done())
m.waitGroup.Add(1)
go func() {
defer m.waitGroup.Done()
informerFactory.Start(ctx.Done())
}()

// Wait for the informer caches to sync
if !cache.WaitForCacheSync(ctx.Done(), nodeInformer.HasSynced) {
Expand Down Expand Up @@ -259,3 +291,24 @@ func generateImexChannelPool(imexDomain string, numChannels int) resourceslice.P

return pool
}

// cleanupResourceSlices removes all resource slices created by the IMEX manager.
func (m *ImexManager) cleanupResourceSlices() error {
// Delete all resource slices created by the IMEX manager
ops := metav1.ListOptions{
FieldSelector: fmt.Sprintf("%s=%s", resourceapi.ResourceSliceSelectorDriver, DriverName),
}
l, err := m.clientset.ResourceV1alpha3().ResourceSlices().List(context.Background(), ops)
if err != nil {
return fmt.Errorf("error listing resource slices: %w", err)
}

for _, rs := range l.Items {
err := m.clientset.ResourceV1alpha3().ResourceSlices().Delete(context.Background(), rs.Name, metav1.DeleteOptions{})
if err != nil {
return fmt.Errorf("error deleting resource slice %s: %w", rs.Name, err)
}
}

return nil
}
20 changes: 17 additions & 3 deletions cmd/nvidia-dra-controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package main

import (
"context"
"fmt"
"net"
"net/http"
"net/http/pprof"
"os"
"os/signal"
"path"
"syscall"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
Expand Down Expand Up @@ -132,7 +135,6 @@ func newApp() *cli.App {
return flags.loggingConfig.Apply()
},
Action: func(c *cli.Context) error {
ctx := c.Context
mux := http.NewServeMux()
flags.deviceClasses = sets.New[string](c.StringSlice("device-classes")...)

Expand All @@ -154,14 +156,26 @@ func newApp() *cli.App {
}
}

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)

var imexManager *ImexManager
ctx, cancel := context.WithCancel(c.Context)
defer func() {
cancel()
if err := imexManager.Stop(); err != nil {
klog.Errorf("Error stopping IMEX manager: %v", err)
}
}()

if flags.deviceClasses.Has(ImexChannelType) {
err = StartIMEXManager(ctx, config)
imexManager, err = StartIMEXManager(ctx, config)
if err != nil {
return fmt.Errorf("start IMEX manager: %w", err)
}
}

<-ctx.Done()
<-sigs

return nil
},
Expand Down