Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@
import warnings
from functools import wraps
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
TypeVar,
)

import numpy
import torch
Expand Down Expand Up @@ -44,6 +54,7 @@
"pack_bitmasks",
"unpack_bitmasks",
"patch_attr",
"patch_attrs",
"ParameterizedDefaultDict",
"get_num_attn_heads",
"get_num_kv_heads",
Expand Down Expand Up @@ -368,6 +379,34 @@ def patch_attr(base: object, attr: str, value: Any):
delattr(base, attr)


@contextlib.contextmanager
def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]):
"""
Same as `patch_attr` but for a list of objects to patch
Patch attribute for a list of objects with list of values.
Original values are restored upon exit

:param bases: objects which has the attribute to patch
:param attr: name of the the attribute to patch
:param values: used to replace original values. Must be same
length as bases

Usage:
>>> from types import SimpleNamespace
>>> obj1 = SimpleNamespace()
>>> obj2 = SimpleNamespace()
>>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]):
... assert obj1.attribute == "value1"
... assert obj2.attribute == "value2"
>>> assert not hasattr(obj1, "attribute")
>>> assert not hasattr(obj2, "attribute")
"""
with contextlib.ExitStack() as stack:
for base, value in zip(bases, values):
stack.enter_context(patch_attr(base, attr, value))
yield


class ParameterizedDefaultDict(dict):
"""
Similar to `collections.DefaultDict`, but upon fetching a key which is missing,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ParameterizedDefaultDict,
load_compressed,
patch_attr,
patch_attrs,
save_compressed,
save_compressed_model,
)
Expand Down Expand Up @@ -176,6 +177,23 @@ def test_patch_attr():
assert not hasattr(obj, "attribute")


def test_patch_attrs():
num_objs = 4
objs = [SimpleNamespace() for _ in range(num_objs)]
for idx, obj in enumerate(objs):
if idx % 2 == 0:
obj.attribute = f"original_{idx}"
with patch_attrs(objs, "attribute", [f"patched_{idx}" for idx in range(num_objs)]):
for idx, obj in enumerate(objs):
assert obj.attribute == f"patched_{idx}"
obj.attribute = "modified"
for idx, obj in enumerate(objs):
if idx % 2 == 0:
assert obj.attribute == f"original_{idx}"
else:
assert not hasattr(obj, "attribute")


def test_parameterized_default_dict():
def add_one(value):
return value + 1
Expand Down