Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions src/silx/gui/data/DataViews.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy
import os

from silx.gui.data.NXdataWidgets import ArrayImagePlot
import silx.io
from silx.gui import qt, icons
from silx.gui.data.TextFormatter import TextFormatter
Expand Down Expand Up @@ -1748,27 +1749,15 @@ def setData(self, data):

self._updateColormap(nxd)

# last two axes are Y & X
img_slicing = slice(-2, None) if not isRgba else slice(-3, -1)
y_axis, x_axis = nxd.axes[img_slicing]
y_label, x_label = nxd.axes_names[img_slicing]
y_scale, x_scale = nxd.plot_style.axes_scale_types[img_slicing]
x_units = get_attr_as_unicode(x_axis, "units") if x_axis else None
y_units = get_attr_as_unicode(y_axis, "units") if y_axis else None

self.getWidget().setImageData(
widget: ArrayImagePlot = self.getWidget()
widget.setImageData(
[nxd.signal] + nxd.auxiliary_signals,
x_axis=x_axis,
y_axis=y_axis,
axes=nxd.axes,
signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
axes_names=nxd.axes_names,
xlabel=x_label,
ylabel=y_label,
axes_scales=nxd.plot_style.axes_scale_types,
title=nxd.title,
isRgba=isRgba,
xscale=x_scale,
yscale=y_scale,
keep_ratio=(x_units == y_units),
)

def getDataPriority(self, data, info: DataInfo):
Expand Down
157 changes: 91 additions & 66 deletions src/silx/gui/data/NXdataWidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
__date__ = "12/11/2018"

import logging
from typing import Literal
import numpy
import h5py

from silx.gui import qt
from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
Expand All @@ -38,6 +40,8 @@
from silx.gui.colors import Colormap
from silx.gui.data._SignalSelector import SignalSelector

from silx.io.commonh5 import Dataset
from silx.io.nxdata._utils import get_attr_as_unicode
from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration


Expand Down Expand Up @@ -367,14 +371,12 @@ class ArrayImagePlot(qt.QWidget):
Widget for plotting an image from a multi-dimensional signal array
and two 1D axes array.

The signal array can have an arbitrary number of dimensions, the only
limitation being that the last two dimensions must have the same length as
the axes arrays.

Sliders are provided to select indices on the first (n - 2) dimensions of
the signal array, and the plot is updated to show the image corresponding
to the selection.

The dimensions can be changed when the signal array has more than 2 dimensions.

If one or both of the axes does not have regularly spaced values, the
the image is plotted as a coloured scatter plot.
"""
Expand All @@ -388,10 +390,6 @@ def __init__(self, parent=None):

self.__signals = None
self.__signals_names = None
self.__x_axis = None
self.__x_axis_name = None
self.__y_axis = None
self.__y_axis_name = None

self._plot = Plot2D(self)
self._plot.setDefaultColormap(
Expand All @@ -404,10 +402,9 @@ def __init__(self, parent=None):
maskToolWidget = self._plot.getMaskToolsDockWidget().widget()
maskToolWidget.setItemMaskUpdated(True)

# not closable
self._axesSelector = NumpyAxesSelector(self)
self._axesSelector.setNamedAxesSelectorVisibility(False)
self._axesSelector.selectionChanged.connect(self._updateImage)
self._axesSelector.selectedAxisChanged.connect(self._updateImageAxes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there are two updating methods:

  • self._updateImage is called when the slider value changes (the axes stay the same)
  • self._updateImageAxes is the new method called when the displayed axis changes. It takes care of resetting the plot properly before slicing the image accordingly.

In normal operation, self._updateImage will always be preceded by a call to self._updateImageAxes.


self._signalSelector = SignalSelector(parent=self)
self._signalSelector.selectionChanged.connect(self._signalChanges)
Expand Down Expand Up @@ -442,7 +439,7 @@ def _aggregationModeChanged(self):
)

def _signalChanges(self, value):
self._updateImage()
self._updateImageAxes()

def getPlot(self):
"""Returns the plot used for the display
Expand All @@ -453,94 +450,103 @@ def getPlot(self):

def setImageData(
self,
signals,
x_axis=None,
y_axis=None,
signals_names=None,
axes_names=None,
xlabel=None,
ylabel=None,
title=None,
isRgba=False,
xscale=None,
yscale=None,
keep_ratio: bool = True,
signals: list[h5py.Dataset | Dataset],
axes: list[h5py.Dataset | Dataset] | None = None,
signals_names: list[str] | None = None,
axes_names: list[str] | None = None,
axes_scales: list[Literal["linear", "log"] | None] | None = None,
title: str | None = None,
isRgba: bool = False,
):
"""
Sets signals, axes and axes metadata that will be used to set the displayed image.

:param signals: list of n-D datasets, whose last 2 dimensions are used as the
image's values, or list of 3D datasets interpreted as RGBA image.
:param x_axis: 1-D dataset used as the image's x coordinates. If
provided, its lengths must be equal to the length of the last
dimension of ``signal``.
:param y_axis: 1-D dataset used as the image's y. If provided,
its lengths must be equal to the length of the 2nd to last
dimension of ``signal``.
:param signals: list of n-D datasets or list of 3D datasets interpreted as RGBA image.
:param axes: list of 1D datasets to be used as axes
:param signals_names: Names for each image, used as subtitle and legend.
:param xlabel: Label for X axis
:param ylabel: Label for Y axis
:param axes_names: Names for each axis, used as graph label.
:param axes_scales: Scale of axes in (None, 'linear', 'log')
:param title: Graph title
:param isRgba: True if data is a 3D RGBA image
:param str xscale: Scale of X axis in (None, 'linear', 'log')
:param str yscale: Scale of Y axis in (None, 'linear', 'log')
:param keep_ratio: Toggle plot keep aspect ratio
"""
self._axesSelector.selectionChanged.disconnect(self._updateImage)
self._axesSelector.selectedAxisChanged.disconnect(self._updateImageAxes)
self._signalSelector.selectionChanged.disconnect(self._signalChanges)

self.__signals = signals
self.__signals_names = signals_names
self.__axis_names = axes_names
self.__x_axis = x_axis
self.__x_axis_name = xlabel
self.__y_axis = y_axis
self.__y_axis_name = ylabel
self.__axes = axes
self.__axes_names = axes_names
self.__axes_scales = axes_scales
self.__title = title

self._axesSelector.clear()

if not isRgba:
self._axesSelector.setAxisNames(["Y", "X"])
self._axesSelector.setNamedAxesSelectorVisibility(True)
img_ndim = 2
else:
self._axesSelector.setAxisNames(["Y", "X", "RGB(A) channel"])
self._axesSelector.setNamedAxesSelectorVisibility(False)
img_ndim = 3
# Labels need to be set before the data
if self.__axes_names:
self._axesSelector.setLabels(self.__axes_names)
self._axesSelector.setData(signals[0])

if len(signals[0].shape) <= img_ndim:
self._axesSelector.hide()
else:
self._axesSelector.show()

if self.__axis_names:
self._axesSelector.setLabels(self.__axis_names)

self._signalSelector.setSignalNames(signals_names)
if len(signals) > 1:
self._signalSelector.show()
else:
self._signalSelector.hide()
self._signalSelector.setSignalIndex(0)

self._axis_scales = xscale, yscale

self._axesSelector.selectionChanged.connect(self._updateImage)
self._axesSelector.selectedAxisChanged.connect(self._updateImageAxes)
self._signalSelector.selectionChanged.connect(self._signalChanges)

self._updateImage()
self._plot.setKeepDataAspectRatio(keep_ratio)
self._updateImageAxes()
self._plot.resetZoom()

def _updateImage(self):
axes_selection = self._axesSelector.selection()
def __getImageToDisplay(self):
signal_index = self._signalSelector.getSignalIndex()
try:
signal = self.__signals[signal_index]
except KeyError:
raise KeyError("No image found. Was an image loaded?")
return signal[self._axesSelector.selection()]

def _updateImageAxes(self):
"""Updates the image axes. Called when the user selects a different axis than the displayed one."""
signal_index = self._signalSelector.getSignalIndex()

legend = self.__signals_names[signal_index]

images = [img[axes_selection] for img in self.__signals]
image = images[signal_index]
image = self.__getImageToDisplay()

x_axis = self.__x_axis
y_axis = self.__y_axis
axis_indices = self._axesSelector.getIndicesOfNamedAxes()
try:
x_axis_index = axis_indices["X"]
y_axis_index = axis_indices["Y"]
except KeyError:
raise KeyError("Axes X and Y not found. Was an image loaded?")

if self.__axes:
x_axis = self.__axes[x_axis_index]
y_axis = self.__axes[y_axis_index]
x_units = get_attr_as_unicode(x_axis, "units") if x_axis else None
y_units = get_attr_as_unicode(y_axis, "units") if y_axis else None
else:
x_axis = None
y_axis = None
x_units = None
y_units = None
self._plot.setKeepDataAspectRatio(x_units == y_units)

if x_axis is None and y_axis is None:
xcalib = NoCalibration()
Expand Down Expand Up @@ -607,7 +613,12 @@ def _updateImage(self):
self._plot.addItem(imageItem)
self._plot.setActiveImage(imageItem)
else:
xaxisscale, yaxisscale = self._axis_scales
if self.__axes_scales:
xaxisscale = self.__axes_scales[x_axis_index]
yaxisscale = self.__axes_scales[y_axis_index]
else:
xaxisscale = None
yaxisscale = None

if xaxisscale is not None:
self._plot.getXAxis().setScale(
Expand All @@ -627,23 +638,37 @@ def _updateImage(self):
legend=legend,
)

if self.__title:
title = self.__title
if len(self.__signals_names) > 1:
# Append dataset name only when there is many datasets
title += "\n" + self.__signals_names[signal_index]
else:
title = self.__signals_names[signal_index]
self._plot.setGraphTitle(title)
self._plot.getXAxis().setLabel(self.__x_axis_name)
self._plot.getYAxis().setLabel(self.__y_axis_name)
self._plot.setGraphTitle(self._graphTitle())
self._plot.getXAxis().setLabel(self.__axes_names[x_axis_index])
self._plot.getYAxis().setLabel(self.__axes_names[y_axis_index])
self._plot.resetZoom()

def clear(self):
old = self._axesSelector.blockSignals(True)
self._axesSelector.clear()
self._axesSelector.blockSignals(old)
self._plot.clear()

def _updateImage(self):
"""Updates the image itself. Called when the user slices through the image without changing the axes."""
image = self.__getImageToDisplay()
activeImageItem = self._plot.getActiveImage()
if activeImageItem:
activeImageItem.setData(image)

def _graphTitle(self) -> str:
signal_index = self._signalSelector.getSignalIndex()
title = self.__title
if not title:
if not self.__signals_names:
return ""
return self.__signals_names[signal_index]

if self.__signals_names and len(self.__signals_names) > 1:
# Append dataset name only when there are many datasets
title += "\n" + self.__signals_names[signal_index]
return title


class ArrayComplexImagePlot(qt.QWidget):
"""
Expand Down
8 changes: 8 additions & 0 deletions src/silx/gui/data/NumpyAxesSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,11 @@ def setNamedAxesSelectorVisibility(self, visible):
self.__namedAxesVisibility = visible
for axis in self.__axis:
axis.setNamedAxisSelectorVisibility(visible)

def getIndicesOfNamedAxes(self) -> dict[str, int]:
result: dict[str, int] = {}
for i, axis in enumerate(self.__axis):
name = axis.axisName()
if name:
result[name] = i
return result
Loading