From 3095fab68e8591218374b2ee1cd7a681fcc7c119 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 13 Nov 2025 22:51:15 +0000 Subject: [PATCH 1/5] patch_attrs helper Signed-off-by: Brian Dellabetta --- src/compressed_tensors/utils/helpers.py | 34 +++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 7649f0d0..5fad083d 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -44,6 +44,7 @@ "pack_bitmasks", "unpack_bitmasks", "patch_attr", + "patch_attrs", "ParameterizedDefaultDict", "get_num_attn_heads", "get_num_kv_heads", @@ -368,6 +369,39 @@ def patch_attr(base: object, attr: str, value: Any): delattr(base, attr) +@contextlib.contextmanager +def patch_attrs(bases: list[object], attr: str, values: list[Any]): + """ + Patch attribute for a list of objects with list of values. + Original values are restored upon exit + + :param bases: objects which has the attribute to patch + :param attr: name of the the attribute to patch + :param values: used to replace original values. Must be same + length as bases + + Usage: + >>> from types import SimpleNamespace + >>> obj = SimpleNamespace() + >>> with patch_attr(obj, "attribute", "value"): + ... assert obj.attribute == "value" + >>> assert not hasattr(obj, "attribute") + """ + _sentinel = object() + original_values = [getattr(base, attr, _sentinel) for base in bases] + + for base, value in zip(bases, values): + setattr(base, attr, value) + try: + yield + finally: + for base, original_value in zip(bases, original_values): + if original_value is not _sentinel: + setattr(base, attr, original_value) + else: + delattr(base, attr) + + class ParameterizedDefaultDict(dict): """ Similar to `collections.DefaultDict`, but upon fetching a key which is missing, From ae305265f6771c2ca8fa01e53b130fa751353703 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 13 Nov 2025 22:59:10 +0000 Subject: [PATCH 2/5] unit test Signed-off-by: Brian Dellabetta --- src/compressed_tensors/utils/helpers.py | 9 ++++++--- tests/test_utils/test_helpers.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 5fad083d..b87996ee 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -372,6 +372,7 @@ def patch_attr(base: object, attr: str, value: Any): @contextlib.contextmanager def patch_attrs(bases: list[object], attr: str, values: list[Any]): """ + Same as `patch_attr` but for a list of objects to patch Patch attribute for a list of objects with list of values. Original values are restored upon exit @@ -383,9 +384,11 @@ def patch_attrs(bases: list[object], attr: str, values: list[Any]): Usage: >>> from types import SimpleNamespace >>> obj = SimpleNamespace() - >>> with patch_attr(obj, "attribute", "value"): - ... assert obj.attribute == "value" - >>> assert not hasattr(obj, "attribute") + >>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]): + ... assert obj1.attribute == "value1" + ... assert obj2.attribute == "value2" + >>> assert not hasattr(obj1, "attribute") + >>> assert not hasattr(obj2, "attribute") """ _sentinel = object() original_values = [getattr(base, attr, _sentinel) for base in bases] diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py index 1c0aed95..eccb7b80 100644 --- a/tests/test_utils/test_helpers.py +++ b/tests/test_utils/test_helpers.py @@ -21,6 +21,7 @@ ParameterizedDefaultDict, load_compressed, patch_attr, + patch_attrs, save_compressed, save_compressed_model, ) @@ -176,6 +177,23 @@ def test_patch_attr(): assert not hasattr(obj, "attribute") +def test_patch_attrs(): + num_objs = 4 + objs = [SimpleNamespace() for _ in range(num_objs)] + for idx, obj in enumerate(objs): + if idx % 2 == 0: + obj.attribute = f"original_{idx}" + with patch_attrs(objs, "attribute", [f"patched_{idx}" for idx in range(num_objs)]): + for idx, obj in enumerate(objs): + assert obj.attribute == f"patched_{idx}" + obj.attribute = "modified" + for idx, obj in enumerate(objs): + if idx % 2 == 0: + assert obj.attribute == f"original_{idx}" + else: + assert not hasattr(obj, "attribute") + + def test_parameterized_default_dict(): def add_one(value): return value + 1 From 429c105ca0d8193e4f3cdefdc09b980e86547f11 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 14 Nov 2025 15:22:44 +0000 Subject: [PATCH 3/5] fix docstring Signed-off-by: Brian Dellabetta --- src/compressed_tensors/utils/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index b87996ee..1f9298ab 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -383,7 +383,8 @@ def patch_attrs(bases: list[object], attr: str, values: list[Any]): Usage: >>> from types import SimpleNamespace - >>> obj = SimpleNamespace() + >>> obj1 = SimpleNamespace() + >>> obj2 = SimpleNamespace() >>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]): ... assert obj1.attribute == "value1" ... assert obj2.attribute == "value2" From f72f778c034d06c28f5f836cf3a6764622e6e399 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 14 Nov 2025 12:32:19 -0500 Subject: [PATCH 4/5] Update src/compressed_tensors/utils/helpers.py Co-authored-by: Kyle Sayers Signed-off-by: Brian Dellabetta --- src/compressed_tensors/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 1f9298ab..31b04e6e 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -370,7 +370,7 @@ def patch_attr(base: object, attr: str, value: Any): @contextlib.contextmanager -def patch_attrs(bases: list[object], attr: str, values: list[Any]): +def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]): """ Same as `patch_attr` but for a list of objects to patch Patch attribute for a list of objects with list of values. From d709392a0b943a180b111baf6af8b381f93b1ddf Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 14 Nov 2025 12:36:45 -0500 Subject: [PATCH 5/5] Apply suggestion from @kylesayrs Co-authored-by: Kyle Sayers Signed-off-by: Brian Dellabetta Signed-off-by: Brian Dellabetta --- src/compressed_tensors/utils/helpers.py | 27 +++++++++++++------------ 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 31b04e6e..d9b3d26e 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -16,7 +16,17 @@ import warnings from functools import wraps from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + TypeVar, +) import numpy import torch @@ -391,19 +401,10 @@ def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]): >>> assert not hasattr(obj1, "attribute") >>> assert not hasattr(obj2, "attribute") """ - _sentinel = object() - original_values = [getattr(base, attr, _sentinel) for base in bases] - - for base, value in zip(bases, values): - setattr(base, attr, value) - try: + with contextlib.ExitStack() as stack: + for base, value in zip(bases, values): + stack.enter_context(patch_attr(base, attr, value)) yield - finally: - for base, original_value in zip(bases, original_values): - if original_value is not _sentinel: - setattr(base, attr, original_value) - else: - delattr(base, attr) class ParameterizedDefaultDict(dict):