1919 _implicit_cvt ,
2020 RESERVED_KWS ,
2121 interpreter_builder ,
22+ InterpretedFunction ,
2223)
2324from triton .runtime .interpreter import _patch_lang as triton_patch_lang
2425from triton .runtime import JITFunction
@@ -342,6 +343,8 @@ def wrapper(input, axis=None, keep_dims=False):
342343def patch ():
343344 old_grid_executor_call = GridExecutor .__call__
344345 old_jit_function_call = JITFunction .__call__
346+ # XXX(Keren): Temporarily disable rewriting of AST
347+ old_rewrite_ast = InterpretedFunction ._rewrite_ast
345348 old_create_make_range = interpreter_builder .create_make_range
346349 old_create_masked_load = interpreter_builder .create_masked_load
347350 old_create_expand_dims = interpreter_builder .create_expand_dims
@@ -350,6 +353,7 @@ def patch():
350353 old_create_masked_store = interpreter_builder .create_masked_store
351354 GridExecutor .__call__ = _grid_executor_call
352355 JITFunction .__call__ = _jit_function_call
356+ InterpretedFunction ._rewrite_ast = lambda self : self .fn
353357 interpreter_builder .create_make_range = _create_make_range (
354358 interpreter_builder .create_make_range
355359 )
@@ -369,6 +373,7 @@ def patch():
369373 finally :
370374 GridExecutor .__call__ = old_grid_executor_call
371375 JITFunction .__call__ = old_jit_function_call
376+ InterpretedFunction ._rewrite_ast = old_rewrite_ast
372377 interpreter_builder .create_make_range = old_create_make_range
373378 interpreter_builder .create_masked_load = old_create_masked_load
374379 interpreter_builder .create_expand_dims = old_create_expand_dims
0 commit comments