Skip to content

Commit c458ee3

Browse files
committed
feat: Enhance auto_pytree functionality with max_print_length parameter and add type ignore for Union checks
1 parent 0f2782c commit c458ee3

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

eformer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# pyright: reportUnsupportedDunderAll=none
16+
1517
__version__ = "0.0.42"
1618

1719
__all__ = (

eformer/pytree/_pytree.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _is_non_jax_type(typ: type) -> bool:
5757
return False
5858

5959
origin = tp.get_origin(typ)
60-
if origin is tp.Union:
60+
if origin is tp.Union: # type:ignore
6161
args = tp.get_args(typ)
6262
return any(_is_non_jax_type(arg) for arg in args)
6363

@@ -127,6 +127,7 @@ def auto_pytree(
127127
meta_fields: tuple[str, ...] | None = None,
128128
json_serializable: bool = True,
129129
frozen: bool = False,
130+
max_print_length: int = 500,
130131
):
131132
"""
132133
A class decorator that automatically registers a dataclass as a JAX PyTree.
@@ -206,7 +207,7 @@ def enhanced_repr(self):
206207
if not k.startswith("_"): # Avoid private/internal attributes
207208
try:
208209
repr_str = str(v).replace("\n", "\n ")
209-
if len(repr_str) > 200: # Truncate long representations
210+
if len(repr_str) > max_print_length: # Truncate long representations
210211
repr_str = f"{v.__class__.__name__}(...)"
211212
items.append(f" {k} : {repr_str}")
212213
except TypeError:

0 commit comments

Comments
 (0)