diff --git a/cmd/compute-domain-controller/daemonset.go b/cmd/compute-domain-controller/daemonset.go index 37aa83196..9b6157f92 100644 --- a/cmd/compute-domain-controller/daemonset.go +++ b/cmd/compute-domain-controller/daemonset.go @@ -63,7 +63,6 @@ type DaemonSetManager struct { resourceClaimTemplateManager *DaemonSetResourceClaimTemplateManager cleanupManager *CleanupManager[*appsv1.DaemonSet] - podManagers map[string]*DaemonSetPodManager } func NewDaemonSetManager(config *ManagerConfig, getComputeDomain GetComputeDomainFunc) *DaemonSetManager { @@ -92,7 +91,6 @@ func NewDaemonSetManager(config *ManagerConfig, getComputeDomain GetComputeDomai getComputeDomain: getComputeDomain, factory: factory, informer: informer, - podManagers: make(map[string]*DaemonSetPodManager), } m.resourceClaimTemplateManager = NewDaemonSetResourceClaimTemplateManager(config, getComputeDomain) m.cleanupManager = NewCleanupManager[*appsv1.DaemonSet](informer, getComputeDomain, m.cleanup) @@ -150,9 +148,6 @@ func (m *DaemonSetManager) Start(ctx context.Context) (rerr error) { } func (m *DaemonSetManager) Stop() error { - if err := m.removeAllPodManagers(); err != nil { - return fmt.Errorf("error removing all Pod managers: %w", err) - } if err := m.resourceClaimTemplateManager.Stop(); err != nil { return fmt.Errorf("error stopping ResourceClaimTemplate manager: %w", err) } @@ -230,16 +225,11 @@ func (m *DaemonSetManager) Delete(ctx context.Context, cdUID string) error { } d := ds[0] - key := d.Spec.Selector.MatchLabels[computeDomainLabelKey] if err := m.resourceClaimTemplateManager.Delete(ctx, cdUID); err != nil { return fmt.Errorf("error deleting ResourceClaimTemplate: %w", err) } - if err := m.removePodManager(key); err != nil { - return fmt.Errorf("error removing Pod manager: %w", err) - } - if d.GetDeletionTimestamp() != nil { return nil } @@ -335,10 +325,6 @@ func (m *DaemonSetManager) onAddOrUpdate(ctx context.Context, obj any) error { return nil } - if err := m.addPodManager(ctx, d.Spec.Selector, cd.Spec.NumNodes); err != nil { - return fmt.Errorf("error adding Pod manager '%s/%s': %w", d.Namespace, d.Name, err) - } - if int(d.Status.NumberReady) != cd.Spec.NumNodes { return nil } @@ -352,60 +338,6 @@ func (m *DaemonSetManager) onAddOrUpdate(ctx context.Context, obj any) error { return nil } -func (m *DaemonSetManager) addPodManager(ctx context.Context, labelSelector *metav1.LabelSelector, numPods int) error { - key := labelSelector.MatchLabels[computeDomainLabelKey] - - if _, exists := m.podManagers[key]; exists { - return nil - } - - podManager := NewDaemonSetPodManager(m.config, labelSelector, numPods, m.getComputeDomain) - - if err := podManager.Start(ctx); err != nil { - return fmt.Errorf("error creating Pod manager: %w", err) - } - - m.Lock() - m.podManagers[key] = podManager - m.Unlock() - - return nil -} - -func (m *DaemonSetManager) removePodManager(key string) error { - if _, exists := m.podManagers[key]; !exists { - return nil - } - - m.Lock() - podManager := m.podManagers[key] - m.Unlock() - - if err := podManager.Stop(); err != nil { - return fmt.Errorf("error stopping Pod manager: %w", err) - } - - m.Lock() - delete(m.podManagers, key) - m.Unlock() - - return nil -} - -func (m *DaemonSetManager) removeAllPodManagers() error { - m.Lock() - for key, pm := range m.podManagers { - m.Unlock() - if err := pm.Stop(); err != nil { - return fmt.Errorf("error stopping Pod manager: %w", err) - } - m.Lock() - delete(m.podManagers, key) - } - m.Unlock() - return nil -} - func (m *DaemonSetManager) cleanup(ctx context.Context, cdUID string) error { if err := m.Delete(ctx, cdUID); err != nil { return fmt.Errorf("error deleting DaemonSet: %w", err) diff --git a/cmd/compute-domain-controller/daemonsetpods.go b/cmd/compute-domain-controller/daemonsetpods.go deleted file mode 100644 index 73e369081..000000000 --- a/cmd/compute-domain-controller/daemonsetpods.go +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package main - -import ( - "context" - "fmt" - "slices" - "sync" - - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/informers" - corev1listers "k8s.io/client-go/listers/core/v1" - "k8s.io/client-go/tools/cache" - "k8s.io/klog/v2" - - nvapi "github.com/NVIDIA/k8s-dra-driver-gpu/api/nvidia.com/resource/v1beta1" -) - -const ( - CliqueIDLabelKey = "nvidia.com/gpu.clique" -) - -type DaemonSetPodManager struct { - config *ManagerConfig - waitGroup sync.WaitGroup - cancelContext context.CancelFunc - - factory informers.SharedInformerFactory - informer cache.SharedInformer - lister corev1listers.PodLister - - getComputeDomain GetComputeDomainFunc - computeDomainNodes []*nvapi.ComputeDomainNode - numPods int -} - -func NewDaemonSetPodManager(config *ManagerConfig, labelSelector *metav1.LabelSelector, numPods int, getComputeDomain GetComputeDomainFunc) *DaemonSetPodManager { - factory := informers.NewSharedInformerFactoryWithOptions( - config.clientsets.Core, - informerResyncPeriod, - informers.WithNamespace(config.driverNamespace), - informers.WithTweakListOptions(func(opts *metav1.ListOptions) { - opts.LabelSelector = metav1.FormatLabelSelector(labelSelector) - }), - ) - - informer := factory.Core().V1().Pods().Informer() - lister := factory.Core().V1().Pods().Lister() - - m := &DaemonSetPodManager{ - config: config, - factory: factory, - informer: informer, - lister: lister, - getComputeDomain: getComputeDomain, - numPods: numPods, - } - - return m -} - -func (m *DaemonSetPodManager) Start(ctx context.Context) (rerr error) { - ctx, cancel := context.WithCancel(ctx) - m.cancelContext = cancel - - defer func() { - if rerr != nil { - if err := m.Stop(); err != nil { - klog.Errorf("error stopping DaemonSetPod manager: %v", err) - } - } - }() - - _, err := m.informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { - m.config.workQueue.Enqueue(obj, m.onPodAddOrUpdate) - }, - UpdateFunc: func(objOld, objNew any) { - m.config.workQueue.Enqueue(objNew, m.onPodAddOrUpdate) - }, - }) - if err != nil { - return fmt.Errorf("error adding event handlers for pod informer: %w", err) - } - - m.waitGroup.Add(1) - go func() { - defer m.waitGroup.Done() - m.factory.Start(ctx.Done()) - }() - - if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { - return fmt.Errorf("error syncing pod informer: %w", err) - } - - return nil -} - -func (m *DaemonSetPodManager) Stop() error { - m.cancelContext() - m.waitGroup.Wait() - return nil -} - -func (m *DaemonSetPodManager) onPodAddOrUpdate(ctx context.Context, obj any) error { - p, ok := obj.(*corev1.Pod) - if !ok { - return fmt.Errorf("failed to cast to Pod") - } - - p, err := m.lister.Pods(p.Namespace).Get(p.Name) - if err != nil && errors.IsNotFound(err) { - return nil - } - if err != nil { - return fmt.Errorf("erroring retreiving Pod: %w", err) - } - - klog.Infof("Processing added or updated Pod: %s/%s", p.Namespace, p.Name) - - if p.Spec.NodeName == "" { - return fmt.Errorf("pod not yet scheduled: %s/%s", p.Namespace, p.Name) - } - - cd, err := m.getComputeDomain(p.Labels[computeDomainLabelKey]) - if err != nil { - return fmt.Errorf("error getting ComputeDomain: %w", err) - } - if cd == nil { - return nil - } - - if cd.Status.Status == nvapi.ComputeDomainStatusReady { - return nil - } - - var nodeNames []string - for _, node := range m.computeDomainNodes { - nodeNames = append(nodeNames, node.Name) - } - - if !slices.Contains(nodeNames, p.Spec.NodeName) { - node, err := m.GetComputeDomainNode(ctx, p.Spec.NodeName) - if err != nil { - return fmt.Errorf("error getting ComputeDomainNode: %w", err) - } - nodeNames = append(nodeNames, node.Name) - m.computeDomainNodes = append(m.computeDomainNodes, node) - } - - if len(nodeNames) != m.numPods { - return fmt.Errorf("not all pods scheduled yet") - } - - newCD := cd.DeepCopy() - newCD.Status.Nodes = m.computeDomainNodes - newCD.Status.Status = nvapi.ComputeDomainStatusNotReady - if _, err = m.config.clientsets.Nvidia.ResourceV1beta1().ComputeDomains(newCD.Namespace).UpdateStatus(ctx, newCD, metav1.UpdateOptions{}); err != nil { - return fmt.Errorf("error updating nodes in ComputeDomain status: %w", err) - } - - return nil -} - -func (m *DaemonSetPodManager) GetComputeDomainNode(ctx context.Context, nodeName string) (*nvapi.ComputeDomainNode, error) { - node, err := m.config.clientsets.Core.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}) - if err != nil { - return nil, fmt.Errorf("error getting Node '%s': %w", nodeName, err) - } - - var ipAddress string - for _, addr := range node.Status.Addresses { - if addr.Type == corev1.NodeInternalIP { - ipAddress = addr.Address - break - } - } - - n := &nvapi.ComputeDomainNode{ - Name: nodeName, - IPAddress: ipAddress, - CliqueID: node.Labels[CliqueIDLabelKey], - } - - return n, nil -} diff --git a/cmd/compute-domain-kubelet-plugin/computedomain.go b/cmd/compute-domain-kubelet-plugin/computedomain.go index 95ff28add..91fbecc73 100644 --- a/cmd/compute-domain-kubelet-plugin/computedomain.go +++ b/cmd/compute-domain-kubelet-plugin/computedomain.go @@ -22,10 +22,12 @@ import ( "fmt" "os" "path/filepath" + "slices" "sync" "text/template" "time" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" @@ -253,6 +255,66 @@ func (m *ComputeDomainManager) AssertComputeDomainReady(ctx context.Context, cdU return nil } +func (m *ComputeDomainManager) AddNodeStatusToComputeDomain(ctx context.Context, cdUID string) error { + cd, err := m.GetComputeDomain(ctx, cdUID) + if err != nil { + return fmt.Errorf("error getting ComputeDomain: %w", err) + } + if cd == nil { + return fmt.Errorf("ComputeDomain not found: %s", cdUID) + } + + if cd.Status.Status == nvapi.ComputeDomainStatusReady { + return nil + } + + var nodeNames []string + for _, node := range cd.Status.Nodes { + nodeNames = append(nodeNames, node.Name) + } + + if slices.Contains(nodeNames, m.config.flags.nodeName) { + return nil + } + + node, err := m.GetComputeDomainNodeStatusInfo(ctx, m.config.flags.nodeName) + if err != nil { + return fmt.Errorf("error getting ComputeDomain node status info: %w", err) + } + + newCD := cd.DeepCopy() + newCD.Status.Nodes = append(newCD.Status.Nodes, node) + newCD.Status.Status = nvapi.ComputeDomainStatusNotReady + if _, err = m.config.clientsets.Nvidia.ResourceV1beta1().ComputeDomains(newCD.Namespace).UpdateStatus(ctx, newCD, metav1.UpdateOptions{}); err != nil { + return fmt.Errorf("error updating nodes in ComputeDomain status: %w", err) + } + + return nil +} + +func (m *ComputeDomainManager) GetComputeDomainNodeStatusInfo(ctx context.Context, nodeName string) (*nvapi.ComputeDomainNode, error) { + node, err := m.config.clientsets.Core.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("error getting Node '%s': %w", nodeName, err) + } + + var ipAddress string + for _, addr := range node.Status.Addresses { + if addr.Type == corev1.NodeInternalIP { + ipAddress = addr.Address + break + } + } + + n := &nvapi.ComputeDomainNode{ + Name: nodeName, + IPAddress: ipAddress, + CliqueID: m.cliqueID, + } + + return n, nil +} + func (m *ComputeDomainManager) GetNodeIPs(ctx context.Context, cdUID string) ([]string, error) { cd, err := m.GetComputeDomain(ctx, cdUID) if err != nil { @@ -266,6 +328,10 @@ func (m *ComputeDomainManager) GetNodeIPs(ctx context.Context, cdUID string) ([] return nil, fmt.Errorf("error getting status of nodes in ComputeDomain: %w", err) } + if len(cd.Status.Nodes) != cd.Spec.NumNodes { + return nil, fmt.Errorf("not all nodes populated in ComputeDomain status yet") + } + var ips []string for _, node := range cd.Status.Nodes { if m.cliqueID == node.CliqueID { diff --git a/cmd/compute-domain-kubelet-plugin/device_state.go b/cmd/compute-domain-kubelet-plugin/device_state.go index 9cbf285e4..aef93a625 100644 --- a/cmd/compute-domain-kubelet-plugin/device_state.go +++ b/cmd/compute-domain-kubelet-plugin/device_state.go @@ -32,10 +32,6 @@ import ( configapi "github.com/NVIDIA/k8s-dra-driver-gpu/api/nvidia.com/resource/v1beta1" ) -const ( - CliqueIDLabelKey = "nvidia.com/gpu.clique" -) - type OpaqueDeviceConfig struct { Requests []string Config runtime.Object @@ -372,6 +368,11 @@ func (s *DeviceState) applyComputeDomainDaemonConfig(ctx context.Context, config return nil, fmt.Errorf("only expected 1 device for requests '%v' in claim '%v'", requests, claim.UID) } + // Add info about this node to the ComputeDomain status. + if err := s.computeDomainManager.AddNodeStatusToComputeDomain(ctx, config.DomainID); err != nil { + return nil, fmt.Errorf("error adding node status to ComputeDomain: %w", err) + } + // Declare a device group state object to populate. configState := DeviceConfigState{ Type: ComputeDomainDaemonType,