Skip to content

Commit 9d72f74

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Fix AvgPool2d for float16 (pytorch#136822)
This was a stupid cast error that caused MPSGraph to crash with the following exception ``` (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %3 = "mps.multiply"(%2, %arg1) : (tensor<1x3x9x9xf16>, tensor<1xf32>) -> tensor<*xf32> (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %3 = "mps.multiply"(%2, %arg1) : (tensor<1x3x9x9xf16>, tensor<1xf32>) -> tensor<*xf32> /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification' ``` Pull Request resolved: pytorch#136822 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#136754, pytorch#136755, pytorch#136821
1 parent 2b6f4e9 commit 9d72f74

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

aten/src/ATen/native/mps/operations/Pooling.mm

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,18 +311,22 @@ static void avg_pool2d_template(const Tensor& input,
311311
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor:paddedTensor descriptor:desc name:nil];
312312
if (cachedGraph.divisorTensor) {
313313
// workaround: custom divisor isn't supported by MPS backend, so we scale manually
314-
return [mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor
315-
secondaryTensor:cachedGraph.divisorTensor
316-
name:nil];
314+
return
315+
[mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor
316+
secondaryTensor:mps::castMPSTensor(
317+
mpsGraph, cachedGraph.divisorTensor, [avgPoolTensor dataType])
318+
name:nil];
317319
} else {
318320
return avgPoolTensor;
319321
}
320322
} else { // backward pass
321323
MPSGraphTensor* scaledGradTensor = cachedGraph.gradOutputTensor;
322324
if (cachedGraph.divisorTensor) {
323-
scaledGradTensor = [mpsGraph multiplicationWithPrimaryTensor:cachedGraph.gradOutputTensor
324-
secondaryTensor:cachedGraph.divisorTensor
325-
name:nil];
325+
scaledGradTensor = [mpsGraph
326+
multiplicationWithPrimaryTensor:cachedGraph.gradOutputTensor
327+
secondaryTensor:mps::castMPSTensor(
328+
mpsGraph, cachedGraph.divisorTensor, [scaledGradTensor dataType])
329+
name:nil];
326330
}
327331
MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DGradientWithGradientTensor:scaledGradTensor
328332
sourceTensor:paddedTensor

test/test_mps.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,8 +1005,6 @@ def mps_ops_modifier(ops):
10051005

10061006
SKIPLIST = {
10071007
# Unsupported
1008-
# input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible
1009-
'nn.functional.avg_pool2d': [torch.float16],
10101008

10111009
# This doesn't work on M1, but is partially working on M2 with the exception of torch.float16
10121010
'nn.functional.conv3d': None,
@@ -12035,6 +12033,7 @@ class TestConsistency(TestCaseMPS):
1203512033
'var_mean_unbiased',
1203612034
'acosh', 'asinh', 'asin',
1203712035
'masked.std',
12036+
'nn.functional.avg_pool2d', # NS: Only for backward pass
1203812037
'nn.functional.normalize',
1203912038
'nn.functional.triplet_margin_loss',
1204012039
'nn.functional.triplet_margin_with_distance_loss',

0 commit comments

Comments
 (0)