@@ -107,7 +107,9 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
107107
108108 std::shared_ptr<GgmlOvDecoder> ggml_decoder;
109109 std::shared_ptr<ov::InferRequest> infer_request;
110- bool is_first_token = get_is_first_token (cgraph);
110+
111+ const ggml_tensor* inp_pos = get_inp_pos_tensor (cgraph);
112+ bool is_first_token = get_is_first_token (inp_pos);
111113
112114 int64_t decoder_end_time;
113115 int64_t conversion_end_time;
@@ -207,41 +209,17 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
207209 }
208210 }
209211
210- // TODO not correct yet
211- // Even if we make this correct, there will still be a corner case that will fail:
212- // in llama-server, enter some prompt in a conversation, after it completes,
213- // enter the same prompt in another conversation. This should still be treated as
214- // prefill but get_is_prefill will return false because len(inp_pos) == 1 && inp_pos[0] != 0
215- // which in most cases means generate stage.
216- if (!is_static && get_is_prefill (cgraph)) {
212+ if (!is_static) {
217213 auto states = infer_request->query_state ();
218- if (get_is_first_token (cgraph)) {
214+ int32_t kv_len = *(int32_t *) inp_pos->data ;
215+ int32_t kv_len_in_state = states[0 ].get_state ().get_shape ()[1 ];
216+ if (kv_len != kv_len_in_state) {
219217 for (auto & state : states) {
220- state.reset ();
221- }
222- } else {
223- const auto * inp_pos = get_inp_pos_tensor (cgraph);
224- for (auto & state : states) {
225- std::string state_name = state.get_name ();
226- state_name = state_name.substr (0 , state_name.size () / 2 );
227- ggml_tensor* kv_tensor;
228- for (const auto & kv : kv_tensors) {
229- if (state_name == kv.first ) {
230- kv_tensor = kv.second ;
231- break ;
232- }
233- }
234- // shape should be [1, inp_pos[0], num_heads, head_dim]
235- ov::Shape state_shape = state.get_state ().get_shape ();
236- // std::cout << state_shape << std::endl;
237- state_shape[1 ] = *(int32_t *) inp_pos->data ;
238- // std::cout << state_shape << std::endl;
239- ov::Tensor state_tensor (state.get_state ().get_element_type (), state_shape, kv_tensor->data );
240-
241- // The above is wrong becaues I am setting the state using kvbuffer's in the cgraph which
242- // we never update with our stateful approach.
243- // What we should do is to get the kv values from ov by state.get_state(), slice the to the
244- // rows of inp_pos->data, and use that as the new state.
218+ ov::Tensor state_tensor = state.get_state ();
219+ ov::Shape state_shape = state_tensor.get_shape ();
220+ state_shape[1 ] = kv_len;
221+ state_tensor.set_shape (state_shape);
222+ state.set_state (state_tensor);
245223 }
246224 }
247225 }
@@ -547,17 +525,10 @@ const ggml_tensor* get_inp_pos_tensor(struct ggml_cgraph* cgraph) {
547525 throw std::runtime_error (" get_inp_pos_tensor: inp_pos not found in cgraph" );
548526}
549527
550- bool get_is_first_token (struct ggml_cgraph * cgraph) {
551- const auto * inp_pos = get_inp_pos_tensor (cgraph);
528+ bool get_is_first_token (const ggml_tensor* inp_pos) {
552529 return *(int32_t *) inp_pos->data == 0 ;
553530}
554531
555- // Check if the graph is for prefill (first token or batch size > 1)
556- bool get_is_prefill (struct ggml_cgraph * cgraph) {
557- const auto * inp_pos = get_inp_pos_tensor (cgraph);
558- return *(int32_t *) inp_pos->data == 0 || inp_pos->ne [0 ] > 1 ;
559- }
560-
561532std::vector<std::pair<std::string, ggml_tensor*>> get_kv_tensors (struct ggml_cgraph * cgraph) {
562533 std::vector<std::pair<std::string, ggml_tensor*>> kv_tensors;
563534 for (int i = 0 ; i < cgraph->n_nodes ; ++i) {
0 commit comments