Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,14 +1215,15 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any:
f"'{engine_node.target}' not found on graph module"
)
elif engine_node.op == "placeholder":
constants = getattr(exp_program, "constants", {})
engine_obj = constants.get(engine_node.name) or constants.get(
engine_node.target
)
from torch_tensorrt.dynamo._exporter import _resolve_lifted_custom_obj

engine_obj = _resolve_lifted_custom_obj(exp_program, engine_node)
if engine_obj is None:
raise RuntimeError(
f"execute_engine node '{node.name}': placeholder engine "
f"'{engine_node.name}' not found in exp_program.constants"
f"'{engine_node.name}' did not resolve to a lifted "
f"custom-object constant (available: "
f"{sorted(getattr(exp_program, 'constants', {}) or {})})"
)
else:
raise RuntimeError(
Expand Down
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch._export.non_strict_utils import make_constraints
from torch._guards import detect_fake_mode
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor
from torch.export import ExportedProgram, ExportGraphSignature
from torch.export._trace import _combine_args
Expand All @@ -23,6 +24,37 @@
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX, NAME_IDX


def _resolve_lifted_custom_obj(
exp_program: ExportedProgram, node: torch.fx.Node
) -> Any:
# torch.export lifts custom objects into exp_program.constants keyed by their
# graph-signature FQN and renames the placeholder node, so constants[node.name]
# misses. Resolve name -> FQN through the signature mapping; the direct
# name/target lookup is only for legacy programs that carry no such mapping.
constants = getattr(exp_program, "constants", {}) or {}
sig = getattr(exp_program, "graph_signature", None)
name_to_fqn = (
getattr(sig, "inputs_to_lifted_custom_objs", {}) or {}
if sig is not None
else {}
)

obj = None
fqn = name_to_fqn.get(node.name)
if fqn is not None:
obj = constants.get(fqn)
elif not name_to_fqn:
for key in (node.target, node.name):
if key in constants:
obj = constants[key]
break

# A FakeScriptObject has no __getstate__; callers need the real object.
if isinstance(obj, FakeScriptObject):
obj = obj.real_obj
return obj


def export(
gm: torch.fx.GraphModule,
*,
Expand Down
48 changes: 31 additions & 17 deletions py/torch_tensorrt/executorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PreprocessResult,
)
from torch.export.exported_program import ExportedProgram
from torch_tensorrt.dynamo._exporter import _resolve_lifted_custom_obj
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import (
DEVICE_IDX,
ENGINE_IDX,
Expand All @@ -36,18 +37,25 @@ def _schema_name(target: Any) -> str:
return ""


_ENGINE_OP_SCHEMA_NAMES = (
"tensorrt::execute_engine",
"tensorrt::no_op_placeholder_for_execute_engine",
)


def _get_engine_nodes_in(nodes: Any) -> List[Any]:
"""Return the TRT engine nodes in an iterable of FX nodes (graph or partition)."""
return [
node
for node in nodes
if node.op == "call_function"
and _schema_name(node.target) in _ENGINE_OP_SCHEMA_NAMES
]


def _get_engine_nodes_from_edge_program(edge_program: ExportedProgram) -> List[Any]:
"""Return all TRT engine nodes found in a lowered ExecuTorch partition."""
engine_nodes = []
for node in edge_program.graph_module.graph.nodes:
if node.op != "call_function":
continue
if _schema_name(node.target) in (
"tensorrt::execute_engine",
"tensorrt::no_op_placeholder_for_execute_engine",
):
engine_nodes.append(node)
return engine_nodes
return _get_engine_nodes_in(edge_program.graph_module.graph.nodes)


def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[Any]:
Expand All @@ -63,15 +71,22 @@ def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[An
Uses schema name comparison (not object identity) so it works for both
OpOverload and EdgeOpOverload targets.
"""
gm = edge_program.graph_module
engine_nodes = _get_engine_nodes_from_edge_program(edge_program)
if len(engine_nodes) != 1:
raise RuntimeError(
"TensorRT ExecuTorch backend expects exactly 1 engine node per "
f"partition, found {len(engine_nodes)}."
)
return _get_engine_info_for_node(edge_program, engine_nodes[0])


node = engine_nodes[0]
def _get_engine_info_for_node(
edge_program: ExportedProgram, node: torch.fx.Node
) -> List[Any]:
# Engine-info extraction for a single TRT node; callable per-partition so a
# coalesced multi-engine graph can resolve each engine without the
# whole-program "exactly 1 engine" assumption.
gm = edge_program.graph_module
name = _schema_name(node.target)

if name == "tensorrt::no_op_placeholder_for_execute_engine":
Expand Down Expand Up @@ -135,14 +150,13 @@ def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[An
f"'{engine_node.target}' not found on graph module"
)
elif engine_node.op == "placeholder":
constants = getattr(edge_program, "constants", {})
engine_obj = constants.get(engine_node.name) or constants.get(
engine_node.target
)
engine_obj = _resolve_lifted_custom_obj(edge_program, engine_node)
if engine_obj is None:
raise RuntimeError(
f"execute_engine node '{node.name}': placeholder engine "
f"'{engine_node.name}' not found in edge_program.constants"
f"'{engine_node.name}' did not resolve to a lifted custom-object "
f"constant (available: "
f"{sorted(getattr(edge_program, 'constants', {}) or {})})"
)
else:
raise RuntimeError(
Expand Down
64 changes: 36 additions & 28 deletions py/torch_tensorrt/executorch/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
)
from executorch.exir.backend.utils import tag_constant_data
from torch.export import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import DEVICE_IDX
from torch_tensorrt.executorch.backend import (
TensorRTBackend,
_get_engine_info_from_edge_program,
_get_engine_info_for_node,
_get_engine_nodes_in,
_parse_device_id,
)
from torch_tensorrt.executorch.operator_support import TensorRTOperatorSupport
Expand Down Expand Up @@ -77,27 +78,34 @@ def __init__(
compile_specs=self.compile_specs,
)

def _resolve_target_device(self, exported_program: ExportedProgram) -> bytes:
"""Best-effort ``target_device`` for the delegate-boundary TensorSpecs.
def _resolve_target_device_for_partition(
self, exported_program: ExportedProgram, partition: Partition
) -> bytes:
"""Best-effort ``target_device`` for one partition's delegate boundary.

Reuses the backend's own engine-info extraction so the device index
cannot drift from the runtime blob. Any extraction failure -- no single
engine node (zero or multiple TRT partitions) or an unreadable index --
falls back to ``cuda:0``; per-partition multi-GPU labeling is left to a
follow-up.
Derives the device from this partition's own TRT engine node, so a
coalesced multi-engine graph labels each delegate with its correct GPU
instead of stamping every partition with a single whole-program value.
Any extraction failure falls back to ``cuda:0``.
"""
try:
engine_info = _get_engine_info_from_edge_program(exported_program)
engine_nodes = _get_engine_nodes_in(partition.nodes)
if len(engine_nodes) != 1:
raise RuntimeError(
f"expected exactly 1 engine node in partition "
f"{getattr(partition, 'id', '?')}, found {len(engine_nodes)}"
)
engine_info = _get_engine_info_for_node(exported_program, engine_nodes[0])
return f"cuda:{_parse_device_id(engine_info[DEVICE_IDX])}".encode()
except Exception as e:
# Broad by design: any extraction failure must fall back, not abort
# the export. Warn so a non-default GPU silently labeled cuda:0 stays
# diagnosable.
logger.warning(
"Could not derive target_device from the TensorRT engine (%s); "
"falling back to cuda:0. A non-default GPU engine may be "
'mislabeled -- pin it via CompileSpec("target_device", '
"Could not derive target_device for partition %s (%s); falling "
'back to cuda:0. Pin it via CompileSpec("target_device", '
'b"cuda:<index>").',
getattr(partition, "id", "?"),
e,
)
return b"cuda:0"
Expand All @@ -110,26 +118,26 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
)
partition_list = capability_partitioner.propose_partitions()

if self._has_explicit_target_device:
delegation_spec = self.delegation_spec
else:
delegation_spec = DelegationSpec(
backend_id=TensorRTBackend.__name__,
compile_specs=self.compile_specs
+ [
CompileSpec(
_TARGET_DEVICE_COMPILE_SPEC_KEY,
self._resolve_target_device(exported_program),
)
],
)

partition_tags: Dict[str, DelegationSpec] = {}
for partition in partition_list:
tag = f"tensorrt_{partition.id}"
for node in partition.nodes:
node.meta["delegation_tag"] = tag
partition_tags[tag] = delegation_spec
if self._has_explicit_target_device:
partition_tags[tag] = self.delegation_spec
else:
partition_tags[tag] = DelegationSpec(
backend_id=TensorRTBackend.__name__,
compile_specs=self.compile_specs
+ [
CompileSpec(
_TARGET_DEVICE_COMPILE_SPEC_KEY,
self._resolve_target_device_for_partition(
exported_program, partition
),
)
],
)

tag_constant_data(exported_program)

Expand Down
122 changes: 122 additions & 0 deletions tests/py/dynamo/executorch/test_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import ast
import importlib
import sys
import types
from pathlib import Path

import pytest
import torch
from torch._library.fake_class_registry import FakeScriptObject

from torch_tensorrt.dynamo._exporter import _resolve_lifted_custom_obj


@pytest.mark.unit
Expand Down Expand Up @@ -144,3 +148,121 @@ def test_executorch_headers_are_not_dlfw_gated():
isinstance(node, ast.Name) and node.id == "IS_DLFW_CI"
for node in ast.walk(header_package_data)
)


def _stub_node(name, target=None):
return types.SimpleNamespace(name=name, target=name if target is None else target)


def _stub_exported_program(constants, name_to_fqn=None):
sig = (
None
if name_to_fqn is None
else types.SimpleNamespace(inputs_to_lifted_custom_objs=name_to_fqn)
)
return types.SimpleNamespace(constants=constants, graph_signature=sig)


@pytest.mark.unit
def test_resolve_lifted_custom_obj_via_signature_fqn():
# Modern torch.export: placeholder name differs from the constants FQN key.
sentinel = object()
ep = _stub_exported_program({"engine_fqn": sentinel}, {"obj_engine": "engine_fqn"})
assert _resolve_lifted_custom_obj(ep, _stub_node("obj_engine")) is sentinel


@pytest.mark.unit
def test_resolve_lifted_custom_obj_legacy_fallback():
# No signature mapping: fall back to a direct name/target lookup.
sentinel = object()
ep = _stub_exported_program({"engine": sentinel}, name_to_fqn=None)
assert _resolve_lifted_custom_obj(ep, _stub_node("engine")) is sentinel


@pytest.mark.unit
def test_resolve_lifted_custom_obj_signature_present_name_absent_is_none():
# A present-but-incomplete mapping must not bind a different object by name.
ep = _stub_exported_program({"engine": object()}, {"some_other_obj": "x"})
assert _resolve_lifted_custom_obj(ep, _stub_node("engine")) is None


@pytest.mark.unit
def test_resolve_lifted_custom_obj_missing_is_none():
ep = _stub_exported_program({}, name_to_fqn=None)
assert _resolve_lifted_custom_obj(ep, _stub_node("missing")) is None


@pytest.mark.unit
def test_resolve_lifted_custom_obj_unwraps_fake_script_object():
class _Real:
pass

fake = FakeScriptObject(object(), "Engine", _Real())
ep = _stub_exported_program({"engine_fqn": fake}, {"obj_engine": "engine_fqn"})
resolved = _resolve_lifted_custom_obj(ep, _stub_node("obj_engine"))
assert not isinstance(resolved, FakeScriptObject)
assert isinstance(resolved, _Real)


# --- per-partition target_device (TensorRTPartitioner) -----------------------
# These exercise the partitioner directly, so they need ExecuTorch installed;
# they run in the dedicated executorch CI job and skip elsewhere.


@pytest.mark.unit
def test_resolve_target_device_uses_partition_engine(monkeypatch):
pytest.importorskip("executorch.exir")
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import DEVICE_IDX
from torch_tensorrt.executorch import partitioner as P

part = P.TensorRTPartitioner()
engine_node = object()
monkeypatch.setattr(P, "_get_engine_nodes_in", lambda nodes: [engine_node])
info = ["0"] * (DEVICE_IDX + 1)
info[DEVICE_IDX] = "2"
monkeypatch.setattr(P, "_get_engine_info_for_node", lambda ep, n: info)

partition = types.SimpleNamespace(id=0, nodes=[engine_node])
assert part._resolve_target_device_for_partition(object(), partition) == b"cuda:2"


@pytest.mark.unit
def test_resolve_target_device_falls_back_when_not_one_engine(monkeypatch):
pytest.importorskip("executorch.exir")
from torch_tensorrt.executorch import partitioner as P

part = P.TensorRTPartitioner()
partition = types.SimpleNamespace(id=1, nodes=[])

monkeypatch.setattr(P, "_get_engine_nodes_in", lambda nodes: [])
assert part._resolve_target_device_for_partition(object(), partition) == b"cuda:0"

monkeypatch.setattr(P, "_get_engine_nodes_in", lambda nodes: [object(), object()])
assert part._resolve_target_device_for_partition(object(), partition) == b"cuda:0"


@pytest.mark.unit
def test_per_partition_distinct_target_devices(monkeypatch):
pytest.importorskip("executorch.exir")
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import DEVICE_IDX
from torch_tensorrt.executorch import partitioner as P

part = P.TensorRTPartitioner()
# Each partition's engine node carries its own device id as its value.
monkeypatch.setattr(P, "_get_engine_nodes_in", lambda nodes: [nodes[0]])

def fake_info(ep, node):
info = ["0"] * (DEVICE_IDX + 1)
info[DEVICE_IDX] = str(node)
return info

monkeypatch.setattr(P, "_get_engine_info_for_node", fake_info)
d0 = part._resolve_target_device_for_partition(
object(), types.SimpleNamespace(id=0, nodes=["0"])
)
d1 = part._resolve_target_device_for_partition(
object(), types.SimpleNamespace(id=1, nodes=["1"])
)
assert d0 == b"cuda:0"
assert d1 == b"cuda:1"
assert d0 != d1
Loading