Skip to content

Commit 82290b9

Browse files
committed
Optimize tensor conversion, improve TTFT
1 parent ca5e725 commit 82290b9

File tree

1 file changed

+17
-58
lines changed

1 file changed

+17
-58
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <openvino/runtime/tensor.hpp>
2525
#include <ostream>
2626
#include <set>
27+
#include <stdexcept>
2728
#include <string>
2829

2930
#include "ggml-backend-impl.h"
@@ -391,53 +392,12 @@ std::map<std::string, std::shared_ptr<ov::Node>> GgmlOvDecoder::create_weight_no
391392
}
392393

393394
std::shared_ptr<ov::Node> GgmlOvDecoder::create_weight_node(ggml_tensor* tensor) {
394-
std::shared_ptr<ov::Node> weight_node;
395395
auto node_type = get_ov_type(tensor);
396396
auto node_shape = get_shape(tensor);
397397
auto ne_total = ggml_nelements(tensor);
398-
switch (tensor->type) {
399-
case GGML_TYPE_I32: {
400-
const auto* ptr = reinterpret_cast<const int32_t*>(tensor->data);
401-
std::vector<int32_t> data(ptr, ptr + ne_total);
402-
weight_node = std::make_shared<ov::op::v0::Constant>(node_type, node_shape, data);
403-
break;
404-
}
405-
case GGML_TYPE_I64: {
406-
const auto* ptr = reinterpret_cast<const int64_t*>(tensor->data);
407-
std::vector<int64_t> data(ptr, ptr + ne_total);
408-
weight_node = std::make_shared<ov::op::v0::Constant>(node_type, node_shape, data);
409-
break;
410-
}
411-
case GGML_TYPE_F32: {
412-
const auto* ptr = reinterpret_cast<const float*>(tensor->data);
413-
std::vector<float> data(ptr, ptr + ne_total);
414-
weight_node = std::make_shared<ov::op::v0::Constant>(node_type, node_shape, data);
415-
break;
416-
}
417-
case GGML_TYPE_F16: {
418-
const auto* ptr = reinterpret_cast<const uint16_t*>(tensor->data);
419-
std::vector<ov::float16> data_f16;
420-
data_f16.reserve(ne_total);
421-
for (int i = 0; i < ne_total; ++i) {
422-
data_f16.push_back(ov::float16::from_bits(ptr[i]));
423-
}
424-
weight_node = std::make_shared<ov::op::v0::Constant>(node_type, node_shape, data_f16);
425-
break;
426-
}
427-
case GGML_TYPE_BF16: {
428-
const auto* ptr = reinterpret_cast<const uint16_t*>(tensor->data);
429-
std::vector<ov::bfloat16> data_bf16;
430-
data_bf16.reserve(ne_total);
431-
for (int i = 0; i < ne_total; ++i) {
432-
data_bf16.push_back(ov::bfloat16::from_bits(ptr[i]));
433-
}
434-
weight_node = std::make_shared<ov::op::v0::Constant>(node_type, node_shape, data_bf16);
435-
break;
436-
}
437-
default:
438-
throw std::invalid_argument("Unsupported tensor type");
439-
}
440-
return weight_node;
398+
ov::Tensor weights(node_type, node_shape);
399+
memcpy(weights.data(), tensor->data, ne_total * node_type.size());
400+
return std::make_shared<ov::op::v0::Constant>(weights);
441401
}
442402

443403
void GgmlOvDecoder::dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename) {
@@ -549,27 +509,26 @@ std::vector<size_t> GgmlOvDecoder::get_stride(const ggml_tensor* tensor) {
549509
}
550510

551511
ov::element::Type GgmlOvDecoder::get_ov_type(const ggml_tensor* tensor) {
552-
ov::element::Type type = ov::element::dynamic;
553512
switch (tensor->type) {
513+
case GGML_TYPE_F64:
514+
return ov::element::f64;
554515
case GGML_TYPE_F32:
555-
type = ov::element::f32;
556-
break;
516+
return ov::element::f32;
557517
case GGML_TYPE_F16:
558-
type = ov::element::f16;
559-
break;
518+
return ov::element::f16;
560519
case GGML_TYPE_BF16:
561-
type = ov::element::bf16;
562-
break;
563-
case GGML_TYPE_I64:
564-
type = ov::element::i64;
565-
break;
520+
return ov::element::bf16;
521+
case GGML_TYPE_I8:
522+
return ov::element::i8;
523+
case GGML_TYPE_I16:
524+
return ov::element::i16;
566525
case GGML_TYPE_I32:
567-
type = ov::element::i32;
568-
break;
526+
return ov::element::i32;
527+
case GGML_TYPE_I64:
528+
return ov::element::i64;
569529
default:
570-
break;
530+
throw std::runtime_error("Unsupported tensor type");
571531
}
572-
return type;
573532
}
574533

575534
ov::PartialShape GgmlOvDecoder::get_input_shape(const std::string& name) const {

0 commit comments

Comments
 (0)