66from torch import Tensor
77
88
9- @triton .jit
10- def _div_up (val , other ):
11- return (val + other - 1 ) // other
12-
13-
149@triton .jit
1510def _quant_int8 (val ):
1611 val_min = tl .min (val , 1 )
@@ -126,6 +121,149 @@ def _fill_kv_cache_kernel(
126121 tl .store (vc_ptrs , v , mask = mask_vc )
127122
128123
124+ @triton .jit
125+ def _fill_page_quant_int8 (
126+ state_ptr ,
127+ cache_ptr ,
128+ scales_zeros_ptr ,
129+ block_off ,
130+ head_id ,
131+ page_offs ,
132+ q_offs ,
133+ kv_mask ,
134+ head_dim : tl .constexpr ,
135+ stride_ss ,
136+ stride_sh ,
137+ stride_sd ,
138+ stride_cn : tl .constexpr ,
139+ stride_cb : tl .constexpr ,
140+ stride_ch : tl .constexpr ,
141+ stride_cd : tl .constexpr ,
142+ stride_szn : tl .constexpr ,
143+ stride_szb : tl .constexpr ,
144+ stride_szh : tl .constexpr ,
145+ stride_szd : tl .constexpr ,
146+ BLOCK_D : tl .constexpr ,
147+ ):
148+ """Fill page int8."""
149+ d_off = tl .arange (0 , BLOCK_D )
150+ mask_kc = kv_mask [:, None ] & (d_off [None , :] < head_dim )
151+ state_ptr = state_ptr + head_id * stride_sh
152+ state_ptrs = state_ptr + q_offs [:, None ] * stride_ss + d_off [None , :] * stride_sd
153+ cache_ptr = cache_ptr + block_off * stride_cn + head_id * stride_ch
154+ cache_ptrs = cache_ptr + page_offs [:, None ] * stride_cb + d_off [None , :] * stride_cd
155+ scales_zeros_ptr = scales_zeros_ptr + block_off * stride_szn + head_id * stride_szh
156+ scales_ptrs = scales_zeros_ptr + page_offs [:, None ] * stride_szb
157+ zeros_ptrs = scales_ptrs + stride_szd
158+
159+ state = tl .load (state_ptrs , mask = kv_mask [:, None ])
160+ state , scales , zeros = _quant_int8 (state )
161+
162+ tl .store (cache_ptrs , state , mask = mask_kc )
163+ tl .store (scales_ptrs , scales [:, None ], mask = kv_mask [:, None ])
164+ tl .store (zeros_ptrs , zeros [:, None ], mask = kv_mask [:, None ])
165+
166+
167+ @triton .jit
168+ def _fill_page_quant_int4 (
169+ state_ptr ,
170+ cache_ptr ,
171+ scales_zeros_ptr ,
172+ block_off ,
173+ head_id ,
174+ page_offs ,
175+ q_offs ,
176+ kv_mask ,
177+ head_dim : tl .constexpr ,
178+ stride_ss ,
179+ stride_sh ,
180+ stride_sd ,
181+ stride_cn : tl .constexpr ,
182+ stride_cb : tl .constexpr ,
183+ stride_ch : tl .constexpr ,
184+ stride_cd : tl .constexpr ,
185+ stride_szn : tl .constexpr ,
186+ stride_szb : tl .constexpr ,
187+ stride_szh : tl .constexpr ,
188+ stride_szd : tl .constexpr ,
189+ BLOCK_D : tl .constexpr ,
190+ ):
191+ """Fill page int4."""
192+ d_off = tl .arange (0 , BLOCK_D )
193+ mask_kc = kv_mask [:, None ] & (d_off [None , :] < head_dim )
194+ state_ptr = state_ptr + head_id * stride_sh
195+ state0_ptrs = state_ptr + q_offs [:, None ] * stride_ss + d_off [None , :] * stride_sd
196+ state1_ptrs = state0_ptrs + head_dim * stride_sd
197+ cache_ptr = cache_ptr + block_off * stride_cn + head_id * stride_ch
198+ cache_ptrs = cache_ptr + page_offs [:, None ] * stride_cb + d_off [None , :] * stride_cd
199+ scales_zeros_ptr = scales_zeros_ptr + block_off * stride_szn + head_id * stride_szh
200+ scales_ptrs = scales_zeros_ptr + page_offs [:, None ] * stride_szb
201+ zeros_ptrs = scales_ptrs + stride_szd
202+
203+ state0 = tl .load (state0_ptrs , mask = mask_kc )
204+ state1 = tl .load (state1_ptrs , mask = mask_kc )
205+ state , scales , zeros = _quant_int4 (state0 , state1 )
206+
207+ tl .store (cache_ptrs , state , mask = mask_kc )
208+ tl .store (scales_ptrs , scales [:, None ], mask = kv_mask [:, None ])
209+ tl .store (zeros_ptrs , zeros [:, None ], mask = kv_mask [:, None ])
210+
211+
212+ @triton .jit
213+ def _fill_page_quant (state_ptr , cache_ptr , scales_zeros_ptr , block_off , head_id , page_offs , q_offs , kv_mask ,
214+ head_dim : tl .constexpr , stride_ss , stride_sh , stride_sd , stride_cn : tl .constexpr ,
215+ stride_cb : tl .constexpr , stride_ch : tl .constexpr , stride_cd : tl .constexpr ,
216+ stride_szn : tl .constexpr , stride_szb : tl .constexpr , stride_szh : tl .constexpr ,
217+ stride_szd : tl .constexpr , BLOCK_D : tl .constexpr , quant_policy : tl .constexpr ):
218+ """Fill page."""
219+ if quant_policy == 8 :
220+ return _fill_page_quant_int8 (state_ptr ,
221+ cache_ptr ,
222+ scales_zeros_ptr ,
223+ block_off ,
224+ head_id ,
225+ page_offs ,
226+ q_offs ,
227+ kv_mask ,
228+ head_dim = head_dim ,
229+ stride_ss = stride_ss ,
230+ stride_sh = stride_sh ,
231+ stride_sd = stride_sd ,
232+ stride_cn = stride_cn ,
233+ stride_cb = stride_cb ,
234+ stride_ch = stride_ch ,
235+ stride_cd = stride_cd ,
236+ stride_szn = stride_szn ,
237+ stride_szb = stride_szb ,
238+ stride_szh = stride_szh ,
239+ stride_szd = stride_szd ,
240+ BLOCK_D = BLOCK_D )
241+ elif quant_policy == 4 :
242+ return _fill_page_quant_int4 (state_ptr ,
243+ cache_ptr ,
244+ scales_zeros_ptr ,
245+ block_off ,
246+ head_id ,
247+ page_offs ,
248+ q_offs ,
249+ kv_mask ,
250+ head_dim = head_dim ,
251+ stride_ss = stride_ss ,
252+ stride_sh = stride_sh ,
253+ stride_sd = stride_sd ,
254+ stride_cn = stride_cn ,
255+ stride_cb = stride_cb ,
256+ stride_ch = stride_ch ,
257+ stride_cd = stride_cd ,
258+ stride_szn = stride_szn ,
259+ stride_szb = stride_szb ,
260+ stride_szh = stride_szh ,
261+ stride_szd = stride_szd ,
262+ BLOCK_D = BLOCK_D )
263+ else :
264+ tl .static_assert (False , 'Unsupported quant policy' )
265+
266+
129267@triton .jit
130268def _fill_kv_cache_quant_kernel (
131269 KStates ,
@@ -138,7 +276,7 @@ def _fill_kv_cache_quant_kernel(
138276 QSeqLens ,
139277 KVSeqLens ,
140278 BlockOffsets ,
141- num_heads : tl .constexpr ,
279+ is_decoding : tl .constexpr ,
142280 head_dim : tl .constexpr ,
143281 head_dim_v : tl .constexpr ,
144282 stride_kss ,
@@ -168,7 +306,6 @@ def _fill_kv_cache_quant_kernel(
168306 BLOCK : tl .constexpr ,
169307 BLOCK_D : tl .constexpr ,
170308 BLOCK_DV : tl .constexpr ,
171- BLOCK_H : tl .constexpr ,
172309):
173310 """Fill kv cache kernel with int4 and int8 quant fuzed.
174311
@@ -181,88 +318,82 @@ def _fill_kv_cache_quant_kernel(
181318 stride_xh: stride of head_num dim
182319 stride_xd: stride of head_size dim
183320 """
184- batch_id = tl .program_id (0 )
321+ batch_id = tl .program_id (2 )
322+ head_id = tl .program_id (0 )
185323 block_id = tl .program_id (1 )
186- d_off = tl .arange (0 , BLOCK_D )
187-
188- # initialize
189- h_off = tl .arange (0 , BLOCK_H )
190- szd_off = tl .arange (0 , 2 )
191324
192325 q_startloc = tl .load (QStartLoc + batch_id )
193326 q_seqlen = tl .load (QSeqLens + batch_id )
194327 kv_seqlen = tl .load (KVSeqLens + batch_id )
195328 history_seqlen = kv_seqlen - q_seqlen
196329
197- block0_first_tokenloc = history_seqlen % BLOCK
330+ kv_block_id = history_seqlen // BLOCK + block_id
331+
332+ if kv_seqlen <= 0 :
333+ return
334+
335+ if kv_block_id * BLOCK >= kv_seqlen :
336+ return
337+
338+ if is_decoding :
339+ page_offs = tl .full ((1 , ), history_seqlen % BLOCK , dtype = tl .int32 )
340+ kv_mask = tl .full ((1 , ), 1 , dtype = tl .int1 )
341+ q_offs = tl .full ((1 , ), q_startloc , dtype = tl .int32 )
342+ else :
343+ page_offs = tl .arange (0 , BLOCK )
344+ kv_offs = kv_block_id * BLOCK + page_offs
345+ kv_mask = (kv_offs >= history_seqlen ) & (kv_offs < kv_seqlen )
346+ token_off = q_startloc + kv_block_id * BLOCK - history_seqlen
347+ q_offs = token_off + page_offs
198348
199- state_token_offset = tl .maximum (block_id * BLOCK - block0_first_tokenloc , 0 )
200- kv_block_id = _div_up (history_seqlen + 1 , BLOCK ) - 1 + block_id
201- kv_block_id = min (kv_block_id , stride_boff - 1 )
202349 block_off = tl .load (BlockOffsets + batch_id * stride_boff + kv_block_id )
203350
204- cur_startloc = q_startloc + state_token_offset
205- ks_ptr = KStates + cur_startloc * stride_kss
206- vs_ptr = VStates + cur_startloc * stride_vss
207-
208- kc_ptr = KCaches + block_off * stride_kcn
209- vc_ptr = VCaches + block_off * stride_vcn
210-
211- ksz_ptr = KScalesZeros + block_off * stride_kszn
212- vsz_ptr = VScalesZeros + block_off * stride_vszn
213-
214- c_first_tokenloc = block0_first_tokenloc
215- if block_id != 0 :
216- c_first_tokenloc *= 0
217- c_last_tokenloc = tl .minimum (BLOCK , q_seqlen + block0_first_tokenloc - block_id * BLOCK )
218-
219- for bidx in range (c_first_tokenloc , c_last_tokenloc ):
220- sidx = bidx - c_first_tokenloc
221- mask = (h_off [:, None ] < num_heads ) & (d_off [None , :] < head_dim )
222- if quant_policy == 4 :
223- k1 = tl .load (ks_ptr + sidx * stride_kss + h_off [:, None ] * stride_ksh + d_off [None , :] * stride_ksd ,
224- mask = mask )
225- k2 = tl .load (ks_ptr + sidx * stride_kss + h_off [:, None ] * stride_ksh + d_off [None , :] * stride_ksd +
226- head_dim * stride_ksd ,
227- mask = mask )
228- q_k , k_scales , k_zeros = _quant_int4 (k1 , k2 )
229- else :
230- k = tl .load (ks_ptr + sidx * stride_kss + h_off [:, None ] * stride_ksh + d_off [None , :] * stride_ksd ,
231- mask = mask )
232- q_k , k_scales , k_zeros = _quant_int8 (k )
233- tl .store (kc_ptr + bidx * stride_kcb + h_off [:, None ] * stride_kch + d_off [None , :] * stride_kcd , q_k , mask = mask )
234- tl .store (ksz_ptr + bidx * stride_kszb + h_off [:, None ] * stride_kszh + szd_off [None , :] * stride_kszd ,
235- k_scales [:, None ],
236- mask = (h_off [:, None ] < num_heads ) & (szd_off [None , :] < 1 ))
237- tl .store (ksz_ptr + bidx * stride_kszb + h_off [:, None ] * stride_kszh + szd_off [None , :] * stride_kszd ,
238- k_zeros [:, None ],
239- mask = (h_off [:, None ] < num_heads ) & (szd_off [None , :] == 1 ))
240-
241- if BLOCK_DV > 0 :
242- if quant_policy == 4 :
243- dv_off = tl .arange (0 , BLOCK_DV // 2 ) # int4 pack, half the head_dim
244- maskv = (h_off [:, None ] < num_heads ) & (dv_off [None , :] < head_dim_v // 2 )
245- v1 = tl .load (vs_ptr + sidx * stride_vss + h_off [:, None ] * stride_vsh + dv_off [None , :] * stride_vsd ,
246- mask = maskv )
247- v2 = tl .load (vs_ptr + sidx * stride_vss + h_off [:, None ] * stride_vsh + dv_off [None , :] * stride_vsd +
248- head_dim_v // 2 * stride_vsd ,
249- mask = maskv )
250- q_v , v_scales , v_zeros = _quant_int4 (v1 , v2 )
251- else :
252- dv_off = tl .arange (0 , BLOCK_DV )
253- maskv = (h_off [:, None ] < num_heads ) & (dv_off [None , :] < head_dim_v )
254- v = tl .load (vs_ptr + sidx * stride_vss + h_off [:, None ] * stride_vsh + dv_off [None , :] * stride_vsd ,
255- mask = maskv )
256- q_v , v_scales , v_zeros = _quant_int8 (v )
257- tl .store (vc_ptr + bidx * stride_vcb + h_off [:, None ] * stride_vch + dv_off [None , :] * stride_vcd ,
258- q_v ,
259- mask = maskv )
260- tl .store (vsz_ptr + bidx * stride_vszb + h_off [:, None ] * stride_vszh + szd_off [None , :] * stride_vszd ,
261- v_scales [:, None ],
262- mask = (h_off [:, None ] < num_heads ) & (szd_off [None , :] < 1 ))
263- tl .store (vsz_ptr + bidx * stride_vszb + h_off [:, None ] * stride_vszh + szd_off [None , :] * stride_vszd ,
264- v_zeros [:, None ],
265- mask = (h_off [:, None ] < num_heads ) & (szd_off [None , :] == 1 ))
351+ _fill_page_quant (KStates ,
352+ KCaches ,
353+ KScalesZeros ,
354+ block_off ,
355+ head_id ,
356+ page_offs ,
357+ q_offs ,
358+ kv_mask ,
359+ head_dim = head_dim ,
360+ stride_ss = stride_kss ,
361+ stride_sh = stride_ksh ,
362+ stride_sd = stride_ksd ,
363+ stride_cn = stride_kcn ,
364+ stride_cb = stride_kcb ,
365+ stride_ch = stride_kch ,
366+ stride_cd = stride_kcd ,
367+ stride_szn = stride_kszn ,
368+ stride_szb = stride_kszb ,
369+ stride_szh = stride_kszh ,
370+ stride_szd = stride_kszd ,
371+ BLOCK_D = BLOCK_D ,
372+ quant_policy = quant_policy )
373+
374+ if BLOCK_DV > 0 :
375+ _fill_page_quant (VStates ,
376+ VCaches ,
377+ VScalesZeros ,
378+ block_off ,
379+ head_id ,
380+ page_offs ,
381+ q_offs ,
382+ kv_mask ,
383+ head_dim = head_dim_v ,
384+ stride_ss = stride_vss ,
385+ stride_sh = stride_vsh ,
386+ stride_sd = stride_vsd ,
387+ stride_cn = stride_vcn ,
388+ stride_cb = stride_vcb ,
389+ stride_ch = stride_vch ,
390+ stride_cd = stride_vcd ,
391+ stride_szn = stride_vszn ,
392+ stride_szb = stride_vszb ,
393+ stride_szh = stride_vszh ,
394+ stride_szd = stride_vszd ,
395+ BLOCK_D = BLOCK_DV ,
396+ quant_policy = quant_policy )
266397
267398
268399def fill_kv_cache (k_states : Tensor ,
@@ -291,21 +422,22 @@ def fill_kv_cache(k_states: Tensor,
291422 block_size = k_caches .size (s_dim )
292423 num_heads = k_caches .size (h_dim )
293424 head_dim = k_caches .size (d_dim )
294- head_dim_v = v_states .size (- 1 )
425+ head_dim_v = v_caches .size (d_dim )
426+ if v_states .size (- 1 ) == 0 :
427+ head_dim_v = 0
295428 if max_q_seq_length == 1 :
296429 max_num_blocks = 1
297430 else :
298431 max_num_blocks = triton .cdiv (max_q_seq_length , block_size ) + 1
299432
300433 BLOCK = block_size
301- BLOCK_H = triton .next_power_of_2 (num_heads )
302434 BLOCK_D = triton .next_power_of_2 (head_dim )
303435 BLOCK_DV = triton .next_power_of_2 (head_dim_v )
304436 if k_caches .data_ptr () == v_caches .data_ptr () and head_dim_v <= head_dim :
305437 BLOCK_DV = 0
438+ grid = (num_heads , max_num_blocks , batch_size )
439+ is_decoding = max_num_blocks == 1
306440 if quant_policy == 0 :
307- grid = (num_heads , max_num_blocks , batch_size )
308- is_decoding = max_num_blocks == 1
309441 _fill_kv_cache_kernel [grid ](
310442 k_states ,
311443 v_states ,
@@ -340,7 +472,6 @@ def fill_kv_cache(k_states: Tensor,
340472 num_stages = 3 ,
341473 )
342474 else :
343- grid = (batch_size , max_num_blocks )
344475 _fill_kv_cache_quant_kernel [grid ](
345476 k_states ,
346477 v_states ,
@@ -352,7 +483,7 @@ def fill_kv_cache(k_states: Tensor,
352483 q_seq_length ,
353484 kv_seq_length ,
354485 block_offsets ,
355- num_heads = num_heads ,
486+ is_decoding = is_decoding ,
356487 head_dim = head_dim ,
357488 head_dim_v = head_dim_v ,
358489 stride_kss = k_states .stride (- 3 ),
@@ -382,7 +513,6 @@ def fill_kv_cache(k_states: Tensor,
382513 BLOCK = BLOCK ,
383514 BLOCK_D = BLOCK_D ,
384515 BLOCK_DV = BLOCK_DV ,
385- BLOCK_H = BLOCK_H ,
386516 num_warps = 4 ,
387- num_stages = 3 ,
517+ num_stages = 1 ,
388518 )
0 commit comments