@@ -5,95 +5,118 @@ package inference
55import (
66 "context"
77 _ "embed"
8- "flag"
98 "fmt"
109 "log"
1110 "os"
11+ "os/signal"
1212 "slices"
1313 "testing"
1414 "time"
1515
1616 fwext "github.com/aws/aws-k8s-tester/internal/e2e"
17+ "github.com/aws/aws-k8s-tester/test/common"
1718 "github.com/aws/aws-k8s-tester/test/manifests"
18- appsv1 "k8s.io/api/apps/v1"
1919 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2020 "k8s.io/client-go/kubernetes"
2121 "sigs.k8s.io/e2e-framework/klient/wait"
2222 "sigs.k8s.io/e2e-framework/pkg/env"
2323 "sigs.k8s.io/e2e-framework/pkg/envconf"
2424)
2525
26+ type TestConfig struct {
27+ common.MetricOps
28+ BertInferenceImage string `flag:"bertInferenceImage" desc:"BERT inference container image"`
29+ InferenceMode string `flag:"inferenceMode" desc:"Inference mode for BERT (throughput or latency)"`
30+ GpuRequested int `flag:"gpuRequested" desc:"Number of GPUs required for inference"`
31+ NodeType string `flag:"nodeType" desc:"Instance type for cluster nodes"`
32+ }
33+
2634var (
27- testenv env.Environment
28- bertInferenceImage * string
29- inferenceMode * string
30- gpuRequested * int
35+ testenv env.Environment
36+ testConfig TestConfig
3137)
3238
3339func TestMain (m * testing.M ) {
34- bertInferenceImage = flag .String ("bertInferenceImage" , "" , "BERT inference container image" )
35- inferenceMode = flag .String ("inferenceMode" , "throughput" , "Inference mode for BERT (throughput or latency)" )
36- gpuRequested = flag .Int ("gpuRequested" , 1 , "Number of GPUs required for inference" )
40+ testConfig = TestConfig {
41+ InferenceMode : "throughput" ,
42+ GpuRequested : 1 ,
43+ }
3744
45+ _ , err := common .ParseFlags (& testConfig )
46+ if err != nil {
47+ log .Fatalf ("failed to parse flags: %v" , err )
48+ }
3849 cfg , err := envconf .NewFromFlags ()
3950 if err != nil {
40- log .Fatalf ("[ERROR] Failed to create test environment: %v" , err )
51+ log .Fatalf ("failed to initialize test environment: %v" , err )
4152 }
42- testenv = env .NewWithConfig (cfg )
4353
44- devicePluginManifests := [][]byte {
54+ ctx , cancel := signal .NotifyContext (context .Background (), os .Interrupt )
55+ defer cancel ()
56+ testenv = env .NewWithConfig (cfg ).WithContext (ctx )
57+
58+ // Render CloudWatch Agent manifest with dynamic dimensions
59+ renderedCloudWatchAgentManifest , err := manifests .RenderCloudWatchAgentManifest (testConfig .MetricDimensions )
60+ if err != nil {
61+ log .Printf ("Warning: failed to render CloudWatch Agent manifest: %v" , err )
62+ }
63+
64+ manifestsList := [][]byte {
4565 manifests .NvidiaDevicePluginManifest ,
4666 }
4767
68+ if len (testConfig .MetricDimensions ) > 0 {
69+ manifestsList = append (manifestsList , manifests .DCGMExporterManifest , renderedCloudWatchAgentManifest )
70+ }
71+
4872 testenv .Setup (
4973 func (ctx context.Context , config * envconf.Config ) (context.Context , error ) {
50- log .Println ("[INFO] Applying NVIDIA device plugin." )
51- if applyErr := fwext .ApplyManifests (config .Client ().RESTConfig (), devicePluginManifests ... ); applyErr != nil {
52- return ctx , fmt .Errorf ("failed to apply device plugin: %w" , applyErr )
74+ log .Println ("Applying manifests." )
75+ err := fwext .ApplyManifests (config .Client ().RESTConfig (), manifestsList ... )
76+ if err != nil {
77+ return ctx , fmt .Errorf ("failed to apply manifests: %w" , err )
5378 }
79+ log .Println ("Successfully applied manifests." )
5480 return ctx , nil
5581 },
82+ common .DeployDaemonSet ("nvidia-device-plugin-daemonset" , "kube-system" ),
5683 func (ctx context.Context , config * envconf.Config ) (context.Context , error ) {
57- ds := & appsv1.DaemonSet {
58- ObjectMeta : metav1.ObjectMeta {
59- Name : "nvidia-device-plugin-daemonset" ,
60- Namespace : "kube-system" ,
61- },
84+ if len (testConfig .MetricDimensions ) > 0 {
85+ if ctx , err := common .DeployDaemonSet ("dcgm-exporter" , "kube-system" )(ctx , config ); err != nil {
86+ return ctx , err
87+ }
88+ if ctx , err := common .DeployDaemonSet ("cwagent" , "amazon-cloudwatch" )(ctx , config ); err != nil {
89+ return ctx , err
90+ }
6291 }
63- err := wait .For (
64- fwext .NewConditionExtension (config .Client ().Resources ()).DaemonSetReady (ds ),
65- wait .WithTimeout (5 * time .Minute ),
66- )
67- if err != nil {
68- return ctx , fmt .Errorf ("device plugin daemonset not ready: %w" , err )
69- }
70- log .Println ("[INFO] NVIDIA device plugin is ready." )
7192 return ctx , nil
7293 },
7394 checkGpuCapacity ,
7495 )
7596
7697 testenv .Finish (
7798 func (ctx context.Context , config * envconf.Config ) (context.Context , error ) {
78- log .Println ("[INFO] Cleaning up NVIDIA device plugin." )
79- slices .Reverse (devicePluginManifests )
80- if delErr := fwext .DeleteManifests (config .Client ().RESTConfig (), devicePluginManifests ... ); delErr != nil {
81- return ctx , fmt .Errorf ("failed to delete device plugin: %w" , delErr )
99+ log .Println ("Deleting NVIDIA device plugin, DCGM Exporter and CloudWatch Agent manifests." )
100+ slices .Reverse (manifestsList )
101+ err := fwext .DeleteManifests (config .Client ().RESTConfig (), manifestsList ... )
102+ if err != nil {
103+ return ctx , fmt .Errorf ("failed to delete manifests: %w" , err )
82104 }
83- log .Println ("[INFO] Device plugin cleanup complete ." )
105+ log .Println ("Successfully deleted NVIDIA device plugin, DCGM Exporter and CloudWatch Agent manifests ." )
84106 return ctx , nil
85107 },
86108 )
87109
110+ log .Println ("Starting tests..." )
88111 exitCode := testenv .Run (m )
89- log .Printf ("[INFO] Test environment finished with exit code %d" , exitCode )
112+ log .Printf ("Tests finished with exit code %d" , exitCode )
90113 os .Exit (exitCode )
91114}
92115
93116// checkGpuCapacity ensures at least one node has >= the requested number of GPUs,
94117// and logs each node's instance type.
95118func checkGpuCapacity (ctx context.Context , config * envconf.Config ) (context.Context , error ) {
96- log .Printf ("[INFO] Validating cluster has at least %d GPU(s)." , * gpuRequested )
119+ log .Printf ("[INFO] Validating cluster has at least %d GPU(s)." , testConfig . GpuRequested )
97120
98121 cs , err := kubernetes .NewForConfig (config .Client ().RESTConfig ())
99122 if err != nil {
@@ -110,9 +133,9 @@ func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Cont
110133 for _ , node := range nodes .Items {
111134 instanceType := node .Labels ["node.kubernetes.io/instance-type" ]
112135 gpuCap , ok := node .Status .Capacity ["nvidia.com/gpu" ]
113- if ok && int (gpuCap .Value ()) >= * gpuRequested {
136+ if ok && int (gpuCap .Value ()) >= testConfig . GpuRequested {
114137 log .Printf ("[INFO] Node %s (type: %s) meets the request of %d GPU(s)." ,
115- node .Name , instanceType , * gpuRequested )
138+ node .Name , instanceType , testConfig . GpuRequested )
116139 return true , nil
117140 }
118141 log .Printf ("[INFO] Node %s (type: %s) has no GPU capacity." , node .Name , instanceType )
@@ -122,7 +145,7 @@ func checkGpuCapacity(ctx context.Context, config *envconf.Config) (context.Cont
122145 }, wait .WithTimeout (5 * time .Minute ), wait .WithInterval (10 * time .Second ))
123146
124147 if err != nil {
125- return ctx , fmt .Errorf ("no node has >= %d GPU(s)" , * gpuRequested )
148+ return ctx , fmt .Errorf ("no node has >= %d GPU(s)" , testConfig . GpuRequested )
126149 }
127150
128151 log .Println ("[INFO] GPU capacity check passed." )
0 commit comments