@@ -27,3 +27,125 @@ def test_unsupported_op(self):
2727 x = m .addMVar (example .shape , lb = 0.0 , ub = 1.0 , name = "x" )
2828 with self .assertRaises (NoModel ):
2929 add_predictor_constr (m , model , x )
30+
31+ def test_skip_connection_rejected (self ):
32+ # Build a model with skip connection: input used by multiple nodes
33+ n_in , n_hidden , n_out = 4 , 8 , 2
34+
35+ W1 = np .random .randn (n_in , n_hidden ).astype (np .float32 )
36+ b1 = np .random .randn (n_hidden ).astype (np .float32 )
37+ W2 = np .random .randn (n_hidden , n_out ).astype (np .float32 )
38+ b2 = np .random .randn (n_out ).astype (np .float32 )
39+ W_skip = np .random .randn (n_in , n_out ).astype (np .float32 )
40+ b_skip = np .random .randn (n_out ).astype (np .float32 )
41+
42+ X = helper .make_tensor_value_info ("X" , TensorProto .FLOAT , [None , n_in ])
43+ Y = helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , [None , n_out ])
44+
45+ init_W1 = helper .make_tensor (
46+ "W1" , TensorProto .FLOAT , W1 .T .shape , W1 .T .flatten ()
47+ )
48+ init_b1 = helper .make_tensor ("b1" , TensorProto .FLOAT , b1 .shape , b1 )
49+ init_W2 = helper .make_tensor (
50+ "W2" , TensorProto .FLOAT , W2 .T .shape , W2 .T .flatten ()
51+ )
52+ init_b2 = helper .make_tensor ("b2" , TensorProto .FLOAT , b2 .shape , b2 )
53+ init_W_skip = helper .make_tensor (
54+ "W_skip" , TensorProto .FLOAT , W_skip .T .shape , W_skip .T .flatten ()
55+ )
56+ init_b_skip = helper .make_tensor (
57+ "b_skip" , TensorProto .FLOAT , b_skip .shape , b_skip
58+ )
59+
60+ # Main path
61+ gemm1 = helper .make_node ("Gemm" , ["X" , "W1" , "b1" ], ["H1" ], transB = 1 )
62+ relu1 = helper .make_node ("Relu" , ["H1" ], ["A1" ])
63+ gemm2 = helper .make_node ("Gemm" , ["A1" , "W2" , "b2" ], ["branch1" ], transB = 1 )
64+
65+ # Skip connection path - uses X again!
66+ gemm_skip = helper .make_node (
67+ "Gemm" , ["X" , "W_skip" , "b_skip" ], ["branch2" ], transB = 1
68+ )
69+
70+ # Combine branches (residual add)
71+ add = helper .make_node ("Add" , ["branch1" , "branch2" ], ["Y" ])
72+
73+ graph = helper .make_graph (
74+ [gemm1 , relu1 , gemm2 , gemm_skip , add ],
75+ "SkipConnectionMLP" ,
76+ [X ],
77+ [Y ],
78+ [init_W1 , init_b1 , init_W2 , init_b2 , init_W_skip , init_b_skip ],
79+ )
80+
81+ model = helper .make_model (graph , opset_imports = [helper .make_opsetid ("" , 18 )])
82+ model .ir_version = 9
83+ onnx .checker .check_model (model )
84+
85+ m = gp .Model ()
86+ x = m .addMVar ((n_in ,), lb = - 1.0 , ub = 1.0 , name = "x" )
87+ with self .assertRaises (NoModel ) as cm :
88+ add_predictor_constr (m , model , x )
89+
90+ # Verify the error message mentions skip connections
91+ self .assertIn ("skip connection" , str (cm .exception ).lower ())
92+
93+ def test_residual_connection_rejected (self ):
94+ # Build a model with residual connection: intermediate value used by multiple nodes
95+ n_in , n_hidden , n_out = 4 , 8 , 2
96+
97+ W1 = np .random .randn (n_in , n_hidden ).astype (np .float32 )
98+ b1 = np .random .randn (n_hidden ).astype (np .float32 )
99+ W2a = np .random .randn (n_hidden , n_out ).astype (np .float32 )
100+ b2a = np .random .randn (n_out ).astype (np .float32 )
101+ W2b = np .random .randn (n_hidden , n_out ).astype (np .float32 )
102+ b2b = np .random .randn (n_out ).astype (np .float32 )
103+
104+ X = helper .make_tensor_value_info ("X" , TensorProto .FLOAT , [None , n_in ])
105+ Y = helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , [None , n_out ])
106+
107+ init_W1 = helper .make_tensor (
108+ "W1" , TensorProto .FLOAT , W1 .T .shape , W1 .T .flatten ()
109+ )
110+ init_b1 = helper .make_tensor ("b1" , TensorProto .FLOAT , b1 .shape , b1 )
111+ init_W2a = helper .make_tensor (
112+ "W2a" , TensorProto .FLOAT , W2a .T .shape , W2a .T .flatten ()
113+ )
114+ init_b2a = helper .make_tensor ("b2a" , TensorProto .FLOAT , b2a .shape , b2a )
115+ init_W2b = helper .make_tensor (
116+ "W2b" , TensorProto .FLOAT , W2b .T .shape , W2b .T .flatten ()
117+ )
118+ init_b2b = helper .make_tensor ("b2b" , TensorProto .FLOAT , b2b .shape , b2b )
119+
120+ # Shared layer
121+ gemm1 = helper .make_node ("Gemm" , ["X" , "W1" , "b1" ], ["H1" ], transB = 1 )
122+ relu1 = helper .make_node ("Relu" , ["H1" ], ["A1" ])
123+
124+ # Branch 1 - uses A1
125+ gemm2a = helper .make_node ("Gemm" , ["A1" , "W2a" , "b2a" ], ["branch1" ], transB = 1 )
126+
127+ # Branch 2 - also uses A1!
128+ gemm2b = helper .make_node ("Gemm" , ["A1" , "W2b" , "b2b" ], ["branch2" ], transB = 1 )
129+
130+ # Combine branches
131+ add = helper .make_node ("Add" , ["branch1" , "branch2" ], ["Y" ])
132+
133+ graph = helper .make_graph (
134+ [gemm1 , relu1 , gemm2a , gemm2b , add ],
135+ "ResidualMLP" ,
136+ [X ],
137+ [Y ],
138+ [init_W1 , init_b1 , init_W2a , init_b2a , init_W2b , init_b2b ],
139+ )
140+
141+ model = helper .make_model (graph , opset_imports = [helper .make_opsetid ("" , 18 )])
142+ model .ir_version = 9
143+ onnx .checker .check_model (model )
144+
145+ m = gp .Model ()
146+ x = m .addMVar ((n_in ,), lb = - 1.0 , ub = 1.0 , name = "x" )
147+ with self .assertRaises (NoModel ) as cm :
148+ add_predictor_constr (m , model , x )
149+
150+ # Verify the error message mentions the architecture issue
151+ self .assertIn ("non-sequential" , str (cm .exception ).lower ())
0 commit comments