diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index c8b2ac7a95..d0b9cd59ad 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -3203,3 +3203,9 @@ def get_user_object_from_id(obj_id): def store_user_object_weakref(obj): obj_id = id(obj) user_obj_id_to_weakref[obj_id] = weakref.ref(obj) + + +def realize_inputs(inputs: List[torch.fx.Node]): + for inp in inputs: + if isinstance(inp, torch.fx.node.Node): + inp.meta["inductor_realize_to_strides"] = True