Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"pack_bitmasks",
"unpack_bitmasks",
"patch_attr",
"patch_attrs",
"ParameterizedDefaultDict",
"get_num_attn_heads",
"get_num_kv_heads",
Expand Down Expand Up @@ -368,6 +369,43 @@ def patch_attr(base: object, attr: str, value: Any):
delattr(base, attr)


@contextlib.contextmanager
def patch_attrs(bases: list[object], attr: str, values: list[Any]):
"""
Same as `patch_attr` but for a list of objects to patch
Patch attribute for a list of objects with list of values.
Original values are restored upon exit

:param bases: objects which has the attribute to patch
:param attr: name of the the attribute to patch
:param values: used to replace original values. Must be same
length as bases

Usage:
>>> from types import SimpleNamespace
>>> obj1 = SimpleNamespace()
>>> obj2 = SimpleNamespace()
>>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]):
... assert obj1.attribute == "value1"
... assert obj2.attribute == "value2"
>>> assert not hasattr(obj1, "attribute")
>>> assert not hasattr(obj2, "attribute")
"""
_sentinel = object()
original_values = [getattr(base, attr, _sentinel) for base in bases]

for base, value in zip(bases, values):
setattr(base, attr, value)
try:
yield
finally:
for base, original_value in zip(bases, original_values):
if original_value is not _sentinel:
setattr(base, attr, original_value)
else:
delattr(base, attr)


class ParameterizedDefaultDict(dict):
"""
Similar to `collections.DefaultDict`, but upon fetching a key which is missing,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ParameterizedDefaultDict,
load_compressed,
patch_attr,
patch_attrs,
save_compressed,
save_compressed_model,
)
Expand Down Expand Up @@ -176,6 +177,23 @@ def test_patch_attr():
assert not hasattr(obj, "attribute")


def test_patch_attrs():
num_objs = 4
objs = [SimpleNamespace() for _ in range(num_objs)]
for idx, obj in enumerate(objs):
if idx % 2 == 0:
obj.attribute = f"original_{idx}"
with patch_attrs(objs, "attribute", [f"patched_{idx}" for idx in range(num_objs)]):
for idx, obj in enumerate(objs):
assert obj.attribute == f"patched_{idx}"
obj.attribute = "modified"
for idx, obj in enumerate(objs):
if idx % 2 == 0:
assert obj.attribute == f"original_{idx}"
else:
assert not hasattr(obj, "attribute")


def test_parameterized_default_dict():
def add_one(value):
return value + 1
Expand Down