1919import hashlib
2020import os
2121import pickle
22+ import re
2223import typing as tp
2324import warnings
2425from pathlib import Path
@@ -88,6 +89,7 @@ def get_cache_dir() -> Path:
8889COMPILE_FUNC_DIR = CACHE_DIR / "compiled_funcs"
8990COMPILE_FUNC_DIR .mkdir (parents = True , exist_ok = True )
9091COMPILED_FILE_NAME = "compiled.func"
92+ SIGNATURE_FILE_NAME = "compiled.signature"
9193
9294COMPILED_CACHE : dict [tuple , tp .Any ] = {}
9395
@@ -103,38 +105,75 @@ def is_jit_wrapped(fn: tp.Any) -> bool:
103105 )
104106
105107
106- def cjit (
107- fn : tp .Callable [P , R ],
108- static_argnums : tuple [int , ...] | None = None ,
109- static_argnames : tuple [str , ...] | None = None ,
110- verbose : bool = True ,
111- ):
108+ def remove_memory_addresses (input_string : str ) -> str :
109+ """
110+ Removes hexadecimal memory address patterns (e.g., 0x736a142445e0) from a string.
111+
112+ Args:
113+ input_string: The string to process.
114+
115+ Returns:
116+ The string with memory address patterns removed.
117+ """
118+ # Regex to find hexadecimal memory addresses:
119+ # 0x : matches the literal "0x"
120+ # [0-9a-fA-F]+ : matches one or more hexadecimal characters (0-9, a-f, A-F)
121+
122+ return re .sub (r"0x[0-9a-fA-F]+" , "" , input_string )
123+
124+
125+ def cjit (fn : tp .Callable [P , R ], verbose : bool = True ):
112126 """
113127 A decorator that adds caching to a JAX JIT-compiled function.
114128 The input `fn` must already be a JIT-transformed function (e.g., from @jax.jit).
115129 """
116130 assert is_jit_wrapped (fn = fn ), "function should be jit wrapped already"
117131
132+ static_argnums = fn ._jit_info .static_argnums
133+ static_argnames = fn ._jit_info .static_argnames
134+
135+ if len (static_argnames ) == 0 :
136+ static_argnames = None
137+ if len (static_argnums ) == 0 :
138+ static_argnums = None
139+
140+ state_signature = remove_memory_addresses (
141+ str (
142+ (
143+ fn .__signature__ ._hash_basis (),
144+ fn .__annotations__ ,
145+ fn ._fun .__annotations__ ,
146+ fn ._fun .__kwdefaults__ ,
147+ fn ._fun .__name__ ,
148+ static_argnums ,
149+ static_argnames ,
150+ )
151+ )
152+ )
153+
154+ static_arg_indices = set (static_argnums ) if static_argnums is not None else set ()
155+
118156 @functools .wraps (fn )
119157 def wrapped (* args , ** kwargs ):
120- static_arg_indices = set (static_argnums ) if static_argnums is not None else set ()
121158 dynamic_args = tuple (arg for i , arg in enumerate (args ) if i not in static_arg_indices )
122159 dynamic_kwargs = kwargs .copy ()
123160 if static_argnames is not None :
124161 for key in static_argnames :
125162 dynamic_kwargs .pop (key , None )
126163 signature = get_signature_tree_util (dynamic_args , dynamic_kwargs )
127- cache_key = (fn , signature )
164+ cache_key = (state_signature , signature )
128165 if cache_key in COMPILED_CACHE :
129166 compiled_func = COMPILED_CACHE [cache_key ]
130167 return compiled_func (* dynamic_args , ** dynamic_kwargs )
131168 lowered_func : Lowered = fn .lower (* args , ** kwargs )
132- compiled_func = smart_compile (
169+ compiled_func , signature = smart_compile (
133170 lowered_func = lowered_func ,
134171 tag = "cached-jit" ,
135172 verbose = verbose ,
173+ cache_key = cache_key ,
136174 )
137- COMPILED_CACHE [cache_key ] = compiled_func
175+
176+ COMPILED_CACHE [signature ] = compiled_func
138177
139178 return compiled_func (* dynamic_args , ** dynamic_kwargs )
140179
@@ -256,7 +295,8 @@ def smart_compile(
256295 lowered_func : Lowered ,
257296 tag : str | None = None ,
258297 verbose : bool = True ,
259- ) -> Compiled :
298+ cache_key : tuple [str , tuple ] | None = None ,
299+ ) -> tuple [Compiled , tuple [str , tuple ] | None ]:
260300 """Compile a lowered JAX function with caching.
261301
262302 Args:
@@ -271,16 +311,19 @@ def smart_compile(
271311 foldername = str (func_hash ) if tag is None else f"{ tag } -{ func_hash } "
272312 func_dir = COMPILE_FUNC_DIR / foldername
273313 filepath = func_dir / COMPILED_FILE_NAME
314+ signature_filepath = func_dir / SIGNATURE_FILE_NAME
274315 post_fix = f" (TAG : { tag } )" if tag else ""
316+ signature = cache_key
275317 if filepath .exists () and not RECOMPILE_FORCE :
276318 try :
277319 (serialized , in_tree , out_tree ) = pickle .load (open (filepath , "rb" ))
320+ signature = pickle .load (open (signature_filepath , "rb" ))
278321 compiled_func = deserialize_and_load (
279322 serialized = serialized ,
280323 in_tree = in_tree ,
281324 out_tree = out_tree ,
282325 )
283- return compiled_func
326+ return compiled_func , signature
284327 except Exception as e :
285328 if verbose :
286329 warnings .warn (
@@ -293,27 +336,29 @@ def smart_compile(
293336 func_dir .mkdir (parents = True , exist_ok = True )
294337 try :
295338 pickle .dump ((serialized , in_tree , out_tree ), open (filepath , "wb" ))
339+ pickle .dump (cache_key , open (signature_filepath , "wb" ))
296340 except Exception as e :
297341 if verbose :
298342 warnings .warn (
299343 f"couldn't save compiled function due to { e } " + post_fix ,
300344 stacklevel = 4 ,
301345 )
302- return compiled_func
346+ return compiled_func , signature
303347 else :
304348 compiled_func : Compiled = lowered_func .compile ()
305349 if ECACHE_COMPILES :
306350 try :
307351 serialized , in_tree , out_tree = serialize (compiled_func )
308352 func_dir .mkdir (parents = True , exist_ok = True )
309353 pickle .dump ((serialized , in_tree , out_tree ), open (filepath , "wb" ))
354+ pickle .dump (cache_key , open (signature_filepath , "wb" ))
310355 except Exception as e :
311356 if verbose :
312357 warnings .warn (
313358 f"couldn't save and serialize compiled function due to { e } " + post_fix ,
314359 stacklevel = 4 ,
315360 )
316- return compiled_func
361+ return compiled_func , signature
317362
318363
319364def save_compiled_fn (
@@ -341,10 +386,7 @@ def save_compiled_fn(
341386 warnings .warn (f"couldn't save compiled function due to { e } " , stacklevel = 4 )
342387
343388
344- def load_compiled_fn (
345- path : str | os .PathLike ,
346- prefix : str | None = None ,
347- ):
389+ def load_compiled_fn (path : str | os .PathLike , prefix : str | None = None ):
348390 """Load a compiled function from disk.
349391
350392 Args:
@@ -524,8 +566,20 @@ def compile_function(
524566 ).compile ()
525567
526568
569+ def load_cached_functions () -> None :
570+ files = [o for o in os .listdir (COMPILE_FUNC_DIR ) if os .path .exists (os .path .join (COMPILE_FUNC_DIR , o ))]
571+ for file in files :
572+ target_compiled_function = Path (COMPILE_FUNC_DIR ) / file / COMPILED_FILE_NAME
573+ target_signature = Path (COMPILE_FUNC_DIR ) / file / SIGNATURE_FILE_NAME
574+ (serialized , in_tree , out_tree ) = pickle .loads (target_compiled_function .read_bytes ())
575+ signature = pickle .loads (target_signature .read_bytes ())
576+ compiled_function = deserialize_and_load (serialized = serialized , in_tree = in_tree , out_tree = out_tree )
577+ COMPILED_CACHE [signature ] = compiled_function
578+
579+
527580if __name__ == "__main__" :
528581 jnp = jax .numpy
582+ load_cached_functions ()
529583
530584 @cjit
531585 @jax .jit
0 commit comments