diff --git a/charts/topograph/Chart.yaml b/charts/topograph/Chart.yaml index 2056189..e4b6061 100644 --- a/charts/topograph/Chart.yaml +++ b/charts/topograph/Chart.yaml @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.0 +version: 0.2.0 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to @@ -25,8 +25,8 @@ appVersion: "1.16.0" dependencies: - name: node-data-broker - version: 0.1.0 + version: 0.2.0 repository: "file://charts/node-data-broker" - name: node-observer - version: 0.1.0 + version: 0.2.0 repository: "file://charts/node-observer" diff --git a/charts/topograph/charts/node-data-broker/Chart.yaml b/charts/topograph/charts/node-data-broker/Chart.yaml index 7b70ccd..a3414a2 100644 --- a/charts/topograph/charts/node-data-broker/Chart.yaml +++ b/charts/topograph/charts/node-data-broker/Chart.yaml @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.0 +version: 0.2.0 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to diff --git a/charts/topograph/charts/node-data-broker/templates/daemonset.yaml b/charts/topograph/charts/node-data-broker/templates/daemonset.yaml index 5f8dd25..b5f6929 100644 --- a/charts/topograph/charts/node-data-broker/templates/daemonset.yaml +++ b/charts/topograph/charts/node-data-broker/templates/daemonset.yaml @@ -29,7 +29,7 @@ spec: command: - /usr/local/bin/node-data-broker-initc args: - - -provider={{ .Values.global.provider }} + - -provider={{ .Values.global.provider.name }} - -v={{ .Values.verbosity }} env: - name: NODE_NAME diff --git a/charts/topograph/charts/node-data-broker/templates/rbac.yaml b/charts/topograph/charts/node-data-broker/templates/rbac.yaml index 02b0a2a..912bfa9 100644 --- a/charts/topograph/charts/node-data-broker/templates/rbac.yaml +++ b/charts/topograph/charts/node-data-broker/templates/rbac.yaml @@ -7,7 +7,7 @@ rules: - apiGroups: [""] resources: [nodes] verbs: [get,list,update] -{{- if eq .Values.global.provider "infiniband-k8s" }} +{{- if eq .Values.global.provider.name "infiniband-k8s" }} - apiGroups: [apps] resources: [daemonsets] verbs: [get,list] diff --git a/charts/topograph/charts/node-observer/Chart.yaml b/charts/topograph/charts/node-observer/Chart.yaml index f746cfd..ffd29c1 100644 --- a/charts/topograph/charts/node-observer/Chart.yaml +++ b/charts/topograph/charts/node-observer/Chart.yaml @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.0 +version: 0.2.0 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to diff --git a/charts/topograph/charts/node-observer/templates/configmap.yml b/charts/topograph/charts/node-observer/templates/configmap.yml index 7875cae..db1ac17 100644 --- a/charts/topograph/charts/node-observer/templates/configmap.yml +++ b/charts/topograph/charts/node-observer/templates/configmap.yml @@ -7,7 +7,9 @@ metadata: data: node-observer-config.yaml: |- generateTopologyUrl: "{{ include "topograph.url" $ }}/v1/generate" - params: - {{- toYaml .Values.global.engineParams | nindent 6 }} + provider: + {{- toYaml .Values.global.provider | nindent 6 }} + engine: + {{- toYaml .Values.global.engine | nindent 6 }} trigger: {{- toYaml .Values.topograph.trigger | nindent 6 }} diff --git a/charts/topograph/templates/configmap.yml b/charts/topograph/templates/configmap.yml index 8209dcb..c68871e 100644 --- a/charts/topograph/templates/configmap.yml +++ b/charts/topograph/templates/configmap.yml @@ -9,8 +9,6 @@ data: http: port: {{ .Values.global.service.port }} ssl: false - provider: {{ .Values.global.provider }} - engine: {{ .Values.global.engine }} requestAggregationDelay: {{ .Values.config.requestAggregationDelay }} {{- if .Values.config.credentialsSecretName }} credentialsPath: /etc/topograph/credentials/credentials.yaml diff --git a/charts/topograph/templates/rbac.yaml b/charts/topograph/templates/rbac.yaml index b688fce..11cab2c 100644 --- a/charts/topograph/templates/rbac.yaml +++ b/charts/topograph/templates/rbac.yaml @@ -15,7 +15,7 @@ rules: - apiGroups: [apps] resources: [daemonsets] verbs: [get,list] -{{- if eq .Values.global.engine "slinky" }} +{{- if eq .Values.global.engine.name "slinky" }} - apiGroups: [""] resources: [configmaps] verbs: [create,get,list,update] diff --git a/charts/topograph/values-slinky-block-example.yaml b/charts/topograph/values-slinky-block-example.yaml index 8d75796..1ff8cf1 100644 --- a/charts/topograph/values-slinky-block-example.yaml +++ b/charts/topograph/values-slinky-block-example.yaml @@ -3,19 +3,25 @@ # Declare variables to be passed into your templates. global: - # provider: "aws", "oci", "gcp", "nebius", "netq", "infiniband-k8s", "dra" or "test" - provider: aws - # engine: "k8s" or "slinky" - engine: slinky - engineParams: - namespace: slurm - podSelector: - matchLabels: - app.kubernetes.io/component: compute - plugin: topology/block - block_sizes: 4 - topologyConfigPath: topology.conf - topologyConfigmapName: slurm-config + provider: + # name: "aws", "oci", "gcp", "nebius", "netq", "infiniband-k8s", "dra" or "test" + name: aws + params: + nodeSelector: + slurmCluster: my-cluster + engine: + name: slinky + params: + namespace: slurm + nodeSelector: + slurmCluster: my-cluster + podSelector: + matchLabels: + app.kubernetes.io/component: compute + plugin: topology/block + block_sizes: 4 + topologyConfigPath: topology.conf + topologyConfigmapName: slurm-config nodeSelector: dedicated: user-workload diff --git a/charts/topograph/values-slinky-partition-example.yaml b/charts/topograph/values-slinky-partition-example.yaml index 3ec4282..601a6a7 100644 --- a/charts/topograph/values-slinky-partition-example.yaml +++ b/charts/topograph/values-slinky-partition-example.yaml @@ -3,29 +3,30 @@ # Declare variables to be passed into your templates. global: - # provider: "aws", "oci", "gcp", "nebius", "netq", "infiniband-k8s", "dra" or "test" - provider: aws - # engine: "k8s" or "slinky" - engine: slinky - engineParams: - namespace: slurm - podSelector: - matchLabels: - app.kubernetes.io/component: compute - topologies: - topo1: - plugin: topology/block - blockSizes: [2,4] - topo2: - plugin: topology/block - blockSizes: [8,16] - topo3: - plugin: topology/tree - topo-default: - plugin: topology/flat - clusterDefault: true - topologyConfigPath: topology.conf - topologyConfigmapName: slurm-config + provider: + # name: "aws", "oci", "gcp", "nebius", "netq", "infiniband-k8s", "dra" or "test" + name: aws + engine: + name: slinky + params: + namespace: slurm + podSelector: + matchLabels: + app.kubernetes.io/component: compute + topologies: + topo1: + plugin: topology/block + blockSizes: [2,4] + topo2: + plugin: topology/block + blockSizes: [8,16] + topo3: + plugin: topology/tree + topo-default: + plugin: topology/flat + clusterDefault: true + topologyConfigPath: topology.conf + topologyConfigmapName: slurm-config nodeSelector: dedicated: user-workload diff --git a/charts/topograph/values-slinky-tree-example.yaml b/charts/topograph/values-slinky-tree-example.yaml index 1c60698..d115741 100644 --- a/charts/topograph/values-slinky-tree-example.yaml +++ b/charts/topograph/values-slinky-tree-example.yaml @@ -3,18 +3,19 @@ # Declare variables to be passed into your templates. global: - # provider: "aws", "oci", "gcp", "nebius", "netq", "infiniband-k8s", "dra" or "test" - provider: aws - # engine: "k8s" or "slinky" - engine: slinky - engineParams: - namespace: slurm - podSelector: - matchLabels: - app.kubernetes.io/component: compute - plugin: topology/tree - topologyConfigPath: topology.conf - topologyConfigmapName: slurm-config + provider: + # name: "aws", "oci", "gcp", "nebius", "netq", "infiniband-k8s", "dra" or "test" + name: aws + engine: + name: slinky + params: + namespace: slurm + podSelector: + matchLabels: + app.kubernetes.io/component: compute + plugin: topology/tree + topologyConfigPath: topology.conf + topologyConfigmapName: slurm-config nodeSelector: dedicated: user-workload diff --git a/charts/topograph/values.yaml b/charts/topograph/values.yaml index 973aaa6..7dc23e8 100644 --- a/charts/topograph/values.yaml +++ b/charts/topograph/values.yaml @@ -3,11 +3,12 @@ # Declare variables to be passed into your templates. global: - # provider: "aws", "oci", "gcp", "nebius", "netq", "infiniband-k8s", "dra" or "test". - provider: test - # engine: "k8s" or "slinky" - engine: k8s - # engineParams: + provider: + # name: "aws", "oci", "gcp", "nebius", "netq", "infiniband-k8s", "dra" or "test". + name: test + engine: + # name: "k8s" or "slinky" + name: k8s service: type: ClusterIP diff --git a/internal/k8s/utils.go b/internal/k8s/utils.go index b8a5c52..4c69570 100644 --- a/internal/k8s/utils.go +++ b/internal/k8s/utils.go @@ -19,8 +19,12 @@ import ( "k8s.io/client-go/tools/remotecommand" ) -func GetNodes(ctx context.Context, client *kubernetes.Clientset) (*corev1.NodeList, error) { - nodes, err := client.CoreV1().Nodes().List(ctx, metav1.ListOptions{}) +func GetNodes(ctx context.Context, client *kubernetes.Clientset, opt *metav1.ListOptions) (*corev1.NodeList, error) { + if opt == nil { + opt = &metav1.ListOptions{} + } + + nodes, err := client.CoreV1().Nodes().List(ctx, *opt) if err != nil { return nil, fmt.Errorf("failed to list node in the cluster: %v", err) } diff --git a/pkg/engines/k8s/engine.go b/pkg/engines/k8s/engine.go index 906677c..880516d 100644 --- a/pkg/engines/k8s/engine.go +++ b/pkg/engines/k8s/engine.go @@ -20,9 +20,12 @@ import ( "context" "net/http" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" + "github.com/NVIDIA/topograph/internal/config" "github.com/NVIDIA/topograph/internal/httperr" "github.com/NVIDIA/topograph/pkg/engines" "github.com/NVIDIA/topograph/pkg/topology" @@ -33,13 +36,27 @@ const NAME = "k8s" type K8sEngine struct { config *rest.Config client *kubernetes.Clientset + params *Params +} + +type Params struct { + // NodeSelector (optional) specifies nodes participating in the topology + NodeSelector map[string]string `mapstructure:"nodeSelector"` + + // derived fields + nodeListOpt *metav1.ListOptions } func NamedLoader() (string, engines.Loader) { return NAME, Loader } -func Loader(_ context.Context, _ engines.Config) (engines.Engine, *httperr.Error) { +func Loader(_ context.Context, params engines.Config) (engines.Engine, *httperr.Error) { + p, err := getParameters(params) + if err != nil { + return nil, httperr.NewError(http.StatusBadRequest, err.Error()) + } + config, err := rest.InClusterConfig() if err != nil { return nil, httperr.NewError(http.StatusBadGateway, err.Error()) @@ -53,9 +70,25 @@ func Loader(_ context.Context, _ engines.Config) (engines.Engine, *httperr.Error return &K8sEngine{ config: config, client: client, + params: p, }, nil } +func getParameters(params engines.Config) (*Params, error) { + p := &Params{} + if err := config.Decode(params, p); err != nil { + return nil, err + } + + if len(p.NodeSelector) != 0 { + p.nodeListOpt = &metav1.ListOptions{ + LabelSelector: labels.Set(p.NodeSelector).String(), + } + } + + return p, nil +} + func (eng *K8sEngine) GenerateOutput(ctx context.Context, tree *topology.Vertex, params map[string]any) ([]byte, *httperr.Error) { if err := NewTopologyLabeler().ApplyNodeLabels(ctx, tree, eng); err != nil { return nil, httperr.NewError(http.StatusBadGateway, err.Error()) diff --git a/pkg/engines/k8s/engine_test.go b/pkg/engines/k8s/engine_test.go new file mode 100644 index 0000000..2496847 --- /dev/null +++ b/pkg/engines/k8s/engine_test.go @@ -0,0 +1,55 @@ +/* + * Copyright 2025 NVIDIA CORPORATION + * SPDX-License-Identifier: Apache-2.0 + */ + +package k8s + +import ( + "testing" + + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestGetParameters(t *testing.T) { + testCases := []struct { + name string + params map[string]any + ret *Params + err string + }{ + { + name: "Case 1: no params", + params: nil, + ret: &Params{}, + }, + { + name: "Case 2: bad params", + params: map[string]any{"nodeSelector": .1}, + err: "could not decode configuration: 1 error(s) decoding:\n\n* 'nodeSelector' expected a map, got 'float64'", + }, + { + name: "Case 3: valid input", + params: map[string]any{"nodeSelector": map[string]string{"key": "val"}}, + ret: &Params{ + NodeSelector: map[string]string{"key": "val"}, + nodeListOpt: &metav1.ListOptions{ + LabelSelector: "key=val", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p, err := getParameters(tc.params) + if len(tc.err) != 0 { + require.ErrorContains(t, err, tc.err) + } else { + require.NoError(t, err) + require.Equal(t, tc.ret, p) + } + }) + } +} diff --git a/pkg/engines/k8s/kubernetes.go b/pkg/engines/k8s/kubernetes.go index 83f6510..aff2053 100644 --- a/pkg/engines/k8s/kubernetes.go +++ b/pkg/engines/k8s/kubernetes.go @@ -32,7 +32,7 @@ import ( ) func (eng *K8sEngine) GetComputeInstances(ctx context.Context, _ engines.Environment) ([]topology.ComputeInstances, *httperr.Error) { - nodes, err := k8s.GetNodes(ctx, eng.client) + nodes, err := k8s.GetNodes(ctx, eng.client, eng.params.nodeListOpt) if err != nil { return nil, httperr.NewError(http.StatusBadGateway, err.Error()) } diff --git a/pkg/engines/slinky/engine.go b/pkg/engines/slinky/engine.go index 19494ba..8220d7e 100644 --- a/pkg/engines/slinky/engine.go +++ b/pkg/engines/slinky/engine.go @@ -27,6 +27,7 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/klog/v2" @@ -50,13 +51,20 @@ type SlinkyEngine struct { type Params struct { slurm.BaseParams `mapstructure:",squash"` - Namespace string `mapstructure:"namespace"` - LabelSelector metav1.LabelSelector `mapstructure:"podSelector"` - ConfigPath string `mapstructure:"topologyConfigPath"` - ConfigMapName string `mapstructure:"topologyConfigmapName"` + // Namespace specifies the namespace where Slinky cluster is deployed + Namespace string `mapstructure:"namespace"` + // PodSelector specifies slurmd pods + PodSelector metav1.LabelSelector `mapstructure:"podSelector"` + // NodeSelector (optional) specifies nodes running slurmd pods + NodeSelector map[string]string `mapstructure:"nodeSelector"` + // ConfigMapName specifies the name of the configmap containing topology config + ConfigMapName string `mapstructure:"topologyConfigmapName"` + // ConfigPath specifies the topology config filename inside the configmap + ConfigPath string `mapstructure:"topologyConfigPath"` // derived fields - podSelector string + podListOpt *metav1.ListOptions + nodeListOpt *metav1.ListOptions } func NamedLoader() (string, engines.Loader) { @@ -92,15 +100,23 @@ func getParameters(params engines.Config) (*Params, error) { return nil, err } - selector, err := metav1.LabelSelectorAsSelector(&p.LabelSelector) + sel, err := metav1.LabelSelectorAsSelector(&p.PodSelector) if err != nil { return nil, err } - p.podSelector = selector.String() + p.podListOpt = &metav1.ListOptions{ + LabelSelector: sel.String(), + } + + if len(p.NodeSelector) != 0 { + p.nodeListOpt = &metav1.ListOptions{ + LabelSelector: labels.Set(p.NodeSelector).String(), + } + } for key, val := range map[string]string{ topology.KeyNamespace: p.Namespace, - topology.KeyPodSelector: p.podSelector, + topology.KeyPodSelector: p.podListOpt.LabelSelector, topology.KeyTopoConfigPath: p.ConfigPath, topology.KeyTopoConfigmapName: p.ConfigMapName, } { @@ -113,22 +129,20 @@ func getParameters(params engines.Config) (*Params, error) { } func (eng *SlinkyEngine) GetComputeInstances(ctx context.Context, _ engines.Environment) ([]topology.ComputeInstances, *httperr.Error) { - nodes, err := k8s.GetNodes(ctx, eng.client) + + nodes, err := k8s.GetNodes(ctx, eng.client, eng.params.nodeListOpt) if err != nil { return nil, httperr.NewError(http.StatusBadGateway, err.Error()) } - opt := metav1.ListOptions{ - LabelSelector: eng.params.podSelector, - } - pods, err := eng.client.CoreV1().Pods(eng.params.Namespace).List(ctx, opt) + pods, err := eng.client.CoreV1().Pods(eng.params.Namespace).List(ctx, *eng.params.podListOpt) if err != nil { return nil, httperr.NewError(http.StatusBadGateway, fmt.Sprintf("failed to list SLURM pods in the cluster: %v", err)) } - klog.V(4).Infof("Found %d pods in %q namespace with selector %q", len(pods.Items), eng.params.Namespace, eng.params.podSelector) + klog.V(4).Infof("Found %d pods in %q namespace with selector %q", len(pods.Items), eng.params.Namespace, eng.params.podListOpt.LabelSelector) // map k8s host name to SLURM host name nodeMap := make(map[string]string) diff --git a/pkg/engines/slinky/engine_test.go b/pkg/engines/slinky/engine_test.go index 12e0198..ac20c34 100644 --- a/pkg/engines/slinky/engine_test.go +++ b/pkg/engines/slinky/engine_test.go @@ -29,17 +29,19 @@ import ( ) func TestGetParameters(t *testing.T) { - selector := map[string]any{ - "matchLabels": map[string]string{"app.kubernetes.io/component": "compute"}, + podSelector := map[string]any{ + "matchLabels": map[string]string{"key": "value"}, } + nodeSelector := map[string]string{"key": "value"} invalidSelector := map[string]any{ "matchExpressions": []metav1.LabelSelectorRequirement{ {Operator: "BAD"}, }, } labelSelector := metav1.LabelSelector{ - MatchLabels: map[string]string{"app.kubernetes.io/component": "compute"}, + MatchLabels: map[string]string{"key": "value"}, } + testCases := []struct { name string params map[string]any @@ -69,7 +71,7 @@ func TestGetParameters(t *testing.T) { err: `could not decode configuration:`, }, { - name: "Case 4: invalid label selector", + name: "Case 4: invalid pod label selector", params: map[string]any{ topology.KeyNamespace: "namespace", topology.KeyPodSelector: invalidSelector, @@ -82,23 +84,24 @@ func TestGetParameters(t *testing.T) { name: "Case 5: minimal valid input", params: map[string]any{ topology.KeyNamespace: "namespace", - topology.KeyPodSelector: selector, + topology.KeyPodSelector: podSelector, topology.KeyTopoConfigPath: "path", topology.KeyTopoConfigmapName: "name", }, ret: &Params{ Namespace: "namespace", - LabelSelector: labelSelector, + PodSelector: labelSelector, ConfigPath: "path", ConfigMapName: "name", - podSelector: "app.kubernetes.io/component=compute", + podListOpt: &metav1.ListOptions{LabelSelector: "key=value"}, }, }, { name: "Case 6: complete valid input", params: map[string]any{ topology.KeyNamespace: "namespace", - topology.KeyPodSelector: selector, + topology.KeyPodSelector: podSelector, + topology.KeyNodeSelector: nodeSelector, topology.KeyPlugin: topology.TopologyBlock, topology.KeyBlockSizes: "16", topology.KeyTopoConfigPath: "path", @@ -110,10 +113,12 @@ func TestGetParameters(t *testing.T) { BlockSizes: "16", }, Namespace: "namespace", - LabelSelector: labelSelector, + PodSelector: labelSelector, + NodeSelector: nodeSelector, ConfigPath: "path", ConfigMapName: "name", - podSelector: "app.kubernetes.io/component=compute", + podListOpt: &metav1.ListOptions{LabelSelector: "key=value"}, + nodeListOpt: &metav1.ListOptions{LabelSelector: "key=value"}, }, }, } @@ -219,7 +224,7 @@ func TestConfigMapAnnotationsAndMetadata(t *testing.T) { name: "minimal params, no plugin/block", params: &Params{ Namespace: "test-namespace", - LabelSelector: labelSelector, + PodSelector: labelSelector, ConfigPath: "topology.conf", ConfigMapName: "slurm-topology", }, @@ -231,7 +236,7 @@ func TestConfigMapAnnotationsAndMetadata(t *testing.T) { BaseParams: slurm.BaseParams{ Plugin: topology.TopologyBlock, }, - LabelSelector: labelSelector, + PodSelector: labelSelector, ConfigPath: "topology.conf", ConfigMapName: "slurm-topology", }, @@ -244,7 +249,7 @@ func TestConfigMapAnnotationsAndMetadata(t *testing.T) { BlockSizes: "8,16,32", }, Namespace: "test-namespace", - LabelSelector: labelSelector, + PodSelector: labelSelector, ConfigPath: "topology.conf", ConfigMapName: "slurm-topology", }, @@ -258,7 +263,7 @@ func TestConfigMapAnnotationsAndMetadata(t *testing.T) { BlockSizes: "8,16,32", }, Namespace: "test-namespace", - LabelSelector: labelSelector, + PodSelector: labelSelector, ConfigPath: "topology.conf", ConfigMapName: "slurm-topology", }, diff --git a/pkg/node_observer/config.go b/pkg/node_observer/config.go index 727119f..a099936 100644 --- a/pkg/node_observer/config.go +++ b/pkg/node_observer/config.go @@ -22,14 +22,16 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/yaml" + + "github.com/NVIDIA/topograph/pkg/topology" ) type Config struct { - GenerateTopologyURL string `yaml:"generateTopologyUrl"` - Trigger Trigger `yaml:"trigger"` - Provider string `yaml:"provider"` - Engine string `yaml:"engine"` - Params map[string]any `yaml:"params"` + GenerateTopologyURL string `yaml:"generateTopologyUrl"` + Trigger Trigger `yaml:"trigger"` + Provider topology.Provider `yaml:"provider"` + Engine topology.Engine `yaml:"engine"` + Params map[string]any `yaml:"params"` } type Trigger struct { diff --git a/pkg/node_observer/controller.go b/pkg/node_observer/controller.go index cf0a7ea..f3646ad 100644 --- a/pkg/node_observer/controller.go +++ b/pkg/node_observer/controller.go @@ -38,7 +38,7 @@ type Controller struct { func NewController(ctx context.Context, client kubernetes.Interface, cfg *Config) (*Controller, error) { var f httpreq.RequestFunc = func() (*http.Request, error) { - payload := topology.NewRequest(cfg.Provider, nil, cfg.Engine, cfg.Params) + payload := topology.NewRequest(cfg.Provider, cfg.Engine) data, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("failed to parse payload: %v", err) diff --git a/pkg/node_observer/controller_test.go b/pkg/node_observer/controller_test.go index d927e48..8977c16 100644 --- a/pkg/node_observer/controller_test.go +++ b/pkg/node_observer/controller_test.go @@ -11,14 +11,16 @@ import ( "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/NVIDIA/topograph/pkg/topology" ) func TestNewController(t *testing.T) { ctx := context.TODO() cfg := &Config{ - Provider: "test", - Engine: "test", + Provider: topology.Provider{Name: "test"}, + Engine: topology.Engine{Name: "test"}, } testCases := []struct { diff --git a/pkg/providers/dra/provider.go b/pkg/providers/dra/provider.go index bcdfb1c..4942fa4 100644 --- a/pkg/providers/dra/provider.go +++ b/pkg/providers/dra/provider.go @@ -9,9 +9,12 @@ import ( "context" "net/http" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" + "github.com/NVIDIA/topograph/internal/config" "github.com/NVIDIA/topograph/internal/httperr" "github.com/NVIDIA/topograph/internal/k8s" "github.com/NVIDIA/topograph/pkg/providers" @@ -23,6 +26,15 @@ const NAME = "dra" type Provider struct { config *rest.Config client *kubernetes.Clientset + params *Params +} + +type Params struct { + // NodeSelector (optional) specifies nodes participating in the topology + NodeSelector map[string]string `mapstructure:"nodeSelector"` + + // derived fields + nodeListOpt *metav1.ListOptions } func NamedLoader() (string, providers.Loader) { @@ -30,6 +42,11 @@ func NamedLoader() (string, providers.Loader) { } func Loader(ctx context.Context, config providers.Config) (providers.Provider, *httperr.Error) { + p, err := getParameters(config.Params) + if err != nil { + return nil, httperr.NewError(http.StatusBadRequest, err.Error()) + } + cfg, err := rest.InClusterConfig() if err != nil { return nil, httperr.NewError(http.StatusBadGateway, err.Error()) @@ -43,16 +60,32 @@ func Loader(ctx context.Context, config providers.Config) (providers.Provider, * return &Provider{ config: cfg, client: client, + params: p, }, nil } +func getParameters(params map[string]any) (*Params, error) { + p := &Params{} + if err := config.Decode(params, p); err != nil { + return nil, err + } + + if len(p.NodeSelector) != 0 { + p.nodeListOpt = &metav1.ListOptions{ + LabelSelector: labels.Set(p.NodeSelector).String(), + } + } + + return p, nil +} + func (p *Provider) GenerateTopologyConfig(ctx context.Context, _ *int, instances []topology.ComputeInstances) (*topology.Vertex, *httperr.Error) { regIndices := make(map[string]int) // map[region : index] for i, ci := range instances { regIndices[ci.Region] = i } - nodes, err := k8s.GetNodes(ctx, p.client) + nodes, err := k8s.GetNodes(ctx, p.client, p.params.nodeListOpt) if err != nil { return nil, httperr.NewError(http.StatusBadGateway, err.Error()) } diff --git a/pkg/providers/dra/provider_test.go b/pkg/providers/dra/provider_test.go new file mode 100644 index 0000000..1c8937a --- /dev/null +++ b/pkg/providers/dra/provider_test.go @@ -0,0 +1,55 @@ +/* + * Copyright 2025 NVIDIA CORPORATION + * SPDX-License-Identifier: Apache-2.0 + */ + +package dra + +import ( + "testing" + + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestGetParameters(t *testing.T) { + testCases := []struct { + name string + params map[string]any + ret *Params + err string + }{ + { + name: "Case 1: no params", + params: nil, + ret: &Params{}, + }, + { + name: "Case 2: bad params", + params: map[string]any{"nodeSelector": .1}, + err: "could not decode configuration: 1 error(s) decoding:\n\n* 'nodeSelector' expected a map, got 'float64'", + }, + { + name: "Case 3: valid input", + params: map[string]any{"nodeSelector": map[string]string{"key": "val"}}, + ret: &Params{ + NodeSelector: map[string]string{"key": "val"}, + nodeListOpt: &metav1.ListOptions{ + LabelSelector: "key=val", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p, err := getParameters(tc.params) + if len(tc.err) != 0 { + require.ErrorContains(t, err, tc.err) + } else { + require.NoError(t, err) + require.Equal(t, tc.ret, p) + } + }) + } +} diff --git a/pkg/providers/infiniband/provider_k8s.go b/pkg/providers/infiniband/provider_k8s.go index 022bef0..1717bb9 100644 --- a/pkg/providers/infiniband/provider_k8s.go +++ b/pkg/providers/infiniband/provider_k8s.go @@ -10,9 +10,12 @@ import ( "fmt" "net/http" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" + "github.com/NVIDIA/topograph/internal/config" "github.com/NVIDIA/topograph/internal/httperr" "github.com/NVIDIA/topograph/internal/k8s" "github.com/NVIDIA/topograph/pkg/providers" @@ -24,13 +27,27 @@ const NAME_K8S = "infiniband-k8s" type ProviderK8S struct { config *rest.Config client *kubernetes.Clientset + params *Params +} + +type Params struct { + // NodeSelector (optional) specifies nodes participating in the topology + NodeSelector map[string]string `mapstructure:"nodeSelector"` + + // derived fields + nodeListOpt *metav1.ListOptions } func NamedLoaderK8S() (string, providers.Loader) { return NAME_K8S, LoaderK8S } -func LoaderK8S(ctx context.Context, _ providers.Config) (providers.Provider, *httperr.Error) { +func LoaderK8S(ctx context.Context, config providers.Config) (providers.Provider, *httperr.Error) { + p, err := getParameters(config.Params) + if err != nil { + return nil, httperr.NewError(http.StatusBadRequest, err.Error()) + } + cfg, err := rest.InClusterConfig() if err != nil { return nil, httperr.NewError(http.StatusBadGateway, err.Error()) @@ -44,15 +61,31 @@ func LoaderK8S(ctx context.Context, _ providers.Config) (providers.Provider, *ht return &ProviderK8S{ config: cfg, client: client, + params: p, }, nil } +func getParameters(params map[string]any) (*Params, error) { + p := &Params{} + if err := config.Decode(params, p); err != nil { + return nil, err + } + + if len(p.NodeSelector) != 0 { + p.nodeListOpt = &metav1.ListOptions{ + LabelSelector: labels.Set(p.NodeSelector).String(), + } + } + + return p, nil +} + func (p *ProviderK8S) GenerateTopologyConfig(ctx context.Context, _ *int, cis []topology.ComputeInstances) (*topology.Vertex, *httperr.Error) { if len(cis) > 1 { return nil, httperr.NewError(http.StatusBadRequest, "on-prem does not support multi-region topology requests") } - nodes, err := k8s.GetNodes(ctx, p.client) + nodes, err := k8s.GetNodes(ctx, p.client, p.params.nodeListOpt) if err != nil { return nil, httperr.NewError(http.StatusBadGateway, err.Error()) } diff --git a/pkg/providers/infiniband/provider_k8s_test.go b/pkg/providers/infiniband/provider_k8s_test.go new file mode 100644 index 0000000..af7a1ce --- /dev/null +++ b/pkg/providers/infiniband/provider_k8s_test.go @@ -0,0 +1,55 @@ +/* + * Copyright 2025 NVIDIA CORPORATION + * SPDX-License-Identifier: Apache-2.0 + */ + +package infiniband + +import ( + "testing" + + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestGetParameters(t *testing.T) { + testCases := []struct { + name string + params map[string]any + ret *Params + err string + }{ + { + name: "Case 1: no params", + params: nil, + ret: &Params{}, + }, + { + name: "Case 2: bad params", + params: map[string]any{"nodeSelector": .1}, + err: "could not decode configuration: 1 error(s) decoding:\n\n* 'nodeSelector' expected a map, got 'float64'", + }, + { + name: "Case 3: valid input", + params: map[string]any{"nodeSelector": map[string]string{"key": "val"}}, + ret: &Params{ + NodeSelector: map[string]string{"key": "val"}, + nodeListOpt: &metav1.ListOptions{ + LabelSelector: "key=val", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p, err := getParameters(tc.params) + if len(tc.err) != 0 { + require.ErrorContains(t, err, tc.err) + } else { + require.NoError(t, err) + require.Equal(t, tc.ret, p) + } + }) + } +} diff --git a/pkg/server/trailing_delay_queue.go b/pkg/server/trailing_delay_queue.go index 8a3f6c0..14c9b5c 100644 --- a/pkg/server/trailing_delay_queue.go +++ b/pkg/server/trailing_delay_queue.go @@ -31,10 +31,10 @@ import ( const RequestHistorySize = 100 -type HandleFunc func(interface{}) (interface{}, *httperr.Error) +type HandleFunc func(any) (any, *httperr.Error) type Completion struct { - Ret interface{} + Ret any Status int Message string } @@ -45,10 +45,10 @@ type TrailingDelayQueue struct { handle HandleFunc delay time.Duration shutdown chan struct{} - item interface{} // current item to be processed, if not nil - lastTime time.Time // last submit time - uid string // unique item processing ID - store *lru.Cache // map uid:process result + item any // current item to be processed, if not nil + lastTime time.Time // last submit time + uid string // unique item processing ID + store *lru.Cache // map uid:process result } func NewTrailingDelayQueue(handle HandleFunc, delay time.Duration) *TrailingDelayQueue { @@ -73,7 +73,7 @@ func (q *TrailingDelayQueue) run() { klog.V(4).Infof("queue shutdown") return case <-q.ticker.C: - var item interface{} + var item any var uid string q.mutex.Lock() if time.Since(q.lastTime) > q.delay && q.item != nil { @@ -104,7 +104,7 @@ func (q *TrailingDelayQueue) run() { } } -func (q *TrailingDelayQueue) Submit(item interface{}) string { +func (q *TrailingDelayQueue) Submit(item any) string { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/pkg/topology/request.go b/pkg/topology/request.go index 7234ea1..1ad33e3 100644 --- a/pkg/topology/request.go +++ b/pkg/topology/request.go @@ -45,16 +45,10 @@ type ComputeInstances struct { Instances map[string]string `json:"instances"` // : map } -func NewRequest(prv string, creds map[string]string, eng string, params map[string]any) *Request { +func NewRequest(prv Provider, eng Engine) *Request { return &Request{ - Provider: Provider{ - Name: prv, - Creds: creds, - }, - Engine: Engine{ - Name: eng, - Params: params, - }, + Provider: prv, + Engine: eng, } } diff --git a/pkg/topology/request_test.go b/pkg/topology/request_test.go index 6260b56..d75b613 100644 --- a/pkg/topology/request_test.go +++ b/pkg/topology/request_test.go @@ -22,6 +22,13 @@ import ( "github.com/stretchr/testify/require" ) +func TestNewRequest(t *testing.T) { + provider := Provider{Name: "test", Params: map[string]any{"key": 1}} + engine := Engine{} + + require.Equal(t, &Request{Provider: provider}, NewRequest(provider, engine)) +} + func TestPayload(t *testing.T) { testCases := []struct { name string diff --git a/pkg/topology/topology.go b/pkg/topology/topology.go index a9c15a2..1c8075d 100644 --- a/pkg/topology/topology.go +++ b/pkg/topology/topology.go @@ -27,6 +27,7 @@ const ( KeyUID = "uid" KeyNamespace = "namespace" KeyPodSelector = "podSelector" + KeyNodeSelector = "nodeSelector" KeyTopoConfigPath = "topologyConfigPath" KeyTopoConfigmapName = "topologyConfigmapName" KeyBlockSizes = "block_sizes"