diff --git a/src/Compiler.jl b/src/Compiler.jl
index d06992d13e..815a6a1a99 100644
--- a/src/Compiler.jl
+++ b/src/Compiler.jl
@@ -1182,12 +1182,6 @@ function compile_mlir!(
         raise isa Bool && (raise = true)
     end
 
-    concrete_seen = OrderedIdDict()
-
-    concrete_result = make_tracer(
-        concrete_seen, traced_result, ("result",), TracedToConcrete; runtime
-    )
-
     optimize isa Bool && (optimize = ifelse(optimize, :all, :none))
 
     toolkit = ""
@@ -1811,6 +1805,12 @@ function compile_mlir!(
         ]
     end
 
+    if result_shardings isa Vector
+        result_shardings = [
+            result_shardings[i] for (i, present) in enumerate(results_mask) if present
+        ]
+    end
+
     func3 = MLIR.Dialects.func.func_(;
         sym_name="main",
         function_type=MLIR.IR.FunctionType(in_tys, out_tys2),
@@ -1897,7 +1897,6 @@ function compile_mlir!(
         mlir_fn_res.num_replicas,
         mlir_fn_res.is_sharded,
         preserved_args,
-        concrete_result,
         mlir_fn_res.unique_meshes,
         mlir_fn_res.mutated_args,
         use_shardy_partitioner,
@@ -2985,7 +2984,7 @@ function compile(f, args; sync=false, kwargs...)
         seen_args,
         linear_results,
         preserved_args,
-        concrete_result,
+        traced_result,
         donated_args_mask,
     ) = mlir_fn_res
 
@@ -3037,6 +3036,15 @@ function compile(f, args; sync=false, kwargs...)
         ndevices,
     )
 
+    concrete_result = make_tracer(
+        OrderedIdDict(),
+        traced_result,
+        ("result",),
+        TracedToConcrete;
+        runtime=XLA.runtime(client),
+        sharding=mlir_fn_res.result_shardings,
+    )
+
     unflatten_code, used_shardinfo = codegen_unflatten!(
         linear_args,
         preserved_args,
diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl
index 4241ee316a..ab17fe652c 100644
--- a/src/TracedUtils.jl
+++ b/src/TracedUtils.jl
@@ -211,7 +211,7 @@ Base.@nospecializeinfer function unpad_val_op(
     )
 end
 
-mutable struct CompiledMlirFnResult{F,TR,Re,Rt,LA,LR,PA,CR,M,MA,RS,GD,DA}
+mutable struct CompiledMlirFnResult{F,TR,Re,Rt,LA,LR,PA,M,MA,RS,GD,DA}
     fnwrapped::Bool
     f::F
     traced_result::TR
@@ -225,7 +225,6 @@ mutable struct CompiledMlirFnResult{F,TR,Re,Rt,LA,LR,PA,CR,M,MA,RS,GD,DA}
     num_replicas::Int
     is_sharded::Bool
     preserved_args::PA
-    concrete_result::CR
     unique_meshes::M
     mutated_args::MA
     use_shardy_partitioner::Bool
@@ -370,7 +369,6 @@ function make_mlir_fn(
         num_replicas,
         is_sharded,
         nothing,
-        nothing,
         unique_meshes,
         mutated_args,
         true,