Skip to content

Commit c6e106d

Browse files
Cloudwatch Agent + DCGM Integration for nvidia-inference test
1 parent 6e6dd5d commit c6e106d

File tree

2 files changed

+65
-42
lines changed

2 files changed

+65
-42
lines changed

test/cases/nvidia-inference/bert_inference_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func TestBertInference(t *testing.T) {
3838
WithLabel("suite", "nvidia").
3939
WithLabel("hardware", "gpu").
4040
Setup(func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context {
41-
if *bertInferenceImage == "" {
41+
if testConfig.BertInferenceImage == "" {
4242
t.Fatalf("[ERROR] bertInferenceImage must be set")
4343
}
4444

@@ -47,9 +47,9 @@ func TestBertInference(t *testing.T) {
4747
renderedBertInferenceManifest, err = fwext.RenderManifests(
4848
bertInferenceManifest,
4949
bertInferenceManifestTplVars{
50-
BertInferenceImage: *bertInferenceImage,
51-
InferenceMode: *inferenceMode,
52-
GPUPerNode: fmt.Sprintf("%d", *gpuRequested),
50+
BertInferenceImage: testConfig.BertInferenceImage,
51+
InferenceMode: testConfig.InferenceMode,
52+
GPUPerNode: fmt.Sprintf("%d", testConfig.GpuRequested),
5353
},
5454
)
5555
if err != nil {

test/cases/nvidia-inference/main_test.go

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,95 +5,118 @@ package inference
55
import (
66
"context"
77
_ "embed"
8-
"flag"
98
"fmt"
109
"log"
1110
"os"
11+
"os/signal"
1212
"slices"
1313
"testing"
1414
"time"
1515

1616
fwext "github.com/aws/aws-k8s-tester/internal/e2e"
17+
"github.com/aws/aws-k8s-tester/test/common"
1718
"github.com/aws/aws-k8s-tester/test/manifests"
18-
appsv1 "k8s.io/api/apps/v1"
1919
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2020
"k8s.io/client-go/kubernetes"
2121
"sigs.k8s.io/e2e-framework/klient/wait"
2222
"sigs.k8s.io/e2e-framework/pkg/env"
2323
"sigs.k8s.io/e2e-framework/pkg/envconf"
2424
)
2525

26+
type TestConfig struct {
27+
common.MetricOps
28+
BertInferenceImage string `flag:"bertInferenceImage" desc:"BERT inference container image"`
29+
InferenceMode string `flag:"inferenceMode" desc:"Inference mode for BERT (throughput or latency)"`
30+
GpuRequested int `flag:"gpuRequested" desc:"Number of GPUs required for inference"`
31+
NodeType string `flag:"nodeType" desc:"Instance type for cluster nodes"`
32+
}
33+
2634
var (
27-
testenv env.Environment
28-
bertInferenceImage *string
29-
inferenceMode *string
30-
gpuRequested *int
35+
testenv env.Environment
36+
testConfig TestConfig
3137
)
3238

3339
func TestMain(m *testing.M) {
34-
bertInferenceImage = flag.String("bertInferenceImage", "", "BERT inference container image")
35-
inferenceMode = flag.String("inferenceMode", "throughput", "Inference mode for BERT (throughput or latency)")
36-
gpuRequested = flag.Int("gpuRequested", 1, "Number of GPUs required for inference")
40+
testConfig = TestConfig{
41+
InferenceMode: "throughput",
42+
GpuRequested: 1,
43+
}
3744

45+
_, err := common.ParseFlags(&testConfig)
46+
if err != nil {
47+
log.Fatalf("failed to parse flags: %v", err)
48+
}
3849
cfg, err := envconf.NewFromFlags()
3950
if err != nil {
40-
log.Fatalf("[ERROR] Failed to create test environment: %v", err)
51+
log.Fatalf("failed to initialize test environment: %v", err)
4152
}
42-
testenv = env.NewWithConfig(cfg)
4353

44-
devicePluginManifests := [][]byte{
54+
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
55+
defer cancel()
56+
testenv = env.NewWithConfig(cfg).WithContext(ctx)
57+
58+
// Render CloudWatch Agent manifest with dynamic dimensions
59+
renderedCloudWatchAgentManifest, err := manifests.RenderCloudWatchAgentManifest(testConfig.MetricDimensions)
60+
if err != nil {
61+
log.Printf("Warning: failed to render CloudWatch Agent manifest: %v", err)
62+
}
63+
64+
manifestsList := [][]byte{
4565
manifests.NvidiaDevicePluginManifest,
4666
}
4767

68+
if len(testConfig.MetricDimensions) > 0 {
69+
manifestsList = append(manifestsList, manifests.DCGMExporterManifest, renderedCloudWatchAgentManifest)
70+
}
71+
4872
testenv.Setup(
4973
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
50-
log.Println("[INFO] Applying NVIDIA device plugin.")
51-
if applyErr := fwext.ApplyManifests(config.Client().RESTConfig(), devicePluginManifests...); applyErr != nil {
52-
return ctx, fmt.Errorf("failed to apply device plugin: %w", applyErr)
74+
log.Println("Applying manifests.")
75+
err := fwext.ApplyManifests(config.Client().RESTConfig(), manifestsList...)
76+
if err != nil {
77+
return ctx, fmt.Errorf("failed to apply manifests: %w", err)
5378
}
79+
log.Println("Successfully applied manifests.")
5480
return ctx, nil
5581
},
82+
common.DeployDaemonSet("nvidia-device-plugin-daemonset", "kube-system"),
5683
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
57-
ds := &appsv1.DaemonSet{
58-
ObjectMeta: metav1.ObjectMeta{
59-
Name: "nvidia-device-plugin-daemonset",
60-
Namespace: "kube-system",
61-
},
84+
if len(testConfig.MetricDimensions) > 0 {
85+
if ctx, err := common.DeployDaemonSet("dcgm-exporter", "kube-system")(ctx, config); err != nil {
86+
return ctx, err
87+
}
88+
if ctx, err := common.DeployDaemonSet("cwagent", "amazon-cloudwatch")(ctx, config); err != nil {
89+
return ctx, err
90+
}
6291
}
63-
err := wait.For(
64-
fwext.NewConditionExtension(config.Client().Resources()).DaemonSetReady(ds),
65-
wait.WithTimeout(5*time.Minute),
66-
)
67-
if err != nil {
68-
return ctx, fmt.Errorf("device plugin daemonset not ready: %w", err)
69-
}
70-
log.Println("[INFO] NVIDIA device plugin is ready.")
7192
return ctx, nil
7293
},
7394
checkGpuCapacity,
7495
)
7596

7697
testenv.Finish(
7798
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
78-
log.Println("[INFO] Cleaning up NVIDIA device plugin.")
79-
slices.Reverse(devicePluginManifests)
80-
if delErr := fwext.DeleteManifests(config.Client().RESTConfig(), devicePluginManifests...); delErr != nil {
81-
return ctx, fmt.Errorf("failed to delete device plugin: %w", delErr)
99+
log.Println("Deleting NVIDIA device plugin, DCGM Exporter and CloudWatch Agent manifests.")
100+
slices.Reverse(manifestsList)
101+
err := fwext.DeleteManifests(config.Client().RESTConfig(), manifestsList...)
102+
if err != nil {
103+
return ctx, fmt.Errorf("failed to delete manifests: %w", err)
82104
}
83-
log.Println("[INFO] Device plugin cleanup complete.")
105+
log.Println("Successfully deleted NVIDIA device plugin, DCGM Exporter and CloudWatch Agent manifests.")
84106
return ctx, nil
85107
},
86108
)
87109

110+
log.Println("Starting tests...")
88111
exitCode := testenv.Run(m)
89-
log.Printf("[INFO] Test environment finished with exit code %d", exitCode)
112+
log.Printf("Tests finished with exit code %d", exitCode)
90113
os.Exit(exitCode)
91114
}
92115

93116
// checkGpuCapacity ensures at least one node has >= the requested number of GPUs,
94117
// and logs each node's instance type.
95118
func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Context, error) {
96-
log.Printf("[INFO] Validating cluster has at least %d GPU(s).", *gpuRequested)
119+
log.Printf("[INFO] Validating cluster has at least %d GPU(s).", testConfig.GpuRequested)
97120

98121
cs, err := kubernetes.NewForConfig(config.Client().RESTConfig())
99122
if err != nil {
@@ -110,9 +133,9 @@ func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Cont
110133
for _, node := range nodes.Items {
111134
instanceType := node.Labels["node.kubernetes.io/instance-type"]
112135
gpuCap, ok := node.Status.Capacity["nvidia.com/gpu"]
113-
if ok && int(gpuCap.Value()) >= *gpuRequested {
136+
if ok && int(gpuCap.Value()) >= testConfig.GpuRequested {
114137
log.Printf("[INFO] Node %s (type: %s) meets the request of %d GPU(s).",
115-
node.Name, instanceType, *gpuRequested)
138+
node.Name, instanceType, testConfig.GpuRequested)
116139
return true, nil
117140
}
118141
log.Printf("[INFO] Node %s (type: %s) has no GPU capacity.", node.Name, instanceType)
@@ -122,7 +145,7 @@ func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Cont
122145
}, wait.WithTimeout(5*time.Minute), wait.WithInterval(10*time.Second))
123146

124147
if err != nil {
125-
return ctx, fmt.Errorf("no node has >= %d GPU(s)", *gpuRequested)
148+
return ctx, fmt.Errorf("no node has >= %d GPU(s)", testConfig.GpuRequested)
126149
}
127150

128151
log.Println("[INFO] GPU capacity check passed.")

0 commit comments

Comments
 (0)