@@ -227,12 +227,6 @@ def _create_output_vars(self, input_vars):
227227 int (output_shape_1 ),
228228 self .channels ,
229229 )
230- print (
231- f"Conv2D layer with input shape { input_vars .shape } gives output shape { output_shape } "
232- )
233- print (
234- f" kernel size { self .kernel_size } , stride { self .strides } , padding { self .padding } "
235- )
236230 rval = self .gp_model .addMVar (output_shape , lb = - gp .GRB .INFINITY , name = "act" )
237231 self .gp_model .update ()
238232 return rval
@@ -242,7 +236,8 @@ def _mip_model(self, **kwargs):
242236 model = self .gp_model
243237 model .update ()
244238
245- (_ , height , width , _ ) = self .input .shape
239+ (_ , height , width , in_c ) = self .input .shape
240+ out_n , out_h , out_w , out_c = self .output .shape
246241 mixing = self .gp_model .addMVar (
247242 self .output .shape ,
248243 lb = - gp .GRB .INFINITY ,
@@ -254,25 +249,24 @@ def _mip_model(self, **kwargs):
254249
255250 assert self .padding == "valid"
256251
257- # Here comes the complicated loop...
258- # I am sure there is a better way but this is a pedestrian version
259- kernel_w , kernel_h = self .kernel_size
260- stride_h , stride_w = self .strides
261- for k in range (self .channels ):
262- for out_i , i in enumerate (range (0 , height - kernel_h + 1 , stride_h )):
263- if i + kernel_h > height :
252+ kh , kw = self .kernel_size
253+ sh , sw = self .strides
254+ # Pre-flatten kernel to (kh*kw*in_c, out_c) for efficient batched matmul
255+ coefs_flat = self .coefs .reshape (int (kh * kw * in_c ), int (out_c ))
256+
257+ for oi in range (int (out_h )):
258+ i = oi * sh
259+ if i + kh > height :
260+ continue
261+ for oj in range (int (out_w )):
262+ j = oj * sw
263+ if j + kw > width :
264264 continue
265- for out_j , j in enumerate (range (0 , width - kernel_w + 1 , stride_w )):
266- if j + kernel_w > width :
267- continue
268- self .gp_model .addConstr (
269- mixing [:, out_i , out_j , k ]
270- == (
271- self .input [:, i : i + kernel_h , j : j + kernel_w , :]
272- * self .coefs [:, :, :, k ]
273- ).sum ()
274- + self .intercept [k ]
275- )
265+ # Extract patch (batch, kh, kw, in_c) and flatten to (batch, kh*kw*in_c)
266+ patch = self .input [:, i : i + kh , j : j + kw , :]
267+ patch2d = patch .reshape (int (out_n ), int (kh * kw * in_c ))
268+ expr = patch2d @ coefs_flat + self .intercept
269+ self .gp_model .addConstr (mixing [:, oi , oj , :] == expr )
276270
277271 if "activation" in kwargs :
278272 activation = kwargs ["activation" ]
@@ -313,7 +307,6 @@ def __init__(self, gp_model, output_vars, input_vars, **kwargs):
313307 def _create_output_vars (self , input_vars ):
314308 assert len (input_vars .shape ) >= 2
315309 output_shape = (input_vars .shape [0 ], int (np .prod (input_vars .shape [1 :])))
316- print (f"Flattening { input_vars .shape } into { output_shape } " )
317310 rval = self .gp_model .addMVar (output_shape , lb = - gp .GRB .INFINITY , name = "act" )
318311 self .gp_model .update ()
319312 return rval
@@ -370,9 +363,6 @@ def _create_output_vars(self, input_vars):
370363 )
371364 rval = self .gp_model .addMVar (output_shape , lb = - gp .GRB .INFINITY , name = "act" )
372365 self .gp_model .update ()
373- print (
374- f"MaxPool2D layer with input shape { input_vars .shape } gives output shape { output_shape } "
375- )
376366 return rval
377367
378368 def _mip_model (self , ** kwargs ):
0 commit comments