diff --git a/README-pypi.md b/README-pypi.md
index be0de197..f6173669 100644
--- a/README-pypi.md
+++ b/README-pypi.md
@@ -93,6 +93,14 @@ Here you find a series of notebooks that give you an overview of the core featur
Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline.
+- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)**
+
+ Creating custom scatterers of arbitrary shapes.
+
+- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)**
+
+ Creating custom scatterers in the shape of bacteria.
+
# Examples
These are examples of how DeepTrack2 can be used on real datasets:
diff --git a/README.md b/README.md
index 674d41d3..a9f507c2 100644
--- a/README.md
+++ b/README.md
@@ -97,6 +97,14 @@ Here you find a series of notebooks that give you an overview of the core featur
Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline.
+- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)**
+
+ Creating custom scatterers of arbitrary shapes.
+
+- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)**
+
+ Creating custom scatterers in the shape of bacteria.
+
# Examples
These are examples of how DeepTrack2 can be used on real datasets:
diff --git a/deeptrack/features.py b/deeptrack/features.py
index 4bdfad38..702e7362 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -1,6 +1,6 @@
"""Core features for building and processing pipelines in DeepTrack2.
-The `feasture.py` module defines the core classes and utilities used to create
+The `feature.py` module defines the core classes and utilities used to create
and manipulate features in DeepTrack2, enabling users to build sophisticated
data processing pipelines with modular, reusable, and composable components.
@@ -80,11 +80,8 @@
- `OneOf`: Resolve one feature from a given collection.
- `OneOfDict`: Resolve one feature from a dictionary and apply it to an input.
- `LoadImage`: Load an image from disk and preprocess it.
-- `SampleToMasks`: Create a mask from a list of images.
- `AsType`: Convert the data type of the input.
- `ChannelFirst2d`: DEPRECATED Convert an image to a channel-first format.
-- `Upscale`: Simulate a pipeline at a higher resolution.
-- `NonOverlapping`: Ensure volumes are placed non-overlapping in a 3D space.
- `Store`: Store the output of a feature for reuse.
- `Squeeze`: Squeeze the input to the smallest possible dimension.
- `Unsqueeze`: Unsqueeze the input.
@@ -96,7 +93,7 @@
- `TakeProperties`: Extract all instances of properties from a pipeline.
Arithmetic Feature Classes:
-- `Add`: Add a value to the input.
+- `Add`: Add a value to the input.@dataclass
- `Subtract`: Subtract a value from the input.
- `Multiply`: Multiply the input by a value.
- `Divide`: Divide the input by a value.
@@ -218,11 +215,8 @@
"OneOf",
"OneOfDict",
"LoadImage",
- "SampleToMasks", # TODO ***CM*** revise this after elimination of Image
"AsType",
"ChannelFirst2d",
- "Upscale", # TODO ***CM*** revise and check PyTorch afrer elimin. Image
- "NonOverlapping", # TODO ***CM*** revise + PyTorch afrer elimin. Image
"Store",
"Squeeze",
"Unsqueeze",
@@ -7359,312 +7353,6 @@ def get(
return image
-class SampleToMasks(Feature):
- """Create a mask from a list of images.
-
- This feature applies a transformation function to each input image and
- merges the resulting masks into a single multi-layer image. Each input
- image must have a `position` property that determines its placement within
- the final mask. When used with scatterers, the `voxel_size` property must
- be provided for correct object sizing.
-
- Parameters
- ----------
- transformation_function: Callable[[Image], Image]
- A function that transforms each input image into a mask with
- `number_of_masks` layers.
- number_of_masks: PropertyLike[int], optional
- The number of mask layers to generate. Default is 1.
- output_region: PropertyLike[tuple[int, int, int, int]], optional
- The size and position of the output mask, typically aligned with
- `optics.output_region`.
- merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
- Method for merging individual masks into the final image. Can be:
- - "add" (default): Sum the masks.
- - "overwrite": Later masks overwrite earlier masks.
- - "or": Combine masks using a logical OR operation.
- - "mul": Multiply masks.
- - Function: Custom function taking two images and merging them.
-
- **kwargs: dict[str, Any]
- Additional keyword arguments passed to the parent `Feature` class.
-
- Methods
- -------
- `get(image, transformation_function, **kwargs) -> Image`
- Applies the transformation function to the input image.
- `_process_and_get(images, **kwargs) -> Image | np.ndarray`
- Processes a list of images and generates a multi-layer mask.
-
- Returns
- -------
- Image or np.ndarray
- The final mask image with the specified number of layers.
-
- Raises
- ------
- ValueError
- If `merge_method` is invalid.
-
- Examples
- -------
- >>> import deeptrack as dt
-
- Define number of particles:
-
- >>> n_particles = 12
-
- Define optics and particles:
-
- >>> import numpy as np
- >>>
- >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64))
- >>> particle = dt.PointParticle(
- >>> position=lambda: np.random.uniform(5, 55, size=2),
- >>> )
- >>> particles = particle ^ n_particles
-
- Define pipelines:
-
- >>> sim_im_pip = optics(particles)
- >>> sim_mask_pip = particles >> dt.SampleToMasks(
- ... lambda: lambda particles: particles > 0,
- ... output_region=optics.output_region,
- ... merge_method="or",
- ... )
- >>> pipeline = sim_im_pip & sim_mask_pip
- >>> pipeline.store_properties()
-
- Generate image and mask:
-
- >>> image, mask = pipeline.update()()
-
- Get particle positions:
-
- >>> positions = np.array(image.get_property("position", get_one=False))
-
- Visualize results:
-
- >>> import matplotlib.pyplot as plt
- >>>
- >>> plt.subplot(1, 2, 1)
- >>> plt.imshow(image, cmap="gray")
- >>> plt.title("Original Image")
- >>> plt.subplot(1, 2, 2)
- >>> plt.imshow(mask, cmap="gray")
- >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50)
- >>> plt.title("Mask")
- >>> plt.show()
-
- """
-
- def __init__(
- self: Feature,
- transformation_function: Callable[[Image], Image],
- number_of_masks: PropertyLike[int] = 1,
- output_region: PropertyLike[tuple[int, int, int, int]] = None,
- merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add",
- **kwargs: Any,
- ):
- """Initialize the SampleToMasks feature.
-
- Parameters
- ----------
- transformation_function: Callable[[Image], Image]
- Function to transform input images into masks.
- number_of_masks: PropertyLike[int], optional
- Number of mask layers. Default is 1.
- output_region: PropertyLike[tuple[int, int, int, int]], optional
- Output region of the mask. Default is None.
- merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
- Method to merge masks. Defaults to "add".
- **kwargs: dict[str, Any]
- Additional keyword arguments passed to the parent class.
-
- """
-
- super().__init__(
- transformation_function=transformation_function,
- number_of_masks=number_of_masks,
- output_region=output_region,
- merge_method=merge_method,
- **kwargs,
- )
-
- def get(
- self: Feature,
- image: np.ndarray | Image,
- transformation_function: Callable[[Image], Image],
- **kwargs: Any,
- ) -> Image:
- """Apply the transformation function to a single image.
-
- Parameters
- ----------
- image: np.ndarray | Image
- The input image.
- transformation_function: Callable[[Image], Image]
- Function to transform the image.
- **kwargs: dict[str, Any]
- Additional parameters.
-
- Returns
- -------
- Image
- The transformed image.
-
- """
-
- return transformation_function(image)
-
- def _process_and_get(
- self: Feature,
- images: list[np.ndarray] | np.ndarray | list[Image] | Image,
- **kwargs: Any,
- ) -> Image | np.ndarray:
- """Process a list of images and generate a multi-layer mask.
-
- Parameters
- ----------
- images: np.ndarray or list[np.ndarrray] or Image or list[Image]
- List of input images or a single image.
- **kwargs: dict[str, Any]
- Additional parameters including `output_region`, `number_of_masks`,
- and `merge_method`.
-
- Returns
- -------
- Image or np.ndarray
- The final mask image.
-
- """
-
- # Handle list of images.
- if isinstance(images, list) and len(images) != 1:
- list_of_labels = super()._process_and_get(images, **kwargs)
- if not self._wrap_array_with_image:
- for idx, (label, image) in enumerate(zip(list_of_labels,
- images)):
- list_of_labels[idx] = \
- Image(label, copy=False).merge_properties_from(image)
- else:
- if isinstance(images, list):
- images = images[0]
- list_of_labels = []
- for prop in images.properties:
-
- if "position" in prop:
-
- inp = Image(np.array(images))
- inp.append(prop)
- out = Image(self.get(inp, **kwargs))
- out.merge_properties_from(inp)
- list_of_labels.append(out)
-
- # Create an empty output image.
- output_region = kwargs["output_region"]
- output = np.zeros(
- (
- output_region[2] - output_region[0],
- output_region[3] - output_region[1],
- kwargs["number_of_masks"],
- )
- )
-
- from deeptrack.optics import _get_position
-
- # Merge masks into the output.
- for label in list_of_labels:
- position = _get_position(label)
- p0 = np.round(position - output_region[0:2])
-
- if np.any(p0 > output.shape[0:2]) or \
- np.any(p0 + label.shape[0:2] < 0):
- continue
-
- crop_x = int(-np.min([p0[0], 0]))
- crop_y = int(-np.min([p0[1], 0]))
- crop_x_end = int(
- label.shape[0]
- - np.max([p0[0] + label.shape[0] - output.shape[0], 0])
- )
- crop_y_end = int(
- label.shape[1]
- - np.max([p0[1] + label.shape[1] - output.shape[1], 0])
- )
-
- labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :]
-
- p0[0] = np.max([p0[0], 0])
- p0[1] = np.max([p0[1], 0])
-
- p0 = p0.astype(int)
-
- output_slice = output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- ]
-
- for label_index in range(kwargs["number_of_masks"]):
-
- if isinstance(kwargs["merge_method"], list):
- merge = kwargs["merge_method"][label_index]
- else:
- merge = kwargs["merge_method"]
-
- if merge == "add":
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] += labelarg[..., label_index]
-
- elif merge == "overwrite":
- output_slice[
- labelarg[..., label_index] != 0, label_index
- ] = labelarg[labelarg[..., label_index] != 0, \
- label_index]
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] = output_slice[..., label_index]
-
- elif merge == "or":
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] = (output_slice[..., label_index] != 0) | (
- labelarg[..., label_index] != 0
- )
-
- elif merge == "mul":
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] *= labelarg[..., label_index]
-
- else:
- # No match, assume function
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] = merge(
- output_slice[..., label_index],
- labelarg[..., label_index],
- )
-
- if not self._wrap_array_with_image:
- return output
- output = Image(output)
- for label in list_of_labels:
- output.merge_properties_from(label)
- return output
-
-
class AsType(Feature):
"""Convert the data type of arrays.
@@ -7930,855 +7618,181 @@ def get(
return array
-class Upscale(Feature):
- """Simulate a pipeline at a higher resolution.
+# class Upscale(Feature):
+# """Simulate a pipeline at a higher resolution.
- This feature scales up the resolution of the input pipeline by a specified
- factor, performs computations at the higher resolution, and then
- downsamples the result back to the original size. This is useful for
- simulating effects at a finer resolution while preserving compatibility
- with lower-resolution pipelines.
+# This feature scales up the resolution of the input pipeline by a specified
+# factor, performs computations at the higher resolution, and then
+# downsamples the result back to the original size. This is useful for
+# simulating effects at a finer resolution while preserving compatibility
+# with lower-resolution pipelines.
- Internally, this feature redefines the scale of physical units (e.g.,
- `units.pixel`) to achieve the effect of upscaling. Therefore, it does not
- resize the input image itself but affects only features that rely on
- physical units.
-
- Parameters
- ----------
- feature: Feature
- The pipeline or feature to resolve at a higher resolution.
- factor: int or tuple[int, int, int], optional
- The factor by which to upscale the simulation. If a single integer is
- provided, it is applied uniformly across all axes. If a tuple of three
- integers is provided, each axis is scaled individually. Defaults to 1.
- **kwargs: Any
- Additional keyword arguments passed to the parent `Feature` class.
-
- Attributes
- ----------
- __distributed__: bool
- Always `False` for `Upscale`, indicating that this feature’s `.get()`
- method processes the entire input at once even if it is a list, rather
- than distributing calls for each item of the list.
-
- Methods
- -------
- `get(image, factor, **kwargs) -> np.ndarray | torch.tensor`
- Simulates the pipeline at a higher resolution and returns the result at
- the original resolution.
-
- Notes
- -----
- - This feature does not directly resize the image. Instead, it modifies the
- unit conversions within the pipeline, making physical units smaller,
- which results in more detail being simulated.
- - The final output is downscaled back to the original resolution using
- `block_reduce` from `skimage.measure`.
- - The effect is only noticeable if features use physical units (e.g.,
- `units.pixel`, `units.meter`). Otherwise, the result will be identical.
-
- Examples
- --------
- >>> import deeptrack as dt
-
- Define an optical pipeline and a spherical particle:
-
- >>> optics = dt.Fluorescence()
- >>> particle = dt.Sphere()
- >>> simple_pipeline = optics(particle)
-
- Create an upscaled pipeline with a factor of 4:
-
- >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4)
-
- Resolve the pipelines:
-
- >>> image = simple_pipeline()
- >>> upscaled_image = upscaled_pipeline()
-
- Visualize the images:
-
- >>> import matplotlib.pyplot as plt
- >>>
- >>> plt.subplot(1, 2, 1)
- >>> plt.imshow(image, cmap="gray")
- >>> plt.title("Original Image")
- >>>
- >>> plt.subplot(1, 2, 2)
- >>> plt.imshow(upscaled_image, cmap="gray")
- >>> plt.title("Simulated at Higher Resolution")
- >>>
- >>> plt.show()
+# Internally, this feature redefines the scale of physical units (e.g.,
+# `units.pixel`) to achieve the effect of upscaling. Therefore, it does not
+# resize the input image itself but affects only features that rely on
+# physical units.
+
+# Parameters
+# ----------
+# feature: Feature
+# The pipeline or feature to resolve at a higher resolution.
+# factor: int or tuple[int, int, int], optional
+# The factor by which to upscale the simulation. If a single integer is
+# provided, it is applied uniformly across all axes. If a tuple of three
+# integers is provided, each axis is scaled individually. Defaults to 1.
+# **kwargs: Any
+# Additional keyword arguments passed to the parent `Feature` class.
+
+# Attributes
+# ----------
+# __distributed__: bool
+# Always `False` for `Upscale`, indicating that this feature’s `.get()`
+# method processes the entire input at once even if it is a list, rather
+# than distributing calls for each item of the list.
+
+# Methods
+# -------
+# `get(image, factor, **kwargs) -> np.ndarray | torch.tensor`
+# Simulates the pipeline at a higher resolution and returns the result at
+# the original resolution.
+
+# Notes
+# -----
+# - This feature does not directly resize the image. Instead, it modifies the
+# unit conversions within the pipeline, making physical units smaller,
+# which results in more detail being simulated.
+# - The final output is downscaled back to the original resolution using
+# `block_reduce` from `skimage.measure`.
+# - The effect is only noticeable if features use physical units (e.g.,
+# `units.pixel`, `units.meter`). Otherwise, the result will be identical.
+
+# Examples
+# --------
+# >>> import deeptrack as dt
+
+# Define an optical pipeline and a spherical particle:
+
+# >>> optics = dt.Fluorescence()
+# >>> particle = dt.Sphere()
+# >>> simple_pipeline = optics(particle)
+
+# Create an upscaled pipeline with a factor of 4:
+
+# >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4)
- Compare the shapes (both are the same due to downscaling):
-
- >>> print(image.shape)
- (128, 128, 1)
- >>> print(upscaled_image.shape)
- (128, 128, 1)
+# Resolve the pipelines:
+
+# >>> image = simple_pipeline()
+# >>> upscaled_image = upscaled_pipeline()
+
+# Visualize the images:
+
+# >>> import matplotlib.pyplot as plt
+# >>>
+# >>> plt.subplot(1, 2, 1)
+# >>> plt.imshow(image, cmap="gray")
+# >>> plt.title("Original Image")
+# >>>
+# >>> plt.subplot(1, 2, 2)
+# >>> plt.imshow(upscaled_image, cmap="gray")
+# >>> plt.title("Simulated at Higher Resolution")
+# >>>
+# >>> plt.show()
- """
-
- __distributed__: bool = False
-
- feature: Feature
-
- def __init__(
- self: Feature,
- feature: Feature,
- factor: int | tuple[int, int, int] = 1,
- **kwargs: Any,
- ) -> None:
- """Initialize the Upscale feature.
-
- Parameters
- ----------
- feature: Feature
- The pipeline or feature to resolve at a higher resolution.
- factor: int or tuple[int, int, int], optional
- The factor by which to upscale the simulation. If a single integer
- is provided, it is applied uniformly across all axes. If a tuple of
- three integers is provided, each axis is scaled individually.
- Defaults to 1.
- **kwargs: Any
- Additional keyword arguments passed to the parent `Feature` class.
-
- """
-
- super().__init__(factor=factor, **kwargs)
- self.feature = self.add_feature(feature)
-
- def get(
- self: Feature,
- image: np.ndarray | torch.Tensor,
- factor: int | tuple[int, int, int],
- **kwargs: Any,
- ) -> np.ndarray | torch.Tensor:
- """Simulate the pipeline at a higher resolution and return result.
-
- Parameters
- ----------
- image: np.ndarray or torch.Tensor
- The input image to process.
- factor: int or tuple[int, int, int]
- The factor by which to upscale the simulation. If a single integer
- is provided, it is applied uniformly across all axes. If a tuple of
- three integers is provided, each axis is scaled individually.
- **kwargs: Any
- Additional keyword arguments passed to the feature.
-
- Returns
- -------
- np.ndarray or torch.Tensor
- The processed image at the original resolution.
+# Compare the shapes (both are the same due to downscaling):
- Raises
- ------
- ValueError
- If the input `factor` is not a valid integer or tuple of integers.
-
- """
-
- # Ensure factor is a tuple of three integers.
- if np.size(factor) == 1:
- factor = (factor,) * 3
- elif len(factor) != 3:
- raise ValueError(
- "Factor must be an integer or a tuple of three integers."
- )
-
- # Create a context for upscaling and perform computation.
- ctx = create_context(None, None, None, *factor)
- with units.context(ctx):
- image = self.feature(image)
-
- # Downscale the result to the original resolution.
- import skimage.measure
-
- image = skimage.measure.block_reduce(
- image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
- )
-
- return image
-
-
-class NonOverlapping(Feature):
- """Ensure volumes are placed non-overlapping in a 3D space.
-
- This feature ensures that a list of 3D volumes are positioned such that
- their non-zero voxels do not overlap. If volumes overlap, their positions
- are resampled until they are non-overlapping. If the maximum number of
- attempts is exceeded, the feature regenerates the list of volumes and
- raises a warning if non-overlapping placement cannot be achieved.
-
- Note: `min_distance` refers to the distance between the edges of volumes,
- not their centers. Due to the way volumes are calculated, slight rounding
- errors may affect the final distance.
+# >>> print(image.shape)
+# (128, 128, 1)
+# >>> print(upscaled_image.shape)
+# (128, 128, 1)
- This feature is incompatible with non-volumetric scatterers such as
- `MieScatterers`.
-
- Parameters
- ----------
- feature: Feature
- The feature that generates the list of volumes to place
- non-overlapping.
- min_distance: float, optional
- The minimum distance between volumes in pixels. It can be negative to
- allow for partial overlap. Defaults to 1.
- max_attempts: int, optional
- The maximum number of attempts to place volumes without overlap.
- Defaults to 5.
- max_iters: int, optional
- The maximum number of resamplings. If this number is exceeded, a new
- list of volumes is generated. Defaults to 100.
-
- Attributes
- ----------
- __distributed__: bool
- Always `False` for `NonOverlapping`, indicating that this feature’s
- `.get()` method processes the entire input at once even if it is a
- list, rather than distributing calls for each item of the list.N
-
- Methods
- -------
- `get(*_, min_distance, max_attempts, **kwargs) -> array`
- Generate a list of non-overlapping 3D volumes.
- `_check_non_overlapping(list_of_volumes) -> bool`
- Check if all volumes in the list are non-overlapping.
- `_check_bounding_cubes_non_overlapping(...) -> bool`
- Check if two bounding cubes are non-overlapping.
- `_get_overlapping_cube(...) -> list[int]`
- Get the overlapping cube between two bounding cubes.
- `_get_overlapping_volume(...) -> array`
- Get the overlapping volume between a volume and a bounding cube.
- `_check_volumes_non_overlapping(...) -> bool`
- Check if two volumes are non-overlapping.
- `_resample_volume_position(volume) -> Image`
- Resample the position of a volume to avoid overlap.
-
- Notes
- -----
- - This feature performs bounding cube checks first to quickly reject
- obvious overlaps before voxel-level checks.
- - If the bounding cubes overlap, precise voxel-based checks are performed.
-
- Examples
- ---------
- >>> import deeptrack as dt
-
- Define an ellipse scatterer with randomly positioned objects:
-
- >>> import numpy as np
- >>>
- >>> scatterer = dt.Ellipse(
- >>> radius= 13 * dt.units.pixels,
- >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels,
- >>> )
-
- Create multiple scatterers:
-
- >>> scatterers = (scatterer ^ 8)
-
- Define the optics and create the image with possible overlap:
-
- >>> optics = dt.Fluorescence()
- >>> im_with_overlap = optics(scatterers)
- >>> im_with_overlap.store_properties()
- >>> im_with_overlap_resolved = image_with_overlap()
-
- Gather position from image:
-
- >>> pos_with_overlap = np.array(
- >>> im_with_overlap_resolved.get_property(
- >>> "position",
- >>> get_one=False
- >>> )
- >>> )
-
- Enforce non-overlapping and create the image without overlap:
-
- >>> non_overlapping_scatterers = dt.NonOverlapping(
- ... scatterers,
- ... min_distance=4,
- ... )
- >>> im_without_overlap = optics(non_overlapping_scatterers)
- >>> im_without_overlap.store_properties()
- >>> im_without_overlap_resolved = im_without_overlap()
-
- Gather position from image:
-
- >>> pos_without_overlap = np.array(
- >>> im_without_overlap_resolved.get_property(
- >>> "position",
- >>> get_one=False
- >>> )
- >>> )
-
- Create a figure with two subplots to visualize the difference:
-
- >>> import matplotlib.pyplot as plt
- >>>
- >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5))
- >>>
- >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray")
- >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0])
- >>> axes[0].set_title("Overlapping Objects")
- >>> axes[0].axis("off")
- >>>
- >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray")
- >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0])
- >>> axes[1].set_title("Non-Overlapping Objects")
- >>> axes[1].axis("off")
- >>> plt.tight_layout()
- >>>
- >>> plt.show()
-
- Define function to calculate minimum distance:
-
- >>> def calculate_min_distance(positions):
- >>> distances = [
- >>> np.linalg.norm(positions[i] - positions[j])
- >>> for i in range(len(positions))
- >>> for j in range(i + 1, len(positions))
- >>> ]
- >>> return min(distances)
-
- Print minimum distances with and without overlap:
-
- >>> print(calculate_min_distance(pos_with_overlap))
- 10.768742383382174
-
- >>> print(calculate_min_distance(pos_without_overlap))
- 30.82531120942446
-
- """
-
- __distributed__: bool = False
-
- def __init__(
- self: NonOverlapping,
- feature: Feature,
- min_distance: float = 1,
- max_attempts: int = 5,
- max_iters: int = 100,
- **kwargs: Any,
- ):
- """Initializes the NonOverlapping feature.
-
- Ensures that volumes are placed **non-overlapping** by iteratively
- resampling their positions. If the maximum number of attempts is
- exceeded, the feature regenerates the list of volumes.
-
- Parameters
- ----------
- feature: Feature
- The feature that generates the list of volumes.
- min_distance: float, optional
- The minimum separation distance **between volume edges**, in
- pixels. It defaults to `1`. Negative values allow for partial
- overlap.
- max_attempts: int, optional
- The maximum number of attempts to place the volumes without
- overlap. It defaults to `5`.
- max_iters: int, optional
- The maximum number of resampling iterations per attempt. If
- exceeded, a new list of volumes is generated. It defaults to `100`.
-
- """
-
- super().__init__(
- min_distance=min_distance,
- max_attempts=max_attempts,
- max_iters=max_iters,
- **kwargs,
- )
- self.feature = self.add_feature(feature, **kwargs)
-
- def get(
- self: NonOverlapping,
- *_: Any,
- min_distance: float,
- max_attempts: int,
- max_iters: int,
- **kwargs: Any,
- ) -> list[np.ndarray]:
- """Generates a list of non-overlapping 3D volumes within a defined
- field of view (FOV).
-
- This method **iteratively** attempts to place volumes while ensuring
- they maintain at least `min_distance` separation. If non-overlapping
- placement is not achieved within `max_attempts`, a warning is issued,
- and the best available configuration is returned.
-
- Parameters
- ----------
- _: Any
- Placeholder parameter, typically for an input image.
- min_distance: float
- The minimum required separation distance between volumes, in
- pixels.
- max_attempts: int
- The maximum number of attempts to generate a valid non-overlapping
- configuration.
- max_iters: int
- The maximum number of resampling iterations per attempt.
- **kwargs: Any
- Additional parameters that may be used by subclasses.
-
- Returns
- -------
- list[np.ndarray]
- A list of 3D volumes represented as NumPy arrays. If
- non-overlapping placement is unsuccessful, the best available
- configuration is returned.
-
- Warns
- -----
- UserWarning
- If non-overlapping placement is **not** achieved within
- `max_attempts`, suggesting parameter adjustments such as increasing
- the FOV or reducing `min_distance`.
-
- Notes
- -----
- - The placement process prioritizes bounding cube checks for
- efficiency.
- - If bounding cubes overlap, voxel-based overlap checks are performed.
+# """
+
+# __distributed__: bool = False
+
+# feature: Feature
+
+# def __init__(
+# self: Feature,
+# feature: Feature,
+# factor: int | tuple[int, int, int] = 1,
+# **kwargs: Any,
+# ) -> None:
+# """Initialize the Upscale feature.
+
+# Parameters
+# ----------
+# feature: Feature
+# The pipeline or feature to resolve at a higher resolution.
+# factor: int or tuple[int, int, int], optional
+# The factor by which to upscale the simulation. If a single integer
+# is provided, it is applied uniformly across all axes. If a tuple of
+# three integers is provided, each axis is scaled individually.
+# Defaults to 1.
+# **kwargs: Any
+# Additional keyword arguments passed to the parent `Feature` class.
+
+# """
+
+# super().__init__(factor=factor, **kwargs)
+# self.feature = self.add_feature(feature)
+
+# def get(
+# self: Feature,
+# image: np.ndarray | torch.Tensor,
+# factor: int | tuple[int, int, int],
+# **kwargs: Any,
+# ) -> np.ndarray | torch.Tensor:
+# """Simulate the pipeline at a higher resolution and return result.
+
+# Parameters
+# ----------
+# image: np.ndarray or torch.Tensor
+# The input image to process.
+# factor: int or tuple[int, int, int]
+# The factor by which to upscale the simulation. If a single integer
+# is provided, it is applied uniformly across all axes. If a tuple of
+# three integers is provided, each axis is scaled individually.
+# **kwargs: Any
+# Additional keyword arguments passed to the feature.
+
+# Returns
+# -------
+# np.ndarray or torch.Tensor
+# The processed image at the original resolution.
+
+# Raises
+# ------
+# ValueError
+# If the input `factor` is not a valid integer or tuple of integers.
+
+# """
+
+# # Ensure factor is a tuple of three integers.
+# if np.size(factor) == 1:
+# factor = (factor, factor, 1)
+# elif len(factor) != 3:
+# raise ValueError(
+# "Factor must be an integer or a tuple of three integers."
+# )
- """
-
- for _ in range(max_attempts):
- list_of_volumes = self.feature()
-
- if not isinstance(list_of_volumes, list):
- list_of_volumes = [list_of_volumes]
+# # Create a context for upscaling and perform computation.
+# ctx = create_context(None, None, None, *factor)
- for _ in range(max_iters):
-
- list_of_volumes = [
- self._resample_volume_position(volume)
- for volume in list_of_volumes
- ]
-
- if self._check_non_overlapping(list_of_volumes):
- return list_of_volumes
-
- # Generate a new list of volumes if max_attempts is exceeded.
- self.feature.update()
-
- warnings.warn(
- "Non-overlapping placement could not be achieved. Consider "
- "adjusting parameters: reduce object radius, increase FOV, "
- "or decrease min_distance.",
- UserWarning,
- )
- return list_of_volumes
+# print('before:', image)
+# with units.context(ctx):
+# image = self.feature(image)
- def _check_non_overlapping(
- self: NonOverlapping,
- list_of_volumes: list[np.ndarray],
- ) -> bool:
- """Determines whether all volumes in the provided list are
- non-overlapping.
+# print('after:', image)
+# # Downscale the result to the original resolution.
+# import skimage.measure
- This method verifies that the non-zero voxels of each 3D volume in
- `list_of_volumes` are at least `min_distance` apart. It first checks
- bounding boxes for early rejection and then examines actual voxel
- overlap when necessary. Volumes are assumed to have a `position`
- attribute indicating their placement in 3D space.
-
- Parameters
- ----------
- list_of_volumes: list[np.ndarray]
- A list of 3D arrays representing the volumes to be checked for
- overlap. Each volume is expected to have a position attribute.
-
- Returns
- -------
- bool
- `True` if all volumes are non-overlapping, otherwise `False`.
-
- Notes
- -----
- - If `min_distance` is negative, volumes are shrunk using isotropic
- erosion before checking overlap.
- - If `min_distance` is positive, volumes are padded and expanded using
- isotropic dilation.
- - Overlapping checks are first performed on bounding cubes for
- efficiency.
- - If bounding cubes overlap, voxel-level checks are performed.
-
- """
-
- from skimage.morphology import isotropic_erosion, isotropic_dilation
-
- from deeptrack.augmentations import CropTight, Pad
- from deeptrack.optics import _get_position
-
- min_distance = self.min_distance()
- crop = CropTight()
-
- if min_distance < 0:
- list_of_volumes = [
- Image(
- crop(isotropic_erosion(volume != 0, -min_distance/2)),
- copy=False,
- ).merge_properties_from(volume)
- for volume in list_of_volumes
- ]
- else:
- pad = Pad(px = [int(np.ceil(min_distance/2))]*6, keep_size=True)
- list_of_volumes = [
- Image(
- crop(isotropic_dilation(pad(volume) != 0, min_distance/2)),
- copy=False,
- ).merge_properties_from(volume)
- for volume in list_of_volumes
- ]
- min_distance = 1
-
- # The position of the top left corner of each volume (index (0, 0, 0)).
- volume_positions_1 = [
- _get_position(volume, mode="corner", return_z=True).astype(int)
- for volume in list_of_volumes
- ]
-
- # The position of the bottom right corner of each volume
- # (index (-1, -1, -1)).
- volume_positions_2 = [
- p0 + np.array(v.shape)
- for v, p0 in zip(list_of_volumes, volume_positions_1)
- ]
-
- # (x1, y1, z1, x2, y2, z2) for each volume.
- volume_bounding_cube = [
- [*p0, *p1]
- for p0, p1 in zip(volume_positions_1, volume_positions_2)
- ]
-
- for i, j in itertools.combinations(range(len(list_of_volumes)), 2):
-
- # If the bounding cubes do not overlap, the volumes do not overlap.
- if self._check_bounding_cubes_non_overlapping(
- volume_bounding_cube[i], volume_bounding_cube[j], min_distance
- ):
- continue
-
- # If the bounding cubes overlap, get the overlapping region of each
- # volume.
- overlapping_cube = self._get_overlapping_cube(
- volume_bounding_cube[i], volume_bounding_cube[j]
- )
- overlapping_volume_1 = self._get_overlapping_volume(
- list_of_volumes[i], volume_bounding_cube[i], overlapping_cube
- )
- overlapping_volume_2 = self._get_overlapping_volume(
- list_of_volumes[j], volume_bounding_cube[j], overlapping_cube
- )
-
- # If either the overlapping regions are empty, the volumes do not
- # overlap (done for speed).
- if (np.all(overlapping_volume_1 == 0)
- or np.all(overlapping_volume_2 == 0)):
- continue
-
- # If products of overlapping regions are non-zero, return False.
- # if np.any(overlapping_volume_1 * overlapping_volume_2):
- # return False
-
- # Finally, check that the non-zero voxels of the volumes are at
- # least min_distance apart.
- if not self._check_volumes_non_overlapping(
- overlapping_volume_1, overlapping_volume_2, min_distance
- ):
- return False
-
- return True
-
- def _check_bounding_cubes_non_overlapping(
- self: NonOverlapping,
- bounding_cube_1: list[int],
- bounding_cube_2: list[int],
- min_distance: float,
- ) -> bool:
- """Determines whether two 3D bounding cubes are non-overlapping.
-
- This method checks whether the bounding cubes of two volumes are
- **separated by at least** `min_distance` along **any** spatial axis.
-
- Parameters
- ----------
- bounding_cube_1: list[int]
- A list of six integers `[x1, y1, z1, x2, y2, z2]` representing
- the first bounding cube.
- bounding_cube_2: list[int]
- A list of six integers `[x1, y1, z1, x2, y2, z2]` representing
- the second bounding cube.
- min_distance: float
- The required **minimum separation distance** between the two
- bounding cubes.
-
- Returns
- -------
- bool
- `True` if the bounding cubes are non-overlapping (separated by at
- least `min_distance` along **at least one axis**), otherwise
- `False`.
-
- Notes
- -----
- - This function **only checks bounding cubes**, **not actual voxel
- data**.
- - If the bounding cubes are non-overlapping, the corresponding
- **volumes are also non-overlapping**.
- - This check is much **faster** than full voxel-based comparisons.
-
- """
-
- # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2).
- # Check that the bounding cubes are non-overlapping.
- return (
- (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or
- (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or
- (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or
- (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or
- (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or
- (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance)
- )
-
- def _get_overlapping_cube(
- self: NonOverlapping,
- bounding_cube_1: list[int],
- bounding_cube_2: list[int],
- ) -> list[int]:
- """Computes the overlapping region between two 3D bounding cubes.
-
- This method calculates the coordinates of the intersection of two
- axis-aligned bounding cubes, each represented as a list of six
- integers:
-
- - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner.
- - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner.
-
- The resulting overlapping region is determined by:
- - Taking the **maximum** of the starting coordinates (`x1, y1, z1`).
- - Taking the **minimum** of the ending coordinates (`x2, y2, z2`).
-
- If the cubes **do not** overlap, the resulting coordinates will not
- form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`).
-
- Parameters
- ----------
- bounding_cube_1: list[int]
- The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`.
- bounding_cube_2: list[int]
- The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`.
-
- Returns
- -------
- list[int]
- A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the
- overlapping bounding cube. If no overlap exists, the coordinates
- will **not** define a valid cube.
-
- Notes
- -----
- - This function does **not** check for valid input or ensure the
- resulting cube is well-formed.
- - If no overlap exists, downstream functions must handle the invalid
- result.
-
- """
-
- return [
- max(bounding_cube_1[0], bounding_cube_2[0]),
- max(bounding_cube_1[1], bounding_cube_2[1]),
- max(bounding_cube_1[2], bounding_cube_2[2]),
- min(bounding_cube_1[3], bounding_cube_2[3]),
- min(bounding_cube_1[4], bounding_cube_2[4]),
- min(bounding_cube_1[5], bounding_cube_2[5]),
- ]
-
- def _get_overlapping_volume(
- self: NonOverlapping,
- volume: np.ndarray, # 3D array.
- bounding_cube: tuple[float, float, float, float, float, float],
- overlapping_cube: tuple[float, float, float, float, float, float],
- ) -> np.ndarray:
- """Extracts the overlapping region of a 3D volume within the specified
- overlapping cube.
-
- This method identifies and returns the subregion of `volume` that
- lies within the `overlapping_cube`. The bounding information of the
- volume is provided via `bounding_cube`.
-
- Parameters
- ----------
- volume: np.ndarray
- A 3D NumPy array representing the volume from which the
- overlapping region is extracted.
- bounding_cube: tuple[float, float, float, float, float, float]
- The bounding cube of the volume, given as a tuple of six floats:
- `(x1, y1, z1, x2, y2, z2)`. The first three values define the
- **top-left-front** corner, while the last three values define the
- **bottom-right-back** corner.
- overlapping_cube: tuple[float, float, float, float, float, float]
- The overlapping region between the volume and another volume,
- represented in the same format as `bounding_cube`.
-
- Returns
- -------
- np.ndarray
- A 3D NumPy array representing the portion of `volume` that
- lies within `overlapping_cube`. If the overlap does not exist,
- an empty array may be returned.
-
- Notes
- -----
- - The method computes the relative indices of `overlapping_cube`
- within `volume` by subtracting the bounding cube's starting
- position.
- - The extracted region is determined by integer indices, meaning
- coordinates are implicitly **floored to integers**.
- - If `overlapping_cube` extends beyond `volume` boundaries, the
- returned subregion is **cropped** to fit within `volume`.
-
- """
-
- # The position of the top left corner of the overlapping cube in the volume
- overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array(
- bounding_cube[:3]
- )
-
- # The position of the bottom right corner of the overlapping cube in the volume
- overlapping_cube_end_position = np.array(
- overlapping_cube[3:]
- ) - np.array(bounding_cube[:3])
-
- # cast to int
- overlapping_cube_position = overlapping_cube_position.astype(int)
- overlapping_cube_end_position = overlapping_cube_end_position.astype(int)
-
- return volume[
- overlapping_cube_position[0] : overlapping_cube_end_position[0],
- overlapping_cube_position[1] : overlapping_cube_end_position[1],
- overlapping_cube_position[2] : overlapping_cube_end_position[2],
- ]
-
- def _check_volumes_non_overlapping(
- self: NonOverlapping,
- volume_1: np.ndarray,
- volume_2: np.ndarray,
- min_distance: float,
- ) -> bool:
- """Determines whether the non-zero voxels in two 3D volumes are at
- least `min_distance` apart.
-
- This method checks whether the active regions (non-zero voxels) in
- `volume_1` and `volume_2` maintain a minimum separation of
- `min_distance`. If the volumes differ in size, the positions of their
- non-zero voxels are adjusted accordingly to ensure a fair comparison.
-
- Parameters
- ----------
- volume_1: np.ndarray
- A 3D NumPy array representing the first volume.
- volume_2: np.ndarray
- A 3D NumPy array representing the second volume.
- min_distance: float
- The minimum Euclidean distance required between any two non-zero
- voxels in the two volumes.
-
- Returns
- -------
- bool
- `True` if all non-zero voxels in `volume_1` and `volume_2` are at
- least `min_distance` apart, otherwise `False`.
-
- Notes
- -----
- - This function assumes both volumes are correctly aligned within a
- shared coordinate space.
- - If the volumes are of different sizes, voxel positions are scaled
- or adjusted for accurate distance measurement.
- - Uses **Euclidean distance** for separation checking.
- - If either volume is empty (i.e., no non-zero voxels), they are
- considered non-overlapping.
-
- """
-
- # Get the positions of the non-zero voxels of each volume.
- positions_1 = np.argwhere(volume_1)
- positions_2 = np.argwhere(volume_2)
-
- # if positions_1.size == 0 or positions_2.size == 0:
- # return True # If either volume is empty, they are "non-overlapping"
-
- # # If the volumes are not the same size, the positions of the non-zero
- # # voxels of each volume need to be scaled.
- # if positions_1.size == 0 or positions_2.size == 0:
- # return True # If either volume is empty, they are "non-overlapping"
-
- # If the volumes are not the same size, the positions of the non-zero
- # voxels of each volume need to be scaled.
- if volume_1.shape != volume_2.shape:
- positions_1 = (
- positions_1 * np.array(volume_2.shape)
- / np.array(volume_1.shape)
- )
- positions_1 = positions_1.astype(int)
-
- # Check that the non-zero voxels of the volumes are at least
- # min_distance apart.
- return np.all(
- cdist(positions_1, positions_2) > min_distance
- )
-
- def _resample_volume_position(
- self: NonOverlapping,
- volume: np.ndarray | Image,
- ) -> Image:
- """Resamples the position of a 3D volume using its internal position
- sampler.
-
- This method updates the `position` property of the given `volume` by
- drawing a new position from the `_position_sampler` stored in the
- volume's `properties`. If the sampled position is a `Quantity`, it is
- converted to pixel units.
-
- Parameters
- ----------
- volume: np.ndarray or Image
- The 3D volume whose position is to be resampled. The volume must
- have a `properties` attribute containing dictionaries with
- `position` and `_position_sampler` keys.
-
- Returns
- -------
- Image
- The same input volume with its `position` property updated to the
- newly sampled value.
-
- Notes
- -----
- - The `_position_sampler` function is expected to return a **tuple of
- three floats** (e.g., `(x, y, z)`).
- - If the sampled position is a `Quantity`, it is converted to pixels.
- - **Only** dictionaries in `volume.properties` that contain both
- `position` and `_position_sampler` keys are modified.
-
- """
+# image = skimage.measure.block_reduce(
+# image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
+# )
- for pdict in volume.properties:
- if "position" in pdict and "_position_sampler" in pdict:
- new_position = pdict["_position_sampler"]()
- if isinstance(new_position, Quantity):
- new_position = new_position.to("pixel").magnitude
- pdict["position"] = new_position
+# return image
- return volume
class Store(Feature):
@@ -9593,4 +8607,4 @@ def get(
if len(res) == 1:
res = res[0]
- return res
+ return res
\ No newline at end of file
diff --git a/deeptrack/holography.py b/deeptrack/holography.py
index 380969cf..141cc540 100644
--- a/deeptrack/holography.py
+++ b/deeptrack/holography.py
@@ -101,7 +101,7 @@ def get_propagation_matrix(
def get_propagation_matrix(
shape: tuple[int, int],
to_z: float,
- pixel_size: float,
+ pixel_size: float | tuple[float, float],
wavelength: float,
dx: float = 0,
dy: float = 0
@@ -118,8 +118,8 @@ def get_propagation_matrix(
The dimensions of the optical field (height, width).
to_z: float
Propagation distance along the z-axis.
- pixel_size: float
- The physical size of each pixel in the optical field.
+ pixel_size: float | tuple[float, float]
+ Physical pixel size. If scalar, isotropic pixels are assumed.
wavelength: float
The wavelength of the optical field.
dx: float, optional
@@ -140,14 +140,22 @@ def get_propagation_matrix(
"""
+ if pixel_size is None:
+ pixel_size = get_active_voxel_size()
+
+ if np.isscalar(pixel_size):
+ pixel_size = (pixel_size, pixel_size)
+
+ px, py = pixel_size
+
k = 2 * np.pi / wavelength
yr, xr, *_ = shape
x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2
y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2
- x = 2 * np.pi / pixel_size * x / xr
- y = 2 * np.pi / pixel_size * y / yr
+ x = 2 * np.pi / px * x / xr
+ y = 2 * np.pi / py * y / yr
KXk, KYk = np.meshgrid(x, y)
KXk = KXk.astype(complex)
diff --git a/deeptrack/math.py b/deeptrack/math.py
index 05cbf311..2f8b7339 100644
--- a/deeptrack/math.py
+++ b/deeptrack/math.py
@@ -93,23 +93,24 @@
from __future__ import annotations
-from typing import Any, Callable, TYPE_CHECKING
+from typing import Any, Callable, Dict, Tuple, TYPE_CHECKING
import array_api_compat as apc
import numpy as np
-from numpy.typing import NDArray
+from numpy.typing import NDArray #TODO TBE
from scipy import ndimage
import skimage
import skimage.measure
from deeptrack import utils, OPENCV_AVAILABLE, TORCH_AVAILABLE
from deeptrack.features import Feature
-from deeptrack.image import Image, strip
-from deeptrack.types import ArrayLike, PropertyLike
+from deeptrack.image import Image, strip #TODO TBE
+from deeptrack.types import PropertyLike
from deeptrack.backend import xp
if TORCH_AVAILABLE:
import torch
+ import torch.nn.functional as F
if OPENCV_AVAILABLE:
import cv2
@@ -133,7 +134,6 @@
"BilateralBlur",
]
-
if TYPE_CHECKING:
import torch
@@ -227,10 +227,10 @@ def __init__(
def get(
self: Average,
- images: list[NDArray[Any] | torch.Tensor | Image],
+ images: list[np.ndarray | torch.Tensor],
axis: int | tuple[int],
**kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor | Image:
+ ) -> np.ndarray | torch.Tensor:
"""Compute the average of input images along the specified axis(es).
This method computes the average of the input images along the
@@ -318,11 +318,11 @@ def __init__(
def get(
self: Clip,
- image: NDArray[Any] | torch.Tensor | Image,
+ image: np.ndarray | torch.Tensor,
min: float,
max: float,
**kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor | Image:
+ ) -> np.ndarray | torch.Tensor:
"""Clips the input image within the specified values.
This method clips the input image within the specified minimum and
@@ -363,8 +363,7 @@ class NormalizeMinMax(Feature):
max: float, optional
Upper bound of the transformation. It defaults to 1.
featurewise: bool, optional
- Whether to normalize each feature independently. It default to `True`,
- which is the only behavior currently implemented.
+ Whether to normalize each feature independently. It default to `True`.
Methods
-------
@@ -390,8 +389,6 @@ class NormalizeMinMax(Feature):
"""
- #TODO ___??___ Implement the `featurewise=False` option
-
def __init__(
self: NormalizeMinMax,
min: PropertyLike[float] = 0,
@@ -418,32 +415,47 @@ def __init__(
def get(
self: NormalizeMinMax,
- image: ArrayLike,
+ image: np.ndarray | torch.Tensor,
min: float,
max: float,
+ featurewise: bool = True,
**kwargs: Any,
- ) -> ArrayLike:
+ ) -> np.ndarray | torch.Tensor:
"""Normalize the input to fall between `min` and `max`.
Parameters
----------
- image: array
+ image: np.ndarray or torch.Tensor
Input image to normalize.
min: float
Lower bound of the output range.
max: float
Upper bound of the output range.
+ featurewise: bool
+ Whether to normalize each feature (channel) independently.
Returns
-------
- array
+ np.ndarray or torch.Tensor
Min-max normalized image.
"""
- ptp = xp.max(image) - xp.min(image)
- image = image / ptp * (max - min)
- image = image - xp.min(image) + min
+ if featurewise:
+ # Normalize per feature (last axis)
+ axis = tuple(range(image.ndim - 1))
+
+ img_min = xp.min(image, axis=axis, keepdims=True)
+ img_max = xp.max(image, axis=axis, keepdims=True)
+ else:
+ # Normalize globally
+ img_min = xp.min(image)
+ img_max = xp.max(image)
+
+ ptp = img_max - img_min
+
+ # Avoid division by zero
+ image = (image - img_min) / ptp * (max - min) + min
try:
image[xp.isnan(image)] = 0
@@ -487,8 +499,6 @@ class NormalizeStandard(Feature):
"""
- #TODO ___??___ Implement the `featurewise=False` option
-
def __init__(
self: NormalizeStandard,
featurewise: PropertyLike[bool] = True,
@@ -511,33 +521,52 @@ def __init__(
def get(
self: NormalizeStandard,
- image: NDArray[Any] | torch.Tensor | Image,
+ image: np.ndarray | torch.Tensor,
+ featurewise: bool,
**kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor | Image:
+ ) -> np.ndarray | torch.Tensor:
"""Normalizes the input image to have mean 0 and standard deviation 1.
- This method normalizes the input image to have mean 0 and standard
- deviation 1.
-
Parameters
----------
- image: array
+ image: np.ndarray or torch.Tensor
The input image to normalize.
+ featurewise: bool
+ Whether to normalize each feature (channel) independently.
Returns
-------
- array
- The normalized image.
-
+ np.ndarray or torch.Tensor
+ The standardized image.
"""
- if apc.is_torch_array(image):
- # By default, torch.std() is unbiased, i.e., divides by N-1
- return (
- (image - torch.mean(image)) / torch.std(image, unbiased=False)
- )
+ if featurewise:
+ # Normalize per feature (last axis)
+ axis = tuple(range(image.ndim - 1))
+
+ mean = xp.mean(image, axis=axis, keepdims=True)
+
+ if apc.is_torch_array(image):
+ std = torch.std(image, dim=axis, keepdim=True, unbiased=False)
+ else:
+ std = xp.std(image, axis=axis)
+ else:
+ # Normalize globally
+ mean = xp.mean(image)
+
+ if apc.is_torch_array(image):
+ std = torch.std(image, unbiased=False)
+ else:
+ std = xp.std(image)
+
+ image = (image - mean) / std
+
+ try:
+ image[xp.isnan(image)] = 0
+ except TypeError:
+ pass
- return (image - xp.mean(image)) / xp.std(image)
+ return image
class NormalizeQuantile(Feature):
@@ -609,158 +638,164 @@ def __init__(
def get(
self: NormalizeQuantile,
- image: NDArray[Any] | torch.Tensor | Image,
- quantiles: tuple[float, float] = None,
+ image: np.ndarray | torch.Tensor,
+ quantiles: tuple[float, float],
+ featurewise: bool,
**kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor | Image:
+ ) -> np.ndarray | torch.Tensor:
"""Normalize the input image based on the specified quantiles.
- This method normalizes the input image based on the specified
- quantiles.
-
Parameters
----------
- image: array
+ image: np.ndarray or torch.Tensor
The input image to normalize.
quantiles: tuple[float, float]
Quantile range to calculate scaling factor.
+ featurewise: bool
+ Whether to normalize each feature (channel) independently.
Returns
-------
- array
- The normalized image.
-
+ np.ndarray or torch.Tensor
+ The quantile-normalized image.
"""
- if apc.is_torch_array(image):
- q_tensor = torch.tensor(
- [*quantiles, 0.5],
- device=image.device,
- dtype=image.dtype,
- )
- q_low, q_high, median = torch.quantile(
- image, q_tensor, dim=None, keepdim=False,
- )
- else: # NumPy
- q_low, q_high, median = xp.quantile(image, (*quantiles, 0.5))
-
- return (image - median) / (q_high - q_low) * 2.0
+ q_low_val, q_high_val = quantiles
+ if featurewise:
+ # Per-feature normalization (last axis)
+ axis = tuple(range(image.ndim - 1))
-#TODO ***JH*** revise Blur - torch, typing, docstring, unit test
-class Blur(Feature):
- """Apply a blurring filter to an image.
-
- This class applies a blurring filter to an image. The filter function
- must be a function that takes an input image and returns a blurred
- image.
-
- Parameters
- ----------
- filter_function: Callable
- The blurring function to apply. This function must accept the input
- image as a keyword argument named `input`. If using OpenCV functions
- (e.g., `cv2.GaussianBlur`), use `BlurCV2` instead.
- mode: str
- Border mode for handling boundaries (e.g., 'reflect').
+ if apc.is_torch_array(image):
+ q = torch.tensor(
+ [q_low_val, q_high_val, 0.5],
+ device=image.device,
+ dtype=image.dtype,
+ )
+ q_low, q_high, median = torch.quantile(
+ image, q, dim=axis, keepdim=True
+ )
+ else:
+ q_low, q_high, median = xp.quantile(
+ image, (q_low_val, q_high_val, 0.5),
+ axis=axis,
+ keepdims=True,
+ )
+ else:
+ # Global normalization
+ if apc.is_torch_array(image):
+ q = torch.tensor(
+ [q_low_val, q_high_val, 0.5],
+ device=image.device,
+ dtype=image.dtype,
+ )
+ q_low, q_high, median = torch.quantile(
+ image, q, dim=None, keepdim=False
+ )
+ else:
+ q_low, q_high, median = xp.quantile(
+ image, (q_low_val, q_high_val, 0.5)
+ )
- Methods
- -------
- `get(image: np.ndarray | Image, **kwargs: Any) --> np.ndarray`
- Applies the blurring filter to the input image.
+ image = (image - median) / (q_high - q_low) * 2.0
- Examples
- --------
- >>> import deeptrack as dt
- >>> import numpy as np
- >>> from scipy.ndimage import convolve
+ try:
+ image[xp.isnan(image)] = 0
+ except TypeError:
+ pass
- Create an input image:
- >>> input_image = np.random.rand(32, 32)
+ return image
- Define a Gaussian kernel for blurring:
- >>> gaussian_kernel = np.array([
- ... [1, 4, 6, 4, 1],
- ... [4, 16, 24, 16, 4],
- ... [6, 24, 36, 24, 6],
- ... [4, 16, 24, 16, 4],
- ... [1, 4, 6, 4, 1]
- ... ], dtype=float)
- >>> gaussian_kernel /= np.sum(gaussian_kernel)
- Define a blur function using the Gaussian kernel:
- >>> def gaussian_blur(input, **kwargs):
- ... return convolve(input, gaussian_kernel, mode='reflect')
+#TODO ***CM*** revise typing, docstring, unit test
+class Blur(Feature):
+ """Apply a blurring filter to an image.
- Define a blur feature using the Gaussian blur function:
- >>> blur = dt.Blur(filter_function=gaussian_blur)
- >>> output_image = blur(input_image)
- >>> print(output_image.shape)
- (32, 32)
+ This class acts as a backend-dispatching blur operator. Subclasses must
+ implement backend-specific logic via `_get_numpy` and optionally
+ `_get_torch`.
Notes
-----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
- The filter_function must accept the input image as a keyword argument named
- input. This is required because it is called via utils.safe_call. If you
- are using functions that do not support input=... (such as OpenCV filters
- like cv2.GaussianBlur), consider using BlurCV2 instead.
+ - NumPy execution is always supported.
+ - Torch execution is only supported if `_get_torch` is implemented.
+ - Generic `filter_function`-based blurs are NumPy-only by design.
"""
def __init__(
- self: Blur,
- filter_function: Callable,
+ self,
+ filter_function: Callable | None = None,
mode: PropertyLike[str] = "reflect",
**kwargs: Any,
):
- """Initialize the parameters for blurring input features.
-
- This constructor initializes the parameters for blurring input
- features.
+ """Initialize the blur feature.
Parameters
----------
- filter_function: Callable
- The blurring function to apply.
- mode: str
- Border mode for handling boundaries (e.g., 'reflect').
- **kwargs: Any
- Additional keyword arguments.
-
+ filter_function : Callable or None
+ NumPy-based blurring function. Must accept the input image as a
+ keyword argument named `input`. If `None`, the subclass must
+ implement `_get_numpy`.
+ mode : str
+ Border mode for NumPy-based filters.
+ **kwargs : Any
+ Additional keyword arguments passed to Feature.
"""
-
self.filter = filter_function
- super().__init__(borderType=mode, **kwargs)
+ self.mode = mode
+ super().__init__(**kwargs)
- def get(self: Blur, image: np.ndarray | Image, **kwargs: Any) -> np.ndarray:
- """Applies the blurring filter to the input image.
+ def __call__(
+ self,
+ image: np.ndarray | torch.Tensor,
+ **kwargs: Any,
+ ) -> np.ndarray | torch.Tensor:
+ if isinstance(image, np.ndarray):
+ return self._get_numpy(image, **kwargs)
- This method applies the blurring filter to the input image.
+ if TORCH_AVAILABLE and isinstance(image, torch.Tensor):
+ return self._get_torch(image, **kwargs)
- Parameters
- ----------
- image: np.ndarray
- The input image to blur.
- **kwargs: dict[str, Any]
- Additional keyword arguments.
+ raise TypeError(
+ "Blur only supports numpy.ndarray or torch.Tensor inputs."
+ )
- Returns
- -------
- np.ndarray
- The blurred image.
+ def _get_numpy(
+ self,
+ image: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ if self.filter is None:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not implement a NumPy backend."
+ )
- """
+ # Avoid passing conflicting keywords
+ kwargs = dict(kwargs)
+ kwargs.pop("input", None)
+
+ return utils.safe_call(
+ self.filter,
+ input=image,
+ mode=self.mode,
+ **kwargs,
+ )
+
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ raise TypeError(
+ f"{self.__class__.__name__} does not support torch.Tensor inputs. "
+ "Use a Torch-enabled blur (e.g. AverageBlur or a V2 blur class)."
+ )
- kwargs.pop("input", False)
- return utils.safe_call(self.filter, input=image, **kwargs)
-#TODO ***JH*** revise AverageBlur - torch, typing, docstring, unit test
+#TODO ***CM*** revise AverageBlur - torch, typing, docstring, unit test
class AverageBlur(Blur):
"""Blur an image by computing simple means over neighbourhoods.
@@ -774,7 +809,7 @@ class AverageBlur(Blur):
Methods
-------
- `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray`
+ `get(image: np.ndarray | torch.Tensor, ksize: int, **kwargs: Any) --> np.ndarray | torch.Tensor`
Applies the average blurring filter to the input image.
Examples
@@ -791,13 +826,6 @@ class AverageBlur(Blur):
>>> print(output_image.shape)
(32, 32)
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
"""
def __init__(
@@ -840,12 +868,10 @@ def _get_numpy(
def _get_torch(
self, input: torch.Tensor, ksize: tuple[int, ...], **kwargs: Any
- ) -> np.ndarray:
- F = xp.nn.functional
+ ) -> torch.Tensor:
last_dim_is_channel = len(ksize) < input.ndim
if last_dim_is_channel:
- # permute to first dim
input = input.movedim(-1, 0)
else:
input = input.unsqueeze(0)
@@ -853,40 +879,26 @@ def _get_torch(
# add batch dimension
input = input.unsqueeze(0)
- # pad input
+ # dynamic padding
+ pad = []
+ for k in reversed(ksize):
+ p = k // 2
+ pad.extend([p, p])
+ pad = tuple(pad)
+
input = F.pad(
input,
- (ksize[0] // 2, ksize[0] // 2, ksize[1] // 2, ksize[1] // 2),
+ pad,
mode=kwargs.get("mode", "reflect"),
value=kwargs.get("cval", 0),
)
+
if input.ndim == 3:
- x = F.avg_pool1d(
- input,
- kernel_size=ksize,
- stride=1,
- padding=0,
- ceil_mode=False,
- count_include_pad=False,
- )
+ x = F.avg_pool1d(input, kernel_size=ksize, stride=1)
elif input.ndim == 4:
- x = F.avg_pool2d(
- input,
- kernel_size=ksize,
- stride=1,
- padding=0,
- ceil_mode=False,
- count_include_pad=False,
- )
+ x = F.avg_pool2d(input, kernel_size=ksize, stride=1)
elif input.ndim == 5:
- x = F.avg_pool3d(
- input,
- kernel_size=ksize,
- stride=1,
- padding=0,
- ceil_mode=False,
- count_include_pad=False,
- )
+ x = F.avg_pool3d(input, kernel_size=ksize, stride=1)
else:
raise NotImplementedError(
f"Input dimension {input.ndim - 2} not supported for torch backend"
@@ -903,10 +915,10 @@ def _get_torch(
def get(
self: AverageBlur,
- input: ArrayLike,
+ input: np.ndarray | torch.Tensor,
ksize: int,
**kwargs: Any,
- ) -> np.ndarray:
+ ) -> np.ndarray | torch.Tensor:
"""Applies the average blurring filter to the input image.
This method applies the average blurring filter to the input image.
@@ -937,7 +949,7 @@ def get(
raise NotImplementedError(f"Backend {self.backend} not supported")
-#TODO ***JH*** revise GaussianBlur - torch, typing, docstring, unit test
+#TODO ***CM*** revise typing, docstring, unit test
class GaussianBlur(Blur):
"""Applies a Gaussian blur to images using Gaussian kernels.
@@ -973,13 +985,6 @@ class GaussianBlur(Blur):
>>> plt.imshow(output_image, cmap='gray')
>>> plt.show()
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
"""
def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any):
@@ -996,7 +1001,101 @@ def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any):
"""
- super().__init__(ndimage.gaussian_filter, sigma=sigma, **kwargs)
+ self.sigma = float(sigma)
+ super().__init__(None, **kwargs)
+
+ def _get_numpy(
+ self,
+ input: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ return ndimage.gaussian_filter(
+ input,
+ sigma=self.sigma,
+ mode=kwargs.get("mode", "reflect"),
+ cval=kwargs.get("cval", 0),
+ )
+
+ def _gaussian_kernel_1d(
+ self,
+ sigma: float,
+ device,
+ dtype,
+ ) -> torch.Tensor:
+ radius = int(np.ceil(3 * sigma))
+ x = torch.arange(
+ -radius, radius + 1,
+ device=device,
+ dtype=dtype,
+ )
+ kernel = torch.exp(-(x ** 2) / (2 * sigma ** 2))
+ kernel /= kernel.sum()
+ return kernel
+
+ def _get_torch(
+ self,
+ input: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ import torch.nn.functional as F
+
+ sigma = self.sigma
+ kernel_1d = self._gaussian_kernel_1d(
+ sigma,
+ device=input.device,
+ dtype=input.dtype,
+ )
+
+ last_dim_is_channel = input.ndim >= 3
+ if last_dim_is_channel:
+ input = input.movedim(-1, 0) # C, ...
+ else:
+ input = input.unsqueeze(0) # 1, ...
+
+ # add batch dimension
+ input = input.unsqueeze(0) # 1, C, ...
+
+ spatial_dims = input.ndim - 2
+ C = input.shape[1]
+
+ for d in range(spatial_dims):
+ k = kernel_1d
+ shape = [1] * spatial_dims
+ shape[d] = -1
+ k = k.view(1, 1, *shape)
+ k = k.repeat(C, 1, *([1] * spatial_dims))
+
+ pad = [0, 0] * spatial_dims
+ radius = k.shape[2 + d] // 2
+ pad[-(2 * d + 2)] = radius
+ pad[-(2 * d + 1)] = radius
+ pad = tuple(pad)
+
+ input = F.pad(
+ input,
+ pad,
+ mode=kwargs.get("mode", "reflect"),
+ )
+
+ if spatial_dims == 1:
+ input = F.conv1d(input, k, groups=C)
+ elif spatial_dims == 2:
+ input = F.conv2d(input, k, groups=C)
+ elif spatial_dims == 3:
+ input = F.conv3d(input, k, groups=C)
+ else:
+ raise NotImplementedError(
+ f"{spatial_dims}D Gaussian blur not supported"
+ )
+
+ # restore layout
+ input = input.squeeze(0)
+ if last_dim_is_channel:
+ input = input.movedim(0, -1)
+ else:
+ input = input.squeeze(0)
+
+ return input
#TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test
@@ -1009,6 +1108,9 @@ class MedianBlur(Blur):
useful for reducing noise while preserving edges. It is particularly
effective for removing salt-and-pepper noise from images.
+ - NumPy backend: `scipy.ndimage.median_filter`
+ - Torch backend: explicit unfolding followed by `torch.median`
+
Parameters
----------
ksize: int
@@ -1016,6 +1118,11 @@ class MedianBlur(Blur):
**kwargs: dict
Additional parameters sent to the blurring function.
+ Notes
+ -----
+ Torch median blurring is significantly more expensive than mean or
+ Gaussian blurring due to explicit tensor unfolding.
+
Examples
--------
>>> import deeptrack as dt
@@ -1039,13 +1146,6 @@ class MedianBlur(Blur):
>>> plt.imshow(output_image, cmap='gray')
>>> plt.show()
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
"""
def __init__(
@@ -1053,667 +1153,367 @@ def __init__(
ksize: PropertyLike[int] = 3,
**kwargs: Any,
):
- """Initialize the parameters for median blurring.
+ self.ksize = int(ksize)
+ super().__init__(None, **kwargs)
+
+ def _get_numpy(
+ self,
+ input: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ return ndimage.median_filter(
+ input,
+ size=self.ksize,
+ mode=kwargs.get("mode", "reflect"),
+ cval=kwargs.get("cval", 0),
+ )
- This constructor initializes the parameters for median blurring.
+ def _get_torch(
+ self,
+ input: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ import torch.nn.functional as F
- Parameters
- ----------
- ksize: int
- Kernel size.
- **kwargs: Any
- Additional keyword arguments.
+ k = self.ksize
+ if k % 2 == 0:
+ raise ValueError("MedianBlur requires an odd kernel size.")
- """
+ last_dim_is_channel = input.ndim >= 3
+ if last_dim_is_channel:
+ input = input.movedim(-1, 0) # C, ...
+ else:
+ input = input.unsqueeze(0) # 1, ...
- super().__init__(ndimage.median_filter, size=ksize, **kwargs)
+ # add batch dimension
+ input = input.unsqueeze(0) # 1, C, ...
+ spatial_dims = input.ndim - 2
+ pad = k // 2
-#TODO ***AL*** revise Pool - torch, typing, docstring, unit test
-class Pool(Feature):
- """Downsamples the image by applying a function to local regions of the
- image.
+ # padding
+ pad_tuple = []
+ for _ in range(spatial_dims):
+ pad_tuple.extend([pad, pad])
+ pad_tuple = tuple(reversed(pad_tuple))
- This class reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the specified pooling
- function to each block. The result is a downsampled image where each pixel
- value represents the result of the pooling function applied to the
- corresponding block.
+ input = F.pad(
+ input,
+ pad_tuple,
+ mode=kwargs.get("mode", "reflect"),
+ )
- Parameters
- ----------
- pooling_function: function
- A function that is applied to each local region of the image.
- DOES NOT NEED TO BE WRAPPED IN ANOTHER FUNCTION.
- The `pooling_function` must accept the input image as a keyword argument
- named `input`, as it is called via `utils.safe_call`.
- Examples include `np.mean`, `np.max`, `np.min`, etc.
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional parameters sent to the pooling function.
+ # unfold spatial dimensions
+ if spatial_dims == 1:
+ x = input.unfold(2, k, 1)
+ elif spatial_dims == 2:
+ x = (
+ input
+ .unfold(2, k, 1)
+ .unfold(3, k, 1)
+ )
+ elif spatial_dims == 3:
+ x = (
+ input
+ .unfold(2, k, 1)
+ .unfold(3, k, 1)
+ .unfold(4, k, 1)
+ )
+ else:
+ raise NotImplementedError(
+ f"{spatial_dims}D median blur not supported"
+ )
- Methods
- -------
- `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray`
- Applies the pooling function to the input image.
+ # flatten neighborhood and take median
+ x = x.contiguous().view(*x.shape[:-spatial_dims], -1)
+ x = x.median(dim=-1).values
- Examples
- --------
- >>> import deeptrack as dt
- >>> import numpy as np
+ # restore layout
+ x = x.squeeze(0)
+ if last_dim_is_channel:
+ x = x.movedim(0, -1)
+ else:
+ x = x.squeeze(0)
- Create an input image:
- >>> input_image = np.random.rand(32, 32)
+ return x
- Define a pooling feature:
- >>> pooling_feature = dt.Pool(pooling_function=np.mean, ksize=4)
- >>> output_image = pooling_feature.get(input_image, ksize=4)
- >>> print(output_image.shape)
- (8, 8)
+#TODO ***CM*** revise typing, docstring, unit test
+class Pool:
+ """
+ DeepTrack v2 replacement for Pool.
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
- The filter_function must accept the input image as a keyword argument named
- input. This is required because it is called via utils.safe_call. If you
- are using functions that do not support input=... (such as OpenCV filters
- like cv2.GaussianBlur), consider using BlurCV2 instead.
+ Generic, center-preserving block pooling with NumPy and Torch backends.
+ Public API matches v1: a single integer ksize.
+ Pool size semantics:
+ - 2D input -> (ksize, ksize, 1)
+ - 3D input -> (ksize, ksize, ksize)
"""
+ _TORCH_REDUCERS_2D: Dict[Callable, Callable] = {
+ np.mean: lambda x, k, s: F.avg_pool2d(x, k, s),
+ np.sum: lambda x, k, s: F.avg_pool2d(x, k, s) * (k[0] * k[1]),
+ np.max: lambda x, k, s: F.max_pool2d(x, k, s),
+ np.min: lambda x, k, s: -F.max_pool2d(-x, k, s),
+ }
+
+ _TORCH_REDUCERS_3D: Dict[Callable, Callable] = {
+ np.mean: lambda x, k, s: F.avg_pool3d(x, k, s),
+ np.sum: lambda x, k, s: F.avg_pool3d(x, k, s) * (k[0] * k[1] * k[2]),
+ np.max: lambda x, k, s: F.max_pool3d(x, k, s),
+ np.min: lambda x, k, s: -F.max_pool3d(-x, k, s),
+ }
+
def __init__(
- self: Pool,
+ self,
pooling_function: Callable,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
+ ksize: int = 2,
):
- """Initialize the parameters for pooling input features.
+ if pooling_function not in (
+ np.mean, np.sum, np.min, np.max, np.median
+ ):
+ raise ValueError(
+ "Unsupported pooling_function. "
+ "Use one of: np.mean, np.sum, np.min, np.max, np.median."
+ )
- This constructor initializes the parameters for pooling input
- features.
+ if not isinstance(ksize, int) or ksize < 1:
+ raise ValueError("ksize must be a positive integer.")
- Parameters
- ----------
- pooling_function: Callable
- The pooling function to apply.
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional keyword arguments.
+ self.pooling_function = pooling_function
+ self.ksize = int(ksize)
+ def _get_pool_size(self, array) -> Tuple[int, int, int]:
"""
+ Determine pooling kernel size based on semantic dimensionality.
- self.pooling = pooling_function
- super().__init__(ksize=ksize, **kwargs)
+ - 2D images: (Nx, Ny) or (Nx, Ny, C) -> pool in x,y only
+ - 3D volumes: (Nx, Ny, Nz) or (Nx, Ny, Nz, C) -> pool in x,y,z
+ - Never pool over channels
+ """
+ k = self.ksize
- def get(
- self: Pool,
- image: np.ndarray | Image,
- ksize: int,
- **kwargs: Any,
- ) -> np.ndarray:
- """Applies the pooling function to the input image.
+ # 2D image
+ if array.ndim == 2:
+ return k, k, 1
- This method applies the pooling function to the input image.
+ # 3D array: could be (x, y, z) or (x, y, c)
+ if array.ndim == 3:
+ # Heuristic: small last dim → channels
+ if array.shape[-1] <= 4:
+ return k, k, 1
+ return k, k, k
- Parameters
- ----------
- image: np.ndarray
- The input image to pool.
- ksize: int
- Size of the pooling kernel.
- **kwargs: dict[str, Any]
- Additional keyword arguments.
-
- Returns
- -------
- np.ndarray
- The pooled image.
-
- """
+ # 4D array: (x, y, z, c)
+ if array.ndim == 4:
+ return k, k, k
- kwargs.pop("func", False)
- kwargs.pop("image", False)
- kwargs.pop("block_size", False)
- return utils.safe_call(
- skimage.measure.block_reduce,
- image=image,
- func=self.pooling,
- block_size=ksize,
- **kwargs,
+ raise ValueError(
+ f"Unsupported array shape {array.shape} for pooling."
)
+ def _crop_center(self, array):
+ px, py, pz = self._get_pool_size(array)
+
+ # 2D (or effectively 2D)
+ if array.ndim < 3 or pz == 1:
+ H, W = array.shape[:2]
+ crop_h = (H // px) * px
+ crop_w = (W // py) * py
+ off_h = (H - crop_h) // 2
+ off_w = (W - crop_w) // 2
+ return array[
+ off_h : off_h + crop_h,
+ off_w : off_w + crop_w,
+ ...
+ ]
+
+ # 3D
+ Z, H, W = array.shape[:3]
+ crop_z = (Z // pz) * pz
+ crop_h = (H // px) * px
+ crop_w = (W // py) * py
+ off_z = (Z - crop_z) // 2
+ off_h = (H - crop_h) // 2
+ off_w = (W - crop_w) // 2
+ return array[
+ off_z : off_z + crop_z,
+ off_h : off_h + crop_h,
+ off_w : off_w + crop_w,
+ ...
+ ]
+
+ def _pool_numpy(self, array: np.ndarray) -> np.ndarray:
+ array = self._crop_center(array)
+ px, py, pz = self._get_pool_size(array)
+
+ if array.ndim < 3 or pz == 1:
+ pool_shape = (px, py) + (1,) * (array.ndim - 2)
+ else:
+ pool_shape = (pz, px, py) + (1,) * (array.ndim - 3)
-#TODO ***AL*** revise AveragePooling - torch, typing, docstring, unit test
-class AveragePooling(Pool):
- """Apply average pooling to an image.
-
- This class reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the average function to
- each block. The result is a downsampled image where each pixel value
- represents the average value within the corresponding block of the
- original image.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: dict
- Additional parameters sent to the pooling function.
-
- Examples
- --------
- >>> import deeptrack as dt
- >>> import numpy as np
-
- Create an input image:
- >>> input_image = np.random.rand(32, 32)
-
- Define an average pooling feature:
- >>> average_pooling = dt.AveragePooling(ksize=4)
- >>> output_image = average_pooling(input_image)
- >>> print(output_image.shape)
- (8, 8)
-
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
- """
-
- def __init__(
- self: Pool,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
- """Initialize the parameters for average pooling.
-
- This constructor initializes the parameters for average pooling.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional keyword arguments.
-
- """
-
- super().__init__(np.mean, ksize=ksize, **kwargs)
-
-
-class MaxPooling(Pool):
- """Apply max-pooling to images.
-
- `MaxPooling` reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the `max` function
- to each block. The result is a downsampled image where each pixel value
- represents the maximum value within the corresponding block of the
- original image. This is useful for reducing the size of an image while
- retaining the most significant features.
-
- If the backend is NumPy, the downsampling is performed using
- `skimage.measure.block_reduce`.
-
- If the backend is PyTorch, the downsampling is performed using
- `torch.nn.functional.max_pool2d`.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional parameters sent to the pooling function.
-
- Examples
- --------
- >>> import deeptrack as dt
-
- Create an input image:
- >>> import numpy as np
- >>>
- >>> input_image = np.random.rand(32, 32)
-
- Define and use a max-pooling feature:
-
- >>> max_pooling = dt.MaxPooling(ksize=8)
- >>> output_image = max_pooling(input_image)
- >>> output_image.shape
- (4, 4)
-
- """
-
- def __init__(
- self: MaxPooling,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
- """Initialize the parameters for max-pooling.
-
- This constructor initializes the parameters for max-pooling.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional keyword arguments.
-
- """
-
- super().__init__(np.max, ksize=ksize, **kwargs)
-
- def get(
- self: MaxPooling,
- image: NDArray[Any] | torch.Tensor,
- ksize: int=3,
- **kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor:
- """Max-pooling of input.
-
- Checks the current backend and chooses the appropriate function to pool
- the input image, either `._get_torch()` or `._get_numpy()`.
-
- Parameters
- ----------
- image: array or tensor
- Input array or tensor be pooled.
- ksize: int
- Kernel size of the pooling operation.
-
- Returns
- -------
- array or tensor
- The pooled input as `NDArray` or `torch.Tensor` depending on
- the backend.
-
- """
-
- if self.get_backend() == "numpy":
- return self._get_numpy(image, ksize, **kwargs)
-
- if self.get_backend() == "torch":
- return self._get_torch(image, ksize, **kwargs)
-
- raise NotImplementedError(f"Backend {self.backend} not supported")
-
- def _get_numpy(
- self: MaxPooling,
- image: NDArray[Any],
- ksize: int=3,
- **kwargs: Any,
- ) -> NDArray[Any]:
- """Max-pooling pooling with the NumPy backend enabled.
-
- Returns the result of the input array passed to the scikit image
- `block_reduce()` function with `np.max()` as the pooling function.
-
- Parameters
- ----------
- image: array
- Input array to be pooled.
- ksize: int
- Kernel size of the pooling operation.
-
- Returns
- -------
- array
- The pooled image as a NumPy array.
-
- """
-
- return utils.safe_call(
- skimage.measure.block_reduce,
- image=image,
- func=np.max,
- block_size=ksize,
- **kwargs,
+ return skimage.measure.block_reduce(
+ array,
+ block_size=pool_shape,
+ func=self.pooling_function,
)
- def _get_torch(
- self: MaxPooling,
- image: torch.Tensor,
- ksize: int=3,
- **kwargs: Any,
- ) -> torch.Tensor:
- """Max-pooling with the PyTorch backend enabled.
-
-
- Returns the result of the tensor passed to a PyTorch max
- pooling layer.
-
- Parameters
- ----------
- image: torch.Tensor
- Input tensor to be pooled.
- ksize: int
- Kernel size of the pooling operation.
-
- Returns
- -------
- torch.Tensor
- The pooled image as a `torch.Tensor`.
-
- """
+ def _pool_torch(self, array: torch.Tensor) -> torch.Tensor:
+ array = self._crop_center(array)
+ px, py, pz = self._get_pool_size(array)
- # If input tensor is 2D
- if len(image.shape) == 2:
- # Add batch dimension for max-pooling
- expanded_image = image.unsqueeze(0)
+ is_3d = array.ndim >= 3 and pz > 1
- pooled_image = torch.nn.functional.max_pool2d(
- expanded_image, kernel_size=ksize,
+ if not is_3d:
+ extra = array.shape[2:]
+ C = int(np.prod(extra)) if extra else 1
+ x = array.reshape(1, C, array.shape[0], array.shape[1])
+ kernel = (px, py)
+ stride = (px, py)
+ reducers = self._TORCH_REDUCERS_2D
+ else:
+ extra = array.shape[3:]
+ C = int(np.prod(extra)) if extra else 1
+ x = array.reshape(
+ 1, C, array.shape[0], array.shape[1], array.shape[2]
)
- # Remove the expanded dim
- return pooled_image.squeeze(0)
-
- return torch.nn.functional.max_pool2d(
- image,
- kernel_size=ksize,
- )
-
-
-class MinPooling(Pool):
- """Apply min-pooling to images.
-
- `MinPooling` reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the `min` function to
- each block. The result is a downsampled image where each pixel value
- represents the minimum value within the corresponding block of the original
- image.
-
- If the backend is NumPy, the downsampling is performed using
- `skimage.measure.block_reduce`.
-
- If the backend is PyTorch, the downsampling is performed using the inverse
- of `torch.nn.functional.max_pool2d` by changing the sign of the input.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional parameters sent to the pooling function.
-
- Examples
- --------
- >>> import deeptrack as dt
-
- Create an input image:
- >>> import numpy as np
- >>>
- >>> input_image = np.random.rand(32, 32)
-
- Define and use a min-pooling feature:
- >>> min_pooling = dt.MinPooling(ksize=4)
- >>> output_image = min_pooling(input_image)
- >>> output_image.shape
- (8, 8)
-
- """
-
- def __init__(
- self: MinPooling,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
- """Initialize the parameters for min-pooling.
-
- This constructor initializes the parameters for min-pooling and checks
- whether to use the NumPy or PyTorch implementation, defaults to NumPy.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional keyword arguments.
-
- """
-
- super().__init__(np.min, ksize=ksize, **kwargs)
-
- def get(
- self: MinPooling,
- image: NDArray[Any] | torch.Tensor,
- ksize: int=3,
- **kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor:
- """Min pooling of input.
-
- Checks the current backend and chooses the appropriate function to pool
- the input image, either `._get_torch()` or `._get_numpy()`.
-
- Parameters
- ----------
- image: array or tensor
- Input array or tensor to be pooled.
- ksize: int
- Kernel size of the pooling operation.
-
- Returns
- -------
- array or tensor
- The pooled image as `NDArray` or `torch.Tensor` depending on the
- backend.
-
- """
-
- if self.get_backend() == "numpy":
- return self._get_numpy(image, ksize, **kwargs)
-
- if self.get_backend() == "torch":
- return self._get_torch(image, ksize, **kwargs)
-
- raise NotImplementedError(f"Backend {self.backend} not supported")
+ kernel = (pz, px, py)
+ stride = (pz, px, py)
+ reducers = self._TORCH_REDUCERS_3D
+
+ # Median: explicit unfolding
+ if self.pooling_function is np.median:
+ if is_3d:
+ x_u = (
+ x.unfold(2, pz, pz)
+ .unfold(3, px, px)
+ .unfold(4, py, py)
+ )
+ x_u = x_u.contiguous().view(
+ 1, C,
+ x_u.shape[2],
+ x_u.shape[3],
+ x_u.shape[4],
+ -1,
+ )
+ pooled = x_u.median(dim=-1).values
+ else:
+ x_u = x.unfold(2, px, px).unfold(3, py, py)
+ x_u = x_u.contiguous().view(
+ 1, C,
+ x_u.shape[2],
+ x_u.shape[3],
+ -1,
+ )
+ pooled = x_u.median(dim=-1).values
+ else:
+ reducer = reducers[self.pooling_function]
+ pooled = reducer(x, kernel, stride)
- def _get_numpy(
- self: MinPooling,
- image: NDArray[Any],
- ksize: int=3,
- **kwargs: Any,
- ) -> NDArray[Any]:
- """Min-pooling with the NumPy backend.
+ return pooled.reshape(pooled.shape[2:] + extra)
- Returns the result of the input array passed to the scikit
- `image block_reduce()` function with `np.min()` as the pooling
- function.
+ def __call__(self, array):
+ if isinstance(array, np.ndarray):
+ return self._pool_numpy(array)
- Parameters
- ----------
- image: NDArray
- Input image to be pooled.
- ksize: int
- Kernel size of the pooling operation.
+ if TORCH_AVAILABLE and isinstance(array, torch.Tensor):
+ return self._pool_torch(array)
- Returns
- -------
- NDArray
- The pooled image as a `NDArray`.
-
- """
-
- return utils.safe_call(
- skimage.measure.block_reduce,
- image=image,
- func=np.min,
- block_size=ksize,
- **kwargs,
+ raise TypeError(
+ "Pool only supports np.ndarray or torch.Tensor inputs."
)
- def _get_torch(
- self: MinPooling,
- image: torch.Tensor,
- ksize: int=3,
- **kwargs: Any,
- ) -> torch.Tensor:
- """Min-pooling with the PyTorch backend.
-
- As PyTorch does not have a min-pooling layer, the equivalent operation
- is to first multiply the input tensor with `-1`, then perform
- max-pooling, and finally multiply the max pooled tensor with `-1`.
- Parameters
- ----------
- image: torch.Tensor
- Input tensor to be pooled.
- ksize: int
- Kernel size of the pooling operation.
+class AveragePooling(Pool):
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.mean, ksize)
- Returns
- -------
- torch.Tensor
- The pooled image as a `torch.Tensor`.
- """
+class SumPooling(Pool):
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.sum, ksize)
- # If input tensor is 2D
- if len(image.shape) == 2:
- # Add batch dimension for min-pooling
- expanded_image = image.unsqueeze(0)
- pooled_image = - torch.nn.functional.max_pool2d(
- expanded_image * (-1),
- kernel_size=ksize,
- )
+class MinPooling(Pool):
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.min, ksize)
- # Remove the expanded dim
- return pooled_image.squeeze(0)
- return -torch.nn.functional.max_pool2d(
- image * (-1),
- kernel_size=ksize,
- )
+class MaxPooling(Pool):
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.max, ksize)
-#TODO ***AL*** revise MedianPooling - torch, typing, docstring, unit test
class MedianPooling(Pool):
- """Apply median pooling to images.
-
- This class reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the median function to
- each block. The result is a downsampled image where each pixel value
- represents the median value within the corresponding block of the
- original image. This is useful for reducing the size of an image while
- retaining the most significant features.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional parameters sent to the pooling function.
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.median, ksize)
- Examples
- --------
- >>> import deeptrack as dt
- >>> import numpy as np
-
- Create an input image:
- >>> input_image = np.random.rand(32, 32)
-
- Define a median pooling feature:
- >>> median_pooling = dt.MedianPooling(ksize=3)
- >>> output_image = median_pooling(input_image)
- >>> print(output_image.shape)
- (32, 32)
-
- Visualize the input and output images:
- >>> plt.figure(figsize=(8, 4))
- >>> plt.subplot(1, 2, 1)
- >>> plt.imshow(input_image, cmap='gray')
- >>> plt.subplot(1, 2, 2)
- >>> plt.imshow(output_image, cmap='gray')
- >>> plt.show()
-
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
- """
-
- def __init__(
- self: MedianPooling,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
- """Initialize the parameters for median pooling.
-
- This constructor initializes the parameters for median pooling.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional keyword arguments.
-
- """
-
- super().__init__(np.median, ksize=ksize, **kwargs)
class Resize(Feature):
"""Resize an image to a specified size.
- `Resize` resizes an image using:
- - OpenCV (`cv2.resize`) for NumPy arrays.
- - PyTorch (`torch.nn.functional.interpolate`) for PyTorch tensors.
+ `Resize` resizes images following the channels-last semantic
+ convention.
+
+ The operation supports both NumPy arrays and PyTorch tensors:
+ - NumPy arrays are resized using OpenCV (`cv2.resize`).
+ - PyTorch tensors are resized using `torch.nn.functional.interpolate`.
- The interpretation of the `dsize` parameter follows the convention
- of the underlying backend:
- - **NumPy (OpenCV)**: `dsize` is given as `(width, height)` to match
- OpenCV’s default.
- - **PyTorch**: `dsize` is given as `(height, width)`.
+ In all cases, the input is interpreted as having spatial dimensions
+ first and an optional channel dimension last.
Parameters
----------
- dsize: PropertyLike[tuple[int, int]]
- The target size. Format depends on backend: `(width, height)` for
- NumPy, `(height, width)` for PyTorch.
- **kwargs: Any
- Additional parameters sent to the underlying resize function:
- - NumPy: passed to `cv2.resize`.
- - PyTorch: passed to `torch.nn.functional.interpolate`.
+ dsize : PropertyLike[tuple[int, int]]
+ Target output size given as (width, height). This convention is
+ backend-independent and applies equally to NumPy and PyTorch inputs.
+
+ **kwargs : Any
+ Additional keyword arguments forwarded to the underlying resize
+ implementation:
+ - NumPy backend: passed to `cv2.resize`.
+ - PyTorch backend: passed to
+ `torch.nn.functional.interpolate`.
Methods
-------
get(
- image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs
+ image: np.ndarray | torch.Tensor,
+ dsize: tuple[int, int],
+ **kwargs
) -> np.ndarray | torch.Tensor
Resize the input image to the specified size.
Examples
--------
- >>> import deeptrack as dt
+ NumPy example:
- Numpy example:
>>> import numpy as np
- >>>
- >>> input_image = np.random.rand(16, 16) # Create image
- >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4)
- >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8)
- >>> print(resized_image.shape)
+ >>> input_image = np.random.rand(16, 16)
+ >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4)
+ >>> resized_image = feature.resolve(input_image)
+ >>> resized_image.shape
(4, 8)
PyTorch example:
+
>>> import torch
- >>>
- >>> input_image = torch.rand(1, 1, 16, 16) # Create image
- >>> feature = dt.math.Resize(dsize=(4, 8)) # (height=4, width=8)
- >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8)
- >>> print(resized_image.shape)
- torch.Size([1, 1, 4, 8])
+ >>> input_image = torch.rand(16, 16) # channels-last
+ >>> feature = dt.math.Resize(dsize=(8, 4))
+ >>> resized_image = feature.resolve(input_image)
+ >>> resized_image.shape
+ torch.Size([4, 8])
+
+ Notes
+ -----
+ - Resize follows channels-last semantics, consistent with other features
+ such as Pool and Blur.
+ - Torch tensors with channels-first layout (e.g. (C, H, W) or
+ (N, C, H, W)) are not supported and must be converted to
+ channels-last format before resizing.
+ - For PyTorch tensors, bilinear interpolation is used with
+ `align_corners=False`, closely matching OpenCV’s default behavior.
"""
@@ -1738,78 +1538,117 @@ def __init__(
def get(
self: Resize,
- image: NDArray | torch.Tensor,
+ image: np.ndarray | torch.Tensor,
dsize: tuple[int, int],
**kwargs: Any,
- ) -> NDArray | torch.Tensor:
+ ) -> np.ndarray | torch.Tensor:
"""Resize the input image to the specified size.
Parameters
----------
- image: np.ndarray or torch.Tensor
- The input image to resize.
- - NumPy arrays may be grayscale (H, W) or color (H, W, C).
- - Torch tensors are expected in one of the following formats:
- (N, C, H, W), (C, H, W), or (H, W).
- dsize: tuple[int, int]
- Desired output size of the image.
- - NumPy: (width, height)
- - PyTorch: (height, width)
- **kwargs: Any
- Additional keyword arguments passed to the underlying resize
- function (`cv2.resize` or `torch.nn.functional.interpolate`).
+ image : np.ndarray or torch.Tensor
+ Input image following channels-last semantics.
+
+ Supported shapes are:
+ - (H, W)
+ - (H, W, C)
+ - (Z, H, W)
+ - (Z, H, W, C)
+
+ For PyTorch tensors, channels-first layouts such as (C, H, W) or
+ (N, C, H, W) are not supported and must be converted to
+ channels-last format before calling `Resize`.
+
+ dsize : tuple[int, int]
+ Desired output size given as (width, height). This convention is
+ backend-independent and applies to both NumPy and PyTorch inputs.
+
+ **kwargs : Any
+ Additional keyword arguments passed to the underlying resize
+ implementation:
+ - NumPy backend: forwarded to `cv2.resize`.
+ - PyTorch backend: forwarded to `torch.nn.functional.interpolate`.
Returns
-------
np.ndarray or torch.Tensor
- The resized image in the same type and dimensionality format as
- input.
+ The resized image, with the same type and dimensionality layout as
+ the input image.
Notes
-----
+ - Resize follows the same channels-last semantic convention as other
+ features in `deeptrack.math`.
- For PyTorch tensors, resizing uses bilinear interpolation with
- `align_corners=False`. This choice matches OpenCV’s `cv2.resize`
- default behavior when resizing NumPy arrays, aiming to produce nearly
- identical results between both backends.
+ `align_corners=False`, which closely matches OpenCV’s default behavior.
"""
- if self._wrap_array_with_image:
- image = strip(image)
+ target_w, target_h = dsize
+ # Torch backend
if apc.is_torch_array(image):
- original_shape = image.shape
-
- # Reshape input to (N, C, H, W)
- if image.ndim == 2: # (H, W)
- image = image.unsqueeze(0).unsqueeze(0)
- elif image.ndim == 3: # (C, H, W)
- image = image.unsqueeze(0)
- elif image.ndim != 4:
+ import torch.nn.functional as F
+
+ original_ndim = image.ndim
+ has_channels = (
+ image.ndim >= 3 and image.shape[-1] <= 4
+ )
+
+ # Bring to (N, C, H, W)
+ if image.ndim == 2:
+ # (H, W) -> (1, 1, H, W)
+ x = image.unsqueeze(0).unsqueeze(0)
+
+ elif image.ndim == 3 and has_channels:
+ # (H, W, C) -> (1, C, H, W)
+ x = image.permute(2, 0, 1).unsqueeze(0)
+
+ elif image.ndim == 3:
+ # (Z, H, W) -> treat Z as batch
+ x = image.unsqueeze(1)
+
+ elif image.ndim == 4 and has_channels:
+ # (Z, H, W, C) -> (Z, C, H, W)
+ x = image.permute(0, 3, 1, 2)
+
+ else:
raise ValueError(
- "Resize only supports tensors with shape (N, C, H, W), "
- "(C, H, W), or (H, W)."
+ f"Unsupported tensor shape {image.shape} for Resize."
)
- resized = torch.nn.functional.interpolate(
- image,
- size=dsize,
+ # Resize spatial dimensions
+ resized = F.interpolate(
+ x,
+ size=(target_h, target_w),
mode="bilinear",
align_corners=False,
)
- # Restore original dimensionality
- if len(original_shape) == 2:
- resized = resized.squeeze(0).squeeze(0)
- elif len(original_shape) == 3:
- resized = resized.squeeze(0)
+ # Restore original layout
+ if original_ndim == 2:
+ return resized.squeeze(0).squeeze(0)
+
+ if original_ndim == 3 and has_channels:
+ return resized.squeeze(0).permute(1, 2, 0)
+
+ if original_ndim == 3:
+ return resized.squeeze(1)
- return resized
+ if original_ndim == 4:
+ return resized.permute(0, 2, 3, 1)
+ raise RuntimeError("Unexpected shape restoration path.")
+
+ # NumPy / OpenCV backend
else:
import cv2
+
+ # OpenCV expects (width, height)
return utils.safe_call(
- cv2.resize, positional_args=[image, dsize], **kwargs
+ cv2.resize,
+ positional_args=[image, (target_w, target_h)],
+ **kwargs,
)
@@ -1865,10 +1704,9 @@ class BlurCV2(Feature):
Notes
-----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
+ BlurCV2 is NumPy-only and does not support PyTorch tensors.
+ This class is intended for OpenCV-specific filters that are
+ not available in the backend-agnostic math layer.
"""
@@ -1964,6 +1802,12 @@ def get(
"""
+ if apc.is_torch_array(image):
+ raise TypeError(
+ "BlurCV2 only supports NumPy arrays. "
+ "For Torch tensors, use Blur or GaussianBlur instead."
+ )
+
kwargs.pop("name", None)
result = self.filter(src=image, **kwargs)
return result
@@ -2015,10 +1859,7 @@ class BilateralBlur(BlurCV2):
Notes
-----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
+ BilateralBlur is NumPy-only and does not support PyTorch tensors.
"""
@@ -2059,3 +1900,73 @@ def __init__(
sigmaSpace=sigma_space,
**kwargs,
)
+
+
+def isotropic_dilation(
+ mask,
+ radius: float,
+ *,
+ backend: str,
+ device=None,
+ dtype=None,
+):
+ if radius <= 0:
+ return mask
+
+ if backend == "numpy":
+ from skimage.morphology import isotropic_dilation
+ return isotropic_dilation(mask, radius)
+
+ # torch backend
+ import torch
+
+ r = int(np.ceil(radius))
+ kernel = torch.ones(
+ (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
+ device=device or mask.device,
+ dtype=dtype or torch.float32,
+ )
+
+ x = mask.to(dtype=kernel.dtype)[None, None]
+ y = torch.nn.functional.conv3d(
+ x,
+ kernel,
+ padding=r,
+ )
+
+ return (y[0, 0] > 0)
+
+
+def isotropic_erosion(
+ mask,
+ radius: float,
+ *,
+ backend: str,
+ device=None,
+ dtype=None,
+):
+ if radius <= 0:
+ return mask
+
+ if backend == "numpy":
+ from skimage.morphology import isotropic_erosion
+ return isotropic_erosion(mask, radius)
+
+ import torch
+
+ r = int(np.ceil(radius))
+ kernel = torch.ones(
+ (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
+ device=device or mask.device,
+ dtype=dtype or torch.float32,
+ )
+
+ x = mask.to(dtype=kernel.dtype)[None, None]
+ y = torch.nn.functional.conv3d(
+ x,
+ kernel,
+ padding=r,
+ )
+
+ required = kernel.numel()
+ return (y[0, 0] >= required)
\ No newline at end of file
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 5149bdae..a7394689 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -137,11 +137,13 @@ def _pad_volume(
from __future__ import annotations
from pint import Quantity
-from typing import Any
+from typing import Any, TYPE_CHECKING
import warnings
import numpy as np
-from scipy.ndimage import convolve
+from scipy.ndimage import convolve # might be removed later
+import torch
+import torch.nn.functional as F
from deeptrack.backend.units import (
ConversionTable,
@@ -149,23 +151,37 @@ def _pad_volume(
get_active_scale,
get_active_voxel_size,
)
-from deeptrack.math import AveragePooling
+from deeptrack.math import AveragePooling, SumPooling
from deeptrack.features import propagate_data_to_dependencies
from deeptrack.features import DummyFeature, Feature, StructuralFeature
-from deeptrack.image import Image, pad_image_to_fft
+from deeptrack.image import pad_image_to_fft
from deeptrack.types import ArrayLike, PropertyLike
from deeptrack import image
from deeptrack import units_registry as u
+from deeptrack import TORCH_AVAILABLE, image
+from deeptrack.backend import xp
+from deeptrack.scatterers import ScatteredVolume, ScatteredField
+
+if TORCH_AVAILABLE:
+ import torch
+
+if TYPE_CHECKING:
+ import torch
+
#TODO ***??*** revise Microscope - torch, typing, docstring, unit test
class Microscope(StructuralFeature):
"""Simulates imaging of a sample using an optical system.
- This class combines a feature-set that defines the sample to be imaged with
- a feature-set defining the optical system, enabling the simulation of
- optical imaging processes.
+ This class combines the sample to be imaged with the optical system,
+ enabling the simulation of optical imaging processes.
+ A Microscope:
+ - validates the semantic compatibility between scatterers and optics
+ - interprets volume-based scatterers into scalar fields when needed
+ - delegates numerical propagation to the objective (Optics)
+ - performs detector downscaling according to its physical semantics
Parameters
----------
@@ -186,10 +202,16 @@ class Microscope(StructuralFeature):
Methods
-------
- `get(image: Image or None, **kwargs: Any) -> Image`
+ `get(image: np.ndarray or None, **kwargs: Any) -> np.ndarray`
Simulates the imaging process using the defined optical system and
returns the resulting image.
+ Notes
+ -----
+ All volume scatterers imaged by a Microscope instance are assumed to
+ share the same contrast mechanism (e.g. refractive index or fluorescence).
+ Mixing contrast types is not supported.
+
Examples
--------
Simulating an image using a brightfield optical system:
@@ -238,13 +260,41 @@ def __init__(
self._sample = self.add_feature(sample)
self._objective = self.add_feature(objective)
- self._sample.store_properties()
+
+ def _validate_input(self, scattered):
+ if hasattr(self._objective, "validate_input"):
+ self._objective.validate_input(scattered)
+
+ def _extract_contrast_volume(self, scattered):
+ if hasattr(self._objective, "extract_contrast_volume"):
+ return self._objective.extract_contrast_volume(
+ scattered,
+ **self._objective.properties(),
+ )
+ return scattered.array
+
+ def _downscale_image(self, image, upscale):
+ if hasattr(self._objective, "downscale_image"):
+ return self._objective.downscale_image(image, upscale)
+
+ if not np.any(np.array(upscale) != 1):
+ return image
+
+ ux, uy = upscale[:2]
+ if ux != uy:
+ raise ValueError(
+ f"Energy-conserving detector integration requires ux == uy, "
+ f"got ux={ux}, uy={uy}."
+ )
+ if isinstance(ux, float) and ux.is_integer():
+ ux = int(ux)
+ return AveragePooling(ux)(image)
def get(
self: Microscope,
- image: Image | None,
+ image: np.ndarray | torch.Tensor | None = None,
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray | torch.Tensor:
"""Generate an image of the sample using the defined optical system.
This method processes the sample through the optical system to
@@ -252,14 +302,14 @@ def get(
Parameters
----------
- image: Image | None
+ image: np.ndarray | torch.Tensor | None
The input image to be processed. If None, a new image is created.
**kwargs: Any
Additional parameters for the imaging process.
Returns
-------
- Image: Image
+ image: np.ndarray | torch.Tensor
The processed image after applying the optical system.
Examples
@@ -280,9 +330,6 @@ def get(
# Grab properties from the objective to pass to the sample
additional_sample_kwargs = self._objective.properties()
- # Calculate required output image for the given upscale
- # This way of providing the upscale will be deprecated in the future
- # in favor of dt.Upscale().
_upscale_given_by_optics = additional_sample_kwargs["upscale"]
if np.array(_upscale_given_by_optics).size == 1:
_upscale_given_by_optics = (_upscale_given_by_optics,) * 3
@@ -325,67 +372,61 @@ def get(
if not isinstance(list_of_scatterers, list):
list_of_scatterers = [list_of_scatterers]
+ # Semantic validation (per scatterer)
+ for scattered in list_of_scatterers:
+ self._validate_input(scattered)
+
# All scatterers that are defined as volumes.
volume_samples = [
scatterer
for scatterer in list_of_scatterers
- if not scatterer.get_property("is_field", default=False)
+ if isinstance(scatterer, ScatteredVolume)
]
# All scatterers that are defined as fields.
field_samples = [
scatterer
for scatterer in list_of_scatterers
- if scatterer.get_property("is_field", default=False)
+ if isinstance(scatterer, ScatteredField)
]
-
+
# Merge all volumes into a single volume.
sample_volume, limits = _create_volume(
volume_samples,
**additional_sample_kwargs,
)
- sample_volume = Image(sample_volume)
- # Merge all properties into the volume.
- for scatterer in volume_samples + field_samples:
- sample_volume.merge_properties_from(scatterer)
+ print('prop', volume_samples[0].properties)
+
+ # Interpret the merged volume semantically
+ sample_volume = self._extract_contrast_volume(
+ ScatteredVolume(
+ array=sample_volume,
+ properties=volume_samples[0].properties,
+ ),
+ )
# Let the objective know about the limits of the volume and all the fields.
propagate_data_to_dependencies(
self._objective,
limits=limits,
- fields=field_samples,
+ fields=field_samples, # should We add upscale?
)
imaged_sample = self._objective.resolve(sample_volume)
- # Upscale given by the optics needs to be handled separately.
- if _upscale_given_by_optics != (1, 1, 1):
- imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))(
- imaged_sample
- )
-
- # Merge with input
- if not image:
- if not self._wrap_array_with_image and isinstance(imaged_sample, Image):
- return imaged_sample._value
- else:
- return imaged_sample
-
- if not isinstance(image, list):
- image = [image]
- for i in range(len(image)):
- image[i].merge_properties_from(imaged_sample)
- return image
-
- # def _no_wrap_format_input(self, *args, **kwargs) -> list:
- # return self._image_wrapped_format_input(*args, **kwargs)
-
- # def _no_wrap_process_and_get(self, *args, **feature_input) -> list:
- # return self._image_wrapped_process_and_get(*args, **feature_input)
+ imaged_sample = self._downscale_image(imaged_sample, upscale)
+ # # Handling upscale from dt.Upscale() here to eliminate Image
+ # # wrapping issues.
+ # if np.any(np.array(upscale) != 1):
+ # ux, uy = upscale[:2]
+ # if contrast_type == "intensity":
+ # print("Using sum pooling for intensity downscaling.")
+ # imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample)
+ # else:
+ # imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample)
- # def _no_wrap_process_output(self, *args, **feature_input):
- # return self._image_wrapped_process_output(*args, **feature_input)
+ return imaged_sample
#TODO ***??*** revise Optics - torch, typing, docstring, unit test
@@ -569,6 +610,15 @@ def __init__(
"""
+ def validate_scattered(self, scattered):
+ pass
+
+ def extract_contrast_volume(self, scattered):
+ pass
+
+ def downscale_image(self, image, upscale):
+ pass
+
def get_voxel_size(
resolution: float | ArrayLike[float],
magnification: float,
@@ -673,6 +723,7 @@ def _process_properties(
wavelength = propertydict["wavelength"]
voxel_size = get_active_voxel_size()
radius = NA / wavelength * np.array(voxel_size)
+ print('Pupil radius (in pixels):', radius)
if np.any(radius[:2] > 0.5):
required_upscale = np.max(np.ceil(radius[:2] * 2))
@@ -757,19 +808,18 @@ def _pupil(
W, H = np.meshgrid(y, x)
RHO = (W ** 2 + H ** 2).astype(complex)
- pupil_function = Image((RHO < 1) + 0.0j, copy=False)
+ pupil_function = (RHO < 1) + 0.0j
# Defocus
- z_shift = Image(
+ z_shift = (
2
* np.pi
* refractive_index_medium
/ wavelength
* voxel_size[2]
- * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO),
- copy=False,
+ * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO)
)
- z_shift._value[z_shift._value.imag != 0] = 0
+ z_shift[z_shift.imag != 0] = 0
try:
z_shift = np.nan_to_num(z_shift, False, 0, 0, 0)
@@ -1007,7 +1057,7 @@ class Fluorescence(Optics):
Methods
-------
- `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> Image`
+ `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> np.ndarray`
Simulates the imaging process using a fluorescence microscope.
Examples
@@ -1024,12 +1074,71 @@ class Fluorescence(Optics):
"""
+
+ def validate_input(self, scattered):
+ """Semantic validation for fluorescence microscopy."""
+
+ # Fluorescence cannot operate on coherent fields
+ if isinstance(scattered, ScatteredField):
+ raise TypeError(
+ "Fluorescence microscope cannot operate on ScatteredField."
+ )
+
+
+ def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray:
+ voxel_size = np.asarray(get_active_voxel_size(), float)
+ voxel_volume = np.prod(voxel_size)
+
+ intensity = scattered.get_property("intensity", None)
+ value = scattered.get_property("value", None)
+ ri = scattered.get_property("refractive_index", None)
+
+ # Refractive index is always ignored in fluorescence
+ if ri is not None:
+ warnings.warn(
+ "Scatterer defines 'refractive_index', which is ignored in "
+ "fluorescence microscopy.",
+ UserWarning,
+ )
+
+ # Preferred, physically meaningful case
+ if intensity is not None:
+ return intensity * voxel_volume * scattered.array
+
+ # Fallback: legacy / dimensionless brightness
+ warnings.warn(
+ "Fluorescence scatterer has no 'intensity'. Interpreting 'value' as a "
+ "non-physical brightness factor. Quantitative interpretation is invalid. "
+ "Define 'intensity' to model physical fluorescence emission.",
+ UserWarning,
+ )
+
+ return value * scattered.array
+
+ def downscale_image(self, image: np.ndarray, upscale):
+ """Detector downscaling (energy conserving)"""
+ if not np.any(np.array(upscale) != 1):
+ return image
+
+ ux, uy = upscale[:2]
+ if ux != uy:
+ raise ValueError(
+ f"Energy-conserving detector integration requires ux == uy, "
+ f"got ux={ux}, uy={uy}."
+ )
+ if isinstance(ux, float) and ux.is_integer():
+ ux = int(ux)
+
+ # Energy-conserving detector integration
+ return SumPooling(ux)(image)
+
+
def get(
self: Fluorescence,
illuminated_volume: ArrayLike[complex],
limits: ArrayLike[int],
**kwargs: Any,
- ) -> Image:
+ ) -> ArrayLike[complex]:
"""Simulates the imaging process using a fluorescence microscope.
This method convolves the 3D illuminated volume with a pupil function
@@ -1048,7 +1157,7 @@ def get(
Returns
-------
- Image: Image
+ image: np.ndarray
A 2D image object representing the fluorescence projection.
Notes
@@ -1066,7 +1175,7 @@ def get(
>>> optics = dt.Fluorescence(
... NA=1.4, wavelength=0.52e-6, magnification=60,
... )
- >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex))
+ >>> volume = np.ones((128, 128, 10), dtype=complex)
>>> limits = np.array([[0, 128], [0, 128], [0, 10]])
>>> properties = optics.properties()
>>> filtered_properties = {
@@ -1118,9 +1227,7 @@ def get(
]
z_limits = limits[2, :]
- output_image = Image(
- np.zeros((*padded_volume.shape[0:2], 1)), copy=False
- )
+ output_image = np.zeros((*padded_volume.shape[0:2], 1))
index_iterator = range(padded_volume.shape[2])
@@ -1156,12 +1263,12 @@ def get(
field = np.fft.ifft2(convolved_fourier_field)
# # Discard remaining imaginary part (should be 0 up to rounding error)
field = np.real(field)
- output_image._value[:, :, 0] += field[
+ output_image[:, :, 0] += field[
: padded_volume.shape[0], : padded_volume.shape[1]
]
output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]]
- output_image.properties = illuminated_volume.properties + pupils.properties
+ # output_image.properties = illuminated_volume.properties + pupils.properties
return output_image
@@ -1234,7 +1341,7 @@ class Brightfield(Optics):
-------
`get(illuminated_volume: array_like[complex],
limits: array_like[int, int], fields: array_like[complex],
- **kwargs: Any) -> Image`
+ **kwargs: Any) -> np.ndarray`
Simulates imaging with brightfield microscopy.
@@ -1250,9 +1357,52 @@ class Brightfield(Optics):
"""
+
__conversion_table__ = ConversionTable(
- working_distance=(u.meter, u.meter),
- )
+ working_distance=(u.meter, u.meter),
+)
+
+ def validate_input(self, scattered):
+ """Semantic validation for brightfield microscopy."""
+
+ if isinstance(scattered, ScatteredVolume):
+ warnings.warn(
+ "Brightfield imaging from ScatteredVolume assumes a "
+ "weak-phase / projection approximation. "
+ "Use ScatteredField for physically accurate brightfield simulations.",
+ UserWarning,
+ )
+
+ def extract_contrast_volume(
+ self,
+ scattered: ScatteredVolume,
+ refractive_index_medium: float,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ print('ri_medium', refractive_index_medium)
+
+ ri = scattered.get_property("refractive_index", None)
+ value = scattered.get_property("value", None)
+ intensity = scattered.get_property("intensity", None)
+
+ if intensity is not None:
+ warnings.warn(
+ "Scatterer defines 'intensity', which is ignored in "
+ "brightfield microscopy.",
+ UserWarning,
+ )
+
+ if ri is not None:
+ return (ri - refractive_index_medium) * scattered.array
+
+ warnings.warn(
+ "No 'refractive_index' specified; using 'value' as a non-physical "
+ "brightfield contrast. Results are not physically calibrated. "
+ "Define 'refractive_index' for physically meaningful contrast.",
+ UserWarning,
+ )
+
+ return value * scattered.array
def get(
self: Brightfield,
@@ -1260,7 +1410,7 @@ def get(
limits: ArrayLike[int],
fields: ArrayLike[complex],
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Simulates imaging with brightfield microscopy.
This method propagates light through the given volume, applying
@@ -1285,7 +1435,7 @@ def get(
Returns
-------
- Image: Image
+ image: np.ndarray
Processed image after simulating the brightfield imaging process.
Examples
@@ -1300,7 +1450,7 @@ def get(
... wavelength=0.52e-6,
... magnification=60,
... )
- >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex))
+ >>> volume = np.ones((128, 128, 10), dtype=complex)
>>> limits = np.array([[0, 128], [0, 128], [0, 10]])
>>> fields = np.array([np.ones((162, 162), dtype=complex)])
>>> properties = optics.properties()
@@ -1345,7 +1495,7 @@ def get(
if output_region[3] is None
else int(output_region[3] - limits[1, 0] + pad[3])
)
-
+
padded_volume = padded_volume[
output_region[0] : output_region[2],
output_region[1] : output_region[3],
@@ -1353,9 +1503,7 @@ def get(
]
z_limits = limits[2, :]
- output_image = Image(
- np.zeros((*padded_volume.shape[0:2], 1))
- )
+ output_image = np.zeros((*padded_volume.shape[0:2], 1))
index_iterator = range(padded_volume.shape[2])
z_iterator = np.linspace(
@@ -1414,7 +1562,25 @@ def get(
light_in_focus = light_in * shifted_pupil
if len(fields) > 0:
- field = np.sum(fields, axis=0)
+ # field = np.sum(fields, axis=0)
+ field_arrays = []
+
+ for fs in fields:
+ # fs is a ScatteredField
+ arr = fs.array
+
+ # Enforce (H, W, 1) shape
+ if arr.ndim == 2:
+ arr = arr[..., None]
+
+ if arr.ndim != 3 or arr.shape[-1] != 1:
+ raise ValueError(
+ f"Expected field of shape (H, W, 1), got {arr.shape}"
+ )
+
+ field_arrays.append(arr)
+
+ field = np.sum(field_arrays, axis=0)
light_in_focus += field[..., 0]
shifted_pupil = np.fft.fftshift(pupils[-1])
light_in_focus = light_in_focus * shifted_pupil
@@ -1426,7 +1592,7 @@ def get(
: padded_volume.shape[0], : padded_volume.shape[1]
]
output_image = np.expand_dims(output_image, axis=-1)
- output_image = Image(output_image[pad[0] : -pad[2], pad[1] : -pad[3]])
+ output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]]
if not kwargs.get("return_field", False):
output_image = np.square(np.abs(output_image))
@@ -1436,7 +1602,7 @@ def get(
# output_image = output_image * np.exp(1j * -np.pi / 4)
# output_image = output_image + 1
- output_image.properties = illuminated_volume.properties
+ # output_image.properties = illuminated_volume.properties
return output_image
@@ -1624,6 +1790,73 @@ def __init__(
illumination_angle=illumination_angle,
**kwargs)
+ def validate_input(self, scattered):
+ if isinstance(scattered, ScatteredVolume):
+ warnings.warn(
+ "Darkfield imaging from ScatteredVolume is a very rough "
+ "approximation. Use ScatteredField for physically meaningful "
+ "darkfield simulations.",
+ UserWarning,
+ )
+
+ def extract_contrast_volume(
+ self,
+ scattered: ScatteredVolume,
+ refractive_index_medium: float,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ """
+ Approximate darkfield contrast from a volume (toy model).
+
+ This is a non-physical approximation intended for qualitative simulations.
+ """
+
+ ri = scattered.get_property("refractive_index", None)
+ value = scattered.get_property("value", None)
+ intensity = scattered.get_property("intensity", None)
+
+ # Intensity has no meaning here
+ if intensity is not None:
+ warnings.warn(
+ "Scatterer defines 'intensity', which is ignored in "
+ "darkfield microscopy.",
+ UserWarning,
+ )
+
+ if ri is not None:
+ delta_n = ri - refractive_index_medium
+ warnings.warn(
+ "Approximating darkfield contrast from refractive index. "
+ "Result is non-physical and qualitative only.",
+ UserWarning,
+ )
+ return (delta_n ** 2) * scattered.array
+
+ warnings.warn(
+ "No 'refractive_index' specified; using 'value' as a non-physical "
+ "darkfield scattering strength. Results are qualitative only.",
+ UserWarning,
+ )
+
+ return (value ** 2) * scattered.array
+
+ def downscale_image(self, image: np.ndarray, upscale):
+ """Detector downscaling (energy conserving)"""
+ if not np.any(np.array(upscale) != 1):
+ return image
+
+ ux, uy = upscale[:2]
+ if ux != uy:
+ raise ValueError(
+ f"Energy-conserving detector integration requires ux == uy, "
+ f"got ux={ux}, uy={uy}."
+ )
+ if isinstance(ux, float) and ux.is_integer():
+ ux = int(ux)
+
+ # Energy-conserving detector integration
+ return SumPooling(ux)(image)
+
#Retrieve get as super
def get(
self: Darkfield,
@@ -1631,7 +1864,7 @@ def get(
limits: ArrayLike[int],
fields: ArrayLike[complex],
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Retrieve the darkfield image of the illuminated volume.
Parameters
@@ -1800,9 +2033,1001 @@ def get(
return image
+class NonOverlapping(Feature):
+ """Ensure volumes are placed non-overlapping in a 3D space.
+
+ This feature ensures that a list of 3D volumes are positioned such that
+ their non-zero voxels do not overlap. If volumes overlap, their positions
+ are resampled until they are non-overlapping. If the maximum number of
+ attempts is exceeded, the feature regenerates the list of volumes and
+ raises a warning if non-overlapping placement cannot be achieved.
+
+ Note: `min_distance` refers to the distance between the edges of volumes,
+ not their centers. Due to the way volumes are calculated, slight rounding
+ errors may affect the final distance.
+
+ This feature is incompatible with non-volumetric scatterers such as
+ `MieScatterers`.
+
+ Parameters
+ ----------
+ feature: Feature
+ The feature that generates the list of volumes to place
+ non-overlapping.
+ min_distance: float, optional
+ The minimum distance between volumes in pixels. It can be negative to
+ allow for partial overlap. Defaults to 1.
+ max_attempts: int, optional
+ The maximum number of attempts to place volumes without overlap.
+ Defaults to 5.
+ max_iters: int, optional
+ The maximum number of resamplings. If this number is exceeded, a new
+ list of volumes is generated. Defaults to 100.
+
+ Attributes
+ ----------
+ __distributed__: bool
+ Always `False` for `NonOverlapping`, indicating that this feature’s
+ `.get()` method processes the entire input at once even if it is a
+ list, rather than distributing calls for each item of the list.N
+
+ Methods
+ -------
+ `get(*_, min_distance, max_attempts, **kwargs) -> array`
+ Generate a list of non-overlapping 3D volumes.
+ `_check_non_overlapping(list_of_volumes) -> bool`
+ Check if all volumes in the list are non-overlapping.
+ `_check_bounding_cubes_non_overlapping(...) -> bool`
+ Check if two bounding cubes are non-overlapping.
+ `_get_overlapping_cube(...) -> list[int]`
+ Get the overlapping cube between two bounding cubes.
+ `_get_overlapping_volume(...) -> array`
+ Get the overlapping volume between a volume and a bounding cube.
+ `_check_volumes_non_overlapping(...) -> bool`
+ Check if two volumes are non-overlapping.
+ `_resample_volume_position(volume) -> Image`
+ Resample the position of a volume to avoid overlap.
+
+ Notes
+ -----
+ - This feature performs bounding cube checks first to quickly reject
+ obvious overlaps before voxel-level checks.
+ - If the bounding cubes overlap, precise voxel-based checks are performed.
+
+ Examples
+ ---------
+ >>> import deeptrack as dt
+
+ Define an ellipse scatterer with randomly positioned objects:
+
+ >>> import numpy as np
+ >>>
+ >>> scatterer = dt.Ellipse(
+ >>> radius= 13 * dt.units.pixels,
+ >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels,
+ >>> )
+
+ Create multiple scatterers:
+
+ >>> scatterers = (scatterer ^ 8)
+
+ Define the optics and create the image with possible overlap:
+
+ >>> optics = dt.Fluorescence()
+ >>> im_with_overlap = optics(scatterers)
+ >>> im_with_overlap.store_properties()
+ >>> im_with_overlap_resolved = image_with_overlap()
+
+ Gather position from image:
+
+ >>> pos_with_overlap = np.array(
+ >>> im_with_overlap_resolved.get_property(
+ >>> "position",
+ >>> get_one=False
+ >>> )
+ >>> )
+
+ Enforce non-overlapping and create the image without overlap:
+
+ >>> non_overlapping_scatterers = dt.NonOverlapping(
+ ... scatterers,
+ ... min_distance=4,
+ ... )
+ >>> im_without_overlap = optics(non_overlapping_scatterers)
+ >>> im_without_overlap.store_properties()
+ >>> im_without_overlap_resolved = im_without_overlap()
+
+ Gather position from image:
+
+ >>> pos_without_overlap = np.array(
+ >>> im_without_overlap_resolved.get_property(
+ >>> "position",
+ >>> get_one=False
+ >>> )
+ >>> )
+
+ Create a figure with two subplots to visualize the difference:
+
+ >>> import matplotlib.pyplot as plt
+ >>>
+ >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5))
+ >>>
+ >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray")
+ >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0])
+ >>> axes[0].set_title("Overlapping Objects")
+ >>> axes[0].axis("off")
+ >>>
+ >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray")
+ >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0])
+ >>> axes[1].set_title("Non-Overlapping Objects")
+ >>> axes[1].axis("off")
+ >>> plt.tight_layout()
+ >>>
+ >>> plt.show()
+
+ Define function to calculate minimum distance:
+
+ >>> def calculate_min_distance(positions):
+ >>> distances = [
+ >>> np.linalg.norm(positions[i] - positions[j])
+ >>> for i in range(len(positions))
+ >>> for j in range(i + 1, len(positions))
+ >>> ]
+ >>> return min(distances)
+
+ Print minimum distances with and without overlap:
+
+ >>> print(calculate_min_distance(pos_with_overlap))
+ 10.768742383382174
+
+ >>> print(calculate_min_distance(pos_without_overlap))
+ 30.82531120942446
+
+ """
+
+ __distributed__: bool = False
+
+ def __init__(
+ self: NonOverlapping,
+ feature: Feature,
+ min_distance: float = 1,
+ max_attempts: int = 5,
+ max_iters: int = 100,
+ **kwargs: Any,
+ ):
+ """Initializes the NonOverlapping feature.
+
+ Ensures that volumes are placed **non-overlapping** by iteratively
+ resampling their positions. If the maximum number of attempts is
+ exceeded, the feature regenerates the list of volumes.
+
+ Parameters
+ ----------
+ feature: Feature
+ The feature that generates the list of volumes.
+ min_distance: float, optional
+ The minimum separation distance **between volume edges**, in
+ pixels. It defaults to `1`. Negative values allow for partial
+ overlap.
+ max_attempts: int, optional
+ The maximum number of attempts to place the volumes without
+ overlap. It defaults to `5`.
+ max_iters: int, optional
+ The maximum number of resampling iterations per attempt. If
+ exceeded, a new list of volumes is generated. It defaults to `100`.
+
+ """
+
+ super().__init__(
+ min_distance=min_distance,
+ max_attempts=max_attempts,
+ max_iters=max_iters,
+ **kwargs,
+ )
+ self.feature = self.add_feature(feature, **kwargs)
+
+ def get(
+ self: NonOverlapping,
+ *_: Any,
+ min_distance: float,
+ max_attempts: int,
+ max_iters: int,
+ **kwargs: Any,
+ ) -> list[np.ndarray]:
+ """Generates a list of non-overlapping 3D volumes within a defined
+ field of view (FOV).
+
+ This method **iteratively** attempts to place volumes while ensuring
+ they maintain at least `min_distance` separation. If non-overlapping
+ placement is not achieved within `max_attempts`, a warning is issued,
+ and the best available configuration is returned.
+
+ Parameters
+ ----------
+ _: Any
+ Placeholder parameter, typically for an input image.
+ min_distance: float
+ The minimum required separation distance between volumes, in
+ pixels.
+ max_attempts: int
+ The maximum number of attempts to generate a valid non-overlapping
+ configuration.
+ max_iters: int
+ The maximum number of resampling iterations per attempt.
+ **kwargs: Any
+ Additional parameters that may be used by subclasses.
+
+ Returns
+ -------
+ list[np.ndarray]
+ A list of 3D volumes represented as NumPy arrays. If
+ non-overlapping placement is unsuccessful, the best available
+ configuration is returned.
+
+ Warns
+ -----
+ UserWarning
+ If non-overlapping placement is **not** achieved within
+ `max_attempts`, suggesting parameter adjustments such as increasing
+ the FOV or reducing `min_distance`.
+
+ Notes
+ -----
+ - The placement process prioritizes bounding cube checks for
+ efficiency.
+ - If bounding cubes overlap, voxel-based overlap checks are performed.
+
+ """
+
+ for _ in range(max_attempts):
+ list_of_volumes = self.feature()
+
+ if not isinstance(list_of_volumes, list):
+ list_of_volumes = [list_of_volumes]
+
+ for _ in range(max_iters):
+
+ list_of_volumes = [
+ self._resample_volume_position(volume)
+ for volume in list_of_volumes
+ ]
+
+ if self._check_non_overlapping(list_of_volumes):
+ return list_of_volumes
+
+ # Generate a new list of volumes if max_attempts is exceeded.
+ self.feature.update()
+
+ warnings.warn(
+ "Non-overlapping placement could not be achieved. Consider "
+ "adjusting parameters: reduce object radius, increase FOV, "
+ "or decrease min_distance.",
+ UserWarning,
+ )
+ return list_of_volumes
+
+ def _check_non_overlapping(
+ self: NonOverlapping,
+ list_of_volumes: list[np.ndarray],
+ ) -> bool:
+ """Determines whether all volumes in the provided list are
+ non-overlapping.
+
+ This method verifies that the non-zero voxels of each 3D volume in
+ `list_of_volumes` are at least `min_distance` apart. It first checks
+ bounding boxes for early rejection and then examines actual voxel
+ overlap when necessary. Volumes are assumed to have a `position`
+ attribute indicating their placement in 3D space.
+
+ Parameters
+ ----------
+ list_of_volumes: list[np.ndarray]
+ A list of 3D arrays representing the volumes to be checked for
+ overlap. Each volume is expected to have a position attribute.
+
+ Returns
+ -------
+ bool
+ `True` if all volumes are non-overlapping, otherwise `False`.
+
+ Notes
+ -----
+ - If `min_distance` is negative, volumes are shrunk using isotropic
+ erosion before checking overlap.
+ - If `min_distance` is positive, volumes are padded and expanded using
+ isotropic dilation.
+ - Overlapping checks are first performed on bounding cubes for
+ efficiency.
+ - If bounding cubes overlap, voxel-level checks are performed.
+
+ """
+ from deeptrack.scatterers import ScatteredVolume
+
+ from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend
+ from deeptrack.optics import _get_position
+ from deeptrack.math import isotropic_erosion, isotropic_dilation
+
+ min_distance = self.min_distance()
+ crop = CropTight()
+
+ new_volumes = []
+
+ for volume in list_of_volumes:
+ arr = volume.array
+ mask = arr != 0
+
+ if min_distance < 0:
+ new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend())
+ else:
+ pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True)
+ new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend())
+ new_arr = crop(new_arr)
+
+ if self.get_backend() == "torch":
+ new_arr = new_arr.to(dtype=arr.dtype)
+ else:
+ new_arr = new_arr.astype(arr.dtype)
+
+ new_volume = ScatteredVolume(
+ array=new_arr,
+ properties=volume.properties.copy(),
+ )
+
+ new_volumes.append(new_volume)
+
+ list_of_volumes = new_volumes
+ min_distance = 1
+
+ # The position of the top left corner of each volume (index (0, 0, 0)).
+ volume_positions_1 = [
+ _get_position(volume, mode="corner", return_z=True).astype(int)
+ for volume in list_of_volumes
+ ]
+
+ # The position of the bottom right corner of each volume
+ # (index (-1, -1, -1)).
+ volume_positions_2 = [
+ p0 + np.array(v.shape)
+ for v, p0 in zip(list_of_volumes, volume_positions_1)
+ ]
+
+ # (x1, y1, z1, x2, y2, z2) for each volume.
+ volume_bounding_cube = [
+ [*p0, *p1]
+ for p0, p1 in zip(volume_positions_1, volume_positions_2)
+ ]
+
+ for i, j in itertools.combinations(range(len(list_of_volumes)), 2):
+
+ # If the bounding cubes do not overlap, the volumes do not overlap.
+ if self._check_bounding_cubes_non_overlapping(
+ volume_bounding_cube[i], volume_bounding_cube[j], min_distance
+ ):
+ continue
+
+ # If the bounding cubes overlap, get the overlapping region of each
+ # volume.
+ overlapping_cube = self._get_overlapping_cube(
+ volume_bounding_cube[i], volume_bounding_cube[j]
+ )
+ overlapping_volume_1 = self._get_overlapping_volume(
+ list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube
+ )
+ overlapping_volume_2 = self._get_overlapping_volume(
+ list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube
+ )
+
+ # If either the overlapping regions are empty, the volumes do not
+ # overlap (done for speed).
+ if (np.all(overlapping_volume_1 == 0)
+ or np.all(overlapping_volume_2 == 0)):
+ continue
+
+ # If products of overlapping regions are non-zero, return False.
+ # if np.any(overlapping_volume_1 * overlapping_volume_2):
+ # return False
+
+ # Finally, check that the non-zero voxels of the volumes are at
+ # least min_distance apart.
+ if not self._check_volumes_non_overlapping(
+ overlapping_volume_1, overlapping_volume_2, min_distance
+ ):
+ return False
+
+ return True
+
+ def _check_bounding_cubes_non_overlapping(
+ self: NonOverlapping,
+ bounding_cube_1: list[int],
+ bounding_cube_2: list[int],
+ min_distance: float,
+ ) -> bool:
+ """Determines whether two 3D bounding cubes are non-overlapping.
+
+ This method checks whether the bounding cubes of two volumes are
+ **separated by at least** `min_distance` along **any** spatial axis.
+
+ Parameters
+ ----------
+ bounding_cube_1: list[int]
+ A list of six integers `[x1, y1, z1, x2, y2, z2]` representing
+ the first bounding cube.
+ bounding_cube_2: list[int]
+ A list of six integers `[x1, y1, z1, x2, y2, z2]` representing
+ the second bounding cube.
+ min_distance: float
+ The required **minimum separation distance** between the two
+ bounding cubes.
+
+ Returns
+ -------
+ bool
+ `True` if the bounding cubes are non-overlapping (separated by at
+ least `min_distance` along **at least one axis**), otherwise
+ `False`.
+
+ Notes
+ -----
+ - This function **only checks bounding cubes**, **not actual voxel
+ data**.
+ - If the bounding cubes are non-overlapping, the corresponding
+ **volumes are also non-overlapping**.
+ - This check is much **faster** than full voxel-based comparisons.
+
+ """
+
+ # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2).
+ # Check that the bounding cubes are non-overlapping.
+ return (
+ (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or
+ (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or
+ (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or
+ (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or
+ (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or
+ (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance)
+ )
+
+ def _get_overlapping_cube(
+ self: NonOverlapping,
+ bounding_cube_1: list[int],
+ bounding_cube_2: list[int],
+ ) -> list[int]:
+ """Computes the overlapping region between two 3D bounding cubes.
+
+ This method calculates the coordinates of the intersection of two
+ axis-aligned bounding cubes, each represented as a list of six
+ integers:
+
+ - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner.
+ - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner.
+
+ The resulting overlapping region is determined by:
+ - Taking the **maximum** of the starting coordinates (`x1, y1, z1`).
+ - Taking the **minimum** of the ending coordinates (`x2, y2, z2`).
+
+ If the cubes **do not** overlap, the resulting coordinates will not
+ form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`).
+
+ Parameters
+ ----------
+ bounding_cube_1: list[int]
+ The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`.
+ bounding_cube_2: list[int]
+ The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`.
+
+ Returns
+ -------
+ list[int]
+ A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the
+ overlapping bounding cube. If no overlap exists, the coordinates
+ will **not** define a valid cube.
+
+ Notes
+ -----
+ - This function does **not** check for valid input or ensure the
+ resulting cube is well-formed.
+ - If no overlap exists, downstream functions must handle the invalid
+ result.
+
+ """
+
+ return [
+ max(bounding_cube_1[0], bounding_cube_2[0]),
+ max(bounding_cube_1[1], bounding_cube_2[1]),
+ max(bounding_cube_1[2], bounding_cube_2[2]),
+ min(bounding_cube_1[3], bounding_cube_2[3]),
+ min(bounding_cube_1[4], bounding_cube_2[4]),
+ min(bounding_cube_1[5], bounding_cube_2[5]),
+ ]
+
+ def _get_overlapping_volume(
+ self: NonOverlapping,
+ volume: np.ndarray, # 3D array.
+ bounding_cube: tuple[float, float, float, float, float, float],
+ overlapping_cube: tuple[float, float, float, float, float, float],
+ ) -> np.ndarray:
+ """Extracts the overlapping region of a 3D volume within the specified
+ overlapping cube.
+
+ This method identifies and returns the subregion of `volume` that
+ lies within the `overlapping_cube`. The bounding information of the
+ volume is provided via `bounding_cube`.
+
+ Parameters
+ ----------
+ volume: np.ndarray
+ A 3D NumPy array representing the volume from which the
+ overlapping region is extracted.
+ bounding_cube: tuple[float, float, float, float, float, float]
+ The bounding cube of the volume, given as a tuple of six floats:
+ `(x1, y1, z1, x2, y2, z2)`. The first three values define the
+ **top-left-front** corner, while the last three values define the
+ **bottom-right-back** corner.
+ overlapping_cube: tuple[float, float, float, float, float, float]
+ The overlapping region between the volume and another volume,
+ represented in the same format as `bounding_cube`.
+
+ Returns
+ -------
+ np.ndarray
+ A 3D NumPy array representing the portion of `volume` that
+ lies within `overlapping_cube`. If the overlap does not exist,
+ an empty array may be returned.
+
+ Notes
+ -----
+ - The method computes the relative indices of `overlapping_cube`
+ within `volume` by subtracting the bounding cube's starting
+ position.
+ - The extracted region is determined by integer indices, meaning
+ coordinates are implicitly **floored to integers**.
+ - If `overlapping_cube` extends beyond `volume` boundaries, the
+ returned subregion is **cropped** to fit within `volume`.
+
+ """
+
+ # The position of the top left corner of the overlapping cube in the volume
+ overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array(
+ bounding_cube[:3]
+ )
+
+ # The position of the bottom right corner of the overlapping cube in the volume
+ overlapping_cube_end_position = np.array(
+ overlapping_cube[3:]
+ ) - np.array(bounding_cube[:3])
+
+ # cast to int
+ overlapping_cube_position = overlapping_cube_position.astype(int)
+ overlapping_cube_end_position = overlapping_cube_end_position.astype(int)
+
+ return volume[
+ overlapping_cube_position[0] : overlapping_cube_end_position[0],
+ overlapping_cube_position[1] : overlapping_cube_end_position[1],
+ overlapping_cube_position[2] : overlapping_cube_end_position[2],
+ ]
+
+ def _check_volumes_non_overlapping(
+ self: NonOverlapping,
+ volume_1: np.ndarray,
+ volume_2: np.ndarray,
+ min_distance: float,
+ ) -> bool:
+ """Determines whether the non-zero voxels in two 3D volumes are at
+ least `min_distance` apart.
+
+ This method checks whether the active regions (non-zero voxels) in
+ `volume_1` and `volume_2` maintain a minimum separation of
+ `min_distance`. If the volumes differ in size, the positions of their
+ non-zero voxels are adjusted accordingly to ensure a fair comparison.
+
+ Parameters
+ ----------
+ volume_1: np.ndarray
+ A 3D NumPy array representing the first volume.
+ volume_2: np.ndarray
+ A 3D NumPy array representing the second volume.
+ min_distance: float
+ The minimum Euclidean distance required between any two non-zero
+ voxels in the two volumes.
+
+ Returns
+ -------
+ bool
+ `True` if all non-zero voxels in `volume_1` and `volume_2` are at
+ least `min_distance` apart, otherwise `False`.
+
+ Notes
+ -----
+ - This function assumes both volumes are correctly aligned within a
+ shared coordinate space.
+ - If the volumes are of different sizes, voxel positions are scaled
+ or adjusted for accurate distance measurement.
+ - Uses **Euclidean distance** for separation checking.
+ - If either volume is empty (i.e., no non-zero voxels), they are
+ considered non-overlapping.
+
+ """
+
+ # Get the positions of the non-zero voxels of each volume.
+ if self.get_backend() == "torch":
+ positions_1 = torch.nonzero(volume_1, as_tuple=False)
+ positions_2 = torch.nonzero(volume_2, as_tuple=False)
+ else:
+ positions_1 = np.argwhere(volume_1)
+ positions_2 = np.argwhere(volume_2)
+
+ # if positions_1.size == 0 or positions_2.size == 0:
+ # return True # If either volume is empty, they are "non-overlapping"
+
+ # # If the volumes are not the same size, the positions of the non-zero
+ # # voxels of each volume need to be scaled.
+ # if positions_1.size == 0 or positions_2.size == 0:
+ # return True # If either volume is empty, they are "non-overlapping"
+
+ # If the volumes are not the same size, the positions of the non-zero
+ # voxels of each volume need to be scaled.
+ if volume_1.shape != volume_2.shape:
+ positions_1 = (
+ positions_1 * np.array(volume_2.shape)
+ / np.array(volume_1.shape)
+ )
+ positions_1 = positions_1.astype(int)
+
+ # Check that the non-zero voxels of the volumes are at least
+ # min_distance apart.
+ if self.get_backend() == "torch":
+ dist = torch.cdist(
+ positions_1.float(),
+ positions_2.float(),
+ )
+ return bool((dist > min_distance).all())
+ else:
+ return np.all(cdist(positions_1, positions_2) > min_distance)
+
+ def _resample_volume_position(
+ self: NonOverlapping,
+ volume: np.ndarray | Image,
+ ) -> Image:
+ """Resamples the position of a 3D volume using its internal position
+ sampler.
+
+ This method updates the `position` property of the given `volume` by
+ drawing a new position from the `_position_sampler` stored in the
+ volume's `properties`. If the sampled position is a `Quantity`, it is
+ converted to pixel units.
+
+ Parameters
+ ----------
+ volume: np.ndarray
+ The 3D volume whose position is to be resampled. The volume must
+ have a `properties` attribute containing dictionaries with
+ `position` and `_position_sampler` keys.
+
+ Returns
+ -------
+ Image
+ The same input volume with its `position` property updated to the
+ newly sampled value.
+
+ Notes
+ -----
+ - The `_position_sampler` function is expected to return a **tuple of
+ three floats** (e.g., `(x, y, z)`).
+ - If the sampled position is a `Quantity`, it is converted to pixels.
+ - **Only** dictionaries in `volume.properties` that contain both
+ `position` and `_position_sampler` keys are modified.
+
+ """
+
+ pdict = volume.properties
+ if "position" in pdict and "_position_sampler" in pdict:
+ new_position = pdict["_position_sampler"]()
+ if isinstance(new_position, Quantity):
+ new_position = new_position.to("pixel").magnitude
+ pdict["position"] = new_position
+
+ return volume
+
+
+class SampleToMasks(Feature):
+ """Create a mask from a list of images.
+
+ This feature applies a transformation function to each input image and
+ merges the resulting masks into a single multi-layer image. Each input
+ image must have a `position` property that determines its placement within
+ the final mask. When used with scatterers, the `voxel_size` property must
+ be provided for correct object sizing.
+
+ Parameters
+ ----------
+ transformation_function: Callable[[Image], Image]
+ A function that transforms each input image into a mask with
+ `number_of_masks` layers.
+ number_of_masks: PropertyLike[int], optional
+ The number of mask layers to generate. Default is 1.
+ output_region: PropertyLike[tuple[int, int, int, int]], optional
+ The size and position of the output mask, typically aligned with
+ `optics.output_region`.
+ merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
+ Method for merging individual masks into the final image. Can be:
+ - "add" (default): Sum the masks.
+ - "overwrite": Later masks overwrite earlier masks.
+ - "or": Combine masks using a logical OR operation.
+ - "mul": Multiply masks.
+ - Function: Custom function taking two images and merging them.
+
+ **kwargs: dict[str, Any]
+ Additional keyword arguments passed to the parent `Feature` class.
+
+ Methods
+ -------
+ `get(image, transformation_function, **kwargs) -> Image`
+ Applies the transformation function to the input image.
+ `_process_and_get(images, **kwargs) -> Image | np.ndarray`
+ Processes a list of images and generates a multi-layer mask.
+
+ Returns
+ -------
+ np.ndarray
+ The final mask image with the specified number of layers.
+
+ Raises
+ ------
+ ValueError
+ If `merge_method` is invalid.
+
+ Examples
+ -------
+ >>> import deeptrack as dt
+
+ Define number of particles:
+
+ >>> n_particles = 12
+
+ Define optics and particles:
+
+ >>> import numpy as np
+ >>>
+ >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64))
+ >>> particle = dt.PointParticle(
+ >>> position=lambda: np.random.uniform(5, 55, size=2),
+ >>> )
+ >>> particles = particle ^ n_particles
+
+ Define pipelines:
+
+ >>> sim_im_pip = optics(particles)
+ >>> sim_mask_pip = particles >> dt.SampleToMasks(
+ ... lambda: lambda particles: particles > 0,
+ ... output_region=optics.output_region,
+ ... merge_method="or",
+ ... )
+ >>> pipeline = sim_im_pip & sim_mask_pip
+ >>> pipeline.store_properties()
+
+ Generate image and mask:
+
+ >>> image, mask = pipeline.update()()
+
+ Get particle positions:
+
+ >>> positions = np.array(image.get_property("position", get_one=False))
+
+ Visualize results:
+
+ >>> import matplotlib.pyplot as plt
+ >>>
+ >>> plt.subplot(1, 2, 1)
+ >>> plt.imshow(image, cmap="gray")
+ >>> plt.title("Original Image")
+ >>> plt.subplot(1, 2, 2)
+ >>> plt.imshow(mask, cmap="gray")
+ >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50)
+ >>> plt.title("Mask")
+ >>> plt.show()
+
+ """
+
+ def __init__(
+ self: Feature,
+ transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor],
+ number_of_masks: PropertyLike[int] = 1,
+ output_region: PropertyLike[tuple[int, int, int, int]] = None,
+ merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add",
+ **kwargs: Any,
+ ):
+ """Initialize the SampleToMasks feature.
+
+ Parameters
+ ----------
+ transformation_function: Callable[[Image], Image]
+ Function to transform input images into masks.
+ number_of_masks: PropertyLike[int], optional
+ Number of mask layers. Default is 1.
+ output_region: PropertyLike[tuple[int, int, int, int]], optional
+ Output region of the mask. Default is None.
+ merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
+ Method to merge masks. Defaults to "add".
+ **kwargs: dict[str, Any]
+ Additional keyword arguments passed to the parent class.
+
+ """
+
+ super().__init__(
+ transformation_function=transformation_function,
+ number_of_masks=number_of_masks,
+ output_region=output_region,
+ merge_method=merge_method,
+ **kwargs,
+ )
+
+ def get(
+ self: Feature,
+ image: np.ndarray,
+ transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor],
+ **kwargs: Any,
+ ) -> np.ndarray:
+ """Apply the transformation function to a single image.
+
+ Parameters
+ ----------
+ image: np.ndarray
+ The input image.
+ transformation_function: Callable[[np.ndarray], np.ndarray]
+ Function to transform the image.
+ **kwargs: dict[str, Any]
+ Additional parameters.
+
+ Returns
+ -------
+ Image
+ The transformed image.
+
+ """
+
+ return transformation_function(image.array)
+
+ def _process_and_get(
+ self: Feature,
+ images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ """Process a list of images and generate a multi-layer mask.
+
+ Parameters
+ ----------
+ images: np.ndarray or list[np.ndarrray] or Image or list[Image]
+ List of input images or a single image.
+ **kwargs: dict[str, Any]
+ Additional parameters including `output_region`, `number_of_masks`,
+ and `merge_method`.
+
+ Returns
+ -------
+ Image or np.ndarray
+ The final mask image.
+
+ """
+
+ # Handle list of images.
+ # if isinstance(images, list) and len(images) != 1:
+ list_of_labels = super()._process_and_get(images, **kwargs)
+
+ from deeptrack.scatterers import ScatteredVolume
+
+ for idx, (label, image) in enumerate(zip(list_of_labels, images)):
+ list_of_labels[idx] = \
+ ScatteredVolume(array=label, properties=image.properties.copy())
+
+ # Create an empty output image.
+ output_region = kwargs["output_region"]
+ output = xp.zeros(
+ (
+ output_region[2] - output_region[0],
+ output_region[3] - output_region[1],
+ kwargs["number_of_masks"],
+ ),
+ dtype=list_of_labels[0].array.dtype,
+ )
+
+ from deeptrack.optics import _get_position
+
+ # Merge masks into the output.
+ for volume in list_of_labels:
+ label = volume.array
+ position = _get_position(volume)
+
+ p0 = xp.round(position - xp.asarray(output_region[0:2]))
+ p0 = p0.astype(xp.int64)
+
+
+ if xp.any(p0 > xp.asarray(output.shape[:2])) or \
+ xp.any(p0 + xp.asarray(label.shape[:2]) < 0):
+ continue
+
+ crop_x = (-xp.minimum(p0[0], 0)).item()
+ crop_y = (-xp.minimum(p0[1], 0)).item()
+
+ crop_x_end = int(
+ label.shape[0]
+ - np.max([p0[0] + label.shape[0] - output.shape[0], 0])
+ )
+ crop_y_end = int(
+ label.shape[1]
+ - np.max([p0[1] + label.shape[1] - output.shape[1], 0])
+ )
+
+ labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :]
+
+ p0[0] = np.max([p0[0], 0])
+ p0[1] = np.max([p0[1], 0])
+
+ p0 = p0.astype(int)
+
+ output_slice = output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ ]
+
+ for label_index in range(kwargs["number_of_masks"]):
+
+ if isinstance(kwargs["merge_method"], list):
+ merge = kwargs["merge_method"][label_index]
+ else:
+ merge = kwargs["merge_method"]
+
+ if merge == "add":
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] += labelarg[..., label_index]
+
+ elif merge == "overwrite":
+ output_slice[
+ labelarg[..., label_index] != 0, label_index
+ ] = labelarg[labelarg[..., label_index] != 0, \
+ label_index]
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] = output_slice[..., label_index]
+
+ elif merge == "or":
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] = xp.logical_or(
+ output_slice[..., label_index] != 0,
+ labelarg[..., label_index] != 0
+ )
+
+ elif merge == "mul":
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] *= labelarg[..., label_index]
+
+ else:
+ # No match, assume function
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] = merge(
+ output_slice[..., label_index],
+ labelarg[..., label_index],
+ )
+
+ return output
+
+
#TODO ***??*** revise _get_position - torch, typing, docstring, unit test
def _get_position(
- image: Image,
+ scatterer: ScatteredObject,
mode: str = "corner",
return_z: bool = False,
) -> np.ndarray:
@@ -1826,26 +3051,23 @@ def _get_position(
num_outputs = 2 + return_z
- if mode == "corner" and image.size > 0:
+ if mode == "corner" and scatterer.array.size > 0:
import scipy.ndimage
- image = image.to_numpy()
-
- shift = scipy.ndimage.center_of_mass(np.abs(image))
+ shift = scipy.ndimage.center_of_mass(np.abs(scatterer.array))
if np.isnan(shift).any():
- shift = np.array(image.shape) / 2
+ shift = np.array(scatterer.array.shape) / 2
else:
shift = np.zeros((num_outputs))
- position = np.array(image.get_property("position", default=None))
+ position = np.array(scatterer.get_property("position", default=None))
if position is None:
return position
scale = np.array(get_active_scale())
-
if len(position) == 3:
position = position * scale + 0.5 * (scale - 1)
if return_z:
@@ -1856,7 +3078,7 @@ def _get_position(
elif len(position) == 2:
if return_z:
outp = (
- np.array([position[0], position[1], image.get_property("z", default=0)])
+ np.array([position[0], position[1], scatterer.get_property("z", default=0)])
* scale
- shift
+ 0.5 * (scale - 1)
@@ -1868,6 +3090,58 @@ def _get_position(
return position
+def _bilinear_interpolate_numpy(
+ scatterer: np.ndarray, x_off: float, y_off: float
+) -> np.ndarray:
+ """Apply bilinear subpixel interpolation in the x–y plane (NumPy)."""
+ kernel = np.array(
+ [
+ [0.0, 0.0, 0.0],
+ [0.0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],
+ [0.0, x_off * (1 - y_off), x_off * y_off],
+ ]
+ )
+ out = np.zeros_like(scatterer)
+ for z in range(scatterer.shape[2]):
+ if np.iscomplexobj(scatterer):
+ out[:, :, z] = (
+ convolve(np.real(scatterer[:, :, z]), kernel, mode="constant")
+ + 1j
+ * convolve(np.imag(scatterer[:, :, z]), kernel, mode="constant")
+ )
+ else:
+ out[:, :, z] = convolve(scatterer[:, :, z], kernel, mode="constant")
+ return out
+
+
+def _bilinear_interpolate_torch(
+ scatterer: torch.Tensor, x_off: float, y_off: float
+) -> torch.Tensor:
+ """Apply bilinear subpixel interpolation in the x–y plane (Torch).
+
+ Uses grid_sample for autograd-friendly interpolation.
+ """
+ H, W, D = scatterer.shape
+
+ # Normalized shifts in [-1,1]
+ x_shift = 2 * x_off / (W - 1)
+ y_shift = 2 * y_off / (H - 1)
+
+ yy, xx = torch.meshgrid(
+ torch.linspace(-1, 1, H, device=scatterer.device, dtype=scatterer.dtype),
+ torch.linspace(-1, 1, W, device=scatterer.device, dtype=scatterer.dtype),
+ indexing="ij",
+ )
+ grid = torch.stack((xx + x_shift, yy + y_shift), dim=-1) # (H,W,2)
+ grid = grid.unsqueeze(0).repeat(D, 1, 1, 1) # (D,H,W,2)
+
+ inp = scatterer.permute(2, 0, 1).unsqueeze(1) # (D,1,H,W)
+
+ out = F.grid_sample(inp, grid, mode="bilinear",
+ padding_mode="zeros", align_corners=True)
+ return out.squeeze(1).permute(1, 2, 0) # (H,W,D)
+
+
#TODO ***??*** revise _create_volume - torch, typing, docstring, unit test
def _create_volume(
list_of_scatterers: list,
@@ -1903,6 +3177,12 @@ def _create_volume(
Spatial limits of the volume.
"""
+ # contrast_type = kwargs.get("contrast_type", None)
+ # if contrast_type is None:
+ # raise RuntimeError(
+ # "_create_volume requires a contrast_type "
+ # "(e.g. 'intensity' or 'refractive_index')"
+ # )
if not isinstance(list_of_scatterers, list):
list_of_scatterers = [list_of_scatterers]
@@ -1927,24 +3207,28 @@ def _create_volume(
# This accounts for upscale doing AveragePool instead of SumPool. This is
# a bit of a hack, but it works for now.
- fudge_factor = scale[0] * scale[1] / scale[2]
+ # fudge_factor = scale[0] * scale[1] / scale[2]
for scatterer in list_of_scatterers:
-
position = _get_position(scatterer, mode="corner", return_z=True)
- if scatterer.get_property("intensity", None) is not None:
- intensity = scatterer.get_property("intensity")
- scatterer_value = intensity * fudge_factor
- elif scatterer.get_property("refractive_index", None) is not None:
- refractive_index = scatterer.get_property("refractive_index")
- scatterer_value = (
- refractive_index - refractive_index_medium
- )
- else:
- scatterer_value = scatterer.get_property("value")
+ # if contrast_type == "intensity":
+ # value = scatterer.get_property("intensity", None)
+ # if value is None:
+ # raise ValueError("Scatterer has no intensity.")
+ # scatterer_value = value
+
+ # elif contrast_type == "refractive_index":
+ # ri = scatterer.get_property("refractive_index", None)
+ # if ri is None:
+ # raise ValueError("Scatterer has no refractive_index.")
+ # scatterer_value = ri - refractive_index_medium
+
+ # else:
+ # raise RuntimeError(f"Unknown contrast_type: {contrast_type}")
- scatterer = scatterer * scatterer_value
+ # # Scale the array accordingly
+ # scatterer.array = scatterer.array * scatterer_value
if limits is None:
limits = np.zeros((3, 2), dtype=np.int32)
@@ -1952,26 +3236,25 @@ def _create_volume(
limits[:, 1] = np.floor(position).astype(np.int32) + 1
if (
- position[0] + scatterer.shape[0] < OR[0]
+ position[0] + scatterer.array.shape[0] < OR[0]
or position[0] > OR[2]
- or position[1] + scatterer.shape[1] < OR[1]
+ or position[1] + scatterer.array.shape[1] < OR[1]
or position[1] > OR[3]
):
continue
- padded_scatterer = Image(
- np.pad(
- scatterer,
+ # Pad scatterer to avoid edge effects during interpolation
+ padded_scatterer_arr = np.pad( #Use Pad instead and make it torch-compatible?
+ scatterer.array,
[(2, 2), (2, 2), (2, 2)],
"constant",
constant_values=0,
)
- )
- padded_scatterer.merge_properties_from(scatterer)
-
- scatterer = padded_scatterer
- position = _get_position(scatterer, mode="corner", return_z=True)
- shape = np.array(scatterer.shape)
+ padded_scatterer = ScatteredVolume(
+ array=padded_scatterer_arr, properties=scatterer.properties.copy(),
+ )
+ position = _get_position(padded_scatterer, mode="corner", return_z=True)
+ shape = np.array(padded_scatterer.array.shape)
if position is None:
RuntimeWarning(
@@ -1980,36 +3263,20 @@ def _create_volume(
)
continue
- splined_scatterer = np.zeros_like(scatterer)
-
x_off = position[0] - np.floor(position[0])
y_off = position[1] - np.floor(position[1])
- kernel = np.array(
- [
- [0, 0, 0],
- [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],
- [0, x_off * (1 - y_off), x_off * y_off],
- ]
- )
-
- for z in range(scatterer.shape[2]):
- if splined_scatterer.dtype == complex:
- splined_scatterer[:, :, z] = (
- convolve(
- np.real(scatterer[:, :, z]), kernel, mode="constant"
- )
- + convolve(
- np.imag(scatterer[:, :, z]), kernel, mode="constant"
- )
- * 1j
- )
- else:
- splined_scatterer[:, :, z] = convolve(
- scatterer[:, :, z], kernel, mode="constant"
- )
+
+ if isinstance(padded_scatterer.array, np.ndarray): # get_backend is a method of Features and not exposed
+ splined_scatterer = _bilinear_interpolate_numpy(padded_scatterer.array, x_off, y_off)
+ elif isinstance(padded_scatterer.array, torch.Tensor):
+ splined_scatterer = _bilinear_interpolate_torch(padded_scatterer.array, x_off, y_off)
+ else:
+ raise TypeError(
+ f"Unsupported array type {type(padded_scatterer.array)}. "
+ "Expected np.ndarray or torch.Tensor."
+ )
- scatterer = splined_scatterer
position = np.floor(position)
new_limits = np.zeros(limits.shape, dtype=np.int32)
for i in range(3):
@@ -2038,7 +3305,8 @@ def _create_volume(
within_volume_position = position - limits[:, 0]
- # NOTE: Maybe shouldn't be additive.
+ # NOTE: Maybe shouldn't be ONLY additive.
+ # give options: sum default, but also mean, max, min, or
volume[
int(within_volume_position[0]) :
int(within_volume_position[0] + shape[0]),
@@ -2048,5 +3316,5 @@ def _create_volume(
int(within_volume_position[2]) :
int(within_volume_position[2] + shape[2]),
- ] += scatterer
- return volume, limits
+ ] += splined_scatterer
+ return volume, limits
\ No newline at end of file
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index 04a7c5ea..b7e1b70a 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -166,6 +166,7 @@
import numpy as np
from numpy.typing import NDArray
from pint import Quantity
+from dataclasses import dataclass, field
from deeptrack.holography import get_propagation_matrix
from deeptrack.backend.units import (
@@ -174,12 +175,14 @@
get_active_voxel_size,
)
from deeptrack.backend import mie
+from deeptrack.math import AveragePooling
from deeptrack.features import Feature, MERGE_STRATEGY_APPEND
-from deeptrack.image import pad_image_to_fft, Image
+from deeptrack.image import pad_image_to_fft
from deeptrack.types import ArrayLike
from deeptrack import units_registry as u
+
__all__ = [
"Scatterer",
"PointParticle",
@@ -238,7 +241,7 @@ class Scatterer(Feature):
"""
- __list_merge_strategy__ = MERGE_STRATEGY_APPEND
+ __list_merge_strategy__ = MERGE_STRATEGY_APPEND ### Not clear why needed
__distributed__ = False
__conversion_table__ = ConversionTable(
position=(u.pixel, u.pixel),
@@ -258,11 +261,11 @@ def __init__(
**kwargs,
) -> None:
# Ignore warning to help with comparison with arrays.
- if upsample is not 1: # noqa: F632
- warnings.warn(
- f"Setting upsample != 1 is deprecated. "
- f"Please, instead use dt.Upscale(f, factor={upsample})"
- )
+ # if upsample != 1: # noqa: F632
+ # warnings.warn(
+ # f"Setting upsample != 1 is deprecated. "
+ # f"Please, instead use dt.Upscale(f, factor={upsample})"
+ # )
self._processed_properties = False
@@ -278,6 +281,21 @@ def __init__(
**kwargs,
)
+ def _antialias_volume(self, volume, factor: int):
+ """Geometry-only supersampling anti-aliasing.
+
+ Assumes `volume` was generated on a grid oversampled by `factor`
+ and downsamples it back by average pooling.
+ """
+ if factor == 1:
+ return volume
+
+ # average pooling conserves fractional occupancy
+ return AveragePooling(
+ factor
+ )(volume)
+
+
def _process_properties(
self,
properties: dict
@@ -296,7 +314,7 @@ def _process_and_get(
upsample_axes=None,
crop_empty=True,
**kwargs
- ) -> list[Image] | list[np.ndarray]:
+ ) -> list[np.ndarray]:
# Post processes the created object to handle upsampling,
# as well as cropping empty slices.
if not self._processed_properties:
@@ -307,16 +325,31 @@ def _process_and_get(
+ "Optics.upscale != 1."
)
- voxel_size = get_active_voxel_size()
- # Calls parent _process_and_get.
- new_image = super()._process_and_get(
+ voxel_size = np.asarray(get_active_voxel_size(), float)
+
+ apply_supersampling = upsample > 1 and isinstance(self, VolumeScatterer)
+
+ if upsample > 1 and not apply_supersampling:
+ warnings.warn(
+ "Geometry supersampling (upsample) is ignored for "
+ "FieldScatterers.",
+ UserWarning,
+ )
+
+ if apply_supersampling:
+ voxel_size /= float(upsample)
+
+ new_image = super(Scatterer, self)._process_and_get(
*args,
voxel_size=voxel_size,
upsample=upsample,
**kwargs,
- )
- new_image = new_image[0]
+ )[0]
+
+ if apply_supersampling:
+ new_image = self._antialias_volume(new_image, factor=upsample)
+
if new_image.size == 0:
warnings.warn(
@@ -333,32 +366,35 @@ def _process_and_get(
new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))]
new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))]
- return [Image(new_image)]
+ # # Copy properties
+ # props = kwargs.copy()
+ return [self._wrap_output(new_image, kwargs)]
- def _no_wrap_format_input(
- self,
- *args,
- **kwargs
- ) -> list:
- return self._image_wrapped_format_input(*args, **kwargs)
+ def _wrap_output(self, array, props):
+ raise NotImplementedError(
+ f"{self.__class__.__name__} must implement _wrap_output()"
+ )
- def _no_wrap_process_and_get(
- self,
- *args,
- **feature_input
- ) -> list:
- return self._image_wrapped_process_and_get(*args, **feature_input)
- def _no_wrap_process_output(
- self,
- *args,
- **feature_input
- ) -> list:
- return self._image_wrapped_process_output(*args, **feature_input)
+class VolumeScatterer(Scatterer):
+ """Abstract scatterer producing ScatteredVolume outputs."""
+ def _wrap_output(self, array, props) -> ScatteredVolume:
+ return ScatteredVolume(
+ array=array,
+ properties=props.copy(),
+ )
+
+
+class FieldScatterer(Scatterer):
+ def _wrap_output(self, array, props) -> ScatteredField:
+ return ScatteredField(
+ array=array,
+ properties=props.copy(),
+ )
#TODO ***??*** revise PointParticle - torch, typing, docstring, unit test
-class PointParticle(Scatterer):
+class PointParticle(VolumeScatterer):
"""Generate a diffraction-limited point particle.
A point particle is approximated by the size of a single pixel or voxel.
@@ -389,12 +425,12 @@ def __init__(
"""
"""
-
+ kwargs.pop("upsample", None)
super().__init__(upsample=1, upsample_axes=(), **kwargs)
def get(
self: PointParticle,
- image: Image | np.ndarray,
+ image: np.ndarray,
**kwarg: Any,
) -> NDArray[Any] | torch.Tensor:
"""Evaluate and return the scatterer volume."""
@@ -405,7 +441,7 @@ def get(
#TODO ***??*** revise Ellipse - torch, typing, docstring, unit test
-class Ellipse(Scatterer):
+class Ellipse(VolumeScatterer):
"""Generates an elliptical disk scatterer
Parameters
@@ -441,6 +477,7 @@ class Ellipse(Scatterer):
"""
+
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
rotation=(u.radian, u.radian),
@@ -519,7 +556,7 @@ def get(
#TODO ***??*** revise Sphere - torch, typing, docstring, unit test
-class Sphere(Scatterer):
+class Sphere(VolumeScatterer):
"""Generates a spherical scatterer
Parameters
@@ -559,7 +596,7 @@ def __init__(
def get(
self,
- image: Image | np.ndarray,
+ image: np.ndarray,
radius: float,
voxel_size: float,
**kwargs
@@ -584,7 +621,7 @@ def get(
#TODO ***??*** revise Ellipsoid - torch, typing, docstring, unit test
-class Ellipsoid(Scatterer):
+class Ellipsoid(VolumeScatterer):
"""Generates an ellipsoidal scatterer
Parameters
@@ -694,7 +731,7 @@ def _process_properties(
def get(
self,
- image: Image | np.ndarray,
+ image: np.ndarray,
radius: float,
rotation: ArrayLike[float] | float,
voxel_size: float,
@@ -741,7 +778,7 @@ def get(
#TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test
-class MieScatterer(Scatterer):
+class MieScatterer(FieldScatterer):
"""Base implementation of a Mie particle.
New Mie-theory scatterers can be implemented by extending this class, and
@@ -826,6 +863,7 @@ class MieScatterer(Scatterer):
"""
+
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
polarization_angle=(u.radian, u.radian),
@@ -856,6 +894,7 @@ def __init__(
illumination_angle: float=0,
amp_factor: float=1,
phase_shift_correction: bool=False,
+ # pupil: ArrayLike=[], # Daniel
**kwargs,
) -> None:
if polarization_angle is not None:
@@ -864,11 +903,10 @@ def __init__(
"Please use input_polarization instead"
)
input_polarization = polarization_angle
- kwargs.pop("is_field", None)
kwargs.pop("crop_empty", None)
super().__init__(
- is_field=True,
+ is_field=True, # remove
crop_empty=False,
L=L,
offset_z=offset_z,
@@ -889,6 +927,7 @@ def __init__(
illumination_angle=illumination_angle,
amp_factor=amp_factor,
phase_shift_correction=phase_shift_correction,
+ # pupil=pupil, # Daniel
**kwargs,
)
@@ -1014,7 +1053,8 @@ def get_plane_in_polar_coords(
shape: int,
voxel_size: ArrayLike[float],
plane_position: float,
- illumination_angle: float
+ illumination_angle: float,
+ # k: float, # Daniel
) -> tuple[float, float, float, float]:
"""Computes the coordinates of the plane in polar form."""
@@ -1027,15 +1067,24 @@ def get_plane_in_polar_coords(
R2_squared = X ** 2 + Y ** 2
R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z.
+
+ # # DANIEL
+ # Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0]
+ # # is dimensionally ok?
+ # sin_theta=Q/(k)
+ # pupil_mask=sin_theta<1
+ # cos_theta=np.zeros(sin_theta.shape)
+ # cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2)
# Fet the angles.
cos_theta = Z / R3
+
illumination_cos_theta = (
np.cos(np.arccos(cos_theta) + illumination_angle)
)
phi = np.arctan2(Y, X)
- return R3, cos_theta, illumination_cos_theta, phi
+ return R3, cos_theta, illumination_cos_theta, phi#, pupil_mask # Daniel
def get(
self,
@@ -1060,6 +1109,7 @@ def get(
illumination_angle: float,
amp_factor: float,
phase_shift_correction: bool,
+ # pupil: ArrayLike, # Daniel
**kwargs,
) -> ArrayLike[float]:
"""Abstract method to initialize the Mie scatterer"""
@@ -1067,8 +1117,9 @@ def get(
# Get size of the output.
xSize, ySize = self.get_xy_size(output_region, padding)
voxel_size = get_active_voxel_size()
+ scale = get_active_scale()
arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex)
- position = np.array(position) * voxel_size[: len(position)]
+ position = np.array(position) * scale[: len(position)] * voxel_size[: len(position)]
pupil_physical_size = working_distance * np.tan(collection_angle) * 2
@@ -1076,7 +1127,10 @@ def get(
ratio = offset_z / (working_distance - z)
- # Position of pbjective relative particle.
+ # Wave vector.
+ k = 2 * np.pi / wavelength * refractive_index_medium
+
+ # Position of objective relative particle.
relative_position = np.array(
(
position_objective[0] - position[0],
@@ -1085,12 +1139,13 @@ def get(
)
)
- # Get field evaluation plane at offset_z.
+ # Get field evaluation plane at offset_z. # , pupil_mask # Daniel
R3_field, cos_theta_field, illumination_angle_field, phi_field =\
self.get_plane_in_polar_coords(
arr.shape, voxel_size,
relative_position * ratio,
- illumination_angle
+ illumination_angle,
+ # k # Daniel
)
cos_phi_field, sin_phi_field = np.cos(phi_field), np.sin(phi_field)
@@ -1108,7 +1163,7 @@ def get(
sin_phi_field / ratio
)
- # If the beam is within the pupil.
+ # If the beam is within the pupil. Remove if Daniel
pupil_mask = (x_farfield - position_objective[0]) ** 2 + (
y_farfield - position_objective[1]
) ** 2 < (pupil_physical_size / 2) ** 2
@@ -1146,9 +1201,6 @@ def get(
* illumination_angle_field
)
- # Wave vector.
- k = 2 * np.pi / wavelength * refractive_index_medium
-
# Harmonics.
A, B = coefficients(L)
PI, TAU = mie.harmonics(illumination_angle_field, L)
@@ -1165,12 +1217,15 @@ def get(
[E[i] * B[i] * PI[i] + E[i] * A[i] * TAU[i] for i in range(0, L)]
)
+ # Daniel
+ # arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor
arr[pupil_mask] = (
-1j
/ (k * R3_field)
* np.exp(1j * k * R3_field)
* (S2 * S2_coef + S1 * S1_coef)
) / amp_factor
+
# For phase shift correction (a multiplication of the field
# by exp(1j * k * z)).
@@ -1188,15 +1243,23 @@ def get(
-mask.shape[1] // 2 : mask.shape[1] // 2,
]
mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2))
-
arr = arr * mask
+ # Not sure if needed... CM
+ # if len(pupil)>0:
+ # c_pix=[arr.shape[0]//2,arr.shape[1]//2]
+
+ # arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil
+
+ # Daniel
+ # fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr))))
fourier_field = np.fft.fft2(arr)
propagation_matrix = get_propagation_matrix(
fourier_field.shape,
- pixel_size=voxel_size[2],
+ pixel_size=voxel_size[:2], # this needs a double check
wavelength=wavelength / refractive_index_medium,
+ # to_z=(-z), # Daniel
to_z=(-offset_z - z),
dy=(
relative_position[0] * ratio
@@ -1206,11 +1269,12 @@ def get(
dx=(
relative_position[1] * ratio
+ position[1]
- + (padding[1] - arr.shape[1] / 2) * voxel_size[1]
+ + (padding[2] - arr.shape[1] / 2) * voxel_size[1] # check if padding is top, bottom, left, right
),
)
+
fourier_field = (
- fourier_field * propagation_matrix * np.exp(-1j * k * offset_z)
+ fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) # Remove last part (from exp)) if Daniel
)
if return_fft:
@@ -1275,6 +1339,7 @@ class MieSphere(MieScatterer):
"""
+
def __init__(
self,
radius: float = 1e-6,
@@ -1377,6 +1442,7 @@ class MieStratifiedSphere(MieScatterer):
"""
+
def __init__(
self,
radius: ArrayLike[float] = [1e-6],
@@ -1412,3 +1478,62 @@ def inner(
refractive_index=refractive_index,
**kwargs,
)
+
+
+@dataclass
+class ScatteredBase:
+ """Base class for scatterers (volumes and fields)."""
+
+ array: np.ndarray | torch.Tensor
+ properties: dict[str, Any] = field(default_factory=dict)
+
+ @property
+ def ndim(self) -> int:
+ """Number of dimensions of the underlying array."""
+ return self.array.ndim
+
+ @property
+ def shape(self) -> int:
+ """Number of dimensions of the underlying array."""
+ return self.array.shape
+
+ @property
+ def pos3d(self) -> np.ndarray:
+ return np.array([*self.position, self.z], dtype=float)
+
+ @property
+ def position(self) -> np.ndarray:
+ pos = self.properties.get("position", None)
+ if pos is None:
+ return None
+ pos = np.asarray(pos, dtype=float)
+ if pos.ndim == 2 and pos.shape[0] == 1:
+ pos = pos[0]
+ return pos
+
+ def as_array(self) -> ArrayLike:
+ """Return the underlying array.
+
+ Notes
+ -----
+ The raw array is also directly available as ``scatterer.array``.
+ This method exists mainly for API compatibility and clarity.
+
+ """
+
+ return self.array
+
+ def get_property(self, key: str, default: Any = None) -> Any:
+ return getattr(self, key, self.properties.get(key, default))
+
+
+@dataclass
+class ScatteredVolume(ScatteredBase):
+ """Voxelized volume produced by a VolumeScatterer."""
+ pass
+
+
+@dataclass
+class ScatteredField(ScatteredBase):
+ """Complex field produced by a FieldScatterer."""
+ pass
\ No newline at end of file