From 29c9347121075e0ece33c21d797a02db767bab43 Mon Sep 17 00:00:00 2001 From: Cyrix126 <58007246+Cyrix126@users.noreply.github.com> Date: Tue, 22 Jul 2025 12:39:27 +0200 Subject: [PATCH] feat: check overflow --- src/algorithms/bnb.rs | 31 ++++++++++++++++--------------- src/algorithms/fifo.rs | 19 ++++++++++--------- src/algorithms/knapsack.rs | 16 +++++++++------- src/algorithms/leastchange.rs | 18 ++++++++++-------- src/algorithms/lowestlarger.rs | 34 ++++++++++++++++++++-------------- src/algorithms/srd.rs | 22 ++++++++++++---------- src/types.rs | 1 + src/utils.rs | 6 ++++++ 8 files changed, 84 insertions(+), 63 deletions(-) diff --git a/src/algorithms/bnb.rs b/src/algorithms/bnb.rs index 91fa770..4809a65 100644 --- a/src/algorithms/bnb.rs +++ b/src/algorithms/bnb.rs @@ -1,6 +1,6 @@ use crate::{ types::{CoinSelectionOpt, OutputGroup, SelectionError, SelectionOutput, WasteMetric}, - utils::{calculate_fee, calculate_waste}, + utils::{calculate_fee, calculate_waste, sum}, }; /// Struct MatchParameters encapsulates target_for_match, match_range, and target_feerate, options, tries, best solution. @@ -28,10 +28,11 @@ pub fn select_coin_bnb( sorted_inputs.sort_by_key(|(_, input)| input.value); let mut ctx = BnbContext { - target_for_match: options.target_value - + options.min_change_value - + base_fee.max(options.min_absolute_fee), - match_range: cost_per_input + cost_per_output, + target_for_match: sum( + sum(options.target_value, options.min_change_value)?, + base_fee.max(options.min_absolute_fee), + )?, + match_range: sum(cost_per_input, cost_per_output)?, options: options.clone(), tries: 1_000_000, best_solution: None, @@ -39,7 +40,7 @@ pub fn select_coin_bnb( let mut selected_inputs = vec![]; - bnb(&sorted_inputs, &mut selected_inputs, 0, 0, 0, &mut ctx); + bnb(&sorted_inputs, &mut selected_inputs, 0, 0, 0, &mut ctx)?; match ctx.best_solution { Some((selected, waste)) => Ok(SelectionOutput { @@ -57,9 +58,9 @@ fn bnb( acc_weight: u64, depth: usize, ctx: &mut BnbContext, -) { +) -> Result<(), SelectionError> { if ctx.tries == 0 || depth >= sorted.len() { - return; + return Ok(()); } ctx.tries -= 1; @@ -72,8 +73,8 @@ fn bnb( let effective_value = acc_value.saturating_sub(fee); // Prune if we're way over target (including change consideration) - if effective_value > ctx.target_for_match + ctx.match_range { - return; + if effective_value > sum(ctx.target_for_match, ctx.match_range)? { + return Ok(()); } // Check for valid solution (must cover target + min change) @@ -82,7 +83,7 @@ fn bnb( if ctx.best_solution.is_none() || waste < ctx.best_solution.as_ref().unwrap().1 { ctx.best_solution = Some((selected.clone(), waste)); } - return; + return Ok(()); } let (index, input) = sorted[depth]; @@ -96,15 +97,15 @@ fn bnb( bnb( sorted, selected, - acc_value + input_effective_value, - acc_weight + input.weight, + sum(acc_value, input_effective_value)?, + sum(acc_weight, input.weight)?, depth + 1, ctx, - ); + )?; selected.pop(); // Branch 2: Exclude current input - bnb(sorted, selected, acc_value, acc_weight, depth + 1, ctx); + bnb(sorted, selected, acc_value, acc_weight, depth + 1, ctx) } #[cfg(test)] diff --git a/src/algorithms/fifo.rs b/src/algorithms/fifo.rs index 78efe14..80a2125 100644 --- a/src/algorithms/fifo.rs +++ b/src/algorithms/fifo.rs @@ -1,6 +1,6 @@ use crate::{ types::{CoinSelectionOpt, OutputGroup, SelectionError, SelectionOutput, WasteMetric}, - utils::{calculate_fee, calculate_waste}, + utils::{calculate_fee, calculate_waste, sum}, }; /// Performs coin selection using the First-In-First-Out (FIFO) algorithm. @@ -15,8 +15,10 @@ pub fn select_coin_fifo( let mut selected_inputs: Vec = Vec::new(); let mut estimated_fees: u64 = 0; let base_fees = calculate_fee(options.base_weight, options.target_feerate).unwrap_or_default(); - let target = - options.target_value + options.min_change_value + base_fees.max(options.min_absolute_fee); + let target = sum( + sum(options.target_value, options.min_change_value)?, + base_fees.max(options.min_absolute_fee), + )?; // Sorting the inputs vector based on creation_sequence let mut sorted_inputs: Vec<_> = inputs @@ -36,16 +38,15 @@ pub fn select_coin_fifo( sorted_inputs.extend(inputs_without_sequence); for (index, inputs) in sorted_inputs { - estimated_fees = - calculate_fee(accumulated_weight, options.target_feerate).unwrap_or_default(); - if accumulated_value >= target + estimated_fees { + estimated_fees = calculate_fee(accumulated_weight, options.target_feerate)?; + if accumulated_value >= sum(target, estimated_fees)? { break; } - accumulated_value += inputs.value; - accumulated_weight += inputs.weight; + accumulated_value = sum(accumulated_value, inputs.value)?; + accumulated_weight = sum(accumulated_weight, inputs.weight)?; selected_inputs.push(index); } - if accumulated_value < target + estimated_fees { + if accumulated_value < sum(target, estimated_fees)? { Err(SelectionError::InsufficientFunds) } else { let waste: f32 = calculate_waste( diff --git a/src/algorithms/knapsack.rs b/src/algorithms/knapsack.rs index 492e84a..a675674 100644 --- a/src/algorithms/knapsack.rs +++ b/src/algorithms/knapsack.rs @@ -2,7 +2,7 @@ use crate::{ types::{ CoinSelectionOpt, EffectiveValue, OutputGroup, SelectionError, SelectionOutput, WasteMetric, }, - utils::{calculate_fee, calculate_waste, effective_value}, + utils::{calculate_fee, calculate_waste, effective_value, sum}, }; use rand::{thread_rng, Rng}; use std::collections::HashSet; @@ -11,9 +11,11 @@ pub fn select_coin_knapsack( options: &CoinSelectionOpt, ) -> Result { // Calculate base fees with no inputs - let base_fees = calculate_fee(options.base_weight, options.target_feerate).unwrap_or_default(); - let adjusted_target = - options.target_value + options.min_change_value + base_fees.max(options.min_absolute_fee); + let base_fees = calculate_fee(options.base_weight, options.target_feerate)?; + let adjusted_target = sum( + sum(options.target_value, options.min_change_value)?, + base_fees.max(options.min_absolute_fee), + )?; let mut smaller_coins = inputs .iter() @@ -61,12 +63,12 @@ fn knap_sack( let toss_result: bool = rng.gen_bool(0.5); if (pass == 2 && !selected_inputs.contains(&index)) || (pass == 1 && toss_result) { selected_inputs.insert(index); - accumulated_value += value; - accumulated_weight += weight; + accumulated_value = sum(accumulated_value, value)?; + accumulated_weight = sum(accumulated_weight, weight)?; // Calculate current fees and required value let estimated_fees = calculate_fee(accumulated_weight, options.target_feerate)?; - let required_value = adjusted_target + estimated_fees; + let required_value = sum(adjusted_target, estimated_fees)?; if accumulated_value == required_value { let waste = calculate_waste( diff --git a/src/algorithms/leastchange.rs b/src/algorithms/leastchange.rs index 9ce2d5c..34915d8 100644 --- a/src/algorithms/leastchange.rs +++ b/src/algorithms/leastchange.rs @@ -2,7 +2,7 @@ use std::vec; use crate::{ types::{CoinSelectionOpt, OutputGroup, SelectionError, SelectionOutput, WasteMetric}, - utils::{calculate_fee, calculate_waste, effective_value}, + utils::{calculate_fee, calculate_waste, effective_value, sum}, }; /// A Branch and Bound state for Least Change selection which stores the state while traversing the tree. @@ -21,8 +21,10 @@ pub fn select_coin_bnb_leastchange( ) -> Result { let mut best: Option<(Vec, u64, usize)> = None; // (selection, change, count) let base_fees = calculate_fee(options.base_weight, options.target_feerate).unwrap_or_default(); - let target = - options.target_value + options.min_change_value + base_fees.max(options.min_absolute_fee); + let target = sum( + sum(options.target_value, options.min_change_value)?, + base_fees.max(options.min_absolute_fee), + )?; // Precompute net values and filter beneficial inputs let mut filtered = inputs @@ -43,7 +45,7 @@ pub fn select_coin_bnb_leastchange( let n = filtered.len(); let mut remaining_net = vec![0; n + 1]; for i in (0..n).rev() { - remaining_net[i] = remaining_net[i + 1] + filtered[i].1; + remaining_net[i] = sum(remaining_net[sum(i as u64, 1)? as usize], filtered[i].1)?; } // DFS with BnB pruning @@ -61,7 +63,7 @@ pub fn select_coin_bnb_leastchange( } // Prune if impossible to reach target - if state.current_eff_value + remaining_net[state.index] < target { + if sum(state.current_eff_value, remaining_net[state.index])? < target { continue; } @@ -74,15 +76,15 @@ pub fn select_coin_bnb_leastchange( }); let (orig_idx, net_value, weight) = filtered[state.index]; - let new_eff_value = state.current_eff_value + net_value; + let new_eff_value = sum(state.current_eff_value, net_value)?; let mut new_selection = state.current_selection.clone(); new_selection.push(orig_idx); let new_count = state.current_count + 1; - let new_weight = state.current_weight + weight; + let new_weight = sum(state.current_weight, weight)?; // Calculate fees based on current selection let estimated_fees = calculate_fee(new_weight, options.target_feerate).unwrap_or(0); - let required_value = target + estimated_fees; + let required_value = sum(target, estimated_fees)?; if new_eff_value >= required_value { let change = new_eff_value - required_value; let update = match best { diff --git a/src/algorithms/lowestlarger.rs b/src/algorithms/lowestlarger.rs index 396207d..8aaba15 100644 --- a/src/algorithms/lowestlarger.rs +++ b/src/algorithms/lowestlarger.rs @@ -1,6 +1,6 @@ use crate::{ types::{CoinSelectionOpt, OutputGroup, SelectionError, SelectionOutput, WasteMetric}, - utils::{calculate_fee, calculate_waste, effective_value}, + utils::{calculate_fee, calculate_waste, effective_value, sum}, }; /// Performs coin selection using the Lowest Larger algorithm. @@ -14,43 +14,49 @@ pub fn select_coin_lowestlarger( let mut accumulated_weight: u64 = 0; let mut selected_inputs: Vec = Vec::new(); let mut estimated_fees: u64 = 0; - let base_fees = calculate_fee(options.base_weight, options.target_feerate).unwrap_or_default(); - let target = - options.target_value + options.min_change_value + base_fees.max(options.min_absolute_fee); + let base_fees = calculate_fee(options.base_weight, options.target_feerate)?; + let target = sum( + sum(options.target_value, options.min_change_value)?, + base_fees.max(options.min_absolute_fee), + )?; let mut sorted_inputs: Vec<_> = inputs.iter().enumerate().collect(); sorted_inputs.sort_by_key(|(_, input)| effective_value(input, options.target_feerate)); let index = sorted_inputs.partition_point(|(_, input)| { - input.value - <= (target + calculate_fee(input.weight, options.target_feerate).unwrap_or_default()) + if let Ok(fee) = calculate_fee(input.weight, options.target_feerate) { + if let Ok(target_and_fee) = sum(target, fee) { + return input.value <= target_and_fee; + } + } + false }); for (idx, input) in sorted_inputs.iter().take(index).rev() { - accumulated_value += input.value; - accumulated_weight += input.weight; + accumulated_value = sum(accumulated_value, input.value)?; + accumulated_weight = sum(accumulated_weight, input.weight)?; estimated_fees = calculate_fee(accumulated_weight, options.target_feerate)?; selected_inputs.push(*idx); - if accumulated_value >= (target + estimated_fees) { + if accumulated_value >= sum(target, estimated_fees)? { break; } } - if accumulated_value < (target + estimated_fees) { + if accumulated_value < sum(target, estimated_fees)? { for (idx, input) in sorted_inputs.iter().skip(index) { - accumulated_value += input.value; - accumulated_weight += input.weight; + accumulated_value = sum(accumulated_value, input.value)?; + accumulated_weight = sum(accumulated_weight, input.weight)?; estimated_fees = calculate_fee(accumulated_weight, options.target_feerate)?; selected_inputs.push(*idx); - if accumulated_value >= (target + estimated_fees.max(options.min_absolute_fee)) { + if accumulated_value >= sum(target, estimated_fees.max(options.min_absolute_fee))? { break; } } } - if accumulated_value < (target + estimated_fees) { + if accumulated_value < sum(target, estimated_fees)? { Err(SelectionError::InsufficientFunds) } else { let waste: f32 = calculate_waste( diff --git a/src/algorithms/srd.rs b/src/algorithms/srd.rs index 6c4c796..d8dc07f 100644 --- a/src/algorithms/srd.rs +++ b/src/algorithms/srd.rs @@ -1,6 +1,6 @@ use crate::{ types::{CoinSelectionOpt, OutputGroup, SelectionError, SelectionOutput, WasteMetric}, - utils::{calculate_fee, calculate_waste}, + utils::{calculate_fee, calculate_waste, sum}, }; use rand::{seq::SliceRandom, thread_rng}; @@ -15,33 +15,35 @@ pub fn select_coin_srd( // So keep track of the indexes when randomiz ing the vec let mut randomized_inputs: Vec<_> = inputs.iter().enumerate().collect(); let base_fees = calculate_fee(options.base_weight, options.target_feerate).unwrap_or_default(); - let target = - options.target_value + options.min_change_value + base_fees.max(options.min_absolute_fee); + let target = sum( + sum(options.target_value, options.min_change_value)?, + base_fees.max(options.min_absolute_fee), + )?; // Randomize the inputs order to simulate the random draw let mut rng = thread_rng(); randomized_inputs.shuffle(&mut rng); - let mut accumulated_value = 0; + let mut accumulated_value: u64 = 0; let mut selected_inputs = Vec::new(); - let mut accumulated_weight = 0; + let mut accumulated_weight: u64 = 0; let mut estimated_fee = 0; let mut _input_counts = 0; for (index, input) in randomized_inputs { selected_inputs.push(index); - accumulated_value += input.value; - accumulated_weight += input.weight; - _input_counts += input.input_count; + accumulated_value = sum(accumulated_value, input.value)?; + accumulated_weight = sum(accumulated_weight, input.weight)?; + _input_counts = sum(_input_counts, input.input_count as u64)?; estimated_fee = calculate_fee(accumulated_weight, options.target_feerate)?; - if accumulated_value >= target + estimated_fee { + if accumulated_value >= sum(target, estimated_fee)? { break; } } - if accumulated_value < target + estimated_fee { + if accumulated_value < sum(target, estimated_fee)? { return Err(SelectionError::InsufficientFunds); } let waste = calculate_waste( diff --git a/src/types.rs b/src/types.rs index edbc449..ab9cb68 100644 --- a/src/types.rs +++ b/src/types.rs @@ -92,6 +92,7 @@ pub enum SelectionError { NonPositiveTarget, NonPositiveFeeRate, AbnormallyHighFeeRate, + AbnormallyHighAmount, } /// Measures the efficiency of input selection in satoshis, helping evaluate algorithms based on current and long-term fee rates diff --git a/src/utils.rs b/src/utils.rs index 1b9ee77..d219b2b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -43,6 +43,11 @@ pub fn calculate_accumulated_weight( } accumulated_weight } +/// sugar to return a SelectionError when overflowing +pub fn sum(a: u64, b: u64) -> Result { + a.checked_add(b) + .ok_or_else(|| SelectionError::AbnormallyHighAmount) +} impl fmt::Display for SelectionError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -52,6 +57,7 @@ impl fmt::Display for SelectionError { SelectionError::AbnormallyHighFeeRate => write!(f, "Abnormally high fee rate"), SelectionError::InsufficientFunds => write!(f, "The Inputs funds are insufficient"), SelectionError::NoSolutionFound => write!(f, "No solution could be derived"), + SelectionError::AbnormallyHighAmount => write!(f, "Abnormally high amount"), } } }