|
5 | 5 | #include <map> |
6 | 6 | #include <memory> |
7 | 7 | #include <openvino/core/node.hpp> |
| 8 | +#include <openvino/op/add.hpp> |
8 | 9 | #include <openvino/op/broadcast.hpp> |
9 | 10 | #include <openvino/op/concat.hpp> |
10 | 11 | #include <openvino/op/convert.hpp> |
@@ -78,67 +79,59 @@ void add_kv_update_indices(TensorMap& tensor_map, GgmlDecoder& ggml_model_decode |
78 | 79 | // cache_k layout: [S, N, H] (seq, num_heads, head_size) |
79 | 80 | // cache_v layout: [N, H, S] (num_heads, head_size, seq) |
80 | 81 | // When writing to cache_v, cache should be reshaped to [N*H, S] and v-curr should be flattened |
81 | | - auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); |
| 82 | + auto past_token_len = tensor_map.at("past_token_len").get_node_shared_ptr(); |
82 | 83 | auto token_len = tensor_map.at("token_len").get_node_shared_ptr(); |
83 | 84 |
|
84 | | - std::shared_ptr<ov::Node> update_indices_k; |
85 | | - std::shared_ptr<ov::Node> update_indices_v; |
| 85 | + Output<Node> update_indices_k; |
| 86 | + Output<Node> update_indices_v; |
86 | 87 |
|
87 | 88 | auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); |
88 | 89 | auto zero_scalar = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); |
89 | 90 | auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); |
90 | 91 | auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); |
91 | 92 | auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); |
92 | 93 |
|
93 | | - update_indices_k = |
94 | | - std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); |
95 | | - update_indices_k = std::make_shared<ov::op::v0::Unsqueeze>(update_indices_k, one); |
96 | | - update_indices_k->set_friendly_name("update_indices_k"); |
97 | | - tensor_map.insert({"update_indices_k", update_indices_k->output(0)}); |
| 94 | + auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero); |
| 95 | + auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero); |
| 96 | + auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar); |
| 97 | + |
| 98 | + Output<Node> update_indices = std::make_shared<ov::op::v4::Range>( |
| 99 | + past_token_len_scalar, total_token_len_scalar, one_scalar, ov::element::i64); |
| 100 | + if (ggml_model_decoder.is_static()) { |
| 101 | + update_indices = past_token_len; |
| 102 | + } |
| 103 | + |
| 104 | + update_indices_k = std::make_shared<ov::op::v0::Unsqueeze>(update_indices, one); |
| 105 | + update_indices_k.get_node_shared_ptr()->set_friendly_name("update_indices_k"); |
| 106 | + tensor_map.insert({"update_indices_k", update_indices_k}); |
98 | 107 |
|
99 | 108 | auto total_head_size = ggml_model_decoder.get_num_heads_kv() * ggml_model_decoder.get_head_size(); |
100 | 109 | auto total_head_size_node = ov::op::v0::Constant::create(ov::element::i64, {1}, {total_head_size}); |
101 | 110 | auto total_head_size_scalar = std::make_shared<ov::op::v0::Squeeze>(total_head_size_node, zero); |
102 | 111 |
|
103 | 112 | // 1D tensor of shape [total_head_size], values starting from 0 |
104 | 113 | auto range_row = |
105 | | - std::make_shared<ov::op::v4::Range>(zero_scalar, total_head_size_scalar, one_scalar, ov::element::i32); |
| 114 | + std::make_shared<ov::op::v4::Range>(zero_scalar, total_head_size_scalar, one_scalar, ov::element::i64); |
106 | 115 | auto range_row_reshaped = |
107 | 116 | std::make_shared<ov::op::v0::Unsqueeze>(range_row, ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2})); |
108 | 117 | auto row_indices = std::make_shared<ov::op::v3::Broadcast>( |
109 | 118 | range_row_reshaped, |
110 | 119 | std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0)); |
111 | 120 |
|
112 | 121 | // 1D tensor of shape [token_len], values starting from past_token_len |
113 | | - auto range_col = |
114 | | - std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); |
| 122 | + auto range_col = update_indices; |
115 | 123 | auto range_col_reshaped = |
116 | 124 | std::make_shared<ov::op::v0::Unsqueeze>(range_col, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 2})); |
117 | 125 | auto col_indices = std::make_shared<ov::op::v3::Broadcast>( |
118 | 126 | range_col_reshaped, |
119 | 127 | std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0)); |
120 | 128 |
|
121 | 129 | // Stack row_indices and col_indices along last axis: [total_head_size, token_len, 2] |
122 | | - auto indices = std::make_shared<ov::op::v0::Concat>(OutputVector{row_indices, col_indices}, 2); |
| 130 | + update_indices_v = std::make_shared<ov::op::v0::Concat>(OutputVector{row_indices, col_indices}, 2); |
123 | 131 | update_indices_v = std::make_shared<ov::op::v1::Reshape>( |
124 | | - indices, ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{-1, 2}), false); |
125 | | - update_indices_v->set_friendly_name("update_indices_v"); |
126 | | - tensor_map.insert({"update_indices_v", update_indices_v->output(0)}); |
127 | | -} |
128 | | - |
129 | | -float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { |
130 | | -#ifndef M_PI |
131 | | -# define M_PI 3.14159265358979323846 |
132 | | -#endif |
133 | | - return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float) M_PI)) / (2 * logf(base)); |
134 | | -} |
135 | | - |
136 | | -void ggml_rope_yarn_corr_dims(int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, |
137 | | - float dims[2]) { |
138 | | - float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); |
139 | | - float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); |
140 | | - dims[0] = std::max(0.0f, start); |
141 | | - dims[1] = std::min(static_cast<float>(n_dims - 1), end); |
| 132 | + update_indices_v, ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{-1, 2}), false); |
| 133 | + update_indices_v.get_node_shared_ptr()->set_friendly_name("update_indices_v"); |
| 134 | + tensor_map.insert({"update_indices_v", update_indices_v}); |
142 | 135 | } |
143 | 136 |
|
144 | 137 | void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { |
|
0 commit comments