diff --git a/README-pypi.md b/README-pypi.md index be0de197..f6173669 100644 --- a/README-pypi.md +++ b/README-pypi.md @@ -93,6 +93,14 @@ Here you find a series of notebooks that give you an overview of the core featur Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline. +- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)** + + Creating custom scatterers of arbitrary shapes. + +- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)** + + Creating custom scatterers in the shape of bacteria. + # Examples These are examples of how DeepTrack2 can be used on real datasets: diff --git a/README.md b/README.md index 674d41d3..a9f507c2 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,14 @@ Here you find a series of notebooks that give you an overview of the core featur Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline. +- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)** + + Creating custom scatterers of arbitrary shapes. + +- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)** + + Creating custom scatterers in the shape of bacteria. + # Examples These are examples of how DeepTrack2 can be used on real datasets: diff --git a/deeptrack/features.py b/deeptrack/features.py index 4bdfad38..702e7362 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -1,6 +1,6 @@ """Core features for building and processing pipelines in DeepTrack2. -The `feasture.py` module defines the core classes and utilities used to create +The `feature.py` module defines the core classes and utilities used to create and manipulate features in DeepTrack2, enabling users to build sophisticated data processing pipelines with modular, reusable, and composable components. @@ -80,11 +80,8 @@ - `OneOf`: Resolve one feature from a given collection. - `OneOfDict`: Resolve one feature from a dictionary and apply it to an input. - `LoadImage`: Load an image from disk and preprocess it. -- `SampleToMasks`: Create a mask from a list of images. - `AsType`: Convert the data type of the input. - `ChannelFirst2d`: DEPRECATED Convert an image to a channel-first format. -- `Upscale`: Simulate a pipeline at a higher resolution. -- `NonOverlapping`: Ensure volumes are placed non-overlapping in a 3D space. - `Store`: Store the output of a feature for reuse. - `Squeeze`: Squeeze the input to the smallest possible dimension. - `Unsqueeze`: Unsqueeze the input. @@ -96,7 +93,7 @@ - `TakeProperties`: Extract all instances of properties from a pipeline. Arithmetic Feature Classes: -- `Add`: Add a value to the input. +- `Add`: Add a value to the input.@dataclass - `Subtract`: Subtract a value from the input. - `Multiply`: Multiply the input by a value. - `Divide`: Divide the input by a value. @@ -218,11 +215,8 @@ "OneOf", "OneOfDict", "LoadImage", - "SampleToMasks", # TODO ***CM*** revise this after elimination of Image "AsType", "ChannelFirst2d", - "Upscale", # TODO ***CM*** revise and check PyTorch afrer elimin. Image - "NonOverlapping", # TODO ***CM*** revise + PyTorch afrer elimin. Image "Store", "Squeeze", "Unsqueeze", @@ -7359,312 +7353,6 @@ def get( return image -class SampleToMasks(Feature): - """Create a mask from a list of images. - - This feature applies a transformation function to each input image and - merges the resulting masks into a single multi-layer image. Each input - image must have a `position` property that determines its placement within - the final mask. When used with scatterers, the `voxel_size` property must - be provided for correct object sizing. - - Parameters - ---------- - transformation_function: Callable[[Image], Image] - A function that transforms each input image into a mask with - `number_of_masks` layers. - number_of_masks: PropertyLike[int], optional - The number of mask layers to generate. Default is 1. - output_region: PropertyLike[tuple[int, int, int, int]], optional - The size and position of the output mask, typically aligned with - `optics.output_region`. - merge_method: PropertyLike[str | Callable | list[str | Callable]], optional - Method for merging individual masks into the final image. Can be: - - "add" (default): Sum the masks. - - "overwrite": Later masks overwrite earlier masks. - - "or": Combine masks using a logical OR operation. - - "mul": Multiply masks. - - Function: Custom function taking two images and merging them. - - **kwargs: dict[str, Any] - Additional keyword arguments passed to the parent `Feature` class. - - Methods - ------- - `get(image, transformation_function, **kwargs) -> Image` - Applies the transformation function to the input image. - `_process_and_get(images, **kwargs) -> Image | np.ndarray` - Processes a list of images and generates a multi-layer mask. - - Returns - ------- - Image or np.ndarray - The final mask image with the specified number of layers. - - Raises - ------ - ValueError - If `merge_method` is invalid. - - Examples - ------- - >>> import deeptrack as dt - - Define number of particles: - - >>> n_particles = 12 - - Define optics and particles: - - >>> import numpy as np - >>> - >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64)) - >>> particle = dt.PointParticle( - >>> position=lambda: np.random.uniform(5, 55, size=2), - >>> ) - >>> particles = particle ^ n_particles - - Define pipelines: - - >>> sim_im_pip = optics(particles) - >>> sim_mask_pip = particles >> dt.SampleToMasks( - ... lambda: lambda particles: particles > 0, - ... output_region=optics.output_region, - ... merge_method="or", - ... ) - >>> pipeline = sim_im_pip & sim_mask_pip - >>> pipeline.store_properties() - - Generate image and mask: - - >>> image, mask = pipeline.update()() - - Get particle positions: - - >>> positions = np.array(image.get_property("position", get_one=False)) - - Visualize results: - - >>> import matplotlib.pyplot as plt - >>> - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(image, cmap="gray") - >>> plt.title("Original Image") - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(mask, cmap="gray") - >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50) - >>> plt.title("Mask") - >>> plt.show() - - """ - - def __init__( - self: Feature, - transformation_function: Callable[[Image], Image], - number_of_masks: PropertyLike[int] = 1, - output_region: PropertyLike[tuple[int, int, int, int]] = None, - merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add", - **kwargs: Any, - ): - """Initialize the SampleToMasks feature. - - Parameters - ---------- - transformation_function: Callable[[Image], Image] - Function to transform input images into masks. - number_of_masks: PropertyLike[int], optional - Number of mask layers. Default is 1. - output_region: PropertyLike[tuple[int, int, int, int]], optional - Output region of the mask. Default is None. - merge_method: PropertyLike[str | Callable | list[str | Callable]], optional - Method to merge masks. Defaults to "add". - **kwargs: dict[str, Any] - Additional keyword arguments passed to the parent class. - - """ - - super().__init__( - transformation_function=transformation_function, - number_of_masks=number_of_masks, - output_region=output_region, - merge_method=merge_method, - **kwargs, - ) - - def get( - self: Feature, - image: np.ndarray | Image, - transformation_function: Callable[[Image], Image], - **kwargs: Any, - ) -> Image: - """Apply the transformation function to a single image. - - Parameters - ---------- - image: np.ndarray | Image - The input image. - transformation_function: Callable[[Image], Image] - Function to transform the image. - **kwargs: dict[str, Any] - Additional parameters. - - Returns - ------- - Image - The transformed image. - - """ - - return transformation_function(image) - - def _process_and_get( - self: Feature, - images: list[np.ndarray] | np.ndarray | list[Image] | Image, - **kwargs: Any, - ) -> Image | np.ndarray: - """Process a list of images and generate a multi-layer mask. - - Parameters - ---------- - images: np.ndarray or list[np.ndarrray] or Image or list[Image] - List of input images or a single image. - **kwargs: dict[str, Any] - Additional parameters including `output_region`, `number_of_masks`, - and `merge_method`. - - Returns - ------- - Image or np.ndarray - The final mask image. - - """ - - # Handle list of images. - if isinstance(images, list) and len(images) != 1: - list_of_labels = super()._process_and_get(images, **kwargs) - if not self._wrap_array_with_image: - for idx, (label, image) in enumerate(zip(list_of_labels, - images)): - list_of_labels[idx] = \ - Image(label, copy=False).merge_properties_from(image) - else: - if isinstance(images, list): - images = images[0] - list_of_labels = [] - for prop in images.properties: - - if "position" in prop: - - inp = Image(np.array(images)) - inp.append(prop) - out = Image(self.get(inp, **kwargs)) - out.merge_properties_from(inp) - list_of_labels.append(out) - - # Create an empty output image. - output_region = kwargs["output_region"] - output = np.zeros( - ( - output_region[2] - output_region[0], - output_region[3] - output_region[1], - kwargs["number_of_masks"], - ) - ) - - from deeptrack.optics import _get_position - - # Merge masks into the output. - for label in list_of_labels: - position = _get_position(label) - p0 = np.round(position - output_region[0:2]) - - if np.any(p0 > output.shape[0:2]) or \ - np.any(p0 + label.shape[0:2] < 0): - continue - - crop_x = int(-np.min([p0[0], 0])) - crop_y = int(-np.min([p0[1], 0])) - crop_x_end = int( - label.shape[0] - - np.max([p0[0] + label.shape[0] - output.shape[0], 0]) - ) - crop_y_end = int( - label.shape[1] - - np.max([p0[1] + label.shape[1] - output.shape[1], 0]) - ) - - labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :] - - p0[0] = np.max([p0[0], 0]) - p0[1] = np.max([p0[1], 0]) - - p0 = p0.astype(int) - - output_slice = output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - ] - - for label_index in range(kwargs["number_of_masks"]): - - if isinstance(kwargs["merge_method"], list): - merge = kwargs["merge_method"][label_index] - else: - merge = kwargs["merge_method"] - - if merge == "add": - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] += labelarg[..., label_index] - - elif merge == "overwrite": - output_slice[ - labelarg[..., label_index] != 0, label_index - ] = labelarg[labelarg[..., label_index] != 0, \ - label_index] - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] = output_slice[..., label_index] - - elif merge == "or": - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] = (output_slice[..., label_index] != 0) | ( - labelarg[..., label_index] != 0 - ) - - elif merge == "mul": - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] *= labelarg[..., label_index] - - else: - # No match, assume function - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] = merge( - output_slice[..., label_index], - labelarg[..., label_index], - ) - - if not self._wrap_array_with_image: - return output - output = Image(output) - for label in list_of_labels: - output.merge_properties_from(label) - return output - - class AsType(Feature): """Convert the data type of arrays. @@ -7930,855 +7618,181 @@ def get( return array -class Upscale(Feature): - """Simulate a pipeline at a higher resolution. +# class Upscale(Feature): +# """Simulate a pipeline at a higher resolution. - This feature scales up the resolution of the input pipeline by a specified - factor, performs computations at the higher resolution, and then - downsamples the result back to the original size. This is useful for - simulating effects at a finer resolution while preserving compatibility - with lower-resolution pipelines. +# This feature scales up the resolution of the input pipeline by a specified +# factor, performs computations at the higher resolution, and then +# downsamples the result back to the original size. This is useful for +# simulating effects at a finer resolution while preserving compatibility +# with lower-resolution pipelines. - Internally, this feature redefines the scale of physical units (e.g., - `units.pixel`) to achieve the effect of upscaling. Therefore, it does not - resize the input image itself but affects only features that rely on - physical units. - - Parameters - ---------- - feature: Feature - The pipeline or feature to resolve at a higher resolution. - factor: int or tuple[int, int, int], optional - The factor by which to upscale the simulation. If a single integer is - provided, it is applied uniformly across all axes. If a tuple of three - integers is provided, each axis is scaled individually. Defaults to 1. - **kwargs: Any - Additional keyword arguments passed to the parent `Feature` class. - - Attributes - ---------- - __distributed__: bool - Always `False` for `Upscale`, indicating that this feature’s `.get()` - method processes the entire input at once even if it is a list, rather - than distributing calls for each item of the list. - - Methods - ------- - `get(image, factor, **kwargs) -> np.ndarray | torch.tensor` - Simulates the pipeline at a higher resolution and returns the result at - the original resolution. - - Notes - ----- - - This feature does not directly resize the image. Instead, it modifies the - unit conversions within the pipeline, making physical units smaller, - which results in more detail being simulated. - - The final output is downscaled back to the original resolution using - `block_reduce` from `skimage.measure`. - - The effect is only noticeable if features use physical units (e.g., - `units.pixel`, `units.meter`). Otherwise, the result will be identical. - - Examples - -------- - >>> import deeptrack as dt - - Define an optical pipeline and a spherical particle: - - >>> optics = dt.Fluorescence() - >>> particle = dt.Sphere() - >>> simple_pipeline = optics(particle) - - Create an upscaled pipeline with a factor of 4: - - >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4) - - Resolve the pipelines: - - >>> image = simple_pipeline() - >>> upscaled_image = upscaled_pipeline() - - Visualize the images: - - >>> import matplotlib.pyplot as plt - >>> - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(image, cmap="gray") - >>> plt.title("Original Image") - >>> - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(upscaled_image, cmap="gray") - >>> plt.title("Simulated at Higher Resolution") - >>> - >>> plt.show() +# Internally, this feature redefines the scale of physical units (e.g., +# `units.pixel`) to achieve the effect of upscaling. Therefore, it does not +# resize the input image itself but affects only features that rely on +# physical units. + +# Parameters +# ---------- +# feature: Feature +# The pipeline or feature to resolve at a higher resolution. +# factor: int or tuple[int, int, int], optional +# The factor by which to upscale the simulation. If a single integer is +# provided, it is applied uniformly across all axes. If a tuple of three +# integers is provided, each axis is scaled individually. Defaults to 1. +# **kwargs: Any +# Additional keyword arguments passed to the parent `Feature` class. + +# Attributes +# ---------- +# __distributed__: bool +# Always `False` for `Upscale`, indicating that this feature’s `.get()` +# method processes the entire input at once even if it is a list, rather +# than distributing calls for each item of the list. + +# Methods +# ------- +# `get(image, factor, **kwargs) -> np.ndarray | torch.tensor` +# Simulates the pipeline at a higher resolution and returns the result at +# the original resolution. + +# Notes +# ----- +# - This feature does not directly resize the image. Instead, it modifies the +# unit conversions within the pipeline, making physical units smaller, +# which results in more detail being simulated. +# - The final output is downscaled back to the original resolution using +# `block_reduce` from `skimage.measure`. +# - The effect is only noticeable if features use physical units (e.g., +# `units.pixel`, `units.meter`). Otherwise, the result will be identical. + +# Examples +# -------- +# >>> import deeptrack as dt + +# Define an optical pipeline and a spherical particle: + +# >>> optics = dt.Fluorescence() +# >>> particle = dt.Sphere() +# >>> simple_pipeline = optics(particle) + +# Create an upscaled pipeline with a factor of 4: + +# >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4) - Compare the shapes (both are the same due to downscaling): - - >>> print(image.shape) - (128, 128, 1) - >>> print(upscaled_image.shape) - (128, 128, 1) +# Resolve the pipelines: + +# >>> image = simple_pipeline() +# >>> upscaled_image = upscaled_pipeline() + +# Visualize the images: + +# >>> import matplotlib.pyplot as plt +# >>> +# >>> plt.subplot(1, 2, 1) +# >>> plt.imshow(image, cmap="gray") +# >>> plt.title("Original Image") +# >>> +# >>> plt.subplot(1, 2, 2) +# >>> plt.imshow(upscaled_image, cmap="gray") +# >>> plt.title("Simulated at Higher Resolution") +# >>> +# >>> plt.show() - """ - - __distributed__: bool = False - - feature: Feature - - def __init__( - self: Feature, - feature: Feature, - factor: int | tuple[int, int, int] = 1, - **kwargs: Any, - ) -> None: - """Initialize the Upscale feature. - - Parameters - ---------- - feature: Feature - The pipeline or feature to resolve at a higher resolution. - factor: int or tuple[int, int, int], optional - The factor by which to upscale the simulation. If a single integer - is provided, it is applied uniformly across all axes. If a tuple of - three integers is provided, each axis is scaled individually. - Defaults to 1. - **kwargs: Any - Additional keyword arguments passed to the parent `Feature` class. - - """ - - super().__init__(factor=factor, **kwargs) - self.feature = self.add_feature(feature) - - def get( - self: Feature, - image: np.ndarray | torch.Tensor, - factor: int | tuple[int, int, int], - **kwargs: Any, - ) -> np.ndarray | torch.Tensor: - """Simulate the pipeline at a higher resolution and return result. - - Parameters - ---------- - image: np.ndarray or torch.Tensor - The input image to process. - factor: int or tuple[int, int, int] - The factor by which to upscale the simulation. If a single integer - is provided, it is applied uniformly across all axes. If a tuple of - three integers is provided, each axis is scaled individually. - **kwargs: Any - Additional keyword arguments passed to the feature. - - Returns - ------- - np.ndarray or torch.Tensor - The processed image at the original resolution. +# Compare the shapes (both are the same due to downscaling): - Raises - ------ - ValueError - If the input `factor` is not a valid integer or tuple of integers. - - """ - - # Ensure factor is a tuple of three integers. - if np.size(factor) == 1: - factor = (factor,) * 3 - elif len(factor) != 3: - raise ValueError( - "Factor must be an integer or a tuple of three integers." - ) - - # Create a context for upscaling and perform computation. - ctx = create_context(None, None, None, *factor) - with units.context(ctx): - image = self.feature(image) - - # Downscale the result to the original resolution. - import skimage.measure - - image = skimage.measure.block_reduce( - image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean - ) - - return image - - -class NonOverlapping(Feature): - """Ensure volumes are placed non-overlapping in a 3D space. - - This feature ensures that a list of 3D volumes are positioned such that - their non-zero voxels do not overlap. If volumes overlap, their positions - are resampled until they are non-overlapping. If the maximum number of - attempts is exceeded, the feature regenerates the list of volumes and - raises a warning if non-overlapping placement cannot be achieved. - - Note: `min_distance` refers to the distance between the edges of volumes, - not their centers. Due to the way volumes are calculated, slight rounding - errors may affect the final distance. +# >>> print(image.shape) +# (128, 128, 1) +# >>> print(upscaled_image.shape) +# (128, 128, 1) - This feature is incompatible with non-volumetric scatterers such as - `MieScatterers`. - - Parameters - ---------- - feature: Feature - The feature that generates the list of volumes to place - non-overlapping. - min_distance: float, optional - The minimum distance between volumes in pixels. It can be negative to - allow for partial overlap. Defaults to 1. - max_attempts: int, optional - The maximum number of attempts to place volumes without overlap. - Defaults to 5. - max_iters: int, optional - The maximum number of resamplings. If this number is exceeded, a new - list of volumes is generated. Defaults to 100. - - Attributes - ---------- - __distributed__: bool - Always `False` for `NonOverlapping`, indicating that this feature’s - `.get()` method processes the entire input at once even if it is a - list, rather than distributing calls for each item of the list.N - - Methods - ------- - `get(*_, min_distance, max_attempts, **kwargs) -> array` - Generate a list of non-overlapping 3D volumes. - `_check_non_overlapping(list_of_volumes) -> bool` - Check if all volumes in the list are non-overlapping. - `_check_bounding_cubes_non_overlapping(...) -> bool` - Check if two bounding cubes are non-overlapping. - `_get_overlapping_cube(...) -> list[int]` - Get the overlapping cube between two bounding cubes. - `_get_overlapping_volume(...) -> array` - Get the overlapping volume between a volume and a bounding cube. - `_check_volumes_non_overlapping(...) -> bool` - Check if two volumes are non-overlapping. - `_resample_volume_position(volume) -> Image` - Resample the position of a volume to avoid overlap. - - Notes - ----- - - This feature performs bounding cube checks first to quickly reject - obvious overlaps before voxel-level checks. - - If the bounding cubes overlap, precise voxel-based checks are performed. - - Examples - --------- - >>> import deeptrack as dt - - Define an ellipse scatterer with randomly positioned objects: - - >>> import numpy as np - >>> - >>> scatterer = dt.Ellipse( - >>> radius= 13 * dt.units.pixels, - >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels, - >>> ) - - Create multiple scatterers: - - >>> scatterers = (scatterer ^ 8) - - Define the optics and create the image with possible overlap: - - >>> optics = dt.Fluorescence() - >>> im_with_overlap = optics(scatterers) - >>> im_with_overlap.store_properties() - >>> im_with_overlap_resolved = image_with_overlap() - - Gather position from image: - - >>> pos_with_overlap = np.array( - >>> im_with_overlap_resolved.get_property( - >>> "position", - >>> get_one=False - >>> ) - >>> ) - - Enforce non-overlapping and create the image without overlap: - - >>> non_overlapping_scatterers = dt.NonOverlapping( - ... scatterers, - ... min_distance=4, - ... ) - >>> im_without_overlap = optics(non_overlapping_scatterers) - >>> im_without_overlap.store_properties() - >>> im_without_overlap_resolved = im_without_overlap() - - Gather position from image: - - >>> pos_without_overlap = np.array( - >>> im_without_overlap_resolved.get_property( - >>> "position", - >>> get_one=False - >>> ) - >>> ) - - Create a figure with two subplots to visualize the difference: - - >>> import matplotlib.pyplot as plt - >>> - >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5)) - >>> - >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray") - >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0]) - >>> axes[0].set_title("Overlapping Objects") - >>> axes[0].axis("off") - >>> - >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray") - >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0]) - >>> axes[1].set_title("Non-Overlapping Objects") - >>> axes[1].axis("off") - >>> plt.tight_layout() - >>> - >>> plt.show() - - Define function to calculate minimum distance: - - >>> def calculate_min_distance(positions): - >>> distances = [ - >>> np.linalg.norm(positions[i] - positions[j]) - >>> for i in range(len(positions)) - >>> for j in range(i + 1, len(positions)) - >>> ] - >>> return min(distances) - - Print minimum distances with and without overlap: - - >>> print(calculate_min_distance(pos_with_overlap)) - 10.768742383382174 - - >>> print(calculate_min_distance(pos_without_overlap)) - 30.82531120942446 - - """ - - __distributed__: bool = False - - def __init__( - self: NonOverlapping, - feature: Feature, - min_distance: float = 1, - max_attempts: int = 5, - max_iters: int = 100, - **kwargs: Any, - ): - """Initializes the NonOverlapping feature. - - Ensures that volumes are placed **non-overlapping** by iteratively - resampling their positions. If the maximum number of attempts is - exceeded, the feature regenerates the list of volumes. - - Parameters - ---------- - feature: Feature - The feature that generates the list of volumes. - min_distance: float, optional - The minimum separation distance **between volume edges**, in - pixels. It defaults to `1`. Negative values allow for partial - overlap. - max_attempts: int, optional - The maximum number of attempts to place the volumes without - overlap. It defaults to `5`. - max_iters: int, optional - The maximum number of resampling iterations per attempt. If - exceeded, a new list of volumes is generated. It defaults to `100`. - - """ - - super().__init__( - min_distance=min_distance, - max_attempts=max_attempts, - max_iters=max_iters, - **kwargs, - ) - self.feature = self.add_feature(feature, **kwargs) - - def get( - self: NonOverlapping, - *_: Any, - min_distance: float, - max_attempts: int, - max_iters: int, - **kwargs: Any, - ) -> list[np.ndarray]: - """Generates a list of non-overlapping 3D volumes within a defined - field of view (FOV). - - This method **iteratively** attempts to place volumes while ensuring - they maintain at least `min_distance` separation. If non-overlapping - placement is not achieved within `max_attempts`, a warning is issued, - and the best available configuration is returned. - - Parameters - ---------- - _: Any - Placeholder parameter, typically for an input image. - min_distance: float - The minimum required separation distance between volumes, in - pixels. - max_attempts: int - The maximum number of attempts to generate a valid non-overlapping - configuration. - max_iters: int - The maximum number of resampling iterations per attempt. - **kwargs: Any - Additional parameters that may be used by subclasses. - - Returns - ------- - list[np.ndarray] - A list of 3D volumes represented as NumPy arrays. If - non-overlapping placement is unsuccessful, the best available - configuration is returned. - - Warns - ----- - UserWarning - If non-overlapping placement is **not** achieved within - `max_attempts`, suggesting parameter adjustments such as increasing - the FOV or reducing `min_distance`. - - Notes - ----- - - The placement process prioritizes bounding cube checks for - efficiency. - - If bounding cubes overlap, voxel-based overlap checks are performed. +# """ + +# __distributed__: bool = False + +# feature: Feature + +# def __init__( +# self: Feature, +# feature: Feature, +# factor: int | tuple[int, int, int] = 1, +# **kwargs: Any, +# ) -> None: +# """Initialize the Upscale feature. + +# Parameters +# ---------- +# feature: Feature +# The pipeline or feature to resolve at a higher resolution. +# factor: int or tuple[int, int, int], optional +# The factor by which to upscale the simulation. If a single integer +# is provided, it is applied uniformly across all axes. If a tuple of +# three integers is provided, each axis is scaled individually. +# Defaults to 1. +# **kwargs: Any +# Additional keyword arguments passed to the parent `Feature` class. + +# """ + +# super().__init__(factor=factor, **kwargs) +# self.feature = self.add_feature(feature) + +# def get( +# self: Feature, +# image: np.ndarray | torch.Tensor, +# factor: int | tuple[int, int, int], +# **kwargs: Any, +# ) -> np.ndarray | torch.Tensor: +# """Simulate the pipeline at a higher resolution and return result. + +# Parameters +# ---------- +# image: np.ndarray or torch.Tensor +# The input image to process. +# factor: int or tuple[int, int, int] +# The factor by which to upscale the simulation. If a single integer +# is provided, it is applied uniformly across all axes. If a tuple of +# three integers is provided, each axis is scaled individually. +# **kwargs: Any +# Additional keyword arguments passed to the feature. + +# Returns +# ------- +# np.ndarray or torch.Tensor +# The processed image at the original resolution. + +# Raises +# ------ +# ValueError +# If the input `factor` is not a valid integer or tuple of integers. + +# """ + +# # Ensure factor is a tuple of three integers. +# if np.size(factor) == 1: +# factor = (factor, factor, 1) +# elif len(factor) != 3: +# raise ValueError( +# "Factor must be an integer or a tuple of three integers." +# ) - """ - - for _ in range(max_attempts): - list_of_volumes = self.feature() - - if not isinstance(list_of_volumes, list): - list_of_volumes = [list_of_volumes] +# # Create a context for upscaling and perform computation. +# ctx = create_context(None, None, None, *factor) - for _ in range(max_iters): - - list_of_volumes = [ - self._resample_volume_position(volume) - for volume in list_of_volumes - ] - - if self._check_non_overlapping(list_of_volumes): - return list_of_volumes - - # Generate a new list of volumes if max_attempts is exceeded. - self.feature.update() - - warnings.warn( - "Non-overlapping placement could not be achieved. Consider " - "adjusting parameters: reduce object radius, increase FOV, " - "or decrease min_distance.", - UserWarning, - ) - return list_of_volumes +# print('before:', image) +# with units.context(ctx): +# image = self.feature(image) - def _check_non_overlapping( - self: NonOverlapping, - list_of_volumes: list[np.ndarray], - ) -> bool: - """Determines whether all volumes in the provided list are - non-overlapping. +# print('after:', image) +# # Downscale the result to the original resolution. +# import skimage.measure - This method verifies that the non-zero voxels of each 3D volume in - `list_of_volumes` are at least `min_distance` apart. It first checks - bounding boxes for early rejection and then examines actual voxel - overlap when necessary. Volumes are assumed to have a `position` - attribute indicating their placement in 3D space. - - Parameters - ---------- - list_of_volumes: list[np.ndarray] - A list of 3D arrays representing the volumes to be checked for - overlap. Each volume is expected to have a position attribute. - - Returns - ------- - bool - `True` if all volumes are non-overlapping, otherwise `False`. - - Notes - ----- - - If `min_distance` is negative, volumes are shrunk using isotropic - erosion before checking overlap. - - If `min_distance` is positive, volumes are padded and expanded using - isotropic dilation. - - Overlapping checks are first performed on bounding cubes for - efficiency. - - If bounding cubes overlap, voxel-level checks are performed. - - """ - - from skimage.morphology import isotropic_erosion, isotropic_dilation - - from deeptrack.augmentations import CropTight, Pad - from deeptrack.optics import _get_position - - min_distance = self.min_distance() - crop = CropTight() - - if min_distance < 0: - list_of_volumes = [ - Image( - crop(isotropic_erosion(volume != 0, -min_distance/2)), - copy=False, - ).merge_properties_from(volume) - for volume in list_of_volumes - ] - else: - pad = Pad(px = [int(np.ceil(min_distance/2))]*6, keep_size=True) - list_of_volumes = [ - Image( - crop(isotropic_dilation(pad(volume) != 0, min_distance/2)), - copy=False, - ).merge_properties_from(volume) - for volume in list_of_volumes - ] - min_distance = 1 - - # The position of the top left corner of each volume (index (0, 0, 0)). - volume_positions_1 = [ - _get_position(volume, mode="corner", return_z=True).astype(int) - for volume in list_of_volumes - ] - - # The position of the bottom right corner of each volume - # (index (-1, -1, -1)). - volume_positions_2 = [ - p0 + np.array(v.shape) - for v, p0 in zip(list_of_volumes, volume_positions_1) - ] - - # (x1, y1, z1, x2, y2, z2) for each volume. - volume_bounding_cube = [ - [*p0, *p1] - for p0, p1 in zip(volume_positions_1, volume_positions_2) - ] - - for i, j in itertools.combinations(range(len(list_of_volumes)), 2): - - # If the bounding cubes do not overlap, the volumes do not overlap. - if self._check_bounding_cubes_non_overlapping( - volume_bounding_cube[i], volume_bounding_cube[j], min_distance - ): - continue - - # If the bounding cubes overlap, get the overlapping region of each - # volume. - overlapping_cube = self._get_overlapping_cube( - volume_bounding_cube[i], volume_bounding_cube[j] - ) - overlapping_volume_1 = self._get_overlapping_volume( - list_of_volumes[i], volume_bounding_cube[i], overlapping_cube - ) - overlapping_volume_2 = self._get_overlapping_volume( - list_of_volumes[j], volume_bounding_cube[j], overlapping_cube - ) - - # If either the overlapping regions are empty, the volumes do not - # overlap (done for speed). - if (np.all(overlapping_volume_1 == 0) - or np.all(overlapping_volume_2 == 0)): - continue - - # If products of overlapping regions are non-zero, return False. - # if np.any(overlapping_volume_1 * overlapping_volume_2): - # return False - - # Finally, check that the non-zero voxels of the volumes are at - # least min_distance apart. - if not self._check_volumes_non_overlapping( - overlapping_volume_1, overlapping_volume_2, min_distance - ): - return False - - return True - - def _check_bounding_cubes_non_overlapping( - self: NonOverlapping, - bounding_cube_1: list[int], - bounding_cube_2: list[int], - min_distance: float, - ) -> bool: - """Determines whether two 3D bounding cubes are non-overlapping. - - This method checks whether the bounding cubes of two volumes are - **separated by at least** `min_distance` along **any** spatial axis. - - Parameters - ---------- - bounding_cube_1: list[int] - A list of six integers `[x1, y1, z1, x2, y2, z2]` representing - the first bounding cube. - bounding_cube_2: list[int] - A list of six integers `[x1, y1, z1, x2, y2, z2]` representing - the second bounding cube. - min_distance: float - The required **minimum separation distance** between the two - bounding cubes. - - Returns - ------- - bool - `True` if the bounding cubes are non-overlapping (separated by at - least `min_distance` along **at least one axis**), otherwise - `False`. - - Notes - ----- - - This function **only checks bounding cubes**, **not actual voxel - data**. - - If the bounding cubes are non-overlapping, the corresponding - **volumes are also non-overlapping**. - - This check is much **faster** than full voxel-based comparisons. - - """ - - # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2). - # Check that the bounding cubes are non-overlapping. - return ( - (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or - (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or - (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or - (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or - (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or - (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance) - ) - - def _get_overlapping_cube( - self: NonOverlapping, - bounding_cube_1: list[int], - bounding_cube_2: list[int], - ) -> list[int]: - """Computes the overlapping region between two 3D bounding cubes. - - This method calculates the coordinates of the intersection of two - axis-aligned bounding cubes, each represented as a list of six - integers: - - - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner. - - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner. - - The resulting overlapping region is determined by: - - Taking the **maximum** of the starting coordinates (`x1, y1, z1`). - - Taking the **minimum** of the ending coordinates (`x2, y2, z2`). - - If the cubes **do not** overlap, the resulting coordinates will not - form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`). - - Parameters - ---------- - bounding_cube_1: list[int] - The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. - bounding_cube_2: list[int] - The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. - - Returns - ------- - list[int] - A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the - overlapping bounding cube. If no overlap exists, the coordinates - will **not** define a valid cube. - - Notes - ----- - - This function does **not** check for valid input or ensure the - resulting cube is well-formed. - - If no overlap exists, downstream functions must handle the invalid - result. - - """ - - return [ - max(bounding_cube_1[0], bounding_cube_2[0]), - max(bounding_cube_1[1], bounding_cube_2[1]), - max(bounding_cube_1[2], bounding_cube_2[2]), - min(bounding_cube_1[3], bounding_cube_2[3]), - min(bounding_cube_1[4], bounding_cube_2[4]), - min(bounding_cube_1[5], bounding_cube_2[5]), - ] - - def _get_overlapping_volume( - self: NonOverlapping, - volume: np.ndarray, # 3D array. - bounding_cube: tuple[float, float, float, float, float, float], - overlapping_cube: tuple[float, float, float, float, float, float], - ) -> np.ndarray: - """Extracts the overlapping region of a 3D volume within the specified - overlapping cube. - - This method identifies and returns the subregion of `volume` that - lies within the `overlapping_cube`. The bounding information of the - volume is provided via `bounding_cube`. - - Parameters - ---------- - volume: np.ndarray - A 3D NumPy array representing the volume from which the - overlapping region is extracted. - bounding_cube: tuple[float, float, float, float, float, float] - The bounding cube of the volume, given as a tuple of six floats: - `(x1, y1, z1, x2, y2, z2)`. The first three values define the - **top-left-front** corner, while the last three values define the - **bottom-right-back** corner. - overlapping_cube: tuple[float, float, float, float, float, float] - The overlapping region between the volume and another volume, - represented in the same format as `bounding_cube`. - - Returns - ------- - np.ndarray - A 3D NumPy array representing the portion of `volume` that - lies within `overlapping_cube`. If the overlap does not exist, - an empty array may be returned. - - Notes - ----- - - The method computes the relative indices of `overlapping_cube` - within `volume` by subtracting the bounding cube's starting - position. - - The extracted region is determined by integer indices, meaning - coordinates are implicitly **floored to integers**. - - If `overlapping_cube` extends beyond `volume` boundaries, the - returned subregion is **cropped** to fit within `volume`. - - """ - - # The position of the top left corner of the overlapping cube in the volume - overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array( - bounding_cube[:3] - ) - - # The position of the bottom right corner of the overlapping cube in the volume - overlapping_cube_end_position = np.array( - overlapping_cube[3:] - ) - np.array(bounding_cube[:3]) - - # cast to int - overlapping_cube_position = overlapping_cube_position.astype(int) - overlapping_cube_end_position = overlapping_cube_end_position.astype(int) - - return volume[ - overlapping_cube_position[0] : overlapping_cube_end_position[0], - overlapping_cube_position[1] : overlapping_cube_end_position[1], - overlapping_cube_position[2] : overlapping_cube_end_position[2], - ] - - def _check_volumes_non_overlapping( - self: NonOverlapping, - volume_1: np.ndarray, - volume_2: np.ndarray, - min_distance: float, - ) -> bool: - """Determines whether the non-zero voxels in two 3D volumes are at - least `min_distance` apart. - - This method checks whether the active regions (non-zero voxels) in - `volume_1` and `volume_2` maintain a minimum separation of - `min_distance`. If the volumes differ in size, the positions of their - non-zero voxels are adjusted accordingly to ensure a fair comparison. - - Parameters - ---------- - volume_1: np.ndarray - A 3D NumPy array representing the first volume. - volume_2: np.ndarray - A 3D NumPy array representing the second volume. - min_distance: float - The minimum Euclidean distance required between any two non-zero - voxels in the two volumes. - - Returns - ------- - bool - `True` if all non-zero voxels in `volume_1` and `volume_2` are at - least `min_distance` apart, otherwise `False`. - - Notes - ----- - - This function assumes both volumes are correctly aligned within a - shared coordinate space. - - If the volumes are of different sizes, voxel positions are scaled - or adjusted for accurate distance measurement. - - Uses **Euclidean distance** for separation checking. - - If either volume is empty (i.e., no non-zero voxels), they are - considered non-overlapping. - - """ - - # Get the positions of the non-zero voxels of each volume. - positions_1 = np.argwhere(volume_1) - positions_2 = np.argwhere(volume_2) - - # if positions_1.size == 0 or positions_2.size == 0: - # return True # If either volume is empty, they are "non-overlapping" - - # # If the volumes are not the same size, the positions of the non-zero - # # voxels of each volume need to be scaled. - # if positions_1.size == 0 or positions_2.size == 0: - # return True # If either volume is empty, they are "non-overlapping" - - # If the volumes are not the same size, the positions of the non-zero - # voxels of each volume need to be scaled. - if volume_1.shape != volume_2.shape: - positions_1 = ( - positions_1 * np.array(volume_2.shape) - / np.array(volume_1.shape) - ) - positions_1 = positions_1.astype(int) - - # Check that the non-zero voxels of the volumes are at least - # min_distance apart. - return np.all( - cdist(positions_1, positions_2) > min_distance - ) - - def _resample_volume_position( - self: NonOverlapping, - volume: np.ndarray | Image, - ) -> Image: - """Resamples the position of a 3D volume using its internal position - sampler. - - This method updates the `position` property of the given `volume` by - drawing a new position from the `_position_sampler` stored in the - volume's `properties`. If the sampled position is a `Quantity`, it is - converted to pixel units. - - Parameters - ---------- - volume: np.ndarray or Image - The 3D volume whose position is to be resampled. The volume must - have a `properties` attribute containing dictionaries with - `position` and `_position_sampler` keys. - - Returns - ------- - Image - The same input volume with its `position` property updated to the - newly sampled value. - - Notes - ----- - - The `_position_sampler` function is expected to return a **tuple of - three floats** (e.g., `(x, y, z)`). - - If the sampled position is a `Quantity`, it is converted to pixels. - - **Only** dictionaries in `volume.properties` that contain both - `position` and `_position_sampler` keys are modified. - - """ +# image = skimage.measure.block_reduce( +# image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean +# ) - for pdict in volume.properties: - if "position" in pdict and "_position_sampler" in pdict: - new_position = pdict["_position_sampler"]() - if isinstance(new_position, Quantity): - new_position = new_position.to("pixel").magnitude - pdict["position"] = new_position +# return image - return volume class Store(Feature): @@ -9593,4 +8607,4 @@ def get( if len(res) == 1: res = res[0] - return res + return res \ No newline at end of file diff --git a/deeptrack/holography.py b/deeptrack/holography.py index 380969cf..141cc540 100644 --- a/deeptrack/holography.py +++ b/deeptrack/holography.py @@ -101,7 +101,7 @@ def get_propagation_matrix( def get_propagation_matrix( shape: tuple[int, int], to_z: float, - pixel_size: float, + pixel_size: float | tuple[float, float], wavelength: float, dx: float = 0, dy: float = 0 @@ -118,8 +118,8 @@ def get_propagation_matrix( The dimensions of the optical field (height, width). to_z: float Propagation distance along the z-axis. - pixel_size: float - The physical size of each pixel in the optical field. + pixel_size: float | tuple[float, float] + Physical pixel size. If scalar, isotropic pixels are assumed. wavelength: float The wavelength of the optical field. dx: float, optional @@ -140,14 +140,22 @@ def get_propagation_matrix( """ + if pixel_size is None: + pixel_size = get_active_voxel_size() + + if np.isscalar(pixel_size): + pixel_size = (pixel_size, pixel_size) + + px, py = pixel_size + k = 2 * np.pi / wavelength yr, xr, *_ = shape x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2 y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2 - x = 2 * np.pi / pixel_size * x / xr - y = 2 * np.pi / pixel_size * y / yr + x = 2 * np.pi / px * x / xr + y = 2 * np.pi / py * y / yr KXk, KYk = np.meshgrid(x, y) KXk = KXk.astype(complex) diff --git a/deeptrack/math.py b/deeptrack/math.py index 05cbf311..2f8b7339 100644 --- a/deeptrack/math.py +++ b/deeptrack/math.py @@ -93,23 +93,24 @@ from __future__ import annotations -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, Dict, Tuple, TYPE_CHECKING import array_api_compat as apc import numpy as np -from numpy.typing import NDArray +from numpy.typing import NDArray #TODO TBE from scipy import ndimage import skimage import skimage.measure from deeptrack import utils, OPENCV_AVAILABLE, TORCH_AVAILABLE from deeptrack.features import Feature -from deeptrack.image import Image, strip -from deeptrack.types import ArrayLike, PropertyLike +from deeptrack.image import Image, strip #TODO TBE +from deeptrack.types import PropertyLike from deeptrack.backend import xp if TORCH_AVAILABLE: import torch + import torch.nn.functional as F if OPENCV_AVAILABLE: import cv2 @@ -133,7 +134,6 @@ "BilateralBlur", ] - if TYPE_CHECKING: import torch @@ -227,10 +227,10 @@ def __init__( def get( self: Average, - images: list[NDArray[Any] | torch.Tensor | Image], + images: list[np.ndarray | torch.Tensor], axis: int | tuple[int], **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Compute the average of input images along the specified axis(es). This method computes the average of the input images along the @@ -318,11 +318,11 @@ def __init__( def get( self: Clip, - image: NDArray[Any] | torch.Tensor | Image, + image: np.ndarray | torch.Tensor, min: float, max: float, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Clips the input image within the specified values. This method clips the input image within the specified minimum and @@ -363,8 +363,7 @@ class NormalizeMinMax(Feature): max: float, optional Upper bound of the transformation. It defaults to 1. featurewise: bool, optional - Whether to normalize each feature independently. It default to `True`, - which is the only behavior currently implemented. + Whether to normalize each feature independently. It default to `True`. Methods ------- @@ -390,8 +389,6 @@ class NormalizeMinMax(Feature): """ - #TODO ___??___ Implement the `featurewise=False` option - def __init__( self: NormalizeMinMax, min: PropertyLike[float] = 0, @@ -418,32 +415,47 @@ def __init__( def get( self: NormalizeMinMax, - image: ArrayLike, + image: np.ndarray | torch.Tensor, min: float, max: float, + featurewise: bool = True, **kwargs: Any, - ) -> ArrayLike: + ) -> np.ndarray | torch.Tensor: """Normalize the input to fall between `min` and `max`. Parameters ---------- - image: array + image: np.ndarray or torch.Tensor Input image to normalize. min: float Lower bound of the output range. max: float Upper bound of the output range. + featurewise: bool + Whether to normalize each feature (channel) independently. Returns ------- - array + np.ndarray or torch.Tensor Min-max normalized image. """ - ptp = xp.max(image) - xp.min(image) - image = image / ptp * (max - min) - image = image - xp.min(image) + min + if featurewise: + # Normalize per feature (last axis) + axis = tuple(range(image.ndim - 1)) + + img_min = xp.min(image, axis=axis, keepdims=True) + img_max = xp.max(image, axis=axis, keepdims=True) + else: + # Normalize globally + img_min = xp.min(image) + img_max = xp.max(image) + + ptp = img_max - img_min + + # Avoid division by zero + image = (image - img_min) / ptp * (max - min) + min try: image[xp.isnan(image)] = 0 @@ -487,8 +499,6 @@ class NormalizeStandard(Feature): """ - #TODO ___??___ Implement the `featurewise=False` option - def __init__( self: NormalizeStandard, featurewise: PropertyLike[bool] = True, @@ -511,33 +521,52 @@ def __init__( def get( self: NormalizeStandard, - image: NDArray[Any] | torch.Tensor | Image, + image: np.ndarray | torch.Tensor, + featurewise: bool, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Normalizes the input image to have mean 0 and standard deviation 1. - This method normalizes the input image to have mean 0 and standard - deviation 1. - Parameters ---------- - image: array + image: np.ndarray or torch.Tensor The input image to normalize. + featurewise: bool + Whether to normalize each feature (channel) independently. Returns ------- - array - The normalized image. - + np.ndarray or torch.Tensor + The standardized image. """ - if apc.is_torch_array(image): - # By default, torch.std() is unbiased, i.e., divides by N-1 - return ( - (image - torch.mean(image)) / torch.std(image, unbiased=False) - ) + if featurewise: + # Normalize per feature (last axis) + axis = tuple(range(image.ndim - 1)) + + mean = xp.mean(image, axis=axis, keepdims=True) + + if apc.is_torch_array(image): + std = torch.std(image, dim=axis, keepdim=True, unbiased=False) + else: + std = xp.std(image, axis=axis) + else: + # Normalize globally + mean = xp.mean(image) + + if apc.is_torch_array(image): + std = torch.std(image, unbiased=False) + else: + std = xp.std(image) + + image = (image - mean) / std + + try: + image[xp.isnan(image)] = 0 + except TypeError: + pass - return (image - xp.mean(image)) / xp.std(image) + return image class NormalizeQuantile(Feature): @@ -609,158 +638,164 @@ def __init__( def get( self: NormalizeQuantile, - image: NDArray[Any] | torch.Tensor | Image, - quantiles: tuple[float, float] = None, + image: np.ndarray | torch.Tensor, + quantiles: tuple[float, float], + featurewise: bool, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Normalize the input image based on the specified quantiles. - This method normalizes the input image based on the specified - quantiles. - Parameters ---------- - image: array + image: np.ndarray or torch.Tensor The input image to normalize. quantiles: tuple[float, float] Quantile range to calculate scaling factor. + featurewise: bool + Whether to normalize each feature (channel) independently. Returns ------- - array - The normalized image. - + np.ndarray or torch.Tensor + The quantile-normalized image. """ - if apc.is_torch_array(image): - q_tensor = torch.tensor( - [*quantiles, 0.5], - device=image.device, - dtype=image.dtype, - ) - q_low, q_high, median = torch.quantile( - image, q_tensor, dim=None, keepdim=False, - ) - else: # NumPy - q_low, q_high, median = xp.quantile(image, (*quantiles, 0.5)) - - return (image - median) / (q_high - q_low) * 2.0 + q_low_val, q_high_val = quantiles + if featurewise: + # Per-feature normalization (last axis) + axis = tuple(range(image.ndim - 1)) -#TODO ***JH*** revise Blur - torch, typing, docstring, unit test -class Blur(Feature): - """Apply a blurring filter to an image. - - This class applies a blurring filter to an image. The filter function - must be a function that takes an input image and returns a blurred - image. - - Parameters - ---------- - filter_function: Callable - The blurring function to apply. This function must accept the input - image as a keyword argument named `input`. If using OpenCV functions - (e.g., `cv2.GaussianBlur`), use `BlurCV2` instead. - mode: str - Border mode for handling boundaries (e.g., 'reflect'). + if apc.is_torch_array(image): + q = torch.tensor( + [q_low_val, q_high_val, 0.5], + device=image.device, + dtype=image.dtype, + ) + q_low, q_high, median = torch.quantile( + image, q, dim=axis, keepdim=True + ) + else: + q_low, q_high, median = xp.quantile( + image, (q_low_val, q_high_val, 0.5), + axis=axis, + keepdims=True, + ) + else: + # Global normalization + if apc.is_torch_array(image): + q = torch.tensor( + [q_low_val, q_high_val, 0.5], + device=image.device, + dtype=image.dtype, + ) + q_low, q_high, median = torch.quantile( + image, q, dim=None, keepdim=False + ) + else: + q_low, q_high, median = xp.quantile( + image, (q_low_val, q_high_val, 0.5) + ) - Methods - ------- - `get(image: np.ndarray | Image, **kwargs: Any) --> np.ndarray` - Applies the blurring filter to the input image. + image = (image - median) / (q_high - q_low) * 2.0 - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - >>> from scipy.ndimage import convolve + try: + image[xp.isnan(image)] = 0 + except TypeError: + pass - Create an input image: - >>> input_image = np.random.rand(32, 32) + return image - Define a Gaussian kernel for blurring: - >>> gaussian_kernel = np.array([ - ... [1, 4, 6, 4, 1], - ... [4, 16, 24, 16, 4], - ... [6, 24, 36, 24, 6], - ... [4, 16, 24, 16, 4], - ... [1, 4, 6, 4, 1] - ... ], dtype=float) - >>> gaussian_kernel /= np.sum(gaussian_kernel) - Define a blur function using the Gaussian kernel: - >>> def gaussian_blur(input, **kwargs): - ... return convolve(input, gaussian_kernel, mode='reflect') +#TODO ***CM*** revise typing, docstring, unit test +class Blur(Feature): + """Apply a blurring filter to an image. - Define a blur feature using the Gaussian blur function: - >>> blur = dt.Blur(filter_function=gaussian_blur) - >>> output_image = blur(input_image) - >>> print(output_image.shape) - (32, 32) + This class acts as a backend-dispatching blur operator. Subclasses must + implement backend-specific logic via `_get_numpy` and optionally + `_get_torch`. Notes ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - The filter_function must accept the input image as a keyword argument named - input. This is required because it is called via utils.safe_call. If you - are using functions that do not support input=... (such as OpenCV filters - like cv2.GaussianBlur), consider using BlurCV2 instead. + - NumPy execution is always supported. + - Torch execution is only supported if `_get_torch` is implemented. + - Generic `filter_function`-based blurs are NumPy-only by design. """ def __init__( - self: Blur, - filter_function: Callable, + self, + filter_function: Callable | None = None, mode: PropertyLike[str] = "reflect", **kwargs: Any, ): - """Initialize the parameters for blurring input features. - - This constructor initializes the parameters for blurring input - features. + """Initialize the blur feature. Parameters ---------- - filter_function: Callable - The blurring function to apply. - mode: str - Border mode for handling boundaries (e.g., 'reflect'). - **kwargs: Any - Additional keyword arguments. - + filter_function : Callable or None + NumPy-based blurring function. Must accept the input image as a + keyword argument named `input`. If `None`, the subclass must + implement `_get_numpy`. + mode : str + Border mode for NumPy-based filters. + **kwargs : Any + Additional keyword arguments passed to Feature. """ - self.filter = filter_function - super().__init__(borderType=mode, **kwargs) + self.mode = mode + super().__init__(**kwargs) - def get(self: Blur, image: np.ndarray | Image, **kwargs: Any) -> np.ndarray: - """Applies the blurring filter to the input image. + def __call__( + self, + image: np.ndarray | torch.Tensor, + **kwargs: Any, + ) -> np.ndarray | torch.Tensor: + if isinstance(image, np.ndarray): + return self._get_numpy(image, **kwargs) - This method applies the blurring filter to the input image. + if TORCH_AVAILABLE and isinstance(image, torch.Tensor): + return self._get_torch(image, **kwargs) - Parameters - ---------- - image: np.ndarray - The input image to blur. - **kwargs: dict[str, Any] - Additional keyword arguments. + raise TypeError( + "Blur only supports numpy.ndarray or torch.Tensor inputs." + ) - Returns - ------- - np.ndarray - The blurred image. + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + if self.filter is None: + raise NotImplementedError( + f"{self.__class__.__name__} does not implement a NumPy backend." + ) - """ + # Avoid passing conflicting keywords + kwargs = dict(kwargs) + kwargs.pop("input", None) + + return utils.safe_call( + self.filter, + input=image, + mode=self.mode, + **kwargs, + ) + + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + raise TypeError( + f"{self.__class__.__name__} does not support torch.Tensor inputs. " + "Use a Torch-enabled blur (e.g. AverageBlur or a V2 blur class)." + ) - kwargs.pop("input", False) - return utils.safe_call(self.filter, input=image, **kwargs) -#TODO ***JH*** revise AverageBlur - torch, typing, docstring, unit test +#TODO ***CM*** revise AverageBlur - torch, typing, docstring, unit test class AverageBlur(Blur): """Blur an image by computing simple means over neighbourhoods. @@ -774,7 +809,7 @@ class AverageBlur(Blur): Methods ------- - `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray` + `get(image: np.ndarray | torch.Tensor, ksize: int, **kwargs: Any) --> np.ndarray | torch.Tensor` Applies the average blurring filter to the input image. Examples @@ -791,13 +826,6 @@ class AverageBlur(Blur): >>> print(output_image.shape) (32, 32) - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __init__( @@ -840,12 +868,10 @@ def _get_numpy( def _get_torch( self, input: torch.Tensor, ksize: tuple[int, ...], **kwargs: Any - ) -> np.ndarray: - F = xp.nn.functional + ) -> torch.Tensor: last_dim_is_channel = len(ksize) < input.ndim if last_dim_is_channel: - # permute to first dim input = input.movedim(-1, 0) else: input = input.unsqueeze(0) @@ -853,40 +879,26 @@ def _get_torch( # add batch dimension input = input.unsqueeze(0) - # pad input + # dynamic padding + pad = [] + for k in reversed(ksize): + p = k // 2 + pad.extend([p, p]) + pad = tuple(pad) + input = F.pad( input, - (ksize[0] // 2, ksize[0] // 2, ksize[1] // 2, ksize[1] // 2), + pad, mode=kwargs.get("mode", "reflect"), value=kwargs.get("cval", 0), ) + if input.ndim == 3: - x = F.avg_pool1d( - input, - kernel_size=ksize, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=False, - ) + x = F.avg_pool1d(input, kernel_size=ksize, stride=1) elif input.ndim == 4: - x = F.avg_pool2d( - input, - kernel_size=ksize, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=False, - ) + x = F.avg_pool2d(input, kernel_size=ksize, stride=1) elif input.ndim == 5: - x = F.avg_pool3d( - input, - kernel_size=ksize, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=False, - ) + x = F.avg_pool3d(input, kernel_size=ksize, stride=1) else: raise NotImplementedError( f"Input dimension {input.ndim - 2} not supported for torch backend" @@ -903,10 +915,10 @@ def _get_torch( def get( self: AverageBlur, - input: ArrayLike, + input: np.ndarray | torch.Tensor, ksize: int, **kwargs: Any, - ) -> np.ndarray: + ) -> np.ndarray | torch.Tensor: """Applies the average blurring filter to the input image. This method applies the average blurring filter to the input image. @@ -937,7 +949,7 @@ def get( raise NotImplementedError(f"Backend {self.backend} not supported") -#TODO ***JH*** revise GaussianBlur - torch, typing, docstring, unit test +#TODO ***CM*** revise typing, docstring, unit test class GaussianBlur(Blur): """Applies a Gaussian blur to images using Gaussian kernels. @@ -973,13 +985,6 @@ class GaussianBlur(Blur): >>> plt.imshow(output_image, cmap='gray') >>> plt.show() - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any): @@ -996,7 +1001,101 @@ def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any): """ - super().__init__(ndimage.gaussian_filter, sigma=sigma, **kwargs) + self.sigma = float(sigma) + super().__init__(None, **kwargs) + + def _get_numpy( + self, + input: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + return ndimage.gaussian_filter( + input, + sigma=self.sigma, + mode=kwargs.get("mode", "reflect"), + cval=kwargs.get("cval", 0), + ) + + def _gaussian_kernel_1d( + self, + sigma: float, + device, + dtype, + ) -> torch.Tensor: + radius = int(np.ceil(3 * sigma)) + x = torch.arange( + -radius, radius + 1, + device=device, + dtype=dtype, + ) + kernel = torch.exp(-(x ** 2) / (2 * sigma ** 2)) + kernel /= kernel.sum() + return kernel + + def _get_torch( + self, + input: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F + + sigma = self.sigma + kernel_1d = self._gaussian_kernel_1d( + sigma, + device=input.device, + dtype=input.dtype, + ) + + last_dim_is_channel = input.ndim >= 3 + if last_dim_is_channel: + input = input.movedim(-1, 0) # C, ... + else: + input = input.unsqueeze(0) # 1, ... + + # add batch dimension + input = input.unsqueeze(0) # 1, C, ... + + spatial_dims = input.ndim - 2 + C = input.shape[1] + + for d in range(spatial_dims): + k = kernel_1d + shape = [1] * spatial_dims + shape[d] = -1 + k = k.view(1, 1, *shape) + k = k.repeat(C, 1, *([1] * spatial_dims)) + + pad = [0, 0] * spatial_dims + radius = k.shape[2 + d] // 2 + pad[-(2 * d + 2)] = radius + pad[-(2 * d + 1)] = radius + pad = tuple(pad) + + input = F.pad( + input, + pad, + mode=kwargs.get("mode", "reflect"), + ) + + if spatial_dims == 1: + input = F.conv1d(input, k, groups=C) + elif spatial_dims == 2: + input = F.conv2d(input, k, groups=C) + elif spatial_dims == 3: + input = F.conv3d(input, k, groups=C) + else: + raise NotImplementedError( + f"{spatial_dims}D Gaussian blur not supported" + ) + + # restore layout + input = input.squeeze(0) + if last_dim_is_channel: + input = input.movedim(0, -1) + else: + input = input.squeeze(0) + + return input #TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test @@ -1009,6 +1108,9 @@ class MedianBlur(Blur): useful for reducing noise while preserving edges. It is particularly effective for removing salt-and-pepper noise from images. + - NumPy backend: `scipy.ndimage.median_filter` + - Torch backend: explicit unfolding followed by `torch.median` + Parameters ---------- ksize: int @@ -1016,6 +1118,11 @@ class MedianBlur(Blur): **kwargs: dict Additional parameters sent to the blurring function. + Notes + ----- + Torch median blurring is significantly more expensive than mean or + Gaussian blurring due to explicit tensor unfolding. + Examples -------- >>> import deeptrack as dt @@ -1039,13 +1146,6 @@ class MedianBlur(Blur): >>> plt.imshow(output_image, cmap='gray') >>> plt.show() - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __init__( @@ -1053,667 +1153,367 @@ def __init__( ksize: PropertyLike[int] = 3, **kwargs: Any, ): - """Initialize the parameters for median blurring. + self.ksize = int(ksize) + super().__init__(None, **kwargs) + + def _get_numpy( + self, + input: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + return ndimage.median_filter( + input, + size=self.ksize, + mode=kwargs.get("mode", "reflect"), + cval=kwargs.get("cval", 0), + ) - This constructor initializes the parameters for median blurring. + def _get_torch( + self, + input: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F - Parameters - ---------- - ksize: int - Kernel size. - **kwargs: Any - Additional keyword arguments. + k = self.ksize + if k % 2 == 0: + raise ValueError("MedianBlur requires an odd kernel size.") - """ + last_dim_is_channel = input.ndim >= 3 + if last_dim_is_channel: + input = input.movedim(-1, 0) # C, ... + else: + input = input.unsqueeze(0) # 1, ... - super().__init__(ndimage.median_filter, size=ksize, **kwargs) + # add batch dimension + input = input.unsqueeze(0) # 1, C, ... + spatial_dims = input.ndim - 2 + pad = k // 2 -#TODO ***AL*** revise Pool - torch, typing, docstring, unit test -class Pool(Feature): - """Downsamples the image by applying a function to local regions of the - image. + # padding + pad_tuple = [] + for _ in range(spatial_dims): + pad_tuple.extend([pad, pad]) + pad_tuple = tuple(reversed(pad_tuple)) - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the specified pooling - function to each block. The result is a downsampled image where each pixel - value represents the result of the pooling function applied to the - corresponding block. + input = F.pad( + input, + pad_tuple, + mode=kwargs.get("mode", "reflect"), + ) - Parameters - ---------- - pooling_function: function - A function that is applied to each local region of the image. - DOES NOT NEED TO BE WRAPPED IN ANOTHER FUNCTION. - The `pooling_function` must accept the input image as a keyword argument - named `input`, as it is called via `utils.safe_call`. - Examples include `np.mean`, `np.max`, `np.min`, etc. - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. + # unfold spatial dimensions + if spatial_dims == 1: + x = input.unfold(2, k, 1) + elif spatial_dims == 2: + x = ( + input + .unfold(2, k, 1) + .unfold(3, k, 1) + ) + elif spatial_dims == 3: + x = ( + input + .unfold(2, k, 1) + .unfold(3, k, 1) + .unfold(4, k, 1) + ) + else: + raise NotImplementedError( + f"{spatial_dims}D median blur not supported" + ) - Methods - ------- - `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray` - Applies the pooling function to the input image. + # flatten neighborhood and take median + x = x.contiguous().view(*x.shape[:-spatial_dims], -1) + x = x.median(dim=-1).values - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np + # restore layout + x = x.squeeze(0) + if last_dim_is_channel: + x = x.movedim(0, -1) + else: + x = x.squeeze(0) - Create an input image: - >>> input_image = np.random.rand(32, 32) + return x - Define a pooling feature: - >>> pooling_feature = dt.Pool(pooling_function=np.mean, ksize=4) - >>> output_image = pooling_feature.get(input_image, ksize=4) - >>> print(output_image.shape) - (8, 8) +#TODO ***CM*** revise typing, docstring, unit test +class Pool: + """ + DeepTrack v2 replacement for Pool. - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - The filter_function must accept the input image as a keyword argument named - input. This is required because it is called via utils.safe_call. If you - are using functions that do not support input=... (such as OpenCV filters - like cv2.GaussianBlur), consider using BlurCV2 instead. + Generic, center-preserving block pooling with NumPy and Torch backends. + Public API matches v1: a single integer ksize. + Pool size semantics: + - 2D input -> (ksize, ksize, 1) + - 3D input -> (ksize, ksize, ksize) """ + _TORCH_REDUCERS_2D: Dict[Callable, Callable] = { + np.mean: lambda x, k, s: F.avg_pool2d(x, k, s), + np.sum: lambda x, k, s: F.avg_pool2d(x, k, s) * (k[0] * k[1]), + np.max: lambda x, k, s: F.max_pool2d(x, k, s), + np.min: lambda x, k, s: -F.max_pool2d(-x, k, s), + } + + _TORCH_REDUCERS_3D: Dict[Callable, Callable] = { + np.mean: lambda x, k, s: F.avg_pool3d(x, k, s), + np.sum: lambda x, k, s: F.avg_pool3d(x, k, s) * (k[0] * k[1] * k[2]), + np.max: lambda x, k, s: F.max_pool3d(x, k, s), + np.min: lambda x, k, s: -F.max_pool3d(-x, k, s), + } + def __init__( - self: Pool, + self, pooling_function: Callable, - ksize: PropertyLike[int] = 3, - **kwargs: Any, + ksize: int = 2, ): - """Initialize the parameters for pooling input features. + if pooling_function not in ( + np.mean, np.sum, np.min, np.max, np.median + ): + raise ValueError( + "Unsupported pooling_function. " + "Use one of: np.mean, np.sum, np.min, np.max, np.median." + ) - This constructor initializes the parameters for pooling input - features. + if not isinstance(ksize, int) or ksize < 1: + raise ValueError("ksize must be a positive integer.") - Parameters - ---------- - pooling_function: Callable - The pooling function to apply. - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. + self.pooling_function = pooling_function + self.ksize = int(ksize) + def _get_pool_size(self, array) -> Tuple[int, int, int]: """ + Determine pooling kernel size based on semantic dimensionality. - self.pooling = pooling_function - super().__init__(ksize=ksize, **kwargs) + - 2D images: (Nx, Ny) or (Nx, Ny, C) -> pool in x,y only + - 3D volumes: (Nx, Ny, Nz) or (Nx, Ny, Nz, C) -> pool in x,y,z + - Never pool over channels + """ + k = self.ksize - def get( - self: Pool, - image: np.ndarray | Image, - ksize: int, - **kwargs: Any, - ) -> np.ndarray: - """Applies the pooling function to the input image. + # 2D image + if array.ndim == 2: + return k, k, 1 - This method applies the pooling function to the input image. + # 3D array: could be (x, y, z) or (x, y, c) + if array.ndim == 3: + # Heuristic: small last dim → channels + if array.shape[-1] <= 4: + return k, k, 1 + return k, k, k - Parameters - ---------- - image: np.ndarray - The input image to pool. - ksize: int - Size of the pooling kernel. - **kwargs: dict[str, Any] - Additional keyword arguments. - - Returns - ------- - np.ndarray - The pooled image. - - """ + # 4D array: (x, y, z, c) + if array.ndim == 4: + return k, k, k - kwargs.pop("func", False) - kwargs.pop("image", False) - kwargs.pop("block_size", False) - return utils.safe_call( - skimage.measure.block_reduce, - image=image, - func=self.pooling, - block_size=ksize, - **kwargs, + raise ValueError( + f"Unsupported array shape {array.shape} for pooling." ) + def _crop_center(self, array): + px, py, pz = self._get_pool_size(array) + + # 2D (or effectively 2D) + if array.ndim < 3 or pz == 1: + H, W = array.shape[:2] + crop_h = (H // px) * px + crop_w = (W // py) * py + off_h = (H - crop_h) // 2 + off_w = (W - crop_w) // 2 + return array[ + off_h : off_h + crop_h, + off_w : off_w + crop_w, + ... + ] + + # 3D + Z, H, W = array.shape[:3] + crop_z = (Z // pz) * pz + crop_h = (H // px) * px + crop_w = (W // py) * py + off_z = (Z - crop_z) // 2 + off_h = (H - crop_h) // 2 + off_w = (W - crop_w) // 2 + return array[ + off_z : off_z + crop_z, + off_h : off_h + crop_h, + off_w : off_w + crop_w, + ... + ] + + def _pool_numpy(self, array: np.ndarray) -> np.ndarray: + array = self._crop_center(array) + px, py, pz = self._get_pool_size(array) + + if array.ndim < 3 or pz == 1: + pool_shape = (px, py) + (1,) * (array.ndim - 2) + else: + pool_shape = (pz, px, py) + (1,) * (array.ndim - 3) -#TODO ***AL*** revise AveragePooling - torch, typing, docstring, unit test -class AveragePooling(Pool): - """Apply average pooling to an image. - - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the average function to - each block. The result is a downsampled image where each pixel value - represents the average value within the corresponding block of the - original image. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: dict - Additional parameters sent to the pooling function. - - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - - Create an input image: - >>> input_image = np.random.rand(32, 32) - - Define an average pooling feature: - >>> average_pooling = dt.AveragePooling(ksize=4) - >>> output_image = average_pooling(input_image) - >>> print(output_image.shape) - (8, 8) - - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - - """ - - def __init__( - self: Pool, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for average pooling. - - This constructor initializes the parameters for average pooling. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. - - """ - - super().__init__(np.mean, ksize=ksize, **kwargs) - - -class MaxPooling(Pool): - """Apply max-pooling to images. - - `MaxPooling` reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the `max` function - to each block. The result is a downsampled image where each pixel value - represents the maximum value within the corresponding block of the - original image. This is useful for reducing the size of an image while - retaining the most significant features. - - If the backend is NumPy, the downsampling is performed using - `skimage.measure.block_reduce`. - - If the backend is PyTorch, the downsampling is performed using - `torch.nn.functional.max_pool2d`. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. - - Examples - -------- - >>> import deeptrack as dt - - Create an input image: - >>> import numpy as np - >>> - >>> input_image = np.random.rand(32, 32) - - Define and use a max-pooling feature: - - >>> max_pooling = dt.MaxPooling(ksize=8) - >>> output_image = max_pooling(input_image) - >>> output_image.shape - (4, 4) - - """ - - def __init__( - self: MaxPooling, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for max-pooling. - - This constructor initializes the parameters for max-pooling. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. - - """ - - super().__init__(np.max, ksize=ksize, **kwargs) - - def get( - self: MaxPooling, - image: NDArray[Any] | torch.Tensor, - ksize: int=3, - **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor: - """Max-pooling of input. - - Checks the current backend and chooses the appropriate function to pool - the input image, either `._get_torch()` or `._get_numpy()`. - - Parameters - ---------- - image: array or tensor - Input array or tensor be pooled. - ksize: int - Kernel size of the pooling operation. - - Returns - ------- - array or tensor - The pooled input as `NDArray` or `torch.Tensor` depending on - the backend. - - """ - - if self.get_backend() == "numpy": - return self._get_numpy(image, ksize, **kwargs) - - if self.get_backend() == "torch": - return self._get_torch(image, ksize, **kwargs) - - raise NotImplementedError(f"Backend {self.backend} not supported") - - def _get_numpy( - self: MaxPooling, - image: NDArray[Any], - ksize: int=3, - **kwargs: Any, - ) -> NDArray[Any]: - """Max-pooling pooling with the NumPy backend enabled. - - Returns the result of the input array passed to the scikit image - `block_reduce()` function with `np.max()` as the pooling function. - - Parameters - ---------- - image: array - Input array to be pooled. - ksize: int - Kernel size of the pooling operation. - - Returns - ------- - array - The pooled image as a NumPy array. - - """ - - return utils.safe_call( - skimage.measure.block_reduce, - image=image, - func=np.max, - block_size=ksize, - **kwargs, + return skimage.measure.block_reduce( + array, + block_size=pool_shape, + func=self.pooling_function, ) - def _get_torch( - self: MaxPooling, - image: torch.Tensor, - ksize: int=3, - **kwargs: Any, - ) -> torch.Tensor: - """Max-pooling with the PyTorch backend enabled. - - - Returns the result of the tensor passed to a PyTorch max - pooling layer. - - Parameters - ---------- - image: torch.Tensor - Input tensor to be pooled. - ksize: int - Kernel size of the pooling operation. - - Returns - ------- - torch.Tensor - The pooled image as a `torch.Tensor`. - - """ + def _pool_torch(self, array: torch.Tensor) -> torch.Tensor: + array = self._crop_center(array) + px, py, pz = self._get_pool_size(array) - # If input tensor is 2D - if len(image.shape) == 2: - # Add batch dimension for max-pooling - expanded_image = image.unsqueeze(0) + is_3d = array.ndim >= 3 and pz > 1 - pooled_image = torch.nn.functional.max_pool2d( - expanded_image, kernel_size=ksize, + if not is_3d: + extra = array.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = array.reshape(1, C, array.shape[0], array.shape[1]) + kernel = (px, py) + stride = (px, py) + reducers = self._TORCH_REDUCERS_2D + else: + extra = array.shape[3:] + C = int(np.prod(extra)) if extra else 1 + x = array.reshape( + 1, C, array.shape[0], array.shape[1], array.shape[2] ) - # Remove the expanded dim - return pooled_image.squeeze(0) - - return torch.nn.functional.max_pool2d( - image, - kernel_size=ksize, - ) - - -class MinPooling(Pool): - """Apply min-pooling to images. - - `MinPooling` reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the `min` function to - each block. The result is a downsampled image where each pixel value - represents the minimum value within the corresponding block of the original - image. - - If the backend is NumPy, the downsampling is performed using - `skimage.measure.block_reduce`. - - If the backend is PyTorch, the downsampling is performed using the inverse - of `torch.nn.functional.max_pool2d` by changing the sign of the input. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. - - Examples - -------- - >>> import deeptrack as dt - - Create an input image: - >>> import numpy as np - >>> - >>> input_image = np.random.rand(32, 32) - - Define and use a min-pooling feature: - >>> min_pooling = dt.MinPooling(ksize=4) - >>> output_image = min_pooling(input_image) - >>> output_image.shape - (8, 8) - - """ - - def __init__( - self: MinPooling, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for min-pooling. - - This constructor initializes the parameters for min-pooling and checks - whether to use the NumPy or PyTorch implementation, defaults to NumPy. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. - - """ - - super().__init__(np.min, ksize=ksize, **kwargs) - - def get( - self: MinPooling, - image: NDArray[Any] | torch.Tensor, - ksize: int=3, - **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor: - """Min pooling of input. - - Checks the current backend and chooses the appropriate function to pool - the input image, either `._get_torch()` or `._get_numpy()`. - - Parameters - ---------- - image: array or tensor - Input array or tensor to be pooled. - ksize: int - Kernel size of the pooling operation. - - Returns - ------- - array or tensor - The pooled image as `NDArray` or `torch.Tensor` depending on the - backend. - - """ - - if self.get_backend() == "numpy": - return self._get_numpy(image, ksize, **kwargs) - - if self.get_backend() == "torch": - return self._get_torch(image, ksize, **kwargs) - - raise NotImplementedError(f"Backend {self.backend} not supported") + kernel = (pz, px, py) + stride = (pz, px, py) + reducers = self._TORCH_REDUCERS_3D + + # Median: explicit unfolding + if self.pooling_function is np.median: + if is_3d: + x_u = ( + x.unfold(2, pz, pz) + .unfold(3, px, px) + .unfold(4, py, py) + ) + x_u = x_u.contiguous().view( + 1, C, + x_u.shape[2], + x_u.shape[3], + x_u.shape[4], + -1, + ) + pooled = x_u.median(dim=-1).values + else: + x_u = x.unfold(2, px, px).unfold(3, py, py) + x_u = x_u.contiguous().view( + 1, C, + x_u.shape[2], + x_u.shape[3], + -1, + ) + pooled = x_u.median(dim=-1).values + else: + reducer = reducers[self.pooling_function] + pooled = reducer(x, kernel, stride) - def _get_numpy( - self: MinPooling, - image: NDArray[Any], - ksize: int=3, - **kwargs: Any, - ) -> NDArray[Any]: - """Min-pooling with the NumPy backend. + return pooled.reshape(pooled.shape[2:] + extra) - Returns the result of the input array passed to the scikit - `image block_reduce()` function with `np.min()` as the pooling - function. + def __call__(self, array): + if isinstance(array, np.ndarray): + return self._pool_numpy(array) - Parameters - ---------- - image: NDArray - Input image to be pooled. - ksize: int - Kernel size of the pooling operation. + if TORCH_AVAILABLE and isinstance(array, torch.Tensor): + return self._pool_torch(array) - Returns - ------- - NDArray - The pooled image as a `NDArray`. - - """ - - return utils.safe_call( - skimage.measure.block_reduce, - image=image, - func=np.min, - block_size=ksize, - **kwargs, + raise TypeError( + "Pool only supports np.ndarray or torch.Tensor inputs." ) - def _get_torch( - self: MinPooling, - image: torch.Tensor, - ksize: int=3, - **kwargs: Any, - ) -> torch.Tensor: - """Min-pooling with the PyTorch backend. - - As PyTorch does not have a min-pooling layer, the equivalent operation - is to first multiply the input tensor with `-1`, then perform - max-pooling, and finally multiply the max pooled tensor with `-1`. - Parameters - ---------- - image: torch.Tensor - Input tensor to be pooled. - ksize: int - Kernel size of the pooling operation. +class AveragePooling(Pool): + def __init__(self, ksize: int = 2): + super().__init__(np.mean, ksize) - Returns - ------- - torch.Tensor - The pooled image as a `torch.Tensor`. - """ +class SumPooling(Pool): + def __init__(self, ksize: int = 2): + super().__init__(np.sum, ksize) - # If input tensor is 2D - if len(image.shape) == 2: - # Add batch dimension for min-pooling - expanded_image = image.unsqueeze(0) - pooled_image = - torch.nn.functional.max_pool2d( - expanded_image * (-1), - kernel_size=ksize, - ) +class MinPooling(Pool): + def __init__(self, ksize: int = 2): + super().__init__(np.min, ksize) - # Remove the expanded dim - return pooled_image.squeeze(0) - return -torch.nn.functional.max_pool2d( - image * (-1), - kernel_size=ksize, - ) +class MaxPooling(Pool): + def __init__(self, ksize: int = 2): + super().__init__(np.max, ksize) -#TODO ***AL*** revise MedianPooling - torch, typing, docstring, unit test class MedianPooling(Pool): - """Apply median pooling to images. - - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the median function to - each block. The result is a downsampled image where each pixel value - represents the median value within the corresponding block of the - original image. This is useful for reducing the size of an image while - retaining the most significant features. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. + def __init__(self, ksize: int = 2): + super().__init__(np.median, ksize) - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - - Create an input image: - >>> input_image = np.random.rand(32, 32) - - Define a median pooling feature: - >>> median_pooling = dt.MedianPooling(ksize=3) - >>> output_image = median_pooling(input_image) - >>> print(output_image.shape) - (32, 32) - - Visualize the input and output images: - >>> plt.figure(figsize=(8, 4)) - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(input_image, cmap='gray') - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(output_image, cmap='gray') - >>> plt.show() - - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - - """ - - def __init__( - self: MedianPooling, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for median pooling. - - This constructor initializes the parameters for median pooling. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. - - """ - - super().__init__(np.median, ksize=ksize, **kwargs) class Resize(Feature): """Resize an image to a specified size. - `Resize` resizes an image using: - - OpenCV (`cv2.resize`) for NumPy arrays. - - PyTorch (`torch.nn.functional.interpolate`) for PyTorch tensors. + `Resize` resizes images following the channels-last semantic + convention. + + The operation supports both NumPy arrays and PyTorch tensors: + - NumPy arrays are resized using OpenCV (`cv2.resize`). + - PyTorch tensors are resized using `torch.nn.functional.interpolate`. - The interpretation of the `dsize` parameter follows the convention - of the underlying backend: - - **NumPy (OpenCV)**: `dsize` is given as `(width, height)` to match - OpenCV’s default. - - **PyTorch**: `dsize` is given as `(height, width)`. + In all cases, the input is interpreted as having spatial dimensions + first and an optional channel dimension last. Parameters ---------- - dsize: PropertyLike[tuple[int, int]] - The target size. Format depends on backend: `(width, height)` for - NumPy, `(height, width)` for PyTorch. - **kwargs: Any - Additional parameters sent to the underlying resize function: - - NumPy: passed to `cv2.resize`. - - PyTorch: passed to `torch.nn.functional.interpolate`. + dsize : PropertyLike[tuple[int, int]] + Target output size given as (width, height). This convention is + backend-independent and applies equally to NumPy and PyTorch inputs. + + **kwargs : Any + Additional keyword arguments forwarded to the underlying resize + implementation: + - NumPy backend: passed to `cv2.resize`. + - PyTorch backend: passed to + `torch.nn.functional.interpolate`. Methods ------- get( - image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs + image: np.ndarray | torch.Tensor, + dsize: tuple[int, int], + **kwargs ) -> np.ndarray | torch.Tensor Resize the input image to the specified size. Examples -------- - >>> import deeptrack as dt + NumPy example: - Numpy example: >>> import numpy as np - >>> - >>> input_image = np.random.rand(16, 16) # Create image - >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4) - >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8) - >>> print(resized_image.shape) + >>> input_image = np.random.rand(16, 16) + >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4) + >>> resized_image = feature.resolve(input_image) + >>> resized_image.shape (4, 8) PyTorch example: + >>> import torch - >>> - >>> input_image = torch.rand(1, 1, 16, 16) # Create image - >>> feature = dt.math.Resize(dsize=(4, 8)) # (height=4, width=8) - >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8) - >>> print(resized_image.shape) - torch.Size([1, 1, 4, 8]) + >>> input_image = torch.rand(16, 16) # channels-last + >>> feature = dt.math.Resize(dsize=(8, 4)) + >>> resized_image = feature.resolve(input_image) + >>> resized_image.shape + torch.Size([4, 8]) + + Notes + ----- + - Resize follows channels-last semantics, consistent with other features + such as Pool and Blur. + - Torch tensors with channels-first layout (e.g. (C, H, W) or + (N, C, H, W)) are not supported and must be converted to + channels-last format before resizing. + - For PyTorch tensors, bilinear interpolation is used with + `align_corners=False`, closely matching OpenCV’s default behavior. """ @@ -1738,78 +1538,117 @@ def __init__( def get( self: Resize, - image: NDArray | torch.Tensor, + image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs: Any, - ) -> NDArray | torch.Tensor: + ) -> np.ndarray | torch.Tensor: """Resize the input image to the specified size. Parameters ---------- - image: np.ndarray or torch.Tensor - The input image to resize. - - NumPy arrays may be grayscale (H, W) or color (H, W, C). - - Torch tensors are expected in one of the following formats: - (N, C, H, W), (C, H, W), or (H, W). - dsize: tuple[int, int] - Desired output size of the image. - - NumPy: (width, height) - - PyTorch: (height, width) - **kwargs: Any - Additional keyword arguments passed to the underlying resize - function (`cv2.resize` or `torch.nn.functional.interpolate`). + image : np.ndarray or torch.Tensor + Input image following channels-last semantics. + + Supported shapes are: + - (H, W) + - (H, W, C) + - (Z, H, W) + - (Z, H, W, C) + + For PyTorch tensors, channels-first layouts such as (C, H, W) or + (N, C, H, W) are not supported and must be converted to + channels-last format before calling `Resize`. + + dsize : tuple[int, int] + Desired output size given as (width, height). This convention is + backend-independent and applies to both NumPy and PyTorch inputs. + + **kwargs : Any + Additional keyword arguments passed to the underlying resize + implementation: + - NumPy backend: forwarded to `cv2.resize`. + - PyTorch backend: forwarded to `torch.nn.functional.interpolate`. Returns ------- np.ndarray or torch.Tensor - The resized image in the same type and dimensionality format as - input. + The resized image, with the same type and dimensionality layout as + the input image. Notes ----- + - Resize follows the same channels-last semantic convention as other + features in `deeptrack.math`. - For PyTorch tensors, resizing uses bilinear interpolation with - `align_corners=False`. This choice matches OpenCV’s `cv2.resize` - default behavior when resizing NumPy arrays, aiming to produce nearly - identical results between both backends. + `align_corners=False`, which closely matches OpenCV’s default behavior. """ - if self._wrap_array_with_image: - image = strip(image) + target_w, target_h = dsize + # Torch backend if apc.is_torch_array(image): - original_shape = image.shape - - # Reshape input to (N, C, H, W) - if image.ndim == 2: # (H, W) - image = image.unsqueeze(0).unsqueeze(0) - elif image.ndim == 3: # (C, H, W) - image = image.unsqueeze(0) - elif image.ndim != 4: + import torch.nn.functional as F + + original_ndim = image.ndim + has_channels = ( + image.ndim >= 3 and image.shape[-1] <= 4 + ) + + # Bring to (N, C, H, W) + if image.ndim == 2: + # (H, W) -> (1, 1, H, W) + x = image.unsqueeze(0).unsqueeze(0) + + elif image.ndim == 3 and has_channels: + # (H, W, C) -> (1, C, H, W) + x = image.permute(2, 0, 1).unsqueeze(0) + + elif image.ndim == 3: + # (Z, H, W) -> treat Z as batch + x = image.unsqueeze(1) + + elif image.ndim == 4 and has_channels: + # (Z, H, W, C) -> (Z, C, H, W) + x = image.permute(0, 3, 1, 2) + + else: raise ValueError( - "Resize only supports tensors with shape (N, C, H, W), " - "(C, H, W), or (H, W)." + f"Unsupported tensor shape {image.shape} for Resize." ) - resized = torch.nn.functional.interpolate( - image, - size=dsize, + # Resize spatial dimensions + resized = F.interpolate( + x, + size=(target_h, target_w), mode="bilinear", align_corners=False, ) - # Restore original dimensionality - if len(original_shape) == 2: - resized = resized.squeeze(0).squeeze(0) - elif len(original_shape) == 3: - resized = resized.squeeze(0) + # Restore original layout + if original_ndim == 2: + return resized.squeeze(0).squeeze(0) + + if original_ndim == 3 and has_channels: + return resized.squeeze(0).permute(1, 2, 0) + + if original_ndim == 3: + return resized.squeeze(1) - return resized + if original_ndim == 4: + return resized.permute(0, 2, 3, 1) + raise RuntimeError("Unexpected shape restoration path.") + + # NumPy / OpenCV backend else: import cv2 + + # OpenCV expects (width, height) return utils.safe_call( - cv2.resize, positional_args=[image, dsize], **kwargs + cv2.resize, + positional_args=[image, (target_w, target_h)], + **kwargs, ) @@ -1865,10 +1704,9 @@ class BlurCV2(Feature): Notes ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. + BlurCV2 is NumPy-only and does not support PyTorch tensors. + This class is intended for OpenCV-specific filters that are + not available in the backend-agnostic math layer. """ @@ -1964,6 +1802,12 @@ def get( """ + if apc.is_torch_array(image): + raise TypeError( + "BlurCV2 only supports NumPy arrays. " + "For Torch tensors, use Blur or GaussianBlur instead." + ) + kwargs.pop("name", None) result = self.filter(src=image, **kwargs) return result @@ -2015,10 +1859,7 @@ class BilateralBlur(BlurCV2): Notes ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. + BilateralBlur is NumPy-only and does not support PyTorch tensors. """ @@ -2059,3 +1900,73 @@ def __init__( sigmaSpace=sigma_space, **kwargs, ) + + +def isotropic_dilation( + mask, + radius: float, + *, + backend: str, + device=None, + dtype=None, +): + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_dilation + return isotropic_dilation(mask, radius) + + # torch backend + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + return (y[0, 0] > 0) + + +def isotropic_erosion( + mask, + radius: float, + *, + backend: str, + device=None, + dtype=None, +): + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_erosion + return isotropic_erosion(mask, radius) + + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + required = kernel.numel() + return (y[0, 0] >= required) \ No newline at end of file diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 5149bdae..a7394689 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -137,11 +137,13 @@ def _pad_volume( from __future__ import annotations from pint import Quantity -from typing import Any +from typing import Any, TYPE_CHECKING import warnings import numpy as np -from scipy.ndimage import convolve +from scipy.ndimage import convolve # might be removed later +import torch +import torch.nn.functional as F from deeptrack.backend.units import ( ConversionTable, @@ -149,23 +151,37 @@ def _pad_volume( get_active_scale, get_active_voxel_size, ) -from deeptrack.math import AveragePooling +from deeptrack.math import AveragePooling, SumPooling from deeptrack.features import propagate_data_to_dependencies from deeptrack.features import DummyFeature, Feature, StructuralFeature -from deeptrack.image import Image, pad_image_to_fft +from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike, PropertyLike from deeptrack import image from deeptrack import units_registry as u +from deeptrack import TORCH_AVAILABLE, image +from deeptrack.backend import xp +from deeptrack.scatterers import ScatteredVolume, ScatteredField + +if TORCH_AVAILABLE: + import torch + +if TYPE_CHECKING: + import torch + #TODO ***??*** revise Microscope - torch, typing, docstring, unit test class Microscope(StructuralFeature): """Simulates imaging of a sample using an optical system. - This class combines a feature-set that defines the sample to be imaged with - a feature-set defining the optical system, enabling the simulation of - optical imaging processes. + This class combines the sample to be imaged with the optical system, + enabling the simulation of optical imaging processes. + A Microscope: + - validates the semantic compatibility between scatterers and optics + - interprets volume-based scatterers into scalar fields when needed + - delegates numerical propagation to the objective (Optics) + - performs detector downscaling according to its physical semantics Parameters ---------- @@ -186,10 +202,16 @@ class Microscope(StructuralFeature): Methods ------- - `get(image: Image or None, **kwargs: Any) -> Image` + `get(image: np.ndarray or None, **kwargs: Any) -> np.ndarray` Simulates the imaging process using the defined optical system and returns the resulting image. + Notes + ----- + All volume scatterers imaged by a Microscope instance are assumed to + share the same contrast mechanism (e.g. refractive index or fluorescence). + Mixing contrast types is not supported. + Examples -------- Simulating an image using a brightfield optical system: @@ -238,13 +260,41 @@ def __init__( self._sample = self.add_feature(sample) self._objective = self.add_feature(objective) - self._sample.store_properties() + + def _validate_input(self, scattered): + if hasattr(self._objective, "validate_input"): + self._objective.validate_input(scattered) + + def _extract_contrast_volume(self, scattered): + if hasattr(self._objective, "extract_contrast_volume"): + return self._objective.extract_contrast_volume( + scattered, + **self._objective.properties(), + ) + return scattered.array + + def _downscale_image(self, image, upscale): + if hasattr(self._objective, "downscale_image"): + return self._objective.downscale_image(image, upscale) + + if not np.any(np.array(upscale) != 1): + return image + + ux, uy = upscale[:2] + if ux != uy: + raise ValueError( + f"Energy-conserving detector integration requires ux == uy, " + f"got ux={ux}, uy={uy}." + ) + if isinstance(ux, float) and ux.is_integer(): + ux = int(ux) + return AveragePooling(ux)(image) def get( self: Microscope, - image: Image | None, + image: np.ndarray | torch.Tensor | None = None, **kwargs: Any, - ) -> Image: + ) -> np.ndarray | torch.Tensor: """Generate an image of the sample using the defined optical system. This method processes the sample through the optical system to @@ -252,14 +302,14 @@ def get( Parameters ---------- - image: Image | None + image: np.ndarray | torch.Tensor | None The input image to be processed. If None, a new image is created. **kwargs: Any Additional parameters for the imaging process. Returns ------- - Image: Image + image: np.ndarray | torch.Tensor The processed image after applying the optical system. Examples @@ -280,9 +330,6 @@ def get( # Grab properties from the objective to pass to the sample additional_sample_kwargs = self._objective.properties() - # Calculate required output image for the given upscale - # This way of providing the upscale will be deprecated in the future - # in favor of dt.Upscale(). _upscale_given_by_optics = additional_sample_kwargs["upscale"] if np.array(_upscale_given_by_optics).size == 1: _upscale_given_by_optics = (_upscale_given_by_optics,) * 3 @@ -325,67 +372,61 @@ def get( if not isinstance(list_of_scatterers, list): list_of_scatterers = [list_of_scatterers] + # Semantic validation (per scatterer) + for scattered in list_of_scatterers: + self._validate_input(scattered) + # All scatterers that are defined as volumes. volume_samples = [ scatterer for scatterer in list_of_scatterers - if not scatterer.get_property("is_field", default=False) + if isinstance(scatterer, ScatteredVolume) ] # All scatterers that are defined as fields. field_samples = [ scatterer for scatterer in list_of_scatterers - if scatterer.get_property("is_field", default=False) + if isinstance(scatterer, ScatteredField) ] - + # Merge all volumes into a single volume. sample_volume, limits = _create_volume( volume_samples, **additional_sample_kwargs, ) - sample_volume = Image(sample_volume) - # Merge all properties into the volume. - for scatterer in volume_samples + field_samples: - sample_volume.merge_properties_from(scatterer) + print('prop', volume_samples[0].properties) + + # Interpret the merged volume semantically + sample_volume = self._extract_contrast_volume( + ScatteredVolume( + array=sample_volume, + properties=volume_samples[0].properties, + ), + ) # Let the objective know about the limits of the volume and all the fields. propagate_data_to_dependencies( self._objective, limits=limits, - fields=field_samples, + fields=field_samples, # should We add upscale? ) imaged_sample = self._objective.resolve(sample_volume) - # Upscale given by the optics needs to be handled separately. - if _upscale_given_by_optics != (1, 1, 1): - imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))( - imaged_sample - ) - - # Merge with input - if not image: - if not self._wrap_array_with_image and isinstance(imaged_sample, Image): - return imaged_sample._value - else: - return imaged_sample - - if not isinstance(image, list): - image = [image] - for i in range(len(image)): - image[i].merge_properties_from(imaged_sample) - return image - - # def _no_wrap_format_input(self, *args, **kwargs) -> list: - # return self._image_wrapped_format_input(*args, **kwargs) - - # def _no_wrap_process_and_get(self, *args, **feature_input) -> list: - # return self._image_wrapped_process_and_get(*args, **feature_input) + imaged_sample = self._downscale_image(imaged_sample, upscale) + # # Handling upscale from dt.Upscale() here to eliminate Image + # # wrapping issues. + # if np.any(np.array(upscale) != 1): + # ux, uy = upscale[:2] + # if contrast_type == "intensity": + # print("Using sum pooling for intensity downscaling.") + # imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample) + # else: + # imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample) - # def _no_wrap_process_output(self, *args, **feature_input): - # return self._image_wrapped_process_output(*args, **feature_input) + return imaged_sample #TODO ***??*** revise Optics - torch, typing, docstring, unit test @@ -569,6 +610,15 @@ def __init__( """ + def validate_scattered(self, scattered): + pass + + def extract_contrast_volume(self, scattered): + pass + + def downscale_image(self, image, upscale): + pass + def get_voxel_size( resolution: float | ArrayLike[float], magnification: float, @@ -673,6 +723,7 @@ def _process_properties( wavelength = propertydict["wavelength"] voxel_size = get_active_voxel_size() radius = NA / wavelength * np.array(voxel_size) + print('Pupil radius (in pixels):', radius) if np.any(radius[:2] > 0.5): required_upscale = np.max(np.ceil(radius[:2] * 2)) @@ -757,19 +808,18 @@ def _pupil( W, H = np.meshgrid(y, x) RHO = (W ** 2 + H ** 2).astype(complex) - pupil_function = Image((RHO < 1) + 0.0j, copy=False) + pupil_function = (RHO < 1) + 0.0j # Defocus - z_shift = Image( + z_shift = ( 2 * np.pi * refractive_index_medium / wavelength * voxel_size[2] - * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), - copy=False, + * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO) ) - z_shift._value[z_shift._value.imag != 0] = 0 + z_shift[z_shift.imag != 0] = 0 try: z_shift = np.nan_to_num(z_shift, False, 0, 0, 0) @@ -1007,7 +1057,7 @@ class Fluorescence(Optics): Methods ------- - `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> Image` + `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> np.ndarray` Simulates the imaging process using a fluorescence microscope. Examples @@ -1024,12 +1074,71 @@ class Fluorescence(Optics): """ + + def validate_input(self, scattered): + """Semantic validation for fluorescence microscopy.""" + + # Fluorescence cannot operate on coherent fields + if isinstance(scattered, ScatteredField): + raise TypeError( + "Fluorescence microscope cannot operate on ScatteredField." + ) + + + def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray: + voxel_size = np.asarray(get_active_voxel_size(), float) + voxel_volume = np.prod(voxel_size) + + intensity = scattered.get_property("intensity", None) + value = scattered.get_property("value", None) + ri = scattered.get_property("refractive_index", None) + + # Refractive index is always ignored in fluorescence + if ri is not None: + warnings.warn( + "Scatterer defines 'refractive_index', which is ignored in " + "fluorescence microscopy.", + UserWarning, + ) + + # Preferred, physically meaningful case + if intensity is not None: + return intensity * voxel_volume * scattered.array + + # Fallback: legacy / dimensionless brightness + warnings.warn( + "Fluorescence scatterer has no 'intensity'. Interpreting 'value' as a " + "non-physical brightness factor. Quantitative interpretation is invalid. " + "Define 'intensity' to model physical fluorescence emission.", + UserWarning, + ) + + return value * scattered.array + + def downscale_image(self, image: np.ndarray, upscale): + """Detector downscaling (energy conserving)""" + if not np.any(np.array(upscale) != 1): + return image + + ux, uy = upscale[:2] + if ux != uy: + raise ValueError( + f"Energy-conserving detector integration requires ux == uy, " + f"got ux={ux}, uy={uy}." + ) + if isinstance(ux, float) and ux.is_integer(): + ux = int(ux) + + # Energy-conserving detector integration + return SumPooling(ux)(image) + + def get( self: Fluorescence, illuminated_volume: ArrayLike[complex], limits: ArrayLike[int], **kwargs: Any, - ) -> Image: + ) -> ArrayLike[complex]: """Simulates the imaging process using a fluorescence microscope. This method convolves the 3D illuminated volume with a pupil function @@ -1048,7 +1157,7 @@ def get( Returns ------- - Image: Image + image: np.ndarray A 2D image object representing the fluorescence projection. Notes @@ -1066,7 +1175,7 @@ def get( >>> optics = dt.Fluorescence( ... NA=1.4, wavelength=0.52e-6, magnification=60, ... ) - >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex)) + >>> volume = np.ones((128, 128, 10), dtype=complex) >>> limits = np.array([[0, 128], [0, 128], [0, 10]]) >>> properties = optics.properties() >>> filtered_properties = { @@ -1118,9 +1227,7 @@ def get( ] z_limits = limits[2, :] - output_image = Image( - np.zeros((*padded_volume.shape[0:2], 1)), copy=False - ) + output_image = np.zeros((*padded_volume.shape[0:2], 1)) index_iterator = range(padded_volume.shape[2]) @@ -1156,12 +1263,12 @@ def get( field = np.fft.ifft2(convolved_fourier_field) # # Discard remaining imaginary part (should be 0 up to rounding error) field = np.real(field) - output_image._value[:, :, 0] += field[ + output_image[:, :, 0] += field[ : padded_volume.shape[0], : padded_volume.shape[1] ] output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]] - output_image.properties = illuminated_volume.properties + pupils.properties + # output_image.properties = illuminated_volume.properties + pupils.properties return output_image @@ -1234,7 +1341,7 @@ class Brightfield(Optics): ------- `get(illuminated_volume: array_like[complex], limits: array_like[int, int], fields: array_like[complex], - **kwargs: Any) -> Image` + **kwargs: Any) -> np.ndarray` Simulates imaging with brightfield microscopy. @@ -1250,9 +1357,52 @@ class Brightfield(Optics): """ + __conversion_table__ = ConversionTable( - working_distance=(u.meter, u.meter), - ) + working_distance=(u.meter, u.meter), +) + + def validate_input(self, scattered): + """Semantic validation for brightfield microscopy.""" + + if isinstance(scattered, ScatteredVolume): + warnings.warn( + "Brightfield imaging from ScatteredVolume assumes a " + "weak-phase / projection approximation. " + "Use ScatteredField for physically accurate brightfield simulations.", + UserWarning, + ) + + def extract_contrast_volume( + self, + scattered: ScatteredVolume, + refractive_index_medium: float, + **kwargs: Any, + ) -> np.ndarray: + print('ri_medium', refractive_index_medium) + + ri = scattered.get_property("refractive_index", None) + value = scattered.get_property("value", None) + intensity = scattered.get_property("intensity", None) + + if intensity is not None: + warnings.warn( + "Scatterer defines 'intensity', which is ignored in " + "brightfield microscopy.", + UserWarning, + ) + + if ri is not None: + return (ri - refractive_index_medium) * scattered.array + + warnings.warn( + "No 'refractive_index' specified; using 'value' as a non-physical " + "brightfield contrast. Results are not physically calibrated. " + "Define 'refractive_index' for physically meaningful contrast.", + UserWarning, + ) + + return value * scattered.array def get( self: Brightfield, @@ -1260,7 +1410,7 @@ def get( limits: ArrayLike[int], fields: ArrayLike[complex], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Simulates imaging with brightfield microscopy. This method propagates light through the given volume, applying @@ -1285,7 +1435,7 @@ def get( Returns ------- - Image: Image + image: np.ndarray Processed image after simulating the brightfield imaging process. Examples @@ -1300,7 +1450,7 @@ def get( ... wavelength=0.52e-6, ... magnification=60, ... ) - >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex)) + >>> volume = np.ones((128, 128, 10), dtype=complex) >>> limits = np.array([[0, 128], [0, 128], [0, 10]]) >>> fields = np.array([np.ones((162, 162), dtype=complex)]) >>> properties = optics.properties() @@ -1345,7 +1495,7 @@ def get( if output_region[3] is None else int(output_region[3] - limits[1, 0] + pad[3]) ) - + padded_volume = padded_volume[ output_region[0] : output_region[2], output_region[1] : output_region[3], @@ -1353,9 +1503,7 @@ def get( ] z_limits = limits[2, :] - output_image = Image( - np.zeros((*padded_volume.shape[0:2], 1)) - ) + output_image = np.zeros((*padded_volume.shape[0:2], 1)) index_iterator = range(padded_volume.shape[2]) z_iterator = np.linspace( @@ -1414,7 +1562,25 @@ def get( light_in_focus = light_in * shifted_pupil if len(fields) > 0: - field = np.sum(fields, axis=0) + # field = np.sum(fields, axis=0) + field_arrays = [] + + for fs in fields: + # fs is a ScatteredField + arr = fs.array + + # Enforce (H, W, 1) shape + if arr.ndim == 2: + arr = arr[..., None] + + if arr.ndim != 3 or arr.shape[-1] != 1: + raise ValueError( + f"Expected field of shape (H, W, 1), got {arr.shape}" + ) + + field_arrays.append(arr) + + field = np.sum(field_arrays, axis=0) light_in_focus += field[..., 0] shifted_pupil = np.fft.fftshift(pupils[-1]) light_in_focus = light_in_focus * shifted_pupil @@ -1426,7 +1592,7 @@ def get( : padded_volume.shape[0], : padded_volume.shape[1] ] output_image = np.expand_dims(output_image, axis=-1) - output_image = Image(output_image[pad[0] : -pad[2], pad[1] : -pad[3]]) + output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]] if not kwargs.get("return_field", False): output_image = np.square(np.abs(output_image)) @@ -1436,7 +1602,7 @@ def get( # output_image = output_image * np.exp(1j * -np.pi / 4) # output_image = output_image + 1 - output_image.properties = illuminated_volume.properties + # output_image.properties = illuminated_volume.properties return output_image @@ -1624,6 +1790,73 @@ def __init__( illumination_angle=illumination_angle, **kwargs) + def validate_input(self, scattered): + if isinstance(scattered, ScatteredVolume): + warnings.warn( + "Darkfield imaging from ScatteredVolume is a very rough " + "approximation. Use ScatteredField for physically meaningful " + "darkfield simulations.", + UserWarning, + ) + + def extract_contrast_volume( + self, + scattered: ScatteredVolume, + refractive_index_medium: float, + **kwargs: Any, + ) -> np.ndarray: + """ + Approximate darkfield contrast from a volume (toy model). + + This is a non-physical approximation intended for qualitative simulations. + """ + + ri = scattered.get_property("refractive_index", None) + value = scattered.get_property("value", None) + intensity = scattered.get_property("intensity", None) + + # Intensity has no meaning here + if intensity is not None: + warnings.warn( + "Scatterer defines 'intensity', which is ignored in " + "darkfield microscopy.", + UserWarning, + ) + + if ri is not None: + delta_n = ri - refractive_index_medium + warnings.warn( + "Approximating darkfield contrast from refractive index. " + "Result is non-physical and qualitative only.", + UserWarning, + ) + return (delta_n ** 2) * scattered.array + + warnings.warn( + "No 'refractive_index' specified; using 'value' as a non-physical " + "darkfield scattering strength. Results are qualitative only.", + UserWarning, + ) + + return (value ** 2) * scattered.array + + def downscale_image(self, image: np.ndarray, upscale): + """Detector downscaling (energy conserving)""" + if not np.any(np.array(upscale) != 1): + return image + + ux, uy = upscale[:2] + if ux != uy: + raise ValueError( + f"Energy-conserving detector integration requires ux == uy, " + f"got ux={ux}, uy={uy}." + ) + if isinstance(ux, float) and ux.is_integer(): + ux = int(ux) + + # Energy-conserving detector integration + return SumPooling(ux)(image) + #Retrieve get as super def get( self: Darkfield, @@ -1631,7 +1864,7 @@ def get( limits: ArrayLike[int], fields: ArrayLike[complex], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Retrieve the darkfield image of the illuminated volume. Parameters @@ -1800,9 +2033,1001 @@ def get( return image +class NonOverlapping(Feature): + """Ensure volumes are placed non-overlapping in a 3D space. + + This feature ensures that a list of 3D volumes are positioned such that + their non-zero voxels do not overlap. If volumes overlap, their positions + are resampled until they are non-overlapping. If the maximum number of + attempts is exceeded, the feature regenerates the list of volumes and + raises a warning if non-overlapping placement cannot be achieved. + + Note: `min_distance` refers to the distance between the edges of volumes, + not their centers. Due to the way volumes are calculated, slight rounding + errors may affect the final distance. + + This feature is incompatible with non-volumetric scatterers such as + `MieScatterers`. + + Parameters + ---------- + feature: Feature + The feature that generates the list of volumes to place + non-overlapping. + min_distance: float, optional + The minimum distance between volumes in pixels. It can be negative to + allow for partial overlap. Defaults to 1. + max_attempts: int, optional + The maximum number of attempts to place volumes without overlap. + Defaults to 5. + max_iters: int, optional + The maximum number of resamplings. If this number is exceeded, a new + list of volumes is generated. Defaults to 100. + + Attributes + ---------- + __distributed__: bool + Always `False` for `NonOverlapping`, indicating that this feature’s + `.get()` method processes the entire input at once even if it is a + list, rather than distributing calls for each item of the list.N + + Methods + ------- + `get(*_, min_distance, max_attempts, **kwargs) -> array` + Generate a list of non-overlapping 3D volumes. + `_check_non_overlapping(list_of_volumes) -> bool` + Check if all volumes in the list are non-overlapping. + `_check_bounding_cubes_non_overlapping(...) -> bool` + Check if two bounding cubes are non-overlapping. + `_get_overlapping_cube(...) -> list[int]` + Get the overlapping cube between two bounding cubes. + `_get_overlapping_volume(...) -> array` + Get the overlapping volume between a volume and a bounding cube. + `_check_volumes_non_overlapping(...) -> bool` + Check if two volumes are non-overlapping. + `_resample_volume_position(volume) -> Image` + Resample the position of a volume to avoid overlap. + + Notes + ----- + - This feature performs bounding cube checks first to quickly reject + obvious overlaps before voxel-level checks. + - If the bounding cubes overlap, precise voxel-based checks are performed. + + Examples + --------- + >>> import deeptrack as dt + + Define an ellipse scatterer with randomly positioned objects: + + >>> import numpy as np + >>> + >>> scatterer = dt.Ellipse( + >>> radius= 13 * dt.units.pixels, + >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels, + >>> ) + + Create multiple scatterers: + + >>> scatterers = (scatterer ^ 8) + + Define the optics and create the image with possible overlap: + + >>> optics = dt.Fluorescence() + >>> im_with_overlap = optics(scatterers) + >>> im_with_overlap.store_properties() + >>> im_with_overlap_resolved = image_with_overlap() + + Gather position from image: + + >>> pos_with_overlap = np.array( + >>> im_with_overlap_resolved.get_property( + >>> "position", + >>> get_one=False + >>> ) + >>> ) + + Enforce non-overlapping and create the image without overlap: + + >>> non_overlapping_scatterers = dt.NonOverlapping( + ... scatterers, + ... min_distance=4, + ... ) + >>> im_without_overlap = optics(non_overlapping_scatterers) + >>> im_without_overlap.store_properties() + >>> im_without_overlap_resolved = im_without_overlap() + + Gather position from image: + + >>> pos_without_overlap = np.array( + >>> im_without_overlap_resolved.get_property( + >>> "position", + >>> get_one=False + >>> ) + >>> ) + + Create a figure with two subplots to visualize the difference: + + >>> import matplotlib.pyplot as plt + >>> + >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + >>> + >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray") + >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0]) + >>> axes[0].set_title("Overlapping Objects") + >>> axes[0].axis("off") + >>> + >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray") + >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0]) + >>> axes[1].set_title("Non-Overlapping Objects") + >>> axes[1].axis("off") + >>> plt.tight_layout() + >>> + >>> plt.show() + + Define function to calculate minimum distance: + + >>> def calculate_min_distance(positions): + >>> distances = [ + >>> np.linalg.norm(positions[i] - positions[j]) + >>> for i in range(len(positions)) + >>> for j in range(i + 1, len(positions)) + >>> ] + >>> return min(distances) + + Print minimum distances with and without overlap: + + >>> print(calculate_min_distance(pos_with_overlap)) + 10.768742383382174 + + >>> print(calculate_min_distance(pos_without_overlap)) + 30.82531120942446 + + """ + + __distributed__: bool = False + + def __init__( + self: NonOverlapping, + feature: Feature, + min_distance: float = 1, + max_attempts: int = 5, + max_iters: int = 100, + **kwargs: Any, + ): + """Initializes the NonOverlapping feature. + + Ensures that volumes are placed **non-overlapping** by iteratively + resampling their positions. If the maximum number of attempts is + exceeded, the feature regenerates the list of volumes. + + Parameters + ---------- + feature: Feature + The feature that generates the list of volumes. + min_distance: float, optional + The minimum separation distance **between volume edges**, in + pixels. It defaults to `1`. Negative values allow for partial + overlap. + max_attempts: int, optional + The maximum number of attempts to place the volumes without + overlap. It defaults to `5`. + max_iters: int, optional + The maximum number of resampling iterations per attempt. If + exceeded, a new list of volumes is generated. It defaults to `100`. + + """ + + super().__init__( + min_distance=min_distance, + max_attempts=max_attempts, + max_iters=max_iters, + **kwargs, + ) + self.feature = self.add_feature(feature, **kwargs) + + def get( + self: NonOverlapping, + *_: Any, + min_distance: float, + max_attempts: int, + max_iters: int, + **kwargs: Any, + ) -> list[np.ndarray]: + """Generates a list of non-overlapping 3D volumes within a defined + field of view (FOV). + + This method **iteratively** attempts to place volumes while ensuring + they maintain at least `min_distance` separation. If non-overlapping + placement is not achieved within `max_attempts`, a warning is issued, + and the best available configuration is returned. + + Parameters + ---------- + _: Any + Placeholder parameter, typically for an input image. + min_distance: float + The minimum required separation distance between volumes, in + pixels. + max_attempts: int + The maximum number of attempts to generate a valid non-overlapping + configuration. + max_iters: int + The maximum number of resampling iterations per attempt. + **kwargs: Any + Additional parameters that may be used by subclasses. + + Returns + ------- + list[np.ndarray] + A list of 3D volumes represented as NumPy arrays. If + non-overlapping placement is unsuccessful, the best available + configuration is returned. + + Warns + ----- + UserWarning + If non-overlapping placement is **not** achieved within + `max_attempts`, suggesting parameter adjustments such as increasing + the FOV or reducing `min_distance`. + + Notes + ----- + - The placement process prioritizes bounding cube checks for + efficiency. + - If bounding cubes overlap, voxel-based overlap checks are performed. + + """ + + for _ in range(max_attempts): + list_of_volumes = self.feature() + + if not isinstance(list_of_volumes, list): + list_of_volumes = [list_of_volumes] + + for _ in range(max_iters): + + list_of_volumes = [ + self._resample_volume_position(volume) + for volume in list_of_volumes + ] + + if self._check_non_overlapping(list_of_volumes): + return list_of_volumes + + # Generate a new list of volumes if max_attempts is exceeded. + self.feature.update() + + warnings.warn( + "Non-overlapping placement could not be achieved. Consider " + "adjusting parameters: reduce object radius, increase FOV, " + "or decrease min_distance.", + UserWarning, + ) + return list_of_volumes + + def _check_non_overlapping( + self: NonOverlapping, + list_of_volumes: list[np.ndarray], + ) -> bool: + """Determines whether all volumes in the provided list are + non-overlapping. + + This method verifies that the non-zero voxels of each 3D volume in + `list_of_volumes` are at least `min_distance` apart. It first checks + bounding boxes for early rejection and then examines actual voxel + overlap when necessary. Volumes are assumed to have a `position` + attribute indicating their placement in 3D space. + + Parameters + ---------- + list_of_volumes: list[np.ndarray] + A list of 3D arrays representing the volumes to be checked for + overlap. Each volume is expected to have a position attribute. + + Returns + ------- + bool + `True` if all volumes are non-overlapping, otherwise `False`. + + Notes + ----- + - If `min_distance` is negative, volumes are shrunk using isotropic + erosion before checking overlap. + - If `min_distance` is positive, volumes are padded and expanded using + isotropic dilation. + - Overlapping checks are first performed on bounding cubes for + efficiency. + - If bounding cubes overlap, voxel-level checks are performed. + + """ + from deeptrack.scatterers import ScatteredVolume + + from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend + from deeptrack.optics import _get_position + from deeptrack.math import isotropic_erosion, isotropic_dilation + + min_distance = self.min_distance() + crop = CropTight() + + new_volumes = [] + + for volume in list_of_volumes: + arr = volume.array + mask = arr != 0 + + if min_distance < 0: + new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend()) + else: + pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True) + new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend()) + new_arr = crop(new_arr) + + if self.get_backend() == "torch": + new_arr = new_arr.to(dtype=arr.dtype) + else: + new_arr = new_arr.astype(arr.dtype) + + new_volume = ScatteredVolume( + array=new_arr, + properties=volume.properties.copy(), + ) + + new_volumes.append(new_volume) + + list_of_volumes = new_volumes + min_distance = 1 + + # The position of the top left corner of each volume (index (0, 0, 0)). + volume_positions_1 = [ + _get_position(volume, mode="corner", return_z=True).astype(int) + for volume in list_of_volumes + ] + + # The position of the bottom right corner of each volume + # (index (-1, -1, -1)). + volume_positions_2 = [ + p0 + np.array(v.shape) + for v, p0 in zip(list_of_volumes, volume_positions_1) + ] + + # (x1, y1, z1, x2, y2, z2) for each volume. + volume_bounding_cube = [ + [*p0, *p1] + for p0, p1 in zip(volume_positions_1, volume_positions_2) + ] + + for i, j in itertools.combinations(range(len(list_of_volumes)), 2): + + # If the bounding cubes do not overlap, the volumes do not overlap. + if self._check_bounding_cubes_non_overlapping( + volume_bounding_cube[i], volume_bounding_cube[j], min_distance + ): + continue + + # If the bounding cubes overlap, get the overlapping region of each + # volume. + overlapping_cube = self._get_overlapping_cube( + volume_bounding_cube[i], volume_bounding_cube[j] + ) + overlapping_volume_1 = self._get_overlapping_volume( + list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube + ) + overlapping_volume_2 = self._get_overlapping_volume( + list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube + ) + + # If either the overlapping regions are empty, the volumes do not + # overlap (done for speed). + if (np.all(overlapping_volume_1 == 0) + or np.all(overlapping_volume_2 == 0)): + continue + + # If products of overlapping regions are non-zero, return False. + # if np.any(overlapping_volume_1 * overlapping_volume_2): + # return False + + # Finally, check that the non-zero voxels of the volumes are at + # least min_distance apart. + if not self._check_volumes_non_overlapping( + overlapping_volume_1, overlapping_volume_2, min_distance + ): + return False + + return True + + def _check_bounding_cubes_non_overlapping( + self: NonOverlapping, + bounding_cube_1: list[int], + bounding_cube_2: list[int], + min_distance: float, + ) -> bool: + """Determines whether two 3D bounding cubes are non-overlapping. + + This method checks whether the bounding cubes of two volumes are + **separated by at least** `min_distance` along **any** spatial axis. + + Parameters + ---------- + bounding_cube_1: list[int] + A list of six integers `[x1, y1, z1, x2, y2, z2]` representing + the first bounding cube. + bounding_cube_2: list[int] + A list of six integers `[x1, y1, z1, x2, y2, z2]` representing + the second bounding cube. + min_distance: float + The required **minimum separation distance** between the two + bounding cubes. + + Returns + ------- + bool + `True` if the bounding cubes are non-overlapping (separated by at + least `min_distance` along **at least one axis**), otherwise + `False`. + + Notes + ----- + - This function **only checks bounding cubes**, **not actual voxel + data**. + - If the bounding cubes are non-overlapping, the corresponding + **volumes are also non-overlapping**. + - This check is much **faster** than full voxel-based comparisons. + + """ + + # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2). + # Check that the bounding cubes are non-overlapping. + return ( + (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or + (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or + (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or + (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or + (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or + (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance) + ) + + def _get_overlapping_cube( + self: NonOverlapping, + bounding_cube_1: list[int], + bounding_cube_2: list[int], + ) -> list[int]: + """Computes the overlapping region between two 3D bounding cubes. + + This method calculates the coordinates of the intersection of two + axis-aligned bounding cubes, each represented as a list of six + integers: + + - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner. + - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner. + + The resulting overlapping region is determined by: + - Taking the **maximum** of the starting coordinates (`x1, y1, z1`). + - Taking the **minimum** of the ending coordinates (`x2, y2, z2`). + + If the cubes **do not** overlap, the resulting coordinates will not + form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`). + + Parameters + ---------- + bounding_cube_1: list[int] + The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. + bounding_cube_2: list[int] + The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. + + Returns + ------- + list[int] + A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the + overlapping bounding cube. If no overlap exists, the coordinates + will **not** define a valid cube. + + Notes + ----- + - This function does **not** check for valid input or ensure the + resulting cube is well-formed. + - If no overlap exists, downstream functions must handle the invalid + result. + + """ + + return [ + max(bounding_cube_1[0], bounding_cube_2[0]), + max(bounding_cube_1[1], bounding_cube_2[1]), + max(bounding_cube_1[2], bounding_cube_2[2]), + min(bounding_cube_1[3], bounding_cube_2[3]), + min(bounding_cube_1[4], bounding_cube_2[4]), + min(bounding_cube_1[5], bounding_cube_2[5]), + ] + + def _get_overlapping_volume( + self: NonOverlapping, + volume: np.ndarray, # 3D array. + bounding_cube: tuple[float, float, float, float, float, float], + overlapping_cube: tuple[float, float, float, float, float, float], + ) -> np.ndarray: + """Extracts the overlapping region of a 3D volume within the specified + overlapping cube. + + This method identifies and returns the subregion of `volume` that + lies within the `overlapping_cube`. The bounding information of the + volume is provided via `bounding_cube`. + + Parameters + ---------- + volume: np.ndarray + A 3D NumPy array representing the volume from which the + overlapping region is extracted. + bounding_cube: tuple[float, float, float, float, float, float] + The bounding cube of the volume, given as a tuple of six floats: + `(x1, y1, z1, x2, y2, z2)`. The first three values define the + **top-left-front** corner, while the last three values define the + **bottom-right-back** corner. + overlapping_cube: tuple[float, float, float, float, float, float] + The overlapping region between the volume and another volume, + represented in the same format as `bounding_cube`. + + Returns + ------- + np.ndarray + A 3D NumPy array representing the portion of `volume` that + lies within `overlapping_cube`. If the overlap does not exist, + an empty array may be returned. + + Notes + ----- + - The method computes the relative indices of `overlapping_cube` + within `volume` by subtracting the bounding cube's starting + position. + - The extracted region is determined by integer indices, meaning + coordinates are implicitly **floored to integers**. + - If `overlapping_cube` extends beyond `volume` boundaries, the + returned subregion is **cropped** to fit within `volume`. + + """ + + # The position of the top left corner of the overlapping cube in the volume + overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array( + bounding_cube[:3] + ) + + # The position of the bottom right corner of the overlapping cube in the volume + overlapping_cube_end_position = np.array( + overlapping_cube[3:] + ) - np.array(bounding_cube[:3]) + + # cast to int + overlapping_cube_position = overlapping_cube_position.astype(int) + overlapping_cube_end_position = overlapping_cube_end_position.astype(int) + + return volume[ + overlapping_cube_position[0] : overlapping_cube_end_position[0], + overlapping_cube_position[1] : overlapping_cube_end_position[1], + overlapping_cube_position[2] : overlapping_cube_end_position[2], + ] + + def _check_volumes_non_overlapping( + self: NonOverlapping, + volume_1: np.ndarray, + volume_2: np.ndarray, + min_distance: float, + ) -> bool: + """Determines whether the non-zero voxels in two 3D volumes are at + least `min_distance` apart. + + This method checks whether the active regions (non-zero voxels) in + `volume_1` and `volume_2` maintain a minimum separation of + `min_distance`. If the volumes differ in size, the positions of their + non-zero voxels are adjusted accordingly to ensure a fair comparison. + + Parameters + ---------- + volume_1: np.ndarray + A 3D NumPy array representing the first volume. + volume_2: np.ndarray + A 3D NumPy array representing the second volume. + min_distance: float + The minimum Euclidean distance required between any two non-zero + voxels in the two volumes. + + Returns + ------- + bool + `True` if all non-zero voxels in `volume_1` and `volume_2` are at + least `min_distance` apart, otherwise `False`. + + Notes + ----- + - This function assumes both volumes are correctly aligned within a + shared coordinate space. + - If the volumes are of different sizes, voxel positions are scaled + or adjusted for accurate distance measurement. + - Uses **Euclidean distance** for separation checking. + - If either volume is empty (i.e., no non-zero voxels), they are + considered non-overlapping. + + """ + + # Get the positions of the non-zero voxels of each volume. + if self.get_backend() == "torch": + positions_1 = torch.nonzero(volume_1, as_tuple=False) + positions_2 = torch.nonzero(volume_2, as_tuple=False) + else: + positions_1 = np.argwhere(volume_1) + positions_2 = np.argwhere(volume_2) + + # if positions_1.size == 0 or positions_2.size == 0: + # return True # If either volume is empty, they are "non-overlapping" + + # # If the volumes are not the same size, the positions of the non-zero + # # voxels of each volume need to be scaled. + # if positions_1.size == 0 or positions_2.size == 0: + # return True # If either volume is empty, they are "non-overlapping" + + # If the volumes are not the same size, the positions of the non-zero + # voxels of each volume need to be scaled. + if volume_1.shape != volume_2.shape: + positions_1 = ( + positions_1 * np.array(volume_2.shape) + / np.array(volume_1.shape) + ) + positions_1 = positions_1.astype(int) + + # Check that the non-zero voxels of the volumes are at least + # min_distance apart. + if self.get_backend() == "torch": + dist = torch.cdist( + positions_1.float(), + positions_2.float(), + ) + return bool((dist > min_distance).all()) + else: + return np.all(cdist(positions_1, positions_2) > min_distance) + + def _resample_volume_position( + self: NonOverlapping, + volume: np.ndarray | Image, + ) -> Image: + """Resamples the position of a 3D volume using its internal position + sampler. + + This method updates the `position` property of the given `volume` by + drawing a new position from the `_position_sampler` stored in the + volume's `properties`. If the sampled position is a `Quantity`, it is + converted to pixel units. + + Parameters + ---------- + volume: np.ndarray + The 3D volume whose position is to be resampled. The volume must + have a `properties` attribute containing dictionaries with + `position` and `_position_sampler` keys. + + Returns + ------- + Image + The same input volume with its `position` property updated to the + newly sampled value. + + Notes + ----- + - The `_position_sampler` function is expected to return a **tuple of + three floats** (e.g., `(x, y, z)`). + - If the sampled position is a `Quantity`, it is converted to pixels. + - **Only** dictionaries in `volume.properties` that contain both + `position` and `_position_sampler` keys are modified. + + """ + + pdict = volume.properties + if "position" in pdict and "_position_sampler" in pdict: + new_position = pdict["_position_sampler"]() + if isinstance(new_position, Quantity): + new_position = new_position.to("pixel").magnitude + pdict["position"] = new_position + + return volume + + +class SampleToMasks(Feature): + """Create a mask from a list of images. + + This feature applies a transformation function to each input image and + merges the resulting masks into a single multi-layer image. Each input + image must have a `position` property that determines its placement within + the final mask. When used with scatterers, the `voxel_size` property must + be provided for correct object sizing. + + Parameters + ---------- + transformation_function: Callable[[Image], Image] + A function that transforms each input image into a mask with + `number_of_masks` layers. + number_of_masks: PropertyLike[int], optional + The number of mask layers to generate. Default is 1. + output_region: PropertyLike[tuple[int, int, int, int]], optional + The size and position of the output mask, typically aligned with + `optics.output_region`. + merge_method: PropertyLike[str | Callable | list[str | Callable]], optional + Method for merging individual masks into the final image. Can be: + - "add" (default): Sum the masks. + - "overwrite": Later masks overwrite earlier masks. + - "or": Combine masks using a logical OR operation. + - "mul": Multiply masks. + - Function: Custom function taking two images and merging them. + + **kwargs: dict[str, Any] + Additional keyword arguments passed to the parent `Feature` class. + + Methods + ------- + `get(image, transformation_function, **kwargs) -> Image` + Applies the transformation function to the input image. + `_process_and_get(images, **kwargs) -> Image | np.ndarray` + Processes a list of images and generates a multi-layer mask. + + Returns + ------- + np.ndarray + The final mask image with the specified number of layers. + + Raises + ------ + ValueError + If `merge_method` is invalid. + + Examples + ------- + >>> import deeptrack as dt + + Define number of particles: + + >>> n_particles = 12 + + Define optics and particles: + + >>> import numpy as np + >>> + >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64)) + >>> particle = dt.PointParticle( + >>> position=lambda: np.random.uniform(5, 55, size=2), + >>> ) + >>> particles = particle ^ n_particles + + Define pipelines: + + >>> sim_im_pip = optics(particles) + >>> sim_mask_pip = particles >> dt.SampleToMasks( + ... lambda: lambda particles: particles > 0, + ... output_region=optics.output_region, + ... merge_method="or", + ... ) + >>> pipeline = sim_im_pip & sim_mask_pip + >>> pipeline.store_properties() + + Generate image and mask: + + >>> image, mask = pipeline.update()() + + Get particle positions: + + >>> positions = np.array(image.get_property("position", get_one=False)) + + Visualize results: + + >>> import matplotlib.pyplot as plt + >>> + >>> plt.subplot(1, 2, 1) + >>> plt.imshow(image, cmap="gray") + >>> plt.title("Original Image") + >>> plt.subplot(1, 2, 2) + >>> plt.imshow(mask, cmap="gray") + >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50) + >>> plt.title("Mask") + >>> plt.show() + + """ + + def __init__( + self: Feature, + transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor], + number_of_masks: PropertyLike[int] = 1, + output_region: PropertyLike[tuple[int, int, int, int]] = None, + merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add", + **kwargs: Any, + ): + """Initialize the SampleToMasks feature. + + Parameters + ---------- + transformation_function: Callable[[Image], Image] + Function to transform input images into masks. + number_of_masks: PropertyLike[int], optional + Number of mask layers. Default is 1. + output_region: PropertyLike[tuple[int, int, int, int]], optional + Output region of the mask. Default is None. + merge_method: PropertyLike[str | Callable | list[str | Callable]], optional + Method to merge masks. Defaults to "add". + **kwargs: dict[str, Any] + Additional keyword arguments passed to the parent class. + + """ + + super().__init__( + transformation_function=transformation_function, + number_of_masks=number_of_masks, + output_region=output_region, + merge_method=merge_method, + **kwargs, + ) + + def get( + self: Feature, + image: np.ndarray, + transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor], + **kwargs: Any, + ) -> np.ndarray: + """Apply the transformation function to a single image. + + Parameters + ---------- + image: np.ndarray + The input image. + transformation_function: Callable[[np.ndarray], np.ndarray] + Function to transform the image. + **kwargs: dict[str, Any] + Additional parameters. + + Returns + ------- + Image + The transformed image. + + """ + + return transformation_function(image.array) + + def _process_and_get( + self: Feature, + images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor, + **kwargs: Any, + ) -> np.ndarray: + """Process a list of images and generate a multi-layer mask. + + Parameters + ---------- + images: np.ndarray or list[np.ndarrray] or Image or list[Image] + List of input images or a single image. + **kwargs: dict[str, Any] + Additional parameters including `output_region`, `number_of_masks`, + and `merge_method`. + + Returns + ------- + Image or np.ndarray + The final mask image. + + """ + + # Handle list of images. + # if isinstance(images, list) and len(images) != 1: + list_of_labels = super()._process_and_get(images, **kwargs) + + from deeptrack.scatterers import ScatteredVolume + + for idx, (label, image) in enumerate(zip(list_of_labels, images)): + list_of_labels[idx] = \ + ScatteredVolume(array=label, properties=image.properties.copy()) + + # Create an empty output image. + output_region = kwargs["output_region"] + output = xp.zeros( + ( + output_region[2] - output_region[0], + output_region[3] - output_region[1], + kwargs["number_of_masks"], + ), + dtype=list_of_labels[0].array.dtype, + ) + + from deeptrack.optics import _get_position + + # Merge masks into the output. + for volume in list_of_labels: + label = volume.array + position = _get_position(volume) + + p0 = xp.round(position - xp.asarray(output_region[0:2])) + p0 = p0.astype(xp.int64) + + + if xp.any(p0 > xp.asarray(output.shape[:2])) or \ + xp.any(p0 + xp.asarray(label.shape[:2]) < 0): + continue + + crop_x = (-xp.minimum(p0[0], 0)).item() + crop_y = (-xp.minimum(p0[1], 0)).item() + + crop_x_end = int( + label.shape[0] + - np.max([p0[0] + label.shape[0] - output.shape[0], 0]) + ) + crop_y_end = int( + label.shape[1] + - np.max([p0[1] + label.shape[1] - output.shape[1], 0]) + ) + + labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :] + + p0[0] = np.max([p0[0], 0]) + p0[1] = np.max([p0[1], 0]) + + p0 = p0.astype(int) + + output_slice = output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + ] + + for label_index in range(kwargs["number_of_masks"]): + + if isinstance(kwargs["merge_method"], list): + merge = kwargs["merge_method"][label_index] + else: + merge = kwargs["merge_method"] + + if merge == "add": + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] += labelarg[..., label_index] + + elif merge == "overwrite": + output_slice[ + labelarg[..., label_index] != 0, label_index + ] = labelarg[labelarg[..., label_index] != 0, \ + label_index] + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] = output_slice[..., label_index] + + elif merge == "or": + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] = xp.logical_or( + output_slice[..., label_index] != 0, + labelarg[..., label_index] != 0 + ) + + elif merge == "mul": + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] *= labelarg[..., label_index] + + else: + # No match, assume function + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] = merge( + output_slice[..., label_index], + labelarg[..., label_index], + ) + + return output + + #TODO ***??*** revise _get_position - torch, typing, docstring, unit test def _get_position( - image: Image, + scatterer: ScatteredObject, mode: str = "corner", return_z: bool = False, ) -> np.ndarray: @@ -1826,26 +3051,23 @@ def _get_position( num_outputs = 2 + return_z - if mode == "corner" and image.size > 0: + if mode == "corner" and scatterer.array.size > 0: import scipy.ndimage - image = image.to_numpy() - - shift = scipy.ndimage.center_of_mass(np.abs(image)) + shift = scipy.ndimage.center_of_mass(np.abs(scatterer.array)) if np.isnan(shift).any(): - shift = np.array(image.shape) / 2 + shift = np.array(scatterer.array.shape) / 2 else: shift = np.zeros((num_outputs)) - position = np.array(image.get_property("position", default=None)) + position = np.array(scatterer.get_property("position", default=None)) if position is None: return position scale = np.array(get_active_scale()) - if len(position) == 3: position = position * scale + 0.5 * (scale - 1) if return_z: @@ -1856,7 +3078,7 @@ def _get_position( elif len(position) == 2: if return_z: outp = ( - np.array([position[0], position[1], image.get_property("z", default=0)]) + np.array([position[0], position[1], scatterer.get_property("z", default=0)]) * scale - shift + 0.5 * (scale - 1) @@ -1868,6 +3090,58 @@ def _get_position( return position +def _bilinear_interpolate_numpy( + scatterer: np.ndarray, x_off: float, y_off: float +) -> np.ndarray: + """Apply bilinear subpixel interpolation in the x–y plane (NumPy).""" + kernel = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], + [0.0, x_off * (1 - y_off), x_off * y_off], + ] + ) + out = np.zeros_like(scatterer) + for z in range(scatterer.shape[2]): + if np.iscomplexobj(scatterer): + out[:, :, z] = ( + convolve(np.real(scatterer[:, :, z]), kernel, mode="constant") + + 1j + * convolve(np.imag(scatterer[:, :, z]), kernel, mode="constant") + ) + else: + out[:, :, z] = convolve(scatterer[:, :, z], kernel, mode="constant") + return out + + +def _bilinear_interpolate_torch( + scatterer: torch.Tensor, x_off: float, y_off: float +) -> torch.Tensor: + """Apply bilinear subpixel interpolation in the x–y plane (Torch). + + Uses grid_sample for autograd-friendly interpolation. + """ + H, W, D = scatterer.shape + + # Normalized shifts in [-1,1] + x_shift = 2 * x_off / (W - 1) + y_shift = 2 * y_off / (H - 1) + + yy, xx = torch.meshgrid( + torch.linspace(-1, 1, H, device=scatterer.device, dtype=scatterer.dtype), + torch.linspace(-1, 1, W, device=scatterer.device, dtype=scatterer.dtype), + indexing="ij", + ) + grid = torch.stack((xx + x_shift, yy + y_shift), dim=-1) # (H,W,2) + grid = grid.unsqueeze(0).repeat(D, 1, 1, 1) # (D,H,W,2) + + inp = scatterer.permute(2, 0, 1).unsqueeze(1) # (D,1,H,W) + + out = F.grid_sample(inp, grid, mode="bilinear", + padding_mode="zeros", align_corners=True) + return out.squeeze(1).permute(1, 2, 0) # (H,W,D) + + #TODO ***??*** revise _create_volume - torch, typing, docstring, unit test def _create_volume( list_of_scatterers: list, @@ -1903,6 +3177,12 @@ def _create_volume( Spatial limits of the volume. """ + # contrast_type = kwargs.get("contrast_type", None) + # if contrast_type is None: + # raise RuntimeError( + # "_create_volume requires a contrast_type " + # "(e.g. 'intensity' or 'refractive_index')" + # ) if not isinstance(list_of_scatterers, list): list_of_scatterers = [list_of_scatterers] @@ -1927,24 +3207,28 @@ def _create_volume( # This accounts for upscale doing AveragePool instead of SumPool. This is # a bit of a hack, but it works for now. - fudge_factor = scale[0] * scale[1] / scale[2] + # fudge_factor = scale[0] * scale[1] / scale[2] for scatterer in list_of_scatterers: - position = _get_position(scatterer, mode="corner", return_z=True) - if scatterer.get_property("intensity", None) is not None: - intensity = scatterer.get_property("intensity") - scatterer_value = intensity * fudge_factor - elif scatterer.get_property("refractive_index", None) is not None: - refractive_index = scatterer.get_property("refractive_index") - scatterer_value = ( - refractive_index - refractive_index_medium - ) - else: - scatterer_value = scatterer.get_property("value") + # if contrast_type == "intensity": + # value = scatterer.get_property("intensity", None) + # if value is None: + # raise ValueError("Scatterer has no intensity.") + # scatterer_value = value + + # elif contrast_type == "refractive_index": + # ri = scatterer.get_property("refractive_index", None) + # if ri is None: + # raise ValueError("Scatterer has no refractive_index.") + # scatterer_value = ri - refractive_index_medium + + # else: + # raise RuntimeError(f"Unknown contrast_type: {contrast_type}") - scatterer = scatterer * scatterer_value + # # Scale the array accordingly + # scatterer.array = scatterer.array * scatterer_value if limits is None: limits = np.zeros((3, 2), dtype=np.int32) @@ -1952,26 +3236,25 @@ def _create_volume( limits[:, 1] = np.floor(position).astype(np.int32) + 1 if ( - position[0] + scatterer.shape[0] < OR[0] + position[0] + scatterer.array.shape[0] < OR[0] or position[0] > OR[2] - or position[1] + scatterer.shape[1] < OR[1] + or position[1] + scatterer.array.shape[1] < OR[1] or position[1] > OR[3] ): continue - padded_scatterer = Image( - np.pad( - scatterer, + # Pad scatterer to avoid edge effects during interpolation + padded_scatterer_arr = np.pad( #Use Pad instead and make it torch-compatible? + scatterer.array, [(2, 2), (2, 2), (2, 2)], "constant", constant_values=0, ) - ) - padded_scatterer.merge_properties_from(scatterer) - - scatterer = padded_scatterer - position = _get_position(scatterer, mode="corner", return_z=True) - shape = np.array(scatterer.shape) + padded_scatterer = ScatteredVolume( + array=padded_scatterer_arr, properties=scatterer.properties.copy(), + ) + position = _get_position(padded_scatterer, mode="corner", return_z=True) + shape = np.array(padded_scatterer.array.shape) if position is None: RuntimeWarning( @@ -1980,36 +3263,20 @@ def _create_volume( ) continue - splined_scatterer = np.zeros_like(scatterer) - x_off = position[0] - np.floor(position[0]) y_off = position[1] - np.floor(position[1]) - kernel = np.array( - [ - [0, 0, 0], - [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], - [0, x_off * (1 - y_off), x_off * y_off], - ] - ) - - for z in range(scatterer.shape[2]): - if splined_scatterer.dtype == complex: - splined_scatterer[:, :, z] = ( - convolve( - np.real(scatterer[:, :, z]), kernel, mode="constant" - ) - + convolve( - np.imag(scatterer[:, :, z]), kernel, mode="constant" - ) - * 1j - ) - else: - splined_scatterer[:, :, z] = convolve( - scatterer[:, :, z], kernel, mode="constant" - ) + + if isinstance(padded_scatterer.array, np.ndarray): # get_backend is a method of Features and not exposed + splined_scatterer = _bilinear_interpolate_numpy(padded_scatterer.array, x_off, y_off) + elif isinstance(padded_scatterer.array, torch.Tensor): + splined_scatterer = _bilinear_interpolate_torch(padded_scatterer.array, x_off, y_off) + else: + raise TypeError( + f"Unsupported array type {type(padded_scatterer.array)}. " + "Expected np.ndarray or torch.Tensor." + ) - scatterer = splined_scatterer position = np.floor(position) new_limits = np.zeros(limits.shape, dtype=np.int32) for i in range(3): @@ -2038,7 +3305,8 @@ def _create_volume( within_volume_position = position - limits[:, 0] - # NOTE: Maybe shouldn't be additive. + # NOTE: Maybe shouldn't be ONLY additive. + # give options: sum default, but also mean, max, min, or volume[ int(within_volume_position[0]) : int(within_volume_position[0] + shape[0]), @@ -2048,5 +3316,5 @@ def _create_volume( int(within_volume_position[2]) : int(within_volume_position[2] + shape[2]), - ] += scatterer - return volume, limits + ] += splined_scatterer + return volume, limits \ No newline at end of file diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 04a7c5ea..b7e1b70a 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -166,6 +166,7 @@ import numpy as np from numpy.typing import NDArray from pint import Quantity +from dataclasses import dataclass, field from deeptrack.holography import get_propagation_matrix from deeptrack.backend.units import ( @@ -174,12 +175,14 @@ get_active_voxel_size, ) from deeptrack.backend import mie +from deeptrack.math import AveragePooling from deeptrack.features import Feature, MERGE_STRATEGY_APPEND -from deeptrack.image import pad_image_to_fft, Image +from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike from deeptrack import units_registry as u + __all__ = [ "Scatterer", "PointParticle", @@ -238,7 +241,7 @@ class Scatterer(Feature): """ - __list_merge_strategy__ = MERGE_STRATEGY_APPEND + __list_merge_strategy__ = MERGE_STRATEGY_APPEND ### Not clear why needed __distributed__ = False __conversion_table__ = ConversionTable( position=(u.pixel, u.pixel), @@ -258,11 +261,11 @@ def __init__( **kwargs, ) -> None: # Ignore warning to help with comparison with arrays. - if upsample is not 1: # noqa: F632 - warnings.warn( - f"Setting upsample != 1 is deprecated. " - f"Please, instead use dt.Upscale(f, factor={upsample})" - ) + # if upsample != 1: # noqa: F632 + # warnings.warn( + # f"Setting upsample != 1 is deprecated. " + # f"Please, instead use dt.Upscale(f, factor={upsample})" + # ) self._processed_properties = False @@ -278,6 +281,21 @@ def __init__( **kwargs, ) + def _antialias_volume(self, volume, factor: int): + """Geometry-only supersampling anti-aliasing. + + Assumes `volume` was generated on a grid oversampled by `factor` + and downsamples it back by average pooling. + """ + if factor == 1: + return volume + + # average pooling conserves fractional occupancy + return AveragePooling( + factor + )(volume) + + def _process_properties( self, properties: dict @@ -296,7 +314,7 @@ def _process_and_get( upsample_axes=None, crop_empty=True, **kwargs - ) -> list[Image] | list[np.ndarray]: + ) -> list[np.ndarray]: # Post processes the created object to handle upsampling, # as well as cropping empty slices. if not self._processed_properties: @@ -307,16 +325,31 @@ def _process_and_get( + "Optics.upscale != 1." ) - voxel_size = get_active_voxel_size() - # Calls parent _process_and_get. - new_image = super()._process_and_get( + voxel_size = np.asarray(get_active_voxel_size(), float) + + apply_supersampling = upsample > 1 and isinstance(self, VolumeScatterer) + + if upsample > 1 and not apply_supersampling: + warnings.warn( + "Geometry supersampling (upsample) is ignored for " + "FieldScatterers.", + UserWarning, + ) + + if apply_supersampling: + voxel_size /= float(upsample) + + new_image = super(Scatterer, self)._process_and_get( *args, voxel_size=voxel_size, upsample=upsample, **kwargs, - ) - new_image = new_image[0] + )[0] + + if apply_supersampling: + new_image = self._antialias_volume(new_image, factor=upsample) + if new_image.size == 0: warnings.warn( @@ -333,32 +366,35 @@ def _process_and_get( new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))] new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))] - return [Image(new_image)] + # # Copy properties + # props = kwargs.copy() + return [self._wrap_output(new_image, kwargs)] - def _no_wrap_format_input( - self, - *args, - **kwargs - ) -> list: - return self._image_wrapped_format_input(*args, **kwargs) + def _wrap_output(self, array, props): + raise NotImplementedError( + f"{self.__class__.__name__} must implement _wrap_output()" + ) - def _no_wrap_process_and_get( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_and_get(*args, **feature_input) - def _no_wrap_process_output( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_output(*args, **feature_input) +class VolumeScatterer(Scatterer): + """Abstract scatterer producing ScatteredVolume outputs.""" + def _wrap_output(self, array, props) -> ScatteredVolume: + return ScatteredVolume( + array=array, + properties=props.copy(), + ) + + +class FieldScatterer(Scatterer): + def _wrap_output(self, array, props) -> ScatteredField: + return ScatteredField( + array=array, + properties=props.copy(), + ) #TODO ***??*** revise PointParticle - torch, typing, docstring, unit test -class PointParticle(Scatterer): +class PointParticle(VolumeScatterer): """Generate a diffraction-limited point particle. A point particle is approximated by the size of a single pixel or voxel. @@ -389,12 +425,12 @@ def __init__( """ """ - + kwargs.pop("upsample", None) super().__init__(upsample=1, upsample_axes=(), **kwargs) def get( self: PointParticle, - image: Image | np.ndarray, + image: np.ndarray, **kwarg: Any, ) -> NDArray[Any] | torch.Tensor: """Evaluate and return the scatterer volume.""" @@ -405,7 +441,7 @@ def get( #TODO ***??*** revise Ellipse - torch, typing, docstring, unit test -class Ellipse(Scatterer): +class Ellipse(VolumeScatterer): """Generates an elliptical disk scatterer Parameters @@ -441,6 +477,7 @@ class Ellipse(Scatterer): """ + __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), rotation=(u.radian, u.radian), @@ -519,7 +556,7 @@ def get( #TODO ***??*** revise Sphere - torch, typing, docstring, unit test -class Sphere(Scatterer): +class Sphere(VolumeScatterer): """Generates a spherical scatterer Parameters @@ -559,7 +596,7 @@ def __init__( def get( self, - image: Image | np.ndarray, + image: np.ndarray, radius: float, voxel_size: float, **kwargs @@ -584,7 +621,7 @@ def get( #TODO ***??*** revise Ellipsoid - torch, typing, docstring, unit test -class Ellipsoid(Scatterer): +class Ellipsoid(VolumeScatterer): """Generates an ellipsoidal scatterer Parameters @@ -694,7 +731,7 @@ def _process_properties( def get( self, - image: Image | np.ndarray, + image: np.ndarray, radius: float, rotation: ArrayLike[float] | float, voxel_size: float, @@ -741,7 +778,7 @@ def get( #TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test -class MieScatterer(Scatterer): +class MieScatterer(FieldScatterer): """Base implementation of a Mie particle. New Mie-theory scatterers can be implemented by extending this class, and @@ -826,6 +863,7 @@ class MieScatterer(Scatterer): """ + __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), polarization_angle=(u.radian, u.radian), @@ -856,6 +894,7 @@ def __init__( illumination_angle: float=0, amp_factor: float=1, phase_shift_correction: bool=False, + # pupil: ArrayLike=[], # Daniel **kwargs, ) -> None: if polarization_angle is not None: @@ -864,11 +903,10 @@ def __init__( "Please use input_polarization instead" ) input_polarization = polarization_angle - kwargs.pop("is_field", None) kwargs.pop("crop_empty", None) super().__init__( - is_field=True, + is_field=True, # remove crop_empty=False, L=L, offset_z=offset_z, @@ -889,6 +927,7 @@ def __init__( illumination_angle=illumination_angle, amp_factor=amp_factor, phase_shift_correction=phase_shift_correction, + # pupil=pupil, # Daniel **kwargs, ) @@ -1014,7 +1053,8 @@ def get_plane_in_polar_coords( shape: int, voxel_size: ArrayLike[float], plane_position: float, - illumination_angle: float + illumination_angle: float, + # k: float, # Daniel ) -> tuple[float, float, float, float]: """Computes the coordinates of the plane in polar form.""" @@ -1027,15 +1067,24 @@ def get_plane_in_polar_coords( R2_squared = X ** 2 + Y ** 2 R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z. + + # # DANIEL + # Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0] + # # is dimensionally ok? + # sin_theta=Q/(k) + # pupil_mask=sin_theta<1 + # cos_theta=np.zeros(sin_theta.shape) + # cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2) # Fet the angles. cos_theta = Z / R3 + illumination_cos_theta = ( np.cos(np.arccos(cos_theta) + illumination_angle) ) phi = np.arctan2(Y, X) - return R3, cos_theta, illumination_cos_theta, phi + return R3, cos_theta, illumination_cos_theta, phi#, pupil_mask # Daniel def get( self, @@ -1060,6 +1109,7 @@ def get( illumination_angle: float, amp_factor: float, phase_shift_correction: bool, + # pupil: ArrayLike, # Daniel **kwargs, ) -> ArrayLike[float]: """Abstract method to initialize the Mie scatterer""" @@ -1067,8 +1117,9 @@ def get( # Get size of the output. xSize, ySize = self.get_xy_size(output_region, padding) voxel_size = get_active_voxel_size() + scale = get_active_scale() arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex) - position = np.array(position) * voxel_size[: len(position)] + position = np.array(position) * scale[: len(position)] * voxel_size[: len(position)] pupil_physical_size = working_distance * np.tan(collection_angle) * 2 @@ -1076,7 +1127,10 @@ def get( ratio = offset_z / (working_distance - z) - # Position of pbjective relative particle. + # Wave vector. + k = 2 * np.pi / wavelength * refractive_index_medium + + # Position of objective relative particle. relative_position = np.array( ( position_objective[0] - position[0], @@ -1085,12 +1139,13 @@ def get( ) ) - # Get field evaluation plane at offset_z. + # Get field evaluation plane at offset_z. # , pupil_mask # Daniel R3_field, cos_theta_field, illumination_angle_field, phi_field =\ self.get_plane_in_polar_coords( arr.shape, voxel_size, relative_position * ratio, - illumination_angle + illumination_angle, + # k # Daniel ) cos_phi_field, sin_phi_field = np.cos(phi_field), np.sin(phi_field) @@ -1108,7 +1163,7 @@ def get( sin_phi_field / ratio ) - # If the beam is within the pupil. + # If the beam is within the pupil. Remove if Daniel pupil_mask = (x_farfield - position_objective[0]) ** 2 + ( y_farfield - position_objective[1] ) ** 2 < (pupil_physical_size / 2) ** 2 @@ -1146,9 +1201,6 @@ def get( * illumination_angle_field ) - # Wave vector. - k = 2 * np.pi / wavelength * refractive_index_medium - # Harmonics. A, B = coefficients(L) PI, TAU = mie.harmonics(illumination_angle_field, L) @@ -1165,12 +1217,15 @@ def get( [E[i] * B[i] * PI[i] + E[i] * A[i] * TAU[i] for i in range(0, L)] ) + # Daniel + # arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor arr[pupil_mask] = ( -1j / (k * R3_field) * np.exp(1j * k * R3_field) * (S2 * S2_coef + S1 * S1_coef) ) / amp_factor + # For phase shift correction (a multiplication of the field # by exp(1j * k * z)). @@ -1188,15 +1243,23 @@ def get( -mask.shape[1] // 2 : mask.shape[1] // 2, ] mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2)) - arr = arr * mask + # Not sure if needed... CM + # if len(pupil)>0: + # c_pix=[arr.shape[0]//2,arr.shape[1]//2] + + # arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil + + # Daniel + # fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr)))) fourier_field = np.fft.fft2(arr) propagation_matrix = get_propagation_matrix( fourier_field.shape, - pixel_size=voxel_size[2], + pixel_size=voxel_size[:2], # this needs a double check wavelength=wavelength / refractive_index_medium, + # to_z=(-z), # Daniel to_z=(-offset_z - z), dy=( relative_position[0] * ratio @@ -1206,11 +1269,12 @@ def get( dx=( relative_position[1] * ratio + position[1] - + (padding[1] - arr.shape[1] / 2) * voxel_size[1] + + (padding[2] - arr.shape[1] / 2) * voxel_size[1] # check if padding is top, bottom, left, right ), ) + fourier_field = ( - fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) + fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) # Remove last part (from exp)) if Daniel ) if return_fft: @@ -1275,6 +1339,7 @@ class MieSphere(MieScatterer): """ + def __init__( self, radius: float = 1e-6, @@ -1377,6 +1442,7 @@ class MieStratifiedSphere(MieScatterer): """ + def __init__( self, radius: ArrayLike[float] = [1e-6], @@ -1412,3 +1478,62 @@ def inner( refractive_index=refractive_index, **kwargs, ) + + +@dataclass +class ScatteredBase: + """Base class for scatterers (volumes and fields).""" + + array: np.ndarray | torch.Tensor + properties: dict[str, Any] = field(default_factory=dict) + + @property + def ndim(self) -> int: + """Number of dimensions of the underlying array.""" + return self.array.ndim + + @property + def shape(self) -> int: + """Number of dimensions of the underlying array.""" + return self.array.shape + + @property + def pos3d(self) -> np.ndarray: + return np.array([*self.position, self.z], dtype=float) + + @property + def position(self) -> np.ndarray: + pos = self.properties.get("position", None) + if pos is None: + return None + pos = np.asarray(pos, dtype=float) + if pos.ndim == 2 and pos.shape[0] == 1: + pos = pos[0] + return pos + + def as_array(self) -> ArrayLike: + """Return the underlying array. + + Notes + ----- + The raw array is also directly available as ``scatterer.array``. + This method exists mainly for API compatibility and clarity. + + """ + + return self.array + + def get_property(self, key: str, default: Any = None) -> Any: + return getattr(self, key, self.properties.get(key, default)) + + +@dataclass +class ScatteredVolume(ScatteredBase): + """Voxelized volume produced by a VolumeScatterer.""" + pass + + +@dataclass +class ScatteredField(ScatteredBase): + """Complex field produced by a FieldScatterer.""" + pass \ No newline at end of file