Skip to content

Commit a803cb0

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI] Refactor how cpp_wrapper specific options are set (pytorch#136035)
Summary: 1) When cpp-wrapper is turned on, certain triton specific options need to be set, both for forward and backward. This PR considate the settings in one place. 2) Change config.triton.autotune_at_compile_time to default to None. If the flag is not explicitly set by user, default it to True for cpp-wrapper. Differential Revision: [D62689940](https://our.internmc.facebook.com/intern/diff/D62689940) Pull Request resolved: pytorch#136035 Approved by: https://github.com/chenyang78
1 parent bbc3fdb commit a803cb0

File tree

3 files changed

+30
-28
lines changed

3 files changed

+30
-28
lines changed

torch/_inductor/compile_fx.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,18 @@ def wrapper(args):
12461246
return wrapper
12471247

12481248

1249+
def get_cpp_wrapper_config():
1250+
return {
1251+
# Set autotune_at_compile_time to True as default if the option is not explicitly set
1252+
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
1253+
if config.triton.autotune_at_compile_time is not None
1254+
else True,
1255+
"triton.autotune_cublasLt": False,
1256+
"triton.cudagraphs": False, # TODO: to be removed
1257+
"triton.store_cubin": True,
1258+
}
1259+
1260+
12491261
def compile_fx(
12501262
model_: torch.fx.GraphModule,
12511263
example_inputs_: List[torch.Tensor],
@@ -1268,18 +1280,8 @@ def compile_fx(
12681280
if config.cpp_wrapper:
12691281
with config.patch(
12701282
{
1271-
"cpp_wrapper": False,
1272-
# For triton.autotune_at_compile_time, disable by default for
1273-
# FBCode, but enabled by default for OSS.
1274-
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
1275-
if config.is_fbcode()
1276-
else os.environ.get(
1277-
"TORCHINDUCTOR_TRITON_AUTOTUNE_AT_COMPILE_TIME", "1"
1278-
)
1279-
== "1",
1280-
"triton.autotune_cublasLt": False,
1281-
"triton.cudagraphs": False,
1282-
"triton.store_cubin": True,
1283+
"cpp_wrapper": False, # reset to break recursive call to compile_fx
1284+
**get_cpp_wrapper_config(),
12831285
}
12841286
), V.set_real_inputs(example_inputs_):
12851287
inputs_ = example_inputs_
@@ -1470,16 +1472,19 @@ def bw_compiler(
14701472
n.name for n in model_outputs if isinstance(n, torch.fx.Node)
14711473
)
14721474
fixed = count_tangents(model)
1473-
return inner_compile(
1474-
model,
1475-
example_inputs,
1476-
static_input_idxs=list(range(fixed)),
1477-
cudagraphs=cudagraphs,
1478-
is_backward=True,
1479-
graph_id=graph_id,
1480-
boxed_forward_device_index=forward_device,
1481-
user_visible_outputs=user_visible_outputs,
1482-
)
1475+
with config.patch(
1476+
get_cpp_wrapper_config()
1477+
) if config.cpp_wrapper else contextlib.nullcontext():
1478+
return inner_compile(
1479+
model,
1480+
example_inputs,
1481+
static_input_idxs=list(range(fixed)),
1482+
cudagraphs=cudagraphs,
1483+
is_backward=True,
1484+
graph_id=graph_id,
1485+
boxed_forward_device_index=forward_device,
1486+
user_visible_outputs=user_visible_outputs,
1487+
)
14831488

14841489
# TODO: can add logging before/after the call to create_aot_dispatcher_function
14851490
# in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func

torch/_inductor/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,8 @@ class triton:
905905
autotune_cublasLt = True
906906

907907
# Tune the generated Triton kernels at compile time instead of first time they run
908-
autotune_at_compile_time = False
908+
# Setting to None means uninitialized
909+
autotune_at_compile_time: Optional[bool] = None
909910

910911
# should we stop a fusion to allow better tiling?
911912
tiling_prevents_pointwise_fusion = True

torch/_inductor/graph.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,11 +1719,7 @@ def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]:
17191719
if any(device in self.device_types for device in ["cuda", "xpu"]):
17201720
# first pass
17211721
self.cpp_wrapper = False
1722-
# Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick
1723-
# that up. In theory it should work by only setting triton.store_cubin to True here,
1724-
# but that will cause a problem when use_runtime_constant_folding is OrderedSet.
1725-
with config.patch({"triton.store_cubin": True}):
1726-
compiled = self.compile_to_module().call
1722+
compiled = self.compile_to_module().call
17271723

17281724
if not config.triton.autotune_at_compile_time:
17291725

0 commit comments

Comments
 (0)