Skip to content

Commit 4917316

Browse files
fxdawnnpytorchmergebot
authored andcommitted
[Dynamo][Guards]Fix TLParse CPP guard message with sorting get_leaf_guards and verbose_code_parts (pytorch#169102)
Fix pytorch#168379. 1. The results are validated in the improved testing that the ``___dict_contains`` will be sorted based on the verbose part. The first solution was also suggested in https://fb.workplace.com/groups/1075192433118967/permalink/1650742858897252/ by sorting the ``get_leaf_guards()`` in ``construct_manager_string``. 2. The second solution will be adopted the ``OrderedSet`` in setGuards during guards construction to make sure the ``contain_dict`` are displayed as the order of being added. We decided to pursuit the second options to reduce the sorting time overhead and simplicity. Pull Request resolved: pytorch#169102 Approved by: https://github.com/anijain2305
1 parent 1e34fb2 commit 4917316

File tree

3 files changed

+30
-29
lines changed

3 files changed

+30
-29
lines changed

test/dynamo/test_misc.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,30 +1225,29 @@ def fn(x, y):
12251225
# Filter out id-matches that won't reproduce run to run
12261226
guard_code = filter(
12271227
lambda line: "id" not in line and "lookup_backend" not in line,
1228-
sorted(guard_code),
1228+
guard_code,
12291229
)
12301230
guard_code_str = "\n".join(guard_code)
12311231

1232-
for line in """\
1233-
2 <= L['x'].size()[0]
1234-
L['x'] is L['y']
1235-
L['x'].ndimension() == 2
1236-
L['x'].requires_grad == False
1232+
# Make sure that the dict_contains are present in the order of added
1233+
self.assertExpectedInline(
1234+
guard_code_str,
1235+
"""\
12371236
L['x'].size()[1] == L['x'].size()[0]
12381237
L['x'].storage_offset() == 0
1239-
___dict_contains('operator', G['sys'].modules)
1240-
___dict_contains('operator', G['sys'].modules)
1238+
2 <= L['x'].size()[0]
1239+
utils_device.CURRENT_DEVICE == None
1240+
str(L['x'].dtype) == 'torch.float32'
1241+
str(L['x'].device) == 'cpu'
1242+
L['x'].requires_grad == False
1243+
L['x'].ndimension() == 2
12411244
hasattr(L['x'], '_dynamo_dynamic_indices') == False
1245+
L['x'] is L['y']
12421246
not ___dict_contains('aaaaaaaa', G['sys'].modules)
12431247
not ___dict_contains('bbbbbbbb', G['sys'].modules)
1244-
not ___dict_contains('cccccccc', G['sys'].modules)
1245-
str(L['x'].device) == 'cpu'
1246-
str(L['x'].dtype) == 'torch.float32'
1247-
utils_device.CURRENT_DEVICE == None""".split("\n"):
1248-
self.assertIn(
1249-
line,
1250-
guard_code_str,
1251-
)
1248+
___dict_contains('operator', G['sys'].modules)
1249+
not ___dict_contains('cccccccc', G['sys'].modules)""",
1250+
)
12521251

12531252
def test_fold(self):
12541253
def fn(a):

torch/_dynamo/guards.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3871,15 +3871,15 @@ def _ref(x: Any) -> Any:
38713871
},
38723872
global_scope=global_scope_state,
38733873
_guards=torch._guards.GuardsSet(
3874-
{
3874+
OrderedSet(
38753875
dataclasses.replace(
38763876
guard,
38773877
obj_weakref=None,
38783878
guarded_class_weakref=None,
38793879
create_fn=normalize_create_fn(guard.create_fn),
38803880
)
38813881
for guard in sorted_guards
3882-
}
3882+
)
38833883
),
38843884
input_source_to_sizes_strides=pytree.tree_map(
38853885
convert_int_to_concrete_values,

torch/_guards.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from collections import defaultdict
1515
from contextlib import contextmanager
1616
from dataclasses import dataclass
17-
from typing import Any, Generic, NamedTuple, TYPE_CHECKING, TypeVar
17+
from typing import Any, Generic, NamedTuple, Optional, TYPE_CHECKING, TypeVar
1818

1919
import torch
2020
from torch.utils import _pytree as pytree
21+
from torch.utils._ordered_set import OrderedSet
2122
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
2223
from torch.utils._traceback import CapturedTraceback, format_frame
2324
from torch.utils.weak import WeakTensorKeyDictionary
@@ -487,16 +488,16 @@ class GuardsCheckpointState:
487488
The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
488489
"""
489490

490-
dynamo_guards: set[Guard] = set()
491+
dynamo_guards: OrderedSet[Guard]
491492

492-
def __init__(self, dynamo_guards: set[Guard]) -> None:
493+
def __init__(self, dynamo_guards: OrderedSet[Guard]) -> None:
493494
self.dynamo_guards = dynamo_guards
494495

495-
def diff(self, other: GuardsCheckpointState) -> set[Guard] | None:
496+
def diff(self, other: GuardsCheckpointState) -> Optional[OrderedSet[Guard]]:
496497
"""
497498
Produces a delta against another GuardsCheckpointState.
498499
499-
Returns None if no delta is found, otherwise, return a set() of mismatched
500+
Returns None if no delta is found, otherwise, return an OrderedSet() of mismatched
500501
Guard type objects.
501502
"""
502503
r = self.dynamo_guards.difference(other.dynamo_guards)
@@ -605,10 +606,11 @@ def restore_graphstate(self, state: GlobalContextCheckpointState) -> None:
605606
# Like a Set[Guard] but will record the user stack on all guards at the
606607
# time they were installed at their destination
607608
class GuardsSet:
608-
def __init__(self, inner: set[Guard] | None = None) -> None:
609+
def __init__(self, inner: Optional[OrderedSet[Guard]] = None) -> None:
609610
if inner is None:
610-
inner = set()
611-
self.inner = inner
611+
self.inner: OrderedSet[Guard] = OrderedSet()
612+
else:
613+
self.inner = inner
612614

613615
def __iter__(self) -> Iterator[Guard]:
614616
return iter(self.inner)
@@ -645,9 +647,9 @@ def remove_guards_with_source(self, source: Source) -> None:
645647
"""Delete all guards that contains a given source"""
646648
from ._dynamo.source import is_from_source
647649

648-
self.inner = {
650+
self.inner = OrderedSet(
649651
g for g in self.inner if not is_from_source(g.originating_source, source)
650-
}
652+
)
651653

652654

653655
"""
@@ -664,7 +666,7 @@ def __init__(self) -> None:
664666
self.aotautograd_guards: list[GuardEnvExpr] = []
665667

666668
def copy_graphstate(self) -> GuardsCheckpointState:
667-
return GuardsCheckpointState(set(self.dynamo_guards.inner))
669+
return GuardsCheckpointState(OrderedSet(self.dynamo_guards.inner))
668670

669671
def restore_graphstate(self, state: GuardsCheckpointState) -> None:
670672
# NB: "steals" the passed in state

0 commit comments

Comments
 (0)