-
Notifications
You must be signed in to change notification settings - Fork 717
Open
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)triagedReviewed and examined, release as been assigned if applicable (status)Reviewed and examined, release as been assigned if applicable (status)
Description
🐞Describing the bug
CoreML BatchNorm3d crashes CoreML process
To Reproduce
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.BatchNorm3d(3)
def forward(self, x):
return self.norm(x)
model = Model()
inputs = (
torch.randn(1, 3, 4, 4, 4),
)
eager_outputs = model(*inputs)
print(f"Eager: {eager_outputs.shape} {eager_outputs}")
ep = torch.export.export(model.eval(), inputs)
import coremltools as ct
import numpy as np
ep = ep.run_decompositions({})
mlmodel = ct.convert(ep)
coreml_inputs = mlmodel.get_spec().description.input
coreml_outputs = mlmodel.get_spec().description.output
predict_inputs = {str(ct_in.name): pt_in.detach().cpu().numpy().astype(np.int32) for ct_in, pt_in in zip(coreml_inputs, inputs)}
out = mlmodel.predict(predict_inputs)
print("CoremL", out)
Output is:
loc("tensor<fp16, [1, 3, 4, 4, 4]> _native_batch_norm_legit_no_training_cast_fp16 = batch_norm(beta = tensor<fp16, [3]>([0, 0, 0]), epsilon = fp16(1.00135803e-05), gamma = tensor<fp16, [3]>([1, 1, 1]), mean = tensor<fp16, [3]>([-0.012588501, 0.0046005249, 0.016494751]), variance = tensor<fp16, [3]>([1.00292969, 1.00195312, 1.01855469]), x = x_to_fp16)[milId = uint64(1), name = string(\22_native_batch_norm_legit_no_training_cast_fp16\22)]; - /private/var/folders/lw/phxpy6k10ll388xs18hyq1cr0000gn/T/tmp63_bi7tb.mlmodelc/model.mil":12:12): error: output type 'tensor<1x3x4x4x4xf16>' and mean type 'tensor<1x0x1x1x601354336xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).
zsh: abort python test.py
/opt/miniconda3/envs/op-et/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
System environment (please complete the following information):
- coremltools version: 8.3
- OS (e.g. MacOS version or Linux type): macOS15
Metadata
Metadata
Assignees
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)triagedReviewed and examined, release as been assigned if applicable (status)Reviewed and examined, release as been assigned if applicable (status)