Skip to content

Commit d714dad

Browse files
huydt84huydt-bti
andauthored
pooling : make cls_b and cls_out_b optional (#14165)
Co-authored-by: dinhhuy <[email protected]>
1 parent ffad043 commit d714dad

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/llama-graph.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,23 +1556,30 @@ void llm_graph_context::build_pooling(
15561556
ggml_tensor * inp_cls = build_inp_cls();
15571557
inp = ggml_get_rows(ctx0, inp, inp_cls);
15581558

1559-
if (cls != nullptr && cls_b != nullptr) {
1559+
if (cls) {
15601560
// classification head
15611561
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1562-
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1562+
cur = ggml_mul_mat(ctx0, cls, inp);
1563+
if (cls_b) {
1564+
cur = ggml_add(ctx0, cur, cls_b);
1565+
}
15631566
cur = ggml_tanh(ctx0, cur);
15641567

15651568
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
15661569
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
15671570
if (cls_out) {
1568-
GGML_ASSERT(cls_out_b != nullptr);
1569-
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1571+
cur = ggml_mul_mat(ctx0, cls_out, cur);
1572+
if (cls_out_b) {
1573+
cur = ggml_add(ctx0, cur, cls_out_b);
1574+
}
15701575
}
15711576
} else if (cls_out) {
15721577
// Single layer classification head (direct projection)
15731578
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1574-
GGML_ASSERT(cls_out_b != nullptr);
1575-
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1579+
cur = ggml_mul_mat(ctx0, cls_out, inp);
1580+
if (cls_out_b) {
1581+
cur = ggml_add(ctx0, cur, cls_out_b);
1582+
}
15761583
} else {
15771584
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
15781585
}

0 commit comments

Comments
 (0)