Fix multi-label classification in run_classification.py (closes #43116) #43198
+149
−40
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #43116
Summary
This PR addresses multi-label classification bugs in run_classification.py and adds confidence scores output following community feedback.
Bug fixes
Fixed 4 bugs that broke multi-label classification with JSON datasets:
New features
Based on feedback from @ziorufus, added configurable threshold and confidence scores output.
New parameters:
--output_confidence_scores(bool, default: False) - Output JSON with confidence scores instead of binary predictions--multi_label_threshold(float, default: 0.5) - Threshold for converting probabilities to binary predictions--top_k_labels(int, optional) - Limit output to top K most confident labelsOutput format follows transformers Pipeline API convention:
Traditional mode (default):
Confidence scores mode:
[ { "index": 0, "predictions": [ {"label": "positive", "score": 0.89}, {"label": "urgent", "score": 0.67} ] } ]Backward compatibility
Default behavior unchanged. New features require explicit flags. No breaking changes.
Testing
Validated with:
Test scenarios:
Usage examples
Traditional mode:
Confidence scores:
Custom threshold:
Top-K labels:
Implementation notes
Changes