1
1
import collections
2
+ import keyword
2
3
import typing
3
4
4
5
import nir
@@ -68,6 +69,14 @@ def _default_map_linear(linear: nir.Linear) -> torch.nn.Linear:
68
69
}
69
70
70
71
72
+ def _sanitize_name (name : str ) -> str :
73
+ """Sanitize module name to ensure torch.fx doesn't write any keywords in code"""
74
+ if keyword .iskeyword (name ):
75
+ return "nir_node_" + name
76
+ else :
77
+ return name
78
+
79
+
71
80
def _map_nir_node_to_torch (
72
81
node : nir .NIRNode , node_map : NodeMapType
73
82
) -> typing .Optional [torch .nn .Module ]:
@@ -86,11 +95,13 @@ def _construct_module_dict_recursive(
86
95
for name , node in nir_graph .nodes .items ():
87
96
# Recurse into subgraphs
88
97
if isinstance (node , nir .NIRGraph ):
89
- owning_module [name ] = _construct_module_dict_recursive (node , node_map )
98
+ owning_module [_sanitize_name (name )] = _construct_module_dict_recursive (
99
+ node , node_map
100
+ )
90
101
else :
91
102
mapped_module = _map_nir_node_to_torch (node , node_map = node_map )
92
103
if mapped_module is not None :
93
- owning_module [name ] = mapped_module
104
+ owning_module [_sanitize_name ( name ) ] = mapped_module
94
105
return owning_module
95
106
96
107
@@ -138,6 +149,9 @@ def _construct_fx_graph(
138
149
) -> torch .fx .GraphModule :
139
150
node_outputs = {}
140
151
recursion_counter = collections .Counter (nir_graph .nodes .keys ())
152
+ sanitized_edges = [
153
+ (_sanitize_name (a ), _sanitize_name (b )) for a , b in nir_graph .edges
154
+ ]
141
155
# The maximum iterations per node (see https://github.com/neuromorphs/NIRTorch/pull/28#discussion_r1959343951)
142
156
max_iterations = min (3 , len (recursion_counter ))
143
157
torch_graph = torch .fx .Graph (owning_module )
@@ -151,6 +165,10 @@ def _construct_fx_graph(
151
165
# Loop through all the nodes in the queue
152
166
while module_queue :
153
167
module_name , module = module_queue .popleft ()
168
+ # Sanitize the module name to avoid writing keywords in the generated Python code
169
+ module_name = _sanitize_name (module_name )
170
+
171
+ # Test for number of recursions
154
172
if recursion_counter [module_name ] > max_iterations :
155
173
raise RecursionError (
156
174
f"Module { module_name } has been traversed multiple times"
@@ -185,7 +203,7 @@ def _construct_fx_graph(
185
203
for input_name , output in module .input_type .items ():
186
204
# First fetch the required input nodes
187
205
module_input_nodes = _find_input_nodes (
188
- module_name , edges = nir_graph . edges , node_outputs = node_outputs
206
+ module_name , edges = sanitized_edges , node_outputs = node_outputs
189
207
)
190
208
# If the module uses input that is not yet defined, set the inputs to some dummy value
191
209
# and enqueue the module again for processing (where it's hopefully defined)
@@ -239,7 +257,7 @@ def _construct_fx_graph(
239
257
# Add the raw output to the graph
240
258
node_outputs [f"{ module_name } _raw" ] = output
241
259
# Add the module state to the graph for use as the input to the next module
242
- node_outputs [f" { module_name } " ] = torch_graph .call_method (
260
+ node_outputs [module_name ] = torch_graph .call_method (
243
261
"__getitem__" , (output , 0 )
244
262
)
245
263
# Add the state to the state dictionary
@@ -273,6 +291,7 @@ def nir_to_torch(
273
291
Finally, we wrap the execution in a StatefulInterpreter, to ensure that the internal state of modules are handled correctly.
274
292
275
293
Example:
294
+
276
295
>>> # First, we describe the NIR graph
277
296
>>> nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
278
297
>>> nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))
0 commit comments