Skip to content

Commit d068839

Browse files
committed
Add some detection of non-sequential neural networks
1 parent 069eb8b commit d068839

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed

src/gurobi_ml/onnx/onnx_model.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,69 @@ def __init__(self, gp_model, predictor, input_vars, output_vars=None, **kwargs):
8181

8282
super().__init__(gp_model, predictor, input_vars, output_vars, **kwargs)
8383

84+
def _validate_sequential_architecture(self, graph, init_map):
85+
"""Validate that the graph has a sequential architecture.
86+
87+
Raises NoModel if the graph contains:
88+
- Skip connections (same intermediate value used by multiple nodes)
89+
- Residual connections (Add nodes combining non-bias values)
90+
- Non-sequential topology
91+
"""
92+
# Build usage map: which nodes use each tensor
93+
tensor_usage = {}
94+
for node in graph.node:
95+
for inp in node.input:
96+
if inp not in tensor_usage:
97+
tensor_usage[inp] = []
98+
tensor_usage[inp].append(node.name)
99+
100+
# Check 1: Input should only be used by one node (first layer)
101+
for graph_input in graph.input:
102+
input_name = graph_input.name
103+
if input_name in tensor_usage and len(tensor_usage[input_name]) > 1:
104+
raise NoModel(
105+
graph,
106+
f"Non-sequential architecture detected: input '{input_name}' is used by multiple nodes {tensor_usage[input_name]}. "
107+
"Skip connections and residual architectures are not supported.",
108+
)
109+
110+
# Check 2: Each intermediate node output should be used by at most one node
111+
# (except for the final output which may not be used by any node)
112+
for node in graph.node:
113+
for output in node.output:
114+
if output in tensor_usage and len(tensor_usage[output]) > 1:
115+
raise NoModel(
116+
graph,
117+
f"Non-sequential architecture detected: node '{node.name}' output '{output}' is used by multiple nodes {tensor_usage[output]}. "
118+
"Skip connections and residual architectures are not supported.",
119+
)
120+
121+
# Check 3: Add nodes should only be used for bias addition (MatMul+Add pattern)
122+
# Not for combining two computed branches (residual connections)
123+
for node in graph.node:
124+
if node.op_type == "Add":
125+
# An Add is valid if one of its inputs is an initializer (bias)
126+
# and the other is from a MatMul
127+
inputs = list(node.input)
128+
if len(inputs) != 2:
129+
continue
130+
131+
# Check if this is a MatMul+Add pattern (one input from MatMul, one is initializer)
132+
is_bias_add = False
133+
for inp in inputs:
134+
if inp in init_map:
135+
# One input is a constant (bias)
136+
is_bias_add = True
137+
break
138+
139+
if not is_bias_add:
140+
# Both inputs are computed values - this is a residual connection
141+
raise NoModel(
142+
graph,
143+
f"Non-sequential architecture detected: Add node '{node.name}' combines two computed values {inputs}. "
144+
"Residual connections are not supported.",
145+
)
146+
84147
def _parse_mlp(self, model: onnx.ModelProto) -> list[_ONNXLayer]:
85148
"""Parse a limited subset of ONNX graphs representing MLPs.
86149
@@ -106,6 +169,9 @@ def _get_attr(node, name, default=None):
106169
return float(a.f)
107170
return default
108171

172+
# Validate that the graph is sequential (no skip connections or residual adds)
173+
self._validate_sequential_architecture(graph, init_map)
174+
109175
# Build a map from output name to node for easier traversal
110176
output_to_node = {}
111177
for node in graph.node:

tests/test_onnx/test_onnx_exceptions.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,125 @@ def test_unsupported_op(self):
2727
x = m.addMVar(example.shape, lb=0.0, ub=1.0, name="x")
2828
with self.assertRaises(NoModel):
2929
add_predictor_constr(m, model, x)
30+
31+
def test_skip_connection_rejected(self):
32+
# Build a model with skip connection: input used by multiple nodes
33+
n_in, n_hidden, n_out = 4, 8, 2
34+
35+
W1 = np.random.randn(n_in, n_hidden).astype(np.float32)
36+
b1 = np.random.randn(n_hidden).astype(np.float32)
37+
W2 = np.random.randn(n_hidden, n_out).astype(np.float32)
38+
b2 = np.random.randn(n_out).astype(np.float32)
39+
W_skip = np.random.randn(n_in, n_out).astype(np.float32)
40+
b_skip = np.random.randn(n_out).astype(np.float32)
41+
42+
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, n_in])
43+
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, n_out])
44+
45+
init_W1 = helper.make_tensor(
46+
"W1", TensorProto.FLOAT, W1.T.shape, W1.T.flatten()
47+
)
48+
init_b1 = helper.make_tensor("b1", TensorProto.FLOAT, b1.shape, b1)
49+
init_W2 = helper.make_tensor(
50+
"W2", TensorProto.FLOAT, W2.T.shape, W2.T.flatten()
51+
)
52+
init_b2 = helper.make_tensor("b2", TensorProto.FLOAT, b2.shape, b2)
53+
init_W_skip = helper.make_tensor(
54+
"W_skip", TensorProto.FLOAT, W_skip.T.shape, W_skip.T.flatten()
55+
)
56+
init_b_skip = helper.make_tensor(
57+
"b_skip", TensorProto.FLOAT, b_skip.shape, b_skip
58+
)
59+
60+
# Main path
61+
gemm1 = helper.make_node("Gemm", ["X", "W1", "b1"], ["H1"], transB=1)
62+
relu1 = helper.make_node("Relu", ["H1"], ["A1"])
63+
gemm2 = helper.make_node("Gemm", ["A1", "W2", "b2"], ["branch1"], transB=1)
64+
65+
# Skip connection path - uses X again!
66+
gemm_skip = helper.make_node(
67+
"Gemm", ["X", "W_skip", "b_skip"], ["branch2"], transB=1
68+
)
69+
70+
# Combine branches (residual add)
71+
add = helper.make_node("Add", ["branch1", "branch2"], ["Y"])
72+
73+
graph = helper.make_graph(
74+
[gemm1, relu1, gemm2, gemm_skip, add],
75+
"SkipConnectionMLP",
76+
[X],
77+
[Y],
78+
[init_W1, init_b1, init_W2, init_b2, init_W_skip, init_b_skip],
79+
)
80+
81+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
82+
model.ir_version = 9
83+
onnx.checker.check_model(model)
84+
85+
m = gp.Model()
86+
x = m.addMVar((n_in,), lb=-1.0, ub=1.0, name="x")
87+
with self.assertRaises(NoModel) as cm:
88+
add_predictor_constr(m, model, x)
89+
90+
# Verify the error message mentions skip connections
91+
self.assertIn("skip connection", str(cm.exception).lower())
92+
93+
def test_residual_connection_rejected(self):
94+
# Build a model with residual connection: intermediate value used by multiple nodes
95+
n_in, n_hidden, n_out = 4, 8, 2
96+
97+
W1 = np.random.randn(n_in, n_hidden).astype(np.float32)
98+
b1 = np.random.randn(n_hidden).astype(np.float32)
99+
W2a = np.random.randn(n_hidden, n_out).astype(np.float32)
100+
b2a = np.random.randn(n_out).astype(np.float32)
101+
W2b = np.random.randn(n_hidden, n_out).astype(np.float32)
102+
b2b = np.random.randn(n_out).astype(np.float32)
103+
104+
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, n_in])
105+
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, n_out])
106+
107+
init_W1 = helper.make_tensor(
108+
"W1", TensorProto.FLOAT, W1.T.shape, W1.T.flatten()
109+
)
110+
init_b1 = helper.make_tensor("b1", TensorProto.FLOAT, b1.shape, b1)
111+
init_W2a = helper.make_tensor(
112+
"W2a", TensorProto.FLOAT, W2a.T.shape, W2a.T.flatten()
113+
)
114+
init_b2a = helper.make_tensor("b2a", TensorProto.FLOAT, b2a.shape, b2a)
115+
init_W2b = helper.make_tensor(
116+
"W2b", TensorProto.FLOAT, W2b.T.shape, W2b.T.flatten()
117+
)
118+
init_b2b = helper.make_tensor("b2b", TensorProto.FLOAT, b2b.shape, b2b)
119+
120+
# Shared layer
121+
gemm1 = helper.make_node("Gemm", ["X", "W1", "b1"], ["H1"], transB=1)
122+
relu1 = helper.make_node("Relu", ["H1"], ["A1"])
123+
124+
# Branch 1 - uses A1
125+
gemm2a = helper.make_node("Gemm", ["A1", "W2a", "b2a"], ["branch1"], transB=1)
126+
127+
# Branch 2 - also uses A1!
128+
gemm2b = helper.make_node("Gemm", ["A1", "W2b", "b2b"], ["branch2"], transB=1)
129+
130+
# Combine branches
131+
add = helper.make_node("Add", ["branch1", "branch2"], ["Y"])
132+
133+
graph = helper.make_graph(
134+
[gemm1, relu1, gemm2a, gemm2b, add],
135+
"ResidualMLP",
136+
[X],
137+
[Y],
138+
[init_W1, init_b1, init_W2a, init_b2a, init_W2b, init_b2b],
139+
)
140+
141+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
142+
model.ir_version = 9
143+
onnx.checker.check_model(model)
144+
145+
m = gp.Model()
146+
x = m.addMVar((n_in,), lb=-1.0, ub=1.0, name="x")
147+
with self.assertRaises(NoModel) as cm:
148+
add_predictor_constr(m, model, x)
149+
150+
# Verify the error message mentions the architecture issue
151+
self.assertIn("non-sequential", str(cm.exception).lower())

0 commit comments

Comments
 (0)