From 4279851faa3ad679119d5a0a1bcb7c19777faa7a Mon Sep 17 00:00:00 2001
From: Avik Pal <avikpal@mit.edu>
Date: Sun, 4 May 2025 18:34:40 -0400
Subject: [PATCH] feat: add a macro to directly visualize the generated mlir

---
 docs/src/api/api.md |  1 +
 src/Compiler.jl     | 45 +++++++++++++++++++++++++++++++++++++++++++++
 src/Reactant.jl     | 11 ++++++++++-
 3 files changed, 56 insertions(+), 1 deletion(-)

diff --git a/docs/src/api/api.md b/docs/src/api/api.md
index b2a32148a0..9bc3a3a820 100644
--- a/docs/src/api/api.md
+++ b/docs/src/api/api.md
@@ -27,6 +27,7 @@ within_compile
 @code_hlo
 @code_mhlo
 @code_xla
+@mlir_visualize
 ```
 
 ## Profile XLA
diff --git a/src/Compiler.jl b/src/Compiler.jl
index c724d10b47..b258f55b0b 100644
--- a/src/Compiler.jl
+++ b/src/Compiler.jl
@@ -2068,6 +2068,51 @@ macro jit(args...)
     #! format: on
 end
 
+"""
+    @mlir_visualize [optimize = ...] [no_nan = <true/false>] f(args...)
+
+Runs `@code_hlo` and visualizes the MLIR module using `model-explorer`. This expects the
+`model-explorer` executable to be in your `PATH`. Installation instructions can be found
+[here](https://github.com/google-ai-edge/model-explorer).
+"""
+macro mlir_visualize(args...)
+    default_options = Dict{Symbol,Any}(
+        :optimize => true,
+        :no_nan => false,
+        :client => nothing,
+        :raise => false,
+        :raise_first => false,
+        :shardy_passes => :(:to_mhlo_shardings),
+        :assert_nonallocating => false,
+        :donated_args => :(:auto),
+        :transpose_propagate => :(:up),
+        :reshape_propagate => :(:up),
+        :optimize_then_pad => true,
+        :optimize_communications => true,
+        :cudnn_hlo_optimize => false,
+    )
+    compile_expr, (; compiled) = compile_call_expr(
+        __module__, compile_mlir, default_options, args...
+    )
+    #! format: off
+    return esc(
+        :(
+            if Sys.which("model-explorer") === nothing
+                error("model-explorer is not in your PATH. Please install it from \
+                       https://github.com/google-ai-edge/model-explorer")
+            end;
+            $(compile_expr);
+            mlir_mod = $(first)($(compiled));
+            tmpfile = tempname() * ".mlir";
+            open(tmpfile, "w") do io
+                print(io, mlir_mod)
+            end;
+            run(`model-explorer $(tmpfile)`)
+        )
+    )
+    #! format: on
+end
+
 function compile_call_expr(mod, compiler, options::Dict, args...)
     while length(args) > 1
         option, args = args[1], args[2:end]
diff --git a/src/Reactant.jl b/src/Reactant.jl
index b389d07237..1b698c3bcf 100644
--- a/src/Reactant.jl
+++ b/src/Reactant.jl
@@ -201,7 +201,15 @@ function Enzyme.make_zero(
     return res
 end
 
-using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile
+using .Compiler:
+    @compile,
+    @code_hlo,
+    @code_mhlo,
+    @jit,
+    @code_xla,
+    @mlir_visualize,
+    traced_getfield,
+    compile
 export ConcreteRArray,
     ConcreteRNumber,
     ConcretePJRTArray,
@@ -214,6 +222,7 @@ export ConcreteRArray,
     @code_xla,
     @jit,
     @trace,
+    @mlir_visualize,
     within_compile
 
 const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()