diff --git a/musa_ext/kernels/nn/musa_matmul_bias_op.cc b/musa_ext/kernels/nn/musa_matmul_bias_op.cc new file mode 100644 index 00000000..a988ffc2 --- /dev/null +++ b/musa_ext/kernels/nn/musa_matmul_bias_op.cc @@ -0,0 +1,164 @@ +#include "../utils_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace musa { + +template +class MusaMatMulBiasAddOp : public MusaOpKernel { + public: + explicit MusaMatMulBiasAddOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + } + + bool IsExpensive() override { return true; } + + void Compute(OpKernelContext* ctx) override { + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + const Tensor& bias = ctx->input(2); + + OP_REQUIRES(ctx, a.dims() == 2, + errors::InvalidArgument( + "MatMulBiasAdd requires input a to be 2D, got shape ", + a.shape().DebugString())); + OP_REQUIRES(ctx, b.dims() == 2, + errors::InvalidArgument( + "MatMulBiasAdd requires input b to be 2D, got shape ", + b.shape().DebugString())); + OP_REQUIRES(ctx, bias.dims() == 1, + errors::InvalidArgument( + "MatMulBiasAdd requires bias to be 1D, got shape ", + bias.shape().DebugString())); + + if (a.NumElements() == 0 || b.NumElements() == 0 || + bias.NumElements() == 0) { + TensorShape out_shape; + OP_REQUIRES_OK(ctx, ComputeOutputShape(a, b, &out_shape)); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); + return; + } + + const int64_t a_rows = a.dim_size(0); + const int64_t a_cols = a.dim_size(1); + const int64_t b_rows = b.dim_size(0); + const int64_t b_cols = b.dim_size(1); + + const int64_t m = transpose_a_ ? a_cols : a_rows; + const int64_t k_a = transpose_a_ ? a_rows : a_cols; + const int64_t k_b = transpose_b_ ? b_cols : b_rows; + const int64_t n = transpose_b_ ? b_rows : b_cols; + + OP_REQUIRES(ctx, k_a == k_b, + errors::InvalidArgument("Matrix size-incompatible: a shape ", + a.shape().DebugString(), ", b shape ", + b.shape().DebugString(), + ", transpose_a=", transpose_a_, + ", transpose_b=", transpose_b_)); + + OP_REQUIRES(ctx, bias.dim_size(0) == n, + errors::InvalidArgument("Bias dimension mismatch: bias shape ", + bias.shape().DebugString(), + ", expected [", n, "]")); + + TensorShape out_shape({m, n}); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &output)); + + auto& handle = GetHandleByCtx(ctx); + + mTensor mt_a = CreateMTensor(a, format_); + mTensor mt_b = CreateMTensor(b, format_); + mTensor mt_bias = CreateMTensor(bias, format_); + mTensor mt_out = CreateMTensor(*output, format_); + + ::musa::dnn::MatMul op; + auto status = op.SetTranspose(transpose_a_, transpose_b_); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetTranspose failed, status=", + static_cast(status))); + + status = op.SetAlpha(1.0); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetAlpha failed, status=", + static_cast(status))); + + status = op.SetBeta(0.0); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetBeta failed, status=", + static_cast(status))); + + status = op.SetGamma(1.0); + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMul SetGamma failed, status=", + static_cast(status))); + + status = op.RunWithBiasAdd(handle, mt_out, mt_a, mt_b, mt_bias); + + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal("muDNN MatMulBiasAdd failed, status=", + static_cast(status))); + } + + private: + Status ComputeOutputShape(const Tensor& a, const Tensor& b, + TensorShape* out_shape) { + const int64_t a_rows = a.dim_size(0); + const int64_t a_cols = a.dim_size(1); + const int64_t b_rows = b.dim_size(0); + const int64_t b_cols = b.dim_size(1); + + const int64_t m = transpose_a_ ? a_cols : a_rows; + const int64_t k_a = transpose_a_ ? a_rows : a_cols; + const int64_t k_b = transpose_b_ ? b_cols : b_rows; + const int64_t n = transpose_b_ ? b_rows : b_cols; + + if (k_a != k_b) { + return errors::InvalidArgument( + "Matrix size-incompatible: a shape ", a.shape().DebugString(), + ", b shape ", b.shape().DebugString(), ", transpose_a=", transpose_a_, + ", transpose_b=", transpose_b_); + } + + *out_shape = TensorShape({m, n}); + return Status::OK(); + } + + private: + bool transpose_a_; + bool transpose_b_; +}; + +#define REGISTER_MUSA_MATMUL_BIASADD(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("MusaMatMulBiasAdd").Device("MUSA").TypeConstraint("T"), \ + MusaMatMulBiasAddOp); + +REGISTER_MUSA_MATMUL_BIASADD(float); +// REGISTER_MUSA_MATMUL_BIASADD(double); +REGISTER_MUSA_MATMUL_BIASADD(Eigen::half); +REGISTER_MUSA_MATMUL_BIASADD(bfloat16); + +#undef REGISTER_MUSA_MATMUL_BIASADD + +} // namespace musa + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("MusaMatMulBiasAdd") + .Input("a: T") + .Input("b: T") + .Input("bias: T") + .Output("product: T") + .Attr("T: {float, half, bfloat16}") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .SetShapeFn(::tensorflow::shape_inference::MatMulShape); +} // namespace tensorflow diff --git a/musa_ext/kernels/nn/musa_linear_relu_kernel.mu b/musa_ext/kernels/nn/musa_matmulbias_relu_kernel.mu similarity index 100% rename from musa_ext/kernels/nn/musa_linear_relu_kernel.mu rename to musa_ext/kernels/nn/musa_matmulbias_relu_kernel.mu diff --git a/musa_ext/kernels/nn/musa_linear_relu_op.cc b/musa_ext/kernels/nn/musa_matmulbias_relu_op.cc similarity index 83% rename from musa_ext/kernels/nn/musa_linear_relu_op.cc rename to musa_ext/kernels/nn/musa_matmulbias_relu_op.cc index d048d854..f0919eb5 100644 --- a/musa_ext/kernels/nn/musa_linear_relu_op.cc +++ b/musa_ext/kernels/nn/musa_matmulbias_relu_op.cc @@ -9,7 +9,7 @@ namespace tensorflow { namespace musa { -// The fused op for MusaLinearRelu, which computes MatMul + BiasAdd + Relu +// The fused op for MusaMatmulBiasRelu, which computes MatMul + BiasAdd + Relu // Provides two types of implementations: // 1) A pure MUSA implementation using mudnn for MatMul and a custom kernel for // BiasAdd+Relu @@ -20,9 +20,9 @@ template void LaunchBiasAddReluKernel(const T*, const T*, T*, int, int, musaStream_t); template -class MusaLinearReluOp : public MusaOpKernel { +class MusaMatmulBiasReluOp : public MusaOpKernel { public: - explicit MusaLinearReluOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { + explicit MusaMatmulBiasReluOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &trans_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &trans_b_)); } @@ -106,9 +106,10 @@ class MusaLinearReluOp : public MusaOpKernel { status = op.Run(handle, mt_mm_out, mt_a, mt_b); } - OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal( - "MUSA MatMul/BatchMatMul execution failed in LinearRelu.")); + OP_REQUIRES( + ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal( + "MUSA Matmul/BatchMatmul execution failed in MatmulBiasRelu.")); // 2. BiasAdd + Relu MUSA_KERNEL_TRACE_START("UseMudnn"); @@ -137,9 +138,10 @@ class MusaLinearReluOp : public MusaOpKernel { mTensor mt_out = CreateMTensor(*output); int channel_dim = mm_out_shape.dims() - 1; - OP_REQUIRES( - ctx, bias_input.dim_size(0) == mm_out_shape.dim_size(channel_dim), - errors::InvalidArgument("Dimension mismatch in BiasAdd of LinearRelu")); + OP_REQUIRES(ctx, + bias_input.dim_size(0) == mm_out_shape.dim_size(channel_dim), + errors::InvalidArgument( + "Dimension mismatch in BiasAdd of MatmulBiasRelu")); int dims_cnt = mm_out_shape.dims(); std::vector b_dims(dims_cnt, 1); @@ -154,7 +156,7 @@ class MusaLinearReluOp : public MusaOpKernel { mStatus status = bias_op.Run(handle, mt_out, mt_mm_out, mt_bias); OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("MUSA BiasAdd failed in LinearRelu.")); + errors::Internal("MUSA BiasAdd failed in MatmulBiasRelu.")); // 3. Relu (In-place on current output) mUnary relu_op; @@ -162,7 +164,7 @@ class MusaLinearReluOp : public MusaOpKernel { status = relu_op.Run(handle, mt_out, mt_out); OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, - errors::Internal("MUSA Relu failed in LinearRelu.")); + errors::Internal("MUSA Relu failed in MatmulBiasRelu.")); } void UseKernel(OpKernelContext* ctx, const Tensor& bias_input, @@ -178,20 +180,20 @@ class MusaLinearReluOp : public MusaOpKernel { } }; -#define REGISTER_MUSA_LINEAR_RELU(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("MusaLinearRelu").Device("MUSA").TypeConstraint("T"), \ - MusaLinearReluOp); +#define REGISTER_MUSA_MatmulBias_RELU(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("MusaMatmulBiasRelu").Device("MUSA").TypeConstraint("T"), \ + MusaMatmulBiasReluOp); -REGISTER_MUSA_LINEAR_RELU(float); -REGISTER_MUSA_LINEAR_RELU(Eigen::half); -REGISTER_MUSA_LINEAR_RELU(bfloat16); -REGISTER_MUSA_LINEAR_RELU(double); +REGISTER_MUSA_MatmulBias_RELU(float); +REGISTER_MUSA_MatmulBias_RELU(Eigen::half); +REGISTER_MUSA_MatmulBias_RELU(bfloat16); +REGISTER_MUSA_MatmulBias_RELU(double); -#undef REGISTER_MUSA_LINEAR_RELU +#undef REGISTER_MUSA_MatmulBias_RELU } // namespace musa -REGISTER_OP("MusaLinearRelu") +REGISTER_OP("MusaMatmulBiasRelu") .Input("a: T") .Input("b: T") .Input("bias: T") diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc new file mode 100644 index 00000000..13b8544b --- /dev/null +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.cc @@ -0,0 +1,240 @@ +#include "mu/graph_fusion/matmul_biasadd_fusion.h" + +#include +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace grappler { +namespace musa_fusion { + +namespace { + +// Helper to check if node has specific op type +bool IsOp(const NodeDef& node, const std::string& op_type) { + return node.op() == op_type; +} + +// Helper to find node's input producer +const NodeDef* FindProducer(const GraphDef& graph, const std::string& input) { + if (input.empty()) return nullptr; + + std::string node_name = input; + if (node_name[0] == '^') { + node_name = node_name.substr(1); + } + const size_t colon_pos = node_name.find(':'); + if (colon_pos != std::string::npos) { + node_name = node_name.substr(0, colon_pos); + } + + for (int i = 0; i < graph.node_size(); ++i) { + if (graph.node(i).name() == node_name) { + return &graph.node(i); + } + } + return nullptr; +} + +bool HasOriginalSuffix(const std::string& node_name) { + static const std::string kOriginalSuffix = "_original"; + return node_name.size() >= kOriginalSuffix.size() && + node_name.compare(node_name.size() - kOriginalSuffix.size(), + kOriginalSuffix.size(), kOriginalSuffix) == 0; +} + +} // namespace + +bool MatMulBiasAddFusion::IsKernelAvailable() const { + if (!kernel_checked_) { + kernel_available_ = true; + kernel_checked_ = true; + } + return kernel_available_; +} + +FusionMatchResult MatMulBiasAddFusion::Match(const GraphDef& graph, + int start_node_idx) const { + FusionMatchResult result; + if (start_node_idx < 0 || start_node_idx >= graph.node_size()) { + return result; + } + + const NodeDef& bias_add_node = graph.node(start_node_idx); + + // Start node must be BiasAdd / Add / AddV2 + if (!IsOp(bias_add_node, "BiasAdd")) { + return result; + } + + if (HasOriginalSuffix(bias_add_node.name())) { + return result; + } + + // Find MatMul node and bias node from the two inputs of BiasAdd/Add/AddV2 + const NodeDef* matmul_node = nullptr; + const NodeDef* bias_node = nullptr; + + if (bias_add_node.input_size() >= 2) { + const NodeDef* in0 = FindProducer(graph, bias_add_node.input(0)); + const NodeDef* in1 = FindProducer(graph, bias_add_node.input(1)); + + if (in0 && IsOp(*in0, "MatMul")) { + matmul_node = in0; + bias_node = in1; + } else if (in1 && IsOp(*in1, "MatMul")) { + matmul_node = in1; + bias_node = in0; + } + } + + if (!matmul_node || !bias_node) { + return result; + } + + // Record into result + result.matched = true; + result.matched_nodes.push_back(&bias_add_node); + result.matched_nodes.push_back(matmul_node); + + result.captured_nodes["output"] = &bias_add_node; + result.captured_nodes["bias_add"] = &bias_add_node; + result.captured_nodes["matmul"] = matmul_node; + result.captured_nodes["bias"] = bias_node; + + return result; +} + +Status MatMulBiasAddFusion::Apply(GraphDef* graph, + const FusionMatchResult& match_result) const { + if (!match_result.IsValid()) { + return Status(error::INVALID_ARGUMENT, + "Invalid MatMulBiasAdd match result"); + } + + if (!IsKernelAvailable()) { + return Status::OK(); + } + + auto output_it = match_result.captured_nodes.find("output"); + auto matmul_it = match_result.captured_nodes.find("matmul"); + auto bias_it = match_result.captured_nodes.find("bias"); + auto bias_add_it = match_result.captured_nodes.find("bias_add"); + + if (output_it == match_result.captured_nodes.end() || + matmul_it == match_result.captured_nodes.end() || + bias_it == match_result.captured_nodes.end() || + bias_add_it == match_result.captured_nodes.end()) { + return Status(error::INVALID_ARGUMENT, + "Missing required nodes in MatMulBiasAdd pattern"); + } + + const NodeDef* output_node = output_it->second; + const NodeDef* matmul_node = matmul_it->second; + const NodeDef* bias_node = bias_it->second; + const NodeDef* bias_add_node = bias_add_it->second; + + const std::string original_name = output_node->name(); + const std::string original_output_name = original_name + "_original"; + + // Avoid duplicate fusion + for (const auto& node : graph->node()) { + if (node.name() == original_name && node.op() == "MusaMatMulBiasAdd") { + VLOG(1) << "MusaMatMulBiasAdd: Output node " << original_name + << " is already a fused node, skipping"; + return Status::OK(); + } + } + + int output_node_idx = -1; + for (int i = 0; i < graph->node_size(); ++i) { + if (graph->node(i).name() == original_name) { + output_node_idx = i; + break; + } + } + + if (output_node_idx < 0) { + return Status(error::INVALID_ARGUMENT, + "Failed to find output node in graph: " + original_name); + } + + VLOG(1) << "MatMulBiasAddFusion: Replacing " << original_name + << " with MusaMatMulBiasAdd"; + + NodeDef* original_output_node = graph->mutable_node(output_node_idx); + const std::string output_device = original_output_node->device(); + + // Pick dtype from MatMul first, then output node, otherwise float + AttrValue output_dtype; + auto dtype_it = matmul_node->attr().find("T"); + if (dtype_it != matmul_node->attr().end()) { + output_dtype = dtype_it->second; + } else { + dtype_it = original_output_node->attr().find("T"); + if (dtype_it != original_output_node->attr().end()) { + output_dtype = dtype_it->second; + } else { + output_dtype.set_type(DT_FLOAT); + } + } + + // Rename the original output node + original_output_node->set_name(original_output_name); + + // Create fused node using the original output node name + NodeDef* fused_node = graph->add_node(); + fused_node->set_name(original_name); + fused_node->set_op("MusaMatMulBiasAdd"); + fused_node->set_device(output_device); + + // MusaMatMulBiasAdd inputs: a, b, bias + fused_node->add_input(matmul_node->input(0)); + fused_node->add_input(matmul_node->input(1)); + + // Keep original bias edge string exactly, do not replace with + // bias_node->name() because input could be "bias:0" instead of just "bias". + fused_node->add_input(bias_add_node->input( + bias_add_node->input(0) == matmul_node->name() ? 1 : 0)); + + auto* attr = fused_node->mutable_attr(); + (*attr)["T"] = output_dtype; + + if (matmul_node->attr().count("transpose_a")) { + (*attr)["transpose_a"] = matmul_node->attr().at("transpose_a"); + } else { + (*attr)["transpose_a"].set_b(false); + } + + if (matmul_node->attr().count("transpose_b")) { + (*attr)["transpose_b"] = matmul_node->attr().at("transpose_b"); + } else { + (*attr)["transpose_b"].set_b(false); + } + + // Remove matched nodes if now unused. + std::vector removable_names = {original_output_name, + matmul_node->name()}; + + FusionGraphUtils::RemoveNodesIfUnused( + graph, removable_names, + {matmul_node->input(0), matmul_node->input(1), bias_node->name(), + original_name}); + + VLOG(1) << "MatMulBiasAddFusion: Successfully replaced '" << original_name + << "' with MusaMatMulBiasAdd"; + + return Status::OK(); +} + +// Register the pattern +// REGISTER_FUSION_PATTERN(MatMulBiasAddFusion); + +// // Register kernel availability +// REGISTER_FUSION_KERNEL(MatMulBiasAddFusion, []() { return true; }); + +} // namespace musa_fusion +} // namespace grappler +} // namespace tensorflow diff --git a/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h new file mode 100644 index 00000000..3661d9d5 --- /dev/null +++ b/musa_ext/mu/graph_fusion/matmul_biasadd_fusion.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include "mu/graph_fusion/fusion_pattern_manager.h" + +namespace tensorflow { +namespace grappler { +namespace musa_fusion { + +// Computes: MatMul + BiasAdd + +class MatMulBiasAddFusion : public FusionPattern { + public: + MatMulBiasAddFusion() = default; + ~MatMulBiasAddFusion() override = default; + + FusionMatchResult Match(const GraphDef& graph, + int start_node_idx) const override; + + Status Apply(GraphDef* graph, + const FusionMatchResult& match_result) const override; + + int GetPriority() const override { return 98; } + + bool IsKernelAvailable() const override; + + std::string GetName() const override { return "MatMulBiasAddFusion"; } + + std::string GetFallbackReason() const override { + if (!kernel_available_) { + return "MatMulBiasAddFusion kernel not available on this device"; + } + return ""; + } + + private: + // Kernel availability flag + mutable bool kernel_available_ = true; + mutable bool kernel_checked_ = false; +}; + +} // namespace musa_fusion +} // namespace grappler +} // namespace tensorflow diff --git a/musa_ext/mu/graph_fusion/linear_relu_fusion.cc b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc similarity index 85% rename from musa_ext/mu/graph_fusion/linear_relu_fusion.cc rename to musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc index e4dd8052..6b9b2dfa 100644 --- a/musa_ext/mu/graph_fusion/linear_relu_fusion.cc +++ b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.cc @@ -1,4 +1,4 @@ -#include "mu/graph_fusion/linear_relu_fusion.h" +#include "mu/graph_fusion/matmul_biasrelu_fusion.h" #include #include @@ -47,7 +47,7 @@ bool HasOriginalSuffix(const std::string& node_name) { } // namespace -bool LinearReluFusion::IsKernelAvailable() const { +bool MatmulBiasReluFusion::IsKernelAvailable() const { if (!kernel_checked_) { kernel_available_ = true; kernel_checked_ = true; @@ -55,8 +55,8 @@ bool LinearReluFusion::IsKernelAvailable() const { return kernel_available_; } -FusionMatchResult LinearReluFusion::Match(const GraphDef& graph, - int start_node_idx) const { +FusionMatchResult MatmulBiasReluFusion::Match(const GraphDef& graph, + int start_node_idx) const { FusionMatchResult result; if (start_node_idx < 0 || start_node_idx >= graph.node_size()) { return result; @@ -117,10 +117,11 @@ FusionMatchResult LinearReluFusion::Match(const GraphDef& graph, return result; } -Status LinearReluFusion::Apply(GraphDef* graph, - const FusionMatchResult& match_result) const { +Status MatmulBiasReluFusion::Apply( + GraphDef* graph, const FusionMatchResult& match_result) const { if (!match_result.IsValid()) { - return Status(error::INVALID_ARGUMENT, "Invalid LinearRelu match result"); + return Status(error::INVALID_ARGUMENT, + "Invalid MatmulBiasReluFusion match result"); } if (!IsKernelAvailable()) { @@ -136,7 +137,7 @@ Status LinearReluFusion::Apply(GraphDef* graph, matmul_it == match_result.captured_nodes.end() || bias_it == match_result.captured_nodes.end()) { return Status(error::INVALID_ARGUMENT, - "Missing required nodes in LinearRelu pattern"); + "Missing required nodes in MatmulBiasReluFusion pattern"); } const NodeDef* output_node = output_it->second; @@ -148,8 +149,8 @@ Status LinearReluFusion::Apply(GraphDef* graph, // Check if this output node has already been fused (avoid duplicates) for (const auto& node : graph->node()) { - if (node.name() == original_name && node.op() == "MusaLinearRelu") { - VLOG(1) << "MusaLinearRelu: Output node " << original_name + if (node.name() == original_name && node.op() == "MusaMatmulBiasRelu") { + VLOG(1) << "MusaMatmulBiasRelu: Output node " << original_name << " is already a fused node, skipping"; return Status::OK(); } @@ -168,8 +169,8 @@ Status LinearReluFusion::Apply(GraphDef* graph, "Failed to find output node in graph: " + original_name); } - VLOG(1) << "LinearReluFusion: Replacing " << original_name - << " with MusaLinearRelu"; + VLOG(1) << "MusaMatmulBiasReluFusion: Replacing " << original_name + << " with MusaMatmulBiasReluFusion"; NodeDef* original_output_node = graph->mutable_node(output_node_idx); const std::string output_device = original_output_node->device(); @@ -192,10 +193,10 @@ Status LinearReluFusion::Apply(GraphDef* graph, NodeDef* fused_node = graph->add_node(); fused_node->set_name(original_name); - fused_node->set_op("MusaLinearRelu"); + fused_node->set_op("MusaMatmulBiasRelu"); fused_node->set_device(output_device); - // MusaLinearRelu inputs: a, b, bias + // MusaMatmulBiasRelu inputs: a, b, bias fused_node->add_input(matmul_node->input(0)); fused_node->add_input(matmul_node->input(1)); // bias input might need port handling if it's more than just a name @@ -230,17 +231,17 @@ Status LinearReluFusion::Apply(GraphDef* graph, {matmul_node->input(0), matmul_node->input(1), bias_node->name(), original_name}); - VLOG(1) << "LinearReluFusion: Successfully replaced '" << original_name - << "' with MusaLinearRelu"; + VLOG(1) << "MatmulBiasReluFusion: Successfully replaced '" << original_name + << "' with MusaMatmulBiasRelu"; return Status::OK(); } // Register the pattern -REGISTER_FUSION_PATTERN(LinearReluFusion); +REGISTER_FUSION_PATTERN(MatmulBiasReluFusion); // Register kernel availability -REGISTER_FUSION_KERNEL(LinearReluFusion, []() { return true; }); +REGISTER_FUSION_KERNEL(MatmulBiasReluFusion, []() { return true; }); } // namespace musa_fusion } // namespace grappler diff --git a/musa_ext/mu/graph_fusion/linear_relu_fusion.h b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h similarity index 70% rename from musa_ext/mu/graph_fusion/linear_relu_fusion.h rename to musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h index 97847e1f..c671c5fd 100644 --- a/musa_ext/mu/graph_fusion/linear_relu_fusion.h +++ b/musa_ext/mu/graph_fusion/matmul_biasrelu_fusion.h @@ -9,10 +9,10 @@ namespace musa_fusion { // Computes: MatMul + BiasAdd + Relu -class LinearReluFusion : public FusionPattern { +class MatmulBiasReluFusion : public FusionPattern { public: - LinearReluFusion() = default; - ~LinearReluFusion() override = default; + MatmulBiasReluFusion() = default; + ~MatmulBiasReluFusion() override = default; FusionMatchResult Match(const GraphDef& graph, int start_node_idx) const override; @@ -20,15 +20,15 @@ class LinearReluFusion : public FusionPattern { Status Apply(GraphDef* graph, const FusionMatchResult& match_result) const override; - int GetPriority() const override { return 100; } + int GetPriority() const override { return 96; } bool IsKernelAvailable() const override; - std::string GetName() const override { return "LinearReluFusion"; } + std::string GetName() const override { return "MatmulBiasReluFusion"; } std::string GetFallbackReason() const override { if (!kernel_available_) { - return "LinearReluFusion kernel not available on this device"; + return "MatmulBiasReluFusion kernel not available on this device"; } return ""; } diff --git a/test/fusion/matmul_biasadd_fusion.py b/test/fusion/matmul_biasadd_fusion.py new file mode 100644 index 00000000..f98d4c18 --- /dev/null +++ b/test/fusion/matmul_biasadd_fusion.py @@ -0,0 +1,337 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MatMul+BiasAdd fusion.""" + +import os +import numpy as np +import tensorflow as tf +from musa_test_utils import MUSATestCase + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 + + +def create_config_with_musa_optimizer(): + """Create ConfigProto with MUSA optimizer enabled.""" + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + rewriter_config = config.graph_options.rewrite_options + + custom_optimizer = rewriter_config.custom_optimizers.add() + custom_optimizer.name = "musa_graph_optimizer" + + rewriter_config.min_graph_nodes = -1 + rewriter_config.optimizers.extend(["musa_graph_optimizer"]) + + return config + + +class MatMulBiasAddFusionTest(MUSATestCase): + """Tests for MatMul+BiasAdd fusion.""" + + def test_matmul_biasadd_fusion_basic(self): + """Test MatMul+BiasAdd pattern fusion.""" + np.random.seed(42) + tf.random.set_seed(42) + + m, k, n = 4, 8, 16 + + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference implementation (CPU) + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np) + w_tf = tf.constant(w_np) + b_tf = tf.constant(b_np) + + mm = tf.matmul(x_tf, w_tf) + expected_out = tf.nn.bias_add(mm, b_tf) + # Add a consumer to ensure it's not pruned and has someone to redirect to + expected_out = expected_out * 2.0 + + # Build graph with explicit MUSA device placement + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x") + w = tf.constant(w_np, dtype=tf.float32, name="w") + b = tf.constant(b_np, dtype=tf.float32, name="b") + + # This pattern should be matched by MatMulBiasAddFusion + mm_musa = tf.matmul(x, w) + bias_musa = tf.nn.bias_add(mm_musa, b) + # Add a consumer node + output = bias_musa * 2.0 + + config = create_config_with_musa_optimizer() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual_out = sess.run(output, feed_dict={x: x_np}) + + self.assertAllClose(actual_out, expected_out.numpy(), rtol=1e-5, atol=1e-5) + + def test_matmul_biasadd_fusion_applied(self): + """Verify that MatMul+BiasAdd fusion is applied: MusaMatMulBiasAdd node exists in optimized graph.""" + np.random.seed(123) + tf.random.set_seed(123) + + m, k, n = 4, 8, 16 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x") + w = tf.constant(w_np, dtype=tf.float32, name="w") + b = tf.constant(b_np, dtype=tf.float32, name="b") + + mm_musa = tf.matmul(x, w) + bias_musa = tf.nn.bias_add(mm_musa, b) + # Add a consumer node + output = bias_musa * 2.0 + + config = create_config_with_musa_optimizer() + run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) + run_metadata = tf.compat.v1.RunMetadata() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + sess.run(output, feed_dict={x: x_np}, options=run_options, run_metadata=run_metadata) + + has_fused_node = False + for partition_graph in run_metadata.partition_graphs: + for node in partition_graph.node: + if node.op == "MusaMatMulBiasAdd": + has_fused_node = True + break + + self.assertTrue( + has_fused_node, + "MusaMatMulBiasAdd fusion was NOT applied to the graph" + ) + + def test_matmul_biasadd_fusion_various_batch_sizes(self): + """Test fusion correctness across several batch sizes.""" + np.random.seed(7) + tf.random.set_seed(7) + + k, n = 6, 10 + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + for m in (1, 3, 8): + x_np = np.random.randn(m, k).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np) + w_tf = tf.constant(w_np) + b_tf = tf.constant(b_np) + expected = tf.nn.bias_add(tf.matmul(x_tf, w_tf), b_tf) * 1.5 + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_bs") + w = tf.constant(w_np, dtype=tf.float32, name="w_bs") + b = tf.constant(b_np, dtype=tf.float32, name="b_bs") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + # extra consumer + out = bias * 1.5 + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-5, atol=1e-5) + + def test_matmul_biasadd_fusion_not_applied_with_intervening_op(self): + """If an extra op exists between MatMul and BiasAdd, fusion should not occur.""" + np.random.seed(99) + tf.random.set_seed(99) + + m, k, n = 2, 5, 7 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_int") + w = tf.constant(w_np, dtype=tf.float32, name="w_int") + b = tf.constant(b_np, dtype=tf.float32, name="b_int") + + mm = tf.matmul(x, w) + # Insert an identity (or any intervening op) to block fusion + mid = tf.identity(mm, name="intervening_identity") + bias = tf.nn.bias_add(mid, b) + output = bias * 2.0 + + config = create_config_with_musa_optimizer() + run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) + run_metadata = tf.compat.v1.RunMetadata() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + sess.run(output, feed_dict={x: x_np}, options=run_options, run_metadata=run_metadata) + + has_fused_node = False + for partition_graph in run_metadata.partition_graphs: + for node in partition_graph.node: + if node.op == "MusaMatMulBiasAdd": + has_fused_node = True + break + + self.assertFalse( + has_fused_node, + "MusaMatMulBiasAdd fusion should NOT be applied when an intervening op exists" + ) + + def test_matmul_biasadd_fusion_dtypes(self): + """Test fusion correctness across multiple dtypes: float32, float16, bfloat16.""" + np.random.seed(21) + tf.random.set_seed(21) + + m, k, n = 3, 6, 8 + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + dtypes = [tf.float32, tf.float16, tf.bfloat16] + + for dtype in dtypes: + x_np = np.random.randn(m, k).astype(np.float32) + + # Reference computed in float32 + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np, dtype=tf.float32) + w_tf = tf.constant(w_np, dtype=tf.float32) + b_tf = tf.constant(b_np, dtype=tf.float32) + expected = tf.nn.bias_add(tf.matmul(x_tf, w_tf), b_tf) * 0.75 + expected_f32 = expected.numpy() + + # Build MUSA graph: accept float32 feeds then cast to target dtype inside graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x_ph = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_dt") + x = tf.cast(x_ph, dtype) + w = tf.constant(w_np, dtype=dtype, name="w_dt") + b = tf.constant(b_np, dtype=dtype, name="b_dt") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + out = bias * tf.constant(0.75, dtype=dtype) + # cast back to float32 for stable comparison + out_f32 = tf.cast(out, tf.float32) + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out_f32, feed_dict={x_ph: x_np}) + + if dtype == tf.float32: + rtol, atol = 1e-5, 1e-5 + elif dtype == tf.float16: + rtol, atol = 1e-2, 1e-2 + else: # bfloat16 + rtol, atol = 2e-2, 2e-2 + + self.assertAllClose(actual, expected_f32, rtol=rtol, atol=atol) + + def test_matmul_biasadd_fusion_large_features(self): + """Optional large-feature test. Enable by setting MUSA_RUN_LARGE_TESTS=1.""" + if not os.environ.get("MUSA_RUN_LARGE_TESTS"): + self.skipTest("Large tests disabled; set MUSA_RUN_LARGE_TESTS=1 to run") + + np.random.seed(321) + tf.random.set_seed(321) + + m, k, n = 128, 2048, 1024 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + expected = tf.nn.bias_add( + tf.matmul(tf.constant(x_np), tf.constant(w_np)), + tf.constant(b_np) + ) * 0.9 + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_large_feat") + w = tf.constant(w_np, dtype=tf.float32, name="w_large_feat") + b = tf.constant(b_np, dtype=tf.float32, name="b_large_feat") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + out = bias * 0.9 + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-4, atol=1e-4) + + def test_matmul_biasadd_fusion_large_batch(self): + """Optional large-batch test. Enable by setting MUSA_RUN_LARGE_TESTS=1.""" + if not os.environ.get("MUSA_RUN_LARGE_TESTS"): + self.skipTest("Large tests disabled; set MUSA_RUN_LARGE_TESTS=1 to run") + + np.random.seed(123) + tf.random.set_seed(123) + + m, k, n = 2048, 512, 512 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + expected = tf.nn.bias_add( + tf.matmul(tf.constant(x_np), tf.constant(w_np)), + tf.constant(b_np) + ) + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_large_batch") + w = tf.constant(w_np, dtype=tf.float32, name="w_large_batch") + b = tf.constant(b_np, dtype=tf.float32, name="b_large_batch") + + mm = tf.matmul(x, w) + out = tf.nn.bias_add(mm, b) + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + tf.test.main() \ No newline at end of file diff --git a/test/fusion/matmul_biasadd_fusion_test.py b/test/fusion/matmul_biasadd_fusion_test.py new file mode 100644 index 00000000..7ce2e7bf --- /dev/null +++ b/test/fusion/matmul_biasadd_fusion_test.py @@ -0,0 +1,337 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MatMul+BiasAdd fusion.""" + +import os +import numpy as np +import tensorflow as tf +from musa_test_utils import MUSATestCase + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 + + +def create_config_with_musa_optimizer(): + """Create ConfigProto with MUSA optimizer enabled.""" + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + rewriter_config = config.graph_options.rewrite_options + + custom_optimizer = rewriter_config.custom_optimizers.add() + custom_optimizer.name = "musa_graph_optimizer" + + rewriter_config.min_graph_nodes = -1 + rewriter_config.optimizers.extend(["musa_graph_optimizer"]) + + return config + + +class MatMulBiasAddFusionTest(MUSATestCase): + """Tests for MatMul+BiasAdd fusion.""" + + def test_matmul_biasadd_fusion_basic(self): + """Test MatMul+BiasAdd pattern fusion.""" + np.random.seed(42) + tf.random.set_seed(42) + + m, k, n = 4, 8, 16 + + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference implementation (CPU) + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np) + w_tf = tf.constant(w_np) + b_tf = tf.constant(b_np) + + mm = tf.matmul(x_tf, w_tf) + expected_out = tf.nn.bias_add(mm, b_tf) + # Add a consumer to ensure it's not pruned and has someone to redirect to + expected_out = expected_out * 2.0 + + # Build graph with explicit MUSA device placement + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x") + w = tf.constant(w_np, dtype=tf.float32, name="w") + b = tf.constant(b_np, dtype=tf.float32, name="b") + + # This pattern should be matched by MatMulBiasAddFusion + mm_musa = tf.matmul(x, w) + bias_musa = tf.nn.bias_add(mm_musa, b) + # Add a consumer node + output = bias_musa * 2.0 + + config = create_config_with_musa_optimizer() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual_out = sess.run(output, feed_dict={x: x_np}) + + self.assertAllClose(actual_out, expected_out.numpy(), rtol=1e-5, atol=1e-5) + + def test_matmul_biasadd_fusion_applied(self): + """Verify that MatMul+BiasAdd fusion is applied: MusaMatMulBiasAdd node exists in optimized graph.""" + np.random.seed(123) + tf.random.set_seed(123) + + m, k, n = 4, 8, 16 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x") + w = tf.constant(w_np, dtype=tf.float32, name="w") + b = tf.constant(b_np, dtype=tf.float32, name="b") + + mm_musa = tf.matmul(x, w) + bias_musa = tf.nn.bias_add(mm_musa, b) + # Add a consumer node + output = bias_musa * 2.0 + + config = create_config_with_musa_optimizer() + run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) + run_metadata = tf.compat.v1.RunMetadata() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + sess.run(output, feed_dict={x: x_np}, options=run_options, run_metadata=run_metadata) + + has_fused_node = False + for partition_graph in run_metadata.partition_graphs: + for node in partition_graph.node: + if node.op == "MusaMatMulBiasAdd": + has_fused_node = True + break + + self.assertTrue( + has_fused_node, + "MusaMatMulBiasAdd fusion was NOT applied to the graph" + ) + + def test_matmul_biasadd_fusion_various_batch_sizes(self): + """Test fusion correctness across several batch sizes.""" + np.random.seed(7) + tf.random.set_seed(7) + + k, n = 6, 10 + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + for m in (1, 3, 8): + x_np = np.random.randn(m, k).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np) + w_tf = tf.constant(w_np) + b_tf = tf.constant(b_np) + expected = tf.nn.bias_add(tf.matmul(x_tf, w_tf), b_tf) * 1.5 + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_bs") + w = tf.constant(w_np, dtype=tf.float32, name="w_bs") + b = tf.constant(b_np, dtype=tf.float32, name="b_bs") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + # extra consumer + out = bias * 1.5 + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-5, atol=1e-5) + + def test_matmul_biasadd_fusion_not_applied_with_intervening_op(self): + """If an extra op exists between MatMul and BiasAdd, fusion should not occur.""" + np.random.seed(99) + tf.random.set_seed(99) + + m, k, n = 2, 5, 7 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_int") + w = tf.constant(w_np, dtype=tf.float32, name="w_int") + b = tf.constant(b_np, dtype=tf.float32, name="b_int") + + mm = tf.matmul(x, w) + # Insert an identity (or any intervening op) to block fusion + mid = tf.identity(mm, name="intervening_identity") + bias = tf.nn.bias_add(mid, b) + output = bias * 2.0 + + config = create_config_with_musa_optimizer() + run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) + run_metadata = tf.compat.v1.RunMetadata() + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + sess.run(output, feed_dict={x: x_np}, options=run_options, run_metadata=run_metadata) + + has_fused_node = False + for partition_graph in run_metadata.partition_graphs: + for node in partition_graph.node: + if node.op == "MusaMatMulBiasAdd": + has_fused_node = True + break + + self.assertFalse( + has_fused_node, + "MusaMatMulBiasAdd fusion should NOT be applied when an intervening op exists" + ) + + def test_matmul_biasadd_fusion_dtypes(self): + """Test fusion correctness across multiple dtypes: float32, float16, bfloat16.""" + np.random.seed(21) + tf.random.set_seed(21) + + m, k, n = 3, 6, 8 + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + dtypes = [tf.float32, tf.float16, tf.bfloat16] + + for dtype in dtypes: + x_np = np.random.randn(m, k).astype(np.float32) + + # Reference computed in float32 + with tf.device('/CPU:0'): + x_tf = tf.constant(x_np, dtype=tf.float32) + w_tf = tf.constant(w_np, dtype=tf.float32) + b_tf = tf.constant(b_np, dtype=tf.float32) + expected = tf.nn.bias_add(tf.matmul(x_tf, w_tf), b_tf) * 0.75 + expected_f32 = expected.numpy() + + # Build MUSA graph: accept float32 feeds then cast to target dtype inside graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x_ph = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_dt") + x = tf.cast(x_ph, dtype) + w = tf.constant(w_np, dtype=dtype, name="w_dt") + b = tf.constant(b_np, dtype=dtype, name="b_dt") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + out = bias * tf.constant(0.75, dtype=dtype) + # cast back to float32 for stable comparison + out_f32 = tf.cast(out, tf.float32) + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out_f32, feed_dict={x_ph: x_np}) + + if dtype == tf.float32: + rtol, atol = 1e-5, 1e-5 + elif dtype == tf.float16: + rtol, atol = 1e-2, 1e-2 + else: # bfloat16 + rtol, atol = 2e-2, 2e-2 + + self.assertAllClose(actual, expected_f32, rtol=rtol, atol=atol) + + def test_matmul_biasadd_fusion_large_features(self): + """Optional large-feature test. Enable by setting MUSA_RUN_LARGE_TESTS=1.""" + if not os.environ.get("MUSA_RUN_LARGE_TESTS"): + self.skipTest("Large tests disabled; set MUSA_RUN_LARGE_TESTS=1 to run") + + np.random.seed(321) + tf.random.set_seed(321) + + m, k, n = 128, 2048, 1024 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + expected = tf.nn.bias_add( + tf.matmul(tf.constant(x_np), tf.constant(w_np)), + tf.constant(b_np) + ) * 0.9 + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_large_feat") + w = tf.constant(w_np, dtype=tf.float32, name="w_large_feat") + b = tf.constant(b_np, dtype=tf.float32, name="b_large_feat") + + mm = tf.matmul(x, w) + bias = tf.nn.bias_add(mm, b) + out = bias * 0.9 + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-4, atol=1e-4) + + def test_matmul_biasadd_fusion_large_batch(self): + """Optional large-batch test. Enable by setting MUSA_RUN_LARGE_TESTS=1.""" + if not os.environ.get("MUSA_RUN_LARGE_TESTS"): + self.skipTest("Large tests disabled; set MUSA_RUN_LARGE_TESTS=1 to run") + + np.random.seed(123) + tf.random.set_seed(123) + + m, k, n = 2048, 512, 512 + x_np = np.random.randn(m, k).astype(np.float32) + w_np = np.random.randn(k, n).astype(np.float32) + b_np = np.random.randn(n).astype(np.float32) + + # Reference on CPU + with tf.device('/CPU:0'): + expected = tf.nn.bias_add( + tf.matmul(tf.constant(x_np), tf.constant(w_np)), + tf.constant(b_np) + ) + + # MUSA graph + graph = tf.Graph() + with graph.as_default(): + with tf.device('/device:MUSA:0'): + x = tf.compat.v1.placeholder(tf.float32, shape=[None, k], name="x_large_batch") + w = tf.constant(w_np, dtype=tf.float32, name="w_large_batch") + b = tf.constant(b_np, dtype=tf.float32, name="b_large_batch") + + mm = tf.matmul(x, w) + out = tf.nn.bias_add(mm, b) + + config = create_config_with_musa_optimizer() + with tf.compat.v1.Session(graph=graph, config=config) as sess: + actual = sess.run(out, feed_dict={x: x_np}) + + self.assertAllClose(actual, expected.numpy(), rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + tf.test.main() diff --git a/test/fusion/linear_relu_fusion_test.py b/test/fusion/matmulbias_relu_fusion_test.py similarity index 96% rename from test/fusion/linear_relu_fusion_test.py rename to test/fusion/matmulbias_relu_fusion_test.py index 39cab819..ed8b8512 100644 --- a/test/fusion/linear_relu_fusion_test.py +++ b/test/fusion/matmulbias_relu_fusion_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -"""Tests for Linear+Relu fusion.""" +"""Tests for Matmul+BiasAdd+Relu fusion.""" import os import numpy as np @@ -123,11 +123,11 @@ def test_linear_relu_fusion_applied(self): has_fused_node = False for partition_graph in run_metadata.partition_graphs: for node in partition_graph.node: - if node.op == "MusaLinearRelu": + if node.op == "MusaMatmulBiasRelu": has_fused_node = True break - self.assertTrue(has_fused_node, "MusaLinearRelu fusion was NOT applied to the graph") + self.assertTrue(has_fused_node, "MusaMatmulBiasRelu fusion was NOT applied to the graph") def test_linear_relu_fusion_various_batch_sizes(self): """Test fusion correctness across several batch sizes.""" @@ -169,7 +169,7 @@ def test_linear_relu_fusion_various_batch_sizes(self): self.assertAllClose(actual, expected.numpy(), rtol=1e-5, atol=1e-5) def test_linear_relu_fusion_not_applied_with_intervening_op(self): - """If an extra op exists between MatMul and BiasAdd, fusion should not occur.""" + """If an extra op exists between Matmul and BiasAdd, fusion should not occur.""" m, k, n = 2, 5, 7 x_np = np.random.randn(m, k).astype(np.float32) w_np = np.random.randn(k, n).astype(np.float32) @@ -200,11 +200,11 @@ def test_linear_relu_fusion_not_applied_with_intervening_op(self): has_fused_node = False for partition_graph in run_metadata.partition_graphs: for node in partition_graph.node: - if node.op == "MusaLinearRelu": + if node.op == "MusaMatmulBiasRelu": has_fused_node = True break - self.assertFalse(has_fused_node, "MusaLinearRelu fusion should NOT be applied when an intervening op exists") + self.assertFalse(has_fused_node, "MusaMatmulBiasRelu fusion should NOT be applied when an intervening op exists") def test_linear_relu_fusion_dtypes(self): """Test fusion correctness across multiple dtypes: float32, float16, bfloat16."""