diff --git a/augraphy/base/augmentationpipeline.py b/augraphy/base/augmentationpipeline.py index 0c85864..d465828 100644 --- a/augraphy/base/augmentationpipeline.py +++ b/augraphy/base/augmentationpipeline.py @@ -11,6 +11,7 @@ from augraphy.base.augmentation import Augmentation from augraphy.base.augmentationresult import AugmentationResult from augraphy.base.augmentationsequence import AugmentationSequence +from augraphy.base.oneof import OneOf from augraphy.utilities.detectdpi import dpi_resize from augraphy.utilities.detectdpi import DPIMetrics from augraphy.utilities.overlaybuilder import OverlayBuilder @@ -750,7 +751,11 @@ def apply_phase(self, data, layer, phase): data["log"]["time"].append((augmentation, elapsed)) # not "OneOf" or "AugmentationSequence" - if isinstance(augmentation, Augmentation): + if ( + isinstance(augmentation, Augmentation) + and not isinstance(augmentation, AugmentationSequence) + and not isinstance(augmentation, OneOf) + ): # unpacking augmented image, mask, keypoints and bounding boxes from output if (mask is not None) or (keypoints is not None) or (bounding_boxes is not None): result, mask, keypoints, bounding_boxes = result diff --git a/augraphy/base/augmentationsequence.py b/augraphy/base/augmentationsequence.py index 66f306e..557ce1b 100644 --- a/augraphy/base/augmentationsequence.py +++ b/augraphy/base/augmentationsequence.py @@ -48,5 +48,6 @@ def __call__(self, image, layer=None, mask=None, keypoints=None, bounding_boxes= elif isinstance(current_result, tuple): if current_result[0] is not None: result = current_result - + if (mask is not None) or (keypoints is not None) or (bounding_boxes is not None): + result = (result, mask, keypoints, bounding_boxes) return result, self.augmentations diff --git a/augraphy/base/oneof.py b/augraphy/base/oneof.py index 74949af..94d4a82 100644 --- a/augraphy/base/oneof.py +++ b/augraphy/base/oneof.py @@ -29,8 +29,10 @@ def __call__(self, image, layer=None, mask=None, keypoints=None, bounding_boxes= augmentation = self.augmentations[np.argmax(self.augmentation_probabilities)] # Applies the selected Augmentation. - image = augmentation(image, mask=mask, keypoints=keypoints, bounding_boxes=bounding_boxes, force=True) - return image, [augmentation] + result = augmentation(image, mask=mask, keypoints=keypoints, bounding_boxes=bounding_boxes, force=True) + if isinstance(augmentation, AugmentationSequence): + return result[0], result[1] + return result, [augmentation] # Constructs a string containing the representations # of each augmentation