diff --git a/src/wif/ml/lightGBMWrapper.cpp b/src/wif/ml/lightGBMWrapper.cpp index 0fd713f..f7c59ca 100644 --- a/src/wif/ml/lightGBMWrapper.cpp +++ b/src/wif/ml/lightGBMWrapper.cpp @@ -54,8 +54,8 @@ ClfResult LightGBMWrapper::classify(const FlowFeatures& flowFeatures) int64_t outLen; // length of output result int numOfClasses; // number of classes LGBM_BoosterGetNumClasses(m_booster, &numOfClasses); - std::vector pred(numOfClasses); // vector with predictions + for (const auto& featureID : m_featureIDs) { double value = flowFeatures.get(featureID); dataToClassify.push_back(value); @@ -74,6 +74,11 @@ ClfResult LightGBMWrapper::classify(const FlowFeatures& flowFeatures) &outLen, pred.data()); + if (numOfClasses == 1) { + double tmp = pred[0]; + pred.insert(pred.begin(), (1.0 - tmp)); + } + return ClfResult(pred); } @@ -114,6 +119,12 @@ std::vector LightGBMWrapper::classify(const std::vector std::vector probabilities( pred.begin() + idx * numOfClasses, pred.begin() + (idx + 1) * numOfClasses); + + if (numOfClasses == 1) { + double tmp = probabilities[0]; + probabilities.insert(probabilities.begin(), (1.0 - tmp)); + } + burstResults.emplace_back(probabilities); }