diff --git a/musa_ext/mu/optimizer/musa_graph_optimizer.cc b/musa_ext/mu/optimizer/musa_graph_optimizer.cc index 99ae4822..17fa966a 100644 --- a/musa_ext/mu/optimizer/musa_graph_optimizer.cc +++ b/musa_ext/mu/optimizer/musa_graph_optimizer.cc @@ -75,10 +75,10 @@ struct MusaOptimizerConfigs { // Keep disabled (handled internally by MUSA) TriState auto_mixed_precision = TriState::kOff; - TriState layout_optimizer = TriState::kOff; // Keep as Default or enable as needed TriState disable_model_pruning = TriState::kDefault; + TriState layout_optimizer = TriState::kOff; // MUSA handles layout internally TriState loop_optimization = TriState::kDefault; TriState dependency_optimization = TriState::kDefault; TriState auto_parallel = TriState::kDefault; @@ -396,6 +396,12 @@ class MusaGraphOptimizer : public CustomGraphOptimizer { VLOG(1) << "MusaGraphOptimizer: Optimizing graph with " << optimized_graph->node_size() << " nodes"; + if (VLOG_IS_ON(2)) { + VLOG(2) << "Nodes in graph:"; + for (const auto& node : optimized_graph->node()) { + VLOG(2) << " - " << node.name() << " (" << node.op() << ")"; + } + } // Step 1: Layout optimization (NHWC -> NCHW) if (configs_.layout_optimizer != TriState::kOff) { @@ -576,6 +582,7 @@ class MusaGraphOptimizer : public CustomGraphOptimizer { // Layout Optimization void OptimizeLayout(GraphDef* graph) { + VLOG(1) << "MusaGraphOptimizer: Starting layout optimization"; bool changed = true; int iteration = 0; const int kMaxIterations = 5; @@ -646,6 +653,7 @@ class MusaGraphOptimizer : public CustomGraphOptimizer { // AMP Optimization void OptimizeAMP(GraphDef* graph) { + VLOG(1) << "MusaGraphOptimizer: Starting AMP fix optimization"; std::unordered_map should_convert; AnalyzeGraphForAMP(*graph, should_convert); @@ -666,7 +674,7 @@ class MusaGraphOptimizer : public CustomGraphOptimizer { continue; } - ConvertNodeToLowPrecision(graph, node); + ConvertNodeToLowPrecision(graph, node, should_convert); } } @@ -745,7 +753,9 @@ class MusaGraphOptimizer : public CustomGraphOptimizer { return DT_INVALID; } - bool ConvertNodeToLowPrecision(GraphDef* graph, NodeDef* node) { + bool ConvertNodeToLowPrecision( + GraphDef* graph, NodeDef* node, + std::unordered_map should_convert) { string op_name = node->name(); string device = node->device(); DataType target_t = amp_config_.target_dtype; @@ -764,7 +774,13 @@ class MusaGraphOptimizer : public CustomGraphOptimizer { new_inputs.push_back(input_name); continue; } - + // no need to insert cast node if upstream node is convertible or already + // casted + string upstream_name = GetNodeNameFromInput(input_name); + if (should_convert[upstream_name]) { + new_inputs.push_back(input_name); + continue; + } if (input_name.find("/CastF2Lower") != std::string::npos) { new_inputs.push_back(input_name); continue; @@ -790,7 +806,9 @@ class MusaGraphOptimizer : public CustomGraphOptimizer { for (int j = 0; j < graph->node_size(); ++j) { NodeDef* consumer = graph->mutable_node(j); if (consumer->name() == cast_out_name) continue; - + if (should_convert[consumer->name()]) continue; + // no need to inset cast node if downstream node is convertible or already + // casted for (int k = 0; k < consumer->input_size(); ++k) { string inp = consumer->input(k); diff --git a/test/AMP/InsertCastNode_test.py b/test/AMP/InsertCastNode_test.py new file mode 100644 index 00000000..1f08d2d6 --- /dev/null +++ b/test/AMP/InsertCastNode_test.py @@ -0,0 +1,248 @@ +import os +import time +import json +import numpy as np +import tensorflow as tf + +tf.compat.v1.disable_eager_execution() + + +# ========================= +# 1. 环境变量 +# ========================= +# 日志等级:0=INFO, 1=WARNING, 2=ERROR +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" +os.environ["TF_CPP_MAX_VLOG_LEVEL"] = "2" +# 只打开 musa_graph_optimizer 的 VLOG 1 +# 这个变量名来自你给的模板 +os.environ["TF_CPP_VMODULE"] = "musa_graph_optimizer_FixAMP=2" + + +# ========================= +# 2. 可调测试参数 +# ========================= +BATCH = 4 +SEQ = 128 +HIDDEN = 768 +NUM_BLOCKS = 1 # 可以调大到 32 / 48,让冗余 cast 更明显 +WARMUP_STEPS = 30 +BENCH_STEPS = 10 +SEED = 1234 + +# 你也可以切换成 BF16 +PRECISION_MODE = "FP16" # "FP16" or "BF16" + +# 是否使用 aggressive_mode +AGGRESSIVE_MODE = False + +# 为了让测试更聚焦 AMP,这里默认关闭 layout optimizer +DISABLE_LAYOUT_OPTIMIZER = True + + + + + + +musa_plugin_path = "/workspace/tensorflow_musa_extension/build/libmusa_plugin.so" + +# ========================================== +# 3. 加载 MUSA 插件 +# ========================================== +def load_musa_plugin(): + if os.path.exists(musa_plugin_path): + try: + tf.load_op_library(musa_plugin_path) + print(f">>>> [MUSA] Plugin loaded successfully from: {musa_plugin_path}") + except Exception as e: + print(f"!!!! [MUSA] Failed to load plugin: {e}") + else: + print(f"!!!! [MUSA] Plugin not found at {musa_plugin_path}, assuming built-in.") + +# ========================= +# 4. 图构建 +# ========================= +def dense_block(x, in_dim, out_dim, block_id): + """ + MatMul -> BiasAdd -> Relu + 用 numpy 预生成常量,避免在 MUSA 上创建随机初始化 op。 + """ + rng = np.random.RandomState(SEED + block_id) + + w_np = rng.randn(in_dim, out_dim).astype(np.float32) * 0.02 + b_np = rng.randn(out_dim).astype(np.float32) + + with tf.name_scope(f"block_{block_id}"): + w = tf.constant(w_np, dtype=tf.float32, name="w") + b = tf.constant(b_np, dtype=tf.float32, name="b") + + y = tf.matmul(x, w, name="matmul") + y = tf.nn.bias_add(y, b, name="bias_add") + y = tf.nn.relu(y, name="relu") + return y + + +def build_test_graph(num_blocks=NUM_BLOCKS): + """ + 输入 [BATCH, SEQ, HIDDEN] + 先 reshape 成二维,再串很多 block,最后做一个 reduce_mean, + 保证图里既有 AMP 候选,又有输出 fetch。 + """ + graph = tf.Graph() + with graph.as_default(): + with tf.device("/device:MUSA:0"): + x = tf.compat.v1.placeholder( + tf.float32, shape=[BATCH, SEQ, HIDDEN], name="input" + ) + + y = tf.reshape(x, [BATCH * SEQ, HIDDEN], name="flatten") + + for i in range(num_blocks): + y = dense_block(y, HIDDEN, HIDDEN, i) + + # 避免整个输出过大,fetch 一个较小结果 + out = tf.reduce_mean(y, axis=1, name="reduce_mean") + out = tf.identity(out, name="final_output") + + return graph + + +# ========================= +# 5. Session 配置 +# ========================= +def make_session_config(enable_amp): + config = tf.compat.v1.ConfigProto() + + rewriter = config.graph_options.rewrite_options + opt = rewriter.custom_optimizers.add() + opt.name = "musa_graph_optimizer" + + # 只使用你当前文件里已经存在的参数 + opt.parameter_map["aggressive_mode"].b = AGGRESSIVE_MODE + opt.parameter_map["precision_mode"].s = PRECISION_MODE.encode("utf-8") + opt.parameter_map["disable_layout_optimizer"].b = DISABLE_LAYOUT_OPTIMIZER + opt.parameter_map["disable_amp"].b = (not enable_amp) + + return config + + +# ========================= +# 6. 计时函数 +# ========================= +def benchmark_one_case(graph, enable_amp, input_data): + """ + 返回: + { + "enable_amp": bool, + "warmup_avg_ms": ..., + "bench_avg_ms": ..., + "bench_p50_ms": ..., + "bench_p90_ms": ..., + "bench_p95_ms": ..., + "bench_min_ms": ..., + "bench_max_ms": ..., + } + """ + config = make_session_config(enable_amp=enable_amp) + + with tf.compat.v1.Session(graph=graph, config=config) as sess: + # 初始化变量 + sess.run(tf.compat.v1.global_variables_initializer()) + + x = graph.get_tensor_by_name("input:0") + out = graph.get_tensor_by_name("final_output:0") + + # Warmup + warmup_times = [] + for _ in range(WARMUP_STEPS): + with tf.compat.v1.Session(graph=graph, config=config) as sess: + t0 = time.perf_counter() + _ = sess.run(out, feed_dict={x: input_data}) + t1 = time.perf_counter() + warmup_times.append((t1 - t0) * 1000.0) + + # Benchmark + bench_times = [] + bench_results = [] + for _ in range(BENCH_STEPS): + with tf.compat.v1.Session(graph=graph, config=config) as sess: + t0 = time.perf_counter() + result = sess.run(out, feed_dict={x: input_data}) + t1 = time.perf_counter() + bench_times.append((t1 - t0) * 1000.0) + bench_results.append(result) + if(enable_amp): + with open('result_AMP.txt', 'w') as f: + for line in bench_results: + f.write(f"{line}\n") + bench_arr = np.array(bench_times, dtype=np.float64) + warmup_arr = np.array(warmup_times, dtype=np.float64) + + return { + "enable_amp": enable_amp, + "warmup_avg_ms": float(np.mean(warmup_arr)), + "bench_avg_ms": float(np.mean(bench_arr)), + "bench_avg_result": float(np.mean(bench_results)), + "bench_p50_ms": float(np.percentile(bench_arr, 50)), + "bench_p90_ms": float(np.percentile(bench_arr, 90)), + "bench_p95_ms": float(np.percentile(bench_arr, 95)), + "bench_min_ms": float(np.min(bench_arr)), + "bench_max_ms": float(np.max(bench_arr)), + } + + +# ========================= +# 7. 主流程 +# ========================= +def main(): + load_musa_plugin() + np.random.seed(SEED) + + graph = build_test_graph(num_blocks=NUM_BLOCKS) + + input_data = np.random.randn(BATCH, SEQ, HIDDEN).astype(np.float32) + + print("=" * 80) + print("Benchmark config") + print(f"BATCH={BATCH}, SEQ={SEQ}, HIDDEN={HIDDEN}") + print(f"NUM_BLOCKS={NUM_BLOCKS}") + print(f"WARMUP_STEPS={WARMUP_STEPS}, BENCH_STEPS={BENCH_STEPS}") + print(f"PRECISION_MODE={PRECISION_MODE}") + print(f"AGGRESSIVE_MODE={AGGRESSIVE_MODE}") + print(f"DISABLE_LAYOUT_OPTIMIZER={DISABLE_LAYOUT_OPTIMIZER}") + print("=" * 80) + + # Case 1: AMP 关闭 + print("AMP OFF =============") + result_no_amp = benchmark_one_case( + graph=graph, + enable_amp=False, + input_data=input_data, + ) + print("AMP ON =============") + # Case 2: AMP 开启(当前版本可能含冗余 cast) + result_amp = benchmark_one_case( + graph=graph, + enable_amp=True, + input_data=input_data, + ) + + print("\n[Result] AMP OFF") + print(json.dumps(result_no_amp, indent=2)) + + print("\n[Result] AMP ON") + print(json.dumps(result_amp, indent=2)) + + speedup = result_no_amp["bench_avg_ms"] / result_amp["bench_avg_ms"] + print("\n[Summary]") + print(f"Speedup (AMP OFF / AMP ON) = {speedup:.4f}x") + + print("\n[How to use this script]") + print("1) 先用当前 optimizer 跑一遍,记录 AMP ON 的 bench_avg_ms") + print("2) 修改 OptimizeAMP,去掉冗余 cast") + print("3) 重新编译 plugin") + print("4) 用完全相同的脚本和参数再跑一遍") + print("5) 对比两次 AMP ON 的 bench_avg_ms / p50 / p95") + + +if __name__ == "__main__": + main() diff --git a/test/AMP/PrecisionCompare.py b/test/AMP/PrecisionCompare.py new file mode 100644 index 00000000..4cc97d8c --- /dev/null +++ b/test/AMP/PrecisionCompare.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Compare result_AMP.txt vs result_AMP_fix.txt + +Assumption: +- Each file contains multiple result arrays. +- Each result array is stored in a bracket block like: + [0.1 0.2 0.3 ...] +- Typically there should be 10 runs in each file. + +What this script does: +1. Parse all bracketed arrays from both txt files +2. Check run count and shape consistency +3. Compare each paired run: + - max abs error + - mean abs error + - RMSE + - max/mean relative error + - allclose / pass rate +4. Aggregate all runs +5. Give a simple conclusion on whether there is a large precision deviation +""" + +import re +import sys +import argparse +from pathlib import Path + +import numpy as np + + +def parse_result_file(path: str): + """ + Parse all bracketed float arrays from a txt file. + + Supports content like: + [0.1 0.2 0.3] + [0.4 0.5 0.6] + + Returns: + runs: list[np.ndarray] + """ + text = Path(path).read_text(encoding="utf-8", errors="ignore") + + # Match every [...] block, including multiline blocks + blocks = re.findall(r"\[(.*?)\]", text, flags=re.S) + if not blocks: + raise ValueError(f"No bracketed result blocks found in file: {path}") + + runs = [] + for i, block in enumerate(blocks): + # Normalize whitespace, then parse floats + normalized = " ".join(block.replace("\n", " ").split()) + arr = np.fromstring(normalized, sep=" ", dtype=np.float64) + + if arr.size == 0: + raise ValueError(f"Failed to parse block #{i} in file: {path}") + + runs.append(arr) + + return runs + + +def compare_two_arrays(ref: np.ndarray, test: np.ndarray, rtol=1e-5, atol=1e-7): + if ref.shape != test.shape: + raise ValueError(f"Shape mismatch: {ref.shape} vs {test.shape}") + + diff = test - ref + abs_err = np.abs(diff) + rel_err = abs_err / np.maximum(np.abs(ref), 1e-12) + + metrics = { + "shape": ref.shape, + "max_abs_err": float(abs_err.max()), + "mean_abs_err": float(abs_err.mean()), + "median_abs_err": float(np.median(abs_err)), + "rmse": float(np.sqrt(np.mean(diff ** 2))), + "max_rel_err": float(rel_err.max()), + "mean_rel_err": float(rel_err.mean()), + "relative_l2": float(np.linalg.norm(diff) / (np.linalg.norm(ref) + 1e-12)), + "pass_rate": float(np.isclose(test, ref, rtol=rtol, atol=atol).mean()), + "all_close": bool(np.allclose(test, ref, rtol=rtol, atol=atol)), + "ref_mean": float(ref.mean()), + "test_mean": float(test.mean()), + "ref_std": float(ref.std()), + "test_std": float(test.std()), + "has_nan_ref": bool(np.isnan(ref).any()), + "has_nan_test": bool(np.isnan(test).any()), + "has_inf_ref": bool(np.isinf(ref).any()), + "has_inf_test": bool(np.isinf(test).any()), + } + + worst_idx = int(np.argmax(abs_err)) + metrics["worst_index_flat"] = worst_idx + metrics["worst_ref"] = float(ref.reshape(-1)[worst_idx]) + metrics["worst_test"] = float(test.reshape(-1)[worst_idx]) + metrics["worst_abs_err"] = float(abs_err.reshape(-1)[worst_idx]) + metrics["worst_rel_err"] = float(rel_err.reshape(-1)[worst_idx]) + + return metrics + + +def print_run_report(run_id: int, m: dict): + print("=" * 90) + print(f"Run #{run_id}") + print("=" * 90) + print(f"shape : {m['shape']}") + print(f"ref mean/std : {m['ref_mean']:.10f} / {m['ref_std']:.10f}") + print(f"test mean/std : {m['test_mean']:.10f} / {m['test_std']:.10f}") + print(f"max abs err : {m['max_abs_err']:.10e}") + print(f"mean abs err : {m['mean_abs_err']:.10e}") + print(f"median abs err : {m['median_abs_err']:.10e}") + print(f"rmse : {m['rmse']:.10e}") + print(f"max rel err : {m['max_rel_err']:.10e}") + print(f"mean rel err : {m['mean_rel_err']:.10e}") + print(f"relative L2 : {m['relative_l2']:.10e}") + print(f"pass rate : {m['pass_rate']:.6f}") + print(f"all close : {m['all_close']}") + print(f"nan(ref/test) : {m['has_nan_ref']} / {m['has_nan_test']}") + print(f"inf(ref/test) : {m['has_inf_ref']} / {m['has_inf_test']}") + print(f"worst flat index : {m['worst_index_flat']}") + print(f"worst ref/test : {m['worst_ref']:.10f} / {m['worst_test']:.10f}") + print(f"worst abs err : {m['worst_abs_err']:.10e}") + print(f"worst rel err : {m['worst_rel_err']:.10e}") + + +def summarize_all(metrics_list): + keys = [ + "max_abs_err", "mean_abs_err", "median_abs_err", "rmse", + "max_rel_err", "mean_rel_err", "relative_l2", "pass_rate" + ] + + print("\n" + "#" * 90) + print("Aggregate summary across paired runs") + print("#" * 90) + + for k in keys: + vals = np.array([m[k] for m in metrics_list], dtype=np.float64) + print( + f"{k:16s}: " + f"min={vals.min():.10e}, " + f"mean={vals.mean():.10e}, " + f"max={vals.max():.10e}" + ) + + all_allclose = all(m["all_close"] for m in metrics_list) + print(f"\nall runs allclose : {all_allclose}") + + # Simple engineering conclusion + max_abs = max(m["max_abs_err"] for m in metrics_list) + max_rel_l2 = max(m["relative_l2"] for m in metrics_list) + min_pass_rate = min(m["pass_rate"] for m in metrics_list) + + print("\nConclusion:") + if max_abs < 1e-8 and max_rel_l2 < 1e-8: + print("-> The two result files are essentially identical; no obvious precision deviation.") + elif min_pass_rate == 1.0 and max_rel_l2 < 1e-5: + print("-> The difference is extremely small; no large precision deviation.") + elif min_pass_rate > 0.999 and max_rel_l2 < 1e-3: + print("-> The difference is small and likely acceptable, but worth checking worst-case elements.") + else: + print("-> There may be noticeable precision deviation; inspect the per-run worst errors.") + + +def main(): + print("=============") + parser = argparse.ArgumentParser() + parser.add_argument("--amp", type=str, default="result_AMP.txt", + help="Path to original AMP result txt") + parser.add_argument("--fix", type=str, default="result_AMP_fix.txt", + help="Path to optimized AMP result txt") + parser.add_argument("--rtol", type=float, default=1e-5, + help="Relative tolerance for allclose/isclose") + parser.add_argument("--atol", type=float, default=1e-7, + help="Absolute tolerance for allclose/isclose") + parser.add_argument("--expected-runs", type=int, default=10, + help="Expected number of runs in each file") + args = parser.parse_args() + + amp_runs = parse_result_file(args.amp) + fix_runs = parse_result_file(args.fix) + + print(f"Parsed {len(amp_runs)} runs from {args.amp}") + print(f"Parsed {len(fix_runs)} runs from {args.fix}") + + if len(amp_runs) != len(fix_runs): + raise ValueError( + f"Run count mismatch: {len(amp_runs)} (AMP) vs {len(fix_runs)} (FIX)" + ) + + if args.expected_runs is not None: + if len(amp_runs) != args.expected_runs: + print( + f"[Warning] Expected {args.expected_runs} runs, " + f"but parsed {len(amp_runs)} from AMP file." + ) + if len(fix_runs) != args.expected_runs: + print( + f"[Warning] Expected {args.expected_runs} runs, " + f"but parsed {len(fix_runs)} from FIX file." + ) + + metrics_list = [] + for i, (ref, test) in enumerate(zip(amp_runs, fix_runs)): + if ref.shape != test.shape: + raise ValueError( + f"Shape mismatch at run #{i}: {ref.shape} vs {test.shape}" + ) + + m = compare_two_arrays(ref, test, rtol=args.rtol, atol=args.atol) + metrics_list.append(m) + print_run_report(i, m) + + # Also compare all runs flattened together + ref_all = np.concatenate([x.reshape(-1) for x in amp_runs], axis=0) + test_all = np.concatenate([x.reshape(-1) for x in fix_runs], axis=0) + all_metrics = compare_two_arrays(ref_all, test_all, rtol=args.rtol, atol=args.atol) + + print("\n" + "#" * 90) + print("Global comparison over all runs concatenated") + print("#" * 90) + print_run_report(-1, all_metrics) + + summarize_all(metrics_list) + + +if __name__ == "__main__": + main()