Skip to content

Commit 863ba74

Browse files
authored
fix and optimize fill_kv_cache_quant (#4140)
1 parent 1a859f4 commit 863ba74

File tree

1 file changed

+217
-87
lines changed

1 file changed

+217
-87
lines changed

lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py

Lines changed: 217 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66
from torch import Tensor
77

88

9-
@triton.jit
10-
def _div_up(val, other):
11-
return (val + other - 1) // other
12-
13-
149
@triton.jit
1510
def _quant_int8(val):
1611
val_min = tl.min(val, 1)
@@ -126,6 +121,149 @@ def _fill_kv_cache_kernel(
126121
tl.store(vc_ptrs, v, mask=mask_vc)
127122

128123

124+
@triton.jit
125+
def _fill_page_quant_int8(
126+
state_ptr,
127+
cache_ptr,
128+
scales_zeros_ptr,
129+
block_off,
130+
head_id,
131+
page_offs,
132+
q_offs,
133+
kv_mask,
134+
head_dim: tl.constexpr,
135+
stride_ss,
136+
stride_sh,
137+
stride_sd,
138+
stride_cn: tl.constexpr,
139+
stride_cb: tl.constexpr,
140+
stride_ch: tl.constexpr,
141+
stride_cd: tl.constexpr,
142+
stride_szn: tl.constexpr,
143+
stride_szb: tl.constexpr,
144+
stride_szh: tl.constexpr,
145+
stride_szd: tl.constexpr,
146+
BLOCK_D: tl.constexpr,
147+
):
148+
"""Fill page int8."""
149+
d_off = tl.arange(0, BLOCK_D)
150+
mask_kc = kv_mask[:, None] & (d_off[None, :] < head_dim)
151+
state_ptr = state_ptr + head_id * stride_sh
152+
state_ptrs = state_ptr + q_offs[:, None] * stride_ss + d_off[None, :] * stride_sd
153+
cache_ptr = cache_ptr + block_off * stride_cn + head_id * stride_ch
154+
cache_ptrs = cache_ptr + page_offs[:, None] * stride_cb + d_off[None, :] * stride_cd
155+
scales_zeros_ptr = scales_zeros_ptr + block_off * stride_szn + head_id * stride_szh
156+
scales_ptrs = scales_zeros_ptr + page_offs[:, None] * stride_szb
157+
zeros_ptrs = scales_ptrs + stride_szd
158+
159+
state = tl.load(state_ptrs, mask=kv_mask[:, None])
160+
state, scales, zeros = _quant_int8(state)
161+
162+
tl.store(cache_ptrs, state, mask=mask_kc)
163+
tl.store(scales_ptrs, scales[:, None], mask=kv_mask[:, None])
164+
tl.store(zeros_ptrs, zeros[:, None], mask=kv_mask[:, None])
165+
166+
167+
@triton.jit
168+
def _fill_page_quant_int4(
169+
state_ptr,
170+
cache_ptr,
171+
scales_zeros_ptr,
172+
block_off,
173+
head_id,
174+
page_offs,
175+
q_offs,
176+
kv_mask,
177+
head_dim: tl.constexpr,
178+
stride_ss,
179+
stride_sh,
180+
stride_sd,
181+
stride_cn: tl.constexpr,
182+
stride_cb: tl.constexpr,
183+
stride_ch: tl.constexpr,
184+
stride_cd: tl.constexpr,
185+
stride_szn: tl.constexpr,
186+
stride_szb: tl.constexpr,
187+
stride_szh: tl.constexpr,
188+
stride_szd: tl.constexpr,
189+
BLOCK_D: tl.constexpr,
190+
):
191+
"""Fill page int4."""
192+
d_off = tl.arange(0, BLOCK_D)
193+
mask_kc = kv_mask[:, None] & (d_off[None, :] < head_dim)
194+
state_ptr = state_ptr + head_id * stride_sh
195+
state0_ptrs = state_ptr + q_offs[:, None] * stride_ss + d_off[None, :] * stride_sd
196+
state1_ptrs = state0_ptrs + head_dim * stride_sd
197+
cache_ptr = cache_ptr + block_off * stride_cn + head_id * stride_ch
198+
cache_ptrs = cache_ptr + page_offs[:, None] * stride_cb + d_off[None, :] * stride_cd
199+
scales_zeros_ptr = scales_zeros_ptr + block_off * stride_szn + head_id * stride_szh
200+
scales_ptrs = scales_zeros_ptr + page_offs[:, None] * stride_szb
201+
zeros_ptrs = scales_ptrs + stride_szd
202+
203+
state0 = tl.load(state0_ptrs, mask=mask_kc)
204+
state1 = tl.load(state1_ptrs, mask=mask_kc)
205+
state, scales, zeros = _quant_int4(state0, state1)
206+
207+
tl.store(cache_ptrs, state, mask=mask_kc)
208+
tl.store(scales_ptrs, scales[:, None], mask=kv_mask[:, None])
209+
tl.store(zeros_ptrs, zeros[:, None], mask=kv_mask[:, None])
210+
211+
212+
@triton.jit
213+
def _fill_page_quant(state_ptr, cache_ptr, scales_zeros_ptr, block_off, head_id, page_offs, q_offs, kv_mask,
214+
head_dim: tl.constexpr, stride_ss, stride_sh, stride_sd, stride_cn: tl.constexpr,
215+
stride_cb: tl.constexpr, stride_ch: tl.constexpr, stride_cd: tl.constexpr,
216+
stride_szn: tl.constexpr, stride_szb: tl.constexpr, stride_szh: tl.constexpr,
217+
stride_szd: tl.constexpr, BLOCK_D: tl.constexpr, quant_policy: tl.constexpr):
218+
"""Fill page."""
219+
if quant_policy == 8:
220+
return _fill_page_quant_int8(state_ptr,
221+
cache_ptr,
222+
scales_zeros_ptr,
223+
block_off,
224+
head_id,
225+
page_offs,
226+
q_offs,
227+
kv_mask,
228+
head_dim=head_dim,
229+
stride_ss=stride_ss,
230+
stride_sh=stride_sh,
231+
stride_sd=stride_sd,
232+
stride_cn=stride_cn,
233+
stride_cb=stride_cb,
234+
stride_ch=stride_ch,
235+
stride_cd=stride_cd,
236+
stride_szn=stride_szn,
237+
stride_szb=stride_szb,
238+
stride_szh=stride_szh,
239+
stride_szd=stride_szd,
240+
BLOCK_D=BLOCK_D)
241+
elif quant_policy == 4:
242+
return _fill_page_quant_int4(state_ptr,
243+
cache_ptr,
244+
scales_zeros_ptr,
245+
block_off,
246+
head_id,
247+
page_offs,
248+
q_offs,
249+
kv_mask,
250+
head_dim=head_dim,
251+
stride_ss=stride_ss,
252+
stride_sh=stride_sh,
253+
stride_sd=stride_sd,
254+
stride_cn=stride_cn,
255+
stride_cb=stride_cb,
256+
stride_ch=stride_ch,
257+
stride_cd=stride_cd,
258+
stride_szn=stride_szn,
259+
stride_szb=stride_szb,
260+
stride_szh=stride_szh,
261+
stride_szd=stride_szd,
262+
BLOCK_D=BLOCK_D)
263+
else:
264+
tl.static_assert(False, 'Unsupported quant policy')
265+
266+
129267
@triton.jit
130268
def _fill_kv_cache_quant_kernel(
131269
KStates,
@@ -138,7 +276,7 @@ def _fill_kv_cache_quant_kernel(
138276
QSeqLens,
139277
KVSeqLens,
140278
BlockOffsets,
141-
num_heads: tl.constexpr,
279+
is_decoding: tl.constexpr,
142280
head_dim: tl.constexpr,
143281
head_dim_v: tl.constexpr,
144282
stride_kss,
@@ -168,7 +306,6 @@ def _fill_kv_cache_quant_kernel(
168306
BLOCK: tl.constexpr,
169307
BLOCK_D: tl.constexpr,
170308
BLOCK_DV: tl.constexpr,
171-
BLOCK_H: tl.constexpr,
172309
):
173310
"""Fill kv cache kernel with int4 and int8 quant fuzed.
174311
@@ -181,88 +318,82 @@ def _fill_kv_cache_quant_kernel(
181318
stride_xh: stride of head_num dim
182319
stride_xd: stride of head_size dim
183320
"""
184-
batch_id = tl.program_id(0)
321+
batch_id = tl.program_id(2)
322+
head_id = tl.program_id(0)
185323
block_id = tl.program_id(1)
186-
d_off = tl.arange(0, BLOCK_D)
187-
188-
# initialize
189-
h_off = tl.arange(0, BLOCK_H)
190-
szd_off = tl.arange(0, 2)
191324

192325
q_startloc = tl.load(QStartLoc + batch_id)
193326
q_seqlen = tl.load(QSeqLens + batch_id)
194327
kv_seqlen = tl.load(KVSeqLens + batch_id)
195328
history_seqlen = kv_seqlen - q_seqlen
196329

197-
block0_first_tokenloc = history_seqlen % BLOCK
330+
kv_block_id = history_seqlen // BLOCK + block_id
331+
332+
if kv_seqlen <= 0:
333+
return
334+
335+
if kv_block_id * BLOCK >= kv_seqlen:
336+
return
337+
338+
if is_decoding:
339+
page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32)
340+
kv_mask = tl.full((1, ), 1, dtype=tl.int1)
341+
q_offs = tl.full((1, ), q_startloc, dtype=tl.int32)
342+
else:
343+
page_offs = tl.arange(0, BLOCK)
344+
kv_offs = kv_block_id * BLOCK + page_offs
345+
kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen)
346+
token_off = q_startloc + kv_block_id * BLOCK - history_seqlen
347+
q_offs = token_off + page_offs
198348

199-
state_token_offset = tl.maximum(block_id * BLOCK - block0_first_tokenloc, 0)
200-
kv_block_id = _div_up(history_seqlen + 1, BLOCK) - 1 + block_id
201-
kv_block_id = min(kv_block_id, stride_boff - 1)
202349
block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)
203350

204-
cur_startloc = q_startloc + state_token_offset
205-
ks_ptr = KStates + cur_startloc * stride_kss
206-
vs_ptr = VStates + cur_startloc * stride_vss
207-
208-
kc_ptr = KCaches + block_off * stride_kcn
209-
vc_ptr = VCaches + block_off * stride_vcn
210-
211-
ksz_ptr = KScalesZeros + block_off * stride_kszn
212-
vsz_ptr = VScalesZeros + block_off * stride_vszn
213-
214-
c_first_tokenloc = block0_first_tokenloc
215-
if block_id != 0:
216-
c_first_tokenloc *= 0
217-
c_last_tokenloc = tl.minimum(BLOCK, q_seqlen + block0_first_tokenloc - block_id * BLOCK)
218-
219-
for bidx in range(c_first_tokenloc, c_last_tokenloc):
220-
sidx = bidx - c_first_tokenloc
221-
mask = (h_off[:, None] < num_heads) & (d_off[None, :] < head_dim)
222-
if quant_policy == 4:
223-
k1 = tl.load(ks_ptr + sidx * stride_kss + h_off[:, None] * stride_ksh + d_off[None, :] * stride_ksd,
224-
mask=mask)
225-
k2 = tl.load(ks_ptr + sidx * stride_kss + h_off[:, None] * stride_ksh + d_off[None, :] * stride_ksd +
226-
head_dim * stride_ksd,
227-
mask=mask)
228-
q_k, k_scales, k_zeros = _quant_int4(k1, k2)
229-
else:
230-
k = tl.load(ks_ptr + sidx * stride_kss + h_off[:, None] * stride_ksh + d_off[None, :] * stride_ksd,
231-
mask=mask)
232-
q_k, k_scales, k_zeros = _quant_int8(k)
233-
tl.store(kc_ptr + bidx * stride_kcb + h_off[:, None] * stride_kch + d_off[None, :] * stride_kcd, q_k, mask=mask)
234-
tl.store(ksz_ptr + bidx * stride_kszb + h_off[:, None] * stride_kszh + szd_off[None, :] * stride_kszd,
235-
k_scales[:, None],
236-
mask=(h_off[:, None] < num_heads) & (szd_off[None, :] < 1))
237-
tl.store(ksz_ptr + bidx * stride_kszb + h_off[:, None] * stride_kszh + szd_off[None, :] * stride_kszd,
238-
k_zeros[:, None],
239-
mask=(h_off[:, None] < num_heads) & (szd_off[None, :] == 1))
240-
241-
if BLOCK_DV > 0:
242-
if quant_policy == 4:
243-
dv_off = tl.arange(0, BLOCK_DV // 2) # int4 pack, half the head_dim
244-
maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] < head_dim_v // 2)
245-
v1 = tl.load(vs_ptr + sidx * stride_vss + h_off[:, None] * stride_vsh + dv_off[None, :] * stride_vsd,
246-
mask=maskv)
247-
v2 = tl.load(vs_ptr + sidx * stride_vss + h_off[:, None] * stride_vsh + dv_off[None, :] * stride_vsd +
248-
head_dim_v // 2 * stride_vsd,
249-
mask=maskv)
250-
q_v, v_scales, v_zeros = _quant_int4(v1, v2)
251-
else:
252-
dv_off = tl.arange(0, BLOCK_DV)
253-
maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] < head_dim_v)
254-
v = tl.load(vs_ptr + sidx * stride_vss + h_off[:, None] * stride_vsh + dv_off[None, :] * stride_vsd,
255-
mask=maskv)
256-
q_v, v_scales, v_zeros = _quant_int8(v)
257-
tl.store(vc_ptr + bidx * stride_vcb + h_off[:, None] * stride_vch + dv_off[None, :] * stride_vcd,
258-
q_v,
259-
mask=maskv)
260-
tl.store(vsz_ptr + bidx * stride_vszb + h_off[:, None] * stride_vszh + szd_off[None, :] * stride_vszd,
261-
v_scales[:, None],
262-
mask=(h_off[:, None] < num_heads) & (szd_off[None, :] < 1))
263-
tl.store(vsz_ptr + bidx * stride_vszb + h_off[:, None] * stride_vszh + szd_off[None, :] * stride_vszd,
264-
v_zeros[:, None],
265-
mask=(h_off[:, None] < num_heads) & (szd_off[None, :] == 1))
351+
_fill_page_quant(KStates,
352+
KCaches,
353+
KScalesZeros,
354+
block_off,
355+
head_id,
356+
page_offs,
357+
q_offs,
358+
kv_mask,
359+
head_dim=head_dim,
360+
stride_ss=stride_kss,
361+
stride_sh=stride_ksh,
362+
stride_sd=stride_ksd,
363+
stride_cn=stride_kcn,
364+
stride_cb=stride_kcb,
365+
stride_ch=stride_kch,
366+
stride_cd=stride_kcd,
367+
stride_szn=stride_kszn,
368+
stride_szb=stride_kszb,
369+
stride_szh=stride_kszh,
370+
stride_szd=stride_kszd,
371+
BLOCK_D=BLOCK_D,
372+
quant_policy=quant_policy)
373+
374+
if BLOCK_DV > 0:
375+
_fill_page_quant(VStates,
376+
VCaches,
377+
VScalesZeros,
378+
block_off,
379+
head_id,
380+
page_offs,
381+
q_offs,
382+
kv_mask,
383+
head_dim=head_dim_v,
384+
stride_ss=stride_vss,
385+
stride_sh=stride_vsh,
386+
stride_sd=stride_vsd,
387+
stride_cn=stride_vcn,
388+
stride_cb=stride_vcb,
389+
stride_ch=stride_vch,
390+
stride_cd=stride_vcd,
391+
stride_szn=stride_vszn,
392+
stride_szb=stride_vszb,
393+
stride_szh=stride_vszh,
394+
stride_szd=stride_vszd,
395+
BLOCK_D=BLOCK_DV,
396+
quant_policy=quant_policy)
266397

267398

268399
def fill_kv_cache(k_states: Tensor,
@@ -291,21 +422,22 @@ def fill_kv_cache(k_states: Tensor,
291422
block_size = k_caches.size(s_dim)
292423
num_heads = k_caches.size(h_dim)
293424
head_dim = k_caches.size(d_dim)
294-
head_dim_v = v_states.size(-1)
425+
head_dim_v = v_caches.size(d_dim)
426+
if v_states.size(-1) == 0:
427+
head_dim_v = 0
295428
if max_q_seq_length == 1:
296429
max_num_blocks = 1
297430
else:
298431
max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1
299432

300433
BLOCK = block_size
301-
BLOCK_H = triton.next_power_of_2(num_heads)
302434
BLOCK_D = triton.next_power_of_2(head_dim)
303435
BLOCK_DV = triton.next_power_of_2(head_dim_v)
304436
if k_caches.data_ptr() == v_caches.data_ptr() and head_dim_v <= head_dim:
305437
BLOCK_DV = 0
438+
grid = (num_heads, max_num_blocks, batch_size)
439+
is_decoding = max_num_blocks == 1
306440
if quant_policy == 0:
307-
grid = (num_heads, max_num_blocks, batch_size)
308-
is_decoding = max_num_blocks == 1
309441
_fill_kv_cache_kernel[grid](
310442
k_states,
311443
v_states,
@@ -340,7 +472,6 @@ def fill_kv_cache(k_states: Tensor,
340472
num_stages=3,
341473
)
342474
else:
343-
grid = (batch_size, max_num_blocks)
344475
_fill_kv_cache_quant_kernel[grid](
345476
k_states,
346477
v_states,
@@ -352,7 +483,7 @@ def fill_kv_cache(k_states: Tensor,
352483
q_seq_length,
353484
kv_seq_length,
354485
block_offsets,
355-
num_heads=num_heads,
486+
is_decoding=is_decoding,
356487
head_dim=head_dim,
357488
head_dim_v=head_dim_v,
358489
stride_kss=k_states.stride(-3),
@@ -382,7 +513,6 @@ def fill_kv_cache(k_states: Tensor,
382513
BLOCK=BLOCK,
383514
BLOCK_D=BLOCK_D,
384515
BLOCK_DV=BLOCK_DV,
385-
BLOCK_H=BLOCK_H,
386516
num_warps=4,
387-
num_stages=3,
517+
num_stages=1,
388518
)

0 commit comments

Comments
 (0)