Skip to content

Commit 4f65885

Browse files
DamonFoolCISC
andauthored
llama : support T5 models with unequal number of encoder-decoder layers (ggml-org#15909)
* Extend the support of T5 models with different encoder-decoder layers Signed-off-by: Jie Fu <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update gguf-py/gguf/constants.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update gguf-py/gguf/gguf_writer.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-arch.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-arch.h Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-hparams.h Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Rename n_dec_layer --> dec_n_layer Signed-off-by: Jie Fu <[email protected]> * Adapt to cases when dec_n_layer > n_layer Signed-off-by: Jie Fu <[email protected]> --------- Signed-off-by: Jie Fu <[email protected]> Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 6ab397e commit 4f65885

File tree

7 files changed

+31
-4
lines changed

7 files changed

+31
-4
lines changed

convert_hf_to_gguf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6701,6 +6701,8 @@ def set_gguf_parameters(self):
67016701
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
67026702
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
67036703
self.gguf_writer.add_block_count(self.hparams["num_layers"])
6704+
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
6705+
self.gguf_writer.add_decoder_block_count(dec_n_layer)
67046706
self.gguf_writer.add_head_count(self.hparams["num_heads"])
67056707
self.gguf_writer.add_key_length(self.hparams["d_kv"])
67066708
self.gguf_writer.add_value_length(self.hparams["d_kv"])

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class LLM:
109109
POOLING_TYPE = "{arch}.pooling_type"
110110
LOGIT_SCALE = "{arch}.logit_scale"
111111
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
112+
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
112113
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
113114
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
114115
SWIN_NORM = "{arch}.swin_norm"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,9 @@ def add_parallel_residual(self, use: bool) -> None:
676676
def add_decoder_start_token_id(self, id: int) -> None:
677677
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
678678

679+
def add_decoder_block_count(self, value: int) -> None:
680+
self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value)
681+
679682
def add_embedding_length_per_layer_input(self, value: int) -> None:
680683
self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
681684

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
137137
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
138138
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
139139
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
140+
{ LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" },
140141
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
141142
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
142143
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ enum llm_kv {
141141
LLM_KV_POOLING_TYPE,
142142
LLM_KV_LOGIT_SCALE,
143143
LLM_KV_DECODER_START_TOKEN_ID,
144+
LLM_KV_DECODER_BLOCK_COUNT,
144145
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
145146
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
146147
LLM_KV_SWIN_NORM,

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ struct llama_hparams {
159159
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
160160
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
161161
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
162+
uint32_t dec_n_layer = 0;
162163

163164
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
164165
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;

src/llama-model.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15421542
hparams.dec_start_token_id = dec_start_token_id;
15431543
}
15441544

1545+
hparams.dec_n_layer = hparams.n_layer;
1546+
ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false);
1547+
15451548
switch (hparams.n_layer) {
15461549
case 6: type = LLM_TYPE_60M; break; // t5-small
15471550
case 8: type = LLM_TYPE_80M; break; // flan-t5-small
@@ -4414,6 +4417,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
44144417
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
44154418
}
44164419

4420+
// n_layer: number of encoder_layers
4421+
// dec_n_layer: number of decoder_layers
4422+
const int dec_n_layer = hparams.dec_n_layer;
4423+
if (dec_n_layer > n_layer) {
4424+
layers.resize(dec_n_layer);
4425+
}
4426+
4427+
// load encoder layers
44174428
for (int i = 0; i < n_layer; ++i) {
44184429
auto & layer = layers[i];
44194430

@@ -4429,6 +4440,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
44294440
layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
44304441
layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
44314442
layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4443+
}
4444+
4445+
// load decoder layers
4446+
for (int i = 0; i < dec_n_layer; ++i) {
4447+
auto & layer = layers[i];
44324448

44334449
layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0);
44344450
layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
@@ -13509,7 +13525,9 @@ struct llm_build_t5_dec : public llm_graph_context {
1350913525

1351013526
ggml_tensor * inp_out_ids = build_inp_out_ids();
1351113527

13512-
for (int il = 0; il < n_layer; ++il) {
13528+
const int64_t dec_n_layer = hparams.dec_n_layer;
13529+
13530+
for (int il = 0; il < dec_n_layer; ++il) {
1351313531
ggml_tensor * inpSA = inpL;
1351413532

1351513533
// norm
@@ -13600,7 +13618,7 @@ struct llm_build_t5_dec : public llm_graph_context {
1360013618
//cb(cur, "kqv_out", il);
1360113619
}
1360213620

13603-
if (il == n_layer - 1 && inp_out_ids) {
13621+
if (il == dec_n_layer - 1 && inp_out_ids) {
1360413622
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1360513623
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
1360613624
}
@@ -13621,8 +13639,8 @@ struct llm_build_t5_dec : public llm_graph_context {
1362113639
model.layers[il].ffn_gate, NULL, NULL,
1362213640
model.layers[il].ffn_down, NULL, NULL,
1362313641
NULL,
13624-
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
13625-
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
13642+
model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU,
13643+
model.layers[il].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ,
1362613644
il);
1362713645
cb(cur, "ffn_out", il);
1362813646
}

0 commit comments

Comments
 (0)