|
16 | 16 | import warnings |
17 | 17 | from functools import wraps |
18 | 18 | from types import MappingProxyType |
19 | | -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar |
| 19 | +from typing import ( |
| 20 | + TYPE_CHECKING, |
| 21 | + Any, |
| 22 | + Callable, |
| 23 | + Dict, |
| 24 | + Iterable, |
| 25 | + List, |
| 26 | + Mapping, |
| 27 | + Optional, |
| 28 | + TypeVar, |
| 29 | +) |
20 | 30 |
|
21 | 31 | import numpy |
22 | 32 | import torch |
|
44 | 54 | "pack_bitmasks", |
45 | 55 | "unpack_bitmasks", |
46 | 56 | "patch_attr", |
| 57 | + "patch_attrs", |
47 | 58 | "ParameterizedDefaultDict", |
48 | 59 | "get_num_attn_heads", |
49 | 60 | "get_num_kv_heads", |
@@ -368,6 +379,34 @@ def patch_attr(base: object, attr: str, value: Any): |
368 | 379 | delattr(base, attr) |
369 | 380 |
|
370 | 381 |
|
| 382 | +@contextlib.contextmanager |
| 383 | +def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]): |
| 384 | + """ |
| 385 | + Same as `patch_attr` but for a list of objects to patch |
| 386 | + Patch attribute for a list of objects with list of values. |
| 387 | + Original values are restored upon exit |
| 388 | +
|
| 389 | + :param bases: objects which has the attribute to patch |
| 390 | + :param attr: name of the the attribute to patch |
| 391 | + :param values: used to replace original values. Must be same |
| 392 | + length as bases |
| 393 | +
|
| 394 | + Usage: |
| 395 | + >>> from types import SimpleNamespace |
| 396 | + >>> obj1 = SimpleNamespace() |
| 397 | + >>> obj2 = SimpleNamespace() |
| 398 | + >>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]): |
| 399 | + ... assert obj1.attribute == "value1" |
| 400 | + ... assert obj2.attribute == "value2" |
| 401 | + >>> assert not hasattr(obj1, "attribute") |
| 402 | + >>> assert not hasattr(obj2, "attribute") |
| 403 | + """ |
| 404 | + with contextlib.ExitStack() as stack: |
| 405 | + for base, value in zip(bases, values): |
| 406 | + stack.enter_context(patch_attr(base, attr, value)) |
| 407 | + yield |
| 408 | + |
| 409 | + |
371 | 410 | class ParameterizedDefaultDict(dict): |
372 | 411 | """ |
373 | 412 | Similar to `collections.DefaultDict`, but upon fetching a key which is missing, |
|
0 commit comments