diff --git a/tests/ref_gdr.py b/tests/ref_gdr.py index fde5335..b6f7c7c 100644 --- a/tests/ref_gdr.py +++ b/tests/ref_gdr.py @@ -203,7 +203,8 @@ def torch_chunk_gdr_fwd( if cu_seqlens is not None: vn = pack(vn, cu_seqlens) - h = pack(h, prepare_chunk_offsets(cu_seqlens, chunk_size)) + chunk_offsets, _ = prepare_chunk_offsets(cu_seqlens, chunk_size) + h = pack(h, chunk_offsets) return h, vn, last_state @@ -223,7 +224,8 @@ def torch_chunk_o_fwd( k = unpack(k, cu_seqlens) v = unpack(v, cu_seqlens) g = unpack(g, cu_seqlens) - h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size)) + chunk_offsets, _ = prepare_chunk_offsets(cu_seqlens, chunk_size) + h = unpack(h, chunk_offsets) batch_size, num_tokens, num_k_heads, head_dim_k = k.shape _, _, num_v_heads, head_dim_v = v.shape @@ -379,7 +381,8 @@ def torch_chunk_gdr_bwd( dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens] if cu_seqlens is not None: dv = pack(dv, cu_seqlens) - dh = pack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size)) + chunk_offsets, _ = prepare_chunk_offsets(cu_seqlens, chunk_size) + dh = pack(dh, chunk_offsets) return dh, dh0, dv @@ -405,8 +408,9 @@ def torch_chunk_dqkwg_bwd( g = unpack(g, cu_seqlens) do = unpack(do, cu_seqlens) dv = unpack(dv, cu_seqlens) - h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size)) - dh = unpack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size)) + chunk_offsets, _ = prepare_chunk_offsets(cu_seqlens, chunk_size) + h = unpack(h, chunk_offsets) + dh = unpack(dh, chunk_offsets) batch_size, num_tokens, num_k_heads, head_dim_k = k.shape _, _, num_v_heads, head_dim_v = do.shape