Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def configure_generation_strategy(
# Misc options
torch_device: str | None = None,
simplify_parameter_changes: bool = False,
fit_tracking_metrics: bool = True,
) -> None:
"""
Optional method to configure the way candidate parameterizations are generated
Expand Down Expand Up @@ -276,6 +277,16 @@ def configure_generation_strategy(
[Daulton2026bonsai]_ to simplify parameter changes in arms
generated via Bayesian Optimization by pruning irrelevant
parameter changes.
fit_tracking_metrics: Whether to fit a model to the tracking metrics
(metrics that are not part of the optimization config). If ``False``,
only the metrics in the optimization config (objectives and outcome
constraints) are modeled; tracking metrics are still recorded but not
modeled by the Bayesian optimization model. This can speed up model
fitting when there are many tracking metrics kept only for
book-keeping. NOTE: When this is ``False``, any model-dependent
analyses (e.g. cross-validation, sensitivity analysis) will not be
produced for the tracking metrics. Requires an optimization config to
be set and has no effect when ``method="random_search"``.
"""
generation_strategy = self._choose_generation_strategy(
method=method,
Expand All @@ -287,6 +298,7 @@ def configure_generation_strategy(
allow_exceeding_initialization_budget=allow_exceeding_initialization_budget,
torch_device=torch_device,
simplify_parameter_changes=simplify_parameter_changes,
fit_tracking_metrics=fit_tracking_metrics,
)
self.set_generation_strategy(generation_strategy=generation_strategy)

Expand Down Expand Up @@ -1187,6 +1199,7 @@ def _choose_generation_strategy(
# Misc options
torch_device: str | None = None,
simplify_parameter_changes: bool = False,
fit_tracking_metrics: bool = True,
) -> GenerationStrategy:
"""
Choose a generation strategy based on the provided method and options.
Expand Down Expand Up @@ -1214,6 +1227,11 @@ def _choose_generation_strategy(
[Daulton2026bonsai]_ to simplify parameter changes in arms
generated via Bayesian Optimization by pruning irrelevant
parameter changes.
fit_tracking_metrics: Whether to fit a model to the tracking metrics
(metrics that are not part of the optimization config). If ``False``,
only the metrics in the optimization config are modeled and tracking
metrics are recorded but not modeled. NOTE: When this is ``False``,
model-dependent analyses will not be produced for the tracking metrics.


Returns:
Expand All @@ -1234,6 +1252,7 @@ def _choose_generation_strategy(
),
torch_device=torch_device,
simplify_parameter_changes=simplify_parameter_changes,
fit_tracking_metrics=fit_tracking_metrics,
)
)

Expand Down
25 changes: 25 additions & 0 deletions ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,31 @@ def test_configure_generation_strategy_with_simplify(self) -> None:
.generator_kwargs["acquisition_options"]["prune_irrelevant_parameters"]
)

def test_configure_generation_strategy_with_fit_tracking_metrics(self) -> None:
client = Client()
client.configure_experiment(
parameters=[
RangeParameterConfig(name="x1", parameter_type="float", bounds=(-1, 1)),
],
name="foo",
)
client.configure_optimization(objective="foo")

# Default is to fit tracking metrics.
client.configure_generation_strategy()
self.assertTrue(
client._generation_strategy._nodes[2]
.generator_specs[0]
.generator_kwargs["fit_tracking_metrics"]
)
# The option flows through to the MBM node's generator kwargs.
client.configure_generation_strategy(fit_tracking_metrics=False)
self.assertFalse(
client._generation_strategy._nodes[2]
.generator_specs[0]
.generator_kwargs["fit_tracking_metrics"]
)

def test_configure_experiment_with_derived_parameter(self) -> None:
# Setup: Create parameters including a derived parameter

Expand Down
6 changes: 6 additions & 0 deletions ax/api/utils/generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _get_mbm_node(
method: str,
torch_device: str | None,
simplify_parameter_changes: bool,
fit_tracking_metrics: bool = True,
model_config: ModelConfig | None = None,
botorch_acqf_class: type[AcquisitionFunction] | None = None,
) -> tuple[GenerationNode, str]:
Expand All @@ -126,6 +127,9 @@ def _get_mbm_node(
torch_device: The torch device to use for the MBM node.
simplify_parameter_changes: Whether to use BONSAI [Daulton2026bonsai]_ to
simplify parameter changes in the MBM node.
fit_tracking_metrics: Whether to fit a model to the tracking metrics. If
``False``, only the metrics in the optimization config are modeled, and
model-dependent analyses will not be produced for the tracking metrics.
model_config: Optional model config to use for the MBM node.
This is only supported when ``method`` is "custom".
botorch_acqf_class: An optional BoTorch ``AcquisitionFunction`` class
Expand Down Expand Up @@ -172,6 +176,7 @@ def _get_mbm_node(
"acquisition_options": {
"prune_irrelevant_parameters": simplify_parameter_changes
},
"fit_tracking_metrics": fit_tracking_metrics,
}
if botorch_acqf_class is not None:
generator_kwargs["botorch_acqf_class"] = botorch_acqf_class
Expand Down Expand Up @@ -245,6 +250,7 @@ def choose_generation_strategy(
method=struct.method,
torch_device=struct.torch_device,
simplify_parameter_changes=struct.simplify_parameter_changes,
fit_tracking_metrics=struct.fit_tracking_metrics,
model_config=model_config,
botorch_acqf_class=botorch_acqf_class,
)
Expand Down
11 changes: 11 additions & 0 deletions ax/api/utils/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ class GenerationStrategyDispatchStruct:
simplify_parameter_changes: A boolean indicating whether to simplify parameter
changes of the arms generated by Bayesian optimization by pruning
irrelevant parameters.
fit_tracking_metrics: A boolean indicating whether to fit a model to the
tracking metrics (metrics that are not part of the optimization config).
If ``False``, only the metrics in the optimization config (objectives and
outcome constraints) are modeled, and tracking metrics are still recorded
but not modeled by the Bayesian optimization model. This can speed up model
fitting when there are many tracking metrics that are only kept for
book-keeping. NOTE: When this is ``False``, any model-dependent analyses
(e.g. cross-validation, sensitivity analysis) will not be produced for the
tracking metrics. This option requires an optimization config to be set on
the experiment and has no effect when ``method="random_search"``.
"""

method: Literal["quality", "fast", "random_search", "custom"] = "fast"
Expand All @@ -106,3 +116,4 @@ class GenerationStrategyDispatchStruct:
# Misc options
torch_device: str | None = None
simplify_parameter_changes: bool = False
fit_tracking_metrics: bool = True
27 changes: 27 additions & 0 deletions ax/api/utils/tests/test_generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def test_choose_gs_fast_with_options(self) -> None:
derelativize_with_raw_status_quo=True
),
"acquisition_options": {"prune_irrelevant_parameters": False},
"fit_tracking_metrics": True,
},
)
self.assertEqual(mbm_node._transition_criteria, [])
Expand Down Expand Up @@ -195,10 +196,36 @@ def test_choose_gs_quality_with_options(self) -> None:
derelativize_with_raw_status_quo=True
),
"acquisition_options": {"prune_irrelevant_parameters": False},
"fit_tracking_metrics": True,
},
)
self.assertEqual(mbm_node._transition_criteria, [])

def test_gs_fit_tracking_metrics(self) -> None:
methods: list[Literal["fast", "quality"]] = ["fast", "quality"]
for fit_tracking_metrics, method in product((True, False), methods):
with self.subTest(method=method, fit_tracking_metrics=fit_tracking_metrics):
struct = GenerationStrategyDispatchStruct(
method=method,
fit_tracking_metrics=fit_tracking_metrics,
)
gs = choose_generation_strategy(struct=struct)
self.assertEqual(gs.name, f"Center+Sobol+MBM:{method}")
mbm_node = gs._nodes[2]
mbm_spec = mbm_node.generator_specs[0]
self.assertEqual(
mbm_spec.generator_kwargs["fit_tracking_metrics"],
fit_tracking_metrics,
)
# The Sobol node should not receive ``fit_tracking_metrics``, and the
# random_search method (Sobol only) should not either.
struct = GenerationStrategyDispatchStruct(
method="random_search", fit_tracking_metrics=False
)
gs = choose_generation_strategy(struct=struct)
sobol_spec = gs._nodes[-1].generator_specs[0]
self.assertNotIn("fit_tracking_metrics", sobol_spec.generator_kwargs)

def test_choose_gs_no_initialization(self) -> None:
struct = GenerationStrategyDispatchStruct(
method="fast", initialization_budget=0
Expand Down
3 changes: 2 additions & 1 deletion ax/benchmark/tests/test_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ def test_runner(self) -> None:
nullcontext()
if not isinstance(test_function, SurrogateTestFunction)
else patch.object(
runner.test_function._surrogate, # pyrefly: ignore [missing-attribute]
# pyrefly: ignore [missing-attribute]
runner.test_function._surrogate,
"predict",
return_value=({"branin": [4.2]}, None),
)
Expand Down
Loading