@@ -268,60 +268,68 @@ def __init__(
268268 )
269269
270270 conv_dim = self .d_inner_local_tp + 2 * self .ngroups_local_tp * self .d_state # x B C
271- with get_cuda_rng_tracker ().fork ():
272- # weight shape: [conv_dim, 1, d_conv]
273- # bias shape: [conv_dim]
274- self .conv1d = nn .Conv1d (
275- in_channels = conv_dim ,
276- out_channels = conv_dim ,
277- bias = conv_bias ,
278- kernel_size = d_conv ,
279- groups = conv_dim ,
280- padding = d_conv - 1 ,
281- device = torch .cuda .current_device (),
282- dtype = config .params_dtype ,
283- )
284- setattr (self .conv1d .weight , "tensor_model_parallel" , True )
285- setattr (self .conv1d .bias , "tensor_model_parallel" , True )
271+ # weight shape: [conv_dim, 1, d_conv]
272+ # bias shape: [conv_dim]
273+ self .conv1d = nn .Conv1d (
274+ in_channels = conv_dim ,
275+ out_channels = conv_dim ,
276+ bias = conv_bias ,
277+ kernel_size = d_conv ,
278+ groups = conv_dim ,
279+ padding = d_conv - 1 ,
280+ device = torch .cuda .current_device (),
281+ dtype = config .params_dtype ,
282+ )
283+ setattr (self .conv1d .weight , "tensor_model_parallel" , True )
284+ setattr (self .conv1d .bias , "tensor_model_parallel" , True )
286285
287- if self .conv_init is not None :
286+ if self .config .perform_initialization and self .conv_init is not None :
287+ with get_cuda_rng_tracker ().fork ():
288288 nn .init .uniform_ (self .conv1d .weight , - self .conv_init , self .conv_init )
289289
290290 self .activation = "silu"
291291 self .act = nn .SiLU ()
292292
293- with get_cuda_rng_tracker ().fork ():
294- # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
295- dt = torch .exp (
296- torch .rand (
297- self .nheads_local_tp ,
298- device = torch .cuda .current_device (),
299- dtype = config .params_dtype ,
300- )
301- * (math .log (dt_max ) - math .log (dt_min ))
302- + math .log (dt_min )
303- ).clamp (min = dt_init_floor )
304- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
305- inv_dt = dt + torch .log (- torch .expm1 (- dt ))
306- self .dt_bias = nn .Parameter (inv_dt )
307- # Our initialization would set all Linear.bias to zero,
308- # need to mark this one as _no_reinit
309- self .dt_bias ._no_reinit = True
310- # Just to be explicit. Without this we already don't
311- # put wd on dt_bias because of the check
312- # name.endswith("bias") in param_grouping.py
313- self .dt_bias ._no_weight_decay = True
314- setattr (self .dt_bias , "tensor_model_parallel" , True )
315-
316- # A parameter
317- assert A_init_range [0 ] > 0 and A_init_range [1 ] >= A_init_range [0 ]
318- A = torch .empty (
319- self .nheads_local_tp , dtype = torch .float32 , device = torch .cuda .current_device ()
320- ).uniform_ (* A_init_range )
321- A_log = torch .log (A ) # Keep A_log in fp32
322- self .A_log = nn .Parameter (A_log )
323- self .A_log ._no_weight_decay = True
324- setattr (self .A_log , "tensor_model_parallel" , True )
293+ if self .config .perform_initialization :
294+ with get_cuda_rng_tracker ().fork ():
295+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
296+ dt = torch .exp (
297+ torch .rand (
298+ self .nheads_local_tp ,
299+ device = torch .cuda .current_device (),
300+ dtype = config .params_dtype ,
301+ )
302+ * (math .log (dt_max ) - math .log (dt_min ))
303+ + math .log (dt_min )
304+ ).clamp (min = dt_init_floor )
305+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
306+ inv_dt = dt + torch .log (- torch .expm1 (- dt ))
307+ else :
308+ inv_dt = torch .empty (
309+ self .nheads_local_tp , device = torch .cuda .current_device (), dtype = config .params_dtype
310+ )
311+
312+ self .dt_bias = nn .Parameter (inv_dt )
313+ # Our initialization would set all Linear.bias to zero,
314+ # need to mark this one as _no_reinit
315+ self .dt_bias ._no_reinit = True
316+ # Just to be explicit. Without this we already don't
317+ # put wd on dt_bias because of the check
318+ # name.endswith("bias") in param_grouping.py
319+ self .dt_bias ._no_weight_decay = True
320+ setattr (self .dt_bias , "tensor_model_parallel" , True )
321+
322+ # A parameter
323+ assert A_init_range [0 ] > 0 and A_init_range [1 ] >= A_init_range [0 ]
324+ A = torch .empty (
325+ self .nheads_local_tp , dtype = torch .float32 , device = torch .cuda .current_device ()
326+ )
327+ if self .config .perform_initialization :
328+ A = A .uniform_ (* A_init_range )
329+ A_log = torch .log (A ) # Keep A_log in fp32
330+ self .A_log = nn .Parameter (A_log )
331+ self .A_log ._no_weight_decay = True
332+ setattr (self .A_log , "tensor_model_parallel" , True )
325333
326334 # D "skip" parameter
327335 self .D = nn .Parameter (
0 commit comments