diff --git a/evals/harness.py b/evals/harness.py index 24739c9c0..e54620783 100644 --- a/evals/harness.py +++ b/evals/harness.py @@ -2,11 +2,12 @@ from __future__ import annotations -import fla # noqa from lm_eval.__main__ import cli_evaluate from lm_eval.api.registry import register_model from lm_eval.models.huggingface import HFLM +import fla # noqa + @register_model('fla') class FlashLinearAttentionLMWrapper(HFLM): diff --git a/fla/ops/gated_delta_product/chunk.py b/fla/ops/gated_delta_product/chunk.py index ef1ef89dd..4a968d49c 100644 --- a/fla/ops/gated_delta_product/chunk.py +++ b/fla/ops/gated_delta_product/chunk.py @@ -8,11 +8,9 @@ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd -from fla.ops.delta_rule.chunk import chunk_delta_rule_bwd from fla.ops.delta_rule.wy_fast import recompute_w_u_fwd as dn_recompute_w_u_fwd from fla.ops.gated_delta_product.chunk_deltaproduct_h import chunk_gated_delta_product_fwd_h from fla.ops.gated_delta_product.chunk_deltaproduct_o import chunk_gated_delta_product_fwd_o -from fla.ops.gated_delta_rule.chunk import chunk_gated_delta_rule_bwd from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as gdn_recompute_w_u_fwd from fla.ops.utils import chunk_local_cumsum, solve_tril from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard @@ -93,6 +91,224 @@ def chunk_gated_delta_product_fwd( return g, g_interleaved, o, A, final_state +def chunk_gated_delta_product_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + g_interleaved: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + h: torch.Tensor, + v_new: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + initial_state: Optional[torch.Tensor] = None, + num_householder: int = 1, +): + # from fla.ops.gated_delta_product.chunk_deltaproduct_h import chunk_gated_delta_product_bwd_dhu + # from fla.ops.gated_delta_product.chunk_deltaproduct_o import chunk_gated_delta_product_bwd_o + + # cu_seqlens_dp = cu_seqlens * num_householder if cu_seqlens is not None else None + + # # compute gradients wrt q, k, v + # # dv_new_o is gradient w.r.t. v_new + # # might not be fully parallelizable due to gradient Q including previous states H_0 to H_T + # dq_o, dv_new_o, dk_o, dh = chunk_gated_delta_product_bwd_o( + # q=q, + # k=k, + # v=v_new, # v_new = U[i] - W[i]H[i]^T + # h=h, + # g=g, # forward_h takes in g instead of g_interleaved + # scale=scale, + # cu_seqlens=cu_seqlens, + # num_householder=num_householder, + # do=do, # gradient of the output + # ) + + # # recompute w, u from WY representation to compute gradient + # from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as gdn_recompute_w_u_fwd + # from fla.ops.delta_rule.wy_fast import recompute_w_u_fwd as dn_recompute_w_u_fwd + + # if g_interleaved is not None: + # w, u = gdn_recompute_w_u_fwd( + # k=k, v=v, beta=beta, A=A, g=g_interleaved, cu_seqlens=cu_seqlens_dp, + # ) + # else: + # w, u = dn_recompute_w_u_fwd( + # k=k, v=v, beta=beta, A=A, cu_seqlens=cu_seqlens_dp, + # ) + + # # compute gradients with respect to u and w + # # but need to account for gradients used for sequential computation of hidden states of H_0 to H_T (sequential) + # dh0, du, dw = chunk_gated_delta_product_bwd_dhu( + # q=q, + # k=k, + # w=w, + # u=u, + # g=g_interleaved, + # initial_state=initial_state, #H_0 + # cu_seqlens=cu_seqlens_dp, + # num_householder=num_householder, + # dht=dht, # gradient w.r.t to last hidden state + # dv=dv_new_o, # gradient w.r.t. v_new + # scale=scale, + # ) + + # # compute gradients w.r.t. WY representation (dk, dv, dbeta, dg) + # # This involves computing gradients through the Householder transformations + # from fla.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd + + # # gradient descent from W and U + # # g is used for computing W and U in the forward pass + # # dk2 accounts for the gradient of hidden state wrt to k + # # dv_final is the gradient of hidden state wrt to v + # # this can be fully parallelized + # dk2, dv, dbeta, dg = prepare_wy_repr_bwd( + # k=k, + # v=v, + # beta=beta, + # g=g_interleaved, + # A=A, + # dw=dw, # Use key gradients from output as weights gradients + # du=du, # Use value gradients from hidden state backward + # cu_seqlens=cu_seqlens_dp, + # ) + + # # accumulate gradients + # # dk_final = dk_o + dk2 + # # dk_final = dk2 # should there be (Q[i] K[i]^T \cdot M) + # dv_final = dv + # dg_final = dg + + # # process gating gradients with local cumsum (reverse) + # if g is not None: + # from fla.ops.utils import chunk_local_cumsum + # dg_final = chunk_local_cumsum(dg_final, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens_dp) + + # # Convert interleaved gating gradients back to original format + # dg_final = rearrange(dg_final, 'b (l n) h -> b l n h', n=num_householder)[:, :, 0].contiguous() + # else: + # dg_final = None + + # return dq_o, dk, dv_final, dg_final, dbeta, dh0 + + from fla.ops.gated_delta_product.chunk_deltaproduct_h import chunk_gated_delta_product_bwd_dhu + + cu_seqlens_dp = cu_seqlens * num_householder if cu_seqlens is not None else None + + # recompute w, u from WY representation to compute gradient + from fla.ops.delta_rule.wy_fast import recompute_w_u_fwd as dn_recompute_w_u_fwd + from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as gdn_recompute_w_u_fwd + + if g_interleaved is not None: + w, u = gdn_recompute_w_u_fwd( + k=k, v=v, beta=beta, A=A, g=g_interleaved, cu_seqlens=cu_seqlens_dp, + ) + else: + w, u = dn_recompute_w_u_fwd( + k=k, v=v, beta=beta, A=A, cu_seqlens=cu_seqlens_dp, + ) + + # dv_new is gradient w.r.t. v_new + # Eq: Q[i] @ H[i] + (Q[i]K[i]^T ⊙ M) * v_new[i] + # recurses from do to (Q[i]K[i]^T ⊙ M) * v_new[i] to compute dv_new + # fully parallelizable + from fla.ops.common.chunk_o import chunk_bwd_dv_local + + # is the same for chunk gated delta product and chunk gated delta rule + dv_new = chunk_bwd_dv_local( + q=q, + k=k, + g=g, # chunk_gated_delta_product_fwd_o uses g + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + ) + + # du is the gradient wrt to u + # dh is [B, NT, H, K, V] (NT is number of chunks or groups of tokens) the gradient wrt to H[0] to H[T] + # need to account for gradients used for sequential computation of hidden states of H_0 to H_T (sequential) + dh, dh0, du = chunk_gated_delta_product_bwd_dhu( + q=q, + k=k, + w=w, + g=g_interleaved, # chunk_gated_delta_product_fwd_h uses g_interleaved + h0=initial_state, # H_0 + dht=dht, # gradient w.r.t to last hidden state + do=do, # gradient of the output + dv=dv_new, # gradient w.r.t. v_new + scale=scale, + cu_seqlens=cu_seqlens_dp, # use cu_seqlens_dp which is expanded + num_householder=num_householder, + ) + + from fla.ops.common.chunk_o import chunk_bwd_dqkwg + + # TODO implement delta product version of chunk_bwd_dqkwg + # dq and dw is final + # By multivariate chain rule, for L = f(O) and O = g(k, v_new(k)) + # dL/dK = dL/dO * dO/dK + dL/dO * dO/dv_new * dv_new/dK + # dk_direct_gradient is the direct gradient computed considering v_new as a fixed quantity (by product rule) + # dk_direct_gradient should be fully parallelizable + # might be sequential as dq depends on H (might need forward pass) + dq, dk_direct_gradient, dw, dg_local = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, # v_new = U[i] - W[i]H[i]^T + w=w, + g=g, # should this be g or g_interleaved? since we don't find dk for the hidden states, is it g + h=h, + dv=du, # can be thought as gradient wrt v_new + do=do, + dh=dh, + scale=scale, + cu_seqlens=cu_seqlens_dp, # cu_seqlens * num_householder + ) + + # compute gradients w.r.t. WY representation (dk, dv, dbeta, dg) + # This involves computing gradients through the Householder transformations + from fla.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd + + # compute gradient descent wrt W and U + # g_interleaved is used for computing W and U in the forward pass + # dv is the final gradient wrt to v (only place v appears) + # this should be fully parallelized + # TODO implement delta product version of prepare_wy_repr_bwd + dk_hidden_state_gradient, dv, dbeta, dg2 = prepare_wy_repr_bwd( + k=k, + v=v, + beta=beta, + g=g_interleaved, + A=A, + dw=dw, # Use key gradients from output as weights gradients + du=du, # Use value gradients from hidden state backward + cu_seqlens=cu_seqlens_dp, + ) + + # accumulate gradients + dk_direct_gradient.add_(dk_hidden_state_gradient) # dL/dK = dL/dO * dO/dK + dL/dO * dO/dv_new * dv_new/dK + # by product rule add the two together + dv_final = dv + dg_final = dg2 + if dg_local is not None: + dg_final = dg_final + dg_local # dL/dg = dL/dO * dO/dg + dL/dO * dO/dv_new * dv_new/dg = dg_local + dg2 + + # process gating gradients with local cumsum (reverse) + if g is not None: + from fla.ops.utils import chunk_local_cumsum + dg_final = chunk_local_cumsum(dg_final, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens_dp) + + # Convert interleaved gating gradients back to original format + dg_final = rearrange(dg_final, 'b (l n) h -> b l n h', n=num_householder)[:, :, 0].contiguous() + else: + dg_final = None + + return dq, dk_direct_gradient, dv_final, dg_final, dbeta, dh0 + + class ChunkGatedDeltaProductFunction(torch.autograd.Function): @staticmethod @@ -144,48 +360,120 @@ def backward( do: torch.Tensor, dht: torch.Tensor ): + # q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors + # q_new = q.new_zeros(q.shape[0], q.shape[1], ctx.num_householder, q.shape[2], q.shape[3]) + # q_new[:, :, -1] = q + # do_new = do.new_zeros(do.shape[0], do.shape[1], ctx.num_householder, do.shape[2], do.shape[3]) + # do_new[:, :, -1] = do + # q_org, q = q, rearrange(q_new, 'b t n h d -> b (t n) h d') + # do = rearrange(do_new, 'b t n h d -> b (t n) h d') + # # call the gated deltanet kernel for now. + # # TODO: optimize the backward pass like the forward pass. + # if g is not None: + # dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + # q=q, + # k=k, + # v=v, + # g=g, + # beta=beta, + # A=A, + # scale=ctx.scale, + # initial_state=initial_state, + # do=do, + # dht=dht, + # cu_seqlens=cu_seqlens * ctx.num_householder if cu_seqlens is not None else None, + # ) + # dg = rearrange(dg, 'b (l n) h -> b l n h ', n=ctx.num_householder)[:, :, 0].contiguous().to(g) + # else: + # dq, dk, dv, db, dh0 = chunk_delta_rule_bwd( + # q=q, + # k=k, + # v=v, + # beta=beta, + # A=A, + # scale=ctx.scale, + # initial_state=initial_state, + # do=do, + # dht=dht, + # cu_seqlens=cu_seqlens * ctx.num_householder if cu_seqlens is not None else None, + # ) + # dg = None + # dq = rearrange(dq, 'b (l n) h d -> b l n h d', n=ctx.num_householder)[:, :, -1].contiguous() + # if ctx.use_qk_l2norm_in_kernel: + # dq = l2norm_bwd(q_org, q_rstd, dq) + # dk = l2norm_bwd(k, k_rstd, dk) + # return dq.to(q), dk.to(k), dv.to(v), dg, db.to(beta), None, None, dh0, None, None, None + + # New backward impelmentation q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors - q_new = q.new_zeros(q.shape[0], q.shape[1], ctx.num_householder, q.shape[2], q.shape[3]) - q_new[:, :, -1] = q - do_new = do.new_zeros(do.shape[0], do.shape[1], ctx.num_householder, do.shape[2], do.shape[3]) - do_new[:, :, -1] = do - q_org, q = q, rearrange(q_new, 'b t n h d -> b (t n) h d') - do = rearrange(do_new, 'b t n h d -> b (t n) h d') - # call the gated deltanet kernel for now. - # TODO: optimize the backward pass like the forward pass. + + # recompute forward intermediate values + from fla.ops.delta_rule.wy_fast import recompute_w_u_fwd as dn_recompute_w_u_fwd + from fla.ops.gated_delta_product.chunk_deltaproduct_h import chunk_gated_delta_product_fwd_h + from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as gdn_recompute_w_u_fwd + + cu_seqlens_dp = cu_seqlens * ctx.num_householder if cu_seqlens is not None else None + + # recompute w, u from WY representation (gated and non-gated) if g is not None: - dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( - q=q, - k=k, - v=v, - g=g, - beta=beta, - A=A, - scale=ctx.scale, - initial_state=initial_state, - do=do, - dht=dht, - cu_seqlens=cu_seqlens * ctx.num_householder if cu_seqlens is not None else None, + w, u = gdn_recompute_w_u_fwd( + k=k, v=v, beta=beta, A=A, g=g, cu_seqlens=cu_seqlens_dp, ) - dg = rearrange(dg, 'b (l n) h -> b l n h ', n=ctx.num_householder)[:, :, 0].contiguous().to(g) else: - dq, dk, dv, db, dh0 = chunk_delta_rule_bwd( - q=q, - k=k, - v=v, - beta=beta, - A=A, - scale=ctx.scale, - initial_state=initial_state, - do=do, - dht=dht, - cu_seqlens=cu_seqlens * ctx.num_householder if cu_seqlens is not None else None, + w, u = dn_recompute_w_u_fwd( + k=k, v=v, beta=beta, A=A, cu_seqlens=cu_seqlens_dp, + ) + + # recompute h and v_new from hidden state computation (why can't we store this originally in ctx) + # v_new = (U[t] - W[t]H[t]^T) + h, v_new, _ = chunk_gated_delta_product_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens_dp, + num_householder=ctx.num_householder, + ) + + # need to get g and g_interleaved + if g is not None: + g_interleaved = g.new_zeros(g.shape[0], g.shape[1], ctx.num_householder, g.shape[2], dtype=torch.float32) + g_interleaved[:, :, 0] = g + g_interleaved = rearrange(g_interleaved, 'b l n h -> b (l n) h').contiguous() + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens, output_dtype=torch.float32) + g_interleaved = chunk_local_cumsum( + g_interleaved, chunk_size=64, cu_seqlens=cu_seqlens_dp, output_dtype=torch.float32 ) - dg = None - dq = rearrange(dq, 'b (l n) h d -> b l n h d', n=ctx.num_householder)[:, :, -1].contiguous() + else: + g_interleaved = None + g = None + + # call our optimized backward pass + dq, dk, dv, dg, db, dh0 = chunk_gated_delta_product_bwd( + q=q, + k=k, + v=v, + g=g, # use g + g_interleaved=g_interleaved, # use computed g_interleaved + beta=beta, + A=A, + h=h, + v_new=v_new, + do=do, + dht=dht, + scale=ctx.scale, + cu_seqlens=cu_seqlens, + initial_state=initial_state, + num_householder=ctx.num_householder, + ) + + # if use_qk_l2norm_in_kernel, do l2norm_bwd (calculate gradient for l2norm) if ctx.use_qk_l2norm_in_kernel: - dq = l2norm_bwd(q_org, q_rstd, dq) + dq = l2norm_bwd(q, q_rstd, dq) dk = l2norm_bwd(k, k_rstd, dk) + return dq.to(q), dk.to(k), dv.to(v), dg, db.to(beta), None, None, dh0, None, None, None diff --git a/fla/ops/gated_delta_product/chunk_deltaproduct_h.py b/fla/ops/gated_delta_product/chunk_deltaproduct_h.py index 37f7ca068..5f1b23f04 100644 --- a/fla/ops/gated_delta_product/chunk_deltaproduct_h.py +++ b/fla/ops/gated_delta_product/chunk_deltaproduct_h.py @@ -225,6 +225,7 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64( chunk_offsets, scale, T, + num_householder: tl.constexpr, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, @@ -237,6 +238,7 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64( ): i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos @@ -245,7 +247,9 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64( else: bos, eos = i_n * T, i_n * T + T NT = tl.cdiv(T, BT) - boh = i_n * NT + # boh = i_n * tl.cdiv(T, BT) + # Jinha: update boh to match the chunk_gated_delta_product_fwd_kernel_h_blockdim64 implementation + boh = i_n * tl.cdiv(T // num_householder, BT) # [BK, BV] b_dh1 = tl.zeros([64, BV], dtype=tl.float32) @@ -312,13 +316,13 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64( b_g_exp = None p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_wo = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - b_wo = tl.load(p_wo, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) b_dv = tl.zeros([BT, BV], dtype=tl.float32) - # Update dv + # Update dv based on hidden state gradients p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_dv += tl.dot(b_k, b_dh1.to(b_k.dtype)) @@ -344,7 +348,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64( b_dv += tl.load(p_dv, boundary_check=(0, 1)) tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - # Update dh + + # Update hidden state gradients p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) @@ -353,7 +358,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64( b_dh1 *= bg_last_exp b_q = b_q * b_g_exp[None, :] b_q = (b_q * scale).to(b_q.dtype) - b_dh1 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype)) + b_dh1 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 64: p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) @@ -363,7 +369,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64( b_dh2 *= bg_last_exp b_q = b_q * b_g_exp[None, :] b_q = (b_q * scale).to(b_q.dtype) - b_dh2 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype)) + b_dh2 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 128: p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) @@ -373,7 +380,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64( b_dh3 *= bg_last_exp b_q = b_q * b_g_exp[None, :] b_q = (b_q * scale).to(b_q.dtype) - b_dh3 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype)) + b_dh3 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 192: p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) @@ -383,7 +391,7 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64( b_dh4 *= bg_last_exp b_q = b_q * b_g_exp[None, :] b_q = (b_q * scale).to(b_q.dtype) - b_dh4 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype)) + b_dh4 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype)) if USE_INITIAL_STATE: p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) @@ -460,19 +468,25 @@ def chunk_gated_delta_product_bwd_dhu( dv: torch.Tensor, scale: float, cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + chunk_size: int = 64, + num_householder: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, T, H, K, V = *q.shape, do.shape[-1] + assert T % num_householder == 0, "T must be divisible by num_householder" + T_true = T // num_householder # N: the actual number of sequences in the batch with either equal or variable lengths BT = 64 assert K <= 256, "current kernel does not support head dimension being larger than 256." - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + chunk_indices = prepare_chunk_indices(cu_seqlens // num_householder, chunk_size) if cu_seqlens is not None else None if cu_seqlens is None: - N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + N, NT, chunk_offsets = B, triton.cdiv(T_true, BT), None else: - N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, len(chunk_indices), + prepare_chunk_offsets(cu_seqlens // num_householder, BT) + ) dh = q.new_empty(B, NT, H, K, V) dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None @@ -494,9 +508,12 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H) chunk_offsets=chunk_offsets, scale=scale, T=T, + num_householder=num_householder, H=H, K=K, V=V, BT=BT, ) + # could call chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64 instead + # after adjusting number of tokens return dh, dh0, dv2 diff --git a/fla/ops/gated_delta_product/chunk_deltaproduct_o.py b/fla/ops/gated_delta_product/chunk_deltaproduct_o.py index 665db10fe..d83dbe8ce 100644 --- a/fla/ops/gated_delta_product/chunk_deltaproduct_o.py +++ b/fla/ops/gated_delta_product/chunk_deltaproduct_o.py @@ -152,3 +152,212 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) BT=BT, ) return o + +# @triton.heuristics({ +# 'USE_G': lambda args: args['g'] is not None, +# 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +# }) +# @triton.autotune( +# configs=[ +# triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) +# for BK in BKV_LIST +# for BV in BKV_LIST +# for num_warps in NUM_WARPS +# for num_stages in [2, 3, 4] +# ], +# key=['H', 'K', 'V', 'BT'], +# ) +# @triton.jit(do_not_specialize=['T']) +# def chunk_gated_delta_product_bwd_kernel_o( +# q, +# k, +# v, +# h, +# g, +# do, +# dq, +# dk, +# dv, +# dh, +# cu_seqlens, +# chunk_indices, +# scale, +# T, +# num_householder: tl.constexpr, +# H: tl.constexpr, +# K: tl.constexpr, +# V: tl.constexpr, +# BT: tl.constexpr, +# BK: tl.constexpr, +# BV: tl.constexpr, +# USE_G: tl.constexpr, +# IS_VARLEN: tl.constexpr, +# ): +# # same parameters as forward pass +# i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) +# i_b, i_h = i_bh // H, i_bh % H + +# if IS_VARLEN: +# i_tg = i_t +# i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) +# bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) +# T = eos - bos +# NT = tl.cdiv(T, BT) +# else: +# NT = tl.cdiv(T, BT) +# i_tg = i_b * NT + i_t +# bos, eos = i_b * T, i_b * T + T + +# # offset calculation +# q += (bos * H + i_h) * K +# k += (bos * H + i_h) * K +# v += (bos * H + i_h) * V +# do += (bos * H + i_h) * V +# dq += (bos * H + i_h) * K +# dk += (bos * H + i_h) * K +# dv += (bos * H + i_h) * V +# h += (i_tg * H + i_h).to(tl.int64) * K*V +# dh += (i_tg * H + i_h).to(tl.int64) * K*V + +# b_dq = tl.zeros([BT, BK], dtype=tl.float32) +# b_dk = tl.zeros([BT, BK], dtype=tl.float32) +# b_dv = tl.zeros([BT, BV], dtype=tl.float32) +# b_ds = tl.zeros([BT, BT], dtype=tl.float32) + +# # Compute gradients from hidden state +# for i_k in range(tl.cdiv(K, BK)): +# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) +# p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) +# p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) +# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + +# # [BT, BK] +# b_q = tl.load(p_q, boundary_check=(0, 1)) +# # [BK, BV] +# b_h = tl.load(p_h, boundary_check=(0, 1)) +# b_dh = tl.load(p_dh, boundary_check=(0, 1)) +# # [BT, BV] +# b_do = tl.load(p_do, boundary_check=(0, 1)) + +# # Compute gradients w.r.t. q: dq += do @ h^T +# b_dq += tl.dot(b_do, tl.trans(b_h)) + +# # Compute gradients w.r.t. h: dh += q^T @ do +# tl.store(p_dh, (b_dh + tl.dot(tl.trans(b_q), b_do)).to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + +# # Process multiple Householder transformations +# for i_dp in range(num_householder): +# b_A = tl.zeros([BT, BT], dtype=tl.float32) + +# # Compute attention matrix A = Q @ K^T for this Householder step +# for i_k in range(tl.cdiv(K, BK)): +# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) +# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + +# b_q = tl.load(p_q, boundary_check=(0, 1)) +# b_k = tl.load(p_k, boundary_check=(0, 1)) +# b_A += tl.dot(b_q, b_k) + +# # Apply causal mask and gating +# o_t = i_t * BT + tl.arange(0, BT) +# m_t = o_t < T +# if USE_G: +# g += bos * H + i_h +# p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) +# b_g = tl.load(p_g, boundary_check=(0,)) +# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) +# b_A = tl.where(m_A, b_A * exp(b_g[:, None] - b_g[None, :]), 0) +# else: +# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) +# b_A = tl.where(m_A, b_A, 0) + +# # Load values for this Householder step +# p_v = tl.make_block_ptr(v+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) +# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) +# b_v = tl.load(p_v, boundary_check=(0, 1)) +# b_do = tl.load(p_do, boundary_check=(0, 1)) + +# # Gradient w.r.t. values: dv += A^T @ do +# b_dv += tl.dot(tl.trans(b_A.to(b_v.dtype)), b_do) + +# # Gradient w.r.t. attention scores: ds = do @ v^T +# b_ds += tl.dot(b_do, tl.trans(b_v)) + +# # Apply scale and gating to score gradients +# b_ds = b_ds * scale +# if USE_G: +# b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) +# else: +# b_ds = tl.where(m_A, b_ds, 0) + +# # Compute final gradients for each Householder step +# for i_dp in range(num_householder): +# for i_k in range(tl.cdiv(K, BK)): +# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) +# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) +# p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) +# p_dk = tl.make_block_ptr(dk+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + +# b_q = tl.load(p_q, boundary_check=(0, 1)) +# b_k = tl.load(p_k, boundary_check=(0, 1)) + +# # dq += ds @ k^T +# b_dq += tl.dot(b_ds, tl.trans(b_k)) +# # dk += q^T @ ds +# b_dk = tl.dot(tl.trans(b_q), b_ds) + +# tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) +# tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + +# # Store value gradients +# for i_dp in range(num_householder): +# p_dv = tl.make_block_ptr(dv+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) +# tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +# def chunk_gated_delta_product_bwd_o( +# q: torch.Tensor, +# k: torch.Tensor, +# v: torch.Tensor, +# h: torch.Tensor, +# g: Optional[torch.Tensor] = None, +# do: torch.Tensor = None, +# scale: Optional[float] = None, +# cu_seqlens: Optional[torch.LongTensor] = None, +# chunk_size: int = 64, +# num_householder: int = 1, +# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +# assert q.shape[1] * num_householder == k.shape[1], "q.shape[1] * num_householder must be equal to k.shape[1]" +# B, T, H, K, V = *q.shape, v.shape[-1] +# BT = min(chunk_size, max(16, triton.next_power_of_2(T))) +# chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None +# NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + +# dq = torch.zeros_like(q) +# dk = torch.zeros_like(k) +# dv = torch.zeros_like(v) +# dh = torch.zeros_like(h) + +# def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) +# chunk_gated_delta_product_bwd_kernel_o[grid]( +# q, +# k, +# v, +# h, +# g, +# do, +# dq, +# dk, +# dv, +# dh, +# cu_seqlens, +# chunk_indices, +# scale, +# T=T, +# num_householder=num_householder, +# H=H, +# K=K, +# V=V, +# BT=BT, +# ) +# return dq, dk, dv, dh diff --git a/fla/ops/simple_gla/README.md b/fla/ops/simple_gla/README.md index 2a64f3dcd..c359ced5e 100644 --- a/fla/ops/simple_gla/README.md +++ b/fla/ops/simple_gla/README.md @@ -1,10 +1,10 @@ # Simple GLA -Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). +Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). -Compared to GLA, the gating is head-wise instead of elementwise. -As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. -It is faster than GLA but has less expressive power. +Compared to GLA, the gating is head-wise instead of elementwise. +As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. +It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA. $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. diff --git a/legacy/training/README.md b/legacy/training/README.md index 13bbc1730..9a21de234 100644 --- a/legacy/training/README.md +++ b/legacy/training/README.md @@ -7,14 +7,14 @@ > [!IMPORTANT] > The `flame` project has been migrated to a new project built on torchtitan. > Please visit the [new repository](https://github.com/fla-org/flame) for details and updates. -> +> > The code here is now **archived as legacy**, and no future updates will be synchronized here. A minimal framework for training FLA models, whether from scratch or through finetuning. Built on the robust infrastructure of 🤗, `flame` enables you to train large language models with just a few lines of code: we use `datasets` for data processing, `transformers` for model definitions, and `accelerate`[^1] for seamless distributed training. - + In this README, we will guide you through the process of using `flame` to train GLA models. ## Setup @@ -25,7 +25,7 @@ Clone the `fla` repository and install the necessary packages as follows: ```bash git clone https://github.com/sustcsonglin/flash-linear-attention.git -pip install . +pip install . pip install accelerate ``` @@ -35,8 +35,8 @@ pip install accelerate ## Preprocessing -Before training, you need to download and pre-tokenize your dataset. -We provide a straightforward script for this. +Before training, you need to download and pre-tokenize your dataset. +We provide a straightforward script for this. For instance, to tokenize a 10B sample of the `fineweb-edu` dataset, run: ```bash @@ -103,15 +103,15 @@ Other scheduler types like WSD (`warmup_stable_decay`)[^2] are also supported. The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as `batch_size × gradient_accumulation_steps × context_length × num_gpus_per_node × num_nodes`. -For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens). +For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens). The `warmup_steps` parameter indicates the number of steps for the learning rate warmup phase, while `max_steps` represents the maximum number of training steps. -Each step processes `global_batch_size` tokens. +Each step processes `global_batch_size` tokens. Consequently, `512` and `20480` correspond to processing 0.5B and 10B tokens, respectively. :warning: Monitor the value of `global_batch_size`, `warmup_steps`, and `max_steps` carefully when modifying any of the hyperparameters!! -`flame` also supports resuming interrupted training by specifying the checkpoint path. +`flame` also supports resuming interrupted training by specifying the checkpoint path. Simply use the following command: ```bash @@ -141,7 +141,7 @@ You can also use `wandb` to monitor your training process effectively. ## Continual Pretraining `flame` supports continual training from a pretrained checkpoint. -Below, we provide an example of how to finetune Mistral-7B to GLA. +Below, we provide an example of how to finetune Mistral-7B to GLA. You can follow similar steps to reproduce the results in the [GSA paper](https://arxiv.org/abs/2409.07146): 1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B: @@ -171,7 +171,7 @@ bash train.sh \ cache=data/SlimPajama-627B/train ``` -Please be aware that finetuning on a single node may not be the most efficient approach. +Please be aware that finetuning on a single node may not be the most efficient approach. If available, consider leveraging multi-node GPUs for optimal performance. You can find guidance on how to launch a multi-node job in the [accelerate tutorial](https://github.com/huggingface/accelerate/blob/main/examples/slurm/submit_multinode.sh). diff --git a/legacy/training/configs/gla_1B.json b/legacy/training/configs/gla_1B.json index eed54325e..95ef59945 100644 --- a/legacy/training/configs/gla_1B.json +++ b/legacy/training/configs/gla_1B.json @@ -22,4 +22,4 @@ "use_gk": true, "use_gv": false, "vocab_size": 32000 -} \ No newline at end of file +} diff --git a/legacy/training/configs/gla_340M.json b/legacy/training/configs/gla_340M.json index 378d80e70..bcb3fc3b0 100644 --- a/legacy/training/configs/gla_340M.json +++ b/legacy/training/configs/gla_340M.json @@ -21,4 +21,4 @@ "use_gk": true, "use_gv": false, "vocab_size": 32000 -} \ No newline at end of file +} diff --git a/legacy/training/configs/gla_7B.json b/legacy/training/configs/gla_7B.json index ca5658aab..c321d3d72 100644 --- a/legacy/training/configs/gla_7B.json +++ b/legacy/training/configs/gla_7B.json @@ -25,4 +25,4 @@ "use_gk": true, "use_gv": false, "vocab_size": 32000 -} \ No newline at end of file +} diff --git a/legacy/training/configs/transformer_340M.json b/legacy/training/configs/transformer_340M.json index e703797ca..08356de26 100644 --- a/legacy/training/configs/transformer_340M.json +++ b/legacy/training/configs/transformer_340M.json @@ -15,4 +15,4 @@ "tie_word_embeddings": true, "use_cache": true, "vocab_size": 32000 -} \ No newline at end of file +} diff --git a/legacy/training/flame/logging.py b/legacy/training/flame/logging.py index 0b5ebe3d3..9b572d6aa 100644 --- a/legacy/training/flame/logging.py +++ b/legacy/training/flame/logging.py @@ -6,8 +6,7 @@ import sys import time -from transformers.trainer_callback import (ExportableState, TrainerCallback, - TrainerControl, TrainerState) +from transformers.trainer_callback import ExportableState, TrainerCallback, TrainerControl, TrainerState from transformers.training_args import TrainingArguments diff --git a/legacy/training/flame/parser.py b/legacy/training/flame/parser.py index 3b54d76e2..921fcb4d9 100644 --- a/legacy/training/flame/parser.py +++ b/legacy/training/flame/parser.py @@ -6,9 +6,8 @@ from typing import Optional import transformers -from transformers import HfArgumentParser, TrainingArguments - from flame.logging import get_logger +from transformers import HfArgumentParser, TrainingArguments logger = get_logger(__name__) diff --git a/legacy/training/run.py b/legacy/training/run.py index 0689d28fa..151324919 100644 --- a/legacy/training/run.py +++ b/legacy/training/run.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- from datasets import load_from_disk -from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - Trainer) - -import fla # noqa from flame.data import DataCollatorForLanguageModeling from flame.logging import LogCallback, get_logger from flame.parser import get_train_args +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer + +import fla # noqa logger = get_logger(__name__)