Skip to content

feat: Support exact size config for BatchCoalescer #8112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 220 additions & 28 deletions arrow-select/src/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,25 @@ use primitive::InProgressPrimitiveArray;
/// assert!(coalescer.next_completed_batch().is_none());
/// ```
///
/// Non-strict example (with_exact_size(false))
/// ```
/// # use arrow_array::record_batch;
/// # use arrow_select::coalesce::BatchCoalescer;
/// let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
/// let batch2 = record_batch!(("a", Int32, [4, 5])).unwrap();
///
/// // Non-strict: produce batch once buffered >= target, batch may be larger than target
/// let mut coalescer = BatchCoalescer::new(batch1.schema(), 4).with_exact_size(false);
/// coalescer.push_batch(batch1).unwrap();
/// // still < 4 rows buffered
/// assert!(coalescer.next_completed_batch().is_none());
/// coalescer.push_batch(batch2).unwrap();
/// // now buffered >= 4, non-strict mode emits whole buffered set (5 rows)
/// let finished = coalescer.next_completed_batch().unwrap();
/// let expected = record_batch!(("a", Int32, [1, 2, 3, 4, 5])).unwrap();
/// assert_eq!(finished, expected);
/// ```
///
/// # Background
///
/// Generally speaking, larger [`RecordBatch`]es are more efficient to process
Expand Down Expand Up @@ -128,6 +147,14 @@ use primitive::InProgressPrimitiveArray;
///
/// 2. The output is a sequence of batches, with all but the last being at exactly
/// `target_batch_size` rows.
///
/// Notes on `exact_size`:
///
/// - `exact_size == true` (strict): output batches are produced so that all but
/// the final batch have exactly `target_batch_size` rows (default behavior).
/// - `exact_size == false` (non-strict, default for this crate): output batches
/// will be produced when the buffered rows are >= `target_batch_size`. The
/// produced batch may be larger than `target_batch_size` (i.e., size >= target).
#[derive(Debug)]
pub struct BatchCoalescer {
/// The input schema
Expand All @@ -142,6 +169,8 @@ pub struct BatchCoalescer {
buffered_rows: usize,
/// Completed batches
completed: VecDeque<RecordBatch>,
/// Whether the output batches are guaranteed to be exactly `target_batch_size`
exact_size: bool,
}

impl BatchCoalescer {
Expand All @@ -166,9 +195,24 @@ impl BatchCoalescer {
// We will for sure store at least one completed batch
completed: VecDeque::with_capacity(1),
buffered_rows: 0,
exact_size: true,
}
}

/// Controls whether output batches are produced with exactly `target_batch_size`.
///
/// When `exact_size == true` the coalescer will produce batches of exactly
/// `target_batch_size` rows (except possibly the final batch). This is the
/// historical behavior.
///
/// When `exact_size == false` the coalescer will produce a batch once the
/// buffered rows >= `target_batch_size`, but the produced batch may be larger
/// than `target_batch_size` (i.e. batch size >= target).
pub fn with_exact_size(mut self, exact: bool) -> Self {
self.exact_size = exact;
self
}

/// Return the schema of the output batches
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
Expand Down Expand Up @@ -241,48 +285,56 @@ impl BatchCoalescer {
return Ok(());
}

// setup input rows
// set sources
assert_eq!(arrays.len(), self.in_progress_arrays.len());
self.in_progress_arrays
.iter_mut()
.zip(arrays)
.for_each(|(in_progress, array)| {
in_progress.set_source(Some(array));
});
.for_each(|(in_progress, array)| in_progress.set_source(Some(array)));

// If pushing this batch would exceed the target batch size,
// finish the current batch and start a new one
let mut offset = 0;
while num_rows > (self.target_batch_size - self.buffered_rows) {
let remaining_rows = self.target_batch_size - self.buffered_rows;
debug_assert!(remaining_rows > 0);

// Copy remaining_rows from each array
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.copy_rows(offset, remaining_rows)?;
if self.exact_size {
// Strict: produce exactly target-sized batches (except last)
while num_rows > (self.target_batch_size - self.buffered_rows) {
let remaining_rows = self.target_batch_size - self.buffered_rows;
debug_assert!(remaining_rows > 0);
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.copy_rows(offset, remaining_rows)?;
}
self.buffered_rows += remaining_rows;
offset += remaining_rows;
num_rows -= remaining_rows;
self.finish_buffered_batch()?;
}

self.buffered_rows += remaining_rows;
offset += remaining_rows;
num_rows -= remaining_rows;

self.finish_buffered_batch()?;
}
if num_rows > 0 {
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.copy_rows(offset, num_rows)?;
}
self.buffered_rows += num_rows;
}

// Add any the remaining rows to the buffer
self.buffered_rows += num_rows;
if num_rows > 0 {
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.copy_rows(offset, num_rows)?;
// ensure strict invariant: only finish when exactly full
if self.buffered_rows >= self.target_batch_size {
self.finish_buffered_batch()?;
}
} else {
// Non-strict: append all remaining rows; if buffered >= target, emit them
if num_rows > 0 {
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.copy_rows(offset, num_rows)?;
}
self.buffered_rows += num_rows;
}
}

// If we have reached the target batch size, finalize the buffered batch
if self.buffered_rows >= self.target_batch_size {
self.finish_buffered_batch()?;
// If we've reached or exceeded target, emit the whole buffered set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go over the target size, I think it means the underlying storage will reallocate (and thus copy the data)

I think the more performant way to do this is if adding num_rows to the output would go over target_rows, emit early (even though some of the allocated space is not yet used)

Copy link
Contributor Author

@zhuqi-lucas zhuqi-lucas Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @alamb , thank you for review and good suggestion!

I am trying with the patch to do this based current PR, but the performance not getting better, the best performance from benchmark is still the exact size batch emit. 🤔

diff --git a/arrow-select/src/coalesce.rs b/arrow-select/src/coalesce.rs
index be2bbcafb6..93bdf86a2d 100644
--- a/arrow-select/src/coalesce.rs
+++ b/arrow-select/src/coalesce.rs
@@ -100,16 +100,23 @@ use primitive::InProgressPrimitiveArray;
 /// let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
 /// let batch2 = record_batch!(("a", Int32, [4, 5])).unwrap();
 ///
-/// // Non-strict: produce batch once buffered >= target, batch may be larger than target
+/// // Non-strict: optimized for memory efficiency, may emit early to avoid reallocation
 /// let mut coalescer = BatchCoalescer::new(batch1.schema(), 4).with_exact_size(false);
 /// coalescer.push_batch(batch1).unwrap();
 /// // still < 4 rows buffered
 /// assert!(coalescer.next_completed_batch().is_none());
 /// coalescer.push_batch(batch2).unwrap();
-/// // now buffered >= 4, non-strict mode emits whole buffered set (5 rows)
+/// // buffered=3, new=2, total would be 5 > 4, so emit buffered 3 rows first
 /// let finished = coalescer.next_completed_batch().unwrap();
-/// let expected = record_batch!(("a", Int32, [1, 2, 3, 4, 5])).unwrap();
+/// let expected = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
 /// assert_eq!(finished, expected);
+///
+/// // The remaining 2 rows from batch2 are now buffered
+/// assert!(coalescer.next_completed_batch().is_none());
+/// coalescer.finish_buffered_batch().unwrap();
+/// let remaining = coalescer.next_completed_batch().unwrap();
+/// let expected = record_batch!(("a", Int32, [4, 5])).unwrap();
+/// assert_eq!(remaining, expected);
 /// ```
 ///
 /// # Background
@@ -145,16 +152,21 @@ use primitive::InProgressPrimitiveArray;
 ///
 /// 1. Output rows are produced in the same order as the input rows
 ///
-/// 2. The output is a sequence of batches, with all but the last being at exactly
-///    `target_batch_size` rows.
+/// 2. The output batch sizes depend on the `exact_size` setting:
+///    - In strict mode: all but the last batch have exactly `target_batch_size` rows
+///    - In non-strict mode: batch sizes are optimized to avoid memory reallocation
 ///
 /// Notes on `exact_size`:
 ///
 /// - `exact_size == true` (strict): output batches are produced so that all but
 ///   the final batch have exactly `target_batch_size` rows (default behavior).
-/// - `exact_size == false` (non-strict, default for this crate): output batches
-///   will be produced when the buffered rows are >= `target_batch_size`. The
-///   produced batch may be larger than `target_batch_size` (i.e., size >= target).
+/// - `exact_size == false` (non-strict): output batches are optimized for memory
+///   efficiency. Batches are emitted early to avoid buffer reallocation when adding
+///   new data would exceed the target size. Large input batches are split into
+///   target-sized chunks to prevent excessive memory allocation. This may result in
+///   output batches that are smaller than `target_batch_size`, but the algorithm
+///   ensures batches are as close to the target size as possible while maintaining
+///   memory efficiency. Small batches only occur to avoid costly memory operations.
 #[derive(Debug)]
 pub struct BatchCoalescer {
     /// The input schema
@@ -320,7 +332,29 @@ impl BatchCoalescer {
                 self.finish_buffered_batch()?;
             }
         } else {
-            // Non-strict: append all remaining rows; if buffered >= target, emit them
+            // Non-strict: emit early if adding num_rows would exceed target to avoid reallocation
+            if self.buffered_rows > 0 && self.buffered_rows + num_rows > self.target_batch_size {
+                // Emit the current buffered data before processing the new batch
+                // This avoids potential reallocation in the underlying storage
+                self.finish_buffered_batch()?;
+            }
+
+            // If num_rows is larger than target_batch_size, split it into target-sized chunks
+            // to avoid allocating overly large buffers
+            while num_rows > self.target_batch_size {
+                let chunk_size = self.target_batch_size;
+                for in_progress in self.in_progress_arrays.iter_mut() {
+                    in_progress.copy_rows(offset, chunk_size)?;
+                }
+                self.buffered_rows += chunk_size;
+                offset += chunk_size;
+                num_rows -= chunk_size;
+
+                // Emit this full chunk immediately
+                self.finish_buffered_batch()?;
+            }
+
+            // Now append remaining rows (guaranteed to be <= target_batch_size) to buffer
             if num_rows > 0 {
                 for in_progress in self.in_progress_arrays.iter_mut() {
                     in_progress.copy_rows(offset, num_rows)?;
@@ -328,7 +362,7 @@ impl BatchCoalescer {
                 self.buffered_rows += num_rows;
             }
 
-            // If we've reached or exceeded target, emit the whole buffered set
+            // If the current buffer has reached or exceeded target, emit it
             if self.buffered_rows >= self.target_batch_size {
                 self.finish_buffered_batch()?;
             }
@@ -1381,38 +1415,62 @@ mod tests {
         coalescer.push_batch(batch1).unwrap();
         assert!(coalescer.next_completed_batch().is_none());
 
-        // push second batch (2 rows) -> buffered becomes 5 >= 4, non-strict emits all 5 rows
+        // push second batch (2 rows) -> buffered=3, new=2, 3+2=5 > 4
+        // NEW BEHAVIOR: emit buffered 3 rows first to avoid reallocation
         coalescer.push_batch(batch2).unwrap();
-        let out = coalescer
+        let out1 = coalescer
             .next_completed_batch()
-            .expect("expected a completed batch");
-        assert_eq!(out.num_rows(), 5);
-
-        // check contents equal to concatenation of 0..5
-        let expected = uint32_batch(0..5);
-        let actual = normalize_batch(out);
-        let expected = normalize_batch(expected);
-        assert_eq!(expected, actual);
+            .expect("expected first batch");
+        assert_eq!(out1.num_rows(), 3); // Only the first batch (early emit)
+
+        // The second batch should be buffered now
+        assert!(coalescer.next_completed_batch().is_none());
+
+        // Finish to get the remaining buffered data
+        coalescer.finish_buffered_batch().unwrap();
+        let out2 = coalescer
+            .next_completed_batch()
+            .expect("expected second batch");
+        assert_eq!(out2.num_rows(), 2); // The second batch
+
+        // check contents
+        let expected1 = uint32_batch(0..3);
+        let expected2 = uint32_batch(3..5);
+        assert_eq!(normalize_batch(out1), normalize_batch(expected1));
+        assert_eq!(normalize_batch(out2), normalize_batch(expected2));
     }
 
     #[test]
     fn test_non_strict_single_large_batch() {
-        // one large batch > target: in non-strict mode whole batch should be emitted
+        // one large batch > target: should be split into target-sized chunks
         let batch = uint32_batch(0..4096);
         let schema = Arc::clone(&batch.schema());
         let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 1000).with_exact_size(false);
 
         coalescer.push_batch(batch).unwrap();
-        let out = coalescer
-            .next_completed_batch()
-            .expect("expected a completed batch");
-        assert_eq!(out.num_rows(), 4096);
-
-        // compare to expected
-        let expected = uint32_batch(0..4096);
-        let actual = normalize_batch(out);
-        let expected = normalize_batch(expected);
-        assert_eq!(expected, actual);
+
+        // NEW BEHAVIOR: large batch should be split into chunks of target_batch_size
+        // 4096 / 1000 = 4 full batches + 96 remainder
+        let mut outputs = vec![];
+        while let Some(b) = coalescer.next_completed_batch() {
+            outputs.push(b);
+        }
+
+        assert_eq!(outputs.len(), 4); // 4 full batches emitted immediately
+
+        // Each should be exactly 1000 rows
+        for (i, out) in outputs.iter().enumerate() {
+            assert_eq!(out.num_rows(), 1000);
+            let expected = uint32_batch((i * 1000) as u32..((i + 1) * 1000) as u32);
+            assert_eq!(normalize_batch(out.clone()), normalize_batch(expected));
+        }
+
+        // Remaining 96 rows should be buffered
+        coalescer.finish_buffered_batch().unwrap();
+        let final_batch = coalescer.next_completed_batch().expect("expected final batch");
+        assert_eq!(final_batch.num_rows(), 96);
+        let expected_final = uint32_batch(4000..4096);
+        assert_eq!(normalize_batch(final_batch), normalize_batch(expected_final));
     }
 
     #[test]
@@ -1439,71 +1497,104 @@ mod tests {
 
     #[test]
     fn test_non_strict_multiple_emits_over_time() {
-        // multiple pushes that each eventually push buffered >= target and emit
+        // multiple pushes with early emit behavior
         let b1 = uint32_batch(0..3); // 3
-        let b2 = uint32_batch(3..5); // 2 -> 3+2=5 emit (first)
-        let b3 = uint32_batch(5..8); // 3
-        let b4 = uint32_batch(8..10); // 2 -> 3+2=5 emit (second)
+        let b2 = uint32_batch(3..5); // 2 -> 3+2=5 > 4, emit 3 first
+        let b3 = uint32_batch(5..8); // 3 -> 2+3=5 > 4, emit 2 first
+        let b4 = uint32_batch(8..10); // 2 -> 3+2=5 > 4, emit 3 first
 
         let schema = Arc::clone(&b1.schema());
         let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 4).with_exact_size(false);
 
+        // Push first batch (3 rows) -> buffered
         coalescer.push_batch(b1).unwrap();
         assert!(coalescer.next_completed_batch().is_none());
 
+        // Push second batch (2 rows) -> 3+2=5 > 4, emit buffered 3 rows first
         coalescer.push_batch(b2).unwrap();
         let out1 = coalescer
             .next_completed_batch()
             .expect("expected first batch");
-        assert_eq!(out1.num_rows(), 5);
-        assert_eq!(normalize_batch(out1), normalize_batch(uint32_batch(0..5)));
+        assert_eq!(out1.num_rows(), 3);
+        assert_eq!(normalize_batch(out1), normalize_batch(uint32_batch(0..3)));
 
+        // Now 2 rows from b2 are buffered, push b3 (3 rows) -> 2+3=5 > 4, emit 2 rows first
         coalescer.push_batch(b3).unwrap();
-        assert!(coalescer.next_completed_batch().is_none());
-
-        coalescer.push_batch(b4).unwrap();
         let out2 = coalescer
             .next_completed_batch()
             .expect("expected second batch");
-        assert_eq!(out2.num_rows(), 5);
-        assert_eq!(normalize_batch(out2), normalize_batch(uint32_batch(5..10)));
+        assert_eq!(out2.num_rows(), 2);
+        assert_eq!(normalize_batch(out2), normalize_batch(uint32_batch(3..5)));
+
+        // Now 3 rows from b3 are buffered, push b4 (2 rows) -> 3+2=5 > 4, emit 3 rows first
+        coalescer.push_batch(b4).unwrap();
+        let out3 = coalescer
+            .next_completed_batch()
+            .expect("expected third batch");
+        assert_eq!(out3.num_rows(), 3);
+        assert_eq!(normalize_batch(out3), normalize_batch(uint32_batch(5..8)));
+
+        // Finish to get remaining 2 rows from b4
+        coalescer.finish_buffered_batch().unwrap();
+        let out4 = coalescer
+            .next_completed_batch()
+            .expect("expected fourth batch");
+        assert_eq!(out4.num_rows(), 2);
+        assert_eq!(normalize_batch(out4), normalize_batch(uint32_batch(8..10)));
     }
 
     #[test]
     fn test_non_strict_large_then_more_outputs() {
-        // first push a large batch (should produce one big output), then push more small ones to produce another
+        // first push a large batch (should be split), then push more small ones
         let big = uint32_batch(0..5000);
         let small1 = uint32_batch(5000..5002); // 2
-        let small2 = uint32_batch(5002..5005); // 3 -> 2+3=5 >=4 emit
+        let small2 = uint32_batch(5002..5005); // 3 -> 2+3=5 > 4, emit 2 first
 
         let schema = Arc::clone(&big.schema());
-        // Use small target (4) so that small1 + small2 will trigger an emit
         let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 4).with_exact_size(false);
 
-        // push big: non-strict mode should emit the whole big batch (5000 rows)
+        // push big: should be split into chunks of 4
+        // 5000 / 4 = 1250 full batches
         coalescer.push_batch(big).unwrap();
-        let out_big = coalescer
-            .next_completed_batch()
-            .expect("expected big batch");
-        assert_eq!(out_big.num_rows(), 5000);
-        assert_eq!(
-            normalize_batch(out_big),
-            normalize_batch(uint32_batch(0..5000))
-        );
 
-        // push small1 (2 rows) -> not enough yet
+        let mut big_outputs = vec![];
+        while let Some(b) = coalescer.next_completed_batch() {
+            big_outputs.push(b);
+        }
+
+        assert_eq!(big_outputs.len(), 1250); // 1250 batches of 4 rows each
+        for (i, out) in big_outputs.iter().enumerate() {
+            assert_eq!(out.num_rows(), 4);
+            let start = i * 4;
+            let end = (i + 1) * 4;
+            let expected = uint32_batch(start as u32..end as u32);
+            assert_eq!(normalize_batch(out.clone()), normalize_batch(expected));
+        }
+
+        // push small1 (2 rows) -> buffered
         coalescer.push_batch(small1).unwrap();
         assert!(coalescer.next_completed_batch().is_none());
 
-        // push small2 (3 rows) -> now buffered = 2 + 3 = 5 >= 4, non-strict emits all 5 rows
+        // push small2 (3 rows) -> 2+3=5 > 4, emit buffered 2 rows first
         coalescer.push_batch(small2).unwrap();
-        let out_small = coalescer
+        let out_small1 = coalescer
+            .next_completed_batch()
+            .expect("expected small batch 1");
+        assert_eq!(out_small1.num_rows(), 2);
+        assert_eq!(
+            normalize_batch(out_small1),
+            normalize_batch(uint32_batch(5000..5002))
+        );
+
+        // Finish to get remaining 3 rows from small2
+        coalescer.finish_buffered_batch().unwrap();
+        let out_small2 = coalescer
             .next_completed_batch()
-            .expect("expected small batch");
-        assert_eq!(out_small.num_rows(), 5);
+            .expect("expected small batch 2");
+        assert_eq!(out_small2.num_rows(), 3);
         assert_eq!(
-            normalize_batch(out_small),
-            normalize_batch(uint32_batch(5000..5005))
+            normalize_batch(out_small2),
+            normalize_batch(uint32_batch(5002..5005))
         );
     }
 }

if self.buffered_rows >= self.target_batch_size {
self.finish_buffered_batch()?;
}
}

// clear in progress sources (to allow the memory to be freed)
// clear sources
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.set_source(None);
}
Expand Down Expand Up @@ -1314,4 +1366,144 @@ mod tests {
let options = RecordBatchOptions::new().with_row_count(Some(row_count));
RecordBatch::try_new_with_options(schema, columns, &options).unwrap()
}

// Adding tests for exact_size setting to false
#[test]
fn test_non_coalesce_small_batches() {
// two small batches -> combined when buffered >= target
let batch1 = uint32_batch(0..3); // 3 rows
let batch2 = uint32_batch(3..5); // 2 rows

let schema = Arc::clone(&batch1.schema());
let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 4).with_exact_size(false);

// push first batch (3 rows) -> not enough
coalescer.push_batch(batch1).unwrap();
assert!(coalescer.next_completed_batch().is_none());

// push second batch (2 rows) -> buffered becomes 5 >= 4, non-strict emits all 5 rows
coalescer.push_batch(batch2).unwrap();
let out = coalescer
.next_completed_batch()
.expect("expected a completed batch");
assert_eq!(out.num_rows(), 5);

// check contents equal to concatenation of 0..5
let expected = uint32_batch(0..5);
let actual = normalize_batch(out);
let expected = normalize_batch(expected);
assert_eq!(expected, actual);
}

#[test]
fn test_non_strict_single_large_batch() {
// one large batch > target: in non-strict mode whole batch should be emitted
let batch = uint32_batch(0..4096);
let schema = Arc::clone(&batch.schema());
let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 1000).with_exact_size(false);

coalescer.push_batch(batch).unwrap();
let out = coalescer
.next_completed_batch()
.expect("expected a completed batch");
assert_eq!(out.num_rows(), 4096);

// compare to expected
let expected = uint32_batch(0..4096);
let actual = normalize_batch(out);
let expected = normalize_batch(expected);
assert_eq!(expected, actual);
}

#[test]
fn test_strict_single_large_batch_multiple_outputs() {
// single large batch -> split into multiple exact target batches
let batch = uint32_batch(0..5000);
let schema = Arc::clone(&batch.schema());
let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 1000).with_exact_size(true);

coalescer.push_batch(batch).unwrap();

// should emit 5 batches of 1000 each
let mut outputs = vec![];
while let Some(b) = coalescer.next_completed_batch() {
outputs.push(b);
}
assert_eq!(outputs.len(), 5);
for (i, out) in outputs.into_iter().enumerate() {
assert_eq!(out.num_rows(), 1000);
let expected = uint32_batch((i * 1000) as u32..((i + 1) * 1000) as u32);
assert_eq!(normalize_batch(out), normalize_batch(expected));
}
}

#[test]
fn test_non_strict_multiple_emits_over_time() {
// multiple pushes that each eventually push buffered >= target and emit
let b1 = uint32_batch(0..3); // 3
let b2 = uint32_batch(3..5); // 2 -> 3+2=5 emit (first)
let b3 = uint32_batch(5..8); // 3
let b4 = uint32_batch(8..10); // 2 -> 3+2=5 emit (second)

let schema = Arc::clone(&b1.schema());
let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 4).with_exact_size(false);

coalescer.push_batch(b1).unwrap();
assert!(coalescer.next_completed_batch().is_none());

coalescer.push_batch(b2).unwrap();
let out1 = coalescer
.next_completed_batch()
.expect("expected first batch");
assert_eq!(out1.num_rows(), 5);
assert_eq!(normalize_batch(out1), normalize_batch(uint32_batch(0..5)));

coalescer.push_batch(b3).unwrap();
assert!(coalescer.next_completed_batch().is_none());

coalescer.push_batch(b4).unwrap();
let out2 = coalescer
.next_completed_batch()
.expect("expected second batch");
assert_eq!(out2.num_rows(), 5);
assert_eq!(normalize_batch(out2), normalize_batch(uint32_batch(5..10)));
}

#[test]
fn test_non_strict_large_then_more_outputs() {
// first push a large batch (should produce one big output), then push more small ones to produce another
let big = uint32_batch(0..5000);
let small1 = uint32_batch(5000..5002); // 2
let small2 = uint32_batch(5002..5005); // 3 -> 2+3=5 >=4 emit

let schema = Arc::clone(&big.schema());
// Use small target (4) so that small1 + small2 will trigger an emit
let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 4).with_exact_size(false);

// push big: non-strict mode should emit the whole big batch (5000 rows)
coalescer.push_batch(big).unwrap();
let out_big = coalescer
.next_completed_batch()
.expect("expected big batch");
assert_eq!(out_big.num_rows(), 5000);
assert_eq!(
normalize_batch(out_big),
normalize_batch(uint32_batch(0..5000))
);

// push small1 (2 rows) -> not enough yet
coalescer.push_batch(small1).unwrap();
assert!(coalescer.next_completed_batch().is_none());

// push small2 (3 rows) -> now buffered = 2 + 3 = 5 >= 4, non-strict emits all 5 rows
coalescer.push_batch(small2).unwrap();
let out_small = coalescer
.next_completed_batch()
.expect("expected small batch");
assert_eq!(out_small.num_rows(), 5);
assert_eq!(
normalize_batch(out_small),
normalize_batch(uint32_batch(5000..5005))
);
}
}
2 changes: 1 addition & 1 deletion arrow/benches/coalesce_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ fn filter_streams(
) {
let schema = data_stream.schema();
let batch_size = data_stream.batch_size();
let mut coalescer = BatchCoalescer::new(Arc::clone(schema), batch_size);
let mut coalescer = BatchCoalescer::new(Arc::clone(schema), batch_size).with_exact_size(false);

while num_output_batches > 0 {
let filter = filter_stream.next_filter();
Expand Down
Loading