Skip to content

Commit 3464f56

Browse files
committed
chore: Bump version to 0.0.41 and update module imports for improved functionality
1 parent 1aa6fb1 commit 3464f56

File tree

6 files changed

+225
-187
lines changed

6 files changed

+225
-187
lines changed

eformer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "0.0.40"
15+
__version__ = "0.0.41"
1616

1717
__all__ = (
1818
"aparser",

eformer/callib/__init__.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from ._cjit import cjit, compile_function, lower_function
16-
from ._triton_call import (
17-
cdiv,
18-
get_triton_type,
19-
next_power_of_2,
20-
normalize_grid,
21-
strides_from_shape,
22-
triton_call,
23-
)
24-
25-
try:
26-
from ._suppress_triton import (
27-
enable_all_triton_output,
28-
silence_all_triton_output,
29-
)
30-
except ImportError:
31-
32-
def silence_all_triton_output():
33-
"""Fallback function when real suppression isn't available."""
34-
return False
35-
36-
def enable_all_triton_output():
37-
"""Fallback function when real suppression isn't available."""
38-
pass
39-
15+
from ._cjit import cjit, compile_function, load_cached_functions, lower_function
16+
from ._suppress_triton import disable_cpp_logs
17+
from ._triton_call import cdiv, get_triton_type, next_power_of_2, normalize_grid, strides_from_shape, triton_call
4018

4119
__all__ = (
4220
"cdiv",
4321
"cjit",
4422
"compile_function",
45-
"enable_all_triton_output",
23+
"disable_cpp_logs",
4624
"get_triton_type",
25+
"load_cached_functions",
4726
"lower_function",
4827
"next_power_of_2",
4928
"normalize_grid",
50-
"silence_all_triton_output",
5129
"strides_from_shape",
5230
"triton_call",
5331
)

eformer/callib/_cjit.py

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import hashlib
2020
import os
2121
import pickle
22+
import re
2223
import typing as tp
2324
import warnings
2425
from pathlib import Path
@@ -88,6 +89,7 @@ def get_cache_dir() -> Path:
8889
COMPILE_FUNC_DIR = CACHE_DIR / "compiled_funcs"
8990
COMPILE_FUNC_DIR.mkdir(parents=True, exist_ok=True)
9091
COMPILED_FILE_NAME = "compiled.func"
92+
SIGNATURE_FILE_NAME = "compiled.signature"
9193

9294
COMPILED_CACHE: dict[tuple, tp.Any] = {}
9395

@@ -103,38 +105,75 @@ def is_jit_wrapped(fn: tp.Any) -> bool:
103105
)
104106

105107

106-
def cjit(
107-
fn: tp.Callable[P, R],
108-
static_argnums: tuple[int, ...] | None = None,
109-
static_argnames: tuple[str, ...] | None = None,
110-
verbose: bool = True,
111-
):
108+
def remove_memory_addresses(input_string: str) -> str:
109+
"""
110+
Removes hexadecimal memory address patterns (e.g., 0x736a142445e0) from a string.
111+
112+
Args:
113+
input_string: The string to process.
114+
115+
Returns:
116+
The string with memory address patterns removed.
117+
"""
118+
# Regex to find hexadecimal memory addresses:
119+
# 0x : matches the literal "0x"
120+
# [0-9a-fA-F]+ : matches one or more hexadecimal characters (0-9, a-f, A-F)
121+
122+
return re.sub(r"0x[0-9a-fA-F]+", "", input_string)
123+
124+
125+
def cjit(fn: tp.Callable[P, R], verbose: bool = True):
112126
"""
113127
A decorator that adds caching to a JAX JIT-compiled function.
114128
The input `fn` must already be a JIT-transformed function (e.g., from @jax.jit).
115129
"""
116130
assert is_jit_wrapped(fn=fn), "function should be jit wrapped already"
117131

132+
static_argnums = fn._jit_info.static_argnums
133+
static_argnames = fn._jit_info.static_argnames
134+
135+
if len(static_argnames) == 0:
136+
static_argnames = None
137+
if len(static_argnums) == 0:
138+
static_argnums = None
139+
140+
state_signature = remove_memory_addresses(
141+
str(
142+
(
143+
fn.__signature__._hash_basis(),
144+
fn.__annotations__,
145+
fn._fun.__annotations__,
146+
fn._fun.__kwdefaults__,
147+
fn._fun.__name__,
148+
static_argnums,
149+
static_argnames,
150+
)
151+
)
152+
)
153+
154+
static_arg_indices = set(static_argnums) if static_argnums is not None else set()
155+
118156
@functools.wraps(fn)
119157
def wrapped(*args, **kwargs):
120-
static_arg_indices = set(static_argnums) if static_argnums is not None else set()
121158
dynamic_args = tuple(arg for i, arg in enumerate(args) if i not in static_arg_indices)
122159
dynamic_kwargs = kwargs.copy()
123160
if static_argnames is not None:
124161
for key in static_argnames:
125162
dynamic_kwargs.pop(key, None)
126163
signature = get_signature_tree_util(dynamic_args, dynamic_kwargs)
127-
cache_key = (fn, signature)
164+
cache_key = (state_signature, signature)
128165
if cache_key in COMPILED_CACHE:
129166
compiled_func = COMPILED_CACHE[cache_key]
130167
return compiled_func(*dynamic_args, **dynamic_kwargs)
131168
lowered_func: Lowered = fn.lower(*args, **kwargs)
132-
compiled_func = smart_compile(
169+
compiled_func, signature = smart_compile(
133170
lowered_func=lowered_func,
134171
tag="cached-jit",
135172
verbose=verbose,
173+
cache_key=cache_key,
136174
)
137-
COMPILED_CACHE[cache_key] = compiled_func
175+
176+
COMPILED_CACHE[signature] = compiled_func
138177

139178
return compiled_func(*dynamic_args, **dynamic_kwargs)
140179

@@ -256,7 +295,8 @@ def smart_compile(
256295
lowered_func: Lowered,
257296
tag: str | None = None,
258297
verbose: bool = True,
259-
) -> Compiled:
298+
cache_key: tuple[str, tuple] | None = None,
299+
) -> tuple[Compiled, tuple[str, tuple] | None]:
260300
"""Compile a lowered JAX function with caching.
261301
262302
Args:
@@ -271,16 +311,19 @@ def smart_compile(
271311
foldername = str(func_hash) if tag is None else f"{tag}-{func_hash}"
272312
func_dir = COMPILE_FUNC_DIR / foldername
273313
filepath = func_dir / COMPILED_FILE_NAME
314+
signature_filepath = func_dir / SIGNATURE_FILE_NAME
274315
post_fix = f" (TAG : {tag})" if tag else ""
316+
signature = cache_key
275317
if filepath.exists() and not RECOMPILE_FORCE:
276318
try:
277319
(serialized, in_tree, out_tree) = pickle.load(open(filepath, "rb"))
320+
signature = pickle.load(open(signature_filepath, "rb"))
278321
compiled_func = deserialize_and_load(
279322
serialized=serialized,
280323
in_tree=in_tree,
281324
out_tree=out_tree,
282325
)
283-
return compiled_func
326+
return compiled_func, signature
284327
except Exception as e:
285328
if verbose:
286329
warnings.warn(
@@ -293,27 +336,29 @@ def smart_compile(
293336
func_dir.mkdir(parents=True, exist_ok=True)
294337
try:
295338
pickle.dump((serialized, in_tree, out_tree), open(filepath, "wb"))
339+
pickle.dump(cache_key, open(signature_filepath, "wb"))
296340
except Exception as e:
297341
if verbose:
298342
warnings.warn(
299343
f"couldn't save compiled function due to {e}" + post_fix,
300344
stacklevel=4,
301345
)
302-
return compiled_func
346+
return compiled_func, signature
303347
else:
304348
compiled_func: Compiled = lowered_func.compile()
305349
if ECACHE_COMPILES:
306350
try:
307351
serialized, in_tree, out_tree = serialize(compiled_func)
308352
func_dir.mkdir(parents=True, exist_ok=True)
309353
pickle.dump((serialized, in_tree, out_tree), open(filepath, "wb"))
354+
pickle.dump(cache_key, open(signature_filepath, "wb"))
310355
except Exception as e:
311356
if verbose:
312357
warnings.warn(
313358
f"couldn't save and serialize compiled function due to {e}" + post_fix,
314359
stacklevel=4,
315360
)
316-
return compiled_func
361+
return compiled_func, signature
317362

318363

319364
def save_compiled_fn(
@@ -341,10 +386,7 @@ def save_compiled_fn(
341386
warnings.warn(f"couldn't save compiled function due to {e}", stacklevel=4)
342387

343388

344-
def load_compiled_fn(
345-
path: str | os.PathLike,
346-
prefix: str | None = None,
347-
):
389+
def load_compiled_fn(path: str | os.PathLike, prefix: str | None = None):
348390
"""Load a compiled function from disk.
349391
350392
Args:
@@ -524,8 +566,20 @@ def compile_function(
524566
).compile()
525567

526568

569+
def load_cached_functions() -> None:
570+
files = [o for o in os.listdir(COMPILE_FUNC_DIR) if os.path.exists(os.path.join(COMPILE_FUNC_DIR, o))]
571+
for file in files:
572+
target_compiled_function = Path(COMPILE_FUNC_DIR) / file / COMPILED_FILE_NAME
573+
target_signature = Path(COMPILE_FUNC_DIR) / file / SIGNATURE_FILE_NAME
574+
(serialized, in_tree, out_tree) = pickle.loads(target_compiled_function.read_bytes())
575+
signature = pickle.loads(target_signature.read_bytes())
576+
compiled_function = deserialize_and_load(serialized=serialized, in_tree=in_tree, out_tree=out_tree)
577+
COMPILED_CACHE[signature] = compiled_function
578+
579+
527580
if __name__ == "__main__":
528581
jnp = jax.numpy
582+
load_cached_functions()
529583

530584
@cjit
531585
@jax.jit

eformer/callib/_suppress_triton.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -44,32 +44,34 @@ def __init__(self):
4444

4545
def start(self):
4646
"""Start suppressing C-level stderr output."""
47-
if not SUPPRESSION_AVAILABLE or self.suppressing:
48-
return False
47+
if os.environ.get("LET_TRITON_TALK", "true").lower() in ["1", "true", "on"]:
48+
if not SUPPRESSION_AVAILABLE or self.suppressing:
49+
return False
4950

50-
try:
51-
self.old_stderr_fd = libc.dup(STDERR_FILENO)
52-
if sys.platform == "win32":
53-
self.null_fd = os.open("NUL", os.O_WRONLY)
54-
else:
55-
self.null_fd = os.open("/dev/null", os.O_WRONLY)
56-
libc.dup2(self.null_fd, STDERR_FILENO)
57-
self.suppressing = True
58-
return True
59-
except (OSError, AttributeError):
60-
self.cleanup()
61-
return False
51+
try:
52+
self.old_stderr_fd = libc.dup(STDERR_FILENO)
53+
if sys.platform == "win32":
54+
self.null_fd = os.open("NUL", os.O_WRONLY)
55+
else:
56+
self.null_fd = os.open("/dev/null", os.O_WRONLY)
57+
libc.dup2(self.null_fd, STDERR_FILENO)
58+
self.suppressing = True
59+
return True
60+
except (OSError, AttributeError):
61+
self.cleanup()
62+
return False
6263

6364
def stop(self):
6465
"""Stop suppressing C-level stderr output and restore original stderr."""
65-
if not self.suppressing:
66-
return
66+
if os.environ.get("LET_TRITON_TALK", "true").lower() in ["1", "true", "on"]:
67+
if not self.suppressing:
68+
return
6769

68-
try:
69-
if self.old_stderr_fd is not None:
70-
libc.dup2(self.old_stderr_fd, STDERR_FILENO)
71-
finally:
72-
self.cleanup()
70+
try:
71+
if self.old_stderr_fd is not None:
72+
libc.dup2(self.old_stderr_fd, STDERR_FILENO)
73+
finally:
74+
self.cleanup()
7375

7476
def cleanup(self):
7577
"""Clean up resources."""
@@ -141,18 +143,21 @@ def enable_all_triton_output():
141143

142144

143145
@contextmanager
144-
def no_triton_logs():
146+
def disable_cpp_logs(verbose: bool = False):
145147
"""Context manager to temporarily suppress Triton kernel autotuning logs."""
146-
started = suppress_triton_logs()
147-
try:
148-
yield started
149-
finally:
150-
if started:
151-
restore_triton_logs()
148+
if verbose:
149+
yield False
150+
else:
151+
started = suppress_triton_logs()
152+
try:
153+
yield started
154+
finally:
155+
if started:
156+
restore_triton_logs()
152157

153158

154159
if __name__ == "__main__":
155-
with no_triton_logs():
160+
with disable_cpp_logs():
156161
print("Inside the context manager - C/C++ stderr is suppressed")
157162

158163
suppress_triton_logs()

0 commit comments

Comments
 (0)