Skip to content

Commit d709392

Browse files
Apply suggestion from @kylesayrs
Co-authored-by: Kyle Sayers <[email protected]> Signed-off-by: Brian Dellabetta <[email protected]> Signed-off-by: Brian Dellabetta <[email protected]>
1 parent f72f778 commit d709392

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

src/compressed_tensors/utils/helpers.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@
1616
import warnings
1717
from functools import wraps
1818
from types import MappingProxyType
19-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
Callable,
23+
Dict,
24+
Iterable,
25+
List,
26+
Mapping,
27+
Optional,
28+
TypeVar,
29+
)
2030

2131
import numpy
2232
import torch
@@ -391,19 +401,10 @@ def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]):
391401
>>> assert not hasattr(obj1, "attribute")
392402
>>> assert not hasattr(obj2, "attribute")
393403
"""
394-
_sentinel = object()
395-
original_values = [getattr(base, attr, _sentinel) for base in bases]
396-
397-
for base, value in zip(bases, values):
398-
setattr(base, attr, value)
399-
try:
404+
with contextlib.ExitStack() as stack:
405+
for base, value in zip(bases, values):
406+
stack.enter_context(patch_attr(base, attr, value))
400407
yield
401-
finally:
402-
for base, original_value in zip(bases, original_values):
403-
if original_value is not _sentinel:
404-
setattr(base, attr, original_value)
405-
else:
406-
delattr(base, attr)
407408

408409

409410
class ParameterizedDefaultDict(dict):

0 commit comments

Comments
 (0)