diff --git a/testing/python/compile_pipeline/compile_pipeline.py b/testing/python/compile_pipeline/compile_pipeline.py new file mode 100644 index 0000000000..1ee7ea9fa7 --- /dev/null +++ b/testing/python/compile_pipeline/compile_pipeline.py @@ -0,0 +1,597 @@ +import os +import re +import warnings +import tilelang +from tilelang import tvm +from tilelang.transform import PassConfigKey +from tilelang.utils.target import determine_target +from typing import Any, Literal, Callable +from tvm.target import Target +from tilelang.language.eager import PrimFunc +from tvm import tir, IRModule +from tvm.ir import CallingConv +from tilelang.engine.param import KernelParam +from tilelang.transform import PassContext +from tilelang.contrib.nvcc import have_tma +from tilelang.utils.target import target_is_sunmmio +from tilelang.jit.adapter.utils import is_cuda_target + + +def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + if (not is_cuda_target(target)) or (not have_tma(target)): + return False + disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False) + return not disable_warp_specialized + + +def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + if not have_tma(target): + return False + disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) + return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target) + + +def allow_fence_proxy(target: Target | None = None) -> bool: + return have_tma(target) + + +def allow_vectorize(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False) + return not disable_vectorize + + +def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + enable_global_thread_sync = pass_ctx.config.get("tir.detect_global_barrier", False) + return enable_global_thread_sync + + +def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + enable_aggressive_merge = bool(pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False)) + if allow_warp_specialized(pass_ctx=pass_ctx, target=target): + enable_aggressive_merge = False + return enable_aggressive_merge + + +def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) + + +def should_enable_ast_print(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_AST_PRINT_ENABLE, False)) + + +def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False) + return enabled + + +def should_enable_race_check(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + enabled = not pass_ctx.config.get(tilelang.PassConfigKey.TL_DISABLE_DATA_RACE_CHECK, False) + return enabled + + +def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "") + if not formats_value: + return ["txt"] + + formats_str = formats_value.strip().lower() + valid_formats = ["txt", "png", "pdf", "svg", "all"] + + if formats_str == "all": + return ["txt", "png", "pdf", "svg"] + + if "," in formats_str: + formats_list = [f.strip() for f in formats_str.split(",")] + else: + formats_list = [formats_str] + + invalid_formats = [f for f in formats_list if f not in valid_formats] + if invalid_formats: + raise ValueError( + f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. " + f"Valid formats are: {valid_formats}. " + f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')" + ) + return formats_list + + +def LayoutVisual(mod: IRModule) -> None: + if should_enable_layout_visual(): + formats = get_layout_visual_formats() + tilelang.analysis.LayoutVisual(formats=formats)(mod) + + +def is_cpu_device_backend(target: Target): + return target.kind.name == "c" + + +def has_device_kernel_launch(attrs) -> bool: + """Check if the attributes indicate a device kernel launch.""" + return bool(attrs and "calling_conv" in attrs and attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH) + + +def is_device_call_c_device(func: tir.PrimFunc): + attrs = func.attrs + calling_conv = attrs.get("calling_conv", CallingConv.DEFAULT) + is_cpacked = calling_conv == CallingConv.C_PACKED_FUNC + + # Check if it's a C target + if "target" in attrs and attrs["target"].kind.name == "c" and not is_cpacked: + return True + + return has_device_kernel_launch(attrs) + + +def is_device_call(func: tir.PrimFunc): + return has_device_kernel_launch(func.attrs) + + +def get_device_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: + return is_device_call_c_device if is_device_c else is_device_call + + +def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: + return lambda func: not get_device_call(is_device_c)(func) + + +def is_sunmmio_call(func: tir.PrimFunc): + attrs = func.attrs + return bool(attrs and "target" in attrs and target_is_sunmmio(attrs["target"])) + + +def get_device_call_sunmmio() -> Callable[[tir.PrimFunc], bool]: + return is_sunmmio_call + + +def get_host_call_sunmmio() -> Callable[[tir.PrimFunc], bool]: + return lambda func: not is_sunmmio_call(func) + + +def extrac_params(func: tir.PrimFunc) -> list[KernelParam]: + tensor_types = [] + for var in func.params: + if var in func.buffer_map: + tensor_types.append(KernelParam.from_buffer(func.buffer_map[var])) + else: + tensor_types.append(KernelParam.from_var(var)) + return tensor_types + + +def canon_target_host(target: str | Target, target_host: str | Target | None): + if not target_host: + target_host = "llvm" if tvm.runtime.enabled("llvm") else "c" + return target_host + + +def PreLowerSemanticCheck(mod: IRModule) -> None: + if should_enable_ast_print(): + tilelang.analysis.ASTPrinter()(mod) + tilelang.analysis.NestedLoopChecker()(mod) + tilelang.analysis.FragmentLoopChecker()(mod) + + +def pass_test(mod: IRModule, pass_name: str, test_config: dict[str, Any]) -> None: + if pass_name in test_config: + test_info = test_config[pass_name] + print(f"testing {pass_name}") + if "script_expected" in test_info: + expect = test_info["script_expected"] + if isinstance(expect, str): + expect = [expect.strip()] + elif isinstance(expect, list): + for lines in expect: + assert isinstance(lines, str), f"Invalid type for script_expected: {type(lines)}" + expect = [lines.strip() for lines in expect] + else: + raise ValueError(f"Invalid type for script_expected: {type(expect)}") + script = mod.script(show_meta=True).strip() + error_msg = f"The generated script of {pass_name} does not match the expected output." + if "show_generated_script" in test_info and test_info["show_generated_script"]: + error_msg = error_msg + f"\nGenerated script:\n{script}" + for lines in expect: + if lines not in script: + warnings.warn(error_msg, stacklevel=2) + + if "formal_verify" in test_info: + formal_verify = test_info["formal_verify"] + if isinstance(formal_verify, list): + for test_func in formal_verify: + test_func(mod) + else: + formal_verify(mod) + + +def LowerAndLegalize_sunmmio_test( + mod: IRModule, + target: Target, + test_config: dict[str, Any] | None = None, + log_pass_output: bool = False, + show_meta: bool = False, + log_dir: str = "./", + log_passes: list[str] | None = None, +) -> IRModule: + if test_config is None: + test_config = {} + out_file = os.path.join(log_dir, "passes_lower_and_legalize.log") + if log_pass_output: + print(f"Logging pass output in LowerAndLegalize to {out_file}") + with open(out_file, "w") as f: + f.write("\n'=== Initial Mod ==='\n") + if log_passes is None: + f.write(mod.script(show_meta=show_meta).strip() + "\n\n") + + def pass_output_process(mod: IRModule, pass_name: str, test_config: dict[str, Any]) -> None: + if log_pass_output and (log_passes is None or pass_name in log_passes): + with open(out_file, "a") as f: + f.write(f"'=== After Pass {pass_name} ==='\n") + f.write(mod.script(show_meta=show_meta).strip() + "\n\n") + pass_test(mod, pass_name, test_config) + + mod = tir.transform.BindTarget(target)(mod) + pass_output_process(mod, "BindTarget", test_config) + + if should_force_let_inline(): + mod = tilelang.transform.LetInline()(mod) + pass_output_process(mod, "LetInline", test_config) + + # mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + pass_output_process(mod, "LegalizeNegativeIndex", test_config) + + # if should_enable_race_check(): + # mod = tilelang.transform.VerifyParallelLoop()(mod) + mod = tilelang.transform.InjectAssumes()(mod) + pass_output_process(mod, "InjectAssumes", test_config) + + mod = tilelang.transform.Simplify()(mod) + pass_output_process(mod, "Simplify_lower_1", test_config) + + mod = tilelang.transform.InferSramScope()(mod) + pass_output_process(mod, "InferSramScope", test_config) + + # mod = tilelang.transform.LayoutReducer()(mod) + mod = tilelang.transform.LayoutInference()(mod) + pass_output_process(mod, "LayoutInference", test_config) + + LayoutVisual(mod) + mod = tilelang.transform.LowerTileOp()(mod) + pass_output_process(mod, "LowerTileOp", test_config) + + mod = tilelang.transform.LegalizeTilesLoop()(mod) + pass_output_process(mod, "LegalizeTilesLoop", test_config) + + mod = tilelang.transform.TilesLoop()(mod) + pass_output_process(mod, "TilesLoop", test_config) + + # mod = tilelang.transform.LowerL2Persistent()(mod) + mod = tilelang.transform.DecoupleTypeCast()(mod) + pass_output_process(mod, "DecoupleTypeCast", test_config) + + mod = tilelang.transform.LegalizeVectorizedLoop()(mod) + pass_output_process(mod, "LegalizeVectorizedLoop", test_config) + + mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod) + pass_output_process(mod, "LegalizeSafeMemoryAccess", test_config) + + mod = tilelang.transform.LowerAccessPtr()(mod) + pass_output_process(mod, "LowerAccessPtr", test_config) + + mod = tilelang.transform.Simplify()(mod) + pass_output_process(mod, "Simplify_lower_2", test_config) + + mod = tilelang.transform.HoistNonRestrictParams()(mod) + pass_output_process(mod, "HoistNonRestrictParams", test_config) + + return mod + + +def OptimizeForSunmmio_test( + mod: IRModule, + target: Target, + test_config: dict[str, Any] | None = None, + log_pass_output: bool = False, + show_meta: bool = False, + log_dir: str = "./", + log_passes: list[str] | None = None, +) -> IRModule: + if test_config is None: + test_config = {} + out_file = os.path.join(log_dir, "passes_optimize_for_sunmmio.log") + if log_pass_output: + print(f"Logging pass output in OptimizeForSunmmio to {out_file}") + with open(out_file, "w") as f: + f.write("\n'=== Initial Mod ==='\n") + if log_passes is None: + f.write(mod.script(show_meta=show_meta).strip() + "\n\n") + + def pass_output_process(mod: IRModule, pass_name: str, test_config: dict[str, Any]) -> None: + if log_pass_output and (log_passes is None or pass_name in log_passes): + with open(out_file, "a") as f: + f.write(f"'=== After Pass {pass_name} ==='\n") + f.write(mod.script(show_meta=show_meta).strip() + "\n\n") + pass_test(mod, pass_name, test_config) + + mod = tilelang.transform.IfStmtBinding()(mod) + pass_output_process(mod, "IfStmtBinding", test_config) + + mod = tilelang.transform.SunmmioPipelinePlanning(debug=False)(mod) + pass_output_process(mod, "SunmmioPipelinePlanning", test_config) + + mod = tilelang.transform.InjectSunmmioPipeline()(mod) + pass_output_process(mod, "InjectSunmmioPipeline", test_config) + + mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod) + pass_output_process(mod, "PlanAndUpdateBufferAllocationLocation", test_config) + + mod = tilelang.transform.LowerOpaqueBlock()(mod) + pass_output_process(mod, "LowerOpaqueBlock", test_config) + + mod = tir.transform.Simplify()(mod) + pass_output_process(mod, "Simplify_optimize_1", test_config) + + mod = tir.transform.NarrowDataType(32)(mod) + pass_output_process(mod, "NarrowDataType", test_config) + + mod = tir.transform.HoistIfThenElse()(mod) + pass_output_process(mod, "HoistIfThenElse", test_config) + + mod = tilelang.transform.LoopUnswitching()(mod) + pass_output_process(mod, "LoopUnswitching", test_config) + + mod = tir.transform.UnrollLoop()(mod) + pass_output_process(mod, "UnrollLoop", test_config) + + mod = tir.transform.Simplify()(mod) + pass_output_process(mod, "Simplify_optimize_2", test_config) + + mod = tir.transform.VerifyMemory()(mod) + pass_output_process(mod, "VerifyMemory", test_config) + + mod = tir.transform.AnnotateEntryFunc()(mod) + pass_output_process(mod, "AnnotateEntryFunc", test_config) + + mod = tilelang.transform.AnnotateDeviceRegions()(mod) + pass_output_process(mod, "AnnotateDeviceRegions", test_config) + + mod = tilelang.transform.SplitHostDevice()(mod) + pass_output_process(mod, "SplitHostDevice", test_config) + + mod = tilelang.transform.MergeIfStmt()(mod) + pass_output_process(mod, "MergeIfStmt", test_config) + + mod = tilelang.transform.InjectSunmmioSync()(mod) + pass_output_process(mod, "InjectSunmmioSync", test_config) + + mod = tilelang.transform.FlattenBuffer()(mod) + pass_output_process(mod, "FlattenBuffer", test_config) + + mod = tilelang.transform.ConfigIndexBitwidth()(mod) + pass_output_process(mod, "ConfigIndexBitwidth", test_config) + + mod = tir.transform.Simplify()(mod) + pass_output_process(mod, "Simplify_optimize_3", test_config) + + mod = tilelang.transform.VectorizeLoop(enable_vectorize=True)(mod) + pass_output_process(mod, "VectorizeLoop", test_config) + + mod = tilelang.transform.StorageRewrite()(mod) + pass_output_process(mod, "StorageRewrite", test_config) + + mod = tir.transform.RemoveNoOp()(mod) + pass_output_process(mod, "RemoveNoOp", test_config) + + mod = tir.transform.RenormalizeSplitPattern()(mod) + pass_output_process(mod, "RenormalizeSplitPattern", test_config) + + mod = tir.transform.Simplify()(mod) + pass_output_process(mod, "Simplify_optimize_4", test_config) + + mod = tilelang.transform.AnnotateReadOnlyParams()(mod) + pass_output_process(mod, "AnnotateReadOnlyParams", test_config) + + mod = tilelang.transform.MergeSharedMemoryAllocationsSunmmio(enable_aggressive_merge=True)(mod) + pass_output_process(mod, "MergeSharedMemoryAllocationsSunmmio", test_config) + + mod = tilelang.transform.MakePackedAPI()(mod) + pass_output_process(mod, "MakePackedAPI", test_config) + + mod = tilelang.transform.Simplify()(mod) + pass_output_process(mod, "Simplify_optimize_5", test_config) + + mod = tilelang.transform.LowerDeviceKernelLaunch()(mod) + pass_output_process(mod, "LowerDeviceKernelLaunch", test_config) + + return mod + + +def process_passes_output(log_dir, filenames, remove_header=False): + def clean_header(text): + if not remove_header: + return text + # Header to remove: + # # from tvm.script import ir as I + # # from tvm.script import tir as T + # + # @I.ir_module + header_pattern = r"# from tvm\.script import ir as I\s*\n# from tvm\.script import tir as T\s*\n\s*\n@I\.ir_module\s*\n" + text = re.sub(header_pattern, "", text) + + # Metadata omitted comment to remove + meta_pattern = r"\n\s*\n# Metadata omitted\. Use show_meta=True in script\(\) method to show it\.\s*\n\n" + text = re.sub(meta_pattern, "\n", text) + return text + + for filename in filenames: + log_file = os.path.join(log_dir, filename) + if not os.path.exists(log_file): + continue + + with open(log_file, "r") as f: + content = f.read() + + # Split by the "=== After Pass ... ===" or "=== Initial Mod ===" markers + # We need to capture the markers to keep them + pattern = r"('(?:=== Initial Mod ===|=== After Pass [^=]+ ===)')" + parts = re.split(pattern, content) + + if len(parts) < 3: + continue + + new_parts = [parts[0]] + # Always keep the first pass (Initial Mod) + new_parts.append(parts[1]) + new_parts.append("\n" + clean_header(parts[2]).strip() + "\n") + + last_content = parts[2].strip() + + for i in range(3, len(parts), 2): + header = parts[i] + current_content = parts[i + 1] + + # Compare current content with last content + if current_content.strip() == last_content: + new_parts.append(header) + new_parts.append("\nNo change.\n") + else: + new_parts.append(header) + new_parts.append("\n" + clean_header(current_content).strip() + "\n") + last_content = current_content.strip() + + with open(log_file, "w") as f: + f.write("".join(new_parts)) + + +def compile_test( + func: PrimFunc = None, + out_idx: list[int] | int | None = None, + execution_backend: (Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None) = None, + target: str | Target | None = None, + target_host: str | Target | None = None, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | str | None = None, + test_config: dict[str, Any] | None = None, + log_pass_output: bool = False, + show_meta: bool = False, + log_dir: str = "./", + remove_header: bool = False, + log_passes: list[str] | None = None, +): + """ + Compile the given TileLang PrimFunc with TVM and return the host_mod and device_mod. + This function mimics tilelang.jit.compile but exposes the intermediate TIR modules. + It manually implements the lower logic to avoid failing at codegen step. + + Returns: + tuple: (host_mod, device_mod) corresponding to the modules generated in lower.log + """ + if test_config is None: + test_config = {} + + for pass_name in test_config: + for key in test_config[pass_name]: + assert key in [ + "script_expected", + "show_generated_script", + "formal_verify", + ], f"wrong key :{key} for pass {pass_name}" + + if execution_backend is None: + execution_backend = "tvm_ffi" + + if execution_backend == "auto": + execution_backend = "tvm_ffi" + + if pass_configs is None: + pass_configs = {} + + if compile_flags is not None: + compile_flags_cfg = pass_configs.get(PassConfigKey.TL_DEVICE_COMPILE_FLAGS) + pass_configs[PassConfigKey.TL_DEVICE_COMPILE_FLAGS] = ( + compile_flags_cfg + compile_flags if compile_flags_cfg is not None else compile_flags + ) + + # Determine target + target = determine_target(target, return_object=True) + + # Custom Lower implementation + func_or_mod = func + + mod = func_or_mod + if isinstance(func_or_mod, tir.PrimFunc): + func = func_or_mod + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + + if isinstance(target, str): + target = determine_target(target) + + target_host = canon_target_host(target, target_host) + + target_host = tvm.target.Target.canon_target(target_host) + target = tvm.target.Target(target, target_host) + + if target_is_sunmmio(target): + _is_host_call = get_host_call_sunmmio() + _is_device_call = get_device_call_sunmmio() + else: + _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) + _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) + + with tvm.transform.PassContext(opt_level=3, config=pass_configs), target: + # Before lowering, do semantic check + PreLowerSemanticCheck(mod) + + if log_pass_output: + os.makedirs(log_dir, exist_ok=True) + + # Phase 1: Lower and legalize the IR module + mod = LowerAndLegalize_sunmmio_test(mod, target, test_config, log_pass_output, show_meta, log_dir, log_passes) + + # Phase 2: Optimize the IR for the target + mod = OptimizeForSunmmio_test(mod, target, test_config, log_pass_output, show_meta, log_dir, log_passes) + + host_mod = tir.transform.Filter(_is_host_call)(mod) + device_mod = tir.transform.Filter(_is_device_call)(mod) + + out_file = os.path.join(log_dir, "passes_optimize_for_sunmmio.log") + + def _log_pass(pass_name, m): + if log_pass_output and (log_passes is None or pass_name in log_passes): + with open(out_file, "a") as f: + f.write(f"'=== After Pass {pass_name} ==='\n") + f.write(m.script(show_meta=show_meta).strip() + "\n\n") + + _log_pass("HostMod", host_mod) + pass_test(host_mod, "HostMod", test_config) + + _log_pass("DeviceMod", device_mod) + pass_test(device_mod, "DeviceMod", test_config) + + if log_pass_output: + process_passes_output( + log_dir, + ["passes_lower_and_legalize.log", "passes_optimize_for_sunmmio.log"], + remove_header=remove_header, + ) + + return host_mod, device_mod diff --git a/testing/python/compile_pipeline/formal_verify_funcs.py b/testing/python/compile_pipeline/formal_verify_funcs.py new file mode 100644 index 0000000000..20be1474e8 --- /dev/null +++ b/testing/python/compile_pipeline/formal_verify_funcs.py @@ -0,0 +1,431 @@ +import re +from tvm import tir, IRModule, arith + + +def verify_comm_lower(func: tir.PrimFunc): + expected_broadcasts = [] + analyzer = arith.Analyzer() + current_mesh_nrow = 4 + current_mesh_ncol = 4 + + def get_region_size(node): + if isinstance(node, tir.Call) and node.op.name == "tl.tileop.region": + size = 1 + for i in range(2, len(node.args)): + size *= node.args[i] + return analyzer.simplify(size) + elif isinstance(node, tir.BufferRegion): + size = 1 + for r in node.region: + size *= r.extent + return analyzer.simplify(size) + elif isinstance(node, tir.Buffer): + size = 1 + for s in node.shape: + size *= s + return analyzer.simplify(size) + elif isinstance(node, tir.ProducerLoad): + size = 1 + for s in node.buffer.shape: + size *= s + return analyzer.simplify(size) + elif isinstance(node, tir.BufferLoad): + # If it's a BufferLoad, we try to see if it looks like a slice. + # In TileLang's debug representation, slices look like A[start:end, ...] + # But in the actual object, it's a BufferLoad with indices. + # If we can't determine the slice size, we return the buffer size + # BUT we mark it as "possibly a full buffer" so the caller can prefer + # a more specific size if available. + # For now, let's just return the buffer size but try to handle the fallback better. + size = 1 + for s in node.buffer.shape: + size *= s + return analyzer.simplify(size) + elif hasattr(node, "buffer"): + return get_region_size(node.buffer) + return tir.IntImm("int32", 1) + + def stringify_expr(expr): + if isinstance(expr, tir.IntImm): + return str(int(expr)) + return str(analyzer.simplify(expr)) + + def visitor(node): + nonlocal current_mesh_nrow, current_mesh_ncol + if isinstance(node, tir.Call) and node.op.name.startswith("tl.tileop.comm_"): + if node.op.name == "tl.tileop.comm_broadcast": + # args: src, dst, size, dst_offset, src_core, direction + size_expr = node.args[2] + if isinstance(size_expr, tir.IntImm) and int(size_expr) > 0: + size = analyzer.simplify(size_expr) + else: + # Infer size from src region and dst buffer + size0 = get_region_size(node.args[0]) + size1 = get_region_size(node.args[1]) + if isinstance(size0, tir.IntImm) and isinstance(size1, tir.IntImm): + if int(size0) > 0 and int(size1) > 0: + size = tir.IntImm("int32", min(int(size0), int(size1))) + else: + size = analyzer.simplify(size0 if int(size0) > 0 else size1) + else: + size = analyzer.simplify(size0) + + src_core = node.args[4] + direction_val = node.args[5] + if isinstance(direction_val, tir.StringImm): + direction = 0 if direction_val.value == "h" else 1 if direction_val.value == "v" else 2 + else: + direction = int(direction_val) if isinstance(direction_val, tir.IntImm) else 0 + + if direction == 0 or direction == 1: + expected_broadcasts.append((stringify_expr(size), stringify_expr(src_core), direction)) + elif direction == 2 and isinstance(src_core, tir.IntImm): + # 2D broadcast: only supports constant src_core in C++ + src_core_val = int(src_core) + src_core_col = src_core_val % current_mesh_ncol + expected_broadcasts.append((stringify_expr(size), stringify_expr(src_core), 1)) + for i in range(current_mesh_nrow): + expected_broadcasts.append( + ( + stringify_expr(size), + str(i * current_mesh_ncol + src_core_col), + 0, + ) + ) + + elif node.op.name == "tl.tileop.comm_put": + # args: src, dst, size, src_core, dst_core + size_expr = node.args[2] + if isinstance(size_expr, tir.IntImm) and int(size_expr) > 0: + size = analyzer.simplify(size_expr) + else: + # Infer size from src region, fallback to dst buffer if src is just a pointer/point + size = get_region_size(node.args[0]) + if isinstance(size, tir.IntImm) and int(size) == 1: + size = get_region_size(node.args[1]) + + src_core = node.args[3] + dst_core = node.args[4] + + # Put logic in C++ requires constant core IDs + if isinstance(src_core, tir.IntImm) and isinstance(dst_core, tir.IntImm): + src_core_val = int(src_core) + dst_core_val = int(dst_core) + src_row, src_col = ( + src_core_val // current_mesh_ncol, + src_core_val % current_mesh_ncol, + ) + dst_row, dst_col = ( + dst_core_val // current_mesh_ncol, + dst_core_val % current_mesh_ncol, + ) + + if src_row == dst_row: + expected_broadcasts.append((stringify_expr(size), stringify_expr(src_core), 0)) + elif src_col == dst_col: + expected_broadcasts.append((stringify_expr(size), stringify_expr(src_core), 1)) + else: + intermediate_core = dst_row * current_mesh_ncol + src_col + expected_broadcasts.append((stringify_expr(size), stringify_expr(src_core), 1)) + expected_broadcasts.append((stringify_expr(size), str(intermediate_core), 0)) + + elif node.op.name == "tl.tileop.comm_allgather": + # args: send, recv, direction, size + direction_val = node.args[2] + if isinstance(direction_val, tir.StringImm): + direction = 0 if direction_val.value == "h" else 1 if direction_val.value == "v" else 2 + else: + direction = int(direction_val) if isinstance(direction_val, tir.IntImm) else 0 + + size_expr = node.args[3] + if isinstance(size_expr, tir.IntImm) and int(size_expr) > 0: + size = analyzer.simplify(size_expr) + else: + # Infer size from send region and recv buffer + size0 = get_region_size(node.args[0]) + size1 = get_region_size(node.args[1]) + if isinstance(size0, tir.IntImm) and isinstance(size1, tir.IntImm): + if int(size0) > 0 and int(size1) > 0: + size = tir.IntImm("int32", min(int(size0), int(size1))) + else: + size = analyzer.simplify(size0 if int(size0) > 0 else size1) + else: + size = analyzer.simplify(size0) + + if direction == 0: # horizontal + for i in range(current_mesh_nrow): + for j in range(current_mesh_ncol): + expected_broadcasts.append( + ( + stringify_expr(size), + str(i * current_mesh_ncol + j), + 0, + ) + ) + elif direction == 1: # vertical + for j in range(current_mesh_ncol): + for i in range(current_mesh_nrow): + expected_broadcasts.append( + ( + stringify_expr(size), + str(i * current_mesh_ncol + j), + 1, + ) + ) + elif direction == 2: # all + # horizontal first + for i in range(current_mesh_nrow): + for j in range(current_mesh_ncol): + expected_broadcasts.append( + ( + stringify_expr(size), + str(i * current_mesh_ncol + j), + 0, + ) + ) + # then vertical + allgather_size = analyzer.simplify(size * current_mesh_ncol) + for j in range(current_mesh_ncol): + for i in range(current_mesh_nrow): + expected_broadcasts.append( + ( + stringify_expr(allgather_size), + str(i * current_mesh_ncol + j), + 1, + ) + ) + + if not isinstance(func, tir.PrimFunc): + raise ValueError(f"Expected PrimFunc, got {type(func)}") + + # 1. Gather expectations from func + if func.attrs is not None: + if "mesh_nrow" in func.attrs: + current_mesh_nrow = int(func.attrs["mesh_nrow"]) + if "mesh_ncol" in func.attrs: + current_mesh_ncol = int(func.attrs["mesh_ncol"]) + tir.stmt_functor.post_order_visit(func.body, visitor) + + # 2. Return the check function + def check(mod: IRModule): + script = mod.script() + for size, core, direction in expected_broadcasts: + # Match T.broadcast_(..., size, core, direction, ...) + if size == "1": + size_pattern = r"\d+" + else: + size_pattern = re.escape(size).replace(r"\ ", r"\s*") + escaped_core = re.escape(core).replace(r"\ ", r"\s*") + pattern = rf"T\.broadcast_\(.*?,\s*.*?,\s*{size_pattern},\s*{escaped_core},\s*{direction}" + assert re.search(pattern, script), ( + f"Expected broadcast_ with size={size}, core={core}, direction={direction} not found in IRModule" + ) + + return check + + +def verify_SunmmioSync(mod: IRModule): + script = mod.script() + lines = [l.strip() for l in script.split("\n")] + + token_ids = [int(l.split("sync_token_id(")[1].split(")")[0]) for l in lines if "sync_token_id" in l] + barrier_ids = [int(l.split("(")[1].split(")")[0].split(",")[0]) for l in lines if "barrier_init" in l] + wait_ids = [int(l.split("(")[1].split(")")[0]) for l in lines if "wait_token" in l] + arrive_ids = [int(l.split("(")[1].split(")")[0]) for l in lines if "barrier_arrive_and_wait" in l] + token_num = max(token_ids) + 1 if token_ids else 0 + barrier_num = max(barrier_ids) + 1 if barrier_ids else 0 + + # Check count of wait_lines and arrive_lines + assert len(wait_ids) >= token_num, "wait_lines should be greater than token_lines" + assert len(arrive_ids) >= barrier_num, "arrive_lines should be greater than barrier_lines" + # Check range of wait_ids and arrive_ids + for i in wait_ids: + assert i < token_num, f"wait_token({i}) is out of range {token_num}" + for i in arrive_ids: + assert i < barrier_num, f"arrive_token({i}) is out of range {barrier_num}" + + # Check order of sync_token_id(id) (or sync_null_token(id)) and wait_token(id) + for i in range(token_num): + idx_token = script.find(f"sync_token_id({i})") + idx_null = script.find(f"sync_null_token({i})") + if idx_null != -1: + assert idx_null < idx_token, f"sync_null_token({i}) is after sync_token_id({i})" + idx_token = idx_null + idx_wait = script.find(f"wait_token({i})") + assert idx_token != -1, f"sync_token_id({i}) is not found in script" + assert idx_wait != -1, f"wait_token({i}) is not found in script" + assert idx_token < idx_wait, f"wait_token({i}) is before sync_token_id({i})" + # Check order of barrier_init(id) and barrier_arrive_and_wait(id) + for i in range(barrier_num): + idx_barrier = script.find(f"barrier_init({i}") + idx_arrive = script.find(f"barrier_arrive_and_wait({i})") + assert idx_barrier != -1, f"barrier_init({i}) is not found in script" + assert idx_arrive != -1, f"barrier_arrive_and_wait({i}) is not found in script" + assert idx_barrier < idx_arrive, f"barrier_init({i}) is after barrier_arrive_and_wait({i})" + + +def _count_alloc_by_scope(body, scope): + cnt = 0 + + def visitor(n): + nonlocal cnt + if isinstance(n, tir.Allocate): + ta = n.buffer_var.type_annotation + if getattr(ta, "storage_scope", "") == scope: + cnt += 1 + + tir.stmt_functor.post_order_visit(body, visitor) + return cnt + + +def _get_single_alloc_extent(body, scope): + extent = None + + def visitor(n): + nonlocal extent + if isinstance(n, tir.Allocate): + ta = n.buffer_var.type_annotation + if getattr(ta, "storage_scope", "") == scope and extent is None: + extent = n.extents[0] + + tir.stmt_functor.post_order_visit(body, visitor) + return extent + + +def build_verify_merge_allocate(kernel_name: str, cnt_a=0, cnt_w=0, cnt_r=0): + def verify_merge_allocate(mod: IRModule): + device_mod = mod[kernel_name] + # Expect exactly one Allocate per scope after merge + assert _count_alloc_by_scope(device_mod.body, "shared.asram") <= 1, "shared.asram not 1" + assert _count_alloc_by_scope(device_mod.body, "shared.wsram") <= 1, "shared.wsram not 1" + assert _count_alloc_by_scope(device_mod.body, "shared.rsram") <= 1, "shared.rsram not 1" + if cnt_a > 0: + real_a = _get_single_alloc_extent(device_mod.body, "shared.asram") + assert real_a == cnt_a, f"shared.asram extent error, expected {cnt_a}, got {real_a}" + if cnt_w > 0: + real_w = _get_single_alloc_extent(device_mod.body, "shared.wsram") + assert real_w == cnt_w, f"shared.wsram extent error, expected {cnt_w}, got {real_w}" + if cnt_r > 0: + real_r = _get_single_alloc_extent(device_mod.body, "shared.rsram") + assert real_r == cnt_r, f"shared.rsram extent error, expected {cnt_r}, got {real_r}" + + return verify_merge_allocate + + +def verify_tiles_ops(prim_func: tir.PrimFunc): + has_fill = False + has_reduce = False + has_tiles = False + + def visitor(node): + nonlocal has_fill, has_reduce, has_tiles + if isinstance(node, tir.Call): + if node.op.name == "tl.fill": + has_fill = True + elif node.op.name == "tl.reduce": + has_reduce = True + elif isinstance(node, tir.For) and "tile.loop_parallel" in node.annotations and "tile.tiled_buffer" in node.annotations: + has_tiles = True + + tir.stmt_functor.post_order_visit(prim_func.body, visitor) + + def check_lower_tile_op(mod: IRModule): + script = mod.script() + if has_fill: + # fill and clear + assert "tile.loop_parallel" in script, "Expected tile.loop_parallel in script for fill/clear" + assert "tile.loop_stage" in script, "Expected tile.loop_stage in script for fill/clear" + assert "tile.tiled_buffer" in script, "Expected tile.tiled_buffer in script for fill/clear" + + if has_reduce: + assert "reduce_tile_op" in script, "Expected reduce_tile_op block in script" + assert "shared_acc" in script, "Expected shared_acc buffer in script" + + def check_legalize_tiles_loop(mod: IRModule): + script = mod.script() + if has_tiles: + assert "tile.buffer_new_shape" in script, "Expected tile.buffer_new_shape in script for Tiles" + assert "tile.dim_map" in script, "Expected tile.dim_map in script for Tiles" + assert "tile.tile_size" in script, "Expected tile.tile_size in script for Tiles" + # At this stage, loop_stage should be 1 + assert '"tile.loop_stage": 1' in script, "Expected tile.loop_stage: 1 in script" + + def check_tiles_loop(mod: IRModule): + script = mod.script() + if has_tiles: + assert "tile.interior" in script, "Expected tile.interior in script for Tiles" + assert "tile.interior_axis" in script, "Expected tile.interior_axis in script for Tiles" + # At this stage, loop_stage should be 2 + assert '"tile.loop_stage": 2' in script, "Expected tile.loop_stage: 2 in script" + + def check_vectorize_loop(mod: IRModule): + script = mod.script() + # Verify that vectorized loops have been eliminated + assert "T.vectorized(" not in script, "Expected T.vectorized to be eliminated by VectorizeLoop pass" + + return { + "LowerTileOp": check_lower_tile_op, + "LegalizeTilesLoop": check_legalize_tiles_loop, + "TilesLoop": check_tiles_loop, + "VectorizeLoop": check_vectorize_loop, + } + + +def verify_host_mod_separation(mod: IRModule): + script = mod.script() + + # Host module characteristics based on string representation + assert "tir.is_entry_func" in script, "Host module should contain the entry function attribute" + assert "tir.is_global_func" not in script, "Host module should not contain device global function attributes" + assert "T.call_packed" in script or "T.call_extern" in script, ( + "Host module should contain packed or extern calls for kernel launch/errors" + ) + assert "T.launch_thread" not in script, "Host module should not contain device thread launches" + + +def verify_device_mod_separation(mod: IRModule): + script = mod.script() + + # Device module characteristics based on string representation + assert "tir.is_global_func" in script, "Device module should contain global function attributes" + assert "tir.is_entry_func" not in script, "Device module should not contain the host entry function attribute" + assert "__tvm_error_" not in script, "Device module should not contain host-side error checking" + + +def get_or_add_default_verify(func: tir.PrimFunc, test_config: dict = None): + verify_ops = verify_tiles_ops(func) + verify_comm = verify_comm_lower(func) + + default_verify = { + "LowerTileOp": { + "formal_verify": [verify_ops["LowerTileOp"], verify_comm], + }, + "LegalizeTilesLoop": { + "formal_verify": [verify_ops["LegalizeTilesLoop"]], + }, + "TilesLoop": { + "formal_verify": [verify_ops["TilesLoop"]], + }, + "InjectSunmmioSync": { + "formal_verify": [verify_SunmmioSync], + }, + "VectorizeLoop": { + "formal_verify": [verify_ops["VectorizeLoop"]], + }, + "HostMod": {"formal_verify": [verify_host_mod_separation]}, + "DeviceMod": { + "formal_verify": [verify_SunmmioSync, verify_comm, verify_device_mod_separation], + }, + } + + if test_config is None: + return default_verify + + for key, value in default_verify.items(): + if key not in test_config: + test_config[key] = {} + if "formal_verify" not in test_config[key]: + test_config[key]["formal_verify"] = [] + test_config[key]["formal_verify"].extend(value["formal_verify"]) + + return test_config diff --git a/testing/python/compile_pipeline/test_comm.py b/testing/python/compile_pipeline/test_comm.py new file mode 100644 index 0000000000..8062f4eb2a --- /dev/null +++ b/testing/python/compile_pipeline/test_comm.py @@ -0,0 +1,173 @@ +import tilelang.language as T +from compile_pipeline import compile_test +from formal_verify_funcs import * + + +def kernel_comm(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float32"): + mesh_device_config = (4, 4) + + @T.prim_func + def main( + A: T.MeshTensor((M, K), T.MeshShardingPolicy(x=1, y=0), mesh_device_config, dtype), + B: T.MeshTensor((K, N), T.MeshShardingPolicy(x=1, y=0), mesh_device_config, dtype), + C: T.MeshTensor((M, N), T.MeshShardingPolicy(x=1, y=0), mesh_device_config, accum_dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), dtype=dtype) + A_remote_1 = T.alloc_shared((block_M, block_K), dtype=dtype) + A_remote_2 = T.alloc_shared((block_M, block_K), dtype=dtype) + A_remote_3 = T.alloc_shared((block_M, block_K), dtype=dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype=dtype) + B_remote_1 = T.alloc_shared((block_K, block_N), dtype=dtype) + B_remote_2 = T.alloc_shared((block_K, block_N), dtype=dtype) + B_remote_3 = T.alloc_shared((block_K, block_N), dtype=dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype=accum_dtype) + C_allgather_1 = T.alloc_shared((16, block_M, block_N), dtype=accum_dtype) + C_allgather_2 = T.alloc_shared((4, block_M, block_N), dtype=accum_dtype) + C_allgather_3 = T.alloc_shared((4, block_M, block_N), dtype=accum_dtype) + + T.clear(A_shared) + T.clear(B_shared) + T.clear(C_shared) # Avoid Fill op unsupported scope error + T.comm.broadcast(A_shared, A_remote_1, (0, 0), direction="all") + T.comm.broadcast(A_shared, A_remote_2, (0, 0), direction="h") + T.comm.broadcast(A_shared, A_remote_3, (0, 0), direction="v") + T.comm.put(B_shared, B_remote_1, (1, 2), (2, 3)) + T.comm.put(B_shared, B_remote_2, (1, 2), (1, 3)) + T.comm.put(B_shared, B_remote_3, (1, 2), (3, 2)) + T.comm.all_gather(C_shared, C_allgather_1, direction="all") + T.comm.all_gather(C_shared, C_allgather_2, direction="h") + T.comm.all_gather(C_shared, C_allgather_3, direction="v") + + return main + + +def test_comm(): + func = kernel_comm(1024 * 16, 1024 * 16, 1024 * 16, 1024, 1024, 1024) + script_comm = [ + """ + T.broadcast_(T.region(A_shared[0, 0], 1, 1024, 1024), T.region(A_remote_1[0, 0], 2, 1024, 1024), 1048576, 0, 1) + T.broadcast_(T.region(A_remote_1[0, 0], 1, 1024, 1024), T.region(A_remote_1[0, 0], 2, 1024, 1024), 1048576, 0, 0) + T.broadcast_(T.region(A_remote_1[0, 0], 1, 1024, 1024), T.region(A_remote_1[0, 0], 2, 1024, 1024), 1048576, 4, 0) + T.broadcast_(T.region(A_remote_1[0, 0], 1, 1024, 1024), T.region(A_remote_1[0, 0], 2, 1024, 1024), 1048576, 8, 0) + T.broadcast_(T.region(A_remote_1[0, 0], 1, 1024, 1024), T.region(A_remote_1[0, 0], 2, 1024, 1024), 1048576, 12, 0) + """, + """ + T.broadcast_(T.region(A_shared[0, 0], 1, 1024, 1024), T.region(A_remote_2[0, 0], 2, 1024, 1024), 1048576, 0, 0) + """, + """ + T.broadcast_(T.region(A_shared[0, 0], 1, 1024, 1024), T.region(A_remote_3[0, 0], 2, 1024, 1024), 1048576, 0, 1) + """, + """ + T.broadcast_(T.region(B_shared[0, 0], 1, 1024, 1024), T.region(B_remote_1[0, 0], 2, 1024, 1024), 1048576, 6, 1, 0, 1, 3) + T.broadcast_(T.region(B_remote_1[0, 0], 1, 1024, 1024), T.region(B_remote_1[0, 0], 2, 1024, 1024), 1048576, 10, 0, 0, 1, 2) + """, + """ + T.broadcast_(T.region(B_shared[0, 0], 1, 1024, 1024), T.region(B_remote_2[0, 0], 2, 1024, 1024), 1048576, 6, 0, 0, 1, 2) + """, + """ + T.broadcast_(T.region(B_shared[0, 0], 1, 1024, 1024), T.region(B_remote_3[0, 0], 2, 1024, 1024), 1048576, 6, 1, 0, 1, 2) + """, + """ + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 0, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 1, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 2, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 3, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 4, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 5, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 6, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 7, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 8, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 9, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 10, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 11, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 12, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 13, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 14, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 16, 1024, 1024), 1048576, 15, 0) + T.broadcast_(T.region(C_allgather_1[0, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 4, 1024, 1024), 4194304, 0, 1) + T.broadcast_(T.region(C_allgather_1[4, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[4, 0, 0], 2, 4, 1024, 1024), 4194304, 4, 1) + T.broadcast_(T.region(C_allgather_1[8, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[8, 0, 0], 2, 4, 1024, 1024), 4194304, 8, 1) + T.broadcast_(T.region(C_allgather_1[12, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[12, 0, 0], 2, 4, 1024, 1024), 4194304, 12, 1) + T.broadcast_(T.region(C_allgather_1[0, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 4, 1024, 1024), 4194304, 1, 1) + T.broadcast_(T.region(C_allgather_1[4, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[4, 0, 0], 2, 4, 1024, 1024), 4194304, 5, 1) + T.broadcast_(T.region(C_allgather_1[8, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[8, 0, 0], 2, 4, 1024, 1024), 4194304, 9, 1) + T.broadcast_(T.region(C_allgather_1[12, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[12, 0, 0], 2, 4, 1024, 1024), 4194304, 13, 1) + T.broadcast_(T.region(C_allgather_1[0, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 4, 1024, 1024), 4194304, 2, 1) + T.broadcast_(T.region(C_allgather_1[4, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[4, 0, 0], 2, 4, 1024, 1024), 4194304, 6, 1) + T.broadcast_(T.region(C_allgather_1[8, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[8, 0, 0], 2, 4, 1024, 1024), 4194304, 10, 1) + T.broadcast_(T.region(C_allgather_1[12, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[12, 0, 0], 2, 4, 1024, 1024), 4194304, 14, 1) + T.broadcast_(T.region(C_allgather_1[0, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[0, 0, 0], 2, 4, 1024, 1024), 4194304, 3, 1) + T.broadcast_(T.region(C_allgather_1[4, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[4, 0, 0], 2, 4, 1024, 1024), 4194304, 7, 1) + T.broadcast_(T.region(C_allgather_1[8, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[8, 0, 0], 2, 4, 1024, 1024), 4194304, 11, 1) + T.broadcast_(T.region(C_allgather_1[12, 0, 0], 1, 4, 1024, 1024), T.region(C_allgather_1[12, 0, 0], 2, 4, 1024, 1024), 4194304, 15, 1) + """, + """ + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 0, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 1, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 2, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 3, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 4, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 5, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 6, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 7, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 8, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 9, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 10, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 11, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 12, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 13, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 14, 0) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_2[0, 0, 0], 2, 4, 1024, 1024), 1048576, 15, 0) + """, + """ + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 0, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 4, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 8, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 12, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 1, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 5, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 9, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 13, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 2, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 6, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 10, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 14, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 3, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 7, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 11, 1) + T.broadcast_(T.region(C_shared[0, 0], 1, 1024, 1024), T.region(C_allgather_3[0, 0, 0], 2, 4, 1024, 1024), 1048576, 15, 1) + """, + ] + script_mere_allocate = [ + """ + buf_shmem = T.allocate([121634816], "uint8", "shared.rsram") + """ + ] + + def get_verify_merge_allocate(): + """Merge test for multiple scopes with mixed constant and non-constant sizes""" + kernel_name = "main_kernel" + # 8 float16 buffers of 1024*1024 + 4 float32 buffers, no reuse (many wait at the end) + cnt_r = 1024 * 1024 * 8 * 2 + 1024 * 1024 * 4 * (1 + 16 + 4 + 4) + return build_verify_merge_allocate(kernel_name=kernel_name, cnt_r=cnt_r) + + test_config = { + "LowerTileOp": { + "script_expected": script_comm, + }, + "MergeSharedMemoryAllocationsSunmmio": { + "script_expected": script_mere_allocate, + "formal_verify": get_verify_merge_allocate(), + }, + } + test_config = get_or_add_default_verify(func, test_config) + compile_test(func, out_idx=[2], target="Sunmmio", test_config=test_config) + + +if __name__ == "__main__": + test_comm() diff --git a/testing/python/compile_pipeline/test_flashattn.py b/testing/python/compile_pipeline/test_flashattn.py new file mode 100644 index 0000000000..399fe88696 --- /dev/null +++ b/testing/python/compile_pipeline/test_flashattn.py @@ -0,0 +1,447 @@ +import tilelang.language as T +from compile_pipeline import compile_test +from formal_verify_funcs import * + + +def kernel_flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + block_M=64, + block_N=64, + num_stages=1, + threads=1, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + # Different precisions will cause different number of allocates. The default allocate is allocated according to uint8, so when the data type is float16, the number of allocates will be doubled. + dtype = T.float16 + # accum_dtype = T.float32 + accum_dtype = T.float16 + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as ( + bx, + by, + bz, + ): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_shared([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_shared([block_M, block_N], accum_dtype, scope="shared.asram") + acc_o = T.alloc_shared([block_M, dim], accum_dtype, scope="shared.rsram") + scores_max = T.alloc_shared([block_M], accum_dtype) + scores_max_prev = T.alloc_shared([block_M], accum_dtype) + scores_scale = T.alloc_shared([block_M], accum_dtype) + scores_sum = T.alloc_shared([block_M], accum_dtype) + logsum = T.alloc_shared([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, + 0, + -T.infinity(acc_s.dtype), + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True) + + for i in T.serial(0, block_M): + scores_max_prev[i] = scores_max[i] + scores_max[i] = -T.infinity(accum_dtype) + for j in T.serial(0, block_N): + scores_max[i] = T.max(scores_max[i], acc_s[i, j]) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i in T.serial(0, block_M): + scores_sum[i] = T.cast(0, accum_dtype) + for j in T.serial(0, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + scores_sum[i] = scores_sum[i] + acc_s[i, j] + + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def test_flashattn(): + func = kernel_flashattn(8, 32, 4096, 128, False, block_M=128, block_N=128, num_stages=1, threads=1) + script_lower_tile_op = """ + with T.block("tilelang_root"): + T.reads(Q[bz, bx * 128, by, 0], K[bz, 0:3969, by, 0], V[bz, 0:3969, by, 0], Output[bz, bx * 128, by, 0]) + T.writes() + T.block_attr({"layout_map": {Q_shared: metadata["tl.Layout"][0], K_shared: metadata["tl.Layout"][1], acc_s: metadata["tl.Layout"][2], acc_s_cast: metadata["tl.Layout"][3], V_shared: metadata["tl.Layout"][4], acc_o: metadata["tl.Layout"][5], O_shared: metadata["tl.Layout"][6]}}) + Q_shared = T.alloc_buffer((128, 128), "float16", data=Q_shared.data, scope="shared.asram") + K_shared = T.alloc_buffer((128, 128), "float16", data=K_shared.data, scope="shared.wsram") + V_shared = T.alloc_buffer((128, 128), "float16", data=V_shared.data, scope="shared.wsram") + O_shared = T.alloc_buffer((128, 128), "float16", data=O_shared.data, scope="shared.rsram") + acc_s = T.alloc_buffer((128, 128), "float16", data=acc_s.data, scope="shared.rsram") + acc_s_cast = T.alloc_buffer((128, 128), "float16", data=acc_s_cast.data, scope="shared.asram") + acc_o = T.alloc_buffer((128, 128), "float16", data=acc_o.data, scope="shared.rsram") + scores_max = T.alloc_buffer((128,), "float16", scope="shared.rsram") + scores_max_prev = T.alloc_buffer((128,), "float16", scope="shared.rsram") + scores_scale = T.alloc_buffer((128,), "float16", scope="shared.rsram") + scores_sum = T.alloc_buffer((128,), "float16", scope="shared.rsram") + logsum = T.alloc_buffer((128,), "float16", scope="shared.rsram") + T.dma_copy(T.region(Q[bz, bx * 128, by, 0], 1, 1, 128, 1, 128), T.region(Q_shared[0, 0], 2, 128, 128)) + for i0 in T.serial(128, annotations={"tile.domain": [128, 128], "tile.loop_parallel": 1, "tile.loop_stage": 0}): + for i1 in T.serial(128, annotations={"tile.loop_parallel": 1, "tile.loop_stage": 0}): + acc_o[i0, i1] = T.Cast("float16", 0) + for i0 in T.serial(128, annotations={"tile.domain": [128], "tile.loop_parallel": 1, "tile.loop_stage": 0}): + logsum[i0] = T.Cast("float16", 0) + for i0 in T.serial(128, annotations={"tile.domain": [128], "tile.loop_parallel": 1, "tile.loop_stage": 0}): + scores_max[i0] = T.infinity("float16") * T.float16(-1.0) + for k in T.serial(32, annotations={"num_stages": 1}): + T.dma_copy(T.region(K[bz, k * 128, by, 0], 1, 1, 128, 1, 128), T.region(K_shared[0, 0], 2, 128, 128)) + for i in T.unroll(512, annotations={"pragma_unroll_explicit": T.bool(False)}): + for vec in T.vectorized(32): + acc_s[(i * 32 + vec) // 128, (i * 32 + vec) % 128] = T.float16(0.0) + with T.block("_gemm_sss"): + T.reads() + T.writes() + T.mma_sunmmio(T.region(Q_shared[0, 0], 1, 128, 128), T.region(K_shared[0, 0], 1, 128, 128), T.region(acc_s[0, 0], 3, 128, 128), T.bool(False), T.bool(True), T.bool(False)) + for i in range(128): + scores_max_prev[i] = scores_max[i] + scores_max[i] = T.infinity("float16") * T.float16(-1.0) + for j in range(128): + scores_max[i] = T.max(scores_max[i], acc_s[i, j]) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.unroll(4, annotations={"pragma_unroll_explicit": T.bool(False)}): + for vec in T.vectorized(32): + scores_scale[i * 32 + vec] = T.Cast("float16", T.exp2(T.Cast("float32", scores_max_prev[i * 32 + vec]) * T.float32(0.1275174307460247) - T.Cast("float32", scores_max[i * 32 + vec]) * T.float32(0.1275174307460247))) + for i in range(128): + scores_sum[i] = T.float16(0.0) + for j in range(128): + acc_s[i, j] = T.Cast("float16", T.exp2(T.Cast("float32", acc_s[i, j]) * T.float32(0.1275174307460247) - T.Cast("float32", scores_max[i]) * T.float32(0.1275174307460247))) + scores_sum[i] = scores_sum[i] + acc_s[i, j] + for i in T.unroll(2, annotations={"pragma_unroll_explicit": T.bool(False)}): + for vec in T.vectorized(64): + logsum[i * 64 + vec] = logsum[i * 64 + vec] * scores_scale[i * 64 + vec] + scores_sum[i * 64 + vec] + T.dma_copy(T.region(acc_s[0, 0], 1, 128, 128), T.region(acc_s_cast[0, 0], 2, 128, 128)) + for i in T.unroll(512, annotations={"pragma_unroll_explicit": T.bool(False)}): + for vec in T.vectorized(32): + acc_o[(i * 32 + vec) // 128, (i * 32 + vec) % 128] = acc_o[(i * 32 + vec) // 128, (i * 32 + vec) % 128] * scores_scale[(i * 32 + vec) // 128] + T.dma_copy(T.region(V[bz, k * 128, by, 0], 1, 1, 128, 1, 128), T.region(V_shared[0, 0], 2, 128, 128)) + with T.block("_gemm_sss"): + T.reads() + T.writes() + T.mma_sunmmio(T.region(acc_s_cast[0, 0], 1, 128, 128), T.region(V_shared[0, 0], 1, 128, 128), T.region(acc_o[0, 0], 3, 128, 128), T.bool(False), T.bool(False), T.bool(False)) + for i in T.unroll(512, annotations={"pragma_unroll_explicit": T.bool(False)}): + for vec in T.vectorized(32): + acc_o[(i * 32 + vec) // 128, (i * 32 + vec) % 128] = acc_o[(i * 32 + vec) // 128, (i * 32 + vec) % 128] / logsum[(i * 32 + vec) // 128] + for i in T.serial(128, annotations={"tile.domain": [128, 128], "tile.loop_parallel": 1, "tile.loop_stage": 0}): + for j in T.serial(128, annotations={"tile.loop_parallel": 1, "tile.loop_stage": 0}): + O_shared[i, j] = acc_o[i, j] + T.dma_copy(T.region(O_shared[0, 0], 1, 128, 128), T.region(Output[bz, bx * 128, by, 0], 2, 1, 128, 1, 128)) + """ + + script_inject_sunmmio_sync = """ + with T.launch_thread("blockIdx.x", 32) as bx: + by = T.launch_thread("blockIdx.y", 32) + bz = T.launch_thread("blockIdx.z", 8) + tx = T.launch_thread("threadIdx.x", 1) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.decl_buffer((128, 128), "float16", scope="shared.asram") as Q_shared: + K_shared = T.decl_buffer((1, 128, 128), "float16", scope="shared.wsram") + V_shared = T.decl_buffer((1, 128, 128), "float16", scope="shared.wsram") + O_shared = T.decl_buffer((128, 128), "float16", scope="shared.rsram") + acc_s = T.decl_buffer((1, 128, 128), "float16", scope="shared.rsram") + acc_s_cast = T.decl_buffer((1, 128, 128), "float16", scope="shared.asram") + acc_o = T.decl_buffer((128, 128), "float16", scope="shared.rsram") + scores_max = T.decl_buffer((128,), "float16", scope="shared.rsram") + scores_max_prev = T.decl_buffer((128,), "float16", scope="shared.rsram") + scores_scale = T.decl_buffer((128,), "float16", scope="shared.rsram") + scores_sum = T.decl_buffer((128,), "float16", scope="shared.rsram") + logsum = T.decl_buffer((128,), "float16", scope="shared.rsram") + Q_2 = T.Buffer((8, 4096, 32, 128), "float16", data=Q, strides=(16777216, 4096, 128, 1)) + T.dma_copy(T.region(Q_2[bz, bx * 128, by, 0], 1, 1, 128, 1, 128), T.region(Q_shared[0, 0], 2, 128, 128), T.sync_token_id(0)) + for i0 in T.serial(16, annotations={"tile.domain": [128, 128], "tile.execution_axis": 0, "tile.execution_domain_axes": [0, 1], "tile.scope_entry": 1, "tile.tile_size": [8, 32]}): + for i1 in T.serial(4, annotations={"tile.execution_axis": 1}): + for ki in T.serial(8, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + for kj in T.vectorized(32, annotations={"tile.interior": 1, "tile.interior_axis": 1}): + acc_o[i0 * 8 + ki, i1 * 32 + kj] = T.float16(0.0) + for i0 in T.serial(1, annotations={"tile.domain": [128], "tile.execution_axis": 0, "tile.execution_domain_axes": [0], "tile.scope_entry": 1, "tile.tile_size": [128]}): + for ki in T.serial(2, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + for vec in T.vectorized(64): + logsum[ki * 64 + vec] = T.float16(0.0) + for i0 in T.serial(1, annotations={"tile.domain": [128], "tile.execution_axis": 0, "tile.execution_domain_axes": [0], "tile.scope_entry": 1, "tile.tile_size": [128]}): + for ki in T.serial(2, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + for vec in T.vectorized(64): + scores_max[ki * 64 + vec] = T.infinity("float16") * T.float16(-1.0) + K_2 = T.Buffer((8, 4096, 32, 128), "float16", data=K, strides=(16777216, 4096, 128, 1)) + T.dma_copy(T.region(K_2[bz, 0, by, 0], 1, 1, 128, 1, 128), T.region(K_shared[0, 0, 0], 2, 1, 128, 128), T.sync_token_id(1)) + for i in T.unroll(512): + for vec in T.vectorized(32): + acc_s[0, i // 4, i % 4 * 32 + vec] = T.float16(0.0) + V_2 = T.Buffer((8, 4096, 32, 128), "float16", data=V, strides=(16777216, 4096, 128, 1)) + T.dma_copy(T.region(V_2[bz, 0, by, 0], 1, 1, 128, 1, 128), T.region(V_shared[0, 0, 0], 2, 1, 128, 128), T.sync_token_id(2)) + T.sync_null_token(4) + T.sync_null_token(5) + T.sync_null_token(6) + T.sync_null_token(7) + for k in range(31): + T.wait_token(0) + T.wait_token(1) + T.wait_token(4) + T.wait_token(5) + T.mma_sunmmio(T.region(Q_shared[0, 0], 1, 128, 128), T.region(K_shared[0, 0, 0], 1, 1, 128, 128), T.region(acc_s[0, 0, 0], 3, 1, 128, 128), T.bool(False), T.bool(True), T.bool(False), T.sync_token_id(3)) + for i in range(128): + scores_max_prev[i] = scores_max[i] + scores_max[i] = T.infinity("float16") * T.float16(-1.0) + for j in range(128): + T.wait_token(3) + scores_max[i] = T.max(scores_max[i], acc_s[0, i, j]) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + T.dma_copy(T.region(K_2[bz, k * 128 + 128, by, 0], 1, 1, 128, 1, 128), T.region(K_shared[0, 0, 0], 2, 1, 128, 128), T.sync_token_id(4)) + for i in range(128): + scores_sum[i] = T.float16(0.0) + for j in range(128): + acc_s[0, i, j] = T.Cast("float16", T.exp2(T.Cast("float32", acc_s[0, i, j]) * T.float32(0.1275174307460247) - T.Cast("float32", scores_max[i]) * T.float32(0.1275174307460247))) + scores_sum[i] = scores_sum[i] + acc_s[0, i, j] + for i in T.unroll(4): + for vec in T.vectorized(32): + scores_scale[i * 32 + vec] = T.Cast("float16", T.exp2(T.Cast("float32", scores_max_prev[i * 32 + vec]) * T.float32(0.1275174307460247) - T.Cast("float32", scores_max[i * 32 + vec]) * T.float32(0.1275174307460247))) + T.wait_token(6) + T.dma_copy(T.region(acc_s[0, 0, 0], 1, 1, 128, 128), T.region(acc_s_cast[0, 0, 0], 2, 1, 128, 128), T.sync_token_id(5)) + for i in T.unroll(512): + for vec in T.vectorized(32): + acc_o[i // 4, i % 4 * 32 + vec] = acc_o[i // 4, i % 4 * 32 + vec] * scores_scale[i // 4] + T.wait_token(2) + T.wait_token(7) + T.mma_sunmmio(T.region(acc_s_cast[0, 0, 0], 1, 1, 128, 128), T.region(V_shared[0, 0, 0], 1, 1, 128, 128), T.region(acc_o[0, 0], 3, 128, 128), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(6)) + for i in T.unroll(2): + for vec in T.vectorized(64): + logsum[i * 64 + vec] = logsum[i * 64 + vec] * scores_scale[i * 64 + vec] + scores_sum[i * 64 + vec] + for i in T.unroll(512): + for vec in T.vectorized(32): + acc_s[0, i // 4, i % 4 * 32 + vec] = T.float16(0.0) + T.dma_copy(T.region(V_2[bz, k * 128 + 128, by, 0], 1, 1, 128, 1, 128), T.region(V_shared[0, 0, 0], 2, 1, 128, 128), T.sync_token_id(7)) + T.wait_token(4) + T.mma_sunmmio(T.region(Q_shared[0, 0], 1, 128, 128), T.region(K_shared[0, 0, 0], 1, 1, 128, 128), T.region(acc_s[0, 0, 0], 3, 1, 128, 128), T.bool(False), T.bool(True), T.bool(False), T.sync_token_id(8)) + for i in range(128): + scores_max_prev[i] = scores_max[i] + scores_max[i] = T.infinity("float16") * T.float16(-1.0) + for j in range(128): + T.wait_token(8) + scores_max[i] = T.max(scores_max[i], acc_s[0, i, j]) + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in range(128): + scores_sum[i] = T.float16(0.0) + for j in range(128): + acc_s[0, i, j] = T.Cast("float16", T.exp2(T.Cast("float32", acc_s[0, i, j]) * T.float32(0.1275174307460247) - T.Cast("float32", scores_max[i]) * T.float32(0.1275174307460247))) + scores_sum[i] = scores_sum[i] + acc_s[0, i, j] + for i in T.unroll(4): + for vec in T.vectorized(32): + scores_scale[i * 32 + vec] = T.Cast("float16", T.exp2(T.Cast("float32", scores_max_prev[i * 32 + vec]) * T.float32(0.1275174307460247) - T.Cast("float32", scores_max[i * 32 + vec]) * T.float32(0.1275174307460247))) + T.dma_copy(T.region(acc_s[0, 0, 0], 1, 1, 128, 128), T.region(acc_s_cast[0, 0, 0], 2, 1, 128, 128), T.sync_token_id(9)) + for i in T.unroll(512): + for vec in T.vectorized(32): + acc_o[i // 4, i % 4 * 32 + vec] = acc_o[i // 4, i % 4 * 32 + vec] * scores_scale[i // 4] + T.wait_token(9) + T.wait_token(7) + T.mma_sunmmio(T.region(acc_s_cast[0, 0, 0], 1, 1, 128, 128), T.region(V_shared[0, 0, 0], 1, 1, 128, 128), T.region(acc_o[0, 0], 3, 128, 128), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(10)) + for i in T.unroll(2): + for vec in T.vectorized(64): + logsum[i * 64 + vec] = logsum[i * 64 + vec] * scores_scale[i * 64 + vec] + scores_sum[i * 64 + vec] + for i in T.unroll(512): + for vec in T.vectorized(32): + T.wait_token(10) + acc_o[i // 4, i % 4 * 32 + vec] = acc_o[i // 4, i % 4 * 32 + vec] / logsum[i // 4] + for i in T.serial(16, annotations={"tile.domain": [128, 128], "tile.execution_axis": 0, "tile.execution_domain_axes": [0, 1], "tile.scope_entry": 1, "tile.tile_size": [8, 32]}): + for j in T.serial(4, annotations={"tile.execution_axis": 1}): + for ki in T.serial(8, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + for kj in T.vectorized(32, annotations={"tile.interior": 1, "tile.interior_axis": 1}): + O_shared[i * 8 + ki, j * 32 + kj] = acc_o[i * 8 + ki, j * 32 + kj] + Output_2 = T.Buffer((8, 4096, 32, 128), "float16", data=Output, strides=(16777216, 4096, 128, 1)) + T.dma_copy(T.region(O_shared[0, 0], 1, 128, 128), T.region(Output_2[bz, bx * 128, by, 0], 2, 1, 128, 1, 128), T.sync_token_id(11)) + T.wait_token(11) + """ + + script_device_mode = """ + def main_kernel(K: T.handle("float16", "global"), Output: T.handle("float16", "global"), Q: T.handle("float16", "global"), V: T.handle("float16", "global")) -> T.int32: + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""}), "thread_extent": {"blockIdx.x": 32, "blockIdx.y": 32, "blockIdx.z": 8, "threadIdx.x": 1, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.noalias": True, "tl.non_restrict_params": [], "tl.readonly_param_indices": [0, 1, 2, 3]}) + with T.launch_thread("blockIdx.x", 32) as bx: + buf_shmem = T.allocate([100352], "uint8", "shared.rsram") + buf_shmem_1 = T.allocate([65536], "uint8", "shared.wsram") + buf_shmem_2 = T.allocate([65536], "uint8", "shared.asram") + by = T.launch_thread("blockIdx.y", 32) + bz = T.launch_thread("blockIdx.z", 8) + tx = T.launch_thread("threadIdx.x", 1) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + Q_1 = T.Buffer((134217728,), "float16", data=Q) + Q_shared = T.Buffer((16384,), "float16", data=buf_shmem_2, scope="shared.asram") + T.dma_copy(T.region(Q_1[bz * 16777216 + bx * 524288 + by * 128], 1, 520320), T.region(Q_shared[16384], 2, 16384), T.sync_token_id(0)) + acc_o = T.Buffer((16384,), "float16", data=buf_shmem, scope="shared.rsram") + for i in T.unroll(512): + acc_o[i * 32:i * 32 + 32] = T.Broadcast(T.float16(0.0), 32) + logsum = T.Buffer((128,), "float16", data=buf_shmem, scope="shared.rsram") + for i in T.unroll(2): + logsum[i * 64 + 32768:i * 64 + 32768 + 64] = T.Broadcast(T.float16(0.0), 64) + scores_max = T.Buffer((128,), "float16", data=buf_shmem, scope="shared.rsram") + for i in T.unroll(2): + scores_max[i * 64 + 32896:i * 64 + 32896 + 64] = T.Broadcast(T.infinity("float16") * T.float16(-1.0), 64) + K_1 = T.Buffer((134217728,), "float16", data=K) + K_shared = T.Buffer((16384,), "float16", data=buf_shmem_1, scope="shared.wsram") + T.dma_copy(T.region(K_1[bz * 16777216 + by * 128], 1, 520320), T.region(K_shared[16384], 2, 16384), T.sync_token_id(1)) + acc_s = T.Buffer((16384,), "float16", data=buf_shmem, scope="shared.rsram") + for i in T.unroll(512): + acc_s[i * 32 + 16384:i * 32 + 16384 + 32] = T.Broadcast(T.float16(0.0), 32) + V_1 = T.Buffer((134217728,), "float16", data=V) + V_shared = T.Buffer((16384,), "float16", data=buf_shmem_1, scope="shared.wsram") + T.dma_copy(T.region(V_1[bz * 16777216 + by * 128], 1, 520320), T.region(V_shared[0], 2, 16384), T.sync_token_id(2)) + T.sync_null_token(4) + T.sync_null_token(5) + T.sync_null_token(6) + T.sync_null_token(7) + scores_max_prev = T.Buffer((128,), "float16", data=buf_shmem, scope="shared.rsram") + scores_sum = T.Buffer((128,), "float16", data=buf_shmem, scope="shared.rsram") + scores_scale = T.Buffer((128,), "float16", data=buf_shmem, scope="shared.rsram") + acc_s_cast = T.Buffer((16384,), "float16", data=buf_shmem_2, scope="shared.asram") + for k in range(31): + T.wait_token(0) + T.wait_token(1) + T.wait_token(4) + T.wait_token(5) + T.mma_sunmmio(T.region(Q_shared[16384], 1, 16384), T.region(K_shared[16384], 1, 16384), T.region(acc_s[16384], 3, 16384), T.bool(False), T.bool(True), T.bool(False), T.sync_token_id(3)) + for i in range(128): + scores_max_prev[i + 33024] = scores_max[i + 32896] + scores_max[i + 32896] = T.infinity("float16") * T.float16(-1.0) + for j in range(128): + T.wait_token(3) + scores_max[i + 32896] = T.max(scores_max[i + 32896], acc_s[i * 128 + j + 16384]) + scores_max[i + 32896] = T.max(scores_max[i + 32896], scores_max_prev[i + 33024]) + T.dma_copy(T.region(K_1[bz * 16777216 + k * 524288 + by * 128 + 524288], 1, 520320), T.region(K_shared[16384], 2, 16384), T.sync_token_id(4)) + for i in range(128): + scores_sum[i + 33280] = T.float16(0.0) + for j in range(128): + acc_s[i * 128 + j + 16384] = T.Cast("float16", T.exp2(T.Cast("float32", acc_s[i * 128 + j + 16384]) * T.float32(0.1275174307460247) - T.Cast("float32", scores_max[i + 32896]) * T.float32(0.1275174307460247))) + scores_sum[i + 33280] = scores_sum[i + 33280] + acc_s[i * 128 + j + 16384] + for i in T.unroll(4): + scores_scale[i * 32 + 33152:i * 32 + 33152 + 32] = T.Cast("float16x32", T.exp2(T.Cast("float32x32", scores_max_prev[i * 32 + 33024:i * 32 + 33024 + 32]) * T.Broadcast(T.float32(0.1275174307460247), 32) - T.Cast("float32x32", scores_max[i * 32 + 32896:i * 32 + 32896 + 32]) * T.Broadcast(T.float32(0.1275174307460247), 32))) + T.wait_token(6) + T.dma_copy(T.region(acc_s[16384], 1, 16384), T.region(acc_s_cast[0], 2, 16384), T.sync_token_id(5)) + for i in T.unroll(512): + acc_o[i * 32:i * 32 + 32] = acc_o[i * 32:i * 32 + 32] * T.Broadcast(scores_scale[i // 4 + 33152], 32) + T.wait_token(2) + T.wait_token(7) + T.mma_sunmmio(T.region(acc_s_cast[0], 1, 16384), T.region(V_shared[0], 1, 16384), T.region(acc_o[0], 3, 16384), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(6)) + for i in T.unroll(2): + logsum[i * 64 + 32768:i * 64 + 32768 + 64] = logsum[i * 64 + 32768:i * 64 + 32768 + 64] * scores_scale[i * 64 + 33152:i * 64 + 33152 + 64] + scores_sum[i * 64 + 33280:i * 64 + 33280 + 64] + for i in T.unroll(512): + acc_s[i * 32 + 16384:i * 32 + 16384 + 32] = T.Broadcast(T.float16(0.0), 32) + T.dma_copy(T.region(V_1[bz * 16777216 + k * 524288 + by * 128 + 524288], 1, 520320), T.region(V_shared[0], 2, 16384), T.sync_token_id(7)) + T.wait_token(4) + T.mma_sunmmio(T.region(Q_shared[16384], 1, 16384), T.region(K_shared[16384], 1, 16384), T.region(acc_s[16384], 3, 16384), T.bool(False), T.bool(True), T.bool(False), T.sync_token_id(8)) + for i in range(128): + scores_max_prev[i + 33024] = scores_max[i + 32896] + scores_max[i + 32896] = T.infinity("float16") * T.float16(-1.0) + for j in range(128): + T.wait_token(8) + scores_max[i + 32896] = T.max(scores_max[i + 32896], acc_s[i * 128 + j + 16384]) + scores_max[i + 32896] = T.max(scores_max[i + 32896], scores_max_prev[i + 33024]) + for i in range(128): + scores_sum[i + 33280] = T.float16(0.0) + for j in range(128): + acc_s[i * 128 + j + 16384] = T.Cast("float16", T.exp2(T.Cast("float32", acc_s[i * 128 + j + 16384]) * T.float32(0.1275174307460247) - T.Cast("float32", scores_max[i + 32896]) * T.float32(0.1275174307460247))) + scores_sum[i + 33280] = scores_sum[i + 33280] + acc_s[i * 128 + j + 16384] + for i in T.unroll(4): + scores_scale[i * 32 + 33152:i * 32 + 33152 + 32] = T.Cast("float16x32", T.exp2(T.Cast("float32x32", scores_max_prev[i * 32 + 33024:i * 32 + 33024 + 32]) * T.Broadcast(T.float32(0.1275174307460247), 32) - T.Cast("float32x32", scores_max[i * 32 + 32896:i * 32 + 32896 + 32]) * T.Broadcast(T.float32(0.1275174307460247), 32))) + T.dma_copy(T.region(acc_s[16384], 1, 16384), T.region(acc_s_cast[0], 2, 16384), T.sync_token_id(9)) + for i in T.unroll(512): + acc_o[i * 32:i * 32 + 32] = acc_o[i * 32:i * 32 + 32] * T.Broadcast(scores_scale[i // 4 + 33152], 32) + T.wait_token(9) + T.wait_token(7) + T.mma_sunmmio(T.region(acc_s_cast[0], 1, 16384), T.region(V_shared[0], 1, 16384), T.region(acc_o[0], 3, 16384), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(10)) + for i in T.unroll(2): + logsum[i * 64 + 32768:i * 64 + 32768 + 64] = logsum[i * 64 + 32768:i * 64 + 32768 + 64] * scores_scale[i * 64 + 33152:i * 64 + 33152 + 64] + scores_sum[i * 64 + 33280:i * 64 + 33280 + 64] + for i in T.unroll(512): + T.wait_token(10) + acc_o[i * 32:i * 32 + 32] = acc_o[i * 32:i * 32 + 32] / T.Broadcast(logsum[i // 4 + 32768], 32) + O_shared = T.Buffer((16384,), "float16", data=buf_shmem, scope="shared.rsram") + for v0 in T.serial(4, annotations={"tile.buffer_new_shape": [4, 4, 32, 32], "tile.dim_map": [-2, -1], "tile.execution": 1, "tile.loop_parallel": 1, "tile.loop_stage": 2, "tile.scope_entry": 1, "tile.tile_size": [32, 32], "tile.tiled_buffer": acc_o_1}): + acc_o_1 = T.handle("float16", "shared.rsram") + for v1 in T.serial(4, annotations={"tile.buffer_new_shape": [4, 4, 32, 32], "tile.dim_map": [-2, -1], "tile.execution": 1, "tile.loop_parallel": 1, "tile.loop_stage": 2, "tile.tile_size": [32, 32], "tile.tiled_buffer": acc_o_1}): + for ki in T.serial(32, annotations={"tile.interior": 1, "tile.interior_axis": 0, "tile.loop_stage": 2, "tile.tiled_buffer": acc_o_1}): + O_shared[v0 * 4096 + ki * 128 + v1 * 32 + 33792:v0 * 4096 + ki * 128 + v1 * 32 + 33792 + 32] = acc_o[v0 * 4096 + ki * 128 + v1 * 32:v0 * 4096 + ki * 128 + v1 * 32 + 32] + Output_1 = T.Buffer((134217728,), "float16", data=Output) + T.dma_copy(T.region(O_shared[33792], 1, 16384), T.region(Output_1[bz * 16777216 + bx * 524288 + by * 128], 2, 520320), T.sync_token_id(11)) + T.wait_token(11) + return 0 + """ + + def get_verify_merge_allocate(): + kernel_name = "main_kernel" + # 65536 65536 100352 + block_m, block_n, dim = 128, 128, 128 + cnt_a = block_m * dim + block_m * block_n + cnt_w = block_n * dim + block_n * dim + # rsram mainly has three matrix blocks and 5 vector blocks, the matrix size is 128 (float16), the third matrix block is at the end, aligned to 2048, + # the 5 vector blocks combined are less than 2048, so we just take the alignment size + cnt_r = block_m * dim + block_m * block_n + block_m * dim + 1024 + cnt_a *= 2 + cnt_w *= 2 + cnt_r *= 2 + return build_verify_merge_allocate(kernel_name=kernel_name, cnt_a=cnt_a, cnt_w=cnt_w, cnt_r=cnt_r) + + test_config = { + "LowerTileOp": { + "script_expected": script_lower_tile_op, + }, + "InjectSunmmioSync": { + "script_expected": script_inject_sunmmio_sync, + }, + "MergeSharedMemoryAllocationsSunmmio": { + "formal_verify": get_verify_merge_allocate(), + }, + "DeviceMode": { + "script_expected": script_device_mode, + }, + } + + test_config = get_or_add_default_verify(func, test_config) + compile_test(func, target="Sunmmio", test_config=test_config) + + +if __name__ == "__main__": + test_flashattn() diff --git a/testing/python/compile_pipeline/test_mma_3times.py b/testing/python/compile_pipeline/test_mma_3times.py new file mode 100644 index 0000000000..d6047d131c --- /dev/null +++ b/testing/python/compile_pipeline/test_mma_3times.py @@ -0,0 +1,86 @@ +import tilelang.language as T +from compile_pipeline import compile_test +from formal_verify_funcs import * + + +def kernel_mma_3times_single_thread(M=16, N=16, K=16, block_M=128, block_N=128, block_K=32, dtype="float16"): + @T.prim_func + def mma_3times_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize single-thread Kernel context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=1) as ( + bx, + by, + ): + # with T.Kernel(1, 1, threads=1) as (bx, by): + # [Key modification] Split multiple shared memory allocations to test merge_shared_memory_allocations + # Allocate multiple slice memories related to A (simulate A data storage in different stages) + A_shared1 = T.alloc_shared((block_M, block_K), dtype) + A_shared2 = T.alloc_shared((block_M, block_K), dtype) + A_shared3 = T.alloc_shared((block_M, block_K), dtype) + # Allocate multiple slice memories related to B (simulate B data storage in different stages) + B_shared1 = T.alloc_shared((block_K, block_N), dtype) + B_shared2 = T.alloc_shared((block_K, block_N), dtype) + B_shared3 = T.alloc_shared((block_K, block_N), dtype) + # Allocate multiple accumulation memories related to C (simulate intermediate results in different MMA stages) + C_shared1 = T.alloc_shared((block_M, block_N), dtype) + # C_shared2 = T.alloc_shared((block_M, block_N), dtype) + # C_shared3 = T.alloc_shared((block_M, block_N), dtype) + + # 1st MMA: copy data to stage1 memory -> compute -> save result to acc1 + + T.copy(A[block_M * 0, block_K * 0], A_shared1) + T.copy(B[block_K * 0, block_N * 0], B_shared1) + T.clear(C_shared1) + T.gemm(A_shared1, B_shared1, C_shared1) + + # 2nd MMA: copy data to stage2 memory -> accumulate based on acc1 -> save result to acc2 + T.copy(A[block_M * 1, block_K * 0], A_shared2) + T.copy(B[block_K * 1, block_N * 0], B_shared2) + T.gemm(A_shared2, B_shared2, C_shared1) + + # 3rd MMA: copy data to stage3 memory -> accumulate based on acc2 -> save result to final + T.copy(A[block_M * 2, block_K * 0], A_shared3) + T.copy(B[block_K * 2, block_N * 0], B_shared3) + T.gemm(A_shared3, B_shared3, C_shared1) + + # Write the final result back to global memory + T.copy(C_shared1, C[0, 0]) + + return mma_3times_kernel + + +def test_mma_3times(): + func = kernel_mma_3times_single_thread(1024, 1024, 1024) + + script_mere_allocate = """ + with T.launch_thread("blockIdx.x", 8) as bx: + buf_shmem = T.allocate([16384], "uint8", "shared.wsram") + buf_shmem_1 = T.allocate([16384], "uint8", "shared.asram") + C_shared1 = T.allocate([16384], "float16", "shared.rsram") + """ + + def get_verify_merge_allocate(): + kernel_name = "mma_3times_kernel_kernel" + # block 128*32; float16->uint * 2; a,c reuse, so only 2 buffer spaces are needed * 2 + cnt_a = 128 * 32 * 2 * 2 + cnt_w = 128 * 32 * 2 * 2 + # c_shared only has one, does not participate in merge, so it remains the original size, type unchanged + cnt_r = 128 * 128 + return build_verify_merge_allocate(kernel_name=kernel_name, cnt_a=cnt_a, cnt_w=cnt_w, cnt_r=cnt_r) + + test_config = { + "MergeSharedMemoryAllocationsSunmmio": { + "script_expected": script_mere_allocate, + "formal_verify": get_verify_merge_allocate(), + }, + } + test_config = get_or_add_default_verify(func, test_config) + compile_test(func, target="Sunmmio", test_config=test_config) + + +if __name__ == "__main__": + test_mma_3times() diff --git a/testing/python/compile_pipeline/test_overall.py b/testing/python/compile_pipeline/test_overall.py new file mode 100644 index 0000000000..769ec678a4 --- /dev/null +++ b/testing/python/compile_pipeline/test_overall.py @@ -0,0 +1,231 @@ +import tilelang.language as T +from compile_pipeline import compile_test +from formal_verify_funcs import * + + +def kernel_overall(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float32"): + mesh_device_config = (4, 4) + + @T.prim_func + def main( + A: T.MeshTensor((M, K), T.MeshShardingPolicy(x=1, y=0), mesh_device_config, dtype), + B: T.MeshTensor((K, N), T.MeshShardingPolicy(x=1, y=0), mesh_device_config, dtype), + Bias: T.MeshTensor((M, N), T.MeshShardingPolicy(x=1, y=0), mesh_device_config, accum_dtype), + C: T.MeshTensor((M, N), T.MeshShardingPolicy(x=1, y=0), mesh_device_config, accum_dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): + # [wanghz18] Automatic SRAM Scope Inference + # We declare generic 'shared' scope, expecting InferSramScope pass to + # refine them to 'shared.asram', 'shared.wsram', 'shared.rsram' + A_shared = T.alloc_shared((block_M, block_K), dtype=dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype=dtype) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + Bias_shared = T.alloc_shared((block_M, block_N), accum_dtype) + + T.clear(C_shared) # Avoid Fill op unsupported scope error + + # [wanghz18] GEMM Lowering to mma_sunmmio intrinsic + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_shared) + + # Load Bias + T.copy(Bias[by * block_M, bx * block_N], Bias_shared) + + # [weizzh] Tiles Loop for Element-wise operation + # This loop should be legalized and vectorized by LegalizeTilesLoop/TilesLoop passes + for i, j in T.Tiles(C_shared, parallel=True): + C_shared[i, j] = C_shared[i, j] + Bias_shared[i, j] + + # [xiaoyao-NKU] Inter-core Communication (Broadcast) + C_remote = T.alloc_shared((block_M, block_N), accum_dtype) + T.comm.broadcast(C_shared, C_remote, (0, 0), direction="h") + + # Store result + T.copy(C_remote, C[by * block_M, bx * block_N]) + + return main + + +def test_overall(): + func = kernel_overall(128, 128, 128, 64, 64, 32) + script_lower_tile_op = """ + with T.block("tilelang_root"): + T.reads(A[by * 64, 0:97], B[0:97, bx * 64], Bias[by * 64, bx * 64], C[by * 64, bx * 64]) + T.writes() + T.block_attr({"global_layout_map": {A: metadata["tl.Layout"][0], B: metadata["tl.Layout"][1], Bias: metadata["tl.Layout"][2], C: metadata["tl.Layout"][3]}, "layout_map": {A_shared: metadata["tl.Layout"][4], B_shared: metadata["tl.Layout"][5], Bias_shared: metadata["tl.Layout"][6], C_shared: metadata["tl.Layout"][7], C_remote: metadata["tl.Layout"][8]}}) + A_shared = T.alloc_buffer((64, 32), "float16", data=A_shared.data, scope="shared.asram") + B_shared = T.alloc_buffer((32, 64), "float16", data=B_shared.data, scope="shared.wsram") + C_shared = T.alloc_buffer((64, 64), data=C_shared.data, scope="shared.rsram") + Bias_shared = T.alloc_buffer((64, 64), data=Bias_shared.data, scope="shared.rsram") + C_remote = T.alloc_buffer((64, 64), data=C_remote.data, scope="shared.rsram") + for i0 in T.serial(64, annotations={"tile.domain": [64, 64], "tile.loop_parallel": 1, "tile.loop_stage": 0}): + for i1 in T.serial(64, annotations={"tile.loop_parallel": 1, "tile.loop_stage": 0}): + C_shared[i0, i1] = T.Cast("float32", 0) + for k in T.serial(4, annotations={"num_stages": 2}): + T.dma_copy(T.region(A[by * 64, k * 32], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32)) + T.dma_copy(T.region(B[k * 32, bx * 64], 1, 32, 64), T.region(B_shared[0, 0], 2, 32, 64)) + with T.block("_gemm_sss"): + T.reads() + T.writes() + T.mma_sunmmio(T.region(A_shared[0, 0], 1, 64, 32), T.region(B_shared[0, 0], 1, 32, 64), T.region(C_shared[0, 0], 3, 64, 64), T.bool(False), T.bool(False), T.bool(False)) + T.dma_copy(T.region(Bias[by * 64, bx * 64], 1, 64, 64), T.region(Bias_shared[0, 0], 2, 64, 64)) + for i in T.serial(64, annotations={"tile.domain": [64, 64], "tile.loop_parallel": 1, "tile.loop_stage": 0}): + for j in T.serial(64, annotations={"tile.loop_parallel": 1, "tile.loop_stage": 0}): + C_shared[i, j] = C_shared[i, j] + Bias_shared[i, j] + T.broadcast_(T.region(C_shared[0, 0], 1, 64, 64), T.region(C_remote[0, 0], 2, 64, 64), 4096, 0, 0) + T.dma_copy(T.region(C_remote[0, 0], 1, 64, 64), T.region(C[by * 64, bx * 64], 2, 64, 64)) + """ + script_InjectSunmmioSync = """ + with T.launch_thread("blockIdx.x", 2) as bx: + by = T.launch_thread("blockIdx.y", 2) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.decl_buffer((2, 64, 32), "float16", scope="shared.asram") as A_shared: + B_shared = T.decl_buffer((2, 32, 64), "float16", scope="shared.wsram") + C_shared = T.decl_buffer((64, 64), scope="shared.rsram") + Bias_shared = T.decl_buffer((64, 64), scope="shared.rsram") + C_remote = T.decl_buffer((64, 64), scope="shared.rsram") + for i0 in T.serial(16, annotations={"tile.domain": [64, 64], "tile.execution_axis": 0, "tile.execution_domain_axes": [0, 1], "tile.scope_entry": 1, "tile.tile_size": [4, 32]}): + for i1 in T.serial(2, annotations={"tile.execution_axis": 1}): + for ki in T.serial(4, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + for kj in T.vectorized(32, annotations={"tile.interior": 1, "tile.interior_axis": 1}): + C_shared[i0 * 4 + ki, i1 * 32 + kj] = T.float32(0.0) + A_2 = T.Buffer((32, 32), "float16", data=A, strides=(32, 1)) + T.dma_copy(T.region(A_2[by * 64, 0], 1, 64, 32), T.region(A_shared[0, 0, 0], 2, 1, 64, 32), T.sync_token_id(0)) + B_2 = T.Buffer((32, 32), "float16", data=B, strides=(32, 1)) + T.dma_copy(T.region(B_2[0, bx * 64], 1, 32, 64), T.region(B_shared[0, 0, 0], 2, 1, 32, 64), T.sync_token_id(1)) + T.wait_token(0) + T.wait_token(1) + T.mma_sunmmio(T.region(A_shared[0, 0, 0], 1, 1, 64, 32), T.region(B_shared[0, 0, 0], 1, 1, 32, 64), T.region(C_shared[0, 0], 3, 64, 64), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(2)) + T.dma_copy(T.region(A_2[by * 64, 32], 1, 64, 32), T.region(A_shared[1, 0, 0], 2, 1, 64, 32), T.sync_token_id(3)) + T.dma_copy(T.region(B_2[32, bx * 64], 1, 32, 64), T.region(B_shared[1, 0, 0], 2, 1, 32, 64), T.sync_token_id(4)) + T.wait_token(3) + T.wait_token(4) + T.wait_token(2) + T.mma_sunmmio(T.region(A_shared[1, 0, 0], 1, 1, 64, 32), T.region(B_shared[1, 0, 0], 1, 1, 32, 64), T.region(C_shared[0, 0], 3, 64, 64), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(5)) + T.dma_copy(T.region(A_2[by * 64, 64], 1, 64, 32), T.region(A_shared[0, 0, 0], 2, 1, 64, 32), T.sync_token_id(6)) + T.dma_copy(T.region(B_2[64, bx * 64], 1, 32, 64), T.region(B_shared[0, 0, 0], 2, 1, 32, 64), T.sync_token_id(7)) + T.wait_token(6) + T.wait_token(7) + T.wait_token(5) + T.mma_sunmmio(T.region(A_shared[0, 0, 0], 1, 1, 64, 32), T.region(B_shared[0, 0, 0], 1, 1, 32, 64), T.region(C_shared[0, 0], 3, 64, 64), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(8)) + T.dma_copy(T.region(A_2[by * 64, 96], 1, 64, 32), T.region(A_shared[1, 0, 0], 2, 1, 64, 32), T.sync_token_id(9)) + T.dma_copy(T.region(B_2[96, bx * 64], 1, 32, 64), T.region(B_shared[1, 0, 0], 2, 1, 32, 64), T.sync_token_id(10)) + T.wait_token(9) + T.wait_token(10) + T.wait_token(8) + T.mma_sunmmio(T.region(A_shared[1, 0, 0], 1, 1, 64, 32), T.region(B_shared[1, 0, 0], 1, 1, 32, 64), T.region(C_shared[0, 0], 3, 64, 64), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(11)) + Bias_2 = T.Buffer((32, 32), data=Bias, strides=(32, 1)) + T.dma_copy(T.region(Bias_2[by * 64, bx * 64], 1, 64, 64), T.region(Bias_shared[0, 0], 2, 64, 64), T.sync_token_id(12)) + for i in T.serial(16, annotations={"tile.domain": [64, 64], "tile.execution_axis": 0, "tile.execution_domain_axes": [0, 1], "tile.scope_entry": 1, "tile.tile_size": [4, 32]}): + for j in T.serial(2, annotations={"tile.execution_axis": 1}): + for ki in T.serial(4, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + for kj in T.vectorized(32, annotations={"tile.interior": 1, "tile.interior_axis": 1}): + T.wait_token(11) + T.wait_token(12) + C_shared[i * 4 + ki, j * 32 + kj] = C_shared[i * 4 + ki, j * 32 + kj] + Bias_shared[i * 4 + ki, j * 32 + kj] + T.broadcast_(T.region(C_shared[0, 0], 1, 64, 64), T.region(C_remote[0, 0], 2, 64, 64), 4096, 0, 0, T.sync_token_id(13)) + T.barrier_init(0, 0, 1, 2, 3) + T.wait_token(13) + T.barrier_arrive_and_wait(0) + C_2 = T.Buffer((32, 32), data=C, strides=(32, 1)) + T.dma_copy(T.region(C_remote[0, 0], 1, 64, 64), T.region(C_2[by * 64, bx * 64], 2, 64, 64), T.sync_token_id(14)) + T.wait_token(14) + """ + script_device_mode = """ + with T.launch_thread("blockIdx.x", 2) as bx: + buf_shmem = T.allocate([32768], "uint8", "shared.rsram") + A_shared = T.allocate([4096], "float16", "shared.asram") + B_shared = T.allocate([4096], "float16", "shared.wsram") + by = T.launch_thread("blockIdx.y", 2) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + C_shared = T.Buffer((4096,), data=buf_shmem, scope="shared.rsram") + for i0 in T.serial(16, annotations={"tile.domain": [64, 64], "tile.execution_axis": 0, "tile.execution_domain_axes": [0, 1], "tile.scope_entry": 1, "tile.tile_size": [4, 32]}): + for i1 in T.serial(2, annotations={"tile.execution_axis": 1}): + for ki in T.serial(4, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + C_shared[i0 * 256 + ki * 64 + i1 * 32:i0 * 256 + ki * 64 + i1 * 32 + 32] = T.Broadcast(T.float32(0.0), 32) + A_1 = T.Buffer((1024,), "float16", data=A) + A_shared_1 = T.Buffer((4096,), "float16", data=A_shared, scope="shared.asram") + T.dma_copy(T.region(A_1[by * 2048], 1, 2048), T.region(A_shared_1[0], 2, 2048), T.sync_token_id(0)) + B_1 = T.Buffer((1024,), "float16", data=B) + B_shared_1 = T.Buffer((4096,), "float16", data=B_shared, scope="shared.wsram") + T.dma_copy(T.region(B_1[bx * 64], 1, 1056), T.region(B_shared_1[0], 2, 2048), T.sync_token_id(1)) + T.wait_token(0) + T.wait_token(1) + T.mma_sunmmio(T.region(A_shared_1[0], 1, 2048), T.region(B_shared_1[0], 1, 2048), T.region(C_shared[0], 3, 4096), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(2)) + T.dma_copy(T.region(A_1[by * 2048 + 32], 1, 2048), T.region(A_shared_1[2048], 2, 2048), T.sync_token_id(3)) + T.dma_copy(T.region(B_1[bx * 64 + 1024], 1, 1056), T.region(B_shared_1[2048], 2, 2048), T.sync_token_id(4)) + T.wait_token(3) + T.wait_token(4) + T.wait_token(2) + T.mma_sunmmio(T.region(A_shared_1[2048], 1, 2048), T.region(B_shared_1[2048], 1, 2048), T.region(C_shared[0], 3, 4096), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(5)) + T.dma_copy(T.region(A_1[by * 2048 + 64], 1, 2048), T.region(A_shared_1[0], 2, 2048), T.sync_token_id(6)) + T.dma_copy(T.region(B_1[bx * 64 + 2048], 1, 1056), T.region(B_shared_1[0], 2, 2048), T.sync_token_id(7)) + T.wait_token(6) + T.wait_token(7) + T.wait_token(5) + T.mma_sunmmio(T.region(A_shared_1[0], 1, 2048), T.region(B_shared_1[0], 1, 2048), T.region(C_shared[0], 3, 4096), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(8)) + T.dma_copy(T.region(A_1[by * 2048 + 96], 1, 2048), T.region(A_shared_1[2048], 2, 2048), T.sync_token_id(9)) + T.dma_copy(T.region(B_1[bx * 64 + 3072], 1, 1056), T.region(B_shared_1[2048], 2, 2048), T.sync_token_id(10)) + T.wait_token(9) + T.wait_token(10) + T.wait_token(8) + T.mma_sunmmio(T.region(A_shared_1[2048], 1, 2048), T.region(B_shared_1[2048], 1, 2048), T.region(C_shared[0], 3, 4096), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(11)) + Bias_1 = T.Buffer((1024,), data=Bias) + Bias_shared = T.Buffer((4096,), data=buf_shmem, scope="shared.rsram") + T.dma_copy(T.region(Bias_1[by * 2048 + bx * 64], 1, 2080), T.region(Bias_shared[4096], 2, 4096), T.sync_token_id(12)) + for i in T.serial(16, annotations={"tile.domain": [64, 64], "tile.execution_axis": 0, "tile.execution_domain_axes": [0, 1], "tile.scope_entry": 1, "tile.tile_size": [4, 32]}): + for j in T.serial(2, annotations={"tile.execution_axis": 1}): + for ki in T.serial(4, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + T.wait_token(11) + T.wait_token(12) + C_shared[i * 256 + ki * 64 + j * 32:i * 256 + ki * 64 + j * 32 + 32] = C_shared[i * 256 + ki * 64 + j * 32:i * 256 + ki * 64 + j * 32 + 32] + Bias_shared[i * 256 + ki * 64 + j * 32 + 4096:i * 256 + ki * 64 + j * 32 + 4096 + 32] + C_remote = T.Buffer((4096,), data=buf_shmem, scope="shared.rsram") + T.broadcast_(T.region(C_shared[0], 1, 4096), T.region(C_remote[4096], 2, 4096), 4096, 0, 0, T.sync_token_id(13)) + T.barrier_init(0, 0, 1, 2, 3) + T.wait_token(13) + T.barrier_arrive_and_wait(0) + C_1 = T.Buffer((1024,), data=C) + T.dma_copy(T.region(C_remote[4096], 1, 4096), T.region(C_1[by * 2048 + bx * 64], 2, 2080), T.sync_token_id(14)) + T.wait_token(14) + """ + + def get_verify_merge_allocate(): + kernel_name = "main_kernel" + # 65536 65536 100352 + block_m, block_n, block_k = 64, 64, 32 + cnt_a = block_m * block_k * 2 + cnt_w = block_k * block_n * 2 + # c_shared, bias_shared and c_remote are all on rsram, dtype = *4, bias and c_remote reuse, so only 2 buffer spaces are needed * 2 + cnt_r = block_m * block_n * 2 * 4 + return build_verify_merge_allocate(kernel_name=kernel_name, cnt_a=cnt_a, cnt_w=cnt_w, cnt_r=cnt_r) + + test_config = { + "LowerTileOp": { + "script_expected": script_lower_tile_op, + }, + "InjectSunmmioSync": { + "script_expected": script_InjectSunmmioSync, + }, + "MergeSharedMemoryAllocationsSunmmio": { + "formal_verify": get_verify_merge_allocate(), + }, + "DeviceMode": { + "script_expected": script_device_mode, + }, + } + test_config = get_or_add_default_verify(func, test_config) + compile_test(func, out_idx=[2], target="Sunmmio", test_config=test_config) + + +if __name__ == "__main__": + test_overall() diff --git a/testing/python/compile_pipeline/test_summa.py b/testing/python/compile_pipeline/test_summa.py new file mode 100644 index 0000000000..7413e36cf8 --- /dev/null +++ b/testing/python/compile_pipeline/test_summa.py @@ -0,0 +1,191 @@ +import tilelang.language as T +from compile_pipeline import compile_test +from formal_verify_funcs import * + + +def summa_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float32"): + """ + SUMMA (Scalable Universal Matrix Multiplication Algorithm) + for a 4x4 mesh. + + Grid size: (N/block_N, M/block_M) = (4, 4) + """ + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + # Assume the current is a 4x4 processor grid (Mesh) + # Each core is responsible for outputting a 32x32 block of matrix C + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): + # Allocate local SRAM cache + # A_shared is placed in ASRAM (usually used for A matrix cache) + # B_shared is placed in WSRAM (usually used for B matrix cache) + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + + # Local accumulator, placed in RSRAM + C_local = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + T.clear(C_local) + + # Number of iterations in K dimension (for 128/32 = 4 steps) + K_steps = T.ceildiv(K, block_K) + + # Core loop of SUMMA algorithm + for k_tile in range(K_steps): + # --- Step 1: Broadcast row block of matrix A --- + # Broadcast directly from DRAM to asram of each core + # Source core coordinate is (by, k_tile), which is responsible for reading from DRAM and broadcasting to all cores in the same row + T.comm.broadcast( + A[ + by * block_M : by * block_M + block_M, + k_tile * block_K : k_tile * block_K + block_K, + ], + A_shared, + (by, k_tile), + direction="h", + ) + + # --- Step 2: Broadcast column block of matrix B --- + # Broadcast directly from DRAM to wsram of each core + # Source core coordinate is (k_tile, bx), which is responsible for reading from DRAM and broadcasting to all cores in the same column + T.comm.broadcast( + B[ + k_tile * block_K : k_tile * block_K + block_K, + bx * block_N : bx * block_N + block_N, + ], + B_shared, + (k_tile, bx), + direction="v", + ) + + # --- Step 3: Local computation --- + # Each core performs local GEMM using broadcasted A_shared and B_shared + T.gemm(A_shared, B_shared, C_local) + + # After the loop ends, write local computation result back to DRAM + T.copy(C_local, C[by * block_M, bx * block_N]) + + return kernel + + +def test_summa(): + func = summa_matmul(128, 128, 128, 32, 32, 32) + + script_lower_tile_op = """ + with T.block("tilelang_root"): + T.reads(A[by * 32:by * 32 + 32, 0:128], B[0:128, bx * 32:bx * 32 + 32], C[by * 32, bx * 32]) + T.writes() + T.block_attr({"layout_map": {C_local: metadata["tl.Layout"][0], A_shared: metadata["tl.Layout"][1], B_shared: metadata["tl.Layout"][2]}}) + A_shared = T.alloc_buffer((32, 32), "float16", data=A_shared.data, scope="shared.asram") + B_shared = T.alloc_buffer((32, 32), "float16", data=B_shared.data, scope="shared.wsram") + C_local = T.alloc_buffer((32, 32), data=C_local.data, scope="shared.rsram") + for i0 in T.serial(32, annotations={"tile.domain": [32, 32], "tile.loop_parallel": 1, "tile.loop_stage": 0}): + for i1 in T.serial(32, annotations={"tile.loop_parallel": 1, "tile.loop_stage": 0}): + C_local[i0, i1] = T.Cast("float32", 0) + for k_tile in range(4): + T.broadcast_(T.region(A[by * 32, k_tile * 32], 1, 32, 32), T.region(A_shared[0, 0], 2, 32, 32), 1024, by * 4 + k_tile, 0) + T.broadcast_(T.region(B[k_tile * 32, bx * 32], 1, 32, 32), T.region(B_shared[0, 0], 2, 32, 32), 1024, k_tile * 4 + bx, 1) + with T.block("_gemm_sss"): + T.reads() + T.writes() + T.mma_sunmmio(T.region(A_shared[0, 0], 1, 32, 32), T.region(B_shared[0, 0], 1, 32, 32), T.region(C_local[0, 0], 3, 32, 32), T.bool(False), T.bool(False), T.bool(False)) + T.dma_copy(T.region(C_local[0, 0], 1, 32, 32), T.region(C[by * 32, bx * 32], 2, 32, 32)) + """ + + script_InjectSunmmioSync = """ + with T.decl_buffer((32, 32), scope="shared.rsram") as C_local: + for i0 in T.serial(8, annotations={"tile.domain": [32, 32], "tile.execution_axis": 0, "tile.execution_domain_axes": [0, 1], "tile.scope_entry": 1, "tile.tile_size": [4, 32]}): + for i1 in T.serial(1, annotations={"tile.execution_axis": 1}): + for ki in T.serial(4, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + for kj in T.vectorized(32, annotations={"tile.interior": 1, "tile.interior_axis": 1}): + C_local[i0 * 4 + ki, kj] = T.float32(0.0) + T.sync_null_token(2) + for k_tile in range(4): + A_shared = T.decl_buffer((32, 32), "float16", scope="shared.asram") + B_shared = T.decl_buffer((32, 32), "float16", scope="shared.wsram") + T.wait_token(2) + A_2 = T.Buffer((128, 128), "float16", data=A, strides=(128, 1)) + T.broadcast_(T.region(A_2[by * 32, k_tile * 32], 1, 32, 32), T.region(A_shared[0, 0], 2, 32, 32), 1024, by * 4 + k_tile, 0, T.sync_token_id(0)) + T.barrier_init(0, by * 4 + k_tile, k_tile // 4 * 4 + by * 4, k_tile // 4 * 4 + by * 4 + 1, k_tile // 4 * 4 + by * 4 + 2, k_tile // 4 * 4 + by * 4 + 3) + B_2 = T.Buffer((128, 128), "float16", data=B, strides=(128, 1)) + T.broadcast_(T.region(B_2[k_tile * 32, bx * 32], 1, 32, 32), T.region(B_shared[0, 0], 2, 32, 32), 1024, k_tile * 4 + bx, 1, T.sync_token_id(1)) + T.barrier_init(1, k_tile * 4 + bx, bx % 4, bx % 4 + 4, bx % 4 + 8, bx % 4 + 12) + T.wait_token(0) + T.barrier_arrive_and_wait(0) + T.wait_token(1) + T.barrier_arrive_and_wait(1) + T.mma_sunmmio(T.region(A_shared[0, 0], 1, 32, 32), T.region(B_shared[0, 0], 1, 32, 32), T.region(C_local[0, 0], 3, 32, 32), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(2)) + T.wait_token(2) + C_2 = T.Buffer((128, 128), data=C, strides=(128, 1)) + T.dma_copy(T.region(C_local[0, 0], 1, 32, 32), T.region(C_2[by * 32, bx * 32], 2, 32, 32), T.sync_token_id(3)) + T.wait_token(3) + """ + + script_device_mode = """ + @T.prim_func + def kernel_kernel(A: T.handle("float16", "global"), B: T.handle("float16", "global"), C: T.handle("float32", "global")) -> T.int32: + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""}), "thread_extent": {"blockIdx.x": 4, "blockIdx.y": 4, "threadIdx.x": 128, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.noalias": True, "tl.non_restrict_params": [], "tl.readonly_param_indices": [0, 1, 2]}) + with T.launch_thread("blockIdx.x", 4) as bx: + C_local = T.allocate([1024], "float32", "shared.rsram") + A_shared = T.allocate([1024], "float16", "shared.asram") + B_shared = T.allocate([1024], "float16", "shared.wsram") + by = T.launch_thread("blockIdx.y", 4) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + C_local_1 = T.Buffer((1024,), data=C_local, scope="shared.rsram") + C_local_1[tx * 8:tx * 8 + 8] = T.Broadcast(T.float32(0.0), 8) + T.sync_null_token(2) + for k_tile in range(4): + T.wait_token(2) + A_1 = T.Buffer((16384,), "float16", data=A) + A_shared_1 = T.Buffer((1024,), "float16", data=A_shared, scope="shared.asram") + T.broadcast_(T.region(A_1[by * 4096 + k_tile * 32], 1, 4000), T.region(A_shared_1[0], 2, 1024), 1024, by * 4 + k_tile, 0, T.sync_token_id(0)) + T.barrier_init(0, by * 4 + k_tile, by * 4, by * 4 + 1, by * 4 + 2, by * 4 + 3) + B_1 = T.Buffer((16384,), "float16", data=B) + B_shared_1 = T.Buffer((1024,), "float16", data=B_shared, scope="shared.wsram") + T.broadcast_(T.region(B_1[k_tile * 4096 + bx * 32], 1, 4000), T.region(B_shared_1[0], 2, 1024), 1024, k_tile * 4 + bx, 1, T.sync_token_id(1)) + T.barrier_init(1, k_tile * 4 + bx, bx, bx + 4, bx + 8, bx + 12) + T.wait_token(0) + T.barrier_arrive_and_wait(0) + T.wait_token(1) + T.barrier_arrive_and_wait(1) + T.mma_sunmmio(T.region(A_shared_1[0], 1, 1024), T.region(B_shared_1[0], 1, 1024), T.region(C_local_1[0], 3, 1024), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(2)) + T.wait_token(2) + C_1 = T.Buffer((16384,), data=C) + T.dma_copy(T.region(C_local_1[0], 1, 1024), T.region(C_1[by * 4096 + bx * 32], 2, 4000), T.sync_token_id(3)) + T.wait_token(3) + return 0 + """ + + def get_verify_merge_allocate(): + kernel_name = "kernel_kernel" + # Only one buffer per scope, no need to verify merge size + return build_verify_merge_allocate(kernel_name=kernel_name) + + test_config = { + "LowerTileOp": { + "script_expected": script_lower_tile_op, + }, + "InjectSunmmioSync": { + "script_expected": script_InjectSunmmioSync, + }, + "MergeSharedMemoryAllocationsSunmmio": { + "formal_verify": get_verify_merge_allocate(), + }, + "DeviceMode": { + "script_expected": script_device_mode, + }, + } + test_config = get_or_add_default_verify(func, test_config) + compile_test(func, target="Sunmmio", test_config=test_config) + + +if __name__ == "__main__": + test_summa() diff --git a/testing/python/compile_pipeline/test_sync.py b/testing/python/compile_pipeline/test_sync.py new file mode 100644 index 0000000000..4bb417f760 --- /dev/null +++ b/testing/python/compile_pipeline/test_sync.py @@ -0,0 +1,170 @@ +import tilelang.language as T +from compile_pipeline import compile_test +from formal_verify_funcs import * + + +def kernel_sync(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((M, K), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=1) as ( + bx, + by, + ): + A_shared = T.alloc_shared((1024, 1024), dtype, scope="shared.asram") + B_shared = T.alloc_shared((1024, 1024), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((1024, 1024), dtype, scope="shared.rsram") + D_shared = T.alloc_shared((1024, 1024), dtype, scope="shared.rsram") + E_shared = T.alloc_shared((1024, 1024), dtype, scope="shared.rsram") + + T.gemm(A_shared, B_shared, C_shared) + if bx <= 2: + T.clear(D_shared) + + for i in range(5): + C_shared[i, 0] = C_shared[i, 0] + 1.0 + + for _i in range(10): + T.comm.broadcast(D_shared, E_shared, (0, 0), direction="h") + E_shared[0, 0] = E_shared[0, 0] + 1.0 + T.comm.broadcast(E_shared, D_shared, (0, 0), direction="h") + + return kernel + + +def test_sync(): + func = kernel_sync(1024 * 16, 1024 * 16, 1024 * 16, 1024, 1024, 1024) + + script_lower_tile_op = """ + with T.block("_gemm_sss"): + T.reads() + T.writes() + T.mma_sunmmio(T.region(A_shared[0, 0], 1, 1024, 1024), T.region(B_shared[0, 0], 1, 1024, 1024), T.region(C_shared[0, 0], 3, 1024, 1024), T.bool(False), T.bool(False), T.bool(False)) + if bx <= 2: + for i0 in T.serial(1024, annotations={"tile.domain": [1024, 1024], "tile.loop_parallel": 1, "tile.loop_stage": 0}): + for i1 in T.serial(1024, annotations={"tile.loop_parallel": 1, "tile.loop_stage": 0}): + D_shared[i0, i1] = T.Cast("float16", 0) + for i in range(5): + C_shared[i, 0] = T.Cast("float16", T.Cast("float32", C_shared[i, 0]) + T.float32(1.0)) + for _i in range(10): + T.broadcast_(T.region(D_shared[0, 0], 1, 1024, 1024), T.region(E_shared[0, 0], 2, 1024, 1024), 1048576, 0, 0) + E_shared[0, 0] = T.Cast("float16", T.Cast("float32", E_shared[0, 0]) + T.float32(1.0)) + T.broadcast_(T.region(E_shared[0, 0], 1, 1024, 1024), T.region(D_shared[0, 0], 2, 1024, 1024), 1048576, 0, 0) + """ + + script_InjectSunmmioSync = """ + with T.decl_buffer((1024, 1024), "float16", scope="shared.asram") as A_shared: + B_shared = T.decl_buffer((1024, 1024), "float16", scope="shared.wsram") + T.mma_sunmmio(T.region(A_shared[0, 0], 1, 1024, 1024), T.region(B_shared[0, 0], 1, 1024, 1024), T.region(C_shared[0, 0], 3, 1024, 1024), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(0)) + if bx <= 2: + for i0 in T.serial(1024, annotations={"tile.domain": [1024, 1024], "tile.execution_axis": 0, "tile.execution_domain_axes": [0, 1], "tile.scope_entry": 1, "tile.tile_size": [1, 256]}): + for i1 in T.serial(4, annotations={"tile.execution_axis": 1}): + for ki in T.serial(1, annotations={"tile.interior": 1, "tile.interior_axis": 0}): + for kj in T.serial(4, annotations={"tile.interior": 1, "tile.interior_axis": 1}): + for vec in T.vectorized(64): + D_shared[i0, i1 * 256 + kj * 64 + vec] = T.float16(0.0) + for i in range(5): + T.wait_token(0) + C_shared[i, 0] = T.Cast("float16", T.Cast("float32", C_shared[i, 0]) + T.float32(1.0)) + T.sync_null_token(2) + T.barrier_init(1, 0, 1, 2, 3) + for _i in range(10): + E_shared = T.decl_buffer((1024, 1024), "float16", scope="shared.rsram") + T.wait_token(2) + T.barrier_arrive_and_wait(1) + T.broadcast_(T.region(D_shared[0, 0], 1, 1024, 1024), T.region(E_shared[0, 0], 2, 1024, 1024), 1048576, 0, 0, T.sync_token_id(1)) + T.barrier_init(0, 0, 1, 2, 3) + T.wait_token(1) + T.barrier_arrive_and_wait(0) + E_shared[0, 0] = T.Cast("float16", T.Cast("float32", E_shared[0, 0]) + T.float32(1.0)) + T.broadcast_(T.region(E_shared[0, 0], 1, 1024, 1024), T.region(D_shared[0, 0], 2, 1024, 1024), 1048576, 0, 0, T.sync_token_id(2)) + T.barrier_init(1, 0, 1, 2, 3) + T.wait_token(2) + T.barrier_arrive_and_wait(1) + """ + + script_device_mode = """ + @T.prim_func(private=True) + def kernel_kernel() -> T.int32: + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""}), "tir.is_global_func": True, "tir.noalias": True, "tl.non_restrict_params": []}) + with T.launch_thread("blockIdx.x", 16) as bx: + by = T.launch_thread("blockIdx.y", 16) + tx = T.launch_thread("threadIdx.x", 1) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.allocate([1048576], "float16", "shared.rsram") as C_shared: + D_shared = T.allocate([1048576], "float16", "shared.rsram") + C_shared_1 = T.Buffer((1048576,), "float16", data=C_shared, scope="shared.rsram") + with T.allocate([1048576], "float16", "shared.asram") as A_shared: + B_shared = T.allocate([1048576], "float16", "shared.wsram") + A_shared_1 = T.Buffer((1048576,), "float16", data=A_shared, scope="shared.asram") + B_shared_1 = T.Buffer((1048576,), "float16", data=B_shared, scope="shared.wsram") + T.mma_sunmmio(T.region(A_shared_1[0], 1, 1048576), T.region(B_shared_1[0], 1, 1048576), T.region(C_shared_1[0], 3, 1048576), T.bool(False), T.bool(False), T.bool(False), T.sync_token_id(0)) + D_shared_1 = T.Buffer((1048576,), "float16", data=D_shared, scope="shared.rsram") + if bx <= 2: + for i in T.unroll(16384): + D_shared_1[i * 64:i * 64 + 64] = T.Broadcast(T.float16(0.0), 64) + for i in range(5): + T.wait_token(0) + C_shared_1[i * 1024] = T.Cast("float16", T.Cast("float32", C_shared_1[i * 1024]) + T.float32(1.0)) + T.sync_null_token(2) + T.barrier_init(1, 0, 1, 2, 3) + for _i in range(10): + E_shared = T.allocate([1048576], "float16", "shared.rsram") + T.wait_token(2) + T.barrier_arrive_and_wait(1) + E_shared_1 = T.Buffer((1048576,), "float16", data=E_shared, scope="shared.rsram") + T.broadcast_(T.region(D_shared_1[0], 1, 1048576), T.region(E_shared_1[0], 2, 1048576), 1048576, 0, 0, T.sync_token_id(1)) + T.barrier_init(0, 0, 1, 2, 3) + T.wait_token(1) + T.barrier_arrive_and_wait(0) + E_shared_1[0] = T.Cast("float16", T.Cast("float32", E_shared_1[0]) + T.float32(1.0)) + T.broadcast_(T.region(E_shared_1[0], 1, 1048576), T.region(D_shared_1[0], 2, 1048576), 1048576, 0, 0, T.sync_token_id(2)) + T.barrier_init(1, 0, 1, 2, 3) + T.wait_token(2) + T.barrier_arrive_and_wait(1) + return 0 + """ + + script_mere_allocate = [ + """ + buf_shmem = T.allocate([4194304], "uint8", "shared.rsram") + A_shared = T.allocate([1048576], "float16", "shared.asram") + B_shared = T.allocate([1048576], "float16", "shared.wsram") + """ + ] + + def get_verify_merge_allocate(): + kernel_name = "kernel_kernel" + # a, w only have one, no change in size, + # r has five, 2 are not used, remaining cde, ce reuse, so only need the size of two buffers (*2), float(*2) + cnt_a = 1024 * 1024 + cnt_w = 1024 * 1024 + cnt_r = 1024 * 1024 * 2 * 2 + return build_verify_merge_allocate(kernel_name=kernel_name, cnt_a=cnt_a, cnt_w=cnt_w, cnt_r=cnt_r) + + test_config = { + "LowerTileOp": { + "script_expected": script_lower_tile_op, + }, + "InjectSunmmioSync": { + "script_expected": script_InjectSunmmioSync, + }, + "MergeSharedMemoryAllocationsSunmmio": { + "script_expected": script_mere_allocate, + "formal_verify": get_verify_merge_allocate(), + }, + "DeviceMode": { + "script_expected": script_device_mode, + }, + } + test_config = get_or_add_default_verify(func, test_config) + compile_test(func, out_idx=[2], target="Sunmmio", test_config=test_config) + + +if __name__ == "__main__": + test_sync()