diff --git a/tree_math/_src/structs.py b/tree_math/_src/structs.py index 2a16af6..4b82d0c 100644 --- a/tree_math/_src/structs.py +++ b/tree_math/_src/structs.py @@ -19,61 +19,73 @@ def struct(cls): - """Class decorator that enables JAX function transforms as well as tree math. - - Decorating a class with `@struct` makes it a dataclass that is compatible - with arithmetic infix operators like `+`, `-`, `*` and `/`. The decorated - class is also a valid pytree, making it compatible with JAX function - transformations such as `jit` and `grad`. - - Example usage: - ``` - import jax - import tree_math - - @tree_math.struct - class Point: - x: float - y: float - - a = Point(0.0, 1.0) - b = Point(2.0, 3.0) - - a + 3 * b # Point(6.0, 10.0) - jax.grad(lambda x, y: x @ y)(a, b) # Point(2.0, 3.0) - ``` - - Args: - cls: a class, written with the same syntax as a `dataclass`. - - Returns: - A wrapped version of `cls` that implements dataclass, pytree and tree_math - functionality. - """ - @property - def fields(self): - return dataclasses.fields(self) - - def asdict(self): - return {field.name: getattr(self, field.name) for field in self.fields} - - def astuple(self): - return tuple(getattr(self, field.name) for field in self.fields) - - def tree_flatten(self): - return self.astuple(), None - - @classmethod - def tree_unflatten(cls, _, children): - return cls(*children) - - cls_as_struct = type(cls.__name__, - (VectorMixin, dataclasses.dataclass(cls)), - {'fields': fields, - 'asdict': asdict, - 'astuple': astuple, - 'replace': dataclasses.replace, - 'tree_flatten': tree_flatten, - 'tree_unflatten': tree_unflatten, - '__module__': cls.__module__}) - return jax.tree_util.register_pytree_node_class(cls_as_struct) + """Class decorator that enables JAX function transforms as well as tree math. + + Decorating a class with `@struct` makes it a dataclass that is compatible + with arithmetic infix operators like `+`, `-`, `*` and `/`. The decorated + class is also a valid pytree, making it compatible with JAX function + transformations such as `jit` and `grad`. + + Example usage: + ``` + import jax + import tree_math + + @tree_math.struct + class Point: + x: float + y: float + static_field: int = 0 # base case + + a = Point(0.0, 1.0) + b = Point(2.0, 3.0) + + a + 3 * b # Point(6.0, 10.0) + jax.grad(lambda x, y: x @ y)(a, b) # Point(2.0, 3.0) + ``` + + Args: + cls: a class, written with the same syntax as a `dataclass`. + + Returns: + A wrapped version of `cls` that implements dataclass, pytree and tree_math + functionality. + """ + + # Get static fields from the class if defined + static_fields = getattr(cls, 'static_fields', []) + + @property + def fields(self): + return dataclasses.fields(self) + + def asdict(self): + return {field.name: getattr(self, field.name) for field in self.fields} + + def astuple(self): + return tuple(getattr(self, field.name) for field in self.fields if field.name not in static_fields) + + def tree_flatten(self): + # Flatten only the non-static fields + children = [getattr(self, field.name) for field in self.fields if field.name not in static_fields] + return children, None + + @classmethod + def tree_unflatten(cls, _, children): + # Create an instance with the provided children and static fields + instance = cls(*children) + for field in cls.static_fields: + setattr(instance, field, getattr(cls, field)) # Set static fields + return instance + + cls_as_struct = type(cls.__name__, + (VectorMixin, dataclasses.dataclass(cls)), + {'fields': fields, + 'asdict': asdict, + 'astuple': astuple, + 'replace': dataclasses.replace, + 'tree_flatten': tree_flatten, + 'tree_unflatten': tree_unflatten, + '__module__': cls.__module__}) + + return jax.tree_util.register_pytree_node_class(cls_as_struct) diff --git a/tree_math/_src/structs_test.py b/tree_math/_src/structs_test.py index b3af767..bd6d3c5 100644 --- a/tree_math/_src/structs_test.py +++ b/tree_math/_src/structs_test.py @@ -30,8 +30,14 @@ @tree_math.struct class TestStruct: - a: ArrayLike - b: ArrayLike + a: ArrayLike + b: ArrayLike + static_field: int = 0 # This will be a static field + + # Define static fields as a class variable + static_fields = ['static_field'] # Specify which fields are static + + class StructsTest(test_util.TestCase):