Skip to content

Commit 725f19b

Browse files
authored
retrive SLURM node regions from the worker nodes (#173)
Signed-off-by: Dmitry Shmulevich <[email protected]>
1 parent fc56471 commit 725f19b

File tree

19 files changed

+201
-85
lines changed

19 files changed

+201
-85
lines changed

pkg/engines/slurm/slurm.go

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,12 @@ type TopologyNodeFinder struct {
7777
}
7878

7979
type instanceMapper interface {
80+
// Instances2NodeMap receives a list of SLURM node names and returns a map of
81+
// the service provider assigned compute instance IDs to the node names
8082
Instances2NodeMap(context.Context, []string) (map[string]string, error)
81-
GetComputeInstancesRegion(context.Context) (string, error)
83+
// GetInstancesRegions receives a list of SLURM node names and returns a map
84+
// of node names to their deployed regions
85+
GetInstancesRegions(context.Context, []string) (map[string]string, error)
8286
}
8387

8488
var (
@@ -114,22 +118,50 @@ func (eng *SlurmEngine) GetComputeInstances(ctx context.Context, environment eng
114118
return nil, err
115119
}
116120

121+
if len(nodes) == 0 {
122+
return nil, nil
123+
}
124+
117125
i2n, err := instanceMapper.Instances2NodeMap(ctx, nodes)
118126
if err != nil {
119127
return nil, err
120128
}
121129
klog.V(4).Infof("Detected instance map: %v", i2n)
122130

123-
region, err := instanceMapper.GetComputeInstancesRegion(ctx)
131+
nodeRegions, err := instanceMapper.GetInstancesRegions(ctx, nodes)
124132
if err != nil {
125133
return nil, err
126134
}
127-
klog.V(4).Infof("Detected region: %s", region)
128135

129-
return []topology.ComputeInstances{{
130-
Region: region,
131-
Instances: i2n,
132-
}}, nil
136+
return aggregateComputeInstances(i2n, nodeRegions), nil
137+
}
138+
139+
func aggregateComputeInstances(i2n, nodeRegions map[string]string) []topology.ComputeInstances {
140+
// regions maps region name to the corresponding index in "cis"
141+
regions := make(map[string]int)
142+
cis := []topology.ComputeInstances{}
143+
144+
for instance, node := range i2n {
145+
region, ok := nodeRegions[node]
146+
if !ok {
147+
klog.Warningf("Failed to find region for node %s", node)
148+
continue
149+
}
150+
indx, ok := regions[region]
151+
if !ok {
152+
indx = len(regions)
153+
regions[region] = indx
154+
cis = append(cis, topology.ComputeInstances{
155+
Region: region,
156+
Instances: map[string]string{instance: node},
157+
})
158+
} else {
159+
cis[indx].Instances[instance] = node
160+
}
161+
}
162+
klog.V(4).Infof("Detected regions: %v", regions)
163+
164+
return cis
133165
}
134166

135167
func GetNodeList(ctx context.Context) ([]string, error) {

pkg/engines/slurm/slurm_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,61 @@ import (
2727
"github.com/NVIDIA/topograph/pkg/translate"
2828
)
2929

30+
func TestAggregateComputeInstances(t *testing.T) {
31+
testCases := []struct {
32+
name string
33+
i2n map[string]string
34+
regions map[string]string
35+
cis []topology.ComputeInstances
36+
}{
37+
{
38+
name: "Case 1: no data",
39+
cis: []topology.ComputeInstances{},
40+
},
41+
{
42+
name: "Case 2: full match",
43+
i2n: map[string]string{"i1": "n1", "i2": "n2", "i3": "n3", "i4": "n4", "i5": "n5"},
44+
regions: map[string]string{"n1": "r1", "n2": "r1", "n3": "r2", "n4": "r2", "n5": "r3"},
45+
cis: []topology.ComputeInstances{
46+
{
47+
Region: "r1",
48+
Instances: map[string]string{"i1": "n1", "i2": "n2"},
49+
},
50+
{
51+
Region: "r2",
52+
Instances: map[string]string{"i3": "n3", "i4": "n4"},
53+
},
54+
{
55+
Region: "r3",
56+
Instances: map[string]string{"i5": "n5"},
57+
},
58+
},
59+
},
60+
{
61+
name: "Case 3: partial match",
62+
i2n: map[string]string{"i1": "n1", "i2": "n2", "i3": "n3", "i4": "n4", "i5": "n5"},
63+
regions: map[string]string{"n1": "r1", "n3": "r2", "n4": "r2"},
64+
cis: []topology.ComputeInstances{
65+
{
66+
Region: "r1",
67+
Instances: map[string]string{"i1": "n1"},
68+
},
69+
{
70+
Region: "r2",
71+
Instances: map[string]string{"i3": "n3", "i4": "n4"},
72+
},
73+
},
74+
},
75+
}
76+
77+
for _, tc := range testCases {
78+
t.Run(tc.name, func(t *testing.T) {
79+
cis := aggregateComputeInstances(tc.i2n, tc.regions)
80+
require.ElementsMatch(t, tc.cis, cis)
81+
})
82+
}
83+
}
84+
3085
func TestParseFakeNodes(t *testing.T) {
3186
testCases := []struct {
3287
name string

pkg/providers/aws/imds.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"context"
2121
"fmt"
2222
"net/http"
23-
"strings"
2423
"time"
2524

2625
"github.com/NVIDIA/topograph/internal/exec"
@@ -59,13 +58,13 @@ func imdsCmd(url string) string {
5958
IMDSTokenHeader, IMDSTokenURL, makeHeader(IMDSHeaderKey, "$TOKEN"), url)
6059
}
6160

62-
func getRegion(ctx context.Context) (string, error) {
63-
stdout, err := exec.Exec(ctx, "sh", []string{"-c", imdsCmd(IMDSRegionURL)}, nil)
61+
func getRegions(ctx context.Context, nodes []string) (map[string]string, error) {
62+
stdout, err := exec.Pdsh(ctx, imdsCmd(IMDSRegionURL), nodes)
6463
if err != nil {
65-
return "", err
64+
return nil, err
6665
}
6766

68-
return strings.TrimSpace(stdout.String()), nil
67+
return providers.ParsePdshOutput(stdout, true)
6968
}
7069

7170
func GetNodeAnnotations(ctx context.Context) (map[string]string, error) {

pkg/providers/aws/provider.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ func (p *Provider) Instances2NodeMap(ctx context.Context, nodes []string) (map[s
179179
return instanceToNodeMap(ctx, nodes)
180180
}
181181

182-
// GetComputeInstancesRegion implements slurm.instanceMapper
183-
func (p *Provider) GetComputeInstancesRegion(ctx context.Context) (string, error) {
184-
return getRegion(ctx)
182+
// GetInstancesRegions implements slurm.instanceMapper
183+
func (p *Provider) GetInstancesRegions(ctx context.Context, nodes []string) (map[string]string, error) {
184+
return getRegions(ctx, nodes)
185185
}

pkg/providers/cw/provider.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,13 @@ func (p *Provider) Instances2NodeMap(ctx context.Context, nodes []string) (map[s
7878
return i2n, nil
7979
}
8080

81-
// GetComputeInstancesRegion implements slurm.instanceMapper
82-
func (p *Provider) GetComputeInstancesRegion(_ context.Context) (string, error) {
83-
return "", nil
81+
// GetInstancesRegions implements slurm.instanceMapper
82+
func (p *Provider) GetInstancesRegions(ctx context.Context, nodes []string) (map[string]string, error) {
83+
res := make(map[string]string)
84+
for _, node := range nodes {
85+
res[node] = ""
86+
}
87+
return res, nil
8488
}
8589

8690
// GetNodeRegion implements k8s.k8sNodeInfo

pkg/providers/gcp/imds.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,22 @@ func instanceToNodeMap(ctx context.Context, nodes []string) (map[string]string,
4545
return providers.ParseInstanceOutput(stdout)
4646
}
4747

48-
func getRegion(ctx context.Context) (string, error) {
49-
stdout, err := exec.Exec(ctx, "curl", imdsCurlParams(IMDSRegionURL), nil)
48+
func getRegions(ctx context.Context, nodes []string) (map[string]string, error) {
49+
stdout, err := exec.Pdsh(ctx, pdshCmd(IMDSRegionURL), nodes)
5050
if err != nil {
51-
return "", err
51+
return nil, err
5252
}
5353

54-
return convertRegion(strings.TrimSpace(stdout.String())), nil
55-
}
54+
res, err := providers.ParsePdshOutput(stdout, true)
55+
if err != nil {
56+
return nil, err
57+
}
58+
59+
for key, val := range res {
60+
res[key] = convertRegion(strings.TrimSpace(val))
61+
}
5662

57-
func imdsCurlParams(url string) []string {
58-
return []string{"-s", "-H", IMDSHeader, url}
63+
return res, nil
5964
}
6065

6166
func pdshCmd(url string) string {

pkg/providers/gcp/imds_test.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,6 @@ import (
2323
"github.com/stretchr/testify/require"
2424
)
2525

26-
func TestImdsCurlParams(t *testing.T) {
27-
expected := []string{"-s", "-H", IMDSHeader, IMDSInstanceURL}
28-
require.Equal(t, expected, imdsCurlParams(IMDSInstanceURL))
29-
}
30-
3126
func TestPdshCmd(t *testing.T) {
3227
expected := fmt.Sprintf(`echo $(curl -s -H "Metadata-Flavor: Google" %s)`, IMDSInstanceURL)
3328
require.Equal(t, expected, pdshCmd(IMDSInstanceURL))

pkg/providers/gcp/provider.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func (p *Provider) Instances2NodeMap(ctx context.Context, nodes []string) (map[s
114114
return instanceToNodeMap(ctx, nodes)
115115
}
116116

117-
// GetComputeInstancesRegion implements slurm.instanceMapper
118-
func (p *Provider) GetComputeInstancesRegion(ctx context.Context) (string, error) {
119-
return getRegion(ctx)
117+
// GetInstancesRegions implements slurm.instanceMapper
118+
func (p *Provider) GetInstancesRegions(ctx context.Context, nodes []string) (map[string]string, error) {
119+
return getRegions(ctx, nodes)
120120
}

pkg/providers/infiniband/provider_bm.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ func (p *ProviderBM) Instances2NodeMap(ctx context.Context, nodes []string) (map
6464
return i2n, nil
6565
}
6666

67-
// GetComputeInstancesRegion implements slurm.instanceMapper
68-
func (p *ProviderBM) GetComputeInstancesRegion(_ context.Context) (string, error) {
69-
return "local", nil
67+
// GetInstancesRegions implements slurm.instanceMapper
68+
func (p *ProviderBM) GetInstancesRegions(ctx context.Context, nodes []string) (map[string]string, error) {
69+
res := make(map[string]string)
70+
for _, node := range nodes {
71+
res[node] = "local"
72+
}
73+
return res, nil
7074
}

pkg/providers/nebius/imds.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,17 @@ func instanceToNodeMap(ctx context.Context, nodes []string) (map[string]string,
3030
return providers.ParseInstanceOutput(stdout)
3131
}
3232

33-
func getParentID() (string, error) {
34-
return providers.ReadFile(IMDSParentID)
33+
func getRegions(ctx context.Context, nodes []string) (map[string]string, error) {
34+
stdout, err := exec.Pdsh(ctx, "cat "+IMDSRegionPath, nodes)
35+
if err != nil {
36+
return nil, err
37+
}
38+
39+
return providers.ParsePdshOutput(stdout, true)
3540
}
3641

37-
func getRegion() (string, error) {
38-
return providers.ReadFile(IMDSRegionPath)
42+
func getParentID() (string, error) {
43+
return providers.ReadFile(IMDSParentID)
3944
}
4045

4146
func GetNodeAnnotations(ctx context.Context) (map[string]string, error) {

0 commit comments

Comments
 (0)