@@ -1582,10 +1582,6 @@ def round_up(x, alignment):
15821582 intermediate_size_per_partition_padded //
15831583 weight_vec_size )
15841584
1585- print ("creating weights for NVFP4 fused MoE" )
1586- print ("padded w3_w1_weight_shape:" , w3_w1_weight_shape )
1587- print ("padded w2_weight_shape:" , w2_weight_shape )
1588-
15891585 # Divide by 4 because we use int32 to pack 4 fp8 values
15901586 # column parallel
15911587 w3_w1_weight_scale = nn .Parameter (
@@ -1596,10 +1592,6 @@ def round_up(x, alignment):
15961592 dtype = block_scales_dtype ),
15971593 requires_grad = False )
15981594 module .register_parameter ("w3_w1_weight_scale" , w3_w1_weight_scale )
1599- print ("w3_w1_hidden_size_padded:" , w3_w1_hidden_size_padded )
1600- print ("module.scaling_vector_size:" , module .scaling_vector_size )
1601- print ("block_scales_vec_size:" , block_scales_vec_size )
1602- print ("w3_w1_weight_scale shape:" , w3_w1_weight_scale .shape )
16031595
16041596 # row parallel
16051597 w2_weight_scale = nn .Parameter (
@@ -1610,7 +1602,6 @@ def round_up(x, alignment):
16101602 dtype = block_scales_dtype ),
16111603 requires_grad = False )
16121604 module .register_parameter ("w2_weight_scale" , w2_weight_scale )
1613- print ("w2_weight_scale shape:" , w2_weight_scale .shape )
16141605
16151606 fc31_input_scale = nn .Parameter (torch .tensor (1. , dtype = torch .float32 ),
16161607 requires_grad = False )
@@ -1922,18 +1913,14 @@ def _fp4_quantize_pad_unpad(weight: torch.Tensor, alignment: Tuple[int, int]):
19221913 assert weight .device .type == 'cuda' , "Only cuda tensor is supported."
19231914 assert weight .dtype == torch .bfloat16 , "Only bfloat16 tensor is supported."
19241915
1925- print ("Original shape:" , weight .shape )
19261916 padding_dim_0 = (alignment [0 ] -
19271917 weight .shape [0 ] % alignment [0 ]) % alignment [0 ]
19281918 padding_dim_1 = (alignment [1 ] -
19291919 weight .shape [1 ] % alignment [1 ]) % alignment [1 ]
1930- print ("Padding:" , (padding_dim_0 , padding_dim_1 ))
19311920 weight_padded = torch .nn .functional .pad (
19321921 weight , (0 , padding_dim_1 , 0 , padding_dim_0 ))
1933- print ("Padded shape:" , weight_padded .shape )
19341922
19351923 global_scale_factor = (448 * 6 ) / weight .abs ().max ().float ()
1936- print ("Global scale factor:" , global_scale_factor .item ())
19371924
19381925 weight_nvfp4 , block_scale_factor = torch .ops .trtllm .fp4_quantize (
19391926 weight ,
@@ -1942,8 +1929,6 @@ def _fp4_quantize_pad_unpad(weight: torch.Tensor, alignment: Tuple[int, int]):
19421929 sfUseUE8M0 = False ,
19431930 isSfSwizzledLayout = False )
19441931 block_scale_factor = block_scale_factor .view (weight .shape [0 ], - 1 )
1945- print ("Weight nvfp4 shape:" , weight_nvfp4 .shape )
1946- print ("Block scale factor shape:" , block_scale_factor .shape )
19471932 return weight_nvfp4 , global_scale_factor , block_scale_factor
19481933
19491934
@@ -1993,10 +1978,6 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
19931978 dst_w3_w1_weight_gpu = dst_w3_w1_weight if dst_on_gpu else dst_w3_w1_weight .cuda (
19941979 )
19951980
1996- print ("w1 shape and dtype before pad:" , w1_weight .shape ,
1997- w1_weight .dtype )
1998- print ("w3 shape and dtype before pad:" , w3_weight .shape ,
1999- w3_weight .dtype )
20001981 alignment = _get_weight_alignment (self .weight_alignment ,
20011982 module .scaling_vector_size ,
20021983 module .tp_size , w1_weight .shape [0 ])
@@ -2017,8 +1998,6 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
20171998 assert len (w3_weight .shape ) == 1
20181999 w1_weight = maybe_pad_for_weights (w1_weight , alignment ).float ()
20192000 w3_weight = maybe_pad_for_weights (w3_weight , alignment ).float ()
2020- print ("w1 shape and dtype after pad:" , w1_weight .shape , w1_weight .dtype )
2021- print ("w3 shape and dtype after pad:" , w3_weight .shape , w3_weight .dtype )
20222001
20232002 w1_weight_shard = load_weight_shard (w1_weight ,
20242003 module .tp_size ,
@@ -2063,8 +2042,6 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
20632042 dst_w2_weight_gpu = dst_w2_weight if dst_on_gpu else dst_w2_weight .cuda (
20642043 )
20652044
2066- print ("w2 shape and dtype before pad:" , w2_weight .shape ,
2067- w2_weight .dtype )
20682045 shard_w2_weight_dim = 2 * w2_weight .shape [1 ] if len (
20692046 w2_weight .shape ) == 2 else w2_weight .shape [0 ]
20702047 alignment = _get_weight_alignment (self .weight_alignment ,
@@ -2077,7 +2054,6 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
20772054 else :
20782055 assert len (w2_weight .shape ) == 1
20792056 w2_weight = maybe_pad_for_weights (w2_weight , self .weight_alignment )
2080- print ("w2 shape and dtype after pad:" , w2_weight .shape , w2_weight .dtype )
20812057
20822058 w2_weight_shard = load_weight_shard (w2_weight ,
20832059 module .tp_size ,
@@ -2119,10 +2095,6 @@ def load_expert_w3_w1_weight_scale_nvfp4(
21192095 dst_w3_w1_weight_scale_gpu = dst_w3_w1_weight_scale if dst_on_gpu else dst_w3_w1_weight_scale .cuda (
21202096 )
21212097
2122- print ("w1 scale shape and dtype before pad:" , w1_weight_scale .shape ,
2123- w1_weight_scale .dtype )
2124- print ("w3 scale shape and dtype before pad:" , w3_weight_scale .shape ,
2125- w3_weight_scale .dtype )
21262098 alignment = _get_weight_alignment (self .weight_alignment ,
21272099 module .scaling_vector_size ,
21282100 module .tp_size ,
@@ -2135,10 +2107,6 @@ def load_expert_w3_w1_weight_scale_nvfp4(
21352107 w3_weight_scale ,
21362108 self .input_hidden_alignment // module .scaling_vector_size ,
21372109 alignment )
2138- print ("w1 scale shape and dtype after pad:" , w1_weight_scale .shape ,
2139- w1_weight_scale .dtype )
2140- print ("w3 scale shape and dtype after pad:" , w3_weight_scale .shape ,
2141- w3_weight_scale .dtype )
21422110
21432111 w1_weight_scale = load_weight_shard (w1_weight_scale ,
21442112 module .tp_size ,
@@ -2197,17 +2165,13 @@ def load_expert_w2_weight_scale_nvfp4(self,
21972165 dst_w2_weight_scale_gpu = dst_w2_weight_scale if dst_on_gpu else dst_w2_weight_scale .cuda (
21982166 )
21992167
2200- print ("w2 scale shape and dtype before pad:" , w2_weight_scale .shape ,
2201- w2_weight_scale .dtype )
22022168 alignment = _get_weight_alignment (self .weight_alignment ,
22032169 module .scaling_vector_size ,
22042170 module .tp_size ,
22052171 w2_weight_scale .shape [- 1 ])
22062172 w2_weight_scale = maybe_pad_for_weights (
22072173 w2_weight_scale , alignment // module .scaling_vector_size ,
22082174 self .weight_alignment )
2209- print ("w2 scale shape and dtype after pad:" , w2_weight_scale .shape ,
2210- w2_weight_scale .dtype )
22112175
22122176 w2_weight_scale = load_weight_shard (w2_weight_scale ,
22132177 module .tp_size ,
0 commit comments