3030#include < set>
3131#include < stdexcept>
3232#include < string>
33+ #include < vector>
3334
3435#include " ggml-backend-impl.h"
3536#include " ggml-backend.h"
3637#include " ggml-quants.hpp"
3738
3839GgmlOvDecoder::GgmlOvDecoder (struct ggml_tensor * node, struct ggml_cgraph * cgraph, bool is_static, bool is_first_token,
39- int context_size, int num_heads, int num_heads_kv, int head_size) :
40+ int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size,
41+ const std::vector<int >& swa_layers) :
4042 m_cgraph(cgraph),
4143 m_node(node),
4244 m_op_name(std::string(node->name)),
4345 m_context_size(context_size),
46+ m_context_size_swa(context_size_swa),
47+ m_swa_layers(swa_layers),
4448 m_num_heads(num_heads),
4549 m_num_heads_kv(num_heads_kv),
4650 m_head_size(head_size),
@@ -204,11 +208,14 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
204208 if (node->src [0 ]->op != GGML_OP_VIEW) {
205209 m_op_case = 1 ;
206210 } else if (ggml_is_contiguous (node->src [0 ])) {
207- // Permute cache_k (view)
208- m_op_case = 2 ;
209- } else {
210- // Permute cache_v (view), deprecated, cache_v will also fall to case 2
211- m_op_case = 3 ;
211+ // Permute kv cache (view)
212+ std::string src_name (node->view_src ->name );
213+ int layer = extract_layer_from_name (src_name);
214+ if (!is_swa_layer (layer)) {
215+ m_op_case = 2 ;
216+ } else {
217+ m_op_case = 3 ;
218+ }
212219 }
213220 break ;
214221 }
@@ -239,13 +246,34 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
239246 }
240247}
241248
249+ int extract_layer_from_name (const std::string& name) {
250+ size_t pos1 = name.find (" _l" );
251+ assert (pos1 != std::string::npos);
252+ pos1 += 2 ;
253+ size_t pos2 = name.find (' ' , pos1);
254+ if (pos2 == std::string::npos) {
255+ pos2 = name.length ();
256+ }
257+ std::string layer_str = name.substr (pos1, pos2 - pos1);
258+ int layer = std::stoi (layer_str);
259+ return layer;
260+ }
261+
242262void GgmlOvDecoder::set_llm_params () {
243263 for (int i = 0 ; i < m_cgraph->n_nodes ; i++) {
244264 auto * node = m_cgraph->nodes [i];
245265 std::string name = std::string (node->name );
246- if (node->op == GGML_OP_VIEW && std::string (node->name ) == " cache_k_l0 (view)" ) {
247- auto * cache_k = node->src [0 ];
248- m_context_size = cache_k->ne [1 ];
266+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
267+ auto * cache_k = node->src [1 ];
268+ cache_k = cache_k->view_src ? cache_k->view_src : cache_k;
269+ int layer = extract_layer_from_name (cache_k->name );
270+
271+ if (std::string (node->src [3 ]->name ).find (" swa" ) != std::string::npos) {
272+ m_swa_layers.push_back (layer);
273+ m_context_size_swa = cache_k->ne [1 ];
274+ } else {
275+ m_context_size = cache_k->ne [1 ];
276+ }
249277 } else if (node->op == GGML_OP_ROPE &&
250278 (name.find (" Qcur-0" ) == 0 || std::string (node->src [0 ]->name ).find (" Qcur-0" ) == 0 )) {
251279 m_head_size = node->ne [0 ];
@@ -269,25 +297,24 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
269297 input_shape = ov::PartialShape{1 , 1 , 1 };
270298 }
271299 } else {
272- input_shape = ov::PartialShape{1 , 1 , ov::Dimension ( 1 , m_context_size) };
300+ input_shape = ov::PartialShape{1 , 1 , - 1 };
273301 }
274302 } else if (name == " inp_out_ids" && !m_is_static) {
275- input_shape = ov::PartialShape{1 , 1 , ov::Dimension ( 1 , m_context_size) };
276- } else if (name == " KQ_mask " ) {
303+ input_shape = ov::PartialShape{1 , 1 , - 1 };
304+ } else if (name. find ( " KQ_mask " ) == 0 ) {
277305 if (m_is_static) {
278306 if (m_is_first_token) {
279307 input_shape = ov::PartialShape{1 , m_context_size, m_context_size};
280308 } else {
281309 input_shape = ov::PartialShape{1 , 1 , m_context_size};
282310 }
283311 } else {
284- auto max_mask_size = GGML_PAD (m_context_size, GGML_KQ_MASK_PAD);
285- input_shape = ov::PartialShape{1 , ov::Dimension (1 , max_mask_size), ov::Dimension (1 , max_mask_size)};
312+ input_shape = ov::PartialShape{1 , -1 , -1 };
286313 }
287- } else if (name.find (" cache_k " ) == 0 ) {
288- input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size} ;
289- } else if (name. find ( " cache_v " ) == 0 ) {
290- input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
314+ } else if (name.find (" cache_ " ) == 0 ) {
315+ int layer = extract_layer_from_name (name) ;
316+ bool is_swa = is_swa_layer (layer);
317+ input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
291318 } else if (const auto * op = get_tensor_used_op (src); op && op->op == GGML_OP_SET_ROWS) {
292319 input_shape = ov::PartialShape{1 , 1 , m_is_static ? 1 : -1 };
293320 } else if (src->op == GGML_OP_VIEW) {
@@ -305,35 +332,35 @@ void GgmlOvDecoder::add_extra_inputs() {
305332 // see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
306333 // Not used for NPU
307334 int64_t attention_size = -1 ;
335+ int64_t attention_size_swa = -1 ;
308336 for (const auto & node : m_nodes) {
309- if (node->op == GGML_OP_SOFT_MAX) {
310- auto * mask = node->src [1 ];
311- if (std::string (mask->name ).find (" KQ_mask" ) != 0 ) {
312- throw std::runtime_error (" Unexpected softmax node: " + std::string (mask->name ));
313- }
314- attention_size = mask->ne [0 ];
315- break ;
316- }
317337 if (node->op == GGML_OP_FLASH_ATTN_EXT) {
318338 auto * mask = node->src [3 ];
319- if (std::string (mask->name ).find (" KQ_mask" ) != 0 ) {
339+ std::string mask_name (mask->name );
340+ if (mask_name.find (" KQ_mask" ) != 0 ) {
320341 throw std::runtime_error (" Unexpected flash attention node: " + std::string (mask->name ));
321342 }
322- attention_size = mask->ne [0 ];
343+ if (mask_name.find (" swa" ) != std::string::npos) {
344+ attention_size_swa = mask->ne [0 ];
345+ } else {
346+ attention_size = mask->ne [0 ];
347+ }
323348 }
324349 }
325350
326- {
327- std::string name = " attention_size" ;
351+ auto create_attention_size_input = [this ](const std::string& name, int64_t size) {
328352 auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64 , ov::Shape{1 });
329353 param_node->set_friendly_name (name);
330354 param_node->output (0 ).get_tensor ().set_names ({name});
331355 m_model_extra_inputs[name] = param_node;
332356
333357 auto tensor = std::make_shared<ov::Tensor>(ov::element::i64 , ov::Shape{1 });
334- *tensor->data <int64_t >() = attention_size ;
358+ *tensor->data <int64_t >() = size ;
335359 m_model_extra_input_values[name] = tensor;
336- }
360+ };
361+
362+ create_attention_size_input (" attention_size" , attention_size);
363+ create_attention_size_input (" attention_size_swa" , attention_size_swa);
337364}
338365
339366const ggml_tensor* GgmlOvDecoder::get_tensor_used_op (const ggml_tensor* tensor) const {
@@ -706,8 +733,16 @@ int32_t* GgmlOvDecoder::get_output_op_params(const std::string& name) const {
706733
707734void GgmlOvDecoder::visit_subgraph (std::function<void (std::shared_ptr<GgmlDecoder>)> node_visitor) const {
708735 for (const auto & node : m_nodes) {
709- auto decoder = std::make_shared<GgmlOvDecoder>(
710- node, m_cgraph, m_is_static, m_is_first_token, m_context_size, m_num_heads, m_num_heads_kv, m_head_size);
736+ auto decoder = std::make_shared<GgmlOvDecoder>(node,
737+ m_cgraph,
738+ m_is_static,
739+ m_is_first_token,
740+ m_context_size,
741+ m_context_size_swa,
742+ m_num_heads,
743+ m_num_heads_kv,
744+ m_head_size,
745+ m_swa_layers);
711746 node_visitor (decoder);
712747 }
713748}
0 commit comments