@@ -399,10 +399,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
399399 auto shifted_buffer_x_int4 = channel_x_buffers.buffer () + token_idx_in_buffer * hidden_int4;
400400 auto shifted_recv_x_int4 = recv_x + static_cast <int64_t >(total_offset + chunk_idx) * hidden_int4;
401401#ifndef DISABLE_SM90_FEATURES
402- if ( elect_one_sync ()) {
403- # pragma unroll
404- for ( int i = 0 ; i < 2 ; ++ i) {
405- tma_store_wait< 0 >();
402+ # pragma unroll
403+ for ( int i = 0 ; i < 2 ; ++ i) {
404+ tma_store_wait< 0 >();
405+ if ( elect_one_sync ()) {
406406 tma_load_1d (tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes);
407407 mbarrier_arrive_and_expect_tx (tma_mbarrier, half_hidden_bytes);
408408 mbarrier_wait (tma_mbarrier, tma_phase);
@@ -589,6 +589,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
589589
590590 constexpr int kDtypePerInt4 = sizeof (int4 ) / sizeof (dtype_t );
591591 int hidden_int4 = hidden * sizeof (dtype_t ) / sizeof (int4 );
592+ int hidden_int4_aligned = align_down (hidden_int4, 32 );
592593 auto x_int4 = reinterpret_cast <const int4 *>(x);
593594 auto bias_0_int4 = reinterpret_cast <const int4 *>(bias_0);
594595 auto bias_1_int4 = reinterpret_cast <const int4 *>(bias_1);
@@ -791,8 +792,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
791792
792793 // Wait shared memory release
793794#ifndef DISABLE_SM90_FEATURES
794- if (elect_one_sync ())
795- tma_store_wait<0 >();
795+ tma_store_wait<0 >();
796796 __syncwarp ();
797797#endif
798798
@@ -837,26 +837,29 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
837837 out_dtypes[j] = static_cast <dtype_t >(values[j]);
838838
839839#ifndef DISABLE_SM90_FEATURES
840- // Wait TMA arrival
841- if ( elect_one_sync ())
840+ if (i < hidden_int4_aligned) {
841+ // Wait TMA arrival
842842 tma_store_wait<kNumStages - 1 >();
843- __syncwarp ();
844-
845- // Write into TMA buffer
846- auto tma_stage_idx = (i / 32 ) % kNumStages ;
847- reinterpret_cast <int4 *>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;
848-
849- // Issue TMA
850- tma_store_fence ();
851- __syncwarp ();
852- if (elect_one_sync ()) {
853- auto tma_bytes = min (32 , hidden_int4 - i) * static_cast <int >(sizeof (int4 ));
854- tma_store_1d (reinterpret_cast <int4 *>(tma_buffer) + tma_stage_idx * 32 ,
855- recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false );
843+ __syncwarp ();
844+
845+ // Write into TMA buffer
846+ auto tma_stage_idx = (i / 32 ) % kNumStages ;
847+ reinterpret_cast <int4 *>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;
848+
849+ // Issue TMA
850+ tma_store_fence ();
851+ __syncwarp ();
852+ if (elect_one_sync ()) {
853+ auto tma_bytes = min (32 , hidden_int4 - i) * static_cast <int >(sizeof (int4 ));
854+ tma_store_1d (reinterpret_cast <int4 *>(tma_buffer) + tma_stage_idx * 32 ,
855+ recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false );
856+ }
857+ __syncwarp ();
858+ } else {
859+ #endif
860+ recv_int4[token_idx * hidden_int4 + i] = out_int4;
861+ #ifndef DISABLE_SM90_FEATURES
856862 }
857- __syncwarp ();
858- #else
859- recv_int4[token_idx * hidden_int4 + i] = out_int4;
860863#endif
861864 }
862865
0 commit comments