diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index c77430fe4..e46e0ed22 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -105,14 +105,19 @@ public struct BatchNorm: Layer { precondition( input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") - var (offset, scale) = {x in (x.offset, x.scale) }(self) - if positiveAxis != input.rank - 1 { - var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) - broadcastShape[positiveAxis] = input.shape[positiveAxis] - offset = offset.reshaped(to: broadcastShape) - - scale = scale.reshaped(to: broadcastShape) - } +// var (offset, scale) = {x in (x.offset, x.scale) }(self) +// if positiveAxis != input.rank - 1 { +// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) +// broadcastShape[positiveAxis] = input.shape[positiveAxis] +// offset = offset.reshaped(to: broadcastShape) +// scale = scale.reshaped(to: broadcastShape) +// } + let offsetOriginal = self.offset + let scaleOriginal = self.scale + let (offset, scale) = Self._sr13263workaround(offset: offsetOriginal, + scale: scaleOriginal, + input: input, + positiveAxis: positiveAxis) switch Context.local.learningPhase { case .training: return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) @@ -120,6 +125,23 @@ public struct BatchNorm: Layer { return doInference(input, offset: offset, scale: scale) } } + + @inline(never) + @differentiable(reverse) // if the function is `public` or `internal`, the compiler crashes + private static func _sr13263workaround( + offset: Tensor, + scale: Tensor, + input: Tensor, + positiveAxis: Int + ) -> (Tensor, Tensor) { + if positiveAxis != input.rank - 1 { + var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) + broadcastShape[positiveAxis] = input.shape[positiveAxis] + return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) + } else { + return (offset, scale) + } + } private func doTraining( _ input: Tensor, offset: Tensor, scale: Tensor, axis: Int