diff --git a/docs/index.rst b/docs/index.rst index 2e41bd6..b4b438d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -87,6 +87,7 @@ interpreter of your choice: neksuite.rst simsonsuite.rst vtksuite.rst + viz.rst dataset.rst meshtools.rst usage.myst.md diff --git a/docs/viz.rst b/docs/viz.rst new file mode 100644 index 0000000..042d029 --- /dev/null +++ b/docs/viz.rst @@ -0,0 +1,371 @@ +.. _viz: + +pymech.viz +========== + +3D mesh visualization using PyVista or Matplotlib backends. + +.. warning:: + + This subpackage requires visualization libraries. Install with:: + + pip install pymech[plot] + + This installs PyVista, Matplotlib, and required dependencies. + +Overview +-------- + +The ``pymech.viz`` subpackage provides modern mesh visualization capabilities with two backends: + +- **PyVista** (recommended): Interactive 3D visualization optimized for Jupyter notebooks +- **Matplotlib**: Publication-quality figures and simple 3D plots + +Architecture +~~~~~~~~~~~~ + +The visualization system uses a modular, protocol-based architecture within the ``pymech.viz`` subpackage: + +- ``pymech.viz.viz_protocol``: Defines the ``MeshBackend`` Protocol interface +- ``pymech.viz.pyvista_backend_impl``: PyVista-specific implementation +- ``pymech.viz.matplotlib_backend``: Matplotlib-specific implementation +- ``pymech.viz.pyvista_backend``: Backend dispatcher module + +The subpackage exports a unified API through ``pymech.viz``: + +- Main functions: ``plot_mesh()``, ``get_available_backends()`` +- PyVista utilities: ``hexa_to_pyvista()``, ``add_boundary_conditions()`` +- Protocol and constants: ``MeshBackend``, ``DEFAULT_BC_COLORS`` + +Backends implement the ``MeshBackend`` protocol using ``typing.Protocol`` and +``runtime_checkable``, ensuring consistent interfaces across different rendering engines. +The main ``plot_mesh()`` function automatically selects the best available backend or +uses the explicitly specified one + +Quick Start +----------- + +PyVista Backend (Interactive 3D) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + import pymech as pm + from pymech.viz import plot_mesh + + # Load mesh + field = pm.readnek("channel3D_0.f00001") + + # Plot with boundary conditions + plot_mesh(field, backend='pyvista') + +Matplotlib Backend (Publication Figures) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + # Plot with Matplotlib for static figure + plot_mesh(field, backend='matplotlib', view='xy', + screenshot='mesh_figure.png') + +Jupyter Notebook Usage +---------------------- + +PyVista automatically detects Jupyter environments and uses interactive backends: + +.. code-block:: python + + # In Jupyter notebook - automatically interactive + from pymech.viz import plot_mesh + + field = pm.readnek("mesh.nek5000") + plot_mesh(field, jupyter_backend="trame") + +Available Jupyter backends: + +- ``"trame"``: Interactive (default, requires trame package) +- ``"static"``: Non-interactive PNG/JPEG +- ``"ipyvtklink"``: Alternative interactive backend + +Features +-------- + +Mesh Resolution Modes +~~~~~~~~~~~~~~~~~~~~~ + +**Linear resolution** (default, fast): + +.. code-block:: python + + plot_mesh(field, resolution='linear') + +Uses only corner vertices (8 per hexahedron), suitable for most visualizations. + +**Spectral resolution** (accurate, slower): + +.. code-block:: python + + plot_mesh(field, resolution='spectral') + +Uses all GLL points, subdividing each element. Preserves curved geometries accurately. + +Boundary Condition Visualization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Color-coded boundary conditions: + +.. code-block:: python + + # Show BCs (default) + plot_mesh(field, show_bcs=True, bc_field=0) + + # bc_field=0: velocity BCs + # bc_field=1: temperature BCs + +BC Color Scheme: + +- **W** (wall): Dark blue +- **v** (velocity BC): Light blue +- **O** (outflow): Dark red +- **T** (temperature): Dark green +- **E** (element): Black +- **P** (periodic): Gray + +Field Visualization +~~~~~~~~~~~~~~~~~~~ + +Visualize velocity, pressure, or temperature fields: + +.. code-block:: python + + from pymech.viz import hexa_to_pyvista + import pyvista as pv + + # Convert mesh with field data + mesh = hexa_to_pyvista(field, include_fields=True) + + # Plot velocity magnitude + plotter = pv.Plotter() + plotter.add_mesh(mesh, scalars="velocity_magnitude", + cmap="coolwarm", show_edges=True) + plotter.show() + +Custom Visualization +-------------------- + +Advanced customization with PyVista: + +.. code-block:: python + + # Return plotter for customization + plotter = plot_mesh(field, return_plotter=True) + + # Customize camera + plotter.camera_position = [(10, 10, 10), (0, 0, 0), (0, 1, 0)] + + # Add text annotation + plotter.add_text("Channel Flow Mesh", position="upper_right", + font_size=12) + + # Show + plotter.show() + +Export Screenshots +~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + # Export high-resolution image + plot_mesh(field, screenshot="mesh_hires.png") + + # Matplotlib for publication figures + plot_mesh(field, backend='matplotlib', + screenshot="mesh_pub.pdf", figsize=(12, 10)) + +Backend Comparison +------------------ + ++----------------------+---------------------+----------------------+ +| Feature | PyVista | Matplotlib | ++======================+=====================+======================+ +| Dimensionality | Full 3D | Limited 3D | ++----------------------+---------------------+----------------------+ +| Interactivity | Full rotation/zoom | Basic | ++----------------------+---------------------+----------------------+ +| Jupyter support | Excellent (trame) | Good | ++----------------------+---------------------+----------------------+ +| BC visualization | Face colors | Edge colors | ++----------------------+---------------------+----------------------+ +| Performance | Good for large mesh | Slow for large mesh | ++----------------------+---------------------+----------------------+ +| Use case | Exploration | Publication figures | ++----------------------+---------------------+----------------------+ + +**Recommendation**: Use PyVista for interactive exploration and Jupyter workflows. +Use Matplotlib for generating publication-quality 2D projections. + +API Reference +------------- + +Main Function +~~~~~~~~~~~~~ + +.. autofunction:: pymech.viz.plot_mesh + :noindex: + +Backend Discovery +~~~~~~~~~~~~~~~~~ + +.. autofunction:: pymech.viz.get_available_backends + :noindex: + +Conversion Functions (PyVista) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pymech.viz.hexa_to_pyvista + :noindex: + +.. autofunction:: pymech.viz.add_boundary_conditions + :noindex: + +Complete API +~~~~~~~~~~~~ + +.. automodule:: pymech.viz + :members: + :undoc-members: + :show-inheritance: + +Examples Gallery +---------------- + +Example 1: Check Available Backends +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from pymech.viz import get_available_backends + + backends = get_available_backends() + print(f"Available backends: {list(backends.keys())}") + + # Check capabilities + for name, backend in backends.items(): + caps = backend.get_capabilities() + print(f"{name}: 3D={caps['3d']}, Interactive={caps['interactive']}") + +Example 2: Basic Mesh Visualization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + import pymech as pm + from pymech.viz import plot_mesh + + field = pm.readnek("channel3D_0.f00001") + plot_mesh(field) + +Example 3: Custom Camera View +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + plotter = plot_mesh(field, return_plotter=True) + plotter.camera_position = 'xy' # Top-down view + plotter.show() + +Example 4: Velocity Field Coloring +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from pymech.viz import hexa_to_pyvista + import pyvista as pv + + mesh = hexa_to_pyvista(field, include_fields=True) + mesh.plot(scalars="pressure", cmap="viridis", show_edges=True) + +Example 5: Publication Figure +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + plot_mesh(field, backend='matplotlib', view='xz', + show_bcs=False, color='darkblue', + screenshot='figure.pdf', figsize=(8, 6)) + +Comparison with meshplot +------------------------- + +The existing ``meshplot.py`` module provides 2D visualization with wxPython. +Choose the appropriate tool for your use case: + ++-----------------------+---------------------+---------------------------+ +| Feature | meshplot (2D) | pymech.viz (3D) | ++=======================+=====================+===========================+ +| Dimensionality | 2D only | 2D and 3D | ++-----------------------+---------------------+---------------------------+ +| Curved edges | Exact (parabolic) | Linear approx. | ++-----------------------+---------------------+---------------------------+ +| Jupyter support | No | Yes (primary) | ++-----------------------+---------------------+---------------------------+ +| Boundary conditions | Edge colors | Face colors | ++-----------------------+---------------------+---------------------------+ +| Interactivity | Zoom/pan only | Full 3D rotation | ++-----------------------+---------------------+---------------------------+ +| Dependencies | wxPython, OpenGL | PyVista or Matplotlib | ++-----------------------+---------------------+---------------------------+ + +**Use ``meshplot``** for detailed 2D edge inspection with exact curved geometry. + +**Use ``pymech.viz``** for 3D exploration, Jupyter workflows, and publication figures. + +Troubleshooting +--------------- + +PyVista Not Available +~~~~~~~~~~~~~~~~~~~~~ + +If you get "PyVista is required" error: + +.. code-block:: bash + + pip install pymech[plot] + +or install PyVista separately: + +.. code-block:: bash + + pip install pyvista trame ipywidgets + +Headless Rendering +~~~~~~~~~~~~~~~~~~ + +For headless servers or CI/CD: + +.. code-block:: python + + import pyvista as pv + pv.OFF_SCREEN = True + + plot_mesh(field, screenshot="output.png") + +Jupyter Display Issues +~~~~~~~~~~~~~~~~~~~~~~ + +If plots don't show in Jupyter: + +.. code-block:: python + + import pyvista as pv + pv.set_jupyter_backend('static') # Try static backend + + # Or try different backend + plot_mesh(field, jupyter_backend='ipyvtklink') + +See Also +-------- + +- :ref:`vtksuite` - VTK export functions +- :ref:`meshtools` - Mesh manipulation utilities +- :ref:`dataset` - Xarray integration diff --git a/pyproject.toml b/pyproject.toml index 09291f4..00d94af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,13 @@ dependencies = [ [project.optional-dependencies] opt = ["dask", "rich"] vtk = ["mayavi", "pygments >= 2.16.1"] -full = ["pymech[opt,vtk]"] +plot = [ + "pyvista >= 0.43.0", + "trame >= 2.0.0", + "ipywidgets >= 8.0.0", + "matplotlib", +] +full = ["pymech[opt,vtk,plot]"] docs = [ "asv >= 0.5.1", "furo", @@ -47,7 +53,7 @@ docs = [ "sphinx-inline-tabs", ] tests = [ - "pymech[opt]", + "pymech[opt,plot]", "coverage[toml]", "pygments >= 2.16.1", "pytest >= 6.2.5", diff --git a/src/pymech/__init__.py b/src/pymech/__init__.py index 2d5a5b1..1078875 100644 --- a/src/pymech/__init__.py +++ b/src/pymech/__init__.py @@ -10,6 +10,7 @@ dataset meshtools log + viz """ @@ -26,3 +27,13 @@ warn(repr(err), ImportWarning) from ._version import __version__ # noqa + +# Optional visualization subpackage (PyVista + Matplotlib) +try: + from . import viz # noqa + + # Backward compatibility alias + pyvista_backend = viz +except ImportError: + # PyVista/Matplotlib not installed, visualization features unavailable + pass diff --git a/src/pymech/viz/__init__.py b/src/pymech/viz/__init__.py new file mode 100644 index 0000000..a7a8823 --- /dev/null +++ b/src/pymech/viz/__init__.py @@ -0,0 +1,60 @@ +"""Mesh visualization subpackage with multiple backend support. + +This subpackage provides a unified interface for mesh visualization using either +PyVista or Matplotlib backends, with automatic backend selection. + +Examples +-------- +PyVista backend (interactive 3D): + +>>> import pymech as pm +>>> from pymech.viz import plot_mesh +>>> +>>> field = pm.readnek("channel3D_0.f00001") +>>> plot_mesh(field, backend='pyvista') + +Matplotlib backend (publication figures): + +>>> plot_mesh(field, backend='matplotlib', view='xy') + +Auto-select best available backend: + +>>> plot_mesh(field, backend='auto') +""" + +# Import main API from pyvista_backend module (which is the dispatcher) +from .pyvista_backend import ( + add_boundary_conditions, + get_available_backends, + hexa_to_pyvista, + plot_mesh, +) + +# Import Protocol and backend classes for advanced use +from .viz_protocol import DEFAULT_BC_COLORS, MeshBackend + +# Optional imports for backend implementations +try: + from .pyvista_backend_impl import PyVistaBackend +except ImportError: + PyVistaBackend = None + +try: + from .matplotlib_backend import MatplotlibBackend +except ImportError: + MatplotlibBackend = None + +__all__ = ( + # Main API + "plot_mesh", + "get_available_backends", + # PyVista-specific functions + "hexa_to_pyvista", + "add_boundary_conditions", + # Protocol and colors + "MeshBackend", + "DEFAULT_BC_COLORS", + # Backend classes + "PyVistaBackend", + "MatplotlibBackend", +) diff --git a/src/pymech/viz/matplotlib_backend.py b/src/pymech/viz/matplotlib_backend.py new file mode 100644 index 0000000..ee7edf4 --- /dev/null +++ b/src/pymech/viz/matplotlib_backend.py @@ -0,0 +1,239 @@ +"""Matplotlib backend implementation for mesh visualization. + +This module provides Matplotlib-based mesh visualization for publication-quality +static figures and basic 3D plots. +""" + +from typing import Any, Optional, Tuple + +import numpy as np + +from ..core import HexaData +from ..log import logger +from .viz_protocol import ( + DEFAULT_BC_COLORS, + Color, + Colormap, + Resolution, + View, +) + +# Try importing Matplotlib +try: + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D + + MATPLOTLIB_AVAILABLE = True +except ImportError: + MATPLOTLIB_AVAILABLE = False + plt = None + Axes3D = None + +__all__ = ("MatplotlibBackend",) + + +class MatplotlibBackend: + """Matplotlib visualization backend implementation.""" + + def is_available(self) -> bool: + """Check if Matplotlib is available.""" + return MATPLOTLIB_AVAILABLE + + def get_backend_name(self) -> str: + """Get backend name.""" + return "matplotlib" + + def get_capabilities(self) -> dict: + """Get backend capabilities.""" + return { + "interactive": False, + "3d": True, # Limited 3D support + "jupyter": True, + "headless": True, + "formats": ["png", "jpg", "pdf", "svg", "eps", "ps"], + } + + def plot_mesh( + self, + field: HexaData, + resolution: Resolution = "linear", + show_bcs: bool = True, + bc_field: int = 0, + show_edges: bool = True, + style: str = "surface", + color: Color = None, + cmap: Colormap = None, + view: View = None, + screenshot: Optional[str] = None, + return_plotter: bool = False, + figsize: Tuple[float, float] = (10, 8), + **kwargs, + ) -> Optional[Any]: + """Plot mesh using Matplotlib. + + Parameters documented in MeshBackend Protocol. + + Notes + ----- + Matplotlib backend has limited 3D capabilities compared to PyVista. + Best used for publication-quality 2D projections. + """ + if not MATPLOTLIB_AVAILABLE: + raise ImportError( + "Matplotlib required for this backend. Install with:\n" + " pip install matplotlib" + ) + + # Filter out PyVista-specific kwargs that matplotlib doesn't understand + matplotlib_kwargs = kwargs.copy() + for key in ["jupyter_backend", "window_size", "notebook", "off_screen"]: + matplotlib_kwargs.pop(key, None) + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111, projection="3d") + + # Extract edges from elements + logger.info("Extracting mesh edges for Matplotlib...") + + for iel, elem in enumerate(field.elem): + # Get element dimensions + lx, ly, lz = elem.pos.shape[3], elem.pos.shape[2], elem.pos.shape[1] + + # Define edges based on dimensionality + if field.ndim == 3: + edges = _get_hex_edges_3d() + else: # 2D + edges = _get_quad_edges_2d() + + # Get BC for this element if needed + if show_bcs: + edge_color = _get_element_bc_color(elem, bc_field) + elif color: + edge_color = color + else: + edge_color = "blue" + + # Plot each edge + for (ix1, iy1, iz1), (ix2, iy2, iz2) in edges: + p1 = elem.pos[:, iz1, iy1, ix1] + p2 = elem.pos[:, iz2, iy2, ix2] + + ax.plot( + [p1[0], p2[0]], + [p1[1], p2[1]], + [p1[2], p2[2]], + color=edge_color, + linewidth=0.5, + **matplotlib_kwargs, + ) + + # Set labels + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + + # Set view angle + if view == "xy": + ax.view_init(elev=90, azim=0) + elif view == "xz": + ax.view_init(elev=0, azim=0) + elif view == "yz": + ax.view_init(elev=0, azim=90) + elif view: + # Try to parse as (elev, azim) + try: + elev, azim = map(float, view.split(",")) + ax.view_init(elev=elev, azim=azim) + except (ValueError, AttributeError): + pass # Use default view + + # Equal aspect ratio + _set_axes_equal(ax) + + plt.tight_layout() + + # Save screenshot + if screenshot: + plt.savefig(screenshot, dpi=300, bbox_inches="tight") + logger.info(f"Screenshot saved to {screenshot}") + + if return_plotter: + return fig + else: + plt.show() + return None + + +def _get_hex_edges_3d() -> list: + """Get edge definitions for 3D hexahedron.""" + return [ + # Bottom face + ((0, 0, 0), (-1, 0, 0)), + ((-1, 0, 0), (-1, -1, 0)), + ((-1, -1, 0), (0, -1, 0)), + ((0, -1, 0), (0, 0, 0)), + # Top face + ((0, 0, -1), (-1, 0, -1)), + ((-1, 0, -1), (-1, -1, -1)), + ((-1, -1, -1), (0, -1, -1)), + ((0, -1, -1), (0, 0, -1)), + # Vertical edges + ((0, 0, 0), (0, 0, -1)), + ((-1, 0, 0), (-1, 0, -1)), + ((-1, -1, 0), (-1, -1, -1)), + ((0, -1, 0), (0, -1, -1)), + ] + + +def _get_quad_edges_2d() -> list: + """Get edge definitions for 2D quadrilateral.""" + return [ + ((0, 0, 0), (-1, 0, 0)), + ((-1, 0, 0), (-1, -1, 0)), + ((-1, -1, 0), (0, -1, 0)), + ((0, -1, 0), (0, 0, 0)), + ] + + +def _get_element_bc_color(elem, bc_field: int) -> tuple: + """Get representative BC color for an element. + + Since matplotlib edge-based rendering doesn't distinguish faces, + we use the first non-empty BC color found. + """ + try: + for iface in range(elem.bcs.shape[1]): + bc_type = elem.bcs[bc_field, iface][0] + if bc_type and bc_type != "E": + color = DEFAULT_BC_COLORS.get(bc_type) + if color: + return color + except (IndexError, KeyError): + pass + + # Default to black + return (0.0, 0.0, 0.0) + + +def _set_axes_equal(ax: "Axes3D") -> None: + """Set 3D plot axes to equal scale. + + Parameters + ---------- + ax : Axes3D + Matplotlib 3D axes object + """ + limits = np.array( + [ + ax.get_xlim3d(), + ax.get_ylim3d(), + ax.get_zlim3d(), + ] + ) + + origin = np.mean(limits, axis=1) + radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0])) + + ax.set_xlim3d([origin[0] - radius, origin[0] + radius]) + ax.set_ylim3d([origin[1] - radius, origin[1] + radius]) + ax.set_zlim3d([origin[2] - radius, origin[2] + radius]) diff --git a/src/pymech/viz/pyvista_backend.py b/src/pymech/viz/pyvista_backend.py new file mode 100644 index 0000000..8c9eb01 --- /dev/null +++ b/src/pymech/viz/pyvista_backend.py @@ -0,0 +1,300 @@ +"""Unified mesh visualization API with multiple backend support. + +This module provides a unified interface for mesh visualization using either +PyVista or Matplotlib backends, with automatic backend selection. + +Examples +-------- +PyVista backend (interactive 3D): + +>>> import pymech as pm +>>> from pymech.pyvista_backend import plot_mesh +>>> +>>> field = pm.readnek("channel3D_0.f00001") +>>> plot_mesh(field, backend='pyvista') + +Matplotlib backend (publication figures): + +>>> plot_mesh(field, backend='matplotlib', view='xy') + +Auto-select best available backend: + +>>> plot_mesh(field, backend='auto') + +""" + +import warnings +from typing import Any, Literal, Optional, Tuple + +from ..core import HexaData +from ..log import logger +from .viz_protocol import MeshBackend + +# Import backend implementations +try: + from .pyvista_backend_impl import ( + PyVistaBackend, + add_boundary_conditions, + hexa_to_pyvista, + ) + + PYVISTA_BACKEND_AVAILABLE = True +except ImportError: + PYVISTA_BACKEND_AVAILABLE = False + hexa_to_pyvista = None + add_boundary_conditions = None + +try: + from .matplotlib_backend import MatplotlibBackend + + MATPLOTLIB_BACKEND_AVAILABLE = True +except ImportError: + MATPLOTLIB_BACKEND_AVAILABLE = False + +__all__ = ( + "plot_mesh", + "hexa_to_pyvista", + "add_boundary_conditions", + "get_available_backends", +) + + +def get_available_backends() -> dict: + """Get dictionary of available visualization backends. + + Returns + ------- + dict + Mapping of backend names to backend instances (only available backends) + + Examples + -------- + >>> backends = get_available_backends() + >>> print(f"Available backends: {list(backends.keys())}") + Available backends: ['pyvista', 'matplotlib'] + """ + backends = {} + + if PYVISTA_BACKEND_AVAILABLE: + pv_backend = PyVistaBackend() + if pv_backend.is_available(): + backends["pyvista"] = pv_backend + + if MATPLOTLIB_BACKEND_AVAILABLE: + mpl_backend = MatplotlibBackend() + if mpl_backend.is_available(): + backends["matplotlib"] = mpl_backend + + return backends + + +def _get_backend(backend_name: str) -> MeshBackend: + """Get a specific backend instance. + + Parameters + ---------- + backend_name : str + Name of backend ('pyvista', 'matplotlib', or 'auto') + + Returns + ------- + MeshBackend + Backend instance + + Raises + ------ + ValueError + If backend name is invalid + ImportError + If requested backend is not available + """ + available_backends = get_available_backends() + + if backend_name == "auto": + # Prefer PyVista if available + if "pyvista" in available_backends: + return available_backends["pyvista"] + elif "matplotlib" in available_backends: + return available_backends["matplotlib"] + else: + raise ImportError( + "No visualization backend available. Install with:\n" + " pip install pymech[plot]" + ) + elif backend_name in available_backends: + return available_backends[backend_name] + elif backend_name in ("pyvista", "matplotlib"): + # Backend name is valid but not available + raise ImportError( + f"{backend_name} backend not available. Install with:\n" + f" pip install pymech[plot]" + ) + else: + available = list(available_backends.keys()) + raise ValueError( + f"Invalid backend '{backend_name}'. " + f"Must be 'pyvista', 'matplotlib', or 'auto'. " + f"Available: {available}" + ) + + +def plot_mesh( + field: HexaData, + backend: Literal["pyvista", "matplotlib", "auto"] = "auto", + resolution: Literal["linear", "spectral"] = "linear", + show_bcs: bool = True, + bc_field: int = 0, + show_edges: bool = True, + style: str = "surface", + color: Optional[str] = None, + cmap: Optional[str] = None, + jupyter_backend: str = "trame", + view: Optional[str] = None, + screenshot: Optional[str] = None, + return_plotter: bool = False, + figsize: Tuple[float, float] = (10, 8), + **kwargs, +) -> Optional[Any]: + """Plot 3D mesh with boundary conditions using PyVista or Matplotlib. + + This is the main entry point for mesh visualization. It automatically + selects the best available backend or uses the specified one. + + Parameters + ---------- + field : HexaData + Mesh to visualize + backend : {'pyvista', 'matplotlib', 'auto'}, default='auto' + Visualization backend. 'auto' prefers PyVista if available + resolution : {'linear', 'spectral'}, default='linear' + Mesh resolution: 'linear' (corners only) or 'spectral' (all GLL points) + Note: Only affects PyVista backend + show_bcs : bool, default=True + Whether to color faces/edges by boundary conditions + bc_field : int, default=0 + Which BC field to visualize (0=velocity, 1=temperature, ...) + show_edges : bool, default=True + Whether to show mesh edges + style : str, default='surface' + Visualization style: 'surface', 'wireframe', 'points' (PyVista only) + color : str, optional + Uniform color if not showing BCs (e.g., 'white', '#3498db') + cmap : str, optional + Colormap for scalar fields (e.g., 'viridis', 'coolwarm') + jupyter_backend : str, default='trame' + Backend for Jupyter: 'trame' (interactive), 'static', 'ipyvtklink' + (PyVista only) + view : str, optional + Camera view: 'xy', 'xz', 'yz', 'iso' (PyVista) or similar for Matplotlib + screenshot : str, optional + Save screenshot to this filename + return_plotter : bool, default=False + Return plotter/figure object for further customization + figsize : tuple, default=(10, 8) + Figure size for Matplotlib backend + **kwargs + Additional arguments passed to backend's mesh plotting function + + Returns + ------- + plotter : pv.Plotter or matplotlib.Figure, optional + Plotter/figure object if return_plotter=True + + Raises + ------ + ImportError + If no visualization backend is available + ValueError + If invalid backend name is specified + + Examples + -------- + Basic usage with auto backend selection: + + >>> import pymech as pm + >>> from pymech.pyvista_backend import plot_mesh + >>> field = pm.readnek("mesh.nek5000") + >>> plot_mesh(field) + + Using specific backend: + + >>> plot_mesh(field, backend='matplotlib', view='xy') + + Customizing PyVista visualization: + + >>> plotter = plot_mesh(field, backend='pyvista', return_plotter=True) + >>> plotter.camera_position = [(10, 10, 10), (0, 0, 0), (0, 1, 0)] + >>> plotter.show() + + Saving high-resolution figure: + + >>> plot_mesh(field, backend='matplotlib', screenshot='mesh.pdf', + ... figsize=(12, 10)) + + See Also + -------- + get_available_backends : Check which backends are available + hexa_to_pyvista : Convert HexaData to PyVista format (PyVista backend only) + add_boundary_conditions : Add BC data to mesh (PyVista backend only) + + """ + # Get backend instance + backend_instance = _get_backend(backend) + + logger.info( + f"Using {backend_instance.get_backend_name()} backend for visualization" + ) + + # Call backend's plot_mesh method + return backend_instance.plot_mesh( + field=field, + resolution=resolution, + show_bcs=show_bcs, + bc_field=bc_field, + show_edges=show_edges, + style=style, + color=color, + cmap=cmap, + view=view, + screenshot=screenshot, + return_plotter=return_plotter, + figsize=figsize, + jupyter_backend=jupyter_backend, # Pass through kwargs + **kwargs, + ) + + +# Maintain backward compatibility: export PyVista-specific functions if available +if not PYVISTA_BACKEND_AVAILABLE: + + def hexa_to_pyvista(*args, **kwargs): + """PyVista not available.""" + raise ImportError( + "PyVista backend not available. Install with:\n" + " pip install pymech[plot]" + ) + + def add_boundary_conditions(*args, **kwargs): + """PyVista not available.""" + raise ImportError( + "PyVista backend not available. Install with:\n" + " pip install pymech[plot]" + ) + + +# Module-level convenience: show available backends on import +def _show_backend_info(): + """Display information about available backends (suppressed by default).""" + backends = get_available_backends() + if backends: + logger.debug(f"Available visualization backends: {list(backends.keys())}") + else: + warnings.warn( + "No visualization backends available. Install with: pip install pymech[plot]", + ImportWarning, + stacklevel=2, + ) + + +# Don't show info by default to avoid clutter +# _show_backend_info() diff --git a/src/pymech/viz/pyvista_backend_impl.py b/src/pymech/viz/pyvista_backend_impl.py new file mode 100644 index 0000000..d68bf90 --- /dev/null +++ b/src/pymech/viz/pyvista_backend_impl.py @@ -0,0 +1,451 @@ +"""PyVista backend implementation for mesh visualization. + +This module provides PyVista-specific mesh visualization, optimized for +interactive 3D rendering in Jupyter notebooks. +""" + +from typing import Any, Optional, Tuple + +import numpy as np + +from ..core import HexaData +from ..log import logger +from .viz_protocol import ( + DEFAULT_BC_COLORS, + Color, + Colormap, + Resolution, + View, + compute_face_center, +) + +# Try importing PyVista +try: + import pyvista as pv + + PYVISTA_AVAILABLE = True +except ImportError: + PYVISTA_AVAILABLE = False + pv = None + +__all__ = ("PyVistaBackend", "hexa_to_pyvista", "add_boundary_conditions") + + +class PyVistaBackend: + """PyVista visualization backend implementation.""" + + def is_available(self) -> bool: + """Check if PyVista is available.""" + return PYVISTA_AVAILABLE + + def get_backend_name(self) -> str: + """Get backend name.""" + return "pyvista" + + def get_capabilities(self) -> dict: + """Get backend capabilities.""" + return { + "interactive": True, + "3d": True, + "jupyter": True, + "headless": True, + "formats": ["png", "jpg", "bmp", "tif", "svg", "eps", "ps", "pdf", "tex"], + } + + def plot_mesh( + self, + field: HexaData, + resolution: Resolution = "linear", + show_bcs: bool = True, + bc_field: int = 0, + show_edges: bool = True, + style: str = "surface", + color: Color = None, + cmap: Colormap = None, + view: View = None, + screenshot: Optional[str] = None, + return_plotter: bool = False, + figsize: Tuple[float, float] = (10, 8), + **kwargs, + ) -> Optional[Any]: + """Plot mesh using PyVista. + + Parameters documented in MeshBackend Protocol. + """ + if not PYVISTA_AVAILABLE: + raise ImportError( + "PyVista required for this backend. Install with:\n" + " pip install pymech[plot]" + ) + + # Convert mesh + logger.info(f"Converting HexaData to PyVista mesh (resolution={resolution})...") + mesh = hexa_to_pyvista(field, resolution=resolution, include_fields=True) + + # Auto-detect Jupyter environment + try: + from IPython import get_ipython + + if get_ipython() is not None and "IPKernelApp" in get_ipython().config: + in_notebook = True + else: + in_notebook = False + except (ImportError, AttributeError): + in_notebook = False + + # Extract jupyter_backend from kwargs + jupyter_backend = kwargs.pop("jupyter_backend", "trame") + + # Setup plotter + if in_notebook: + pv.set_jupyter_backend(jupyter_backend) + plotter = pv.Plotter(notebook=True) + else: + plotter = pv.Plotter() + + # Add mesh with BCs + if show_bcs: + logger.info("Extracting boundary conditions...") + surface = add_boundary_conditions(mesh, field, bc_field) + + # Color by BC + plotter.add_mesh( + surface, + scalars="bc_color", + rgb=True, + show_edges=show_edges, + style=style, + **kwargs, + ) + + # Add legend for BC types + _add_bc_legend(plotter, surface) + else: + # Simple mesh without BC coloring + mesh_kwargs = {"show_edges": show_edges, "style": style} + if color: + mesh_kwargs["color"] = color + if cmap: + mesh_kwargs["cmap"] = cmap + mesh_kwargs.update(kwargs) + + plotter.add_mesh(mesh, **mesh_kwargs) + + # Set camera and axes + plotter.add_axes() + if view: + plotter.camera_position = view + else: + plotter.camera_position = "iso" + + # Show or save + if screenshot: + plotter.show(screenshot=screenshot, auto_close=False) + logger.info(f"Screenshot saved to {screenshot}") + + if return_plotter: + return plotter + else: + plotter.show() + return None + + +def hexa_to_pyvista( + field: HexaData, + resolution: Resolution = "linear", + include_fields: bool = True, +) -> "pv.UnstructuredGrid": + """Convert HexaData to PyVista UnstructuredGrid. + + Parameters + ---------- + field : HexaData + Mesh data structure from pymech + resolution : {'linear', 'spectral'}, default='linear' + Mesh resolution strategy + include_fields : bool, default=True + Whether to include velocity, pressure, temperature as point data + + Returns + ------- + mesh : pv.UnstructuredGrid + PyVista mesh with optional field data + """ + if not PYVISTA_AVAILABLE: + raise ImportError("PyVista required. Install with: pip install pyvista") + + if resolution == "linear": + return _hexa_to_pyvista_linear(field, include_fields) + elif resolution == "spectral": + return _hexa_to_pyvista_spectral(field, include_fields) + else: + raise ValueError( + f"resolution must be 'linear' or 'spectral', got '{resolution}'" + ) + + +def _hexa_to_pyvista_linear( + field: HexaData, include_fields: bool +) -> "pv.UnstructuredGrid": + """Convert using only corner vertices (fast, approximate).""" + nel = field.nel + ndim = field.ndim + + # Determine cell type and vertex indices + if ndim == 3: + nvert = 8 + cell_type = pv.CellType.HEXAHEDRON + vertex_indices = [ + (0, 0, 0), + (-1, 0, 0), + (-1, -1, 0), + (0, -1, 0), # bottom face + (0, 0, -1), + (-1, 0, -1), + (-1, -1, -1), + (0, -1, -1), # top face + ] + else: # 2D + nvert = 4 + cell_type = pv.CellType.QUAD + vertex_indices = [(0, 0, 0), (-1, 0, 0), (-1, -1, 0), (0, -1, 0)] + + # Allocate arrays + total_points = nel * nvert + points = np.zeros((total_points, 3), dtype=field.elem[0].pos.dtype) + + # Build connectivity + cells_list = [] + for i in range(nel): + cell = [nvert] + list(range(i * nvert, (i + 1) * nvert)) + cells_list.extend(cell) + cells = np.array(cells_list, dtype=np.int64) + + # Extract corner vertices + for iel, elem in enumerate(field.elem): + for ivert, (ix, iy, iz) in enumerate(vertex_indices): + points[iel * nvert + ivert] = elem.pos[:, iz, iy, ix] + + # Create UnstructuredGrid + cell_types = np.full(nel, cell_type, dtype=np.uint8) + mesh = pv.UnstructuredGrid(cells, cell_types, points) + + # Add field data + if include_fields: + _add_field_data(mesh, field, vertex_indices, nvert, total_points) + + # Add element IDs + mesh.cell_data["element_id"] = np.arange(nel) + + return mesh + + +def _hexa_to_pyvista_spectral( + field: HexaData, include_fields: bool +) -> "pv.UnstructuredGrid": + """Convert using all GLL points (slow, accurate).""" + nel = field.nel + ndim = field.ndim + lx, ly, lz = field.lr1 + + nppel = lx * ly * lz + if ndim == 3: + ncpel = (lx - 1) * (ly - 1) * (lz - 1) + cell_type = pv.CellType.HEXAHEDRON + else: + ncpel = (lx - 1) * (ly - 1) + cell_type = pv.CellType.QUAD + + total_points = nel * nppel + total_cells = nel * ncpel + + points = np.zeros((total_points, 3), dtype=field.elem[0].pos.dtype) + cells_list = [] + + # Extract all GLL points + for iel, elem in enumerate(field.elem): + for iz in range(lz): + for iy in range(ly): + for ix in range(lx): + ipt = iel * nppel + ix + iy * lx + iz * lx * ly + points[ipt] = elem.pos[:, iz, iy, ix] + + # Build connectivity + for iel in range(nel): + base_pt = iel * nppel + if ndim == 3: + for iz in range(lz - 1): + for iy in range(ly - 1): + for ix in range(lx - 1): + v0 = base_pt + ix + iy * lx + iz * lx * ly + v1 = v0 + 1 + v2 = v0 + lx + 1 + v3 = v0 + lx + v4 = v0 + lx * ly + v5 = v4 + 1 + v6 = v4 + lx + 1 + v7 = v4 + lx + cells_list.extend([8, v0, v1, v2, v3, v4, v5, v6, v7]) + else: # 2D + for iy in range(ly - 1): + for ix in range(lx - 1): + v0 = base_pt + ix + iy * lx + v1 = v0 + 1 + v2 = v0 + lx + 1 + v3 = v0 + lx + cells_list.extend([4, v0, v1, v2, v3]) + + cells = np.array(cells_list, dtype=np.int64) + cell_types = np.full(total_cells, cell_type, dtype=np.uint8) + mesh = pv.UnstructuredGrid(cells, cell_types, points) + + # Add field data + if include_fields: + _add_spectral_field_data(mesh, field, lx, ly, lz, nppel, total_points) + + return mesh + + +def _add_field_data(mesh, field, vertex_indices, nvert, total_points): + """Add field data to linear mesh.""" + # Velocity + if field.var[1] == 3: + vel = np.zeros((total_points, 3), dtype=field.elem[0].vel.dtype) + for iel, elem in enumerate(field.elem): + for ivert, (ix, iy, iz) in enumerate(vertex_indices): + vel[iel * nvert + ivert] = elem.vel[:, iz, iy, ix] + mesh.point_data["velocity"] = vel + mesh.point_data["velocity_magnitude"] = np.linalg.norm(vel, axis=1) + + # Pressure + if field.var[2] == 1: + pres = np.zeros(total_points, dtype=field.elem[0].pres.dtype) + for iel, elem in enumerate(field.elem): + for ivert, (ix, iy, iz) in enumerate(vertex_indices): + pres[iel * nvert + ivert] = elem.pres[0, iz, iy, ix] + mesh.point_data["pressure"] = pres + + # Temperature + if field.var[3] == 1: + temp = np.zeros(total_points, dtype=field.elem[0].temp.dtype) + for iel, elem in enumerate(field.elem): + for ivert, (ix, iy, iz) in enumerate(vertex_indices): + temp[iel * nvert + ivert] = elem.temp[0, iz, iy, ix] + mesh.point_data["temperature"] = temp + + +def _add_spectral_field_data(mesh, field, lx, ly, lz, nppel, total_points): + """Add field data to spectral mesh.""" + if field.var[1] == 3: + vel = np.zeros((total_points, 3), dtype=field.elem[0].vel.dtype) + for iel, elem in enumerate(field.elem): + for iz in range(lz): + for iy in range(ly): + for ix in range(lx): + ipt = iel * nppel + ix + iy * lx + iz * lx * ly + vel[ipt] = elem.vel[:, iz, iy, ix] + mesh.point_data["velocity"] = vel + mesh.point_data["velocity_magnitude"] = np.linalg.norm(vel, axis=1) + + if field.var[2] == 1: + pres = np.zeros(total_points, dtype=field.elem[0].pres.dtype) + for iel, elem in enumerate(field.elem): + for iz in range(lz): + for iy in range(ly): + for ix in range(lx): + ipt = iel * nppel + ix + iy * lx + iz * lx * ly + pres[ipt] = elem.pres[0, iz, iy, ix] + mesh.point_data["pressure"] = pres + + if field.var[3] == 1: + temp = np.zeros(total_points, dtype=field.elem[0].temp.dtype) + for iel, elem in enumerate(field.elem): + for iz in range(lz): + for iy in range(ly): + for ix in range(lx): + ipt = iel * nppel + ix + iy * lx + iz * lx * ly + temp[ipt] = elem.temp[0, iz, iy, ix] + mesh.point_data["temperature"] = temp + + +def add_boundary_conditions( + mesh: "pv.UnstructuredGrid", + field: HexaData, + bc_field: int = 0, +) -> "pv.PolyData": + """Add boundary condition information to mesh surface. + + Parameters + ---------- + mesh : pv.UnstructuredGrid + Mesh from hexa_to_pyvista + field : HexaData + Original data with BC information + bc_field : int, default=0 + Which BC field to visualize + + Returns + ------- + surface : pv.PolyData + Surface mesh with BC data + """ + if not PYVISTA_AVAILABLE: + raise ImportError("PyVista required") + + # Extract surface + surface = mesh.extract_surface() + + # Initialize BC arrays + n_faces = surface.n_cells + bc_types = np.empty(n_faces, dtype=" str: + """Find BC type for a face.""" + tol = 1e-4 + ndim = field.ndim + nfaces = 2 * ndim + + for iel, elem in enumerate(field.elem): + for iface in range(nfaces): + face_center = compute_face_center(elem, iface, ndim) + dist = np.linalg.norm(center - face_center) + + if dist < tol: + try: + bc = elem.bcs[bc_field, iface][0] + return bc if bc else "" + except (IndexError, KeyError): + return "" + + return "" + + +def _add_bc_legend(plotter: "pv.Plotter", surface: "pv.PolyData") -> None: + """Add BC legend to plotter.""" + unique_bcs = np.unique(surface.cell_data["bc_type"]) + + legend_entries = [] + for bc in unique_bcs: + if bc and bc in DEFAULT_BC_COLORS: + color = DEFAULT_BC_COLORS[bc] + legend_entries.append([bc, color]) + + if legend_entries: + plotter.add_legend(legend_entries, bcolor="white", size=(0.15, 0.15)) diff --git a/src/pymech/viz/viz_protocol.py b/src/pymech/viz/viz_protocol.py new file mode 100644 index 0000000..d6b45f2 --- /dev/null +++ b/src/pymech/viz/viz_protocol.py @@ -0,0 +1,198 @@ +"""Protocol definition for mesh visualization backends. + +This module defines the interface that all visualization backends must implement, +using typing.Protocol for structural subtyping. +""" + +from typing import Any, Literal, Optional, Protocol, Tuple, runtime_checkable + +import numpy as np +from typing_extensions import TypeAlias + +from ..core import HexaData + +# Type aliases +Resolution: TypeAlias = Literal["linear", "spectral"] +View: TypeAlias = Optional[str] +Color: TypeAlias = Optional[str] +Colormap: TypeAlias = Optional[str] + +# BC color scheme - shared across all backends +DEFAULT_BC_COLORS = { + "": (0.0, 0.0, 0.0), # Default/empty - black + "E": (0.0, 0.0, 0.0), # Element connectivity - black + "W": (0.0, 0.0, 0.8), # Wall - dark blue + "v": (0.3, 0.3, 1.0), # Velocity BC - light blue + "O": (0.8, 0.0, 0.0), # Outflow - dark red + "o": (1.0, 0.2, 0.2), # Outflow variant - red + "ON": (0.8, 0.4, 0.0), # Open Neumann - dark orange + "on": (1.0, 0.6, 0.0), # Open Neumann variant - orange + "T": (0.0, 0.8, 0.0), # Temperature BC - dark green + "t": (0.3, 1.0, 0.3), # Temperature variant - green + "I": (0.95, 0.1, 0.6), # Insulated - magenta + "P": (0.5, 0.5, 0.5), # Periodic - gray +} + + +@runtime_checkable +class MeshBackend(Protocol): + """Protocol for mesh visualization backends. + + All visualization backends must implement this interface to ensure + consistent API across different rendering engines. + """ + + def is_available(self) -> bool: + """Check if this backend is available (dependencies installed). + + Returns + ------- + bool + True if backend can be used, False otherwise + """ + ... + + def plot_mesh( + self, + field: HexaData, + resolution: Resolution = "linear", + show_bcs: bool = True, + bc_field: int = 0, + show_edges: bool = True, + style: str = "surface", + color: Color = None, + cmap: Colormap = None, + view: View = None, + screenshot: Optional[str] = None, + return_plotter: bool = False, + figsize: Tuple[float, float] = (10, 8), + **kwargs, + ) -> Optional[Any]: + """Plot mesh with boundary conditions. + + Parameters + ---------- + field : HexaData + Mesh to visualize + resolution : {'linear', 'spectral'}, default='linear' + Mesh resolution + show_bcs : bool, default=True + Whether to color by boundary conditions + bc_field : int, default=0 + Which BC field to visualize + show_edges : bool, default=True + Whether to show mesh edges + style : str, default='surface' + Visualization style + color : str, optional + Uniform color + cmap : str, optional + Colormap for scalar fields + view : str, optional + Camera view angle + screenshot : str, optional + Save screenshot to file + return_plotter : bool, default=False + Return plotter/figure object + figsize : tuple, default=(10, 8) + Figure size + **kwargs + Backend-specific options + + Returns + ------- + plotter : optional + Plotter/figure object if return_plotter=True + """ + ... + + def get_backend_name(self) -> str: + """Get the name of this backend. + + Returns + ------- + str + Backend name (e.g., 'pyvista', 'matplotlib') + """ + ... + + def get_capabilities(self) -> dict: + """Get backend capabilities. + + Returns + ------- + dict + Dictionary describing backend features: + - 'interactive': bool - supports interactive manipulation + - '3d': bool - supports true 3D rendering + - 'jupyter': bool - works in Jupyter notebooks + - 'headless': bool - supports headless rendering + - 'formats': list - supported export formats + """ + ... + + +def get_bc_color(bc_type: str) -> Tuple[float, float, float]: + """Get RGB color for a boundary condition type. + + Parameters + ---------- + bc_type : str + Boundary condition type identifier + + Returns + ------- + tuple + RGB color tuple (values 0-1) + """ + return DEFAULT_BC_COLORS.get(bc_type, (0.5, 0.5, 0.5)) + + +def compute_face_center(elem, iface: int, ndim: int) -> np.ndarray: + """Compute center of a face for an element. + + This is a shared utility function used by multiple backends. + + Parameters + ---------- + elem : Elem + Element object + iface : int + Face index + ndim : int + Number of dimensions (2 or 3) + + Returns + ------- + np.ndarray + 3D coordinates of face center + """ + lx, ly, lz = elem.pos.shape[3], elem.pos.shape[2], elem.pos.shape[1] + + if ndim == 3: + # 3D: 6 faces (x-, x+, y-, y+, z-, z+) + if iface == 0: # x- face + face_pts = elem.pos[:, :, :, 0] + elif iface == 1: # x+ face + face_pts = elem.pos[:, :, :, -1] + elif iface == 2: # y- face + face_pts = elem.pos[:, :, 0, :] + elif iface == 3: # y+ face + face_pts = elem.pos[:, :, -1, :] + elif iface == 4: # z- face + face_pts = elem.pos[:, 0, :, :] + else: # iface == 5, z+ face + face_pts = elem.pos[:, -1, :, :] + else: # 2D + # 2D: 4 faces (x-, x+, y-, y+) + if iface == 0: # x- face + face_pts = elem.pos[:, 0, :, 0] + elif iface == 1: # x+ face + face_pts = elem.pos[:, 0, :, -1] + elif iface == 2: # y- face + face_pts = elem.pos[:, 0, 0, :] + else: # iface == 3, y+ face + face_pts = elem.pos[:, 0, -1, :] + + # Return mean of all points on the face + return face_pts.mean(axis=tuple(range(1, face_pts.ndim))) diff --git a/tests/conftest.py b/tests/conftest.py index 417fc6f..4f9c9e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,3 +6,20 @@ @pytest.fixture(scope="session") def test_data_dir(): return Path(__file__).parent / "data" + + +def pytest_addoption(parser): + # https://pytest.readthedocs.io/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option + parser.addoption( + "--runslow", action="store_true", default=False, help="run slow tests" + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) diff --git a/tests/test_pyvista.py b/tests/test_pyvista.py new file mode 100644 index 0000000..bed6665 --- /dev/null +++ b/tests/test_pyvista.py @@ -0,0 +1,500 @@ +"""Tests for PyVista visualization backend.""" + + +import pytest + +# Try importing backends +try: + import pyvista as pv + + PYVISTA_AVAILABLE = True +except ImportError: + PYVISTA_AVAILABLE = False + pv = None + +try: + import matplotlib.pyplot as plt + + MATPLOTLIB_AVAILABLE = True +except ImportError: + MATPLOTLIB_AVAILABLE = False + plt = None + +# Try importing Protocol and backend modules +try: + from pymech.viz.viz_protocol import MeshBackend + + PROTOCOL_AVAILABLE = True +except ImportError: + PROTOCOL_AVAILABLE = False + +try: + from pymech.viz.pyvista_backend_impl import PyVistaBackend + + PYVISTA_IMPL_AVAILABLE = True +except ImportError: + PYVISTA_IMPL_AVAILABLE = False + +try: + from pymech.viz.matplotlib_backend import MatplotlibBackend + + MATPLOTLIB_IMPL_AVAILABLE = True +except ImportError: + MATPLOTLIB_IMPL_AVAILABLE = False + + +@pytest.mark.skipif(not PYVISTA_AVAILABLE, reason="PyVista not installed") +class TestPyVistaBackend: + """Test suite for PyVista backend.""" + + def test_import(self): + """Test that viz module can be imported.""" + from pymech import viz + + assert hasattr(viz, "plot_mesh") + assert hasattr(viz, "hexa_to_pyvista") + assert hasattr(viz, "add_boundary_conditions") + + def test_hexa_to_pyvista_linear_3d(self, test_data_dir): + """Test conversion of 3D mesh with linear resolution.""" + from pymech import readre2 + from pymech.viz import hexa_to_pyvista + + # Load test data + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + mesh = hexa_to_pyvista(field, resolution="linear", include_fields=False) + + assert isinstance(mesh, pv.UnstructuredGrid) + assert mesh.n_cells == field.nel + assert mesh.n_points == field.nel * 8 # 8 vertices per hex + assert "element_id" in mesh.cell_data + assert len(mesh.cell_data["element_id"]) == field.nel + + def test_hexa_to_pyvista_linear_2d(self, test_data_dir): + """Test conversion of 2D mesh.""" + from pymech import readre2 + from pymech.viz import hexa_to_pyvista + + # Try to find a 2D test file + test_files_2d = ["box2d.re2", "2D_section_R360.re2"] + field = None + for fname in test_files_2d: + test_file = test_data_dir / "nek" / fname + if test_file.exists(): + try: + field = readre2(str(test_file)) + if field.ndim == 2: + break + except Exception: + continue + + if field is None or field.ndim != 2: + pytest.skip("No 2D test file available") + + mesh = hexa_to_pyvista(field, resolution="linear", include_fields=False) + + assert isinstance(mesh, pv.UnstructuredGrid) + assert mesh.n_points == field.nel * 4 # 4 vertices per quad + + def test_hexa_to_pyvista_spectral(self, test_data_dir): + """Test conversion with spectral resolution.""" + from pymech import readre2 + from pymech.viz import hexa_to_pyvista + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + mesh = hexa_to_pyvista(field, resolution="spectral", include_fields=False) + + lx, ly, lz = field.lr1 + expected_cells = field.nel * (lx - 1) * (ly - 1) * (lz - 1) + assert mesh.n_cells == expected_cells + + def test_include_fields(self, test_data_dir): + """Test that velocity/pressure fields are included.""" + from pymech import readre2 + from pymech.viz import hexa_to_pyvista + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + mesh = hexa_to_pyvista(field, resolution="linear", include_fields=True) + + # Check velocity field + if field.var[1] == 3: + assert "velocity" in mesh.point_data + assert mesh.point_data["velocity"].shape[1] == 3 + assert "velocity_magnitude" in mesh.point_data + + # Check pressure field + if field.var[2] == 1: + assert "pressure" in mesh.point_data + + # Check temperature field + if field.var[3] == 1: + assert "temperature" in mesh.point_data + + def test_add_boundary_conditions(self, test_data_dir): + """Test BC extraction and coloring.""" + from pymech import readre2 + from pymech.viz import add_boundary_conditions, hexa_to_pyvista + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + mesh = hexa_to_pyvista(field, resolution="linear", include_fields=False) + surface = add_boundary_conditions(mesh, field, bc_field=0) + + assert isinstance(surface, pv.PolyData) + assert "bc_type" in surface.cell_data + assert "bc_color" in surface.cell_data + assert surface.cell_data["bc_color"].shape[1] == 3 # RGB colors + + @pytest.mark.parametrize("resolution", ["linear", "spectral"]) + def test_plot_mesh_headless(self, test_data_dir, tmp_path, resolution): + """Test plot_mesh in headless mode (screenshot only).""" + pv.OFF_SCREEN = True # Enable headless rendering + + from pymech import readre2 + from pymech.viz import plot_mesh + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + screenshot_path = tmp_path / f"test_{resolution}.png" + + try: + plot_mesh( + field, + backend="pyvista", + resolution=resolution, + show_bcs=True, + screenshot=str(screenshot_path), + jupyter_backend="static", + ) + + assert screenshot_path.exists() + assert screenshot_path.stat().st_size > 0 + except Exception as e: + # Some systems may not support headless rendering + pytest.skip(f"Headless rendering not supported: {e}") + finally: + pv.OFF_SCREEN = False + + def test_plot_mesh_return_plotter(self, test_data_dir): + """Test that return_plotter works.""" + pv.OFF_SCREEN = True + + from pymech import readre2 + from pymech.viz import plot_mesh + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + + try: + plotter = plot_mesh(field, backend="pyvista", return_plotter=True) + assert isinstance(plotter, pv.Plotter) + plotter.close() + except Exception as e: + pytest.skip(f"Headless rendering not supported: {e}") + finally: + pv.OFF_SCREEN = False + + def test_invalid_resolution(self, test_data_dir): + """Test that invalid resolution raises ValueError.""" + from pymech import readre2 + from pymech.viz import hexa_to_pyvista + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + + with pytest.raises(ValueError, match="resolution must be"): + hexa_to_pyvista(field, resolution="invalid") + + +@pytest.mark.skipif(not MATPLOTLIB_AVAILABLE, reason="Matplotlib not installed") +class TestMatplotlibBackend: + """Test suite for Matplotlib backend.""" + + def test_plot_mesh_matplotlib(self, test_data_dir, tmp_path): + """Test plot_mesh with Matplotlib backend.""" + from pymech import readre2 + from pymech.viz import plot_mesh + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + screenshot_path = tmp_path / "test_matplotlib.png" + + fig = plot_mesh( + field, + backend="matplotlib", + show_bcs=False, + screenshot=str(screenshot_path), + return_plotter=True, + ) + + assert fig is not None + assert screenshot_path.exists() + assert screenshot_path.stat().st_size > 0 + plt.close(fig) + + def test_plot_mesh_matplotlib_views(self, test_data_dir): + """Test different camera views with Matplotlib.""" + from pymech import readre2 + from pymech.viz import plot_mesh + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + + for view in ["xy", "xz", "yz"]: + fig = plot_mesh( + field, + backend="matplotlib", + view=view, + show_bcs=False, + return_plotter=True, + ) + assert fig is not None + plt.close(fig) + + +class TestBackendSelection: + """Test backend selection logic.""" + + def test_auto_backend_selection(self, test_data_dir): + """Test that 'auto' backend selects appropriately.""" + from pymech import readre2 + from pymech.viz import plot_mesh + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + + # 'auto' should work if at least one backend is available + if PYVISTA_AVAILABLE or MATPLOTLIB_AVAILABLE: + if PYVISTA_AVAILABLE: + pv.OFF_SCREEN = True + try: + result = plot_mesh( + field, + backend="auto", + screenshot=None, + return_plotter=True, + ) + assert result is not None + if hasattr(result, "close"): + result.close() + elif hasattr(result, "clf"): + plt.close(result) + except Exception as e: + pytest.skip(f"Auto backend selection failed: {e}") + finally: + if PYVISTA_AVAILABLE: + pv.OFF_SCREEN = False + else: + with pytest.raises(ImportError): + plot_mesh(field, backend="auto") + + def test_invalid_backend(self, test_data_dir): + """Test that invalid backend raises ValueError.""" + from pymech import readre2 + from pymech.viz import plot_mesh + + test_file = test_data_dir / "nek" / "box3d.re2" + if not test_file.exists(): + pytest.skip(f"Test file not found: {test_file}") + + field = readre2(str(test_file)) + + with pytest.raises(ValueError, match="Invalid backend"): + plot_mesh(field, backend="invalid") + + +def test_import_without_backends(): + """Test that module can be imported even without backends.""" + # This test ensures graceful degradation + try: + from pymech import viz + + # If import succeeds, module should have main functions defined + assert hasattr(viz, "plot_mesh") + except ImportError: + # If import fails, it's acceptable (no backends available) + pass + + +def test_bc_colors(): + """Test that BC color scheme is defined.""" + from pymech.viz import DEFAULT_BC_COLORS + + # Check that important BC types are defined + assert "" in DEFAULT_BC_COLORS + assert "E" in DEFAULT_BC_COLORS + assert "W" in DEFAULT_BC_COLORS + assert "O" in DEFAULT_BC_COLORS + + # Check that colors are RGB tuples + for bc_type, color in DEFAULT_BC_COLORS.items(): + assert isinstance(color, tuple) + assert len(color) == 3 + assert all(0 <= c <= 1 for c in color) + + +@pytest.mark.skipif(not PROTOCOL_AVAILABLE, reason="Protocol module not available") +class TestProtocolCompliance: + """Test that backends implement the MeshBackend protocol correctly.""" + + @pytest.mark.skipif( + not PYVISTA_IMPL_AVAILABLE, reason="PyVista implementation not available" + ) + def test_pyvista_backend_protocol(self): + """Test PyVista backend implements MeshBackend protocol.""" + from pymech.viz import MeshBackend + from pymech.viz.pyvista_backend_impl import PyVistaBackend + + backend = PyVistaBackend() + + # Runtime check for Protocol compliance + assert isinstance(backend, MeshBackend) + + # Verify required methods exist + assert hasattr(backend, "is_available") + assert hasattr(backend, "plot_mesh") + assert hasattr(backend, "get_backend_name") + assert hasattr(backend, "get_capabilities") + + # Verify method return types + assert isinstance(backend.is_available(), bool) + assert isinstance(backend.get_backend_name(), str) + assert isinstance(backend.get_capabilities(), dict) + + # Verify backend name + assert backend.get_backend_name() == "pyvista" + + @pytest.mark.skipif( + not MATPLOTLIB_IMPL_AVAILABLE, reason="Matplotlib implementation not available" + ) + def test_matplotlib_backend_protocol(self): + """Test Matplotlib backend implements MeshBackend protocol.""" + from pymech.viz import MeshBackend + from pymech.viz.matplotlib_backend import MatplotlibBackend + + backend = MatplotlibBackend() + + # Runtime check for Protocol compliance + assert isinstance(backend, MeshBackend) + + # Verify required methods exist + assert hasattr(backend, "is_available") + assert hasattr(backend, "plot_mesh") + assert hasattr(backend, "get_backend_name") + assert hasattr(backend, "get_capabilities") + + # Verify method return types + assert isinstance(backend.is_available(), bool) + assert isinstance(backend.get_backend_name(), str) + assert isinstance(backend.get_capabilities(), dict) + + # Verify backend name + assert backend.get_backend_name() == "matplotlib" + + @pytest.mark.skipif( + not PYVISTA_IMPL_AVAILABLE, reason="PyVista implementation not available" + ) + def test_pyvista_capabilities(self): + """Test PyVista backend capabilities.""" + from pymech.viz import PyVistaBackend + + backend = PyVistaBackend() + caps = backend.get_capabilities() + + assert caps["interactive"] is True + assert caps["3d"] is True + assert caps["jupyter"] is True + assert caps["headless"] is True + assert "formats" in caps + assert "png" in caps["formats"] + + @pytest.mark.skipif( + not MATPLOTLIB_IMPL_AVAILABLE, reason="Matplotlib implementation not available" + ) + def test_matplotlib_capabilities(self): + """Test Matplotlib backend capabilities.""" + from pymech.viz import MatplotlibBackend + + backend = MatplotlibBackend() + caps = backend.get_capabilities() + + assert caps["interactive"] is False + assert caps["3d"] is True + assert caps["jupyter"] is True + assert caps["headless"] is True + assert "formats" in caps + assert "png" in caps["formats"] + assert "pdf" in caps["formats"] + + +@pytest.mark.skipif(not PROTOCOL_AVAILABLE, reason="Protocol module not available") +class TestBackendDiscovery: + """Test backend discovery and selection.""" + + def test_get_available_backends(self): + """Test get_available_backends() function.""" + from pymech.viz import get_available_backends + + backends = get_available_backends() + + assert isinstance(backends, dict) + + # Check that available backends are present + if PYVISTA_AVAILABLE: + assert "pyvista" in backends + else: + assert "pyvista" not in backends + + if MATPLOTLIB_AVAILABLE: + assert "matplotlib" in backends + else: + assert "matplotlib" not in backends + + def test_backend_availability(self): + """Test is_available() for each backend.""" + # Test PyVista backend + if PYVISTA_IMPL_AVAILABLE: + from pymech.viz import PyVistaBackend + + backend = PyVistaBackend() + assert backend.is_available() == PYVISTA_AVAILABLE + + # Test Matplotlib backend + if MATPLOTLIB_IMPL_AVAILABLE: + from pymech.viz import MatplotlibBackend + + backend = MatplotlibBackend() + assert backend.is_available() == MATPLOTLIB_AVAILABLE