From f356eda98533731f77246e0b0f90ae49e28471c8 Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Thu, 19 Mar 2026 19:57:41 +0800 Subject: [PATCH 01/10] feat: implement matmul_biasadd fusion op --- musa_ext/kernels/nn/musa_matmul_bias_op.cc | 209 +++++++++++ .../mu/graph_fusion/matmul_biasadd_fusion.cc | 242 +++++++++++++ .../mu/graph_fusion/matmul_biasadd_fusion.h | 46 +++ test/fusion/matmul_biasadd_fusion.py | 337 ++++++++++++++++++ 4 files changed, 834 insertions(+) create mode 100644 musa_ext/kernels/nn/musa_matmul_bias_op.cc create mode 100644 musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc create mode 100644 musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h create mode 100644 test/fusion/matmul_biasadd_fusion.py diff --git a/musa_ext/kernels/nn/musa_matmul_bias_op.cc b/musa_ext/kernels/nn/musa_matmul_bias_op.cc new file mode 100644 index 00000000..cb324c3e --- /dev/null +++ b/musa_ext/kernels/nn/musa_matmul_bias_op.cc @@ -0,0 +1,209 @@ +#include +#include +#include + +#include "../utils_op.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/util/matmul_bcast.h" +#include "utils/logging.h" + +namespace tensorflow { +namespace musa { + +template +class MusaMatMulBiasAddOp : public MusaOpKernel { + public: + explicit MusaMatMulBiasAddOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &trans_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &trans_b_)); + } + + void Compute(OpKernelContext* ctx) override { + MUSA_KERNEL_TIMING_GUARD(ctx); + + const Tensor& in0 = ctx->input(0); // a + const Tensor& in1 = ctx->input(1); // b + const Tensor& bias = ctx->input(2); // bias + + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias.shape()), + errors::InvalidArgument( + "bias must be a 1-D tensor, but got shape: ", + bias.shape().DebugString())); + + // ---------------------------- + // 1. Infer MatMul output shape + // ---------------------------- + MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "Incompatible shapes for MatMul: ", + in0.shape().DebugString(), " vs ", + in1.shape().DebugString())); + + OP_REQUIRES(ctx, in0.dims() >= 2, + errors::InvalidArgument("Input a must have rank >= 2, got rank ", + in0.dims())); + OP_REQUIRES(ctx, in1.dims() >= 2, + errors::InvalidArgument("Input b must have rank >= 2, got rank ", + in1.dims())); + + int64 d0 = in0.dim_size(in0.dims() - 2); + int64 d1 = in0.dim_size(in0.dims() - 1); + int64 d2 = in1.dim_size(in1.dims() - 2); + int64 d3 = in1.dim_size(in1.dims() - 1); + + int64 m = trans_a_ ? d1 : d0; + int64 k = trans_a_ ? d0 : d1; + int64 n = trans_b_ ? d2 : d3; + int64 k_check = trans_b_ ? d3 : d2; + + OP_REQUIRES( + ctx, k == k_check, + errors::InvalidArgument("Matrix size incompatible: lhs k=", k, + ", rhs k=", k_check, + ", lhs shape=", in0.shape().DebugString(), + ", rhs shape=", in1.shape().DebugString())); + + OP_REQUIRES( + ctx, bias.dim_size(0) == n, + errors::InvalidArgument( + "Bias size mismatch: bias.shape[0] must equal MatMul output's last " + "dimension. Got bias size = ", + bias.dim_size(0), ", expected = ", n)); + + TensorShape out_shape = bcast.output_batch_shape(); + out_shape.AddDim(m); + out_shape.AddDim(n); + + TensorShape mm_out_shape = out_shape; + + // ----------------------------------- + // 2. Allocate temp for MatMul result + // ----------------------------------- + Tensor mm_out_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(in0.dtype(), mm_out_shape, &mm_out_tensor)); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); + + if (output->NumElements() == 0) { + return; + } + + // ---------------------------- + // 3. mudnn MatMul / BatchMatMul + // ---------------------------- + auto& handle = GetHandleByCtx(ctx); + handle.SetAllowTF32(tf32_enabled_); + + mTensor mt_a = CreateMTensor(in0); + mTensor mt_b = CreateMTensor(in1); + mTensor mt_mm_out = CreateMTensor(mm_out_tensor); + + ::musa::dnn::Status status; + + if (in0.dims() == 2 && in1.dims() == 2) { + mMatMul op; + op.SetTranspose(trans_a_, trans_b_); + op.SetAlpha(1.0); + op.SetBeta(0.0); + + status = op.Run(handle, mt_mm_out, mt_a, mt_b); + } else { + mBatchMatMul op; + op.SetTranspose(trans_a_, trans_b_); + op.SetAlpha(1.0); + op.SetBeta(0.0); + + int64_t out_batch = bcast.output_batch_shape().num_elements(); + + auto ReshapeTo3D = [out_batch](mTensor& mt, const Tensor& t) { + int64_t dims = t.dims(); + int64_t rows = t.dim_size(dims - 2); + int64_t cols = t.dim_size(dims - 1); + int64_t batch = t.NumElements() / (rows * cols); + + + if (dims != 3 || (batch == 1 && out_batch > 1)) { + mt.SetNdInfo( + {batch == 1 && out_batch > 1 ? out_batch : batch, rows, cols}, + {batch == 1 && out_batch > 1 ? 0 : rows * cols, cols, 1}); + } + }; + + ReshapeTo3D(mt_a, in0); + ReshapeTo3D(mt_b, in1); + mt_mm_out.SetNdInfo({out_batch, m, n}, {m * n, n, 1}); + + status = op.Run(handle, mt_mm_out, mt_a, mt_b); + } + + OP_REQUIRES( + ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("MUSA MatMul/BatchMatMul failed in MusaMatMulBiasAdd")); + + // ---------------------------- + // 4. mudnn BiasAdd + // ---------------------------- + mTensor mt_bias = CreateMTensor(bias); + mTensor mt_out = CreateMTensor(*output); + + const int dims_cnt = mm_out_shape.dims(); + const int channel_dim = dims_cnt - 1; + + std::vector b_dims(dims_cnt, 1); + std::vector b_strides(dims_cnt, 0); + b_dims[channel_dim] = bias.dim_size(0); + b_strides[channel_dim] = 1; + + + mt_bias.SetNdInfo(dims_cnt, b_dims.data(), b_strides.data()); + + mBinary bias_add_op; + bias_add_op.SetMode(::musa::dnn::Binary::Mode::ADD); + + status = bias_add_op.Run(handle, mt_out, mt_mm_out, mt_bias); + + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("MUSA BiasAdd failed in MusaMatMulBiasAdd")); + } + + bool IsExpensive() override { return true; } + + private: + bool trans_a_ = false; + bool trans_b_ = false; + bool tf32_enabled_ = false; +}; + +#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ + MusaMatMulBiasAddOp) + +REGISTER_MUSA_MATMUL_BIASADD(float); +REGISTER_MUSA_MATMUL_BIASADD(Eigen::half); +REGISTER_MUSA_MATMUL_BIASADD(bfloat16); + +#undef REGISTER_MUSA_MATMUL_BIASADD + +} // namespace musa + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("MusaMatMulBiasAdd") + .Input("a: T") + .Input("b: T") + .Input("bias: T") + .Output("product: T") + .Attr("T: {float, half, bfloat16}") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .SetShapeFn(::tensorflow::shape_inference::MatMulShape); + +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc new file mode 100644 index 00000000..8b7aca3b --- /dev/null +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc @@ -0,0 +1,242 @@ +#include "mu/graph_fusion/matmul_biasadd_fusion.h" + +#include +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace grappler { +namespace musa_fusion { + +namespace { + +// Helper to check if node has specific op type +bool IsOp(const NodeDef& node, const std::string& op_type) { + return node.op() == op_type; +} + +// Helper to find node's input producer +const NodeDef* FindProducer(const GraphDef& graph, const std::string& input) { + if (input.empty()) return nullptr; + + std::string node_name = input; + if (node_name[0] == '^') { + node_name = node_name.substr(1); + } + const size_t colon_pos = node_name.find(':'); + if (colon_pos != std::string::npos) { + node_name = node_name.substr(0, colon_pos); + } + + for (int i = 0; i < graph.node_size(); ++i) { + if (graph.node(i).name() == node_name) { + return &graph.node(i); + } + } + return nullptr; +} + +bool HasOriginalSuffix(const std::string& node_name) { + static const std::string kOriginalSuffix = "_original"; + return node_name.size() >= kOriginalSuffix.size() && + node_name.compare(node_name.size() - kOriginalSuffix.size(), + kOriginalSuffix.size(), kOriginalSuffix) == 0; +} + +} // namespace + +bool MatMulBiasAddFusion::IsKernelAvailable() const { + if (!kernel_checked_) { + kernel_available_ = true; + kernel_checked_ = true; + } + return kernel_available_; +} + +FusionMatchResult MatMulBiasAddFusion::Match(const GraphDef& graph, + int start_node_idx) const { + FusionMatchResult result; + if (start_node_idx < 0 || start_node_idx >= graph.node_size()) { + return result; + } + + const NodeDef& bias_add_node = graph.node(start_node_idx); + + // Start node must be BiasAdd / Add / AddV2 + if (!IsOp(bias_add_node, "BiasAdd")) { + return result; + } + + if (HasOriginalSuffix(bias_add_node.name())) { + return result; + } + + // Find MatMul node and bias node from the two inputs of BiasAdd/Add/AddV2 + const NodeDef* matmul_node = nullptr; + const NodeDef* bias_node = nullptr; + + if (bias_add_node.input_size() >= 2) { + const NodeDef* in0 = FindProducer(graph, bias_add_node.input(0)); + const NodeDef* in1 = FindProducer(graph, bias_add_node.input(1)); + + if (in0 && IsOp(*in0, "MatMul")) { + matmul_node = in0; + bias_node = in1; + } else if (in1 && IsOp(*in1, "MatMul")) { + matmul_node = in1; + bias_node = in0; + } + } + + if (!matmul_node || !bias_node) { + return result; + } + + // Record into result + result.matched = true; + result.matched_nodes.push_back(&bias_add_node); + result.matched_nodes.push_back(matmul_node); + + result.captured_nodes["output"] = &bias_add_node; + result.captured_nodes["bias_add"] = &bias_add_node; + result.captured_nodes["matmul"] = matmul_node; + result.captured_nodes["bias"] = bias_node; + + return result; +} + +Status MatMulBiasAddFusion::Apply( + GraphDef* graph, const FusionMatchResult& match_result) const { + if (!match_result.IsValid()) { + return Status(error::INVALID_ARGUMENT, + "Invalid MatMulBiasAdd match result"); + } + + if (!IsKernelAvailable()) { + return Status::OK(); + } + + auto output_it = match_result.captured_nodes.find("output"); + auto matmul_it = match_result.captured_nodes.find("matmul"); + auto bias_it = match_result.captured_nodes.find("bias"); + auto bias_add_it = match_result.captured_nodes.find("bias_add"); + + if (output_it == match_result.captured_nodes.end() || + matmul_it == match_result.captured_nodes.end() || + bias_it == match_result.captured_nodes.end() || + bias_add_it == match_result.captured_nodes.end()) { + return Status(error::INVALID_ARGUMENT, + "Missing required nodes in MatMulBiasAdd pattern"); + } + + const NodeDef* output_node = output_it->second; + const NodeDef* matmul_node = matmul_it->second; + const NodeDef* bias_node = bias_it->second; + const NodeDef* bias_add_node = bias_add_it->second; + + const std::string original_name = output_node->name(); + const std::string original_output_name = original_name + "_original"; + + // Avoid duplicate fusion + for (const auto& node : graph->node()) { + if (node.name() == original_name && node.op() == "MusaMatMulBiasAdd") { + VLOG(1) << "MusaMatMulBiasAdd: Output node " << original_name + << " is already a fused node, skipping"; + return Status::OK(); + } + } + + int output_node_idx = -1; + for (int i = 0; i < graph->node_size(); ++i) { + if (graph->node(i).name() == original_name) { + output_node_idx = i; + break; + } + } + + if (output_node_idx < 0) { + return Status(error::INVALID_ARGUMENT, + "Failed to find output node in graph: " + original_name); + } + + VLOG(1) << "MatMulBiasAddFusion: Replacing " << original_name + << " with MusaMatMulBiasAdd"; + + NodeDef* original_output_node = graph->mutable_node(output_node_idx); + const std::string output_device = original_output_node->device(); + + // Pick dtype from MatMul first, then output node, otherwise float + AttrValue output_dtype; + auto dtype_it = matmul_node->attr().find("T"); + if (dtype_it != matmul_node->attr().end()) { + output_dtype = dtype_it->second; + } else { + dtype_it = original_output_node->attr().find("T"); + if (dtype_it != original_output_node->attr().end()) { + output_dtype = dtype_it->second; + } else { + output_dtype.set_type(DT_FLOAT); + } + } + + // Rename the original output node + original_output_node->set_name(original_output_name); + + // Create fused node using the original output node name + NodeDef* fused_node = graph->add_node(); + fused_node->set_name(original_name); + fused_node->set_op("MusaMatMulBiasAdd"); + fused_node->set_device(output_device); + + // MusaMatMulBiasAdd inputs: a, b, bias + fused_node->add_input(matmul_node->input(0)); + fused_node->add_input(matmul_node->input(1)); + + // Keep original bias edge string exactly, do not replace with bias_node->name() + // because input could be "bias:0" instead of just "bias". + fused_node->add_input( + bias_add_node->input(bias_add_node->input(0) == matmul_node->name() ? 1 + : 0)); + + auto* attr = fused_node->mutable_attr(); + (*attr)["T"] = output_dtype; + + if (matmul_node->attr().count("transpose_a")) { + (*attr)["transpose_a"] = matmul_node->attr().at("transpose_a"); + } else { + (*attr)["transpose_a"].set_b(false); + } + + if (matmul_node->attr().count("transpose_b")) { + (*attr)["transpose_b"] = matmul_node->attr().at("transpose_b"); + } else { + (*attr)["transpose_b"].set_b(false); + } + + // Remove matched nodes if now unused. + std::vector removable_names = { + original_output_name, matmul_node->name() + }; + + FusionGraphUtils::RemoveNodesIfUnused( + graph, removable_names, + {matmul_node->input(0), matmul_node->input(1), bias_node->name(), + original_name}); + + VLOG(1) << "MatMulBiasAddFusion: Successfully replaced '" << original_name + << "' with MusaMatMulBiasAdd"; + + return Status::OK(); +} + +// Register the pattern +REGISTER_FUSION_PATTERN(MatMulBiasAddFusion); + +// Register kernel availability +REGISTER_FUSION_KERNEL(MatMulBiasAddFusion, []() { return true; }); + +} // namespace musa_fusion +} // namespace grappler +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h new file mode 100644 index 00000000..212dd25d --- /dev/null +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include "mu/graph_fusion/fusion_pattern_manager.h" + +namespace tensorflow { +namespace grappler { +namespace musa_fusion { + +// Computes: MatMul + BiasAdd + +class MatMulBiasAddFusion : public FusionPattern { + public: + MatMulBiasAddFusion() = default; + ~MatMulBiasAddFusion() override = default; + + FusionMatchResult Match(const GraphDef& graph, + int start_node_idx) const override; + + Status Apply(GraphDef* graph, + const FusionMatchResult& match_result) const override; + + int GetPriority() const override { return 100; } + + bool IsKernelAvailable() const override; + + std::string GetName() const override { return "MatMulBiasAddFusion"; } + + std::string GetFallbackReason() const override { + if (!kernel_available_) { + return "MatMulBiasAddFusion kernel not available on this device"; + } + return ""; + } + + private: + // Kernel availability flag + mutable bool kernel_available_ = true; + mutable bool kernel_checked_ = false; +}; + +} // namespace musa_fusion +} // namespace grappler +} // namespace tensorflow \ No newline at end of file diff --git a/test/fusion/matmul_biasadd_fusion.py b/test/fusion/matmul_biasadd_fusion.py new file mode 100644 index 00000000..f98d4c18 --- /dev/null +++ b/test/fusion/matmul_biasadd_fusion.py @@ -0,0 +1,337 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MatMul+BiasAdd fusion.""" + +import os +import numpy as np +import tensorflow as tf +from musa_test_utils import MUSATestCase + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 + + +def create_config_with_musa_optimizer(): + """Create ConfigProto with MUSA optimizer enabled.""" + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + rewriter_config = config.graph_options.rewrite_options + + custom_optimizer = rewriter_config.custom_optimizers.add() + custom_optimizer.name = "musa_graph_optimizer" + + rewriter_config.min_graph_nodes = -1 + rewriter_config.optimizers.extend(["musa_graph_optimizer"]) + + return config + + +class MatMulBiasAddFusionTest(MUSATestCase): + """Tests for MatMul+BiasAdd fusion.""" + + def test_matmul_biasadd_fusion_basic(self): + """Test MatMul+BiasAdd pattern fusion.""" + np.random.seed(42) + tf.random.set_seed(42) + + m, k, n = 4, 8, 16 + + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference implementation (CPU) + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np) + w_tf = tf.constant(w_np) + b_tf = tf.constant(b_np) + + mm = tf.matmul(x_tf, w_tf) + expected_out = tf.nn.bias_add(mm, b_tf) + # Add a consumer to ensure it's not pruned and has someone to redirect to + expected_out = expected_out * 2.0 + + # Build graph with explicit MUSA device placement + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x") + w = tf.constant(w_np, dtype=tf.float32, name="w") + b = tf.constant(b_np, dtype=tf.float32, name="b") + + # This pattern should be matched by MatMulBiasAddFusion + mm_musa = tf.matmul(x, w) + bias_musa = tf.nn.bias_add(mm_musa, b) + # Add a consumer node + output = bias_musa * 2.0 + + config = create_config_with_musa_optimizer() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual_out = sess.run(output, feed_dict={x: x_np}) + + self.assertAllClose(actual_out, expected_out.numpy(), rtol=1e-5, atol=1e-5) + + def test_matmul_biasadd_fusion_applied(self): + """Verify that MatMul+BiasAdd fusion is applied: MusaMatMulBiasAdd node exists in optimized graph.""" + np.random.seed(123) + tf.random.set_seed(123) + + m, k, n = 4, 8, 16 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x") + w = tf.constant(w_np, dtype=tf.float32, name="w") + b = tf.constant(b_np, dtype=tf.float32, name="b") + + mm_musa = tf.matmul(x, w) + bias_musa = tf.nn.bias_add(mm_musa, b) + # Add a consumer node + output = bias_musa * 2.0 + + config = create_config_with_musa_optimizer() + run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) + run_metadata = tf.compat.v1.RunMetadata() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + sess.run(output, feed_dict={x: x_np}, options=run_options, run_metadata=run_metadata) + + has_fused_node = False + for partition_graph in run_metadata.partition_graphs: + for node in partition_graph.node: + if node.op == "MusaMatMulBiasAdd": + has_fused_node = True + break + + self.assertTrue( + has_fused_node, + "MusaMatMulBiasAdd fusion was NOT applied to the graph" + ) + + def test_matmul_biasadd_fusion_various_batch_sizes(self): + """Test fusion correctness across several batch sizes.""" + np.random.seed(7) + tf.random.set_seed(7) + + k, n = 6, 10 + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + for m in (1, 3, 8): + x_np = np.random.randn(m, k).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np) + w_tf = tf.constant(w_np) + b_tf = tf.constant(b_np) + expected = tf.nn.bias_add(tf.matmul(x_tf, w_tf), b_tf) * 1.5 + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_bs") + w = tf.constant(w_np, dtype=tf.float32, name="w_bs") + b = tf.constant(b_np, dtype=tf.float32, name="b_bs") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + # extra consumer + out = bias * 1.5 + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-5, atol=1e-5) + + def test_matmul_biasadd_fusion_not_applied_with_intervening_op(self): + """If an extra op exists between MatMul and BiasAdd, fusion should not occur.""" + np.random.seed(99) + tf.random.set_seed(99) + + m, k, n = 2, 5, 7 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_int") + w = tf.constant(w_np, dtype=tf.float32, name="w_int") + b = tf.constant(b_np, dtype=tf.float32, name="b_int") + + mm = tf.matmul(x, w) + # Insert an identity (or any intervening op) to block fusion + mid = tf.identity(mm, name="intervening_identity") + bias = tf.nn.bias_add(mid, b) + output = bias * 2.0 + + config = create_config_with_musa_optimizer() + run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) + run_metadata = tf.compat.v1.RunMetadata() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + sess.run(output, feed_dict={x: x_np}, options=run_options, run_metadata=run_metadata) + + has_fused_node = False + for partition_graph in run_metadata.partition_graphs: + for node in partition_graph.node: + if node.op == "MusaMatMulBiasAdd": + has_fused_node = True + break + + self.assertFalse( + has_fused_node, + "MusaMatMulBiasAdd fusion should NOT be applied when an intervening op exists" + ) + + def test_matmul_biasadd_fusion_dtypes(self): + """Test fusion correctness across multiple dtypes: float32, float16, bfloat16.""" + np.random.seed(21) + tf.random.set_seed(21) + + m, k, n = 3, 6, 8 + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + dtypes = [tf.float32, tf.float16, tf.bfloat16] + + for dtype in dtypes: + x_np = np.random.randn(m, k).astype(np.float32) + + # Reference computed in float32 + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np, dtype=tf.float32) + w_tf = tf.constant(w_np, dtype=tf.float32) + b_tf = tf.constant(b_np, dtype=tf.float32) + expected = tf.nn.bias_add(tf.matmul(x_tf, w_tf), b_tf) * 0.75 + expected_f32 = expected.numpy() + + # Build MUSA graph: accept float32 feeds then cast to target dtype inside graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x_ph = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_dt") + x = tf.cast(x_ph, dtype) + w = tf.constant(w_np, dtype=dtype, name="w_dt") + b = tf.constant(b_np, dtype=dtype, name="b_dt") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + out = bias * tf.constant(0.75, dtype=dtype) + # cast back to float32 for stable comparison + out_f32 = tf.cast(out, tf.float32) + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out_f32, feed_dict={x_ph: x_np}) + + if dtype == tf.float32: + rtol, atol = 1e-5, 1e-5 + elif dtype == tf.float16: + rtol, atol = 1e-2, 1e-2 + else: # bfloat16 + rtol, atol = 2e-2, 2e-2 + + self.assertAllClose(actual, expected_f32, rtol=rtol, atol=atol) + + def test_matmul_biasadd_fusion_large_features(self): + """Optional large-feature test. Enable by setting MUSA_RUN_LARGE_TESTS=1.""" + if not os.environ.get("MUSA_RUN_LARGE_TESTS"): + self.skipTest("Large tests disabled; set MUSA_RUN_LARGE_TESTS=1 to run") + + np.random.seed(321) + tf.random.set_seed(321) + + m, k, n = 128, 2048, 1024 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + expected = tf.nn.bias_add( + tf.matmul(tf.constant(x_np), tf.constant(w_np)), + tf.constant(b_np) + ) * 0.9 + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_large_feat") + w = tf.constant(w_np, dtype=tf.float32, name="w_large_feat") + b = tf.constant(b_np, dtype=tf.float32, name="b_large_feat") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + out = bias * 0.9 + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-4, atol=1e-4) + + def test_matmul_biasadd_fusion_large_batch(self): + """Optional large-batch test. Enable by setting MUSA_RUN_LARGE_TESTS=1.""" + if not os.environ.get("MUSA_RUN_LARGE_TESTS"): + self.skipTest("Large tests disabled; set MUSA_RUN_LARGE_TESTS=1 to run") + + np.random.seed(123) + tf.random.set_seed(123) + + m, k, n = 2048, 512, 512 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + expected = tf.nn.bias_add( + tf.matmul(tf.constant(x_np), tf.constant(w_np)), + tf.constant(b_np) + ) + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_large_batch") + w = tf.constant(w_np, dtype=tf.float32, name="w_large_batch") + b = tf.constant(b_np, dtype=tf.float32, name="b_large_batch") + + mm = tf.matmul(x, w) + out = tf.nn.bias_add(mm, b) + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + tf.test.main() \ No newline at end of file From 5eb6d6039b0baac0cf9feae1936672c74fa79211 Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Wed, 25 Mar 2026 10:07:13 +0800 Subject: [PATCH 02/10] MatmulBias unfinished --- musa_ext/kernels/nn/musa_matmul_bias_op.txt | 163 +++++++++ .../kernels/nn/musa_matmul_bias_op_ori.cc | 205 +++++++++++ test/fusion/matmul_biasadd_fusion_test.py | 337 ++++++++++++++++++ 3 files changed, 705 insertions(+) create mode 100644 musa_ext/kernels/nn/musa_matmul_bias_op.txt create mode 100644 musa_ext/kernels/nn/musa_matmul_bias_op_ori.cc create mode 100644 test/fusion/matmul_biasadd_fusion_test.py diff --git a/musa_ext/kernels/nn/musa_matmul_bias_op.txt b/musa_ext/kernels/nn/musa_matmul_bias_op.txt new file mode 100644 index 00000000..758a5321 --- /dev/null +++ b/musa_ext/kernels/nn/musa_matmul_bias_op.txt @@ -0,0 +1,163 @@ +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "../utils_op.h" + +namespace tensorflow { +namespace musa { + +template +class MusaMatMulBiasAddOp : public MusaOpKernel { + public: + explicit MusaMatMulBiasAddOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + } + + bool IsExpensive() override { return true; } + + void Compute(OpKernelContext* ctx) override { + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + const Tensor& bias = ctx->input(2); + + OP_REQUIRES(ctx, a.dims() == 2, + errors::InvalidArgument("MatMulBiasAdd requires input a to be 2D, got shape ", + a.shape().DebugString())); + OP_REQUIRES(ctx, b.dims() == 2, + errors::InvalidArgument("MatMulBiasAdd requires input b to be 2D, got shape ", + b.shape().DebugString())); + OP_REQUIRES(ctx, bias.dims() == 1, + errors::InvalidArgument("MatMulBiasAdd requires bias to be 1D, got shape ", + bias.shape().DebugString())); + + if (a.NumElements() == 0 || b.NumElements() == 0 || bias.NumElements() == 0) { + TensorShape out_shape; + OP_REQUIRES_OK(ctx, ComputeOutputShape(a, b, &out_shape)); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); + return; + } + + const int64_t a_rows = a.dim_size(0); + const int64_t a_cols = a.dim_size(1); + const int64_t b_rows = b.dim_size(0); + const int64_t b_cols = b.dim_size(1); + + const int64_t m = transpose_a_ ? a_cols : a_rows; + const int64_t k_a = transpose_a_ ? a_rows : a_cols; + const int64_t k_b = transpose_b_ ? b_cols : b_rows; + const int64_t n = transpose_b_ ? b_rows : b_cols; + + OP_REQUIRES( + ctx, k_a == k_b, + errors::InvalidArgument("Matrix size-incompatible: a shape ", + a.shape().DebugString(), ", b shape ", + b.shape().DebugString(), ", transpose_a=", + transpose_a_, ", transpose_b=", transpose_b_)); + + OP_REQUIRES( + ctx, bias.dim_size(0) == n, + errors::InvalidArgument("Bias dimension mismatch: bias shape ", + bias.shape().DebugString(), + ", expected [", n, "]")); + + TensorShape out_shape({m, n}); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); + + auto& handle = GetHandleByCtx(ctx); + + mTensor mt_a = CreateMTensor(a, format_); + mTensor mt_b = CreateMTensor(b, format_); + mTensor mt_bias = CreateMTensor(bias, format_); + mTensor mt_out = CreateMTensor(*output, format_); + + ::musa::dnn::MatMul op; + auto status = op.SetTranspose(transpose_a_, transpose_b_); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetTranspose failed, status=", + static_cast(status))); + + status = op.SetAlpha(1.0); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetAlpha failed, status=", + static_cast(status))); + + status = op.SetBeta(0.0); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetBeta failed, status=", + static_cast(status))); + + status = op.SetGamma(1.0); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetGamma failed, status=", + static_cast(status))); + + + status = op.RunWithBiasAdd(handle, mt_out, mt_a, mt_b, mt_bias); + + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMulBiasAdd failed, status=", + static_cast(status))); + } + + private: + Status ComputeOutputShape(const Tensor& a, const Tensor& b, + TensorShape* out_shape) { + const int64_t a_rows = a.dim_size(0); + const int64_t a_cols = a.dim_size(1); + const int64_t b_rows = b.dim_size(0); + const int64_t b_cols = b.dim_size(1); + + const int64_t m = transpose_a_ ? a_cols : a_rows; + const int64_t k_a = transpose_a_ ? a_rows : a_cols; + const int64_t k_b = transpose_b_ ? b_cols : b_rows; + const int64_t n = transpose_b_ ? b_rows : b_cols; + + if (k_a != k_b) { + return errors::InvalidArgument( + "Matrix size-incompatible: a shape ", a.shape().DebugString(), + ", b shape ", b.shape().DebugString(), + ", transpose_a=", transpose_a_, + ", transpose_b=", transpose_b_); + } + + *out_shape = TensorShape({m, n}); + return Status::OK(); + } + + private: + bool transpose_a_; + bool transpose_b_; +}; + +#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ + MusaMatMulBiasAddOp); + +REGISTER_MUSA_MATMUL_BIASADD(float); +// REGISTER_MUSA_MATMUL_BIASADD(double); +REGISTER_MUSA_MATMUL_BIASADD(Eigen::half); +REGISTER_MUSA_MATMUL_BIASADD(bfloat16); + +#undef REGISTER_MUSA_MATMUL_BIASADD + +} // namespace musa + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("MusaMatMulBiasAdd") + .Input("a: T") + .Input("b: T") + .Input("bias: T") + .Output("product: T") + .Attr("T: {float, half, bfloat16}") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .SetShapeFn(::tensorflow::shape_inference::MatMulShape); +} // namespace tensorflow diff --git a/musa_ext/kernels/nn/musa_matmul_bias_op_ori.cc b/musa_ext/kernels/nn/musa_matmul_bias_op_ori.cc new file mode 100644 index 00000000..76483c68 --- /dev/null +++ b/musa_ext/kernels/nn/musa_matmul_bias_op_ori.cc @@ -0,0 +1,205 @@ +#include +#include +#include + +#include "../utils_op.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/util/matmul_bcast.h" +#include "utils/logging.h" + +namespace tensorflow { +namespace musa { + +template +class MusaMatMulBiasAddOp : public MusaOpKernel { + public: + explicit MusaMatMulBiasAddOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &trans_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &trans_b_)); + } + + void Compute(OpKernelContext* ctx) override { + MUSA_KERNEL_TIMING_GUARD(ctx); + + const Tensor& in0 = ctx->input(0); // a + const Tensor& in1 = ctx->input(1); // b + const Tensor& bias = ctx->input(2); // bias + + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(bias.shape()), + errors::InvalidArgument("bias must be a 1-D tensor, but got shape: ", + bias.shape().DebugString())); + + // ---------------------------- + // 1. Infer MatMul output shape + // ---------------------------- + MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument("Incompatible shapes for MatMul: ", + in0.shape().DebugString(), " vs ", + in1.shape().DebugString())); + + OP_REQUIRES(ctx, in0.dims() >= 2, + errors::InvalidArgument( + "Input a must have rank >= 2, got rank ", in0.dims())); + OP_REQUIRES(ctx, in1.dims() >= 2, + errors::InvalidArgument( + "Input b must have rank >= 2, got rank ", in1.dims())); + + int64 d0 = in0.dim_size(in0.dims() - 2); + int64 d1 = in0.dim_size(in0.dims() - 1); + int64 d2 = in1.dim_size(in1.dims() - 2); + int64 d3 = in1.dim_size(in1.dims() - 1); + + int64 m = trans_a_ ? d1 : d0; + int64 k = trans_a_ ? d0 : d1; + int64 n = trans_b_ ? d2 : d3; + int64 k_check = trans_b_ ? d3 : d2; + + OP_REQUIRES(ctx, k == k_check, + errors::InvalidArgument( + "Matrix size incompatible: lhs k=", k, ", rhs k=", k_check, + ", lhs shape=", in0.shape().DebugString(), + ", rhs shape=", in1.shape().DebugString())); + + OP_REQUIRES( + ctx, bias.dim_size(0) == n, + errors::InvalidArgument( + "Bias size mismatch: bias.shape[0] must equal MatMul output's last " + "dimension. Got bias size = ", + bias.dim_size(0), ", expected = ", n)); + + TensorShape out_shape = bcast.output_batch_shape(); + out_shape.AddDim(m); + out_shape.AddDim(n); + + TensorShape mm_out_shape = out_shape; + + // ----------------------------------- + // 2. Allocate temp for MatMul result + // ----------------------------------- + Tensor mm_out_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(in0.dtype(), mm_out_shape, &mm_out_tensor)); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); + + if (output->NumElements() == 0) { + return; + } + + // ---------------------------- + // 3. mudnn MatMul / BatchMatMul + // ---------------------------- + auto& handle = GetHandleByCtx(ctx); + handle.SetAllowTF32(tf32_enabled_); + + mTensor mt_a = CreateMTensor(in0); + mTensor mt_b = CreateMTensor(in1); + mTensor mt_mm_out = CreateMTensor(mm_out_tensor); + + ::musa::dnn::Status status; + + if (in0.dims() == 2 && in1.dims() == 2) { + mMatMul op; + op.SetTranspose(trans_a_, trans_b_); + op.SetAlpha(1.0); + op.SetBeta(0.0); + + status = op.Run(handle, mt_mm_out, mt_a, mt_b); + } else { + mBatchMatMul op; + op.SetTranspose(trans_a_, trans_b_); + op.SetAlpha(1.0); + op.SetBeta(0.0); + + int64_t out_batch = bcast.output_batch_shape().num_elements(); + + auto ReshapeTo3D = [out_batch](mTensor& mt, const Tensor& t) { + int64_t dims = t.dims(); + int64_t rows = t.dim_size(dims - 2); + int64_t cols = t.dim_size(dims - 1); + int64_t batch = t.NumElements() / (rows * cols); + + if (dims != 3 || (batch == 1 && out_batch > 1)) { + mt.SetNdInfo( + {batch == 1 && out_batch > 1 ? out_batch : batch, rows, cols}, + {batch == 1 && out_batch > 1 ? 0 : rows * cols, cols, 1}); + } + }; + + ReshapeTo3D(mt_a, in0); + ReshapeTo3D(mt_b, in1); + mt_mm_out.SetNdInfo({out_batch, m, n}, {m * n, n, 1}); + + status = op.Run(handle, mt_mm_out, mt_a, mt_b); + } + + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal( + "MUSA MatMul/BatchMatMul failed in MusaMatMulBiasAdd")); + + // ---------------------------- + // 4. mudnn BiasAdd + // ---------------------------- + mTensor mt_bias = CreateMTensor(bias); + mTensor mt_out = CreateMTensor(*output); + + const int dims_cnt = mm_out_shape.dims(); + const int channel_dim = dims_cnt - 1; + + std::vector b_dims(dims_cnt, 1); + std::vector b_strides(dims_cnt, 0); + b_dims[channel_dim] = bias.dim_size(0); + b_strides[channel_dim] = 1; + + mt_bias.SetNdInfo(dims_cnt, b_dims.data(), b_strides.data()); + + mBinary bias_add_op; + bias_add_op.SetMode(::musa::dnn::Binary::Mode::ADD); + + status = bias_add_op.Run(handle, mt_out, mt_mm_out, mt_bias); + + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("MUSA BiasAdd failed in MusaMatMulBiasAdd")); + } + + bool IsExpensive() override { return true; } + + private: + bool trans_a_ = false; + bool trans_b_ = false; + bool tf32_enabled_ = false; +}; + +#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ + MusaMatMulBiasAddOp) + +REGISTER_MUSA_MATMUL_BIASADD(float); +REGISTER_MUSA_MATMUL_BIASADD(Eigen::half); +REGISTER_MUSA_MATMUL_BIASADD(bfloat16); + +#undef REGISTER_MUSA_MATMUL_BIASADD + +} // namespace musa + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("MusaMatMulBiasAdd") + .Input("a: T") + .Input("b: T") + .Input("bias: T") + .Output("product: T") + .Attr("T: {float, half, bfloat16}") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .SetShapeFn(::tensorflow::shape_inference::MatMulShape); + +} // namespace tensorflow diff --git a/test/fusion/matmul_biasadd_fusion_test.py b/test/fusion/matmul_biasadd_fusion_test.py new file mode 100644 index 00000000..7ce2e7bf --- /dev/null +++ b/test/fusion/matmul_biasadd_fusion_test.py @@ -0,0 +1,337 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MatMul+BiasAdd fusion.""" + +import os +import numpy as np +import tensorflow as tf +from musa_test_utils import MUSATestCase + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 + + +def create_config_with_musa_optimizer(): + """Create ConfigProto with MUSA optimizer enabled.""" + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + rewriter_config = config.graph_options.rewrite_options + + custom_optimizer = rewriter_config.custom_optimizers.add() + custom_optimizer.name = "musa_graph_optimizer" + + rewriter_config.min_graph_nodes = -1 + rewriter_config.optimizers.extend(["musa_graph_optimizer"]) + + return config + + +class MatMulBiasAddFusionTest(MUSATestCase): + """Tests for MatMul+BiasAdd fusion.""" + + def test_matmul_biasadd_fusion_basic(self): + """Test MatMul+BiasAdd pattern fusion.""" + np.random.seed(42) + tf.random.set_seed(42) + + m, k, n = 4, 8, 16 + + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference implementation (CPU) + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np) + w_tf = tf.constant(w_np) + b_tf = tf.constant(b_np) + + mm = tf.matmul(x_tf, w_tf) + expected_out = tf.nn.bias_add(mm, b_tf) + # Add a consumer to ensure it's not pruned and has someone to redirect to + expected_out = expected_out * 2.0 + + # Build graph with explicit MUSA device placement + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x") + w = tf.constant(w_np, dtype=tf.float32, name="w") + b = tf.constant(b_np, dtype=tf.float32, name="b") + + # This pattern should be matched by MatMulBiasAddFusion + mm_musa = tf.matmul(x, w) + bias_musa = tf.nn.bias_add(mm_musa, b) + # Add a consumer node + output = bias_musa * 2.0 + + config = create_config_with_musa_optimizer() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual_out = sess.run(output, feed_dict={x: x_np}) + + self.assertAllClose(actual_out, expected_out.numpy(), rtol=1e-5, atol=1e-5) + + def test_matmul_biasadd_fusion_applied(self): + """Verify that MatMul+BiasAdd fusion is applied: MusaMatMulBiasAdd node exists in optimized graph.""" + np.random.seed(123) + tf.random.set_seed(123) + + m, k, n = 4, 8, 16 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x") + w = tf.constant(w_np, dtype=tf.float32, name="w") + b = tf.constant(b_np, dtype=tf.float32, name="b") + + mm_musa = tf.matmul(x, w) + bias_musa = tf.nn.bias_add(mm_musa, b) + # Add a consumer node + output = bias_musa * 2.0 + + config = create_config_with_musa_optimizer() + run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) + run_metadata = tf.compat.v1.RunMetadata() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + sess.run(output, feed_dict={x: x_np}, options=run_options, run_metadata=run_metadata) + + has_fused_node = False + for partition_graph in run_metadata.partition_graphs: + for node in partition_graph.node: + if node.op == "MusaMatMulBiasAdd": + has_fused_node = True + break + + self.assertTrue( + has_fused_node, + "MusaMatMulBiasAdd fusion was NOT applied to the graph" + ) + + def test_matmul_biasadd_fusion_various_batch_sizes(self): + """Test fusion correctness across several batch sizes.""" + np.random.seed(7) + tf.random.set_seed(7) + + k, n = 6, 10 + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + for m in (1, 3, 8): + x_np = np.random.randn(m, k).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np) + w_tf = tf.constant(w_np) + b_tf = tf.constant(b_np) + expected = tf.nn.bias_add(tf.matmul(x_tf, w_tf), b_tf) * 1.5 + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_bs") + w = tf.constant(w_np, dtype=tf.float32, name="w_bs") + b = tf.constant(b_np, dtype=tf.float32, name="b_bs") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + # extra consumer + out = bias * 1.5 + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-5, atol=1e-5) + + def test_matmul_biasadd_fusion_not_applied_with_intervening_op(self): + """If an extra op exists between MatMul and BiasAdd, fusion should not occur.""" + np.random.seed(99) + tf.random.set_seed(99) + + m, k, n = 2, 5, 7 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_int") + w = tf.constant(w_np, dtype=tf.float32, name="w_int") + b = tf.constant(b_np, dtype=tf.float32, name="b_int") + + mm = tf.matmul(x, w) + # Insert an identity (or any intervening op) to block fusion + mid = tf.identity(mm, name="intervening_identity") + bias = tf.nn.bias_add(mid, b) + output = bias * 2.0 + + config = create_config_with_musa_optimizer() + run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) + run_metadata = tf.compat.v1.RunMetadata() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + sess.run(output, feed_dict={x: x_np}, options=run_options, run_metadata=run_metadata) + + has_fused_node = False + for partition_graph in run_metadata.partition_graphs: + for node in partition_graph.node: + if node.op == "MusaMatMulBiasAdd": + has_fused_node = True + break + + self.assertFalse( + has_fused_node, + "MusaMatMulBiasAdd fusion should NOT be applied when an intervening op exists" + ) + + def test_matmul_biasadd_fusion_dtypes(self): + """Test fusion correctness across multiple dtypes: float32, float16, bfloat16.""" + np.random.seed(21) + tf.random.set_seed(21) + + m, k, n = 3, 6, 8 + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + dtypes = [tf.float32, tf.float16, tf.bfloat16] + + for dtype in dtypes: + x_np = np.random.randn(m, k).astype(np.float32) + + # Reference computed in float32 + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np, dtype=tf.float32) + w_tf = tf.constant(w_np, dtype=tf.float32) + b_tf = tf.constant(b_np, dtype=tf.float32) + expected = tf.nn.bias_add(tf.matmul(x_tf, w_tf), b_tf) * 0.75 + expected_f32 = expected.numpy() + + # Build MUSA graph: accept float32 feeds then cast to target dtype inside graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x_ph = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_dt") + x = tf.cast(x_ph, dtype) + w = tf.constant(w_np, dtype=dtype, name="w_dt") + b = tf.constant(b_np, dtype=dtype, name="b_dt") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + out = bias * tf.constant(0.75, dtype=dtype) + # cast back to float32 for stable comparison + out_f32 = tf.cast(out, tf.float32) + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out_f32, feed_dict={x_ph: x_np}) + + if dtype == tf.float32: + rtol, atol = 1e-5, 1e-5 + elif dtype == tf.float16: + rtol, atol = 1e-2, 1e-2 + else: # bfloat16 + rtol, atol = 2e-2, 2e-2 + + self.assertAllClose(actual, expected_f32, rtol=rtol, atol=atol) + + def test_matmul_biasadd_fusion_large_features(self): + """Optional large-feature test. Enable by setting MUSA_RUN_LARGE_TESTS=1.""" + if not os.environ.get("MUSA_RUN_LARGE_TESTS"): + self.skipTest("Large tests disabled; set MUSA_RUN_LARGE_TESTS=1 to run") + + np.random.seed(321) + tf.random.set_seed(321) + + m, k, n = 128, 2048, 1024 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + expected = tf.nn.bias_add( + tf.matmul(tf.constant(x_np), tf.constant(w_np)), + tf.constant(b_np) + ) * 0.9 + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_large_feat") + w = tf.constant(w_np, dtype=tf.float32, name="w_large_feat") + b = tf.constant(b_np, dtype=tf.float32, name="b_large_feat") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + out = bias * 0.9 + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-4, atol=1e-4) + + def test_matmul_biasadd_fusion_large_batch(self): + """Optional large-batch test. Enable by setting MUSA_RUN_LARGE_TESTS=1.""" + if not os.environ.get("MUSA_RUN_LARGE_TESTS"): + self.skipTest("Large tests disabled; set MUSA_RUN_LARGE_TESTS=1 to run") + + np.random.seed(123) + tf.random.set_seed(123) + + m, k, n = 2048, 512, 512 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + expected = tf.nn.bias_add( + tf.matmul(tf.constant(x_np), tf.constant(w_np)), + tf.constant(b_np) + ) + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_large_batch") + w = tf.constant(w_np, dtype=tf.float32, name="w_large_batch") + b = tf.constant(b_np, dtype=tf.float32, name="b_large_batch") + + mm = tf.matmul(x, w) + out = tf.nn.bias_add(mm, b) + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + tf.test.main() From 1d891feb5ea70167a58a7329c3fb1cd227514858 Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Tue, 31 Mar 2026 16:16:41 +0800 Subject: [PATCH 03/10] modify fusion priority --- musa_ext/kernels/nn/musa_matmul_bias_op.txt | 163 -------------- .../kernels/nn/musa_matmul_bias_op_ori.cc | 205 ------------------ .../mu/graph_fusion/matmul_biasadd_fusion.h | 2 +- 3 files changed, 1 insertion(+), 369 deletions(-) delete mode 100644 musa_ext/kernels/nn/musa_matmul_bias_op.txt delete mode 100644 musa_ext/kernels/nn/musa_matmul_bias_op_ori.cc diff --git a/musa_ext/kernels/nn/musa_matmul_bias_op.txt b/musa_ext/kernels/nn/musa_matmul_bias_op.txt deleted file mode 100644 index 758a5321..00000000 --- a/musa_ext/kernels/nn/musa_matmul_bias_op.txt +++ /dev/null @@ -1,163 +0,0 @@ -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "../utils_op.h" - -namespace tensorflow { -namespace musa { - -template -class MusaMatMulBiasAddOp : public MusaOpKernel { - public: - explicit MusaMatMulBiasAddOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); - } - - bool IsExpensive() override { return true; } - - void Compute(OpKernelContext* ctx) override { - const Tensor& a = ctx->input(0); - const Tensor& b = ctx->input(1); - const Tensor& bias = ctx->input(2); - - OP_REQUIRES(ctx, a.dims() == 2, - errors::InvalidArgument("MatMulBiasAdd requires input a to be 2D, got shape ", - a.shape().DebugString())); - OP_REQUIRES(ctx, b.dims() == 2, - errors::InvalidArgument("MatMulBiasAdd requires input b to be 2D, got shape ", - b.shape().DebugString())); - OP_REQUIRES(ctx, bias.dims() == 1, - errors::InvalidArgument("MatMulBiasAdd requires bias to be 1D, got shape ", - bias.shape().DebugString())); - - if (a.NumElements() == 0 || b.NumElements() == 0 || bias.NumElements() == 0) { - TensorShape out_shape; - OP_REQUIRES_OK(ctx, ComputeOutputShape(a, b, &out_shape)); - - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); - return; - } - - const int64_t a_rows = a.dim_size(0); - const int64_t a_cols = a.dim_size(1); - const int64_t b_rows = b.dim_size(0); - const int64_t b_cols = b.dim_size(1); - - const int64_t m = transpose_a_ ? a_cols : a_rows; - const int64_t k_a = transpose_a_ ? a_rows : a_cols; - const int64_t k_b = transpose_b_ ? b_cols : b_rows; - const int64_t n = transpose_b_ ? b_rows : b_cols; - - OP_REQUIRES( - ctx, k_a == k_b, - errors::InvalidArgument("Matrix size-incompatible: a shape ", - a.shape().DebugString(), ", b shape ", - b.shape().DebugString(), ", transpose_a=", - transpose_a_, ", transpose_b=", transpose_b_)); - - OP_REQUIRES( - ctx, bias.dim_size(0) == n, - errors::InvalidArgument("Bias dimension mismatch: bias shape ", - bias.shape().DebugString(), - ", expected [", n, "]")); - - TensorShape out_shape({m, n}); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); - - auto& handle = GetHandleByCtx(ctx); - - mTensor mt_a = CreateMTensor(a, format_); - mTensor mt_b = CreateMTensor(b, format_); - mTensor mt_bias = CreateMTensor(bias, format_); - mTensor mt_out = CreateMTensor(*output, format_); - - ::musa::dnn::MatMul op; - auto status = op.SetTranspose(transpose_a_, transpose_b_); - OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("muDNN MatMul SetTranspose failed, status=", - static_cast(status))); - - status = op.SetAlpha(1.0); - OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("muDNN MatMul SetAlpha failed, status=", - static_cast(status))); - - status = op.SetBeta(0.0); - OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("muDNN MatMul SetBeta failed, status=", - static_cast(status))); - - status = op.SetGamma(1.0); - OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("muDNN MatMul SetGamma failed, status=", - static_cast(status))); - - - status = op.RunWithBiasAdd(handle, mt_out, mt_a, mt_b, mt_bias); - - OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("muDNN MatMulBiasAdd failed, status=", - static_cast(status))); - } - - private: - Status ComputeOutputShape(const Tensor& a, const Tensor& b, - TensorShape* out_shape) { - const int64_t a_rows = a.dim_size(0); - const int64_t a_cols = a.dim_size(1); - const int64_t b_rows = b.dim_size(0); - const int64_t b_cols = b.dim_size(1); - - const int64_t m = transpose_a_ ? a_cols : a_rows; - const int64_t k_a = transpose_a_ ? a_rows : a_cols; - const int64_t k_b = transpose_b_ ? b_cols : b_rows; - const int64_t n = transpose_b_ ? b_rows : b_cols; - - if (k_a != k_b) { - return errors::InvalidArgument( - "Matrix size-incompatible: a shape ", a.shape().DebugString(), - ", b shape ", b.shape().DebugString(), - ", transpose_a=", transpose_a_, - ", transpose_b=", transpose_b_); - } - - *out_shape = TensorShape({m, n}); - return Status::OK(); - } - - private: - bool transpose_a_; - bool transpose_b_; -}; - -#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ - MusaMatMulBiasAddOp); - -REGISTER_MUSA_MATMUL_BIASADD(float); -// REGISTER_MUSA_MATMUL_BIASADD(double); -REGISTER_MUSA_MATMUL_BIASADD(Eigen::half); -REGISTER_MUSA_MATMUL_BIASADD(bfloat16); - -#undef REGISTER_MUSA_MATMUL_BIASADD - -} // namespace musa - -using shape_inference::DimensionHandle; -using shape_inference::InferenceContext; -using shape_inference::ShapeHandle; - -REGISTER_OP("MusaMatMulBiasAdd") - .Input("a: T") - .Input("b: T") - .Input("bias: T") - .Output("product: T") - .Attr("T: {float, half, bfloat16}") - .Attr("transpose_a: bool = false") - .Attr("transpose_b: bool = false") - .SetShapeFn(::tensorflow::shape_inference::MatMulShape); -} // namespace tensorflow diff --git a/musa_ext/kernels/nn/musa_matmul_bias_op_ori.cc b/musa_ext/kernels/nn/musa_matmul_bias_op_ori.cc deleted file mode 100644 index 76483c68..00000000 --- a/musa_ext/kernels/nn/musa_matmul_bias_op_ori.cc +++ /dev/null @@ -1,205 +0,0 @@ -#include -#include -#include - -#include "../utils_op.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/util/matmul_bcast.h" -#include "utils/logging.h" - -namespace tensorflow { -namespace musa { - -template -class MusaMatMulBiasAddOp : public MusaOpKernel { - public: - explicit MusaMatMulBiasAddOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &trans_a_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &trans_b_)); - } - - void Compute(OpKernelContext* ctx) override { - MUSA_KERNEL_TIMING_GUARD(ctx); - - const Tensor& in0 = ctx->input(0); // a - const Tensor& in1 = ctx->input(1); // b - const Tensor& bias = ctx->input(2); // bias - - OP_REQUIRES( - ctx, TensorShapeUtils::IsVector(bias.shape()), - errors::InvalidArgument("bias must be a 1-D tensor, but got shape: ", - bias.shape().DebugString())); - - // ---------------------------- - // 1. Infer MatMul output shape - // ---------------------------- - MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); - OP_REQUIRES(ctx, bcast.IsValid(), - errors::InvalidArgument("Incompatible shapes for MatMul: ", - in0.shape().DebugString(), " vs ", - in1.shape().DebugString())); - - OP_REQUIRES(ctx, in0.dims() >= 2, - errors::InvalidArgument( - "Input a must have rank >= 2, got rank ", in0.dims())); - OP_REQUIRES(ctx, in1.dims() >= 2, - errors::InvalidArgument( - "Input b must have rank >= 2, got rank ", in1.dims())); - - int64 d0 = in0.dim_size(in0.dims() - 2); - int64 d1 = in0.dim_size(in0.dims() - 1); - int64 d2 = in1.dim_size(in1.dims() - 2); - int64 d3 = in1.dim_size(in1.dims() - 1); - - int64 m = trans_a_ ? d1 : d0; - int64 k = trans_a_ ? d0 : d1; - int64 n = trans_b_ ? d2 : d3; - int64 k_check = trans_b_ ? d3 : d2; - - OP_REQUIRES(ctx, k == k_check, - errors::InvalidArgument( - "Matrix size incompatible: lhs k=", k, ", rhs k=", k_check, - ", lhs shape=", in0.shape().DebugString(), - ", rhs shape=", in1.shape().DebugString())); - - OP_REQUIRES( - ctx, bias.dim_size(0) == n, - errors::InvalidArgument( - "Bias size mismatch: bias.shape[0] must equal MatMul output's last " - "dimension. Got bias size = ", - bias.dim_size(0), ", expected = ", n)); - - TensorShape out_shape = bcast.output_batch_shape(); - out_shape.AddDim(m); - out_shape.AddDim(n); - - TensorShape mm_out_shape = out_shape; - - // ----------------------------------- - // 2. Allocate temp for MatMul result - // ----------------------------------- - Tensor mm_out_tensor; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(in0.dtype(), mm_out_shape, &mm_out_tensor)); - - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); - - if (output->NumElements() == 0) { - return; - } - - // ---------------------------- - // 3. mudnn MatMul / BatchMatMul - // ---------------------------- - auto& handle = GetHandleByCtx(ctx); - handle.SetAllowTF32(tf32_enabled_); - - mTensor mt_a = CreateMTensor(in0); - mTensor mt_b = CreateMTensor(in1); - mTensor mt_mm_out = CreateMTensor(mm_out_tensor); - - ::musa::dnn::Status status; - - if (in0.dims() == 2 && in1.dims() == 2) { - mMatMul op; - op.SetTranspose(trans_a_, trans_b_); - op.SetAlpha(1.0); - op.SetBeta(0.0); - - status = op.Run(handle, mt_mm_out, mt_a, mt_b); - } else { - mBatchMatMul op; - op.SetTranspose(trans_a_, trans_b_); - op.SetAlpha(1.0); - op.SetBeta(0.0); - - int64_t out_batch = bcast.output_batch_shape().num_elements(); - - auto ReshapeTo3D = [out_batch](mTensor& mt, const Tensor& t) { - int64_t dims = t.dims(); - int64_t rows = t.dim_size(dims - 2); - int64_t cols = t.dim_size(dims - 1); - int64_t batch = t.NumElements() / (rows * cols); - - if (dims != 3 || (batch == 1 && out_batch > 1)) { - mt.SetNdInfo( - {batch == 1 && out_batch > 1 ? out_batch : batch, rows, cols}, - {batch == 1 && out_batch > 1 ? 0 : rows * cols, cols, 1}); - } - }; - - ReshapeTo3D(mt_a, in0); - ReshapeTo3D(mt_b, in1); - mt_mm_out.SetNdInfo({out_batch, m, n}, {m * n, n, 1}); - - status = op.Run(handle, mt_mm_out, mt_a, mt_b); - } - - OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal( - "MUSA MatMul/BatchMatMul failed in MusaMatMulBiasAdd")); - - // ---------------------------- - // 4. mudnn BiasAdd - // ---------------------------- - mTensor mt_bias = CreateMTensor(bias); - mTensor mt_out = CreateMTensor(*output); - - const int dims_cnt = mm_out_shape.dims(); - const int channel_dim = dims_cnt - 1; - - std::vector b_dims(dims_cnt, 1); - std::vector b_strides(dims_cnt, 0); - b_dims[channel_dim] = bias.dim_size(0); - b_strides[channel_dim] = 1; - - mt_bias.SetNdInfo(dims_cnt, b_dims.data(), b_strides.data()); - - mBinary bias_add_op; - bias_add_op.SetMode(::musa::dnn::Binary::Mode::ADD); - - status = bias_add_op.Run(handle, mt_out, mt_mm_out, mt_bias); - - OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("MUSA BiasAdd failed in MusaMatMulBiasAdd")); - } - - bool IsExpensive() override { return true; } - - private: - bool trans_a_ = false; - bool trans_b_ = false; - bool tf32_enabled_ = false; -}; - -#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ - MusaMatMulBiasAddOp) - -REGISTER_MUSA_MATMUL_BIASADD(float); -REGISTER_MUSA_MATMUL_BIASADD(Eigen::half); -REGISTER_MUSA_MATMUL_BIASADD(bfloat16); - -#undef REGISTER_MUSA_MATMUL_BIASADD - -} // namespace musa - -using shape_inference::DimensionHandle; -using shape_inference::InferenceContext; -using shape_inference::ShapeHandle; - -REGISTER_OP("MusaMatMulBiasAdd") - .Input("a: T") - .Input("b: T") - .Input("bias: T") - .Output("product: T") - .Attr("T: {float, half, bfloat16}") - .Attr("transpose_a: bool = false") - .Attr("transpose_b: bool = false") - .SetShapeFn(::tensorflow::shape_inference::MatMulShape); - -} // namespace tensorflow diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h index 212dd25d..0e1bda7f 100644 --- a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h @@ -22,7 +22,7 @@ class MatMulBiasAddFusion : public FusionPattern { Status Apply(GraphDef* graph, const FusionMatchResult& match_result) const override; - int GetPriority() const override { return 100; } + int GetPriority() const override { return 99; } bool IsKernelAvailable() const override; From 1edfedb1a4509c260233f27b3f70fb99e99a90d9 Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Tue, 31 Mar 2026 16:48:23 +0800 Subject: [PATCH 04/10] change linear relu priority --- musa_ext/mu/graph_fusion/linear_relu_fusion.h | 2 +- musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/musa_ext/mu/graph_fusion/linear_relu_fusion.h b/musa_ext/mu/graph_fusion/linear_relu_fusion.h index 97847e1f..2d67e628 100644 --- a/musa_ext/mu/graph_fusion/linear_relu_fusion.h +++ b/musa_ext/mu/graph_fusion/linear_relu_fusion.h @@ -20,7 +20,7 @@ class LinearReluFusion : public FusionPattern { Status Apply(GraphDef* graph, const FusionMatchResult& match_result) const override; - int GetPriority() const override { return 100; } + int GetPriority() const override { return 98; } bool IsKernelAvailable() const override; diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h index 0e1bda7f..ae0f432e 100644 --- a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h @@ -43,4 +43,4 @@ class MatMulBiasAddFusion : public FusionPattern { } // namespace musa_fusion } // namespace grappler -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow From 49320227516ff659ca076ebdc418b228c7218a90 Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Tue, 31 Mar 2026 18:51:49 +0800 Subject: [PATCH 05/10] use mudnn runwithbiasadd api --- musa_ext/kernels/nn/musa_matmul_bias_op.cc | 252 +++++++++------------ 1 file changed, 103 insertions(+), 149 deletions(-) diff --git a/musa_ext/kernels/nn/musa_matmul_bias_op.cc b/musa_ext/kernels/nn/musa_matmul_bias_op.cc index cb324c3e..4d409a70 100644 --- a/musa_ext/kernels/nn/musa_matmul_bias_op.cc +++ b/musa_ext/kernels/nn/musa_matmul_bias_op.cc @@ -1,13 +1,7 @@ -#include -#include -#include - -#include "../utils_op.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/util/matmul_bcast.h" -#include "utils/logging.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "../utils_op.h" namespace tensorflow { namespace musa { @@ -16,175 +10,136 @@ template class MusaMatMulBiasAddOp : public MusaOpKernel { public: explicit MusaMatMulBiasAddOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &trans_a_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &trans_b_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); } + bool IsExpensive() override { return true; } + void Compute(OpKernelContext* ctx) override { - MUSA_KERNEL_TIMING_GUARD(ctx); - - const Tensor& in0 = ctx->input(0); // a - const Tensor& in1 = ctx->input(1); // b - const Tensor& bias = ctx->input(2); // bias - - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias.shape()), - errors::InvalidArgument( - "bias must be a 1-D tensor, but got shape: ", - bias.shape().DebugString())); - - // ---------------------------- - // 1. Infer MatMul output shape - // ---------------------------- - MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); - OP_REQUIRES(ctx, bcast.IsValid(), - errors::InvalidArgument( - "Incompatible shapes for MatMul: ", - in0.shape().DebugString(), " vs ", - in1.shape().DebugString())); - - OP_REQUIRES(ctx, in0.dims() >= 2, - errors::InvalidArgument("Input a must have rank >= 2, got rank ", - in0.dims())); - OP_REQUIRES(ctx, in1.dims() >= 2, - errors::InvalidArgument("Input b must have rank >= 2, got rank ", - in1.dims())); - - int64 d0 = in0.dim_size(in0.dims() - 2); - int64 d1 = in0.dim_size(in0.dims() - 1); - int64 d2 = in1.dim_size(in1.dims() - 2); - int64 d3 = in1.dim_size(in1.dims() - 1); - - int64 m = trans_a_ ? d1 : d0; - int64 k = trans_a_ ? d0 : d1; - int64 n = trans_b_ ? d2 : d3; - int64 k_check = trans_b_ ? d3 : d2; + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + const Tensor& bias = ctx->input(2); + + OP_REQUIRES(ctx, a.dims() == 2, + errors::InvalidArgument("MatMulBiasAdd requires input a to be 2D, got shape ", + a.shape().DebugString())); + OP_REQUIRES(ctx, b.dims() == 2, + errors::InvalidArgument("MatMulBiasAdd requires input b to be 2D, got shape ", + b.shape().DebugString())); + OP_REQUIRES(ctx, bias.dims() == 1, + errors::InvalidArgument("MatMulBiasAdd requires bias to be 1D, got shape ", + bias.shape().DebugString())); + + if (a.NumElements() == 0 || b.NumElements() == 0 || bias.NumElements() == 0) { + TensorShape out_shape; + OP_REQUIRES_OK(ctx, ComputeOutputShape(a, b, &out_shape)); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); + return; + } + + const int64_t a_rows = a.dim_size(0); + const int64_t a_cols = a.dim_size(1); + const int64_t b_rows = b.dim_size(0); + const int64_t b_cols = b.dim_size(1); + + const int64_t m = transpose_a_ ? a_cols : a_rows; + const int64_t k_a = transpose_a_ ? a_rows : a_cols; + const int64_t k_b = transpose_b_ ? b_cols : b_rows; + const int64_t n = transpose_b_ ? b_rows : b_cols; OP_REQUIRES( - ctx, k == k_check, - errors::InvalidArgument("Matrix size incompatible: lhs k=", k, - ", rhs k=", k_check, - ", lhs shape=", in0.shape().DebugString(), - ", rhs shape=", in1.shape().DebugString())); + ctx, k_a == k_b, + errors::InvalidArgument("Matrix size-incompatible: a shape ", + a.shape().DebugString(), ", b shape ", + b.shape().DebugString(), ", transpose_a=", + transpose_a_, ", transpose_b=", transpose_b_)); OP_REQUIRES( ctx, bias.dim_size(0) == n, - errors::InvalidArgument( - "Bias size mismatch: bias.shape[0] must equal MatMul output's last " - "dimension. Got bias size = ", - bias.dim_size(0), ", expected = ", n)); - - TensorShape out_shape = bcast.output_batch_shape(); - out_shape.AddDim(m); - out_shape.AddDim(n); - - TensorShape mm_out_shape = out_shape; - - // ----------------------------------- - // 2. Allocate temp for MatMul result - // ----------------------------------- - Tensor mm_out_tensor; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(in0.dtype(), mm_out_shape, &mm_out_tensor)); + errors::InvalidArgument("Bias dimension mismatch: bias shape ", + bias.shape().DebugString(), + ", expected [", n, "]")); + TensorShape out_shape({m, n}); Tensor* output = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); - if (output->NumElements() == 0) { - return; - } - - // ---------------------------- - // 3. mudnn MatMul / BatchMatMul - // ---------------------------- auto& handle = GetHandleByCtx(ctx); - handle.SetAllowTF32(tf32_enabled_); - - mTensor mt_a = CreateMTensor(in0); - mTensor mt_b = CreateMTensor(in1); - mTensor mt_mm_out = CreateMTensor(mm_out_tensor); - - ::musa::dnn::Status status; - - if (in0.dims() == 2 && in1.dims() == 2) { - mMatMul op; - op.SetTranspose(trans_a_, trans_b_); - op.SetAlpha(1.0); - op.SetBeta(0.0); - - status = op.Run(handle, mt_mm_out, mt_a, mt_b); - } else { - mBatchMatMul op; - op.SetTranspose(trans_a_, trans_b_); - op.SetAlpha(1.0); - op.SetBeta(0.0); - - int64_t out_batch = bcast.output_batch_shape().num_elements(); - - auto ReshapeTo3D = [out_batch](mTensor& mt, const Tensor& t) { - int64_t dims = t.dims(); - int64_t rows = t.dim_size(dims - 2); - int64_t cols = t.dim_size(dims - 1); - int64_t batch = t.NumElements() / (rows * cols); - - - if (dims != 3 || (batch == 1 && out_batch > 1)) { - mt.SetNdInfo( - {batch == 1 && out_batch > 1 ? out_batch : batch, rows, cols}, - {batch == 1 && out_batch > 1 ? 0 : rows * cols, cols, 1}); - } - }; - - ReshapeTo3D(mt_a, in0); - ReshapeTo3D(mt_b, in1); - mt_mm_out.SetNdInfo({out_batch, m, n}, {m * n, n, 1}); - - status = op.Run(handle, mt_mm_out, mt_a, mt_b); - } - OP_REQUIRES( - ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("MUSA MatMul/BatchMatMul failed in MusaMatMulBiasAdd")); + mTensor mt_a = CreateMTensor(a, format_); + mTensor mt_b = CreateMTensor(b, format_); + mTensor mt_bias = CreateMTensor(bias, format_); + mTensor mt_out = CreateMTensor(*output, format_); - // ---------------------------- - // 4. mudnn BiasAdd - // ---------------------------- - mTensor mt_bias = CreateMTensor(bias); - mTensor mt_out = CreateMTensor(*output); + ::musa::dnn::MatMul op; + auto status = op.SetTranspose(transpose_a_, transpose_b_); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetTranspose failed, status=", + static_cast(status))); - const int dims_cnt = mm_out_shape.dims(); - const int channel_dim = dims_cnt - 1; + status = op.SetAlpha(1.0); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetAlpha failed, status=", + static_cast(status))); - std::vector b_dims(dims_cnt, 1); - std::vector b_strides(dims_cnt, 0); - b_dims[channel_dim] = bias.dim_size(0); - b_strides[channel_dim] = 1; + status = op.SetBeta(0.0); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetBeta failed, status=", + static_cast(status))); - - mt_bias.SetNdInfo(dims_cnt, b_dims.data(), b_strides.data()); + status = op.SetGamma(1.0); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetGamma failed, status=", + static_cast(status))); - mBinary bias_add_op; - bias_add_op.SetMode(::musa::dnn::Binary::Mode::ADD); - status = bias_add_op.Run(handle, mt_out, mt_mm_out, mt_bias); + status = op.RunWithBiasAdd(handle, mt_out, mt_a, mt_b, mt_bias); OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("MUSA BiasAdd failed in MusaMatMulBiasAdd")); + errors::Internal("muDNN MatMulBiasAdd failed, status=", + static_cast(status))); } - bool IsExpensive() override { return true; } + private: + Status ComputeOutputShape(const Tensor& a, const Tensor& b, + TensorShape* out_shape) { + const int64_t a_rows = a.dim_size(0); + const int64_t a_cols = a.dim_size(1); + const int64_t b_rows = b.dim_size(0); + const int64_t b_cols = b.dim_size(1); + + const int64_t m = transpose_a_ ? a_cols : a_rows; + const int64_t k_a = transpose_a_ ? a_rows : a_cols; + const int64_t k_b = transpose_b_ ? b_cols : b_rows; + const int64_t n = transpose_b_ ? b_rows : b_cols; + + if (k_a != k_b) { + return errors::InvalidArgument( + "Matrix size-incompatible: a shape ", a.shape().DebugString(), + ", b shape ", b.shape().DebugString(), + ", transpose_a=", transpose_a_, + ", transpose_b=", transpose_b_); + } + + *out_shape = TensorShape({m, n}); + return Status::OK(); + } private: - bool trans_a_ = false; - bool trans_b_ = false; - bool tf32_enabled_ = false; + bool transpose_a_; + bool transpose_b_; }; -#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ - MusaMatMulBiasAddOp) +#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ + MusaMatMulBiasAddOp); REGISTER_MUSA_MATMUL_BIASADD(float); +// REGISTER_MUSA_MATMUL_BIASADD(double); REGISTER_MUSA_MATMUL_BIASADD(Eigen::half); REGISTER_MUSA_MATMUL_BIASADD(bfloat16); @@ -205,5 +160,4 @@ REGISTER_OP("MusaMatMulBiasAdd") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") .SetShapeFn(::tensorflow::shape_inference::MatMulShape); - } // namespace tensorflow \ No newline at end of file From a8c82bb65482f476a7836cec688d4c1c89fc5416 Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Thu, 2 Apr 2026 15:52:44 +0800 Subject: [PATCH 06/10] update matmul bias --- musa_ext/kernels/nn/musa_matmul_bias_op.cc | 53 ++++++++++--------- .../mu/graph_fusion/matmul_biasadd_fusion.cc | 2 +- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/musa_ext/kernels/nn/musa_matmul_bias_op.cc b/musa_ext/kernels/nn/musa_matmul_bias_op.cc index 4d409a70..a988ffc2 100644 --- a/musa_ext/kernels/nn/musa_matmul_bias_op.cc +++ b/musa_ext/kernels/nn/musa_matmul_bias_op.cc @@ -1,7 +1,7 @@ +#include "../utils_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "../utils_op.h" namespace tensorflow { namespace musa { @@ -22,16 +22,20 @@ class MusaMatMulBiasAddOp : public MusaOpKernel { const Tensor& bias = ctx->input(2); OP_REQUIRES(ctx, a.dims() == 2, - errors::InvalidArgument("MatMulBiasAdd requires input a to be 2D, got shape ", - a.shape().DebugString())); + errors::InvalidArgument( + "MatMulBiasAdd requires input a to be 2D, got shape ", + a.shape().DebugString())); OP_REQUIRES(ctx, b.dims() == 2, - errors::InvalidArgument("MatMulBiasAdd requires input b to be 2D, got shape ", - b.shape().DebugString())); + errors::InvalidArgument( + "MatMulBiasAdd requires input b to be 2D, got shape ", + b.shape().DebugString())); OP_REQUIRES(ctx, bias.dims() == 1, - errors::InvalidArgument("MatMulBiasAdd requires bias to be 1D, got shape ", - bias.shape().DebugString())); + errors::InvalidArgument( + "MatMulBiasAdd requires bias to be 1D, got shape ", + bias.shape().DebugString())); - if (a.NumElements() == 0 || b.NumElements() == 0 || bias.NumElements() == 0) { + if (a.NumElements() == 0 || b.NumElements() == 0 || + bias.NumElements() == 0) { TensorShape out_shape; OP_REQUIRES_OK(ctx, ComputeOutputShape(a, b, &out_shape)); @@ -50,18 +54,17 @@ class MusaMatMulBiasAddOp : public MusaOpKernel { const int64_t k_b = transpose_b_ ? b_cols : b_rows; const int64_t n = transpose_b_ ? b_rows : b_cols; - OP_REQUIRES( - ctx, k_a == k_b, - errors::InvalidArgument("Matrix size-incompatible: a shape ", - a.shape().DebugString(), ", b shape ", - b.shape().DebugString(), ", transpose_a=", - transpose_a_, ", transpose_b=", transpose_b_)); + OP_REQUIRES(ctx, k_a == k_b, + errors::InvalidArgument("Matrix size-incompatible: a shape ", + a.shape().DebugString(), ", b shape ", + b.shape().DebugString(), + ", transpose_a=", transpose_a_, + ", transpose_b=", transpose_b_)); - OP_REQUIRES( - ctx, bias.dim_size(0) == n, - errors::InvalidArgument("Bias dimension mismatch: bias shape ", - bias.shape().DebugString(), - ", expected [", n, "]")); + OP_REQUIRES(ctx, bias.dim_size(0) == n, + errors::InvalidArgument("Bias dimension mismatch: bias shape ", + bias.shape().DebugString(), + ", expected [", n, "]")); TensorShape out_shape({m, n}); Tensor* output = nullptr; @@ -95,7 +98,6 @@ class MusaMatMulBiasAddOp : public MusaOpKernel { errors::Internal("muDNN MatMul SetGamma failed, status=", static_cast(status))); - status = op.RunWithBiasAdd(handle, mt_out, mt_a, mt_b, mt_bias); OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, @@ -119,8 +121,7 @@ class MusaMatMulBiasAddOp : public MusaOpKernel { if (k_a != k_b) { return errors::InvalidArgument( "Matrix size-incompatible: a shape ", a.shape().DebugString(), - ", b shape ", b.shape().DebugString(), - ", transpose_a=", transpose_a_, + ", b shape ", b.shape().DebugString(), ", transpose_a=", transpose_a_, ", transpose_b=", transpose_b_); } @@ -133,9 +134,9 @@ class MusaMatMulBiasAddOp : public MusaOpKernel { bool transpose_b_; }; -#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ +#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ MusaMatMulBiasAddOp); REGISTER_MUSA_MATMUL_BIASADD(float); @@ -160,4 +161,4 @@ REGISTER_OP("MusaMatMulBiasAdd") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") .SetShapeFn(::tensorflow::shape_inference::MatMulShape); -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc index 8b7aca3b..f487edd9 100644 --- a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc @@ -235,7 +235,7 @@ Status MatMulBiasAddFusion::Apply( REGISTER_FUSION_PATTERN(MatMulBiasAddFusion); // Register kernel availability -REGISTER_FUSION_KERNEL(MatMulBiasAddFusion, []() { return true; }); +REGISTER_FUSION_KERNEL(MatMulBiasAddFusion, []() { return false; }); } // namespace musa_fusion } // namespace grappler From 2516aaf073e6ea5c60049d804054d9dd18c4e2f1 Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Thu, 2 Apr 2026 15:56:36 +0800 Subject: [PATCH 07/10] update --- .../mu/graph_fusion/matmul_biasadd_fusion.cc | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc index f487edd9..ef1b1b1b 100644 --- a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc @@ -107,8 +107,8 @@ FusionMatchResult MatMulBiasAddFusion::Match(const GraphDef& graph, return result; } -Status MatMulBiasAddFusion::Apply( - GraphDef* graph, const FusionMatchResult& match_result) const { +Status MatMulBiasAddFusion::Apply(GraphDef* graph, + const FusionMatchResult& match_result) const { if (!match_result.IsValid()) { return Status(error::INVALID_ARGUMENT, "Invalid MatMulBiasAdd match result"); @@ -194,11 +194,10 @@ Status MatMulBiasAddFusion::Apply( fused_node->add_input(matmul_node->input(0)); fused_node->add_input(matmul_node->input(1)); - // Keep original bias edge string exactly, do not replace with bias_node->name() - // because input could be "bias:0" instead of just "bias". - fused_node->add_input( - bias_add_node->input(bias_add_node->input(0) == matmul_node->name() ? 1 - : 0)); + // Keep original bias edge string exactly, do not replace with + // bias_node->name() because input could be "bias:0" instead of just "bias". + fused_node->add_input(bias_add_node->input( + bias_add_node->input(0) == matmul_node->name() ? 1 : 0)); auto* attr = fused_node->mutable_attr(); (*attr)["T"] = output_dtype; @@ -216,9 +215,8 @@ Status MatMulBiasAddFusion::Apply( } // Remove matched nodes if now unused. - std::vector removable_names = { - original_output_name, matmul_node->name() - }; + std::vector removable_names = {original_output_name, + matmul_node->name()}; FusionGraphUtils::RemoveNodesIfUnused( graph, removable_names, @@ -235,8 +233,8 @@ Status MatMulBiasAddFusion::Apply( REGISTER_FUSION_PATTERN(MatMulBiasAddFusion); // Register kernel availability -REGISTER_FUSION_KERNEL(MatMulBiasAddFusion, []() { return false; }); +REGISTER_FUSION_KERNEL(MatMulBiasAddFusion, []() { return true; }); } // namespace musa_fusion } // namespace grappler -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow From 5bca84e5c3a2a01733bdb458f12f12bf2c98d3c8 Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Fri, 3 Apr 2026 12:09:21 +0800 Subject: [PATCH 08/10] replace linear relu op with matmul bias relu and revise the priority --- ...rnel.mu => musa_matmulbias_relu_kernel.mu} | 0 ..._relu_op.cc => musa_matmulbias_relu_op.cc} | 32 +++++++++---------- .../mu/graph_fusion/matmul_biasadd_fusion.h | 2 +- ...lu_fusion.cc => matmul_biasrelu_fusion.cc} | 32 +++++++++---------- ...relu_fusion.h => matmul_biasrelu_fusion.h} | 12 +++---- ...test.py => matmulbias_relu_fusion_test.py} | 12 +++---- 6 files changed, 45 insertions(+), 45 deletions(-) rename musa_ext/kernels/nn/{musa_linear_relu_kernel.mu => musa_matmulbias_relu_kernel.mu} (100%) rename musa_ext/kernels/nn/{musa_linear_relu_op.cc => musa_matmulbias_relu_op.cc} (87%) rename musa_ext/mu/graph_fusion/{linear_relu_fusion.cc => matmul_biasrelu_fusion.cc} (87%) rename musa_ext/mu/graph_fusion/{linear_relu_fusion.h => matmul_biasrelu_fusion.h} (70%) rename test/fusion/{linear_relu_fusion_test.py => matmulbias_relu_fusion_test.py} (96%) diff --git a/musa_ext/kernels/nn/musa_linear_relu_kernel.mu b/musa_ext/kernels/nn/musa_matmulbias_relu_kernel.mu similarity index 100% rename from musa_ext/kernels/nn/musa_linear_relu_kernel.mu rename to musa_ext/kernels/nn/musa_matmulbias_relu_kernel.mu diff --git a/musa_ext/kernels/nn/musa_linear_relu_op.cc b/musa_ext/kernels/nn/musa_matmulbias_relu_op.cc similarity index 87% rename from musa_ext/kernels/nn/musa_linear_relu_op.cc rename to musa_ext/kernels/nn/musa_matmulbias_relu_op.cc index d048d854..645ffb9b 100644 --- a/musa_ext/kernels/nn/musa_linear_relu_op.cc +++ b/musa_ext/kernels/nn/musa_matmulbias_relu_op.cc @@ -9,7 +9,7 @@ namespace tensorflow { namespace musa { -// The fused op for MusaLinearRelu, which computes MatMul + BiasAdd + Relu +// The fused op for MusaMatmulBiasRelu, which computes MatMul + BiasAdd + Relu // Provides two types of implementations: // 1) A pure MUSA implementation using mudnn for MatMul and a custom kernel for // BiasAdd+Relu @@ -20,9 +20,9 @@ template void LaunchBiasAddReluKernel(const T*, const T*, T*, int, int, musaStream_t); template -class MusaLinearReluOp : public MusaOpKernel { +class MusaMatmulBiasReluOp : public MusaOpKernel { public: - explicit MusaLinearReluOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { + explicit MusaMatmulBiasReluOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &trans_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &trans_b_)); } @@ -108,7 +108,7 @@ class MusaLinearReluOp : public MusaOpKernel { OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, errors::Internal( - "MUSA MatMul/BatchMatMul execution failed in LinearRelu.")); + "MUSA Matmul/BatchMatmul execution failed in MatmulBiasRelu.")); // 2. BiasAdd + Relu MUSA_KERNEL_TRACE_START("UseMudnn"); @@ -139,7 +139,7 @@ class MusaLinearReluOp : public MusaOpKernel { int channel_dim = mm_out_shape.dims() - 1; OP_REQUIRES( ctx, bias_input.dim_size(0) == mm_out_shape.dim_size(channel_dim), - errors::InvalidArgument("Dimension mismatch in BiasAdd of LinearRelu")); + errors::InvalidArgument("Dimension mismatch in BiasAdd of MatmulBiasRelu")); int dims_cnt = mm_out_shape.dims(); std::vector b_dims(dims_cnt, 1); @@ -154,7 +154,7 @@ class MusaLinearReluOp : public MusaOpKernel { mStatus status = bias_op.Run(handle, mt_out, mt_mm_out, mt_bias); OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("MUSA BiasAdd failed in LinearRelu.")); + errors::Internal("MUSA BiasAdd failed in MatmulBiasRelu.")); // 3. Relu (In-place on current output) mUnary relu_op; @@ -162,7 +162,7 @@ class MusaLinearReluOp : public MusaOpKernel { status = relu_op.Run(handle, mt_out, mt_out); OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("MUSA Relu failed in LinearRelu.")); + errors::Internal("MUSA Relu failed in MatmulBiasRelu.")); } void UseKernel(OpKernelContext* ctx, const Tensor& bias_input, @@ -178,20 +178,20 @@ class MusaLinearReluOp : public MusaOpKernel { } }; -#define REGISTER_MUSA_LINEAR_RELU(TYPE) \ +#define REGISTER_MUSA_MatmulBias_RELU(TYPE) \ REGISTER_KERNEL_BUILDER( \ - Name("MusaLinearRelu").Device("MUSA").TypeConstraint("T"), \ - MusaLinearReluOp); + Name("MusaMatmulBiasRelu").Device("MUSA").TypeConstraint("T"), \ + MusaMatmulBiasReluOp); -REGISTER_MUSA_LINEAR_RELU(float); -REGISTER_MUSA_LINEAR_RELU(Eigen::half); -REGISTER_MUSA_LINEAR_RELU(bfloat16); -REGISTER_MUSA_LINEAR_RELU(double); +REGISTER_MUSA_MatmulBias_RELU(float); +REGISTER_MUSA_MatmulBias_RELU(Eigen::half); +REGISTER_MUSA_MatmulBias_RELU(bfloat16); +REGISTER_MUSA_MatmulBias_RELU(double); -#undef REGISTER_MUSA_LINEAR_RELU +#undef REGISTER_MUSA_MatmulBias_RELU } // namespace musa -REGISTER_OP("MusaLinearRelu") +REGISTER_OP("MusaMatmulBiasRelu") .Input("a: T") .Input("b: T") .Input("bias: T") diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h index ae0f432e..3661d9d5 100644 --- a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h @@ -22,7 +22,7 @@ class MatMulBiasAddFusion : public FusionPattern { Status Apply(GraphDef* graph, const FusionMatchResult& match_result) const override; - int GetPriority() const override { return 99; } + int GetPriority() const override { return 98; } bool IsKernelAvailable() const override; diff --git a/musa_ext/mu/graph_fusion/linear_relu_fusion.cc b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc similarity index 87% rename from musa_ext/mu/graph_fusion/linear_relu_fusion.cc rename to musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc index e4dd8052..5ad513f5 100644 --- a/musa_ext/mu/graph_fusion/linear_relu_fusion.cc +++ b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc @@ -1,4 +1,4 @@ -#include "mu/graph_fusion/linear_relu_fusion.h" +#include "mu/graph_fusion/matmul_biasrelu_fusion.h" #include #include @@ -47,7 +47,7 @@ bool HasOriginalSuffix(const std::string& node_name) { } // namespace -bool LinearReluFusion::IsKernelAvailable() const { +bool MatmulBiasReluFusion::IsKernelAvailable() const { if (!kernel_checked_) { kernel_available_ = true; kernel_checked_ = true; @@ -55,7 +55,7 @@ bool LinearReluFusion::IsKernelAvailable() const { return kernel_available_; } -FusionMatchResult LinearReluFusion::Match(const GraphDef& graph, +FusionMatchResult MatmulBiasReluFusion::Match(const GraphDef& graph, int start_node_idx) const { FusionMatchResult result; if (start_node_idx < 0 || start_node_idx >= graph.node_size()) { @@ -117,10 +117,10 @@ FusionMatchResult LinearReluFusion::Match(const GraphDef& graph, return result; } -Status LinearReluFusion::Apply(GraphDef* graph, +Status MatmulBiasReluFusion::Apply(GraphDef* graph, const FusionMatchResult& match_result) const { if (!match_result.IsValid()) { - return Status(error::INVALID_ARGUMENT, "Invalid LinearRelu match result"); + return Status(error::INVALID_ARGUMENT, "Invalid MatmulBiasReluFusion match result"); } if (!IsKernelAvailable()) { @@ -136,7 +136,7 @@ Status LinearReluFusion::Apply(GraphDef* graph, matmul_it == match_result.captured_nodes.end() || bias_it == match_result.captured_nodes.end()) { return Status(error::INVALID_ARGUMENT, - "Missing required nodes in LinearRelu pattern"); + "Missing required nodes in MatmulBiasReluFusion pattern"); } const NodeDef* output_node = output_it->second; @@ -148,8 +148,8 @@ Status LinearReluFusion::Apply(GraphDef* graph, // Check if this output node has already been fused (avoid duplicates) for (const auto& node : graph->node()) { - if (node.name() == original_name && node.op() == "MusaLinearRelu") { - VLOG(1) << "MusaLinearRelu: Output node " << original_name + if (node.name() == original_name && node.op() == "MusaMatmulBiasRelu") { + VLOG(1) << "MusaMatmulBiasRelu: Output node " << original_name << " is already a fused node, skipping"; return Status::OK(); } @@ -168,8 +168,8 @@ Status LinearReluFusion::Apply(GraphDef* graph, "Failed to find output node in graph: " + original_name); } - VLOG(1) << "LinearReluFusion: Replacing " << original_name - << " with MusaLinearRelu"; + VLOG(1) << "MusaMatmulBiasReluFusion: Replacing " << original_name + << " with MusaMatmulBiasReluFusion"; NodeDef* original_output_node = graph->mutable_node(output_node_idx); const std::string output_device = original_output_node->device(); @@ -192,10 +192,10 @@ Status LinearReluFusion::Apply(GraphDef* graph, NodeDef* fused_node = graph->add_node(); fused_node->set_name(original_name); - fused_node->set_op("MusaLinearRelu"); + fused_node->set_op("MusaMatmulBiasRelu"); fused_node->set_device(output_device); - // MusaLinearRelu inputs: a, b, bias + // MusaMatmulBiasRelu inputs: a, b, bias fused_node->add_input(matmul_node->input(0)); fused_node->add_input(matmul_node->input(1)); // bias input might need port handling if it's more than just a name @@ -230,17 +230,17 @@ Status LinearReluFusion::Apply(GraphDef* graph, {matmul_node->input(0), matmul_node->input(1), bias_node->name(), original_name}); - VLOG(1) << "LinearReluFusion: Successfully replaced '" << original_name - << "' with MusaLinearRelu"; + VLOG(1) << "MatmulBiasReluFusion: Successfully replaced '" << original_name + << "' with MusaMatmulBiasRelu"; return Status::OK(); } // Register the pattern -REGISTER_FUSION_PATTERN(LinearReluFusion); +REGISTER_FUSION_PATTERN(MatmulBiasReluFusion); // Register kernel availability -REGISTER_FUSION_KERNEL(LinearReluFusion, []() { return true; }); +REGISTER_FUSION_KERNEL(MatmulBiasReluFusion, []() { return true; }); } // namespace musa_fusion } // namespace grappler diff --git a/musa_ext/mu/graph_fusion/linear_relu_fusion.h b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h similarity index 70% rename from musa_ext/mu/graph_fusion/linear_relu_fusion.h rename to musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h index 2d67e628..48611e3f 100644 --- a/musa_ext/mu/graph_fusion/linear_relu_fusion.h +++ b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h @@ -9,10 +9,10 @@ namespace musa_fusion { // Computes: MatMul + BiasAdd + Relu -class LinearReluFusion : public FusionPattern { +class MatmulBiasReluFusion : public FusionPattern { public: - LinearReluFusion() = default; - ~LinearReluFusion() override = default; + MatmulBiasReluFusion() = default; + ~MatmulBiasReluFusion() override = default; FusionMatchResult Match(const GraphDef& graph, int start_node_idx) const override; @@ -20,15 +20,15 @@ class LinearReluFusion : public FusionPattern { Status Apply(GraphDef* graph, const FusionMatchResult& match_result) const override; - int GetPriority() const override { return 98; } + int GetPriority() const override { return 96; } bool IsKernelAvailable() const override; - std::string GetName() const override { return "LinearReluFusion"; } + std::string GetName() const override { return "MatmulBiasReluFusion"; } std::string GetFallbackReason() const override { if (!kernel_available_) { - return "LinearReluFusion kernel not available on this device"; + return "MatmulBiasReluFusion kernel not available on this device"; } return ""; } diff --git a/test/fusion/linear_relu_fusion_test.py b/test/fusion/matmulbias_relu_fusion_test.py similarity index 96% rename from test/fusion/linear_relu_fusion_test.py rename to test/fusion/matmulbias_relu_fusion_test.py index 39cab819..ed8b8512 100644 --- a/test/fusion/linear_relu_fusion_test.py +++ b/test/fusion/matmulbias_relu_fusion_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -"""Tests for Linear+Relu fusion.""" +"""Tests for Matmul+BiasAdd+Relu fusion.""" import os import numpy as np @@ -123,11 +123,11 @@ def test_linear_relu_fusion_applied(self): has_fused_node = False for partition_graph in run_metadata.partition_graphs: for node in partition_graph.node: - if node.op == "MusaLinearRelu": + if node.op == "MusaMatmulBiasRelu": has_fused_node = True break - self.assertTrue(has_fused_node, "MusaLinearRelu fusion was NOT applied to the graph") + self.assertTrue(has_fused_node, "MusaMatmulBiasRelu fusion was NOT applied to the graph") def test_linear_relu_fusion_various_batch_sizes(self): """Test fusion correctness across several batch sizes.""" @@ -169,7 +169,7 @@ def test_linear_relu_fusion_various_batch_sizes(self): self.assertAllClose(actual, expected.numpy(), rtol=1e-5, atol=1e-5) def test_linear_relu_fusion_not_applied_with_intervening_op(self): - """If an extra op exists between MatMul and BiasAdd, fusion should not occur.""" + """If an extra op exists between Matmul and BiasAdd, fusion should not occur.""" m, k, n = 2, 5, 7 x_np = np.random.randn(m, k).astype(np.float32) w_np = np.random.randn(k, n).astype(np.float32) @@ -200,11 +200,11 @@ def test_linear_relu_fusion_not_applied_with_intervening_op(self): has_fused_node = False for partition_graph in run_metadata.partition_graphs: for node in partition_graph.node: - if node.op == "MusaLinearRelu": + if node.op == "MusaMatmulBiasRelu": has_fused_node = True break - self.assertFalse(has_fused_node, "MusaLinearRelu fusion should NOT be applied when an intervening op exists") + self.assertFalse(has_fused_node, "MusaMatmulBiasRelu fusion should NOT be applied when an intervening op exists") def test_linear_relu_fusion_dtypes(self): """Test fusion correctness across multiple dtypes: float32, float16, bfloat16.""" From 23511314e267dc3db0f9cb8e222413597cd275ec Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Fri, 3 Apr 2026 12:12:37 +0800 Subject: [PATCH 09/10] clang format --- musa_ext/kernels/nn/musa_matmulbias_relu_op.cc | 16 +++++++++------- .../mu/graph_fusion/matmul_biasrelu_fusion.cc | 9 +++++---- .../mu/graph_fusion/matmul_biasrelu_fusion.h | 4 ++-- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/musa_ext/kernels/nn/musa_matmulbias_relu_op.cc b/musa_ext/kernels/nn/musa_matmulbias_relu_op.cc index 645ffb9b..f0919eb5 100644 --- a/musa_ext/kernels/nn/musa_matmulbias_relu_op.cc +++ b/musa_ext/kernels/nn/musa_matmulbias_relu_op.cc @@ -106,9 +106,10 @@ class MusaMatmulBiasReluOp : public MusaOpKernel { status = op.Run(handle, mt_mm_out, mt_a, mt_b); } - OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal( - "MUSA Matmul/BatchMatmul execution failed in MatmulBiasRelu.")); + OP_REQUIRES( + ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal( + "MUSA Matmul/BatchMatmul execution failed in MatmulBiasRelu.")); // 2. BiasAdd + Relu MUSA_KERNEL_TRACE_START("UseMudnn"); @@ -137,9 +138,10 @@ class MusaMatmulBiasReluOp : public MusaOpKernel { mTensor mt_out = CreateMTensor(*output); int channel_dim = mm_out_shape.dims() - 1; - OP_REQUIRES( - ctx, bias_input.dim_size(0) == mm_out_shape.dim_size(channel_dim), - errors::InvalidArgument("Dimension mismatch in BiasAdd of MatmulBiasRelu")); + OP_REQUIRES(ctx, + bias_input.dim_size(0) == mm_out_shape.dim_size(channel_dim), + errors::InvalidArgument( + "Dimension mismatch in BiasAdd of MatmulBiasRelu")); int dims_cnt = mm_out_shape.dims(); std::vector b_dims(dims_cnt, 1); @@ -179,7 +181,7 @@ class MusaMatmulBiasReluOp : public MusaOpKernel { }; #define REGISTER_MUSA_MatmulBias_RELU(TYPE) \ - REGISTER_KERNEL_BUILDER( \ + REGISTER_KERNEL_BUILDER( \ Name("MusaMatmulBiasRelu").Device("MUSA").TypeConstraint("T"), \ MusaMatmulBiasReluOp); diff --git a/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc index 5ad513f5..6b9b2dfa 100644 --- a/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc +++ b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc @@ -56,7 +56,7 @@ bool MatmulBiasReluFusion::IsKernelAvailable() const { } FusionMatchResult MatmulBiasReluFusion::Match(const GraphDef& graph, - int start_node_idx) const { + int start_node_idx) const { FusionMatchResult result; if (start_node_idx < 0 || start_node_idx >= graph.node_size()) { return result; @@ -117,10 +117,11 @@ FusionMatchResult MatmulBiasReluFusion::Match(const GraphDef& graph, return result; } -Status MatmulBiasReluFusion::Apply(GraphDef* graph, - const FusionMatchResult& match_result) const { +Status MatmulBiasReluFusion::Apply( + GraphDef* graph, const FusionMatchResult& match_result) const { if (!match_result.IsValid()) { - return Status(error::INVALID_ARGUMENT, "Invalid MatmulBiasReluFusion match result"); + return Status(error::INVALID_ARGUMENT, + "Invalid MatmulBiasReluFusion match result"); } if (!IsKernelAvailable()) { diff --git a/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h index 48611e3f..c671c5fd 100644 --- a/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h +++ b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h @@ -11,8 +11,8 @@ namespace musa_fusion { class MatmulBiasReluFusion : public FusionPattern { public: - MatmulBiasReluFusion() = default; - ~MatmulBiasReluFusion() override = default; + MatmulBiasReluFusion() = default; + ~MatmulBiasReluFusion() override = default; FusionMatchResult Match(const GraphDef& graph, int start_node_idx) const override; From aed931038be0a5856aa54852e111cfbffbc4493d Mon Sep 17 00:00:00 2001 From: yaotianhang Date: Tue, 7 Apr 2026 15:35:27 +0800 Subject: [PATCH 10/10] update --- musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc index ef1b1b1b..13b8544b 100644 --- a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc @@ -230,10 +230,10 @@ Status MatMulBiasAddFusion::Apply(GraphDef* graph, } // Register the pattern -REGISTER_FUSION_PATTERN(MatMulBiasAddFusion); +// REGISTER_FUSION_PATTERN(MatMulBiasAddFusion); -// Register kernel availability -REGISTER_FUSION_KERNEL(MatMulBiasAddFusion, []() { return true; }); +// // Register kernel availability +// REGISTER_FUSION_KERNEL(MatMulBiasAddFusion, []() { return true; }); } // namespace musa_fusion } // namespace grappler