diff --git a/.tools/envs/testenv-linux.yml b/.tools/envs/testenv-linux.yml index ec4b969f9..3d72c214f 100644 --- a/.tools/envs/testenv-linux.yml +++ b/.tools/envs/testenv-linux.yml @@ -19,6 +19,7 @@ dependencies: - numpy >= 2 # run, tests - pandas # run, tests - plotly<6.0.0 # run, tests + - matplotlib # run, tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-numpy.yml b/.tools/envs/testenv-numpy.yml index 9f9fa7d0f..78cf41b0d 100644 --- a/.tools/envs/testenv-numpy.yml +++ b/.tools/envs/testenv-numpy.yml @@ -17,6 +17,7 @@ dependencies: - cloudpickle # run, tests - joblib # run, tests - plotly<6.0.0 # run, tests + - matplotlib # run, tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-others.yml b/.tools/envs/testenv-others.yml index ce9490b7f..37279a219 100644 --- a/.tools/envs/testenv-others.yml +++ b/.tools/envs/testenv-others.yml @@ -17,6 +17,7 @@ dependencies: - numpy >= 2 # run, tests - pandas # run, tests - plotly<6.0.0 # run, tests + - matplotlib # run, tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-pandas.yml b/.tools/envs/testenv-pandas.yml index 7b342240b..7cd56eaa1 100644 --- a/.tools/envs/testenv-pandas.yml +++ b/.tools/envs/testenv-pandas.yml @@ -17,6 +17,7 @@ dependencies: - cloudpickle # run, tests - joblib # run, tests - plotly<6.0.0 # run, tests + - matplotlib # run, tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/environment.yml b/environment.yml index 80435b8d7..9262724b6 100644 --- a/environment.yml +++ b/environment.yml @@ -21,6 +21,7 @@ dependencies: - numpy >= 2 # run, tests - pandas # run, tests - plotly<6.0.0 # run, tests + - matplotlib # run, tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/pyproject.toml b/pyproject.toml index ce6707e6e..9e1cbe2ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "numpy", "pandas", "plotly<6.0.0", + "matplotlib", "pybaum>=0.1.2", "scipy>=1.2.1", "sqlalchemy>=1.3", @@ -290,6 +291,7 @@ module = [ "optimagic.shared.process_user_function", "optimagic.visualization", + "optimagic.visualization.backends", "optimagic.visualization.convergence_plot", "optimagic.visualization.deviation_plot", "optimagic.visualization.history_plots", @@ -347,6 +349,8 @@ module = [ "plotly.graph_objects", "plotly.express", "plotly.subplots", + "matplotlib", + "matplotlib.pyplot", "cyipopt", "nlopt", "bokeh", diff --git a/src/optimagic/config.py b/src/optimagic/config.py index 643a6f663..892272d17 100644 --- a/src/optimagic/config.py +++ b/src/optimagic/config.py @@ -1,5 +1,6 @@ from pathlib import Path +import matplotlib as mpl import pandas as pd import plotly.express as px from packaging import version @@ -10,6 +11,14 @@ PLOTLY_TEMPLATE = "simple_white" PLOTLY_PALETTE = px.colors.qualitative.Set2 +PLOT_DEFAULTS = { + "plotly": {"template": "simple_white", "palette": px.colors.qualitative.Set2}, + "matplotlib": { + "template": "default", + "palette": mpl.colormaps["Set2"], + }, +} + DEFAULT_N_CORES = 1 CRITERION_PENALTY_SLOPE = 0.1 diff --git a/src/optimagic/visualization/backends.py b/src/optimagic/visualization/backends.py new file mode 100644 index 000000000..59b91aa68 --- /dev/null +++ b/src/optimagic/visualization/backends.py @@ -0,0 +1,108 @@ +import abc +from dataclasses import dataclass +from typing import Any + +import matplotlib.pyplot as plt +import plotly.graph_objects as go + + +@dataclass(frozen=True) +class PlotConfig: + template: str + plotly_legend: dict[str, Any] + matplotlib_legend: dict[str, Any] + + +class BackendWrapper(abc.ABC): + def __init__(self, plot_config): + self.plot_config = plot_config + + @abc.abstractmethod + def create_figure(self): + pass + + @abc.abstractmethod + def lineplot(self, **kwargs): + pass + + @abc.abstractmethod + def post_plot(self, **kwargs): + pass + + @abc.abstractmethod + def return_obj(self): + pass + + +class BackendRegistry: + _registry: dict[str, BackendWrapper] = {} + + @classmethod + def register(cls, backend_name): + def decorator(backend_wrapper): + cls._registry[backend_name] = backend_wrapper + return backend_wrapper + + return decorator + + @classmethod + def get_backend_wrapper(cls, backend_name): + if backend_name not in cls._registry: + raise ValueError( + f"Backend '{backend_name}' is not supported. " + f"Supported backends are: {', '.join(cls._registry.keys())}." + ) + return cls._registry.get(backend_name) + + +@BackendRegistry.register("plotly") +class PlotlyBackend(BackendWrapper): + def __init__(self, plot_config): + super().__init__(plot_config) + self.fig = self.create_figure() + + def create_figure(self): + fig = go.Figure() + return fig + + def lineplot(self, *, x, y, color, name=None, plotly_scatter_kws=None, **kwargs): + if plotly_scatter_kws is None: + plotly_scatter_kws = {} + + trace = go.Scatter( + x=x, y=y, mode="lines", line_color=color, name=name, **plotly_scatter_kws + ) + self.fig.add_trace(trace) + + def post_plot(self, *, xlabel=None, ylabel=None, **kwargs): + self.fig.update_layout( + template=self.plot_config.template, + xaxis_title_text=xlabel, + yaxis_title_text=ylabel, + legend=self.plot_config.plotly_legend, + ) + + def return_obj(self): + return self.fig + + +@BackendRegistry.register("matplotlib") +class MatplotlibBackend(BackendWrapper): + def __init__(self, plot_config): + super().__init__(plot_config) + self.fig, self.ax = self.create_figure() + + def create_figure(self): + plt.style.use(self.plot_config.template) + fig, ax = plt.subplots() + return fig, ax + + def lineplot(self, *, x, y, color, name=None, **kwargs): + self.ax.plot(x, y, color=color, label=name) + + def post_plot(self, *, xlabel=None, ylabel=None, **kwargs): + self.ax.set(xlabel=xlabel, ylabel=ylabel) + self.ax.legend(**self.plot_config.matplotlib_legend) + + def return_obj(self): + return self.fig diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index cb64a4e94..27076b48a 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -1,5 +1,4 @@ import inspect -import itertools from pathlib import Path from typing import Any @@ -7,21 +6,24 @@ import plotly.graph_objects as go from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten -from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE +from optimagic.config import PLOTLY_TEMPLATE from optimagic.logging.logger import LogReader, SQLiteLogOptions from optimagic.optimization.algorithm import Algorithm from optimagic.optimization.history import History from optimagic.optimization.optimize_result import OptimizeResult from optimagic.parameters.tree_registry import get_registry from optimagic.typing import Direction +from optimagic.visualization.backends import BackendRegistry, PlotConfig +from optimagic.visualization.plotting_utilities import get_palette, get_template def criterion_plot( results, names=None, max_evaluations=None, - template=PLOTLY_TEMPLATE, - palette=PLOTLY_PALETTE, + backend="plotly", + template=None, + palette=None, stack_multistart=False, monotone=False, show_exploration=False, @@ -34,6 +36,7 @@ def criterion_plot( key is used as the name in a legend. names (Union[List[str], str]): Names corresponding to res or entries in res. max_evaluations (int): Clip the criterion history after that many entries. + backend (str): The backend to use for plotting. Default is "plotly". template (str): The template for the figure. Default is "plotly_white". palette (Union[List[str], str]): The coloring palette for traces. Default is "qualitative.Plotly". @@ -46,7 +49,7 @@ def criterion_plot( optimization are visualized. Default is False. Returns: - plotly.graph_objs._figure.Figure: The figure. + Native figure object returned by the chosen backend. """ # ================================================================================== @@ -55,9 +58,8 @@ def criterion_plot( results = _harmonize_inputs_to_dict(results, names) - if not isinstance(palette, list): - palette = [palette] - palette = itertools.cycle(palette) + template = get_template(backend, template) + palette = get_palette(backend, palette) fun_or_monotone_fun = "monotone_fun" if monotone else "fun" @@ -87,36 +89,41 @@ def criterion_plot( # Create figure # ================================================================================== - fig = go.Figure() - - plot_multistart = ( - len(data) == 1 and data[0]["is_multistart"] and not stack_multistart + plot_config = PlotConfig( + template=template, + plotly_legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95}, + matplotlib_legend={"loc": "upper right"}, ) + _backend_wrapper = BackendRegistry.get_backend_wrapper(backend) + fig = _backend_wrapper(plot_config) + # ================================================================================== # Plot multistart paths + plot_multistart = ( + len(data) == 1 and data[0]["is_multistart"] and not stack_multistart + ) + if plot_multistart: scatter_kws = { "connectgaps": True, "showlegend": False, } - for i, local_history in enumerate(data[0]["local_histories"]): + for local_history in data[0]["local_histories"]: history = getattr(local_history, fun_or_monotone_fun) if max_evaluations is not None and len(history) > max_evaluations: history = history[:max_evaluations] - trace = go.Scatter( + fig.lineplot( x=np.arange(len(history)), y=history, - mode="lines", - name=str(i), - line_color="#bab0ac", - **scatter_kws, + name=None, + color="#bab0ac", + plotly_scatter_kws=scatter_kws, ) - fig.add_trace(trace) # ================================================================================== # Plot main optimization objects @@ -134,36 +141,23 @@ def criterion_plot( scatter_kws = { "connectgaps": True, - "showlegend": not plot_multistart, + "showlegend": True, } - _color = next(palette) - if not isinstance(_color, str): - msg = "highlight_palette needs to be a string or list of strings, but its " - f"entry is of type {type(_color)}." - raise TypeError(msg) - - line_kws = { - "color": _color, - } - - trace = go.Scatter( + fig.lineplot( x=np.arange(len(history)), y=history, - mode="lines", name="best result" if plot_multistart else _data["name"], - line=line_kws, - **scatter_kws, + color=next(palette), + plotly_scatter_kws=scatter_kws, ) - fig.add_trace(trace) - fig.update_layout( - template=template, - xaxis_title_text="No. of criterion evaluations", - yaxis_title_text="Criterion value", - legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95}, + fig.post_plot( + xlabel="No. of criterion evaluations", + ylabel="Criterion value", ) - return fig + + return fig.return_obj() def _harmonize_inputs_to_dict(results, names): diff --git a/src/optimagic/visualization/plotting_utilities.py b/src/optimagic/visualization/plotting_utilities.py index eea622d9c..55ee35257 100644 --- a/src/optimagic/visualization/plotting_utilities.py +++ b/src/optimagic/visualization/plotting_utilities.py @@ -1,11 +1,12 @@ import itertools from copy import deepcopy +import matplotlib as mpl import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots -from optimagic.config import PLOTLY_TEMPLATE +from optimagic.config import PLOT_DEFAULTS, PLOTLY_TEMPLATE def combine_plots( @@ -328,3 +329,23 @@ def get_layout_kwargs(layout_kwargs, legend_kwargs, title_kwargs, template, show if layout_kwargs: default_kwargs.update(layout_kwargs) return default_kwargs + + +def get_template(backend, template): + if template is None: + template = PLOT_DEFAULTS[backend]["template"] + + return template + + +def get_palette(backend, palette): + if palette is None: + palette = PLOT_DEFAULTS[backend]["palette"] + + if isinstance(palette, mpl.colors.Colormap): + palette = [palette(i) for i in range(palette.N)] + if not isinstance(palette, list): + palette = [palette] + palette = itertools.cycle(palette) + + return palette