Skip to content

Commit 1caae7f

Browse files
authored
gguf-py : add add_classifier_output_labels method to writer (#14031)
* add add_classifier_output_labels * use add_classifier_output_labels
1 parent 669c13e commit 1caae7f

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3709,8 +3709,7 @@ def set_gguf_parameters(self):
37093709
self._try_set_pooling_type()
37103710

37113711
if self.cls_out_labels:
3712-
key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
3713-
self.gguf_writer.add_array(key_name, [v for k, v in sorted(self.cls_out_labels.items())])
3712+
self.gguf_writer.add_classifier_output_labels([v for k, v in sorted(self.cls_out_labels.items())])
37143713

37153714
def set_vocab(self):
37163715
tokens, toktypes, tokpre = self.get_vocab_base()

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,9 @@ def add_eot_token_id(self, id: int) -> None:
935935
def add_eom_token_id(self, id: int) -> None:
936936
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
937937

938+
def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
939+
self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
940+
938941
# for vision models
939942

940943
def add_clip_has_vision_encoder(self, value: bool) -> None:

0 commit comments

Comments
 (0)