@@ -37,12 +37,35 @@ class GoodputRecorderTest(parameterized.TestCase):
37
37
expected_rolling_window_size = [1 , 2 , 3 ],
38
38
expected_jax_backend = "proxy" ,
39
39
),
40
+ dict (
41
+ recorder_spec = [
42
+ "name=test-name" ,
43
+ "upload_dir=/test/path" ,
44
+ "upload_interval=15" ,
45
+ "enable_monitoring=true" ,
46
+ ],
47
+ expected_rolling_window_size = [],
48
+ expected_jax_backend = None ,
49
+ expected_enable_monitoring = True ,
50
+ ),
51
+ dict (
52
+ recorder_spec = [
53
+ "name=test-name" ,
54
+ "upload_dir=/test/path" ,
55
+ "upload_interval=15" ,
56
+ "enable_monitoring=false" ,
57
+ ],
58
+ expected_rolling_window_size = [],
59
+ expected_jax_backend = None ,
60
+ expected_enable_monitoring = False ,
61
+ ),
40
62
)
41
63
def test_from_flags (
42
64
self ,
43
65
recorder_spec ,
44
66
expected_rolling_window_size ,
45
67
expected_jax_backend ,
68
+ expected_enable_monitoring = False ,
46
69
):
47
70
"""Tests that flags are correctly parsed into the config."""
48
71
mock_fv = mock .MagicMock (spec = flags .FlagValues )
@@ -56,6 +79,7 @@ def test_from_flags(
56
79
self .assertEqual (15 , recorder .config .upload_interval )
57
80
self .assertEqual (expected_rolling_window_size , recorder .config .rolling_window_size )
58
81
self .assertEqual (expected_jax_backend , recorder .config .jax_backend )
82
+ self .assertEqual (expected_enable_monitoring , recorder .config .enable_monitoring )
59
83
60
84
def test_from_flags_missing_required (self ):
61
85
"""Tests that missing required flags raise an error."""
@@ -188,6 +212,7 @@ def test_maybe_monitor_goodput(self, _, is_pathways_job, mock_jax_backend):
188
212
upload_dir = "/test" ,
189
213
upload_interval = 30 ,
190
214
jax_backend = mock_jax_backend ,
215
+ enable_monitoring = True ,
191
216
)
192
217
recorder = GoodputRecorder (cfg )
193
218
@@ -245,6 +270,7 @@ def test_maybe_monitor_rolling_window(
245
270
upload_interval = 30 ,
246
271
rolling_window_size = rolling_window_size ,
247
272
jax_backend = mock_jax_backend ,
273
+ enable_monitoring = True ,
248
274
)
249
275
recorder = GoodputRecorder (cfg )
250
276
@@ -279,7 +305,7 @@ def test_non_zero_process_index_skips_monitoring(
279
305
): # pylint: disable=unused-argument
280
306
"""Tests that monitoring is skipped on non-zero process indices."""
281
307
cfg = GoodputRecorder .default_config ().set (
282
- name = "test" , upload_dir = "/test" , upload_interval = 30
308
+ name = "test" , upload_dir = "/test" , upload_interval = 30 , enable_monitoring = True
283
309
)
284
310
recorder = GoodputRecorder (cfg )
285
311
@@ -294,6 +320,7 @@ def test_non_zero_process_index_skips_monitoring(
294
320
upload_dir = "/test" ,
295
321
upload_interval = 30 ,
296
322
rolling_window_size = [10 , 20 ],
323
+ enable_monitoring = True ,
297
324
)
298
325
recorder_rolling = GoodputRecorder (cfg_rolling )
299
326
with recorder_rolling ._maybe_monitor_rolling_window_goodput ():
@@ -347,6 +374,7 @@ def test_maybe_monitor_all(
347
374
upload_interval = 30 ,
348
375
rolling_window_size = rolling_window_size ,
349
376
jax_backend = jax_backend ,
377
+ enable_monitoring = True ,
350
378
)
351
379
recorder = GoodputRecorder (cfg )
352
380
@@ -373,3 +401,118 @@ def test_maybe_monitor_all(
373
401
else :
374
402
mock_monitor_instance .start_rolling_window_goodput_uploader .assert_not_called ()
375
403
mock_monitor_instance .stop_rolling_window_goodput_uploader .assert_not_called ()
404
+
405
+ @mock .patch ("jax.process_index" , return_value = 0 )
406
+ def test_enable_monitoring_disabled_by_default (self , _ ):
407
+ """Tests that monitoring is disabled by default (enable_monitoring=False)."""
408
+ cfg = GoodputRecorder .default_config ().set (
409
+ name = "test-disabled" ,
410
+ upload_dir = "/test" ,
411
+ upload_interval = 30 ,
412
+ # enable_monitoring defaults to False
413
+ rolling_window_size = [10 , 20 ],
414
+ )
415
+ recorder = GoodputRecorder (cfg )
416
+
417
+ # Verify the flag defaults to False
418
+ self .assertFalse (recorder .config .enable_monitoring )
419
+
420
+ with mock .patch ("ml_goodput_measurement.monitoring.GoodputMonitor" ) as mock_monitor_cls :
421
+ # Test that cumulative goodput monitoring is skipped
422
+ with recorder ._maybe_monitor_goodput ():
423
+ pass
424
+ mock_monitor_cls .assert_not_called ()
425
+
426
+ # Test that rolling window monitoring is skipped
427
+ with recorder ._maybe_monitor_rolling_window_goodput ():
428
+ pass
429
+ mock_monitor_cls .assert_not_called ()
430
+
431
+ # Test that maybe_monitor_all is skipped
432
+ with recorder .maybe_monitor_all ():
433
+ pass
434
+ mock_monitor_cls .assert_not_called ()
435
+
436
+ @mock .patch ("jax.process_index" , return_value = 0 )
437
+ def test_enable_monitoring_explicitly_disabled (self , _ ):
438
+ """Tests that monitoring is disabled when enable_monitoring=False."""
439
+ cfg = GoodputRecorder .default_config ().set (
440
+ name = "test-explicitly-disabled" ,
441
+ upload_dir = "/test" ,
442
+ upload_interval = 30 ,
443
+ enable_monitoring = False , # Explicitly disabled
444
+ rolling_window_size = [10 , 20 ],
445
+ )
446
+ recorder = GoodputRecorder (cfg )
447
+
448
+ with mock .patch ("ml_goodput_measurement.monitoring.GoodputMonitor" ) as mock_monitor_cls :
449
+ # Test cumulative goodput monitoring is skipped
450
+ with recorder ._maybe_monitor_goodput ():
451
+ pass
452
+ mock_monitor_cls .assert_not_called ()
453
+
454
+ # Test rolling window monitoring is skipped
455
+ with recorder ._maybe_monitor_rolling_window_goodput ():
456
+ pass
457
+ mock_monitor_cls .assert_not_called ()
458
+
459
+ # Test maybe_monitor_all is skipped
460
+ with recorder .maybe_monitor_all ():
461
+ pass
462
+ mock_monitor_cls .assert_not_called ()
463
+
464
+ @mock .patch ("jax.process_index" , return_value = 0 )
465
+ def test_enable_monitoring_explicitly_enabled (self , _ ):
466
+ """Tests that monitoring works when enable_monitoring=True."""
467
+ cfg = GoodputRecorder .default_config ().set (
468
+ name = "test-enabled" ,
469
+ upload_dir = "/test" ,
470
+ upload_interval = 30 ,
471
+ enable_monitoring = True , # Explicitly enabled
472
+ rolling_window_size = [10 , 20 ],
473
+ )
474
+ recorder = GoodputRecorder (cfg )
475
+
476
+ with mock .patch ("ml_goodput_measurement.monitoring.GoodputMonitor" ) as mock_monitor_cls :
477
+ mock_monitor_instance = mock_monitor_cls .return_value
478
+
479
+ # Test cumulative goodput monitoring works
480
+ with recorder ._maybe_monitor_goodput ():
481
+ pass
482
+
483
+ # Should be called once for cumulative monitoring
484
+ self .assertEqual (mock_monitor_cls .call_count , 1 )
485
+ mock_monitor_instance .start_goodput_uploader .assert_called_once ()
486
+ mock_monitor_instance .stop_goodput_uploader .assert_called_once ()
487
+
488
+ @mock .patch ("jax.process_index" , return_value = 0 )
489
+ def test_record_event_works_with_monitoring_disabled (self , _ ):
490
+ """Tests that record_event still works when monitoring is disabled."""
491
+ cfg = GoodputRecorder .default_config ().set (
492
+ name = "test-recording-only" ,
493
+ upload_dir = "/test" ,
494
+ upload_interval = 30 ,
495
+ enable_monitoring = False , # Monitoring disabled
496
+ )
497
+ recorder = GoodputRecorder (cfg )
498
+
499
+ # Verify that goodput recording still works (not monitoring/uploading)
500
+ with mock .patch ("ml_goodput_measurement.goodput.GoodputRecorder" ) as mock_recorder_cls :
501
+ mock_instance = mock_recorder_cls .return_value
502
+ mock_instance .record_job_start_time = mock .MagicMock ()
503
+ mock_instance .record_job_end_time = mock .MagicMock ()
504
+
505
+ # Record event should work
506
+ with recorder .record_event (measurement .EventType .JOB ):
507
+ pass
508
+
509
+ # Verify goodput recording happened
510
+ mock_recorder_cls .assert_called_once ()
511
+ mock_instance .record_job_start_time .assert_called_once ()
512
+ mock_instance .record_job_end_time .assert_called_once ()
513
+
514
+ # Verify no monitoring/uploading happened
515
+ with mock .patch ("ml_goodput_measurement.monitoring.GoodputMonitor" ) as mock_monitor_cls :
516
+ with recorder .maybe_monitor_all ():
517
+ pass
518
+ mock_monitor_cls .assert_not_called ()
0 commit comments