From 2cdfc7989e635c6a5ec1d2fec29b74415ca42dc5 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team <961186938@qq.com> Date: Mon, 17 Nov 2025 14:13:36 +0800 Subject: [PATCH 1/4] add ane caller --- mlx_lm/models/ane/ops/matmul.py | 136 ++++++++++++++++++++++++++++++++ requirement-dev.txt | 9 +++ requirement-metal.txt | 7 ++ 3 files changed, 152 insertions(+) create mode 100644 mlx_lm/models/ane/ops/matmul.py create mode 100644 requirement-dev.txt create mode 100644 requirement-metal.txt diff --git a/mlx_lm/models/ane/ops/matmul.py b/mlx_lm/models/ane/ops/matmul.py new file mode 100644 index 00000000..1bbca453 --- /dev/null +++ b/mlx_lm/models/ane/ops/matmul.py @@ -0,0 +1,136 @@ +import hashlib + +import os + +import tempfile +import torch + +import numpy as np + +try: + import coremltools as ct + + from CoreML import MLModel, MLModelConfiguration + import objc +except Exception as e: + print(e) + print("Please install CoreML, pyobjc and coremltools, see requirements.txt.") + exit(1) + + +def ane_subgraph_builder(w : np.ndarray, b : np.ndarray = None, input_name="x", prefix : str = "") -> ct.models.MLModel: + if w.ndim != 2: + # reshape + pass + + M, K = x.shape + N = w.shape[0] + + output_name = f'{prefix}/out' + + input_features = [(input_name, ct.models.datatypes.Array(M, K))] + output_features = [(output_name, ct.models.datatypes.Array(M, N))] + + # see https://apple.github.io/coremltools/v3.4/generated/coremltools.models.neural_network.builder.html + builder = ct.models.neural_network.NeuralNetworkBuilder(input_features, output_features) + builder.add_inner_product(name='matmul', input_name=input_name, output_name=output_name, W=weights, b=b, input_channels=K, output_channels=N, has_bias=b != None) + + spec = build.spec + model = ct.models.MLModel(spec) + return model, output_name + + +_cache = {} + +# TODO (yiakwy) : add multi-levels cache +def _hash_matmul(W, b=None): + algo = hashlib.sha256() + algo.update(W.tobytes()) + if b is not None: + algo.update(b.tobytes()) + + return aglo.hexdigest() + + +def matmul(x : np.ndarray, w : np.ndarray, b : np.ndarray = None, prefix : str = "", input_name="x", model=None): + + key = _hash_matmul(W, b=b) + + if model is None: + # TODO (yiakwy) : load from path + + cached = _cache.get(key, None) + if cached is None or cached[2] is None: + model, output_name = ane_subgraph_builder(w, b) + + # TODO (yiakwy) : save to path asynchronously + modelproto_saved_path_dir = os.path.join(tempfile.gettempdir(), "mlx_ane_ops_cache") + os.makedirs(modelproto_saved_path_dir, exist_ok=True) + + modelproto_saved_path = os.path.join(modelproto_saved_path_dir, f"matmul_{key}.mlmodel") + + compiled_mlmodel_path = MLModel.compileModelAtURL_error_(modelproto_saved_path, None) + + config = MLModelConfiguration.alloc().init() + config.computeUnits = "all" + mlmodel_obj = MLModel.modelWithContentsOfURL_configuration_error_(compiled_mlmodel_path, config, None) + + _cache[key] = (modelproto_saved_path, compiled_mlmodel_path, mlmodel_obj) + + model = mlmodel_obj + else: + model = cached[2] + + + # get the moel + + + inputs = { + input_name : x + } + + outputs = model.predictionFromFeatures_error_(inputs, None) + + out = np.array(outputs["out"], dtype=np.float32) + return out + + +def test_fp8_group_scaled_gemm(): + test_configs = [ + (8, 32, 8) + # (128, 256, 128) + # (1024, 4096, 1024) + # (4096, 16384, 4096), + ] + + for M, K, N in test_configs: + print(f"\n{'='*60}") + print(f"Testing M={M}, K={K}, N={N}") + print(f"{'='*60}") + + torch.manual_seed(42) + input_fp16 = torch.randn(M, K, dtype=torch.float16, device="mps") + weights = torch.randint(0, 2, (K, N), device="mps").float() * 2 - 1 + + weights_fp16 = weights.half() + + # print(f"input_fp16 : {input_fp16}") + # print(f"weights_fp16 : {weights_fp16}") + + assert(weights_fp16.is_contiguous()) + + # correctness check + output_torch = torch.matmul(input_fp16, weights_fp16) + + input_fp16_np = input_fp16.cpu().numpy() + weights_fp16_np = weights_fp16.cpu().numpy() + + out = matmul(input_fp16_np, weights_fp16_np) + + max_error = np.max(np.abs(output_torch.cpu().numpy() - output_mlx)) + print(f"Max error between torch and mlx: {max_error:.6f}") + pass + + + if __name__ == "__main__": + test_fp8_group_scaled_gemm() \ No newline at end of file diff --git a/requirement-dev.txt b/requirement-dev.txt new file mode 100644 index 00000000..4716dc53 --- /dev/null +++ b/requirement-dev.txt @@ -0,0 +1,9 @@ +setuptools>=40.8.0 +wheel +cmake>=3.20,<4.0 +ninja>=1.11.1 +pybind11>=2.13.1 +lit + +triton @ git+https://github.com/triton-lang/triton.git + diff --git a/requirement-metal.txt b/requirement-metal.txt new file mode 100644 index 00000000..0d9a65a8 --- /dev/null +++ b/requirement-metal.txt @@ -0,0 +1,7 @@ +# use to build objc codes +pyobjc + +# use to build ane graph +coremltools + +-r requirements-dev.txt From 8897a6310c99b636982d097b4132edbb4a702423 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team <961186938@qq.com> Date: Mon, 17 Nov 2025 18:42:54 +0800 Subject: [PATCH 2/4] fix bugs with wrong api --- mlx_lm/models/ane/ops/matmul.py | 97 ++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/mlx_lm/models/ane/ops/matmul.py b/mlx_lm/models/ane/ops/matmul.py index 1bbca453..acc71fde 100644 --- a/mlx_lm/models/ane/ops/matmul.py +++ b/mlx_lm/models/ane/ops/matmul.py @@ -6,36 +6,44 @@ import torch import numpy as np +import coremltools.models.datatypes as datatypes try: import coremltools as ct + from Cocoa import NSURL from CoreML import MLModel, MLModelConfiguration - import objc + from CoreML import MLDictionaryFeatureProvider, MLFeatureValue + # import objc except Exception as e: print(e) print("Please install CoreML, pyobjc and coremltools, see requirements.txt.") exit(1) -def ane_subgraph_builder(w : np.ndarray, b : np.ndarray = None, input_name="x", prefix : str = "") -> ct.models.MLModel: +def ane_subgraph_builder(x_desc : tuple, w : np.ndarray, b : np.ndarray = None, input_name="x", output_name="out", prefix : str = "") -> ct.models.MLModel: if w.ndim != 2: # reshape pass - M, K = x.shape - N = w.shape[0] + M, K = x_desc + N, K = w.shape - output_name = f'{prefix}/out' + output_name = f'{prefix}/{output_name}' input_features = [(input_name, ct.models.datatypes.Array(M, K))] output_features = [(output_name, ct.models.datatypes.Array(M, N))] # see https://apple.github.io/coremltools/v3.4/generated/coremltools.models.neural_network.builder.html - builder = ct.models.neural_network.NeuralNetworkBuilder(input_features, output_features) - builder.add_inner_product(name='matmul', input_name=input_name, output_name=output_name, W=weights, b=b, input_channels=K, output_channels=N, has_bias=b != None) + builder = ct.models.neural_network.NeuralNetworkBuilder(input_features, output_features, disable_rank5_shape_mapping=True) + builder.add_inner_product(name='matmul', input_name=input_name, output_name=output_name, + W=w, b=b, input_channels=K, output_channels=N, has_bias=b != None) + + spec = builder.spec + spec.description.predictedFeatureName = output_name + + # ct.utils.convert_double_to_float_multiarray_type(spec) - spec = build.spec model = ct.models.MLModel(spec) return model, output_name @@ -49,57 +57,54 @@ def _hash_matmul(W, b=None): if b is not None: algo.update(b.tobytes()) - return aglo.hexdigest() + return algo.hexdigest() -def matmul(x : np.ndarray, w : np.ndarray, b : np.ndarray = None, prefix : str = "", input_name="x", model=None): +def save_model_proto(model : ct.models.MLModel, saved_path : str, model_name : str): + if not os.path.exists(saved_path): + os.makedirs(saved_path, exist_ok=True) - key = _hash_matmul(W, b=b) + model.save(saved_path) - if model is None: - # TODO (yiakwy) : load from path - cached = _cache.get(key, None) - if cached is None or cached[2] is None: - model, output_name = ane_subgraph_builder(w, b) +def matmul(x : np.ndarray, w : np.ndarray, b : np.ndarray = None, prefix : str = "", input_name="x", output_name="out", model=None, **configs): - # TODO (yiakwy) : save to path asynchronously - modelproto_saved_path_dir = os.path.join(tempfile.gettempdir(), "mlx_ane_ops_cache") - os.makedirs(modelproto_saved_path_dir, exist_ok=True) + key = _hash_matmul(w, b=b) - modelproto_saved_path = os.path.join(modelproto_saved_path_dir, f"matmul_{key}.mlmodel") + if model is None: + cached = _cache.get(key, None) + if cached is None: - compiled_mlmodel_path = MLModel.compileModelAtURL_error_(modelproto_saved_path, None) + model, output_name = ane_subgraph_builder(x.shape, w, b, input_name=input_name) + print( f"model : {model}") - config = MLModelConfiguration.alloc().init() - config.computeUnits = "all" - mlmodel_obj = MLModel.modelWithContentsOfURL_configuration_error_(compiled_mlmodel_path, config, None) + dump = configs.get("dump", True) + saved_path = configs.get("saved_path", os.path.join(tempfile.gettempdir(), "mlx_ane_ops_cache")) + model_name = configs.get("model_name", f"op_matmul_{key}.mlmodel") - _cache[key] = (modelproto_saved_path, compiled_mlmodel_path, mlmodel_obj) + # TODO (yiakwy) : save to path asynchronously + if dump: + save_model_proto(model, saved_path, model_name) - model = mlmodel_obj + _cache[key] = (model) else: - model = cached[2] + model = cached[0] + output_name = f'{prefix}/{output_name}' + inputs = {input_name : x.astype(np.float32)} - # get the moel - - - inputs = { - input_name : x - } - - outputs = model.predictionFromFeatures_error_(inputs, None) + outputs = model.predict(inputs) - out = np.array(outputs["out"], dtype=np.float32) + out = np.array(outputs[output_name], dtype=np.float32) return out def test_fp8_group_scaled_gemm(): test_configs = [ - (8, 32, 8) + # (8, 32, 8) + # (1, 1, 1) # (128, 256, 128) - # (1024, 4096, 1024) + (1024, 4096, 1024) # (4096, 16384, 4096), ] @@ -123,14 +128,18 @@ def test_fp8_group_scaled_gemm(): output_torch = torch.matmul(input_fp16, weights_fp16) input_fp16_np = input_fp16.cpu().numpy() - weights_fp16_np = weights_fp16.cpu().numpy() - out = matmul(input_fp16_np, weights_fp16_np) - max_error = np.max(np.abs(output_torch.cpu().numpy() - output_mlx)) + weights_fp16_t = torch.zeros((N, K), dtype=torch.float16, device="mps") + weights_fp16_t[:] = weights_fp16.T[:] + + weights_fp16_np = weights_fp16_t.cpu().numpy() + + output_ane = matmul(input_fp16_np, weights_fp16_np) + + max_error = np.max(np.abs(output_torch.cpu().numpy() - output_ane)) print(f"Max error between torch and mlx: {max_error:.6f}") - pass - if __name__ == "__main__": - test_fp8_group_scaled_gemm() \ No newline at end of file +if __name__ == "__main__": + test_fp8_group_scaled_gemm() \ No newline at end of file From c8bb9c710a562fa717bb1ac671ef6ae53723e3b8 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team <961186938@qq.com> Date: Mon, 17 Nov 2025 19:54:26 +0800 Subject: [PATCH 3/4] add performance benchmark --- mlx_lm/models/ane/ops/matmul.py | 57 +++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/mlx_lm/models/ane/ops/matmul.py b/mlx_lm/models/ane/ops/matmul.py index acc71fde..334d1519 100644 --- a/mlx_lm/models/ane/ops/matmul.py +++ b/mlx_lm/models/ane/ops/matmul.py @@ -11,6 +11,9 @@ try: import coremltools as ct + from coremltools.models.neural_network import quantization_utils + from coremltools.proto.FeatureTypes_pb2 import ArrayFeatureType + from Cocoa import NSURL from CoreML import MLModel, MLModelConfiguration from CoreML import MLDictionaryFeatureProvider, MLFeatureValue @@ -24,7 +27,7 @@ def ane_subgraph_builder(x_desc : tuple, w : np.ndarray, b : np.ndarray = None, input_name="x", output_name="out", prefix : str = "") -> ct.models.MLModel: if w.ndim != 2: # reshape - pass + w = w.reshape(-1, w.shape[-1]) M, K = x_desc N, K = w.shape @@ -42,14 +45,19 @@ def ane_subgraph_builder(x_desc : tuple, w : np.ndarray, b : np.ndarray = None, spec = builder.spec spec.description.predictedFeatureName = output_name - # ct.utils.convert_double_to_float_multiarray_type(spec) + ct.utils.convert_double_to_float_multiarray_type(spec) + + model_fp32 = ct.models.MLModel(spec) + + # https://apple.github.io/coremltools/docs-guides/source/quantization-neural-network.html + model_fp16 = quantization_utils.quantize_weights(model_fp32, nbits=16, quantization_mode='linear',) - model = ct.models.MLModel(spec) - return model, output_name + return model_fp16, output_name _cache = {} + # TODO (yiakwy) : add multi-levels cache def _hash_matmul(W, b=None): algo = hashlib.sha256() @@ -86,7 +94,7 @@ def matmul(x : np.ndarray, w : np.ndarray, b : np.ndarray = None, prefix : str = if dump: save_model_proto(model, saved_path, model_name) - _cache[key] = (model) + _cache[key] = (model,) else: model = cached[0] output_name = f'{prefix}/{output_name}' @@ -95,17 +103,17 @@ def matmul(x : np.ndarray, w : np.ndarray, b : np.ndarray = None, prefix : str = outputs = model.predict(inputs) - out = np.array(outputs[output_name], dtype=np.float32) + out = outputs[output_name] return out def test_fp8_group_scaled_gemm(): test_configs = [ # (8, 32, 8) - # (1, 1, 1) # (128, 256, 128) - (1024, 4096, 1024) - # (4096, 16384, 4096), + # (1024, 4096, 1024) + (4096, 16384, 4096), + # (32, 128, 1024) ] for M, K, N in test_configs: @@ -140,6 +148,37 @@ def test_fp8_group_scaled_gemm(): max_error = np.max(np.abs(output_torch.cpu().numpy() - output_ane)) print(f"Max error between torch and mlx: {max_error:.6f}") + # print(f"diff : {output_ane - output_torch.cpu().numpy()}") + print(f"ane type : {output_ane.dtype}") + + # Performance benchmark for ANE + + import time + + for _ in range(10): + _ = torch.matmul(input_fp16, weights_fp16) + torch.mps.synchronize() + + for _ in range(10): + _ = matmul(input_fp16_np, weights_fp16_np) + + # Benchmark Pytorch MPS backend + start_time = time.time() + for _ in range(10): + _ = torch.matmul(input_fp16, weights_fp16) + torch.mps.synchronize() + torch_elpase = (time.time() - start_time)/ 10 * 1000 + + # Benchmark for ANE + start_time = time.time() + for _ in range(10): + _ = matmul(input_fp16_np, weights_fp16_np) + + mlx_elpase = (time.time() - start_time)/ 10 * 1000 + + print(f"fp16x{M}x{N}x{K} Pytorch MPS backend average time over 10 runs: {torch_elpase:.2f} ms") + print(f"fp16x{M}x{N}x{K} ANE backend average time over 10 runs: {mlx_elpase:.2f} ms") + if __name__ == "__main__": test_fp8_group_scaled_gemm() \ No newline at end of file From b5eb270a2c9391503696f147d8c2dffcbda33cc9 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team <961186938@qq.com> Date: Mon, 17 Nov 2025 21:06:06 +0800 Subject: [PATCH 4/4] add torch ane mix gemm --- mlx_lm/models/ane/ops/matmul.py | 69 ++++++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/mlx_lm/models/ane/ops/matmul.py b/mlx_lm/models/ane/ops/matmul.py index 334d1519..3991cb3f 100644 --- a/mlx_lm/models/ane/ops/matmul.py +++ b/mlx_lm/models/ane/ops/matmul.py @@ -11,6 +11,9 @@ try: import coremltools as ct + from coremltools.converters.mil import Builder as mb + from coremltools.converters.mil.input_types import TensorType + from coremltools.models.neural_network import quantization_utils from coremltools.proto.FeatureTypes_pb2 import ArrayFeatureType @@ -49,7 +52,7 @@ def ane_subgraph_builder(x_desc : tuple, w : np.ndarray, b : np.ndarray = None, model_fp32 = ct.models.MLModel(spec) - # https://apple.github.io/coremltools/docs-guides/source/quantization-neural-network.html + # weights to fp16 https://apple.github.io/coremltools/docs-guides/source/quantization-neural-network.html model_fp16 = quantization_utils.quantize_weights(model_fp32, nbits=16, quantization_mode='linear',) return model_fp16, output_name @@ -107,6 +110,26 @@ def matmul(x : np.ndarray, w : np.ndarray, b : np.ndarray = None, prefix : str = return out +# NOTE(yiakwy) : x.data_ptr() != x.cpu().data_ptr() in pytorch metal backend +def dispatch_matmul(x, w, w_cpu, out=None, div=5): + M, K = x.shape + K, N = w.shape + + if out is None: + out = torch.zeros((M, N), dtype=torch.float16, device="mps") + + x_partial_0_view = x[M // div:M] + out_partial_0_view = out[M // div:M] + + torch.matmul(x_partial_0_view, w, out=out_partial_0_view) + + x_partial_1_view = x[:M // div] + output_partial_ane = matmul(x_partial_1_view.cpu().numpy(), w_cpu) + out[:M // div] = torch.from_numpy(output_partial_ane[:]).to("mps") + + return out + + def test_fp8_group_scaled_gemm(): test_configs = [ # (8, 32, 8) @@ -121,6 +144,8 @@ def test_fp8_group_scaled_gemm(): print(f"Testing M={M}, K={K}, N={N}") print(f"{'='*60}") + use_mx_ane = True + torch.manual_seed(42) input_fp16 = torch.randn(M, K, dtype=torch.float16, device="mps") weights = torch.randint(0, 2, (K, N), device="mps").float() * 2 - 1 @@ -143,13 +168,27 @@ def test_fp8_group_scaled_gemm(): weights_fp16_np = weights_fp16_t.cpu().numpy() - output_ane = matmul(input_fp16_np, weights_fp16_np) + if use_mx_ane: + output_mx_ane = dispatch_matmul(input_fp16, weights_fp16, weights_fp16_np) + + if use_mx_ane: + torch.mps.synchronize() + + diff = output_torch - output_mx_ane + + max_error = np.max(np.abs(diff.cpu().numpy())) + print(f"Max error between torch and mlx: {max_error:.6f}") + print(f"diff : {diff}") + else: + output_ane = matmul(input_fp16_np, weights_fp16_np) + + diff = output_torch.cpu().numpy() - output_ane - max_error = np.max(np.abs(output_torch.cpu().numpy() - output_ane)) - print(f"Max error between torch and mlx: {max_error:.6f}") + max_error = np.max(np.abs(diff)) + print(f"Max error between torch and mlx: {max_error:.6f}") - # print(f"diff : {output_ane - output_torch.cpu().numpy()}") - print(f"ane type : {output_ane.dtype}") + # print(f"diff : {diff}") + # print(f"ane type : {output_ane.dtype}") # Performance benchmark for ANE @@ -160,7 +199,13 @@ def test_fp8_group_scaled_gemm(): torch.mps.synchronize() for _ in range(10): - _ = matmul(input_fp16_np, weights_fp16_np) + if use_mx_ane: + _ = dispatch_matmul(input_fp16, weights_fp16, weights_fp16_np) + else: + _ = matmul(input_fp16_np, weights_fp16_np) + + if use_mx_ane: + torch.mps.synchronize() # Benchmark Pytorch MPS backend start_time = time.time() @@ -172,12 +217,18 @@ def test_fp8_group_scaled_gemm(): # Benchmark for ANE start_time = time.time() for _ in range(10): - _ = matmul(input_fp16_np, weights_fp16_np) + if use_mx_ane: + _ = dispatch_matmul(input_fp16, weights_fp16, weights_fp16_np) + else: + _ = matmul(input_fp16_np, weights_fp16_np) + + if use_mx_ane: + torch.mps.synchronize() mlx_elpase = (time.time() - start_time)/ 10 * 1000 print(f"fp16x{M}x{N}x{K} Pytorch MPS backend average time over 10 runs: {torch_elpase:.2f} ms") - print(f"fp16x{M}x{N}x{K} ANE backend average time over 10 runs: {mlx_elpase:.2f} ms") + print(f"fp16x{M}x{N}x{K} {'ANE_METAL_MIX' if use_mx_ane else 'ANE'} backend average time over 10 runs: {mlx_elpase:.2f} ms") if __name__ == "__main__":