Skip to content

Commit 589b115

Browse files
authored
Hotfix to sync with triton/main (#29)
1 parent 62ddb6f commit 589b115

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

triton_viz/interpreter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_implicit_cvt,
2020
RESERVED_KWS,
2121
interpreter_builder,
22+
InterpretedFunction,
2223
)
2324
from triton.runtime.interpreter import _patch_lang as triton_patch_lang
2425
from triton.runtime import JITFunction
@@ -342,6 +343,8 @@ def wrapper(input, axis=None, keep_dims=False):
342343
def patch():
343344
old_grid_executor_call = GridExecutor.__call__
344345
old_jit_function_call = JITFunction.__call__
346+
# XXX(Keren): Temporarily disable rewriting of AST
347+
old_rewrite_ast = InterpretedFunction._rewrite_ast
345348
old_create_make_range = interpreter_builder.create_make_range
346349
old_create_masked_load = interpreter_builder.create_masked_load
347350
old_create_expand_dims = interpreter_builder.create_expand_dims
@@ -350,6 +353,7 @@ def patch():
350353
old_create_masked_store = interpreter_builder.create_masked_store
351354
GridExecutor.__call__ = _grid_executor_call
352355
JITFunction.__call__ = _jit_function_call
356+
InterpretedFunction._rewrite_ast = lambda self: self.fn
353357
interpreter_builder.create_make_range = _create_make_range(
354358
interpreter_builder.create_make_range
355359
)
@@ -369,6 +373,7 @@ def patch():
369373
finally:
370374
GridExecutor.__call__ = old_grid_executor_call
371375
JITFunction.__call__ = old_jit_function_call
376+
InterpretedFunction._rewrite_ast = old_rewrite_ast
372377
interpreter_builder.create_make_range = old_create_make_range
373378
interpreter_builder.create_masked_load = old_create_masked_load
374379
interpreter_builder.create_expand_dims = old_create_expand_dims

0 commit comments

Comments
 (0)