Skip to content

deduplicate torch ao debugger tests between pytorch/ao and ExecuTorch #2390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 3 additions & 149 deletions test/quantization/pt2e/test_numeric_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,13 @@

import torch
from torch.testing._internal.common_quantization import TestHelperModules
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, run_tests
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests

from torchao.quantization.pt2e import (
CUSTOM_KEY,
NUMERIC_DEBUG_HANDLE_KEY,
compare_results,
extract_results_from_loggers,
generate_numeric_debug_handle,
prepare_for_propagation_comparison,
)
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.testing.pt2e._xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from torchao.testing.pt2e.utils import PT2ENumericDebuggerTestCase
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7

if TORCH_VERSION_AT_LEAST_2_7:
Expand All @@ -36,59 +27,7 @@

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
class TestNumericDebugger(TestCase):
def _assert_each_node_has_debug_handle(self, model) -> None:
def _assert_node_has_debug_handle(node):
self.assertTrue(
CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
f"Node {node} doesn't have debug handle",
)

bfs_trace_with_node_process(model, _assert_node_has_debug_handle)

def _extract_debug_handles(self, model) -> dict[str, int]:
debug_handle_map: dict[str, int] = {}

def _extract_debug_handles_from_node(node):
nonlocal debug_handle_map
if (
CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
):
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
NUMERIC_DEBUG_HANDLE_KEY
]

bfs_trace_with_node_process(model, _extract_debug_handles_from_node)

return debug_handle_map

def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]:
prev_decomp_op_to_debug_handle_map: dict[str, int] = {}

def _extract_debug_handles_with_prev_decomp_op_from_node(node):
nonlocal prev_decomp_op_to_debug_handle_map
if (
CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
):
prev_decomp_op = str(node.meta.get("nn_module_stack"))
debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
if prev_decomp_op not in prev_decomp_op_to_debug_handle_map:
prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle
else:
assert (
prev_decomp_op_to_debug_handle_map[prev_decomp_op]
== debug_handle
), f"Node {node} has different debug handle {debug_handle}"
"than previous node sharing the same decomp op {prev_decomp_op}"

bfs_trace_with_node_process(
model, _extract_debug_handles_with_prev_decomp_op_from_node
)
return prev_decomp_op_to_debug_handle_map

class TestNumericDebuggerInfra(PT2ENumericDebuggerTestCase):
@unittest.skip(
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
)
Expand All @@ -113,36 +52,6 @@ def test_control_flow(self):

self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))

def test_quantize_pt2e_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()

quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=False)
)
m = prepare_pt2e(m, quantizer)
debug_handle_map = self._extract_debug_handles(m)
res_counter = Counter(debug_handle_map.values())
repeated_debug_handle_ids = [1, 2, 3]
# 3 ids were repeated because we copy over the id from node to its output observer
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
for dh_id in repeated_debug_handle_ids:
self.assertEqual(res_counter[dh_id], 2)

m(*example_inputs)
m = convert_pt2e(m)
self._assert_each_node_has_debug_handle(ep)
debug_handle_map = self._extract_debug_handles(m)
res_counter = Counter(debug_handle_map.values())
# same set of ids where repeated, because we copy over the id from observer/fake_quant to
# dequantize node
repeated_debug_handle_ids = [1, 2, 3]
for dh_id in repeated_debug_handle_ids:
self.assertEqual(res_counter[dh_id], 2)

def test_copy_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
Expand Down Expand Up @@ -262,61 +171,6 @@ def test_prepare_for_propagation_comparison(self):
self.assertTrue("conv2d" in [logger.node_name for logger in loggers])
self.assertEqual(res, ref)

def test_extract_results_from_loggers(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)

quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=False)
)
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m_quant_logger = prepare_for_propagation_comparison(m)

m_ref_logger(*example_inputs)
m_quant_logger(*example_inputs)
ref_results = extract_results_from_loggers(m_ref_logger)
quant_results = extract_results_from_loggers(m_quant_logger)
comparison_results = compare_results(ref_results, quant_results)
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)

def test_extract_results_from_loggers_list_output(self):
m = TestHelperModules.Conv2dWithSplit()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)

quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=False)
)
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m_quant_logger = prepare_for_propagation_comparison(m)

m_ref_logger(*example_inputs)
m_quant_logger(*example_inputs)
ref_results = extract_results_from_loggers(m_ref_logger)
quant_results = extract_results_from_loggers(m_quant_logger)
comparison_results = compare_results(ref_results, quant_results)
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
sqnr = node_summary.results[0].sqnr
if isinstance(sqnr, list):
for sqnr_i in sqnr:
self.assertGreaterEqual(sqnr_i, 35)
else:
self.assertGreaterEqual(sqnr, 35)

def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
Expand Down
72 changes: 71 additions & 1 deletion torchao/testing/pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import copy
import unittest
from typing import Dict

import torch
from torch.ao.quantization.backend_config import (
Expand All @@ -19,13 +20,19 @@
NodeSpec,
QuantizationTestCase,
)
from torch.testing._internal.common_utils import TestCase

from torchao.quantization.pt2e import (
CUSTOM_KEY,
NUMERIC_DEBUG_HANDLE_KEY,
)
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7

if TORCH_VERSION_AT_LEAST_2_5:
from torch.export import export_for_training
Expand Down Expand Up @@ -133,3 +140,66 @@ def _test_quantizer(
fx_quant_output = m_fx(*example_inputs)
self.assertEqual(fx_quant_output, pt2_quant_output)
return m


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
class PT2ENumericDebuggerTestCase(TestCase):
"""
Base test case class for PT2E numeric debugger tests containing common utility functions
for numeric debugging functionality.
"""

def _assert_each_node_has_debug_handle(self, model) -> None:
"""Assert that each node in the model has a debug handle."""

def _assert_node_has_debug_handle(node):
self.assertTrue(
CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
f"Node {node} doesn't have debug handle",
)

bfs_trace_with_node_process(model, _assert_node_has_debug_handle)

def _extract_debug_handles(self, model) -> Dict[str, int]:
"""Extract debug handles from all nodes in the model."""
debug_handle_map: Dict[str, int] = {}

def _extract_debug_handles_from_node(node):
nonlocal debug_handle_map
if (
CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
):
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
NUMERIC_DEBUG_HANDLE_KEY
]

bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
return debug_handle_map

def _extract_debug_handles_with_prev_decomp_op(self, model) -> Dict[str, int]:
"""Extract debug handles with previous decomposition operation mapping."""
prev_decomp_op_to_debug_handle_map: Dict[str, int] = {}

def _extract_debug_handles_with_prev_decomp_op_from_node(node):
nonlocal prev_decomp_op_to_debug_handle_map
if (
CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
):
prev_decomp_op = str(node.meta.get("nn_module_stack"))
debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
if prev_decomp_op not in prev_decomp_op_to_debug_handle_map:
prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle
else:
assert (
prev_decomp_op_to_debug_handle_map[prev_decomp_op]
== debug_handle
), f"Node {node} has different debug handle {debug_handle}"
"than previous node sharing the same decomp op {prev_decomp_op}"

bfs_trace_with_node_process(
model, _extract_debug_handles_with_prev_decomp_op_from_node
)
return prev_decomp_op_to_debug_handle_map
Loading