diff --git a/src/Compiler.jl b/src/Compiler.jl index d06992d13e..815a6a1a99 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1182,12 +1182,6 @@ function compile_mlir!( raise isa Bool && (raise = true) end - concrete_seen = OrderedIdDict() - - concrete_result = make_tracer( - concrete_seen, traced_result, ("result",), TracedToConcrete; runtime - ) - optimize isa Bool && (optimize = ifelse(optimize, :all, :none)) toolkit = "" @@ -1811,6 +1805,12 @@ function compile_mlir!( ] end + if result_shardings isa Vector + result_shardings = [ + result_shardings[i] for (i, present) in enumerate(results_mask) if present + ] + end + func3 = MLIR.Dialects.func.func_(; sym_name="main", function_type=MLIR.IR.FunctionType(in_tys, out_tys2), @@ -1897,7 +1897,6 @@ function compile_mlir!( mlir_fn_res.num_replicas, mlir_fn_res.is_sharded, preserved_args, - concrete_result, mlir_fn_res.unique_meshes, mlir_fn_res.mutated_args, use_shardy_partitioner, @@ -2985,7 +2984,7 @@ function compile(f, args; sync=false, kwargs...) seen_args, linear_results, preserved_args, - concrete_result, + traced_result, donated_args_mask, ) = mlir_fn_res @@ -3037,6 +3036,15 @@ function compile(f, args; sync=false, kwargs...) ndevices, ) + concrete_result = make_tracer( + OrderedIdDict(), + traced_result, + ("result",), + TracedToConcrete; + runtime=XLA.runtime(client), + sharding=mlir_fn_res.result_shardings, + ) + unflatten_code, used_shardinfo = codegen_unflatten!( linear_args, preserved_args, diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 4241ee316a..ab17fe652c 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -211,7 +211,7 @@ Base.@nospecializeinfer function unpad_val_op( ) end -mutable struct CompiledMlirFnResult{F,TR,Re,Rt,LA,LR,PA,CR,M,MA,RS,GD,DA} +mutable struct CompiledMlirFnResult{F,TR,Re,Rt,LA,LR,PA,M,MA,RS,GD,DA} fnwrapped::Bool f::F traced_result::TR @@ -225,7 +225,6 @@ mutable struct CompiledMlirFnResult{F,TR,Re,Rt,LA,LR,PA,CR,M,MA,RS,GD,DA} num_replicas::Int is_sharded::Bool preserved_args::PA - concrete_result::CR unique_meshes::M mutated_args::MA use_shardy_partitioner::Bool @@ -370,7 +369,6 @@ function make_mlir_fn( num_replicas, is_sharded, nothing, - nothing, unique_meshes, mutated_args, true,