diff --git a/keras/src/losses/loss.py b/keras/src/losses/loss.py index 6af73902d0f..fd16daac2af 100644 --- a/keras/src/losses/loss.py +++ b/keras/src/losses/loss.py @@ -211,7 +211,7 @@ def apply_mask(sample_weight, mask, dtype, reduction): dtype, ) valid = ops.sum(mask) # May be 0! - mask *= total / (valid + backend.epsilon()) + mask *= ops.divide_no_nan(total, valid) if sample_weight is not None: sample_weight = ops.cast(sample_weight, dtype=dtype)