Skip to content

Commit 62e7b41

Browse files
ampollorenochanglan
authored andcommitted
Add a flag to measurement.py to disable goodput monitoring by default
GitOrigin-RevId: e701de09e8b02bc95d21d2fed9bd445ed2d195f9
1 parent 06c50eb commit 62e7b41

File tree

2 files changed

+159
-4
lines changed

2 files changed

+159
-4
lines changed

axlearn/cloud/gcp/measurement.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ml_goodput_measurement import goodput
3030
from ml_goodput_measurement import monitoring as goodput_monitoring
3131

32-
from axlearn.cloud.common.utils import parse_kv_flags
32+
from axlearn.cloud.common.utils import parse_kv_flags, to_bool
3333
from axlearn.common import measurement
3434
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config
3535

@@ -52,12 +52,17 @@ class Config(measurement.Recorder.Config):
5252
See "How to Monitor Rolling Window Goodput Metrics" in
5353
docs/05-Goodput-Monitoring.md for more details.
5454
jax_backend: Jax backend type to infer Pathways environment.
55+
enable_monitoring: Whether to enable goodput monitoring/uploading.
5556
"""
5657

5758
upload_dir: Required[str] = REQUIRED
5859
upload_interval: Required[int] = REQUIRED
5960
rolling_window_size: Sequence[int] = []
6061
jax_backend: Optional[str] = None
62+
# Disabled by default because of performance degradation. This doesn't disable goodput
63+
# recording.
64+
# TODO (apolloreno): once the performance degradation is fixed, will change default to True
65+
enable_monitoring: bool = False
6166

6267
@classmethod
6368
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
@@ -72,6 +77,7 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
7277
- rolling_window_size: Comma-separated list of integers representing rolling window
7378
sizes in seconds.
7479
- jax_backend: The type of jax backend.
80+
- enable_monitoring: Boolean to enable/disable goodput monitoring (default: false).
7581
"""
7682
cfg: measurement.Recorder.Config = cls.default_config()
7783
parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=")
@@ -83,6 +89,8 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
8389
parsed_flags["rolling_window_size"] = [
8490
int(x) for x in parsed_flags["rolling_window_size"].split(",")
8591
]
92+
if "enable_monitoring" in parsed_flags:
93+
parsed_flags["enable_monitoring"] = to_bool(parsed_flags["enable_monitoring"])
8694
return maybe_set_config(cfg, **parsed_flags).instantiate()
8795

8896
def __init__(self, cfg):
@@ -149,7 +157,7 @@ def _maybe_monitor_goodput(self, *args, **kwargs):
149157
Default behavior is to push metrics to Google Cloud Monitoring.
150158
This behavior can be overridden by configuring `goodput_monitoring.GCPOptions`
151159
"""
152-
if jax.process_index() != 0:
160+
if not self.config.enable_monitoring or jax.process_index() != 0:
153161
yield
154162
return
155163
try:
@@ -175,7 +183,11 @@ def _maybe_monitor_goodput(self, *args, **kwargs):
175183
@contextlib.contextmanager
176184
def _maybe_monitor_rolling_window_goodput(self):
177185
"""Monitor rolling window goodput if enabled."""
178-
if not self.config.rolling_window_size or jax.process_index() != 0:
186+
if (
187+
not self.config.enable_monitoring
188+
or not self.config.rolling_window_size
189+
or jax.process_index() != 0
190+
):
179191
yield
180192
return
181193
try:

axlearn/cloud/gcp/measurement_test.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,35 @@ class GoodputRecorderTest(parameterized.TestCase):
3737
expected_rolling_window_size=[1, 2, 3],
3838
expected_jax_backend="proxy",
3939
),
40+
dict(
41+
recorder_spec=[
42+
"name=test-name",
43+
"upload_dir=/test/path",
44+
"upload_interval=15",
45+
"enable_monitoring=true",
46+
],
47+
expected_rolling_window_size=[],
48+
expected_jax_backend=None,
49+
expected_enable_monitoring=True,
50+
),
51+
dict(
52+
recorder_spec=[
53+
"name=test-name",
54+
"upload_dir=/test/path",
55+
"upload_interval=15",
56+
"enable_monitoring=false",
57+
],
58+
expected_rolling_window_size=[],
59+
expected_jax_backend=None,
60+
expected_enable_monitoring=False,
61+
),
4062
)
4163
def test_from_flags(
4264
self,
4365
recorder_spec,
4466
expected_rolling_window_size,
4567
expected_jax_backend,
68+
expected_enable_monitoring=False,
4669
):
4770
"""Tests that flags are correctly parsed into the config."""
4871
mock_fv = mock.MagicMock(spec=flags.FlagValues)
@@ -56,6 +79,7 @@ def test_from_flags(
5679
self.assertEqual(15, recorder.config.upload_interval)
5780
self.assertEqual(expected_rolling_window_size, recorder.config.rolling_window_size)
5881
self.assertEqual(expected_jax_backend, recorder.config.jax_backend)
82+
self.assertEqual(expected_enable_monitoring, recorder.config.enable_monitoring)
5983

6084
def test_from_flags_missing_required(self):
6185
"""Tests that missing required flags raise an error."""
@@ -188,6 +212,7 @@ def test_maybe_monitor_goodput(self, _, is_pathways_job, mock_jax_backend):
188212
upload_dir="/test",
189213
upload_interval=30,
190214
jax_backend=mock_jax_backend,
215+
enable_monitoring=True,
191216
)
192217
recorder = GoodputRecorder(cfg)
193218

@@ -245,6 +270,7 @@ def test_maybe_monitor_rolling_window(
245270
upload_interval=30,
246271
rolling_window_size=rolling_window_size,
247272
jax_backend=mock_jax_backend,
273+
enable_monitoring=True,
248274
)
249275
recorder = GoodputRecorder(cfg)
250276

@@ -279,7 +305,7 @@ def test_non_zero_process_index_skips_monitoring(
279305
): # pylint: disable=unused-argument
280306
"""Tests that monitoring is skipped on non-zero process indices."""
281307
cfg = GoodputRecorder.default_config().set(
282-
name="test", upload_dir="/test", upload_interval=30
308+
name="test", upload_dir="/test", upload_interval=30, enable_monitoring=True
283309
)
284310
recorder = GoodputRecorder(cfg)
285311

@@ -294,6 +320,7 @@ def test_non_zero_process_index_skips_monitoring(
294320
upload_dir="/test",
295321
upload_interval=30,
296322
rolling_window_size=[10, 20],
323+
enable_monitoring=True,
297324
)
298325
recorder_rolling = GoodputRecorder(cfg_rolling)
299326
with recorder_rolling._maybe_monitor_rolling_window_goodput():
@@ -347,6 +374,7 @@ def test_maybe_monitor_all(
347374
upload_interval=30,
348375
rolling_window_size=rolling_window_size,
349376
jax_backend=jax_backend,
377+
enable_monitoring=True,
350378
)
351379
recorder = GoodputRecorder(cfg)
352380

@@ -373,3 +401,118 @@ def test_maybe_monitor_all(
373401
else:
374402
mock_monitor_instance.start_rolling_window_goodput_uploader.assert_not_called()
375403
mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_not_called()
404+
405+
@mock.patch("jax.process_index", return_value=0)
406+
def test_enable_monitoring_disabled_by_default(self, _):
407+
"""Tests that monitoring is disabled by default (enable_monitoring=False)."""
408+
cfg = GoodputRecorder.default_config().set(
409+
name="test-disabled",
410+
upload_dir="/test",
411+
upload_interval=30,
412+
# enable_monitoring defaults to False
413+
rolling_window_size=[10, 20],
414+
)
415+
recorder = GoodputRecorder(cfg)
416+
417+
# Verify the flag defaults to False
418+
self.assertFalse(recorder.config.enable_monitoring)
419+
420+
with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls:
421+
# Test that cumulative goodput monitoring is skipped
422+
with recorder._maybe_monitor_goodput():
423+
pass
424+
mock_monitor_cls.assert_not_called()
425+
426+
# Test that rolling window monitoring is skipped
427+
with recorder._maybe_monitor_rolling_window_goodput():
428+
pass
429+
mock_monitor_cls.assert_not_called()
430+
431+
# Test that maybe_monitor_all is skipped
432+
with recorder.maybe_monitor_all():
433+
pass
434+
mock_monitor_cls.assert_not_called()
435+
436+
@mock.patch("jax.process_index", return_value=0)
437+
def test_enable_monitoring_explicitly_disabled(self, _):
438+
"""Tests that monitoring is disabled when enable_monitoring=False."""
439+
cfg = GoodputRecorder.default_config().set(
440+
name="test-explicitly-disabled",
441+
upload_dir="/test",
442+
upload_interval=30,
443+
enable_monitoring=False, # Explicitly disabled
444+
rolling_window_size=[10, 20],
445+
)
446+
recorder = GoodputRecorder(cfg)
447+
448+
with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls:
449+
# Test cumulative goodput monitoring is skipped
450+
with recorder._maybe_monitor_goodput():
451+
pass
452+
mock_monitor_cls.assert_not_called()
453+
454+
# Test rolling window monitoring is skipped
455+
with recorder._maybe_monitor_rolling_window_goodput():
456+
pass
457+
mock_monitor_cls.assert_not_called()
458+
459+
# Test maybe_monitor_all is skipped
460+
with recorder.maybe_monitor_all():
461+
pass
462+
mock_monitor_cls.assert_not_called()
463+
464+
@mock.patch("jax.process_index", return_value=0)
465+
def test_enable_monitoring_explicitly_enabled(self, _):
466+
"""Tests that monitoring works when enable_monitoring=True."""
467+
cfg = GoodputRecorder.default_config().set(
468+
name="test-enabled",
469+
upload_dir="/test",
470+
upload_interval=30,
471+
enable_monitoring=True, # Explicitly enabled
472+
rolling_window_size=[10, 20],
473+
)
474+
recorder = GoodputRecorder(cfg)
475+
476+
with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls:
477+
mock_monitor_instance = mock_monitor_cls.return_value
478+
479+
# Test cumulative goodput monitoring works
480+
with recorder._maybe_monitor_goodput():
481+
pass
482+
483+
# Should be called once for cumulative monitoring
484+
self.assertEqual(mock_monitor_cls.call_count, 1)
485+
mock_monitor_instance.start_goodput_uploader.assert_called_once()
486+
mock_monitor_instance.stop_goodput_uploader.assert_called_once()
487+
488+
@mock.patch("jax.process_index", return_value=0)
489+
def test_record_event_works_with_monitoring_disabled(self, _):
490+
"""Tests that record_event still works when monitoring is disabled."""
491+
cfg = GoodputRecorder.default_config().set(
492+
name="test-recording-only",
493+
upload_dir="/test",
494+
upload_interval=30,
495+
enable_monitoring=False, # Monitoring disabled
496+
)
497+
recorder = GoodputRecorder(cfg)
498+
499+
# Verify that goodput recording still works (not monitoring/uploading)
500+
with mock.patch("ml_goodput_measurement.goodput.GoodputRecorder") as mock_recorder_cls:
501+
mock_instance = mock_recorder_cls.return_value
502+
mock_instance.record_job_start_time = mock.MagicMock()
503+
mock_instance.record_job_end_time = mock.MagicMock()
504+
505+
# Record event should work
506+
with recorder.record_event(measurement.EventType.JOB):
507+
pass
508+
509+
# Verify goodput recording happened
510+
mock_recorder_cls.assert_called_once()
511+
mock_instance.record_job_start_time.assert_called_once()
512+
mock_instance.record_job_end_time.assert_called_once()
513+
514+
# Verify no monitoring/uploading happened
515+
with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls:
516+
with recorder.maybe_monitor_all():
517+
pass
518+
mock_monitor_cls.assert_not_called()

0 commit comments

Comments
 (0)