Skip to content

graph : refactor context to not pass gf explicitly #14629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gg/llama-reuse-graphs
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}

llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
Expand Down Expand Up @@ -1324,7 +1324,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
}

llm_graph_params llama_context::graph_params(
llm_graph_result_i * res,
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const {
Expand Down
4 changes: 2 additions & 2 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ struct llama_context {
// if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result_i * process_ubatch(
llm_graph_result * process_ubatch(
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_context_i * mctx,
Expand Down Expand Up @@ -196,7 +196,7 @@ struct llama_context {

private:
llm_graph_params graph_params(
llm_graph_result_i * res,
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const;
Expand Down
30 changes: 11 additions & 19 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
mctx (params.mctx),
cross (params.cross),
cb_func (params.cb),
res (static_cast<llm_graph_result *>(params.res)),
ctx0 (res->get_ctx()) {
res->params = params;
res (params.res),
ctx0 (res->get_ctx()),
gf (res->get_gf()) {
res->set_params(params);
}

void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
Expand Down Expand Up @@ -1040,7 +1041,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
}

ggml_tensor * llm_graph_context::build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
Expand Down Expand Up @@ -1170,7 +1170,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
Expand All @@ -1194,7 +1193,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1249,7 +1248,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
Expand Down Expand Up @@ -1282,7 +1280,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand All @@ -1302,7 +1300,6 @@ ggml_tensor * llm_graph_context::build_attn(

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
Expand Down Expand Up @@ -1349,7 +1346,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1382,7 +1379,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
Expand All @@ -1404,7 +1400,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1460,7 +1456,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
}

ggml_tensor * llm_graph_context::build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
Expand Down Expand Up @@ -1518,21 +1513,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {

ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
const llm_graph_get_rows_fn & get_state_rows) const {
const auto * kv_state = inp->mctx;

return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
}

ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const {
int il) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);

const auto token_shift_count = hparams.token_shift_count;
Expand All @@ -1542,7 +1535,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);

ggml_tensor * token_shift = build_rs(
inp, gf, token_shift_all,
inp, token_shift_all,
hparams.n_embd_r(), n_seqs);

token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
Expand Down Expand Up @@ -1582,7 +1575,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
}

void llm_graph_context::build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,
Expand Down
66 changes: 22 additions & 44 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,31 +371,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters

// TODO: this interface seems redundant - remove it
class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;

virtual ggml_tensor * get_tokens() const = 0;
virtual ggml_tensor * get_logits() const = 0;
virtual ggml_tensor * get_embd() const = 0;
virtual ggml_tensor * get_embd_pooled() const = 0;

virtual ggml_cgraph * get_gf() = 0;
virtual ggml_context * get_ctx() = 0;

virtual void reset() = 0;

virtual void set_inputs(const llama_ubatch * ubatch) = 0;

virtual bool can_reuse(const llm_graph_params & params) = 0;
};

using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;

// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;

class llm_graph_result;

struct llm_graph_params {
llm_arch arch = LLM_ARCH_UNKNOWN;

Expand All @@ -418,8 +398,7 @@ struct llm_graph_params {

llm_graph_cb cb;

// TODO: temporary
llm_graph_result_i * res;
llm_graph_result * res;

// return true if the "other" params would result in a graph with the same topology as with the current params
// having the same topology allows us to reuse the graph in some cases
Expand Down Expand Up @@ -462,27 +441,27 @@ struct llm_graph_params {
}
};

class llm_graph_result : public llm_graph_result_i {
class llm_graph_result {
public:
llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
reset();
}

virtual ~llm_graph_result() = default;

ggml_tensor * get_tokens() const override { return t_tokens; }
ggml_tensor * get_logits() const override { return t_logits; }
ggml_tensor * get_embd() const override { return t_embd; }
ggml_tensor * get_embd_pooled() const override { return t_embd_pooled; }
ggml_tensor * get_tokens() const { return t_tokens; }
ggml_tensor * get_logits() const { return t_logits; }
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }

ggml_cgraph * get_gf() override { return gf; }
ggml_context * get_ctx() override { return ctx_compute.get(); }
ggml_cgraph * get_gf() { return gf; }
ggml_context * get_ctx() { return ctx_compute.get(); }

void set_max_nodes(int64_t max_nodes) {
this->max_nodes = max_nodes;
}

void reset() override {
void reset() {
t_tokens = nullptr;
t_logits = nullptr;
t_embd = nullptr;
Expand All @@ -503,7 +482,7 @@ class llm_graph_result : public llm_graph_result_i {
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
}

void set_inputs(const llama_ubatch * ubatch) override {
void set_inputs(const llama_ubatch * ubatch) {
for (auto & input : inputs) {
input->set_input(ubatch);
}
Expand All @@ -514,7 +493,7 @@ class llm_graph_result : public llm_graph_result_i {
// would be identical to the existing graph. in that case, we simply have to update the memory
// contexts of the input tensors of the graph and we can reuse it for another computation
// return true if the graph was updated and can be reused
bool can_reuse(const llm_graph_params & params) override {
bool can_reuse(const llm_graph_params & params) {
if (!this->params.allow_reuse(params)) {
return false;
}
Expand All @@ -533,6 +512,10 @@ class llm_graph_result : public llm_graph_result_i {
return inputs.back().get();
}

void set_params(const llm_graph_params & params) {
this->params = params;
}

// important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr;
Expand All @@ -550,12 +533,15 @@ class llm_graph_result : public llm_graph_result_i {

int64_t max_nodes;

private:
// keep a copy of the previous graph parameters
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
// note: these are updated after constructing the new graph
llm_graph_params params;
};

using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;

//
// llm_graph_context
//
Expand Down Expand Up @@ -613,6 +599,7 @@ struct llm_graph_context {
llm_graph_result * res;

ggml_context * ctx0 = nullptr;
ggml_cgraph * gf = nullptr;

llm_graph_context(const llm_graph_params & params);
virtual ~llm_graph_context() = default;
Expand Down Expand Up @@ -698,7 +685,6 @@ struct llm_graph_context {
//

ggml_tensor * build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
Expand All @@ -711,7 +697,6 @@ struct llm_graph_context {

ggml_tensor * build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
Expand All @@ -726,7 +711,6 @@ struct llm_graph_context {

ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
Expand All @@ -742,7 +726,6 @@ struct llm_graph_context {
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
Expand All @@ -757,7 +740,6 @@ struct llm_graph_context {

ggml_tensor * build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
Expand All @@ -779,7 +761,6 @@ struct llm_graph_context {
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_memory_recurrent`
ggml_tensor * build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
Expand All @@ -794,17 +775,15 @@ struct llm_graph_context {

ggml_tensor * build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;

ggml_tensor * build_rwkv_token_shift_load(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const;
int il) const;

ggml_tensor * build_rwkv_token_shift_store(
ggml_tensor * token_shift,
Expand All @@ -821,7 +800,6 @@ struct llm_graph_context {
//

void build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,
Expand Down
Loading
Loading