Skip to content

Commit b926a99

Browse files
authored
Ensured fx nodes cannot be named after keywords (#32)
Fixes a bug where nodes like nir.IF would be called if in the generated Python code
1 parent 7fb614e commit b926a99

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

nirtorch/nir_interpreter.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import keyword
23
import typing
34

45
import nir
@@ -68,6 +69,14 @@ def _default_map_linear(linear: nir.Linear) -> torch.nn.Linear:
6869
}
6970

7071

72+
def _sanitize_name(name: str) -> str:
73+
"""Sanitize module name to ensure torch.fx doesn't write any keywords in code"""
74+
if keyword.iskeyword(name):
75+
return "nir_node_" + name
76+
else:
77+
return name
78+
79+
7180
def _map_nir_node_to_torch(
7281
node: nir.NIRNode, node_map: NodeMapType
7382
) -> typing.Optional[torch.nn.Module]:
@@ -86,11 +95,13 @@ def _construct_module_dict_recursive(
8695
for name, node in nir_graph.nodes.items():
8796
# Recurse into subgraphs
8897
if isinstance(node, nir.NIRGraph):
89-
owning_module[name] = _construct_module_dict_recursive(node, node_map)
98+
owning_module[_sanitize_name(name)] = _construct_module_dict_recursive(
99+
node, node_map
100+
)
90101
else:
91102
mapped_module = _map_nir_node_to_torch(node, node_map=node_map)
92103
if mapped_module is not None:
93-
owning_module[name] = mapped_module
104+
owning_module[_sanitize_name(name)] = mapped_module
94105
return owning_module
95106

96107

@@ -138,6 +149,9 @@ def _construct_fx_graph(
138149
) -> torch.fx.GraphModule:
139150
node_outputs = {}
140151
recursion_counter = collections.Counter(nir_graph.nodes.keys())
152+
sanitized_edges = [
153+
(_sanitize_name(a), _sanitize_name(b)) for a, b in nir_graph.edges
154+
]
141155
# The maximum iterations per node (see https://github.com/neuromorphs/NIRTorch/pull/28#discussion_r1959343951)
142156
max_iterations = min(3, len(recursion_counter))
143157
torch_graph = torch.fx.Graph(owning_module)
@@ -151,6 +165,10 @@ def _construct_fx_graph(
151165
# Loop through all the nodes in the queue
152166
while module_queue:
153167
module_name, module = module_queue.popleft()
168+
# Sanitize the module name to avoid writing keywords in the generated Python code
169+
module_name = _sanitize_name(module_name)
170+
171+
# Test for number of recursions
154172
if recursion_counter[module_name] > max_iterations:
155173
raise RecursionError(
156174
f"Module {module_name} has been traversed multiple times"
@@ -185,7 +203,7 @@ def _construct_fx_graph(
185203
for input_name, output in module.input_type.items():
186204
# First fetch the required input nodes
187205
module_input_nodes = _find_input_nodes(
188-
module_name, edges=nir_graph.edges, node_outputs=node_outputs
206+
module_name, edges=sanitized_edges, node_outputs=node_outputs
189207
)
190208
# If the module uses input that is not yet defined, set the inputs to some dummy value
191209
# and enqueue the module again for processing (where it's hopefully defined)
@@ -239,7 +257,7 @@ def _construct_fx_graph(
239257
# Add the raw output to the graph
240258
node_outputs[f"{module_name}_raw"] = output
241259
# Add the module state to the graph for use as the input to the next module
242-
node_outputs[f"{module_name}"] = torch_graph.call_method(
260+
node_outputs[module_name] = torch_graph.call_method(
243261
"__getitem__", (output, 0)
244262
)
245263
# Add the state to the state dictionary
@@ -273,6 +291,7 @@ def nir_to_torch(
273291
Finally, we wrap the execution in a StatefulInterpreter, to ensure that the internal state of modules are handled correctly.
274292
275293
Example:
294+
276295
>>> # First, we describe the NIR graph
277296
>>> nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
278297
>>> nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))

nirtorch/torch_tracer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def call_module(self, target, args, kwargs):
7171
def torch_to_nir(
7272
module: torch.nn.Module,
7373
module_map: Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]],
74-
default_dict: Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]] = DEFAULT_MAP,
74+
default_dict: Dict[
75+
torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]
76+
] = DEFAULT_MAP,
7577
) -> nir.NIRGraph:
7678
"""
7779
Traces a PyTorch module and converts it to a NIR graph using the specified module map.

tests/test_nir_interpreter.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,35 @@ def test_map_conv2d_node():
115115
assert torch_conv(torch.ones(1, 3, 10, 11)).shape == (1, 2, 5, 3)
116116

117117

118+
def test_map_if_node():
119+
# Generating code with the "if" keyword can be sensitive
120+
# This tests ensures that it works properly
121+
v_th = np.random.random(1)
122+
r = np.random.random(1)
123+
nir_node = nir.IF(r=r, v_threshold=v_th)
124+
125+
class MyIF(torch.nn.Module):
126+
def __init__(self, r, v_th):
127+
super().__init__()
128+
self.r = r
129+
self.v_th = v_th
130+
131+
def forward(self, x, state):
132+
return x * self.r - self.v_th, state
133+
134+
node_map = {
135+
nir.IF: lambda node: MyIF(
136+
torch.from_numpy(node.r), torch.from_numpy(node.v_threshold)
137+
)
138+
}
139+
torch_if = nir_interpreter._map_nir_node_to_torch(nir_node, node_map)
140+
data = torch.rand(1)
141+
assert torch.allclose(data * r - v_th, torch_if(data, None)[0])
142+
143+
torch_if = nir_interpreter.nir_to_torch(nir.NIRGraph.from_list(nir_node), node_map)
144+
assert isinstance(torch_if.get_submodule("nir_node_if"), MyIF)
145+
146+
118147
def test_map_leaky_stateful_graph_single_module():
119148
# Test that the graph can handle a single stateful module
120149
tau = np.random.random(1)

0 commit comments

Comments
 (0)