diff --git a/src/optimagic/optimization/optimize_result.py b/src/optimagic/optimization/optimize_result.py index f2895cf53..6b30b2a3c 100644 --- a/src/optimagic/optimization/optimize_result.py +++ b/src/optimagic/optimization/optimize_result.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +from numpy.typing import NDArray from optimagic import deprecations from optimagic.logging.logger import LogReader @@ -210,7 +211,7 @@ class MultistartInfo: start_parameters: List of start parameters for each optimization. local_optima: List of optimization results. exploration_sample: List of parameters used for exploration. - exploration_results: List of function values corresponding to exploration. + exploration_results: Array of function values corresponding to exploration. n_optimizations: Number of local optimizations that were run. """ @@ -218,7 +219,7 @@ class MultistartInfo: start_parameters: list[PyTree] local_optima: list[OptimizeResult] exploration_sample: list[PyTree] - exploration_results: list[float] + exploration_results: NDArray[np.float64] def __getitem__(self, key): deprecations.throw_dict_access_future_warning(key, obj_name=type(self).__name__) diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index cb64a4e94..ea9dda743 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -1,7 +1,8 @@ import inspect import itertools +from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Iterator import numpy as np import plotly.graph_objects as go @@ -13,37 +14,34 @@ from optimagic.optimization.history import History from optimagic.optimization.optimize_result import OptimizeResult from optimagic.parameters.tree_registry import get_registry -from optimagic.typing import Direction +from optimagic.typing import Direction, IterationHistory def criterion_plot( - results, - names=None, - max_evaluations=None, - template=PLOTLY_TEMPLATE, - palette=PLOTLY_PALETTE, - stack_multistart=False, - monotone=False, - show_exploration=False, -): + results: list[OptimizeResult | str | Path] | dict[str, OptimizeResult | str | Path], + names: list[str] | str | None = None, + max_evaluations: int | None = None, + template: str = PLOTLY_TEMPLATE, + palette: list[str] | str = PLOTLY_PALETTE, + stack_multistart: bool = False, + monotone: bool = False, + show_exploration: bool = False, +) -> go.Figure: """Plot the criterion history of an optimization. Args: - results (Union[List, Dict][Union[OptimizeResult, pathlib.Path, str]): A (list or - dict of) optimization results with collected history. If dict, then the - key is used as the name in a legend. - names (Union[List[str], str]): Names corresponding to res or entries in res. - max_evaluations (int): Clip the criterion history after that many entries. - template (str): The template for the figure. Default is "plotly_white". - palette (Union[List[str], str]): The coloring palette for traces. Default is - "qualitative.Plotly". - stack_multistart (bool): Whether to combine multistart histories into a single - history. Default is False. - monotone (bool): If True, the criterion plot becomes monotone in the sense - that only that at each iteration the current best criterion value is - displayed. Default is False. - show_exploration (bool): If True, exploration samples of a multistart - optimization are visualized. Default is False. + results: A (list or dict of) optimization results with collected history. + If dict, then the key is used as the name in a legend. + names: Names corresponding to res or entries in res. + max_evaluations: Clip the criterion history after that many entries. + template: The template for the figure. Default is "plotly_white". + palette: The coloring palette for traces. Default is "qualitative.Set2". + stack_multistart: Whether to combine multistart histories into a single history. + Default is False. + monotone: If True, the criterion plot becomes monotone in the sense that at each + iteration the current best criterion value is displayed. Default is False. + show_exploration: If True, exploration samples of a multistart optimization are + visualized. Default is False. Returns: plotly.graph_objs._figure.Figure: The figure. @@ -51,122 +49,46 @@ def criterion_plot( """ # ================================================================================== # Process inputs - # ================================================================================== - - results = _harmonize_inputs_to_dict(results, names) if not isinstance(palette, list): palette = [palette] - palette = itertools.cycle(palette) - - fun_or_monotone_fun = "monotone_fun" if monotone else "fun" - - # ================================================================================== - # Extract plotting data from results objects / data base - # ================================================================================== - - data = [] - for name, res in results.items(): - if isinstance(res, OptimizeResult): - _data = _extract_plotting_data_from_results_object( - res, stack_multistart, show_exploration, plot_name="criterion_plot" - ) - elif isinstance(res, (str, Path)): - _data = _extract_plotting_data_from_database( - res, stack_multistart, show_exploration - ) - else: - msg = "results must be (or contain) an OptimizeResult or a path to a log" - f"file, but is type {type(res)}." - raise TypeError(msg) + palette_cycle = itertools.cycle(palette) - _data["name"] = name - data.append(_data) + dict_of_optimize_results_or_paths = _harmonize_inputs_to_dict(results, names) # ================================================================================== - # Create figure - # ================================================================================== - - fig = go.Figure() + # Extract backend-agnostic plotting data from results - plot_multistart = ( - len(data) == 1 and data[0]["is_multistart"] and not stack_multistart + list_of_optimize_data = _retrieve_optimization_data( + results=dict_of_optimize_results_or_paths, + stack_multistart=stack_multistart, + show_exploration=show_exploration, ) - # ================================================================================== - # Plot multistart paths - - if plot_multistart: - scatter_kws = { - "connectgaps": True, - "showlegend": False, - } - - for i, local_history in enumerate(data[0]["local_histories"]): - history = getattr(local_history, fun_or_monotone_fun) - - if max_evaluations is not None and len(history) > max_evaluations: - history = history[:max_evaluations] - - trace = go.Scatter( - x=np.arange(len(history)), - y=history, - mode="lines", - name=str(i), - line_color="#bab0ac", - **scatter_kws, - ) - fig.add_trace(trace) + plot_data = _extract_criterion_plot_data( + data=list_of_optimize_data, + max_evaluations=max_evaluations, + palette_cycle=palette_cycle, + stack_multistart=stack_multistart, + monotone=monotone, + ) # ================================================================================== - # Plot main optimization objects + # Generate the plotly figure - for _data in data: - if stack_multistart and _data["stacked_local_histories"] is not None: - _history = _data["stacked_local_histories"] - else: - _history = _data["history"] - - history = getattr(_history, fun_or_monotone_fun) - - if max_evaluations is not None and len(history) > max_evaluations: - history = history[:max_evaluations] - - scatter_kws = { - "connectgaps": True, - "showlegend": not plot_multistart, - } - - _color = next(palette) - if not isinstance(_color, str): - msg = "highlight_palette needs to be a string or list of strings, but its " - f"entry is of type {type(_color)}." - raise TypeError(msg) - - line_kws = { - "color": _color, - } - - trace = go.Scatter( - x=np.arange(len(history)), - y=history, - mode="lines", - name="best result" if plot_multistart else _data["name"], - line=line_kws, - **scatter_kws, - ) - fig.add_trace(trace) - - fig.update_layout( + plot_config = PlotConfig( template=template, - xaxis_title_text="No. of criterion evaluations", - yaxis_title_text="Criterion value", legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95}, ) + + fig = _plotly_criterion_plot(plot_data, plot_config) return fig -def _harmonize_inputs_to_dict(results, names): +def _harmonize_inputs_to_dict( + results: list[OptimizeResult | str | Path] | dict[str, OptimizeResult | str | Path], + names: list[str] | str | None, +) -> dict[str, OptimizeResult | str | Path]: """Convert all valid inputs for results and names to dict[str, OptimizeResult].""" # convert scalar case to list case if not isinstance(names, list) and names is not None: @@ -187,7 +109,8 @@ def _harmonize_inputs_to_dict(results, names): # unlabeled iterable of results else: - names = range(len(results)) if names is None else names + if names is None: + names = list(str(i) for i in range(len(results))) results_dict = dict(zip(names, results, strict=False)) # convert keys to strings @@ -234,7 +157,7 @@ def params_plot( # ================================================================================== if isinstance(result, OptimizeResult): - data = _extract_plotting_data_from_results_object( + data = _retrieve_optimization_data_from_results_object( result, stack_multistart=True, show_exploration=show_exploration, @@ -242,19 +165,19 @@ def params_plot( ) start_params = result.start_params elif isinstance(result, (str, Path)): - data = _extract_plotting_data_from_database( + data = _retrieve_optimization_data_from_database( result, stack_multistart=True, show_exploration=show_exploration, ) - start_params = data["start_params"] + start_params = data.start_params else: raise TypeError("result must be an OptimizeResult or a path to a log file.") - if data["stacked_local_histories"] is not None: - history = data["stacked_local_histories"].params + if data.stacked_local_histories is not None: + history = data.stacked_local_histories.params else: - history = data["history"].params + history = data.history.params # ================================================================================== # Create figure @@ -298,30 +221,72 @@ def params_plot( return fig -def _extract_plotting_data_from_results_object( - res, stack_multistart, show_exploration, plot_name -): - """Extract data for plotting from results object. +def _retrieve_optimization_data( + results: dict[str, OptimizeResult | str | Path], + stack_multistart: bool, + show_exploration: bool, +) -> list["_OptimizeData"]: + """Retrieve data for criterion plot from results (OptimizeResult or database). Args: - res (OptmizeResult): An optimization results object. - stack_multistart (bool): Whether to combine multistart histories into a single - history. Default is False. - show_exploration (bool): If True, exploration samples of a multistart - optimization are visualized. Default is False. - plot_name (str): Name of the plotting function that calls this function. Used - for rasing errors. + results: A dict of optimization results with collected history. + The key is used as the name in a legend. + stack_multistart: Whether to combine multistart histories into a single history. + Default is False. + show_exploration: If True, exploration samples of a multistart optimization are + visualized. Default is False. + + Returns: + list[_OptimizeData]: A list of _OptimizeData objects containing the history, + direction, multistart information, and local histories. + + """ + data = [] + for name, res in results.items(): + if isinstance(res, OptimizeResult): + _data = _retrieve_optimization_data_from_results_object( + res=res, + stack_multistart=stack_multistart, + show_exploration=show_exploration, + plot_name="criterion_plot", + ) + elif isinstance(res, (str, Path)): + _data = _retrieve_optimization_data_from_database( + res=res, + stack_multistart=stack_multistart, + show_exploration=show_exploration, + ) + else: + msg = "results must be (or contain) an OptimizeResult or a path to a log" + f"file, but is type {type(res)}." + raise TypeError(msg) + + _data.name = name + data.append(_data) + + return data + + +def _retrieve_optimization_data_from_results_object( + res: OptimizeResult, + stack_multistart: bool, + show_exploration: bool, + plot_name: str, +) -> "_OptimizeData": + """Retrieve optimization data from results object. + + Args: + res: An optimization results object. + stack_multistart: Whether to combine multistart histories into a single history. + Default is False. + show_exploration: If True, exploration samples of a multistart optimization are + visualized. Default is False. + plot_name: Name of the plotting function that calls this function. Used for + raising errors. Returns: - dict: - - "history": The results history - - "direction": maximize or minimize - - "is_multistart": Whether the optimization used multistart - - "local_histories": All other multistart histories except for 'history'. If not - available is None. If show_exploration is True, the exploration phase is - added as the first entry. - - "stacked_local_histories": If stack_multistart is True the local histories - are stacked into a single one. + _OptimizeData: A data object containing the history, direction, multistart + information, and local histories. """ if res.history is None: @@ -331,12 +296,16 @@ def _extract_plotting_data_from_results_object( is_multistart = res.multistart_info is not None - if is_multistart: - local_histories = [opt.history for opt in res.multistart_info.local_optima] + if res.multistart_info: + local_histories = [ + opt.history + for opt in res.multistart_info.local_optima + if opt.history is not None + ] else: local_histories = None - if stack_multistart and local_histories is not None: + if stack_multistart and local_histories is not None and res.multistart_info: stacked = _get_stacked_local_histories(local_histories, res.direction) if show_exploration: fun = res.multistart_info.exploration_results.tolist()[::-1] + stacked.fun @@ -347,58 +316,58 @@ def _extract_plotting_data_from_results_object( fun=fun, params=params, # TODO: This needs to be fixed - start_time=len(fun) * [None], - stop_time=len(fun) * [None], - batches=len(fun) * [None], - task=len(fun) * [None], + start_time=len(fun) * [None], # type: ignore + stop_time=len(fun) * [None], # type: ignore + batches=len(fun) * [None], # type: ignore + task=len(fun) * [None], # type: ignore ) else: stacked = None - data = { - "history": res.history, - "direction": Direction(res.direction), - "is_multistart": is_multistart, - "local_histories": local_histories, - "stacked_local_histories": stacked, - } + data = _OptimizeData( + history=res.history, + direction=Direction(res.direction), + is_multistart=is_multistart, + local_histories=local_histories, + stacked_local_histories=stacked, + ) return data -def _extract_plotting_data_from_database(res, stack_multistart, show_exploration): - """Extract data for plotting from database. +def _retrieve_optimization_data_from_database( + res: str | Path, + stack_multistart: bool, + show_exploration: bool, +) -> "_OptimizeData": + """Retrieve optimization data from a database. Args: - res (str or pathlib.Path): A path to an optimization database. - stack_multistart (bool): Whether to combine multistart histories into a single - history. Default is False. - show_exploration (bool): If True, exploration samples of a multistart - optimization are visualized. Default is False. + res: A path to an optimization database. + stack_multistart: Whether to combine multistart histories into a single history. + Default is False. + show_exploration: If True, exploration samples of a multistart optimization are + visualized. Default is False. Returns: - dict: - - "history": The results history - - "direction": maximize or minimize - - "is_multistart": Whether the optimization used multistart - - "local_histories": All other multistart histories except for 'history'. If not - available is None. If show_exploration is True, the exploration phase is - added as the first entry. - - "stacked_local_histories": If stack_multistart is True the local histories - are stacked into a single one. + _OptimizeData: A data object containing the history, direction, multistart + information, and local histories. """ - reader = LogReader.from_options(SQLiteLogOptions(res)) + reader: LogReader = LogReader.from_options(SQLiteLogOptions(res)) _problem_table = reader.problem_df direction = _problem_table["direction"].tolist()[-1] - _history, local_histories, exploration = reader.read_multistart_history(direction) + multistart_history = reader.read_multistart_history(direction) + _history = multistart_history.history + local_histories = multistart_history.local_histories + exploration = multistart_history.exploration if stack_multistart and local_histories is not None: stacked = _get_stacked_local_histories(local_histories, direction, _history) if show_exploration: - stacked["params"] = exploration["params"][::-1] + stacked["params"] - stacked["criterion"] = exploration["criterion"][::-1] + stacked["criterion"] + stacked["params"] = exploration["params"][::-1] + stacked["params"] # type: ignore + stacked["criterion"] = exploration["criterion"][::-1] + stacked["criterion"] # type: ignore else: stacked = None @@ -409,23 +378,27 @@ def _extract_plotting_data_from_database(res, stack_multistart, show_exploration start_time=_history["time"], # TODO (@janosg): Retrieve `stop_time` from `hist` once it is available. # https://github.com/optimagic-dev/optimagic/pull/553 - stop_time=len(_history["fun"]) * [None], - task=len(_history["fun"]) * [None], + stop_time=len(_history["fun"]) * [None], # type: ignore + task=len(_history["fun"]) * [None], # type: ignore batches=list(range(len(_history["fun"]))), ) - data = { - "history": history, - "direction": direction, - "is_multistart": local_histories is not None, - "local_histories": local_histories, - "stacked_local_histories": stacked, - "start_params": reader.read_start_params(), - } + data = _OptimizeData( + history=history, + direction=direction, + is_multistart=local_histories is not None, + local_histories=local_histories, + stacked_local_histories=stacked, + start_params=reader.read_start_params(), + ) return data -def _get_stacked_local_histories(local_histories, direction, history=None): +def _get_stacked_local_histories( + local_histories: list[History] | list[IterationHistory], + direction: Any, + history: History | IterationHistory | None = None, +) -> History: """Stack local histories. Local histories is a list of dictionaries, each of the same structure. We transform @@ -433,7 +406,7 @@ def _get_stacked_local_histories(local_histories, direction, history=None): append the best history at the end. """ - stacked = {"criterion": [], "params": [], "runtime": []} + stacked: dict[str, list[Any]] = {"criterion": [], "params": [], "runtime": []} for hist in local_histories: stacked["criterion"].extend(hist.fun) stacked["params"].extend(hist.params) @@ -453,7 +426,176 @@ def _get_stacked_local_histories(local_histories, direction, history=None): # TODO (@janosg): Retrieve `stop_time` from `hist` once it is available for the # IterationHistory. # https://github.com/optimagic-dev/optimagic/pull/553 - stop_time=len(stacked["criterion"]) * [None], - task=len(stacked["criterion"]) * [None], + stop_time=len(stacked["criterion"]) * [None], # type: ignore + task=len(stacked["criterion"]) * [None], # type: ignore batches=list(range(len(stacked["criterion"]))), ) + + +def _extract_criterion_plot_data( + data: list["_OptimizeData"], + max_evaluations: int | None, + palette_cycle: Iterator[str], + stack_multistart: bool, + monotone: bool, +) -> "CriterionPlotData": + """Extract lines for criterion plot from data. + + Args: + data: Data retrieved from results or database. + max_evaluations: Clip the criterion history after that many entries. + palette_cycle: Cycle of colors for plotting. + stack_multistart: Whether to combine multistart histories into a single + history. Default is False. + monotone: If True, the criterion plot becomes monotone in the sense that at each + iteration the current best criterion value is displayed. + + Returns: + CriterionPlotData: A data object containing the lines for the plot. + + """ + fun_or_monotone_fun = "monotone_fun" if monotone else "fun" + + # Collect multistart optimization paths + multistart_lines: list[LineData] = [] + + plot_multistart = len(data) == 1 and data[0].is_multistart and not stack_multistart + + if plot_multistart and data[0].local_histories: + for i, local_history in enumerate(data[0].local_histories): + history = getattr(local_history, fun_or_monotone_fun) + + if max_evaluations is not None and len(history) > max_evaluations: + history = history[:max_evaluations] + + line_data = LineData( + x=np.arange(len(history)), + y=history, + color="#bab0ac", + name=str(i), + show_in_legend=False, + ) + multistart_lines.append(line_data) + + # Collect main optimization paths + lines: list[LineData] = [] + + for _data in data: + if stack_multistart and _data.stacked_local_histories is not None: + _history = _data.stacked_local_histories + else: + _history = _data.history + + history = getattr(_history, fun_or_monotone_fun) + + if max_evaluations is not None and len(history) > max_evaluations: + history = history[:max_evaluations] + + _color = next(palette_cycle) + if not isinstance(_color, str): + msg = "highlight_palette needs to be a string or list of strings, but its " + f"entry is of type {type(_color)}." + raise TypeError(msg) + + line_data = LineData( + x=np.arange(len(history)), + y=history, + color=_color, + name="best result" if plot_multistart else _data.name, + show_in_legend=not plot_multistart, + ) + lines.append(line_data) + + plot_data = CriterionPlotData( + lines=lines, + multistart_lines=multistart_lines, + ) + return plot_data + + +@dataclass() +class _OptimizeData: + history: History + direction: Direction + is_multistart: bool + local_histories: list[History] | list[IterationHistory] | None + stacked_local_histories: History | None + start_params: list[Any] | None = None + name: str | None = None + + +@dataclass(frozen=True) +class LineData: + """Data of a single line. + + Attributes: + x: The x-coordinates of the points. + y: The y-coordinates of the points. + color: The color of the line. Default is None. + name: The name of the line. Default is None. + show_in_legend: Whether to show the line in the legend. Default is True. + + """ + + x: np.ndarray + y: np.ndarray + color: str | None = None + name: str | None = None + show_in_legend: bool = True + + +@dataclass(frozen=True) +class CriterionPlotData: + """Backend agnostic data for criterion plot. + + Attributes: + lines: Main optimization paths. + multistart_lines: Multistart optimization paths, if applicable. + + """ + + lines: list[LineData] + multistart_lines: list[LineData] + + +@dataclass(frozen=True) +class PlotConfig: + """Configuration settings for figure. + + Attributes: + template: The template for the figure. + legend: Configuration for the legend. + + """ + + template: str + legend: dict[str, Any] + + +def _plotly_criterion_plot( + plot_data: CriterionPlotData, plot_config: PlotConfig +) -> go.Figure: + """Create a plotly figure from the plot data and configuration.""" + + fig = go.Figure() + + for line in plot_data.multistart_lines + plot_data.lines: + trace = go.Scatter( + x=line.x, + y=line.y, + name=line.name, + mode="lines", + line_color=line.color, + showlegend=line.show_in_legend, + connectgaps=True, + ) + fig.add_trace(trace) + + fig.update_layout( + template=plot_config.template, + xaxis_title_text="No. of criterion evaluations", + yaxis_title_text="Criterion value", + legend=plot_config.legend, + ) + + return fig diff --git a/tests/optimagic/visualization/test_history_plots.py b/tests/optimagic/visualization/test_history_plots.py index 85d6b18e0..6466ecd9f 100644 --- a/tests/optimagic/visualization/test_history_plots.py +++ b/tests/optimagic/visualization/test_history_plots.py @@ -8,8 +8,13 @@ from optimagic.logging import SQLiteLogOptions from optimagic.optimization.optimize import minimize from optimagic.parameters.bounds import Bounds +from optimagic.typing import Direction from optimagic.visualization.history_plots import ( + LineData, + _extract_criterion_plot_data, _harmonize_inputs_to_dict, + _OptimizeData, + _retrieve_optimization_data, criterion_plot, params_plot, ) @@ -187,3 +192,63 @@ def test_harmonize_inputs_to_dict_str_input(): def test_harmonize_inputs_to_dict_path_input(): path = Path("test.db") assert _harmonize_inputs_to_dict(results=path, names=None) == {"0": path} + + +def test_extract_data_from_results(): + res = minimize(fun=lambda x: x @ x, params=np.arange(5), algorithm="scipy_lbfgsb") + results = {"bla": res} + + data = _retrieve_optimization_data(results, False, False) + + expected = [ + _OptimizeData( + history=res.history, + direction=Direction(res.direction), + is_multistart=False, + local_histories=None, + stacked_local_histories=None, + name="bla", + ), + ] + + assert data == expected + + +def test_extract_data_from_multistart_result(minimize_result): + res = minimize_result[True][0] + results = {"multistart": res} + + for stack_multistart in [True, False]: + data = _retrieve_optimization_data(results, stack_multistart, False) + + assert isinstance(data, list) + assert len(data) == 1 + + assert data[0].is_multistart + assert len(data[0].local_histories) == 5 + + if stack_multistart: + assert data[0].stacked_local_histories is not None + else: + assert data[0].stacked_local_histories is None + + +def test_collect_lines_from_data(minimize_result): + res = minimize_result[True][0] + results = {"multistart": res} + data = _retrieve_optimization_data(results, False, False) + + palette = itertools.cycle(["red", "green", "blue"]) + + plot_data = _extract_criterion_plot_data(data, None, palette, False, False) + + lines = plot_data.lines + multistart_lines = plot_data.multistart_lines + + assert isinstance(lines, list) and all(isinstance(line, LineData) for line in lines) + assert len(lines) == 1 + + assert isinstance(multistart_lines, list) and all( + isinstance(line, LineData) for line in multistart_lines + ) + assert len(multistart_lines) == 5