Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 83 additions & 69 deletions ggml/src/ggml-openvino/ggml-decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,26 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
int num_heads_kv,
int head_size,
const std::vector<int> & swa_layers) :
m_is_static(is_static),
m_cgraph(cgraph),
m_node(node),
m_op_name(std::string(node->name)),
m_context_size(context_size),
m_context_size_swa(context_size_swa),
m_swa_layers(swa_layers),
m_num_heads(num_heads),
m_num_heads_kv(num_heads_kv),
m_ctx(context_size),
m_ctx_swa(context_size_swa),
m_n_heads(num_heads),
m_n_heads_kv(num_heads_kv),
m_head_size(head_size),
m_is_static(is_static) {
m_swa_layers(swa_layers) {
set_input_output(node);
}

GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
bool is_static) :
m_is_static(is_static),
m_cgraph(cgraph),
m_op_name(m_node ? std::string(m_node->name) : ""),
m_model_weights(model_weights),
m_is_static(is_static) {
m_model_weights(model_weights) {
if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") {
unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS");
print_tensor_address_map(cgraph);
Expand All @@ -78,7 +78,7 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
set_input_output(cur_node);
}

// add_extra_inputs();
add_extra_inputs();
}

GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights) {
Expand Down Expand Up @@ -125,7 +125,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
// Add model inputs and weights constants, if called for the whole graph
if (naive) {
if (m_model_weights.find(src_name) == m_model_weights.end()) {
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(src));
auto param_node =
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
param_node->set_friendly_name(src_name);
param_node->output(0).get_tensor().set_names({src_name});
m_model_inputs[src_name] = param_node;
Expand All @@ -142,7 +143,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
if (m_model_inputs.find(src_name) != m_model_inputs.end()) {
continue;
}
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(src));
auto param_node =
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
param_node->set_friendly_name(src_name);
param_node->output(0).get_tensor().set_names({src_name});
m_model_inputs[src_name] = param_node;
Expand Down Expand Up @@ -175,15 +177,20 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
if (m_node) {
switch (node->op) {
case GGML_OP_RESHAPE: {
if (node->src[0]->op == GGML_OP_RESHAPE && node->src[0]->src[0]->ne[0] == node->ne[0] &&
node->src[0]->src[0]->ne[1] == node->ne[1]) {
auto * src = node->src[0];
if (src->op == GGML_OP_RESHAPE && src->src[0]->ne[0] == node->ne[0] && src->src[0]->ne[1] == node->ne[1]) {
m_op_case = 4;
} else if (node->ne[0] * node->ne[1] == node->src[0]->ne[0]) {
} else if (node->ne[0] * node->ne[1] == src->ne[0]) {
m_op_case = 1;
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[0]) {
} else if (src->ne[0] * src->ne[1] == node->ne[0]) {
m_op_case = 2;
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[1]) {
if (src->ne[2] * src->ne[3] == node->ne[1]) {
m_op_case = 5;
}
} else if (src->ne[0] * src->ne[1] == node->ne[1]) {
m_op_case = 3;
} else if (src->ne[1] * src->ne[2] == node->ne[1]) {
m_op_case = 6;
}
break;
}
Expand All @@ -204,7 +211,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
} else if (ggml_is_contiguous(node->src[0])) {
std::string src_name(node->view_src->name);
if (src_name.find("cache") == std::string::npos) {
m_op_case = 1;
// permute Qcur
m_op_case = 4;
} else {
// Permute kv cache (view)
int layer = extract_layer_from_name(src_name);
Expand Down Expand Up @@ -272,64 +280,80 @@ void GgmlOvDecoder::set_llm_params() {
auto * node = m_cgraph->nodes[i];
std::string name = std::string(node->name);
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
auto * cache_k = node->src[1];
cache_k = cache_k->view_src ? cache_k->view_src : cache_k;
auto * cache_k_perm = node->src[1];
assert(cache_k_perm->op == GGML_OP_PERMUTE);
auto * cache_k_view = cache_k_perm->src[0];
assert(cache_k_view->op == GGML_OP_VIEW);

auto * cache_k = cache_k_view->src[0];
int layer = extract_layer_from_name(cache_k->name);
auto * mask = node->src[3];
std::string mask_name(mask->name);
assert(mask_name.find("KQ_mask") == 0);

if (std::string(node->src[3]->name).find("swa") != std::string::npos) {
m_swa_layers.push_back(layer);
m_context_size_swa = cache_k->ne[1];
m_ctx_per_seq_swa = cache_k->ne[1];
} else {
m_context_size = cache_k->ne[1];
m_ctx_per_seq = cache_k->ne[1];
m_n_seq = cache_k->ne[2];
}

m_n_seq_active = mask->ne[3];
auto seq_size = cache_k->ne[0] * cache_k->ne[1] * ggml_type_size(cache_k->type);
m_seq_active_start = ((size_t *) cache_k_view->op_params)[0] / seq_size;
m_token_len_per_seq = node->ne[2];

if (mask_name.find("swa") != std::string::npos) {
m_attention_size_swa = mask->ne[0];
} else {
m_attention_size = mask->ne[0];
}

} else if (node->op == GGML_OP_ROPE) {
if (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0) {
m_head_size = node->ne[0];
m_num_heads = node->ne[1];
m_n_heads = node->ne[1];
m_rope_params = node->op_params;
auto * inp_pos = node->src[1];
m_input_len = inp_pos->ne[0];
m_past_kv_len = *(int32_t *) inp_pos->data;
} else if (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0) {
m_num_heads_kv = node->ne[1];
m_n_heads_kv = node->ne[1];
}
}
}
m_ctx = m_ctx_per_seq * m_n_seq;
m_ctx_swa = m_ctx_per_seq_swa * m_n_seq;
}

void GgmlOvDecoder::validate_cgraph() const {}
void GgmlOvDecoder::validate_cgraph() const {
if (m_n_seq > 1 && m_is_static == true) {
throw std::runtime_error("n_seq > 1 is not supported on NPU");
}
}

ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * src) const {
auto name = std::string(src->name);
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const {
auto name = std::string(input->name);
ov::PartialShape input_shape;

if (name == "inp_tokens" || name == "inp_pos" || name == "inp_out_ids") {
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
input_shape = ov::PartialShape{1, 1, 1, m_is_static ? 1 : -1};

} else if (name.find("KQ_mask") == 0) {
if (m_is_static) {
input_shape = ov::PartialShape{1, 1, m_context_size};
input_shape = ov::PartialShape{1, 1, 1, m_ctx};
} else {
input_shape = ov::PartialShape{1, -1, -1};
}

} else if (name.find("cache_") == 0) {
auto past_token_len = -1;
if (m_is_static) {
int layer = extract_layer_from_name(name);
bool is_swa = is_swa_layer(layer);
past_token_len = is_swa ? m_context_size_swa : m_context_size;
input_shape = ov::PartialShape{-1, 1, -1, -1};
}
input_shape = ov::PartialShape{past_token_len, m_num_heads_kv, m_head_size};

} else if (const auto * op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
} else if (op && op->op == GGML_OP_SET_ROWS && op->src[1] == input) {
input_shape = ov::PartialShape{1, 1, 1, m_is_static ? 1 : -1};

} else if (src->op == GGML_OP_VIEW) {
} else if (input->op == GGML_OP_VIEW) {
// This case is added to make test-backend-ops work
input_shape = ov::PartialShape{get_shape(src->view_src)};
input_shape = ov::PartialShape{get_shape(input->view_src)};
} else {
input_shape = ov::PartialShape{get_shape(src)};
input_shape = ov::PartialShape{get_shape(input)};
}
return input_shape;
}
Expand All @@ -339,25 +363,9 @@ void GgmlOvDecoder::add_extra_inputs() {
// 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned,
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
// Not used for NPU.
// Update: not used anymore after the optimization of making kvcache dynamic (but breaks iSWA models)
int64_t attention_size = -1;
int64_t attention_size_swa = -1;
for (const auto & node : m_nodes) {
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
auto * mask = node->src[3];
std::string mask_name(mask->name);
if (mask_name.find("KQ_mask") != 0) {
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
}
if (mask_name.find("swa") != std::string::npos) {
attention_size_swa = mask->ne[0];
} else {
attention_size = mask->ne[0];
}
}
}
// 2. `n_seq_active` and `seq_active_start`, used in FLASH_ATTN_EXT to indicate the active sequences in the batch

auto create_attention_size_input = [this](const std::string & name, int64_t size) {
auto create_1d_input = [this](const std::string & name, int64_t size) {
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
param_node->set_friendly_name(name);
param_node->output(0).get_tensor().set_names({name});
Expand All @@ -368,10 +376,14 @@ void GgmlOvDecoder::add_extra_inputs() {
m_model_extra_input_values[name] = tensor;
};

create_attention_size_input("attention_size", attention_size);
if (attention_size_swa != -1) {
create_attention_size_input("attention_size_swa", attention_size_swa);
create_1d_input("attention_size", m_attention_size);
if (m_attention_size_swa != -1) {
create_1d_input("attention_size_swa", m_attention_size_swa);
}
create_1d_input("seq_active_start", m_seq_active_start);
create_1d_input("n_seq_active", m_n_seq_active);
create_1d_input("token_len_per_seq", m_token_len_per_seq);
create_1d_input("seq_active_end", m_seq_active_start + m_n_seq_active);
}

const ggml_tensor * GgmlOvDecoder::get_tensor_used_op(const ggml_tensor * tensor) const {
Expand Down Expand Up @@ -472,6 +484,8 @@ std::shared_ptr<ov::Node> GgmlOvDecoder::create_weight_node(ggml_tensor * tensor
auto node_shape = get_shape(tensor);
auto ne_total = ggml_nelements(tensor);

OPENVINO_ASSERT(node_shape[0] == 1, "Got 4D weights, expect all weights to be 2D: ", tensor->name);
node_shape.erase(node_shape.begin());
OPENVINO_ASSERT(node_shape[0] == 1, "Got 3D weights, expect all weights to be 2D: ", tensor->name);
node_shape.erase(node_shape.begin());

Expand Down Expand Up @@ -641,15 +655,15 @@ void print_tensor_address_map(const ggml_cgraph * cgraph) {

std::vector<size_t> GgmlOvDecoder::get_shape(const ggml_tensor * tensor) {
std::vector<size_t> shape;
for (int i = GGML_MAX_DIMS - 2; i >= 0; --i) {
for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {
shape.push_back(static_cast<size_t>(tensor->ne[i]));
}
return shape;
}

std::vector<size_t> GgmlOvDecoder::get_stride(const ggml_tensor * tensor) {
std::vector<size_t> stride;
for (int i = GGML_MAX_DIMS - 2; i >= 0; --i) {
for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {
stride.push_back(static_cast<size_t>(tensor->nb[i]));
}
return stride;
Expand Down Expand Up @@ -738,8 +752,8 @@ int32_t * GgmlOvDecoder::get_output_op_params(const std::string & name) const {

void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
for (const auto & node : m_nodes) {
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_context_size, m_context_size_swa,
m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_ctx, m_ctx_swa, m_n_heads,
m_n_heads_kv, m_head_size, m_swa_layers);
node_visitor(decoder);
}
}
Expand Down
50 changes: 31 additions & 19 deletions ggml/src/ggml-openvino/ggml-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,19 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {

virtual const std::vector<std::string> & get_model_output_names() const override { return m_model_output_names; }

virtual int get_context_size() const override { return m_context_size; }
virtual int get_ctx_size() const { return m_ctx; }

virtual int get_context_size_swa() const override { return m_context_size_swa; }
virtual int get_ctx_swa_size() const { return m_ctx_swa; }

virtual int is_swa_layer(int layer) const override {
return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end();
}
virtual int get_ctx_per_seq() const { return m_ctx_per_seq; }

virtual int get_num_heads() const override { return m_num_heads; }
virtual int get_ctx_per_seq_swa() const { return m_ctx_per_seq_swa; }

virtual int get_num_heads_kv() const override { return m_num_heads_kv; }
virtual int get_n_seq() const { return m_n_seq; }

virtual int get_head_size() const override { return m_head_size; }
virtual int is_swa_layer(int layer) const override {
return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end();
}

int get_past_kv_len() const { return m_past_kv_len; }

Expand All @@ -127,7 +127,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {

virtual bool is_static() const override { return m_is_static; }

ov::PartialShape get_graph_input_shape(const ggml_tensor * src) const;
ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const;

static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);

Expand All @@ -151,10 +151,11 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
static std::vector<size_t> get_stride(const ggml_tensor * tensor);
static ov::element::Type get_ov_type(const ggml_tensor * tensor);

// set context_size, num_heads, etc
void set_llm_params();
void validate_cgraph() const;

bool m_is_static = false;

ggml_cgraph * m_cgraph = nullptr;
ggml_tensor * m_node = nullptr;
std::vector<ggml_tensor *> m_nodes;
Expand All @@ -171,17 +172,28 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
std::vector<std::string> m_model_output_names;
int m_context_size;
int m_context_size_swa;

// Fixed for a model
int m_ctx = -1;
int m_ctx_swa = -1;
int m_ctx_per_seq = -1;
int m_ctx_per_seq_swa = -1;
int m_n_seq = -1;
int m_n_heads = -1;
int m_n_heads_kv = -1;
int m_head_size = -1;
std::vector<int> m_swa_layers;
int m_num_heads;
int m_num_heads_kv;
int m_head_size;
int m_past_kv_len;
int m_input_len;
int32_t * m_rope_params;
std::vector<std::string> m_kv_names;
bool m_is_static = false;

// Changed per inference
int m_n_seq_active = -1;
int m_seq_active_start = -1;
int m_attention_size = -1;
int m_attention_size_swa = -1;
int m_input_len = -1;
int m_token_len_per_seq = -1;
int m_past_kv_len = -1;
int32_t * m_rope_params = nullptr;
};

void print_tensor_address_map(const ggml_cgraph * cgraph);
Expand Down
Loading