Skip to content

Implementation of matplotlib backend for criterion_plot() #599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .tools/envs/testenv-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly<6.0.0 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-numpy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- cloudpickle # run, tests
- joblib # run, tests
- plotly<6.0.0 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-others.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly<6.0.0 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-pandas.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- cloudpickle # run, tests
- joblib # run, tests
- plotly<6.0.0 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly<6.0.0 # run, tests
- matplotlib # run, tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"numpy",
"pandas",
"plotly<6.0.0",
"matplotlib",
"pybaum>=0.1.2",
"scipy>=1.2.1",
"sqlalchemy>=1.3",
Expand Down Expand Up @@ -290,6 +291,7 @@ module = [
"optimagic.shared.process_user_function",

"optimagic.visualization",
"optimagic.visualization.backends",
"optimagic.visualization.convergence_plot",
"optimagic.visualization.deviation_plot",
"optimagic.visualization.history_plots",
Expand Down Expand Up @@ -347,6 +349,8 @@ module = [
"plotly.graph_objects",
"plotly.express",
"plotly.subplots",
"matplotlib",
"matplotlib.pyplot",
"cyipopt",
"nlopt",
"bokeh",
Expand Down
9 changes: 9 additions & 0 deletions src/optimagic/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

import matplotlib as mpl
import pandas as pd
import plotly.express as px
from packaging import version
Expand All @@ -10,6 +11,14 @@
PLOTLY_TEMPLATE = "simple_white"
PLOTLY_PALETTE = px.colors.qualitative.Set2

PLOT_DEFAULTS = {
"plotly": {"template": "simple_white", "palette": px.colors.qualitative.Set2},
"matplotlib": {
"template": "default",
"palette": mpl.colormaps["Set2"],
},
}

DEFAULT_N_CORES = 1

CRITERION_PENALTY_SLOPE = 0.1
Expand Down
108 changes: 108 additions & 0 deletions src/optimagic/visualization/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import abc
from dataclasses import dataclass
from typing import Any

import matplotlib.pyplot as plt
import plotly.graph_objects as go


@dataclass(frozen=True)
class PlotConfig:
template: str
plotly_legend: dict[str, Any]
matplotlib_legend: dict[str, Any]


class BackendWrapper(abc.ABC):
def __init__(self, plot_config):
self.plot_config = plot_config

@abc.abstractmethod
def create_figure(self):
pass

Check warning on line 22 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L22

Added line #L22 was not covered by tests

@abc.abstractmethod
def lineplot(self, **kwargs):
pass

Check warning on line 26 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L26

Added line #L26 was not covered by tests

@abc.abstractmethod
def post_plot(self, **kwargs):
pass

Check warning on line 30 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L30

Added line #L30 was not covered by tests

@abc.abstractmethod
def return_obj(self):
pass

Check warning on line 34 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L34

Added line #L34 was not covered by tests


class BackendRegistry:
_registry: dict[str, BackendWrapper] = {}

@classmethod
def register(cls, backend_name):
def decorator(backend_wrapper):
cls._registry[backend_name] = backend_wrapper
return backend_wrapper

return decorator

@classmethod
def get_backend_wrapper(cls, backend_name):
if backend_name not in cls._registry:
raise ValueError(

Check warning on line 51 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L51

Added line #L51 was not covered by tests
f"Backend '{backend_name}' is not supported. "
f"Supported backends are: {', '.join(cls._registry.keys())}."
)
return cls._registry.get(backend_name)


@BackendRegistry.register("plotly")
class PlotlyBackend(BackendWrapper):
def __init__(self, plot_config):
super().__init__(plot_config)
self.fig = self.create_figure()

def create_figure(self):
fig = go.Figure()
return fig

def lineplot(self, *, x, y, color, name=None, plotly_scatter_kws=None, **kwargs):
if plotly_scatter_kws is None:
plotly_scatter_kws = {}

Check warning on line 70 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L70

Added line #L70 was not covered by tests

trace = go.Scatter(
x=x, y=y, mode="lines", line_color=color, name=name, **plotly_scatter_kws
)
self.fig.add_trace(trace)

def post_plot(self, *, xlabel=None, ylabel=None, **kwargs):
self.fig.update_layout(
template=self.plot_config.template,
xaxis_title_text=xlabel,
yaxis_title_text=ylabel,
legend=self.plot_config.plotly_legend,
)

def return_obj(self):
return self.fig


@BackendRegistry.register("matplotlib")
class MatplotlibBackend(BackendWrapper):
def __init__(self, plot_config):
super().__init__(plot_config)
self.fig, self.ax = self.create_figure()

Check warning on line 93 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L92-L93

Added lines #L92 - L93 were not covered by tests

def create_figure(self):
plt.style.use(self.plot_config.template)
fig, ax = plt.subplots()
return fig, ax

Check warning on line 98 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L96-L98

Added lines #L96 - L98 were not covered by tests

def lineplot(self, *, x, y, color, name=None, **kwargs):
self.ax.plot(x, y, color=color, label=name)

Check warning on line 101 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L101

Added line #L101 was not covered by tests

def post_plot(self, *, xlabel=None, ylabel=None, **kwargs):
self.ax.set(xlabel=xlabel, ylabel=ylabel)
self.ax.legend(**self.plot_config.matplotlib_legend)

Check warning on line 105 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L104-L105

Added lines #L104 - L105 were not covered by tests

def return_obj(self):
return self.fig

Check warning on line 108 in src/optimagic/visualization/backends.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/backends.py#L108

Added line #L108 was not covered by tests
76 changes: 35 additions & 41 deletions src/optimagic/visualization/history_plots.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import inspect
import itertools
from pathlib import Path
from typing import Any

import numpy as np
import plotly.graph_objects as go
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten

from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE
from optimagic.config import PLOTLY_TEMPLATE
from optimagic.logging.logger import LogReader, SQLiteLogOptions
from optimagic.optimization.algorithm import Algorithm
from optimagic.optimization.history import History
from optimagic.optimization.optimize_result import OptimizeResult
from optimagic.parameters.tree_registry import get_registry
from optimagic.typing import Direction
from optimagic.visualization.backends import BackendRegistry, PlotConfig
from optimagic.visualization.plotting_utilities import get_palette, get_template


def criterion_plot(
results,
names=None,
max_evaluations=None,
template=PLOTLY_TEMPLATE,
palette=PLOTLY_PALETTE,
backend="plotly",
template=None,
palette=None,
stack_multistart=False,
monotone=False,
show_exploration=False,
Expand All @@ -34,6 +36,7 @@ def criterion_plot(
key is used as the name in a legend.
names (Union[List[str], str]): Names corresponding to res or entries in res.
max_evaluations (int): Clip the criterion history after that many entries.
backend (str): The backend to use for plotting. Default is "plotly".
template (str): The template for the figure. Default is "plotly_white".
palette (Union[List[str], str]): The coloring palette for traces. Default is
"qualitative.Plotly".
Expand All @@ -46,7 +49,7 @@ def criterion_plot(
optimization are visualized. Default is False.

Returns:
plotly.graph_objs._figure.Figure: The figure.
Native figure object returned by the chosen backend.

"""
# ==================================================================================
Expand All @@ -55,9 +58,8 @@ def criterion_plot(

results = _harmonize_inputs_to_dict(results, names)

if not isinstance(palette, list):
palette = [palette]
palette = itertools.cycle(palette)
template = get_template(backend, template)
palette = get_palette(backend, palette)

fun_or_monotone_fun = "monotone_fun" if monotone else "fun"

Expand Down Expand Up @@ -87,36 +89,41 @@ def criterion_plot(
# Create figure
# ==================================================================================

fig = go.Figure()

plot_multistart = (
len(data) == 1 and data[0]["is_multistart"] and not stack_multistart
plot_config = PlotConfig(
template=template,
plotly_legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
matplotlib_legend={"loc": "upper right"},
)

_backend_wrapper = BackendRegistry.get_backend_wrapper(backend)
fig = _backend_wrapper(plot_config)

# ==================================================================================
# Plot multistart paths

plot_multistart = (
len(data) == 1 and data[0]["is_multistart"] and not stack_multistart
)

if plot_multistart:
scatter_kws = {
"connectgaps": True,
"showlegend": False,
}

for i, local_history in enumerate(data[0]["local_histories"]):
for local_history in data[0]["local_histories"]:
history = getattr(local_history, fun_or_monotone_fun)

if max_evaluations is not None and len(history) > max_evaluations:
history = history[:max_evaluations]

trace = go.Scatter(
fig.lineplot(
x=np.arange(len(history)),
y=history,
mode="lines",
name=str(i),
line_color="#bab0ac",
**scatter_kws,
name=None,
color="#bab0ac",
plotly_scatter_kws=scatter_kws,
)
fig.add_trace(trace)

# ==================================================================================
# Plot main optimization objects
Expand All @@ -134,36 +141,23 @@ def criterion_plot(

scatter_kws = {
"connectgaps": True,
"showlegend": not plot_multistart,
"showlegend": True,
}

_color = next(palette)
if not isinstance(_color, str):
msg = "highlight_palette needs to be a string or list of strings, but its "
f"entry is of type {type(_color)}."
raise TypeError(msg)

line_kws = {
"color": _color,
}

trace = go.Scatter(
fig.lineplot(
x=np.arange(len(history)),
y=history,
mode="lines",
name="best result" if plot_multistart else _data["name"],
line=line_kws,
**scatter_kws,
color=next(palette),
plotly_scatter_kws=scatter_kws,
)
fig.add_trace(trace)

fig.update_layout(
template=template,
xaxis_title_text="No. of criterion evaluations",
yaxis_title_text="Criterion value",
legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
fig.post_plot(
xlabel="No. of criterion evaluations",
ylabel="Criterion value",
)
return fig

return fig.return_obj()


def _harmonize_inputs_to_dict(results, names):
Expand Down
23 changes: 22 additions & 1 deletion src/optimagic/visualization/plotting_utilities.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import itertools
from copy import deepcopy

import matplotlib as mpl
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from optimagic.config import PLOTLY_TEMPLATE
from optimagic.config import PLOT_DEFAULTS, PLOTLY_TEMPLATE


def combine_plots(
Expand Down Expand Up @@ -328,3 +329,23 @@
if layout_kwargs:
default_kwargs.update(layout_kwargs)
return default_kwargs


def get_template(backend, template):
if template is None:
template = PLOT_DEFAULTS[backend]["template"]

return template


def get_palette(backend, palette):
if palette is None:
palette = PLOT_DEFAULTS[backend]["palette"]

if isinstance(palette, mpl.colors.Colormap):
palette = [palette(i) for i in range(palette.N)]

Check warning on line 346 in src/optimagic/visualization/plotting_utilities.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/plotting_utilities.py#L346

Added line #L346 was not covered by tests
if not isinstance(palette, list):
palette = [palette]
palette = itertools.cycle(palette)

return palette
Loading