2727
2828namespace tensorrt_llm ::kernels::mnnvl
2929{
30+
31+ // Guard for internal helper functions
32+ namespace
33+ {
3034__device__ bool isNegZero (float v)
3135{
3236 return v == 0 .f && signbit (v);
@@ -49,6 +53,12 @@ inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val)
4953 return __bfloat162float (val);
5054}
5155
56+ template <>
57+ inline __device__ float toFloat<__nv_half>(__nv_half val)
58+ {
59+ return __half2float (val);
60+ }
61+
5262template <typename T>
5363inline __device__ T fromFloat (float val)
5464{
@@ -61,30 +71,76 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
6171 return __float2bfloat16 (val);
6272}
6373
64- __device__ float4 loadfloat4 (void const * ptr)
74+ template <>
75+ inline __device__ __nv_half fromFloat<__nv_half>(float val)
6576{
77+ return __float2half (val);
78+ }
6679
67- float return_value[4 ];
68-
69- asm volatile (" ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n "
70- : " =f" (return_value[0 ]), " =f" (return_value[1 ]), " =f" (return_value[2 ]), " =f" (return_value[3 ])
71- : " l" (ptr));
72-
73- return *(float4 *) return_value;
80+ inline __device__ float2 loadfloat2 (void const * ptr)
81+ {
82+ float2 return_value;
83+ asm volatile (" ld.volatile.global.v2.f32 {%0, %1}, [%2];\n " : " =f" (return_value.x ), " =f" (return_value.y ) : " l" (ptr));
84+ return return_value;
7485}
7586
76- __device__ __inline__ float2 loadfloat2 (void const * ptr)
87+ template <typename T>
88+ inline __device__ T divUp (T val, T divisor)
7789{
90+ return (val + divisor - 1 ) / divisor;
91+ }
7892
79- float return_value[2 ];
93+ __device__ struct __attribute__ ((aligned(32 ))) LamportFlags
94+ {
95+ uint32_t buffer_size;
96+ uint32_t input_offset;
97+ uint32_t clear_offset;
98+ uint32_t num_tokens_prev;
99+ uint32_t * offset_access_ptr;
100+ uint32_t * buffer_flags;
101+
102+ __device__ explicit LamportFlags (uint32_t * buffer_flags)
103+ : offset_access_ptr (&buffer_flags[4 ])
104+ , buffer_flags (buffer_flags)
105+ {
106+ uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
107+ buffer_size = flag.z ;
108+ input_offset = flag.x * (buffer_size << 1U );
109+ clear_offset = flag.y * (buffer_size << 1U );
110+ num_tokens_prev = flag.w ;
111+ }
80112
81- asm volatile (" ld.volatile.global.v2.f32 {%0, %1}, [%2];\n "
82- : " =f" (return_value[0 ]), " =f" (return_value[1 ])
83- : " l" (ptr)
84- : " memory" );
113+ __device__ void cta_arrive ()
114+ {
115+ __syncthreads ();
116+ if (threadIdx .x == 0 )
117+ {
118+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
119+ asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
120+ #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
121+ asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
122+ #else
123+ atomicAdd (offset_access_ptr, 1 );
124+ #endif
125+ }
126+ }
85127
86- return *(float2 *) return_value;
87- }
128+ __device__ void wait_and_update (uint32_t num_tokens)
129+ {
130+ if (threadIdx .x == 0 && blockIdx .x == gridDim .x - 1 && blockIdx .y == 0 )
131+ {
132+ while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
133+ {
134+ }
135+ uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
136+ buffer_flags[0 ] = (flag.x + 1 ) % 3 ;
137+ buffer_flags[1 ] = (flag.y + 1 ) % 3 ;
138+ buffer_flags[3 ] = num_tokens;
139+ *(offset_access_ptr) = 0 ;
140+ }
141+ }
142+ };
143+ } // namespace
88144
89145template <int WORLD_SIZE, typename T>
90146__global__ void twoshot_allreduce_kernel (T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
@@ -99,13 +155,14 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
99155 cudaGridDependencySynchronize ();
100156#endif
101157
102- // [input_ptr, clear_ptr, buffer_size, access_counter]
103- uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
104- // Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
105- uint32_t buffer_group_size = flag.z << 1 ;
106- uint32_t input_offset = flag.x * buffer_group_size;
107- uint32_t clear_offset = flag.y * buffer_group_size;
108- uint32_t * offset_access_ptr = &buffer_flags[3 ];
158+ LamportFlags flags (buffer_flags);
159+
160+ // Capture the number of tokens in previous iteration so that we can properly clear the buffer
161+ // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
162+ uint32_t clr_toks_cta
163+ = divUp<uint32_t >(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, WORLD_SIZE)
164+ * WORLD_SIZE;
165+ clr_toks_cta = divUp<uint32_t >(clr_toks_cta, gridDim .x );
109166
110167 if (elt < token_dim)
111168 {
@@ -115,29 +172,33 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
115172 T val = shard_ptr[token * token_dim + elt];
116173 if (isNegZero (val))
117174 val = fromFloat<T>(0 .f );
118- input_ptrs[dest_rank][input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt] = val;
175+ input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt]
176+ = val;
119177
120- // Reduce and broadcast
178+ // Clear the buffer used by the previous call. Note the number of tokens to clear could be larger than the
179+ // number of tokens in the current call.
180+ for (int clr_tok = 0 ; clr_tok < clr_toks_cta; clr_tok++)
181+ {
182+ uint32_t clr_token_idx = token + clr_tok * gridDim .x ;
183+ if (clr_token_idx < buffer_M)
184+ {
185+ input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat<T>(-0 .f );
186+ }
187+ }
121188
189+ // Reduce and broadcast
122190 if ((token % WORLD_SIZE) == rank)
123191 {
124192 int local_token = token / WORLD_SIZE;
125193 float accum = 0 .f ;
126194
127195 T values[WORLD_SIZE];
128-
129- for (int r = 0 ; r < WORLD_SIZE; r++)
130- {
131- input_ptrs[rank][clear_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt]
132- = fromFloat<T>(-0 .f );
133- }
134-
135196 while (1 )
136197 {
137198 bool valid = true ;
138199 for (int r = 0 ; r < WORLD_SIZE; r++)
139200 {
140- T volatile * lamport_ptr = (T volatile *) &input_ptrs[rank][input_offset
201+ T volatile * lamport_ptr = (T volatile *) &input_ptrs[rank][flags. input_offset
141202 + local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
142203 values[r] = *lamport_ptr;
143204 valid &= !isNegZero (values[r]);
@@ -149,40 +210,39 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
149210 {
150211 accum += toFloat<T>(values[r]);
151212 }
152- mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
213+ mcast_ptr[flags. input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
153214 }
154215 }
155216
156217#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
157218 cudaTriggerProgrammaticLaunchCompletion ();
158219#endif
159220
160- input_ptrs[rank][clear_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(-0 .f );
221+ // Similarly clear broadcast buffer here
222+ for (int clr_tok = 0 ; clr_tok < clr_toks_cta; clr_tok++)
223+ {
224+ uint32_t clr_token_idx = token + clr_tok * gridDim .x ;
225+ if (clr_token_idx < buffer_M)
226+ {
227+ input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
228+ = fromFloat<T>(-0 .f );
229+ }
230+ }
161231
162232 // Optionally wait for results if the next layer isn't doing the Lamport check
163233 if (wait_for_results)
164234 {
165235 // Update the atomic counter to indicate the block has read the offsets
166- __syncthreads ();
236+ flags. cta_arrive ();
167237
168- if (threadIdx .x == 0 )
169- {
170- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
171- asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
172- #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
173- asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
174- #else
175- atomicAdd (offset_access_ptr, 1 );
176- #endif
177- }
178238 // Only use a set of CTAs for lamport sync, reargange the grid
179239 constexpr int ELTS_PER_LOAD = sizeof (float2 ) / sizeof (T);
180240 // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
181241 if (threadIdx .x < (blockDim .x / ELTS_PER_LOAD))
182242 {
183243 uint64_t current_pos = blockIdx .x * token_dim + blockIdx .y * blockDim .x + threadIdx .x * ELTS_PER_LOAD;
184244
185- void * lamport_ptr = (void *) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
245+ void * lamport_ptr = (void *) &input_ptrs[rank][flags. input_offset + buffer_M * token_dim + current_pos];
186246 // We have 2 assumptions here:
187247 // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
188248 // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
@@ -198,16 +258,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
198258 }
199259
200260 // Update the buffer flags
201- if (threadIdx .x == 0 && blockIdx .x == gridDim .x - 1 && blockIdx .y == 0 )
202- {
203- // Make sure all blocks have finished reading the offsets, 2-D grid
204- while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
205- {
206- }
207- buffer_flags[0 ] = (flag.x + 1 ) % 3 ;
208- buffer_flags[1 ] = (flag.y + 1 ) % 3 ;
209- *(offset_access_ptr) = 0 ;
210- }
261+ flags.wait_and_update (num_tokens);
211262 }
212263}
213264
@@ -273,12 +324,28 @@ void twoshot_allreduce_op(AllReduceParams const& params)
273324 default : TLLM_CHECK_WITH_INFO (false , " TwoShot AllReduce]: unsupported world_size." );
274325 }
275326 }
327+ else if (dtype == nvinfer1::DataType::kHALF )
328+ {
329+ switch (world_size)
330+ {
331+ case 2 : LAUNCH_ALL_REDUCE_KERNEL (2 , __nv_half); break ;
332+ case 4 : LAUNCH_ALL_REDUCE_KERNEL (4 , __nv_half); break ;
333+ case 8 : LAUNCH_ALL_REDUCE_KERNEL (8 , __nv_half); break ;
334+ case 16 : LAUNCH_ALL_REDUCE_KERNEL (16 , __nv_half); break ;
335+ case 32 : LAUNCH_ALL_REDUCE_KERNEL (32 , __nv_half); break ;
336+ case 64 : LAUNCH_ALL_REDUCE_KERNEL (64 , __nv_half); break ;
337+ default : TLLM_CHECK_WITH_INFO (false , " TwoShot AllReduce]: unsupported world_size." );
338+ }
339+ }
276340 else
277341 {
278342 TLLM_CHECK_WITH_INFO (false , " TwoShot AllReduce]: unsupported dtype." );
279343 }
280344}
281345
346+ // Guard for internal helper functions
347+ namespace
348+ {
282349template <typename T_IN>
283350__device__ void copy_f4 (T_IN* dst, T_IN const * src)
284351{
@@ -327,6 +394,19 @@ inline __device__ float block_reduce_sum(float val)
327394 return val;
328395}
329396
397+ __device__ float4 loadfloat4 (void const * ptr)
398+ {
399+
400+ float4 return_value;
401+
402+ asm volatile (" ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n "
403+ : " =f" (return_value.x ), " =f" (return_value.y ), " =f" (return_value.z ), " =f" (return_value.w )
404+ : " l" (ptr));
405+
406+ return return_value;
407+ }
408+ } // namespace
409+
330410template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>
331411__global__ void __launch_bounds__ (128 , 1 )
332412 RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const * buffer_input, T_IN const * gamma, float epsilon,
@@ -353,12 +433,8 @@ __global__ void __launch_bounds__(128, 1)
353433
354434 int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
355435
356- uint32_t * offset_access_ptr = &buffer_flags[3 ];
357- uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
358- // Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
359- uint32_t buffer_size = flag.z ;
360- uint32_t buffer_offset = flag.x * (buffer_size << 1 );
361- T_IN const * input = &buffer_input[buffer_offset + buffer_size];
436+ LamportFlags flags (buffer_flags);
437+ T_IN const * input = &buffer_input[flags.input_offset + flags.buffer_size ];
362438
363439 cudaTriggerProgrammaticLaunchCompletion ();
364440
@@ -388,17 +464,7 @@ __global__ void __launch_bounds__(128, 1)
388464 }
389465
390466 __pipeline_commit ();
391- __syncthreads ();
392- if (threadIdx .x == 0 )
393- {
394- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
395- asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
396- #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
397- asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
398- #else
399- atomicAdd (offset_access_ptr, 1 );
400- #endif
401- }
467+ flags.cta_arrive ();
402468 // Load all inputs
403469 bool valid = false ;
404470
@@ -528,16 +594,7 @@ __global__ void __launch_bounds__(128, 1)
528594 = out4;
529595 }
530596 // Update the buffer pointers
531- if (threadIdx .x == 0 && blockIdx .x == 0 && blockIdx .y == 0 )
532- {
533- // Make sure all blocks have finished accessing the buffer
534- while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
535- {
536- }
537- buffer_flags[0 ] = (flag.x + 1 ) % 3 ;
538- buffer_flags[1 ] = (flag.y + 1 ) % 3 ;
539- *(offset_access_ptr) = 0 ;
540- }
597+ flags.wait_and_update (batch_size);
541598#endif
542599}
543600
@@ -548,8 +605,6 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons
548605
549606 // input to rmsnorm is the buffer in the twoshot ar
550607 // We should use prenorm output to determine the actual used size
551- // int batch = normed_output.sizes()[0];
552- // int dim = normed_output.sizes()[1];
553608 float _epsilon{static_cast <float >(epsilon)};
554609
555610 static constexpr int NUM_THREADS = 128 ;
@@ -612,6 +667,20 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
612667 default : TLLM_CHECK_WITH_INFO (false , " [MNNVL TwoShot RMSNorm]: unsupported hidden_dim." );
613668 }
614669 }
670+ else if (dtype == nvinfer1::DataType::kHALF )
671+ {
672+ switch (params.hidden_dim )
673+ {
674+ case 2048 : LAUNCH_RMSNORM_KERNEL (__nv_half, 2048 ); break ;
675+ case 4096 : LAUNCH_RMSNORM_KERNEL (__nv_half, 4096 ); break ;
676+ // Llama-4 Hidden Dimension
677+ case 5120 : LAUNCH_RMSNORM_KERNEL (__nv_half, 5120 ); break ;
678+ // DeepSeek Hidden Dimension
679+ case 7168 : LAUNCH_RMSNORM_KERNEL (__nv_half, 7168 ); break ;
680+ case 8192 : LAUNCH_RMSNORM_KERNEL (__nv_half, 8192 ); break ;
681+ default : TLLM_CHECK_WITH_INFO (false , " [MNNVL TwoShot RMSNorm]: unsupported hidden_dim." );
682+ }
683+ }
615684 else
616685 {
617686 TLLM_CHECK_WITH_INFO (false , " [MNNVL TwoShot RMSNorm]: unsupported dtype." );
0 commit comments