diff --git a/hexrdgui/calibration/hedm/__init__.py b/hexrdgui/calibration/hedm/__init__.py index 3ca9bfc67..3d1b56df5 100644 --- a/hexrdgui/calibration/hedm/__init__.py +++ b/hexrdgui/calibration/hedm/__init__.py @@ -2,11 +2,12 @@ compute_xyo, HEDMCalibrationCallbacks, HEDMCalibrationDialog, - parse_spots_data, ) +from hexrdgui.utils.spots import parse_spots_data from .calibration_options_dialog import HEDMCalibrationOptionsDialog from .calibration_results_dialog import HEDMCalibrationResultsDialog from .calibration_runner import HEDMCalibrationRunner +from .spot_diagnostics_dialog import SpotDiagnosticsDialog __all__ = [ 'compute_xyo', @@ -15,5 +16,6 @@ 'HEDMCalibrationOptionsDialog', 'HEDMCalibrationResultsDialog', 'HEDMCalibrationRunner', + 'SpotDiagnosticsDialog', 'parse_spots_data', ] diff --git a/hexrdgui/calibration/hedm/calibration_dialog.py b/hexrdgui/calibration/hedm/calibration_dialog.py index c536daccf..dc4a6a527 100644 --- a/hexrdgui/calibration/hedm/calibration_dialog.py +++ b/hexrdgui/calibration/hedm/calibration_dialog.py @@ -10,11 +10,17 @@ import hexrd.constants as cnst from hexrd.fitting.calibration.calibrator import Calibrator from hexrd.fitting.calibration.lmfit_param_handling import fix_detector_y +from hexrd.transforms import xfcapi +from hexrd.xrdutil import apply_correction_to_wavelength from hexrdgui.calibration.calibration_dialog import CalibrationDialog from hexrdgui.calibration.hedm.calibration_results_dialog import ( HEDMCalibrationResultsDialog, ) +from hexrdgui.calibration.hedm.spot_diagnostics_dialog import ( + SpotDiagnosticsDialog, +) +from hexrdgui.utils.spots import extract_spot_angles, parse_spots_data from hexrdgui.calibration.tree_item_models import CalibrationTreeItemModel from hexrdgui.calibration.material_calibration_dialog_callbacks import ( MaterialCalibrationDialogCallbacks, @@ -27,6 +33,7 @@ class HEDMCalibrationDialog(CalibrationDialog): apply_refinement_selections_needed = Signal() + spot_diagnostics_requested = Signal() def __init__(self, *args: Any, **kwargs: Any) -> None: # Need to initialize this before setup_connections() is called @@ -75,6 +82,10 @@ def setup_connections(self) -> None: self.save_refit_settings ) + self.extra_ui.spot_diagnostics.clicked.connect( + self.spot_diagnostics_requested.emit, + ) + def show_refinements(self, b: bool) -> None: self.tree_view.setVisible(b) if b: @@ -272,6 +283,9 @@ def setup_connections(self) -> None: self.dialog.apply_refinement_selections_needed.connect( self.apply_refinement_selections ) + self.dialog.spot_diagnostics_requested.connect( + self.show_spot_diagnostics, + ) @property def grain_ids(self) -> np.ndarray: @@ -297,6 +311,26 @@ def xyo_det(self) -> dict[str, list[Any]]: return self._xyo_det + def show_spot_diagnostics(self) -> None: + pred_angs, meas_angs = extract_spot_angles( + self.spots_data, + self.instr, + self.grain_ids, + ) + xyo_pred = compute_xyo(self.calibrators) + + self._spot_diagnostics_dialog = SpotDiagnosticsDialog( + instr=self.instr, + spots_data=self.spots_data, + grain_ids=self.grain_ids, + pred_angs=pred_angs, + meas_angs=meas_angs, + xyo_pred=xyo_pred, + xyo_det=self.xyo_det, + parent=self.dialog.ui, + ) + self._spot_diagnostics_dialog.show() + def on_calibration_finished(self) -> None: super().on_calibration_finished() @@ -339,6 +373,21 @@ def on_calibration_finished(self) -> None: # Do an "undo" self.pop_undo_stack() + self.update_spot_diagnostics() + + def update_spot_diagnostics(self) -> None: + dialog = getattr(self, '_spot_diagnostics_dialog', None) + if dialog is not None and dialog.is_visible: + xyo_pred = compute_xyo(self.calibrators) + pred_angs = compute_pred_angs(self.calibrators) + meas_angs = compute_meas_angs(self.calibrators, self.xyo_det) + dialog.update_data( + self.instr, + xyo_pred=xyo_pred, + pred_angs=pred_angs, + meas_angs=meas_angs, + ) + def push_undo_stack(self) -> Any: self.extra_ui_undo_stack.append(self.dialog.extra_ui_settings) return super().push_undo_stack() @@ -602,64 +651,138 @@ def compute_xyo(calibrators: list[Calibrator]) -> dict[str, list]: return xyo -def parse_spots_data( - spots_data: Any, - instr: Any, - grain_ids: np.ndarray, - ome_period: np.ndarray | None = None, - refit_idx: dict[str, list[Any]] | None = None, -) -> tuple[dict[str, list[Any]], dict[str, list[Any]], dict[str, list[Any]]]: - hkls: dict[str, Any] = {} - xyo_det: dict[str, Any] = {} - idx_0: dict[str, Any] = {} - for det_key, panel in instr.detectors.items(): - hkls[det_key] = [] - xyo_det[det_key] = [] - idx_0[det_key] = [] - - for ig, grain_id in enumerate(grain_ids): - data = spots_data[grain_id][1][det_key] - # Convert to numpy array to make operations easier - data = np.array(data, dtype=object) - - # FIXME: hexrd is not happy if some detectors end up with no - # grain data, which sometimes happens with subpanels like Dexelas - if data.size == 0: - idx_0[det_key].append(np.empty((0,))) - hkls[det_key].append(np.empty((0, 3))) - xyo_det[det_key].append(np.empty((0, 3))) +def compute_pred_angs( + calibrators: list[Calibrator], +) -> dict[str, list[np.ndarray]]: + """Recompute predicted (tth, eta, ome) using current grain/instrument state. + + For each calibrator (grain) and detector, calls oscill_angles_of_hkls() + with current grain parameters and selects the omega solution closest + to the measured omega. + """ + instr = calibrators[0].instr + chi = instr.chi + bvec = instr.beam_vector + tvec_s = instr.tvec + wavelength = instr.beam_wavelength + energy_correction = instr.energy_correction + + pred_angs: dict[str, list[np.ndarray]] = {} + for calibrator in calibrators: + grain = calibrator.grain_params + rmat_c = xfcapi.make_rmat_of_expmap(grain[:3]) + tvec_c = grain[3:6] + vinv_s = grain[6:] + bmat = calibrator.bmatx + ome_period = calibrator.ome_period + + corrected_wavelength = apply_correction_to_wavelength( + wavelength, + energy_correction, + tvec_s, + tvec_c, + ) + + for det_key in instr.detectors: + pred_angs.setdefault(det_key, []) + + hkls = np.asarray( + calibrator.data_dict['hkls'][det_key], + dtype=float, + ) + xyo = np.asarray( + calibrator.data_dict['pick_xys'][det_key], + dtype=float, + ) + + if hkls.size == 0: + pred_angs[det_key].append(np.empty((0, 3))) continue - valid_reflections = data[:, 0] >= 0 - not_saturated = data[:, 4] < panel.saturation_level + # Two omega solutions per HKL + oangs0, oangs1 = xfcapi.oscill_angles_of_hkls( + hkls, + chi, + rmat_c, + bmat, + corrected_wavelength, + v_inv=vinv_s, + beam_vec=bvec, + ) + + # Select the solution whose omega is closest to measured + meas_omes = mapAngle(xyo[:, 2], ome_period) + calc_omes = np.vstack( + [ + mapAngle(oangs0[:, 2], ome_period), + mapAngle(oangs1[:, 2], ome_period), + ] + ) # (2, n) + diff = np.abs( + angularDifference( + np.tile(meas_omes, (2, 1)), + calc_omes, + ) + ) + best = np.argmin(diff, axis=0) # 0 or 1 per reflection + + n = len(hkls) + idx = np.arange(n) + both = np.stack([oangs0, oangs1]) # (2, n, 3) + pred_angs[det_key].append(both[best, idx]) - if refit_idx is None: - idx = np.logical_and(valid_reflections, not_saturated) - idx_0[det_key].append(idx) - else: - idx = refit_idx[det_key][ig] - idx_0[det_key].append(idx) + return pred_angs - if not np.any(idx): - idx_0[det_key].append(np.empty((0,))) - hkls[det_key].append(np.empty((0, 3))) - xyo_det[det_key].append(np.empty((0, 3))) + +def compute_meas_angs( + calibrators: list[Calibrator], + xyo_det: dict[str, list[np.ndarray]], +) -> dict[str, list[np.ndarray]]: + """Convert measured detector XY to angular coordinates using current geometry. + + Uses panel.cart_to_angles() with per-reflection rmat_s (from chi + + measured omega) to account for grain position offset. + """ + instr = calibrators[0].instr + chi = instr.chi + tvec_s = instr.tvec + + meas_angs: dict[str, list[np.ndarray]] = {} + for ig, calibrator in enumerate(calibrators): + grain = calibrator.grain_params + tvec_c = grain[3:6] + + for det_key, panel in instr.detectors.items(): + meas_angs.setdefault(det_key, []) + + xyo = xyo_det[det_key][ig] + if xyo.size == 0: + meas_angs[det_key].append(np.empty((0, 3))) continue - hkls[det_key].append(np.vstack(data[idx, 2])) - meas_omes = np.vstack(data[idx, 6])[:, 2].reshape(sum(idx), 1) - xyo_det_values = np.hstack([np.vstack(data[idx, 7]), meas_omes]) + xy = xyo[:, :2] + omes = xyo[:, 2] - # re-map omegas if need be - if ome_period is not None: - xyo_det_values[:, 2] = mapAngle( - xyo_det_values[:, 2], - ome_period, + # Undistort measured positions before converting to angles + if panel.distortion is not None: + xy = panel.distortion.apply(xy) + + n = len(omes) + result = np.empty((n, 3)) + for i in range(n): + rmat_s = xfcapi.make_sample_rmat(chi, omes[i]) + tth_eta, _ = panel.cart_to_angles( + xy[i : i + 1], + rmat_s=rmat_s, + tvec_s=tvec_s, + tvec_c=tvec_c, ) + result[i, :2] = tth_eta[0] + result[i, 2] = omes[i] - xyo_det[det_key].append(xyo_det_values) + meas_angs[det_key].append(result) - return hkls, xyo_det, idx_0 + return meas_angs REFINEMENT_OPTIONS = { diff --git a/hexrdgui/calibration/hedm/calibration_results_dialog.py b/hexrdgui/calibration/hedm/calibration_results_dialog.py index f5d4c0cfa..71883cb16 100644 --- a/hexrdgui/calibration/hedm/calibration_results_dialog.py +++ b/hexrdgui/calibration/hedm/calibration_results_dialog.py @@ -100,7 +100,7 @@ def setup_canvas(self) -> None: self.ui.canvas_layout.addWidget(canvas) ax[0].grid(True) - ax[0].axis('equal') + ax[0].set_aspect('equal', adjustable='box') ax[0].set_xlabel('detector X [mm]') ax[0].set_ylabel('detector Y [mm]') diff --git a/hexrdgui/calibration/hedm/spot_diagnostics_dialog.py b/hexrdgui/calibration/hedm/spot_diagnostics_dialog.py new file mode 100644 index 000000000..0dd9d7f2f --- /dev/null +++ b/hexrdgui/calibration/hedm/spot_diagnostics_dialog.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QSizePolicy, QWidget + +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure +import numpy as np + +from hexrdgui.navigation_toolbar import NavigationToolbar +from hexrdgui.ui_loader import UiLoader +from hexrdgui.utils.dialog import add_help_url +from hexrdgui.utils.spots import extract_spot_angles, extract_spot_xyo + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.collections import PathCollection + from numpy.typing import NDArray + + from hexrd.core.instrument import HEDMInstrument + + from hexrdgui.utils.spots import DetGrainArrays, SpotsData + + +# Quantity definitions: key -> (label, combo_label, units, default_bounds) +# combo_label uses Unicode for Qt widgets; label uses LaTeX for matplotlib +QUANTITY_CONFIG: dict[str, dict[str, str | float]] = { + 'tth': { + 'label': r'$2\theta$', + 'combo_label': '2θ', + 'units': 'degrees', + 'default_bounds': 0.01, + }, + 'eta': { + 'label': r'$\eta$', + 'combo_label': 'η', + 'units': 'degrees', + 'default_bounds': 0.05, + }, + 'ome': { + 'label': r'$\omega$', + 'combo_label': 'ω', + 'units': 'degrees', + 'default_bounds': 0.4, + }, + 'x': { + 'label': 'X', + 'combo_label': 'X', + 'units': 'mm', + 'default_bounds': 0.2, + }, + 'y': { + 'label': 'Y', + 'combo_label': 'Y', + 'units': 'mm', + 'default_bounds': 0.2, + }, +} + + +def _extract_detector_info( + instr: HEDMInstrument, +) -> tuple[list[str], dict[str, tuple[float, float]], dict[str, NDArray[np.floating]]]: + """Extract det_keys, det_dims, and det_tvecs from an instrument.""" + det_keys = list(instr.detectors) + det_dims = { + k: (panel.col_dim, panel.row_dim) for k, panel in instr.detectors.items() + } + det_tvecs = {k: panel.tvec.copy() for k, panel in instr.detectors.items()} + return det_keys, det_dims, det_tvecs + + +class SpotDiagnosticsDialog: + """Dialog for visualizing spot residuals (pred vs meas). + + There are two ways to construct this dialog: + + 1. **From raw spots data** (fit-grains path): pass ``instr``, + ``spots_data``, and ``grain_ids``. The dialog will compute + ``pred_angs``, ``meas_angs``, ``xyo_pred``, and ``xyo_det`` + internally. + + 2. **With pre-computed arrays** (HEDM calibration path): also pass + ``pred_angs``, ``meas_angs``, ``xyo_pred``, and/or ``xyo_det`` + to override the values derived from ``spots_data``. + """ + + def __init__( + self, + instr: HEDMInstrument, + spots_data: SpotsData, + grain_ids: list[int] | NDArray[np.integer], + *, + pred_angs: DetGrainArrays | None = None, + meas_angs: DetGrainArrays | None = None, + xyo_pred: DetGrainArrays | None = None, + xyo_det: DetGrainArrays | None = None, + parent: QWidget | None = None, + ) -> None: + loader = UiLoader() + self.ui = loader.load_file('spot_diagnostics_dialog.ui', parent) + + self.grain_ids: list[int] = ( + grain_ids.tolist() if isinstance(grain_ids, np.ndarray) else grain_ids + ) + + # Derive detector info from instrument + det_keys, det_dims, det_tvecs = _extract_detector_info(instr) + self.det_keys: list[str] = det_keys + self.det_dims: dict[str, tuple[float, float]] = det_dims + self.det_tvecs: dict[str, NDArray[np.floating]] = det_tvecs + + # Compute from spots_data, then allow overrides + default_pred_angs, default_meas_angs = extract_spot_angles( + spots_data, + instr, + self.grain_ids, + ) + default_xyo_pred, default_xyo_det = extract_spot_xyo( + spots_data, + instr, + self.grain_ids, + ) + + self.pred_angs: DetGrainArrays = ( + pred_angs if pred_angs is not None else default_pred_angs + ) + self.meas_angs: DetGrainArrays = ( + meas_angs if meas_angs is not None else default_meas_angs + ) + self.xyo_pred: DetGrainArrays = ( + xyo_pred if xyo_pred is not None else default_xyo_pred + ) + self.xyo_det: DetGrainArrays = ( + xyo_det if xyo_det is not None else default_xyo_det + ) + + self.fig: Figure | None = None + self.canvas: FigureCanvas | None = None + self.toolbar: NavigationToolbar | None = None + + self.setup_combo_boxes() + self.setup_canvas() + self.setup_connections() + add_help_url( + self.ui.button_box, + 'calibration/rotation_series/#spot-diagnostics', + ) + self.update_canvas() + + def setup_connections(self) -> None: + self.ui.quantity.currentIndexChanged.connect( + self.on_quantity_changed, + ) + self.ui.bounds.valueChanged.connect(self.update_canvas) + self.ui.histogram_bins.valueChanged.connect(self.update_canvas) + self.ui.show_all_detectors.toggled.connect( + self.show_all_detectors_toggled, + ) + self.ui.detector.currentIndexChanged.connect(self.update_canvas) + self.ui.show_all_grains.toggled.connect(self.show_all_grains_toggled) + self.ui.grain_id.currentIndexChanged.connect(self.update_canvas) + self.ui.match_detector_shape.toggled.connect(self.update_canvas) + + def setup_combo_boxes(self) -> None: + self.ui.quantity.clear() + for key, config in QUANTITY_CONFIG.items(): + combo_label = f'{config["combo_label"]} ({config["units"]})' + self.ui.quantity.addItem(combo_label, key) + + self.ui.detector.clear() + for det_key in self.det_keys: + self.ui.detector.addItem(det_key) + + self.ui.grain_id.clear() + for grain_id in sorted(self.grain_ids): + self.ui.grain_id.addItem(str(grain_id), grain_id) + + self.update_enable_states() + + def update_enable_states(self) -> None: + enable_grain_id = not self.show_all_grains and self.ui.grain_id.count() > 1 + self.ui.grain_id.setEnabled(enable_grain_id) + self.ui.grain_id_label.setEnabled(enable_grain_id) + + show_all_grains_visible: bool = self.ui.grain_id.count() > 1 + self.ui.show_all_grains.setVisible(show_all_grains_visible) + + enable_detector = not self.show_all_detectors and self.ui.detector.count() > 1 + self.ui.detector.setEnabled(enable_detector) + self.ui.detector_label.setEnabled(enable_detector) + + show_all_detectors_visible: bool = self.ui.detector.count() > 1 + self.ui.show_all_detectors.setVisible(show_all_detectors_visible) + + def setup_canvas(self) -> None: + canvas = FigureCanvas(Figure(constrained_layout=True)) + canvas.setSizePolicy( + QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.Expanding, + ) + + self.ui.canvas_layout.addWidget(canvas) + + self.toolbar = NavigationToolbar(canvas, self.ui, coordinates=True) + self.ui.canvas_layout.addWidget(self.toolbar) + self.ui.canvas_layout.setAlignment( + self.toolbar, + Qt.AlignmentFlag.AlignCenter, + ) + + self.fig = canvas.figure + self.canvas = canvas + + def update_data( + self, + instr: HEDMInstrument, + *, + xyo_pred: DetGrainArrays, + pred_angs: DetGrainArrays, + meas_angs: DetGrainArrays, + ) -> None: + """Update data after calibration refinement and refresh plots. + + The instrument may have changed (e.g. detector translations), + so detector info is re-derived. The caller provides recomputed + predicted/measured arrays from the calibrator model. + """ + _, self.det_dims, self.det_tvecs = _extract_detector_info(instr) + self.xyo_pred = xyo_pred + self.pred_angs = pred_angs + self.meas_angs = meas_angs + self.update_canvas() + + @property + def is_visible(self) -> bool: + return self.ui.isVisible() + + def show(self) -> None: + self.ui.show() + + def exec(self) -> int: + return self.ui.exec() + + @property + def selected_quantity(self) -> str: + return self.ui.quantity.currentData() + + @property + def selected_detector_key(self) -> str: + return self.ui.detector.currentText() + + @property + def selected_grain_id(self) -> int: + return self.ui.grain_id.currentData() + + @property + def show_all_grains(self) -> bool: + return self.ui.show_all_grains.isChecked() + + @property + def show_all_detectors(self) -> bool: + return self.ui.show_all_detectors.isChecked() + + @property + def match_detector_shape(self) -> bool: + return self.ui.match_detector_shape.isChecked() + + @property + def bounds_value(self) -> float: + return self.ui.bounds.value() + + @property + def num_bins(self) -> int: + return self.ui.histogram_bins.value() + + @property + def grain_indices_to_plot(self) -> list[int]: + if self.show_all_grains: + return list(range(len(self.grain_ids))) + return [self.grain_ids.index(self.selected_grain_id)] + + @property + def det_keys_to_plot(self) -> list[str]: + if self.show_all_detectors: + return self.det_keys + return [self.selected_detector_key] + + def on_quantity_changed(self) -> None: + key = self.selected_quantity + if key is not None: + config = QUANTITY_CONFIG[key] + self.ui.bounds.setValue(config['default_bounds']) + self.update_canvas() + + def show_all_grains_toggled(self) -> None: + self.update_enable_states() + self.update_canvas() + + def show_all_detectors_toggled(self) -> None: + self.update_enable_states() + self.update_canvas() + + def _get_data_for_quantity( + self, + det_key: str, + grain_idx: int, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Get predicted, measured, and scatter positions for a quantity. + + Returns + ------- + exp_val + Predicted values (degrees or mm). + sim_val + Measured values (degrees or mm). + scatter_x + X positions for spatial scatter. + scatter_y + Y positions for spatial scatter. + """ + key = self.selected_quantity + + pred_a = self.pred_angs[det_key][grain_idx] + meas_a = self.meas_angs[det_key][grain_idx] + xyo_p = self.xyo_pred[det_key][grain_idx] + xyo_m = self.xyo_det[det_key][grain_idx] + + # Offset XY by detector translation to place in instrument frame + tvec = self.det_tvecs[det_key] + scatter_x: np.ndarray = xyo_p[:, 0] + tvec[0] + scatter_y: np.ndarray = xyo_p[:, 1] + tvec[1] + + exp_val: np.ndarray + sim_val: np.ndarray + if key == 'tth': + exp_val = np.degrees(pred_a[:, 0]) + sim_val = np.degrees(meas_a[:, 0]) + elif key == 'eta': + exp_val = np.degrees(pred_a[:, 1]) + sim_val = np.degrees(meas_a[:, 1]) + elif key == 'ome': + exp_val = np.degrees(pred_a[:, 2]) + sim_val = np.degrees(meas_a[:, 2]) + elif key == 'x': + exp_val = xyo_p[:, 0] + tvec[0] + sim_val = xyo_m[:, 0] + tvec[0] + elif key == 'y': + exp_val = xyo_p[:, 1] + tvec[1] + sim_val = xyo_m[:, 1] + tvec[1] + else: + raise ValueError(f'Unknown quantity: {key}') + + return exp_val, sim_val, scatter_x, scatter_y + + def update_canvas(self) -> None: + assert self.fig is not None + assert self.canvas is not None + + # Clear the entire figure (removes all axes including colorbars) + self.fig.clear() + + key = self.selected_quantity + if key is None: + self.canvas.draw() + return + + det_keys = self.det_keys_to_plot + if not det_keys: + self.canvas.draw() + return + + config = QUANTITY_CONFIG[key] + bounds: float = self.bounds_value + nbins: int = self.num_bins + label: str = str(config['label']) + units: str = str(config['units']) + + grain_indices = self.grain_indices_to_plot + + # Collect all data across selected detectors and grains + all_diff: list[np.ndarray] = [] + all_exp: list[np.ndarray] = [] + all_scatter_x: list[np.ndarray] = [] + all_scatter_y: list[np.ndarray] = [] + + for det_key in det_keys: + for grain_idx in grain_indices: + exp_val, sim_val, sx, sy = self._get_data_for_quantity( + det_key, + grain_idx, + ) + if exp_val.size == 0: + continue + + diff: np.ndarray = sim_val - exp_val + all_diff.append(diff) + all_exp.append(exp_val) + all_scatter_x.append(sx) + all_scatter_y.append(sy) + + if not all_diff: + self.fig.text( + 0.5, + 0.5, + 'No spot data', + ha='center', + va='center', + fontsize=16, + color='gray', + ) + self.canvas.draw() + return + + cat_diff: np.ndarray = np.concatenate(all_diff) + cat_exp: np.ndarray = np.concatenate(all_exp) + cat_scatter_x: np.ndarray = np.concatenate(all_scatter_x) + cat_scatter_y: np.ndarray = np.concatenate(all_scatter_y) + + # Recreate subplots fresh each time (avoids colorbar accumulation) + ax_hist: Axes + ax_scatter: Axes + ax_line: Axes + ax_hist, ax_scatter, ax_line = self.fig.subplots(1, 3) + + diff_label = f'{label}' + r'$_{Meas}$' + f' - {label}' + r'$_{Pred}$' + diff_with_units = f'{diff_label} ({units})' + + # 1. Histogram + hist_bins: list[float] = np.linspace(-bounds, bounds, nbins).tolist() + ax_hist.hist(cat_diff, hist_bins, edgecolor='black', linewidth=0.5) + ax_hist.set_xlabel(diff_with_units) + ax_hist.set_ylabel('Number of Spots') + ax_hist.set_xlim(-bounds, bounds) + ax_hist.set_title('Residual Histogram') + + # 2. Spatial scatter plot + sc: PathCollection = ax_scatter.scatter( + cat_scatter_x, + cat_scatter_y, + c=cat_diff, + cmap='RdBu_r', + vmin=-bounds, + vmax=bounds, + s=10, + ) + self.fig.colorbar(sc, ax=ax_scatter, label=diff_with_units) + ax_scatter.set_xlabel(r'$X^{D}$ (mm)') + ax_scatter.set_ylabel(r'$Y^{D}$ (mm)') + ax_scatter.set_title('Spatial Distribution') + + if self.match_detector_shape: + # Compute bounding box of selected detectors in instrument frame + x_min: float = min( + self.det_tvecs[k][0] - self.det_dims[k][0] / 2 for k in det_keys + ) + x_max: float = max( + self.det_tvecs[k][0] + self.det_dims[k][0] / 2 for k in det_keys + ) + y_min: float = min( + self.det_tvecs[k][1] - self.det_dims[k][1] / 2 for k in det_keys + ) + y_max: float = max( + self.det_tvecs[k][1] + self.det_dims[k][1] / 2 for k in det_keys + ) + ax_scatter.set_xlim(x_min, x_max) + ax_scatter.set_ylim(y_min, y_max) + ax_scatter.set_aspect('equal') + + # 3. Line/scatter plot vs quantity + # (y-axis label omitted -- the adjacent colorbar already shows it) + ax_line.scatter(cat_exp, cat_diff, s=5, alpha=0.5) + ax_line.set_ylim(-bounds, bounds) + ax_line.set_xlabel(f'{label}' + r'$_{Pred}$' + f' ({units})') + ax_line.set_title(f'Residual vs {label}') + + self.canvas.draw() diff --git a/hexrdgui/indexing/fit_grains_results_dialog.py b/hexrdgui/indexing/fit_grains_results_dialog.py index 8c064c8c2..bb68de322 100644 --- a/hexrdgui/indexing/fit_grains_results_dialog.py +++ b/hexrdgui/indexing/fit_grains_results_dialog.py @@ -56,6 +56,7 @@ def __init__( material: Material | None = None, parent: QObject | None = None, allow_export_workflow: bool = True, + spots_data: dict | None = None, ) -> None: super().__init__(parent) @@ -74,6 +75,7 @@ def __init__( self.data = data self.data_model = GrainsTableModel(data) self.material = material + self.spots_data = spots_data self.canvas: FigureCanvas | None = None self.fig: Figure | None = None self.scatter_artist: Any = None @@ -91,6 +93,12 @@ def __init__( self.ui.splitter.setStretchFactor(0, 1) self.ui.splitter.setStretchFactor(1, 10) + self.ui.spot_diagnostics.setEnabled(spots_data is not None) + if spots_data is None: + self.ui.spot_diagnostics.setToolTip( + 'No spot data available. Re-run fit grains to generate it.' + ) + self.ui.export_workflow.setEnabled(allow_export_workflow) if not allow_export_workflow: # Give some possible reasons @@ -359,6 +367,7 @@ def setup_connections(self) -> None: self.cylindrical_reference_toggled ) self.ui.export_workflow.clicked.connect(self.on_export_workflow_clicked) + self.ui.spot_diagnostics.clicked.connect(self.show_spot_diagnostics) for name in ('x', 'y', 'z'): action = getattr(self, f'set_view_{name}') @@ -714,6 +723,27 @@ def draw_idle(self) -> None: assert self.canvas is not None self.canvas.draw_idle() + def show_spot_diagnostics(self) -> None: + if self.spots_data is None: + return + + from hexrdgui.calibration.hedm.spot_diagnostics_dialog import ( + SpotDiagnosticsDialog, + ) + from hexrdgui.indexing.create_config import create_indexing_config + + cfg = create_indexing_config() + instr = cfg.instrument.hedm + grain_ids = sorted(self.spots_data.keys()) + + self._spot_diagnostics_dialog = SpotDiagnosticsDialog( + instr=instr, + spots_data=self.spots_data, + grain_ids=grain_ids, + parent=self.ui, + ) + self._spot_diagnostics_dialog.show() + def on_grains_table_modified(self) -> None: # Update our grains table self.data = self.data_model.full_grains_table diff --git a/hexrdgui/indexing/run.py b/hexrdgui/indexing/run.py index ad6a06722..9bb5a262a 100644 --- a/hexrdgui/indexing/run.py +++ b/hexrdgui/indexing/run.py @@ -566,12 +566,14 @@ def run_fit_grains(self) -> None: if self.cancel_tracker: kwargs['check_if_canceled_func'] = self.cancel_tracker.get_need_to_cancel - self.fit_grains_results = fit_grains(**kwargs) + kwargs['return_pull_spots_data'] = True + result = fit_grains(**kwargs) if self.operation_canceled: # Operation was canceled return - assert self.fit_grains_results is not None + assert result is not None + self.fit_grains_results, self.spots_data = result self.result_grains_table = create_grains_table(self.fit_grains_results) print('Fit Grains Complete') HexrdConfig().fit_grains_grains_table = copy.deepcopy(self.result_grains_table) @@ -613,9 +615,11 @@ def view_fit_grains_results(self) -> None: ) assert self.result_grains_table is not None + spots_data = getattr(self, 'spots_data', None) dialog = create_fit_grains_results_dialog( grains_table=self.result_grains_table, allow_export_workflow=allow_export_workflow, + spots_data=spots_data, parent=self._parent, ) self.fit_grains_results_dialog = dialog @@ -651,6 +655,7 @@ def create_fit_grains_results_dialog( grains_table: np.ndarray, parent: QWidget | None = None, allow_export_workflow: bool = True, + spots_data: dict | None = None, ) -> Any: # Use the material to compute stress from strain indexing_config = HexrdConfig().indexing_config @@ -666,6 +671,7 @@ def create_fit_grains_results_dialog( material=material, parent=parent, allow_export_workflow=allow_export_workflow, + spots_data=spots_data, ) dialog.ui.resize(1200, 800) diff --git a/hexrdgui/resources/ui/fit_grains_results_dialog.ui b/hexrdgui/resources/ui/fit_grains_results_dialog.ui index 47df4332d..27bf25018 100644 --- a/hexrdgui/resources/ui/fit_grains_results_dialog.ui +++ b/hexrdgui/resources/ui/fit_grains_results_dialog.ui @@ -555,6 +555,13 @@ + + + + Show Spot Diagnostics + + + diff --git a/hexrdgui/resources/ui/hedm_calibration_custom_widgets.ui b/hexrdgui/resources/ui/hedm_calibration_custom_widgets.ui index 226628541..773e20d4a 100644 --- a/hexrdgui/resources/ui/hedm_calibration_custom_widgets.ui +++ b/hexrdgui/resources/ui/hedm_calibration_custom_widgets.ui @@ -165,6 +165,16 @@ + + + + Show Spot Diagnostics + + + View diagnostic plots comparing predicted and measured spot positions and angles + + + diff --git a/hexrdgui/resources/ui/spot_diagnostics_dialog.ui b/hexrdgui/resources/ui/spot_diagnostics_dialog.ui new file mode 100644 index 000000000..babb80fe8 --- /dev/null +++ b/hexrdgui/resources/ui/spot_diagnostics_dialog.ui @@ -0,0 +1,219 @@ + + + spot_diagnostics_dialog + + + + 0 + 0 + 1400 + 900 + + + + + 800 + 600 + + + + Spot Diagnostics + + + + + + Quantity: + + + + + + + QComboBox::AdjustToContents + + + + + + + Bounds: + + + + + + + false + + + 4 + + + 0.000100000000000 + + + 100.000000000000000 + + + 0.010000000000000 + + + 0.010000000000000 + + + + + + + Histogram bins: + + + + + + + false + + + 5 + + + 500 + + + 50 + + + + + + + Match detector shape + + + true + + + + + + + Show all detectors + + + true + + + + + + + false + + + Detector: + + + + + + + false + + + combobox-popup: 0; + + + QComboBox::AdjustToContents + + + + + + + Show all grains + + + true + + + + + + + false + + + Grain ID: + + + + + + + false + + + combobox-popup: 0; + + + QComboBox::AdjustToContents + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + + + QDialogButtonBox::Close + + + + + + + quantity + bounds + histogram_bins + match_detector_shape + show_all_detectors + detector + show_all_grains + grain_id + + + + + button_box + rejected() + spot_diagnostics_dialog + reject() + + + 700 + 874 + + + 700 + 449 + + + + + diff --git a/hexrdgui/utils/spots.py b/hexrdgui/utils/spots.py new file mode 100644 index 000000000..55a91898d --- /dev/null +++ b/hexrdgui/utils/spots.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeAlias + +import numpy as np + +if TYPE_CHECKING: + from collections.abc import Sequence + + from numpy.typing import NDArray + + from hexrd.core.instrument import HEDMInstrument + +# Type alias for the spots_data structure returned by fit_grains(). +# Actual structure: {grain_id: (complvec, {det_key: [spot_list]})} +# Each spot is a list of 9 elements (see pull_spots trimmed output). +# Some callers convert to a list before passing (indexed by position), +# so we accept any int-indexable mapping or sequence. +_SpotsEntry: TypeAlias = 'tuple[list[np.bool_], dict[str, list]]' +SpotsData: TypeAlias = ( + 'dict[int, _SpotsEntry] | list[_SpotsEntry] | Sequence[_SpotsEntry]' +) + +# Per-detector, per-grain array dict used throughout this module. +DetGrainArrays: TypeAlias = 'dict[str, list[np.ndarray]]' + +# The full return type of filter_spots_data (string-keyed for each field). +FilteredSpotsResult: TypeAlias = 'dict[str, dict[str, list[np.ndarray]]]' + + +def filter_spots_data( + spots_data: SpotsData, + instr: HEDMInstrument, + grain_ids: NDArray[np.integer] | Sequence[int], + ome_period: NDArray[np.floating] | None = None, + refit_idx: dict[str, list[NDArray[np.bool_]]] | None = None, +) -> FilteredSpotsResult: + """Filter raw spots data and extract all useful columns in one pass. + + This is the single entry point for parsing the per-grain, per-detector + spot lists produced by ``pull_spots()`` (via ``fit_grains()``). + + Parameters + ---------- + spots_data : dict + ``{grain_id: (complvec, {det_key: [spot_list]})}`` + instr : HEDMInstrument + Instrument with ``.detectors`` dict. + grain_ids : array-like + Grain IDs to process (order is preserved). + ome_period : (2,) array, optional + If given, measured omegas are remapped into this period. + refit_idx : dict, optional + ``{det_key: [bool_array_per_grain]}``. When provided, these + masks are used *instead* of the default valid-reflection / + saturation filter. + + Returns + ------- + dict with the following keys, each mapping to + ``{det_key: [one_array_per_grain]}``: + + - ``'pred_angs'`` : Nx3 predicted [tth, eta, ome] (from col 5) + - ``'meas_angs'`` : Nx3 measured [tth, eta, ome] (from col 6) + - ``'hkls'`` : Nx3 Miller indices (from col 2) + - ``'meas_xy'`` : Nx2 measured detector [x, y] (from col 7) + - ``'pred_xy'`` : Nx2 predicted detector [x, y] (from col 8) + - ``'idx'`` : boolean mask used for filtering + """ + from hexrd.rotations import mapAngle + + result_keys = ('pred_angs', 'meas_angs', 'hkls', 'meas_xy', 'pred_xy', 'idx') + out: FilteredSpotsResult = {k: {} for k in result_keys} + + for det_key, panel in instr.detectors.items(): + for k in result_keys: + out[k][det_key] = [] + + for ig, grain_id in enumerate(grain_ids): + raw = spots_data[grain_id][1][det_key] + data: np.ndarray = np.array(raw, dtype=object) + + if data.size == 0: + out['pred_angs'][det_key].append(np.empty((0, 3))) + out['meas_angs'][det_key].append(np.empty((0, 3))) + out['hkls'][det_key].append(np.empty((0, 3))) + out['meas_xy'][det_key].append(np.empty((0, 2))) + out['pred_xy'][det_key].append(np.empty((0, 2))) + out['idx'][det_key].append(np.empty((0,), dtype=bool)) + continue + + # Determine the filter mask + if refit_idx is None: + valid_reflections = data[:, 0] >= 0 + not_saturated = data[:, 4] < panel.saturation_level + idx = np.logical_and(valid_reflections, not_saturated) + else: + idx = refit_idx[det_key][ig] + + out['idx'][det_key].append(idx) + + if not np.any(idx): + out['pred_angs'][det_key].append(np.empty((0, 3))) + out['meas_angs'][det_key].append(np.empty((0, 3))) + out['hkls'][det_key].append(np.empty((0, 3))) + out['meas_xy'][det_key].append(np.empty((0, 2))) + out['pred_xy'][det_key].append(np.empty((0, 2))) + continue + + out['pred_angs'][det_key].append(np.vstack(data[idx, 5])) + out['meas_angs'][det_key].append(np.vstack(data[idx, 6])) + out['hkls'][det_key].append(np.vstack(data[idx, 2])) + out['meas_xy'][det_key].append(np.vstack(data[idx, 7])) + out['pred_xy'][det_key].append(np.vstack(data[idx, 8])) + + # Remap omegas if requested + if ome_period is not None: + meas_angs = out['meas_angs'][det_key][-1] + meas_angs[:, 2] = mapAngle(meas_angs[:, 2], ome_period) + + return out + + +def extract_spot_angles( + spots_data: SpotsData, + instr: HEDMInstrument, + grain_ids: NDArray[np.integer] | Sequence[int], +) -> tuple[DetGrainArrays, DetGrainArrays]: + """Extract predicted and measured angles from raw spots data. + + Returns + ------- + pred_angs : {det_key: [Nx3 array per grain]} + Predicted [tth, eta, ome]. + meas_angs : {det_key: [Nx3 array per grain]} + Measured [tth, eta, ome]. + """ + out = filter_spots_data(spots_data, instr, grain_ids) + return out['pred_angs'], out['meas_angs'] + + +def extract_spot_xyo( + spots_data: SpotsData, + instr: HEDMInstrument, + grain_ids: NDArray[np.integer] | Sequence[int], +) -> tuple[DetGrainArrays, DetGrainArrays]: + """Extract predicted and measured XY+omega from raw spots data. + + Returns + ------- + xyo_pred : {det_key: [Nx3 array per grain]} + Predicted [x, y, ome]. + xyo_det : {det_key: [Nx3 array per grain]} + Measured [x, y, ome]. + """ + out = filter_spots_data(spots_data, instr, grain_ids) + xyo_pred: DetGrainArrays = {} + xyo_det: DetGrainArrays = {} + + for det_key in out['pred_xy']: + xyo_pred[det_key] = [] + xyo_det[det_key] = [] + for i in range(len(out['pred_xy'][det_key])): + pred_xy: np.ndarray = out['pred_xy'][det_key][i] + meas_xy: np.ndarray = out['meas_xy'][det_key][i] + pred_angs: np.ndarray = out['pred_angs'][det_key][i] + meas_angs: np.ndarray = out['meas_angs'][det_key][i] + + if pred_xy.shape[0] == 0: + xyo_pred[det_key].append(np.empty((0, 3))) + xyo_det[det_key].append(np.empty((0, 3))) + else: + xyo_pred[det_key].append(np.column_stack([pred_xy, pred_angs[:, 2]])) + xyo_det[det_key].append(np.column_stack([meas_xy, meas_angs[:, 2]])) + + return xyo_pred, xyo_det + + +def parse_spots_data( + spots_data: SpotsData, + instr: HEDMInstrument, + grain_ids: NDArray[np.integer] | Sequence[int], + ome_period: NDArray[np.floating] | None = None, + refit_idx: dict[str, list[NDArray[np.bool_]]] | None = None, +) -> tuple[DetGrainArrays, DetGrainArrays, dict[str, list[np.ndarray]]]: + """Parse spots data for calibration, returning hkls, xyo_det, and idx. + + This is the original interface used by the HEDM calibration workflow. + + Returns + ------- + hkls : {det_key: [Nx3 array per grain]} + xyo_det : {det_key: [Nx3 array per grain]} + Measured [x, y, ome]. + idx : {det_key: [bool_array per grain]} + """ + out = filter_spots_data( + spots_data, + instr, + grain_ids, + ome_period=ome_period, + refit_idx=refit_idx, + ) + + # Build xyo_det: [meas_xy, meas_ome] combined into Nx3 + xyo_det: DetGrainArrays = {} + for det_key in out['meas_xy']: + xyo_det[det_key] = [] + for i in range(len(out['meas_xy'][det_key])): + meas_xy: np.ndarray = out['meas_xy'][det_key][i] + meas_angs: np.ndarray = out['meas_angs'][det_key][i] + if meas_xy.shape[0] == 0: + xyo_det[det_key].append(np.empty((0, 3))) + else: + meas_omes = meas_angs[:, 2].reshape(-1, 1) + xyo_det[det_key].append(np.hstack([meas_xy, meas_omes])) + + return out['hkls'], xyo_det, out['idx'] diff --git a/tests/test_spot_diagnostics.py b/tests/test_spot_diagnostics.py new file mode 100644 index 000000000..5a94c304c --- /dev/null +++ b/tests/test_spot_diagnostics.py @@ -0,0 +1,171 @@ +""" +Test the SpotDiagnosticsDialog integration with fit_grains spots data. + +Exercises: + 1. Run fit_grains with return_pull_spots_data=True on NIST ruby example + 2. Extract angles and XY data from spots + 3. Create and interact with SpotDiagnosticsDialog + 4. Verify all quantity/detector/grain selection modes render without error + +Run with: + cd hexrdgui/tests && QT_QPA_PLATFORM=offscreen python -m pytest test_spot_diagnostics.py -v -s +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any + +import numpy as np +import pytest + +from PySide6.QtWidgets import QApplication + +from hexrd.hedm import config +from hexrd.hedm.fitgrains import fit_grains + +from hexrdgui.calibration.hedm.spot_diagnostics_dialog import ( + QUANTITY_CONFIG, + SpotDiagnosticsDialog, +) +from hexrdgui.utils.spots import extract_spot_angles, extract_spot_xyo + + +@pytest.fixture +def single_ge_path(example_repo_path: Path) -> Path: + return example_repo_path / 'NIST_ruby' / 'single_GE' + + +@pytest.fixture +def fit_grains_with_spots(single_ge_path: Path) -> tuple[Any, list, dict]: + """Run fit_grains and return (instrument, fit_results, spots_data).""" + include_path = single_ge_path / 'include' + config_path = include_path / 'ruby_config.yml' + grains_path = single_ge_path / 'results' / 'ruby-b035e' / 'scan-0' / 'grains.out' + + os.chdir(str(include_path)) + + cfg = config.open(config_path)[0] + cfg.working_dir = str(include_path) + grains_table = np.loadtxt(str(grains_path), ndmin=2) + + fit_results, spots_data = fit_grains( + cfg, + grains_table, + write_spots_files=False, + return_pull_spots_data=True, + ) + + instr = cfg.instrument.hedm + return instr, fit_results, spots_data + + +def test_extract_spot_angles( + fit_grains_with_spots: tuple[Any, list, dict], +) -> None: + """Verify extract_spot_angles returns valid angular data.""" + instr, fit_results, spots_data = fit_grains_with_spots + grain_ids = sorted(spots_data.keys()) + + pred_angs, meas_angs = extract_spot_angles(spots_data, instr, grain_ids) + + det_keys = list(instr.detectors) + assert set(pred_angs.keys()) == set(det_keys) + assert set(meas_angs.keys()) == set(det_keys) + + for det_key in det_keys: + assert len(pred_angs[det_key]) == len(grain_ids) + assert len(meas_angs[det_key]) == len(grain_ids) + + for i in range(len(grain_ids)): + p = pred_angs[det_key][i] + m = meas_angs[det_key][i] + assert p.ndim == 2 and p.shape[1] == 3 + assert m.ndim == 2 and m.shape[1] == 3 + assert p.shape[0] == m.shape[0] + assert p.shape[0] > 0, 'Expected at least some valid spots' + + +def test_extract_spot_xyo( + fit_grains_with_spots: tuple[Any, list, dict], +) -> None: + """Verify extract_spot_xyo returns valid XY+omega data.""" + instr, fit_results, spots_data = fit_grains_with_spots + grain_ids = sorted(spots_data.keys()) + + xyo_pred, xyo_det = extract_spot_xyo(spots_data, instr, grain_ids) + + det_keys = list(instr.detectors) + assert set(xyo_pred.keys()) == set(det_keys) + assert set(xyo_det.keys()) == set(det_keys) + + for det_key in det_keys: + assert len(xyo_pred[det_key]) == len(grain_ids) + assert len(xyo_det[det_key]) == len(grain_ids) + + for i in range(len(grain_ids)): + p = xyo_pred[det_key][i] + m = xyo_det[det_key][i] + assert p.ndim == 2 and p.shape[1] == 3 + assert m.ndim == 2 and m.shape[1] == 3 + assert p.shape[0] == m.shape[0] + assert p.shape[0] > 0 + + +def test_spot_diagnostics_dialog( + qtbot: Any, + fit_grains_with_spots: tuple[Any, list, dict], +) -> None: + """Create SpotDiagnosticsDialog with real fit_grains data and exercise it.""" + instr, fit_results, spots_data = fit_grains_with_spots + grain_ids = sorted(spots_data.keys()) + + dialog = SpotDiagnosticsDialog( + instr=instr, + spots_data=spots_data, + grain_ids=grain_ids, + ) + qtbot.addWidget(dialog.ui) + + # The dialog should have rendered the initial canvas + assert dialog.fig is not None + assert dialog.canvas is not None + + det_keys = list(instr.detectors) + + # Verify combo boxes are populated + assert dialog.ui.quantity.count() == len(QUANTITY_CONFIG) + assert dialog.ui.detector.count() == len(det_keys) + assert dialog.ui.grain_id.count() == len(grain_ids) + + # Exercise every quantity selection to ensure no rendering errors + for i in range(dialog.ui.quantity.count()): + dialog.ui.quantity.setCurrentIndex(i) + QApplication.processEvents() + + # Toggle "show all grains" + dialog.ui.show_all_grains.setChecked(True) + QApplication.processEvents() + dialog.ui.show_all_grains.setChecked(False) + QApplication.processEvents() + + # Toggle "show all detectors" (single detector, but still exercise it) + dialog.ui.show_all_detectors.setChecked(True) + QApplication.processEvents() + dialog.ui.show_all_detectors.setChecked(False) + QApplication.processEvents() + + # Toggle "match detector shape" + dialog.ui.match_detector_shape.setChecked(True) + QApplication.processEvents() + dialog.ui.match_detector_shape.setChecked(False) + QApplication.processEvents() + + # Change histogram bins + dialog.ui.histogram_bins.setValue(30) + QApplication.processEvents() + + # Change bounds + dialog.ui.bounds.setValue(0.05) + QApplication.processEvents()