Skip to content

Commit 791f962

Browse files
committed
handle parallel inputs during default PreHookInsertionPoint generation
1 parent 0d5db82 commit 791f962

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/nncf/common/insertion_point_graph.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,10 @@ def _get_default_pre_hook_ip_list(nncf_graph: NNCFGraph) -> list[PreHookInsertio
236236

237237
for pred_node in pred_nodes:
238238
input_edge = nncf_graph.get_edge(pred_node, nncf_node)
239-
allowed_pre_hook_insertion_points.append(
240-
PreHookInsertionPoint(nncf_node.node_name, input_edge.input_port_id)
241-
)
239+
input_port_ids = [input_edge.input_port_id] + input_edge.parallel_input_port_ids
240+
node_name = nncf_node.node_name
241+
for input_port_id in input_port_ids:
242+
allowed_pre_hook_insertion_points.append(PreHookInsertionPoint(node_name, input_port_id))
242243
return allowed_pre_hook_insertion_points
243244

244245
@staticmethod

0 commit comments

Comments
 (0)