Skip to content

Commit 78eeacf

Browse files
committed
Add structural pattern matching code to plot functions
1 parent 65690ca commit 78eeacf

1 file changed

Lines changed: 94 additions & 77 deletions

File tree

Stoner/plot/functions.py

Lines changed: 94 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
from inspect import getfullargspec
55
import os
6+
import re
67
from functools import partial
78

89
import numpy as np
@@ -27,7 +28,7 @@
2728
from ..tools.tests import isanynone, isnone, isiterable
2829
from ..tools import AttributeStore
2930

30-
from ..compat import index_types
31+
from ..compat import _pattern_type
3132

3233
from .utils import hsl2rgb
3334

@@ -160,7 +161,7 @@ def _vector_color(datafile, xcol=None, ycol=None, ucol=None, vcol=None, wcol=Non
160161
"""Map a vector direction in the data to a value for use with a colormnap."""
161162
c = _fix_cols(datafile, xcol=xcol, ycol=ycol, ucol=ucol, vcol=vcol, wcol=wcol, **kargs)
162163

163-
if isinstance(c.wcol, index_types): # 3D vector field
164+
if isinstance(c.wcol, (int, str, _pattern_type)): # 3D vector field
164165
wdata = datafile.column(c.wcol)
165166
phidata = (wdata - np.min(wdata)) / (np.max(wdata) - np.min(wdata))
166167
else: # 2D vector field
@@ -235,17 +236,18 @@ def _fix_cols(datafile, scalar=True, **kargs):
235236

236237
def _fix_fig(datafile, figure, **kargs):
237238
"""Sorts out the matplotlib figure handling."""
238-
if isinstance(figure, bool) and not figure:
239-
figure, ax = datafile.template.new_figure(None, **kargs)
240-
elif not isinstance(figure, bool) and isinstance(figure, int):
241-
figure, ax = datafile.template.new_figure(figure, **kargs)
242-
elif isinstance(figure, mplfig.Figure):
243-
figure, ax = datafile.template.new_figure(figure.number, **kargs)
244-
elif isinstance(datafile._figure, mplfig.Figure):
245-
figure = datafile._figure
246-
ax = datafile._figure.gca(**kargs)
247-
else:
248-
figure, ax = datafile.template.new_figure(None, **kargs)
239+
match figure:
240+
case bool() if not figure:
241+
figure, ax = datafile.template.new_figure(None, **kargs)
242+
case int() if not isinstance(figure, bool):
243+
figure, ax = datafile.template.new_figure(figure, **kargs)
244+
case mplfig.Figure():
245+
figure, ax = datafile.template.new_figure(figure.number, **kargs)
246+
case _ if isinstance(datafile._figure, mplfig.Figure):
247+
figure = datafile._figure
248+
ax = datafile._figure.gca(**kargs)
249+
case _:
250+
figure, ax = datafile.template.new_figure(None, **kargs)
249251
datafile._figure = figure
250252
figure.sca(ax) # Esur4e we're set for plotting on the correct axes
251253
return figure, ax
@@ -305,6 +307,11 @@ def _fix_titles(datafile, ix, multiple, **kargs):
305307
plt.savefig(str(kargs["save_filename"]))
306308

307309

310+
# #########################################################################################
311+
# ################ Public Methods #########################################################
312+
# #########################################################################################
313+
314+
308315
def colormap_xyz(datafile, xcol=None, ycol=None, zcol=None, **kargs):
309316
"""Make a xyz plot that forces the use of plt.colormap.
310317
@@ -401,12 +408,15 @@ def figure(datafile, figure=None, projection="rectilinear", **kargs):
401408
Returns:
402409
The current Stoner.plot.PlotMixin instance
403410
"""
404-
if figure is None:
405-
figure = datafile.template.new_figure(None, projection=projection, **kargs)[0]
406-
elif isinstance(figure, int):
407-
figure = datafile.template.new_figure(figure, projection=projection, **kargs)[0]
408-
elif isinstance(figure, mplfig.Figure):
409-
figure = datafile.template.new_figure(figure.number, projection=projection, **kargs)[0]
411+
match figure:
412+
case None:
413+
figure = datafile.template.new_figure(None, projection=projection, **kargs)[0]
414+
case int():
415+
figure = datafile.template.new_figure(figure, projection=projection, **kargs)[0]
416+
case mplfig.Figure():
417+
figure = datafile.template.new_figure(figure.number, projection=projection, **kargs)[0]
418+
case _:
419+
raise ValueError(f"Unable to interpret {figure=}")
410420
datafile._figure = figure
411421
return datafile
412422

@@ -660,18 +670,25 @@ def inset(datafile, parent=None, loc=None, width=0.35, height=0.30, **kargs): #
660670
loc = locations2.index(loc)
661671
else:
662672
raise RuntimeError(f"Couldn't work out where {loc} was supposed to be")
663-
if isinstance(width, int):
664-
width = f"{width}%"
665-
elif isinstance(width, float) and 0 < width <= 1:
666-
width = f"{width*100}%"
667-
elif not isinstance(width, str):
668-
raise RuntimeError(f"didn't Recognize width specification {width}")
669-
if isinstance(height, int):
670-
height = f"{height}%"
671-
elif isinstance(height, float) and 0 < height <= 1:
672-
height = "{height * 100}%"
673-
elif not isinstance(height, str):
674-
raise RuntimeError("didn't Recognize height specification {height}")
673+
match width:
674+
case int():
675+
width = f"{width}%"
676+
case float() if 0 < width <= 1:
677+
width = f"{width*100}%"
678+
case str() if re.match(r"[0-9]+\%", width):
679+
pass
680+
case _:
681+
raise RuntimeError(f"didn't Recognize width specification {width=}")
682+
match height:
683+
case int():
684+
height = f"{height}%"
685+
case float() if 0 < width <= 1:
686+
height = "{height * 100}%"
687+
case str() if re.match(r"[0-9]+\%", height):
688+
pass
689+
690+
case _:
691+
raise RuntimeError("didn't Recognize height specification {height=}")
675692
if parent is None:
676693
parent = plt.gca()
677694
return inset_locator.inset_axes(parent, width, height, loc, **kargs)
@@ -773,47 +790,47 @@ def plot_matrix(
773790
The matplotib figure with the data plotted
774791
"""
775792
# Sortout yvals values
776-
if isinstance(yvals, int): # Int means we're specifying a data row
777-
if rectang is None: # we need to initialise the rectang
778-
rectang = (yvals + 1, 0) # We'll sort the column origin later
779-
elif (
780-
isinstance(rectang, tuple) and rectang[1] <= yvals
781-
): # We have a rectang, but we need to adjust the row origin
782-
rectang[0] = yvals + 1
783-
yvals = datafile[yvals] # change the yvals into a numpy array
784-
elif isinstance(yvals, (list, tuple, np.ndarray)): # We're given the yvals as a list already
785-
yvals = np.array(yvals)
786-
elif yvals is None: # No yvals, so we'l try column headings
787-
if isinstance(xvals, index_types): # Do we have an xcolumn header to take away ?
788-
xvals = datafile.find_col(xvals)
789-
headers = datafile.column_headers[xvals + 1 :]
790-
elif xvals is None: # No xvals so we're going to be using the first column
791-
xvals = 0
792-
headers = datafile.column_headers[1:]
793-
else:
794-
headers = datafile.column_headers
795-
yvals = np.array([float(x) for x in headers]) # Ok try to construct yvals array
796-
else:
797-
raise RuntimeError("uvals must be either an integer, list, tuple, numpy array or None")
793+
match yvals:
794+
case int() | str() | _pattern_type():
795+
if rectang is None: # we need to initialise the rectang
796+
rectang = (yvals + 1, 0) # We'll sort the column origin later
797+
elif (
798+
isinstance(rectang, tuple) and rectang[1] <= yvals
799+
): # We have a rectang, but we need to adjust the row origin
800+
rectang[0] = yvals + 1
801+
yvals = datafile[yvals] # change the yvals into a numpy array
802+
case list() | tuple() | np.ndarray():
803+
yvals = np.array(yvals)
804+
case None:
805+
if isinstance(xvals, (int, str, _pattern_type)): # Do we have an xcolumn header to take away ?
806+
xvals = datafile.find_col(xvals)
807+
headers = datafile.column_headers[xvals + 1 :]
808+
elif xvals is None: # No xvals so we're going to be using the first column
809+
xvals = 0
810+
headers = datafile.column_headers[1:]
811+
else:
812+
headers = datafile.column_headers
813+
yvals = np.array([float(x) for x in headers]) # Ok try to construct yvals array
814+
case _:
815+
raise RuntimeError("uvals must be either an integer, list, tuple, numpy array or None")
798816
# Sort out xvls values
799-
if isinstance(xvals, index_types): # String or int means using a column index
800-
if xlabel is None:
801-
xlabel = datafile._col_label(xvals)
802-
if rectang is None: # Do we need to init the rectan ?
803-
rectang = (0, xvals + 1)
804-
elif isinstance(rectang, tuple): # Do we need to adjust the rectan column origin ?
805-
rectang[1] = xvals + 1
806-
xvals = datafile.column(xvals)
807-
elif isinstance(xvals, (list, tuple, np.ndarray)): # Xvals as a data item
808-
xvals = np.array(xvals)
809-
elif isinstance(xvals, np.ndarray):
810-
pass
811-
elif xvals is None: # xvals from column 0
812-
xvals = datafile.column(0)
813-
if rectang is None: # and fix up rectang
814-
rectang = (0, 1)
815-
else:
816-
raise RuntimeError("xvals must be a string, integer, list, tuple or numpy array or None")
817+
match xvals:
818+
case int() | str() | _pattern_type():
819+
if xlabel is None:
820+
xlabel = datafile._col_label(xvals)
821+
if rectang is None: # Do we need to init the rectan ?
822+
rectang = (0, xvals + 1)
823+
elif isinstance(rectang, tuple): # Do we need to adjust the rectan column origin ?
824+
rectang[1] = xvals + 1
825+
xvals = datafile.column(xvals)
826+
case list() | tuple() | np.ndarray():
827+
xvals = np.array(xvals)
828+
case None:
829+
xvals = datafile.column(0)
830+
if rectang is None: # and fix up rectang
831+
rectang = (0, 1)
832+
case _:
833+
raise RuntimeError("xvals must be a string, integer, list, tuple or numpy array or None")
817834

818835
if isinstance(rectang, tuple) and len(rectang) == 2: # Sort the rectang value
819836
rectang = (
@@ -1058,13 +1075,13 @@ def plot_xy(datafile, xcol=None, ycol=None, fmt=None, xerr=None, yerr=None, **ka
10581075
if isnone(kargs.get(err, None)):
10591076
kargs.pop(err, None)
10601077

1061-
elif isinstance(kargs[err], index_types):
1078+
elif isinstance(kargs[err], (int, str, _pattern_type)):
10621079
kargs[err] = datafile.column(kargs[err])
10631080
elif isiterable(kargs[err]) and isinstance(c.ycol, list) and len(kargs[err]) <= len(c.ycol):
10641081
# Ok, so it's a list, so redo the check for each item.
10651082
kargs[err].extend([None] * (len(c.ycol) - len(kargs[err])))
10661083
for i in range(len(kargs[err])):
1067-
if isinstance(kargs[err][i], index_types):
1084+
if isinstance(kargs[err][i], (int, str, _pattern_type)):
10681085
kargs[err][i] = datafile.column(kargs[err][i])
10691086
else:
10701087
kargs[err][i] = np.zeros(len(datafile))
@@ -1076,7 +1093,7 @@ def plot_xy(datafile, xcol=None, ycol=None, fmt=None, xerr=None, yerr=None, **ka
10761093
kargs[err] = np.zeros(len(datafile))
10771094

10781095
temp_kwords = copy.copy(kargs)
1079-
if isinstance(c.ycol, (index_types)):
1096+
if isinstance(c.ycol, ((int, str, _pattern_type))):
10801097
c.ycol = [c.ycol]
10811098
if len(c.ycol) > 1:
10821099
if multiple == "panels":
@@ -1211,7 +1228,7 @@ def plot_xyz(datafile, xcol=None, ycol=None, zcol=None, shape=None, xlim=None, y
12111228
}
12121229
coltypes = {"xlabel": c.xcol, "ylabel": c.ycol, "zlabel": c.zcol}
12131230
for k in coltypes:
1214-
if isinstance(coltypes[k], index_types):
1231+
if isinstance(coltypes[k], (int, str, _pattern_type)):
12151232
label = datafile._col_label(coltypes[k])
12161233
if isinstance(label, list):
12171234
label = ",".join(label)
@@ -1402,7 +1419,7 @@ def plot_xyzuvw(datafile, xcol=None, ycol=None, zcol=None, ucol=None, vcol=None,
14021419
projection = kargs.pop("projection", "3d")
14031420
coltypes = {"xlabel": c.xcol, "ylabel": c.ycol, "zlabel": c.zcol}
14041421
for k in coltypes:
1405-
if isinstance(coltypes[k], index_types):
1422+
if isinstance(coltypes[k], (int, str, _pattern_type)):
14061423
label = datafile._col_label(coltypes[k])
14071424
if isinstance(label, list):
14081425
label = ",".join(label)
@@ -1416,7 +1433,7 @@ def plot_xyzuvw(datafile, xcol=None, ycol=None, zcol=None, ucol=None, vcol=None,
14161433
colors = nonkargs.pop("color", True)
14171434
if isinstance(colors, bool) and colors:
14181435
pass
1419-
elif isinstance(colors, index_types):
1436+
elif isinstance(colors, (int, str, _pattern_type)):
14201437
colors = datafile.column(colors)
14211438
elif isinstance(colors, np.ndarray):
14221439
pass

0 commit comments

Comments
 (0)