From 0ee7d9ee8886187eacb0a2d96c6efcc22c8a34cb Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 3 Jul 2025 15:33:40 -0700 Subject: [PATCH 1/8] Avoid O(N^2) in VALUES with OrdinalGrouping --- .../operator/ValuesAggregatorBenchmark.java | 34 ++- .../aggregation/ValuesBytesRefAggregator.java | 119 +++++++-- .../aggregation/ValuesDoubleAggregator.java | 82 +++++-- .../aggregation/ValuesFloatAggregator.java | 94 ++++++-- .../aggregation/ValuesIntAggregator.java | 94 ++++++-- .../aggregation/ValuesLongAggregator.java | 82 +++++-- ...uesBytesRefGroupingAggregatorFunction.java | 2 +- ...aluesDoubleGroupingAggregatorFunction.java | 2 +- ...ValuesFloatGroupingAggregatorFunction.java | 2 +- .../ValuesIntGroupingAggregatorFunction.java | 2 +- .../ValuesLongGroupingAggregatorFunction.java | 2 +- .../ValuesBytesRefAggregators.java | 12 +- .../aggregation/X-ValuesAggregator.java.st | 228 +++++++++++------- .../operator/HashAggregationOperator.java | 2 +- 14 files changed, 533 insertions(+), 224 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index 238540bf2c799..c082079362cc4 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -65,7 +65,7 @@ @Fork(1) public class ValuesAggregatorBenchmark { static final int MIN_BLOCK_LENGTH = 8 * 1024; - private static final int OP_COUNT = 1024; + private static final int OP_COUNT = 20; private static final int UNIQUE_VALUES = 6; private static final BytesRef[] KEYWORDS = new BytesRef[] { new BytesRef("Tokyo"), @@ -95,7 +95,8 @@ static void selfTest() { try { for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) { for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) { - run(Integer.parseInt(groups), dataType, 10); + run(Integer.parseInt(groups), dataType, 10, 0); + run(Integer.parseInt(groups), dataType, 10, 1); } } } catch (NoSuchFieldException e) { @@ -113,7 +114,10 @@ static void selfTest() { @Param({ BYTES_REF, INT, LONG }) public String dataType; - private static Operator operator(DriverContext driverContext, int groups, String dataType) { + @Param({ "0", "1"}) + public int numOrdinalMerges; + + private static Operator operator(DriverContext driverContext, int groups, String dataType, int numOrdinalMerges) { if (groups == 1) { return new AggregationOperator( List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)), @@ -125,7 +129,23 @@ private static Operator operator(DriverContext driverContext, int groups, String List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))), () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false), driverContext - ); + ) { + @Override + public Page getOutput() { + mergeOrdinal(); + return super.getOutput(); + } + // simulate OrdinalsGroupingOperator + void mergeOrdinal() { + var merged = supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1)).apply(driverContext); + for (int i = 0; i < numOrdinalMerges; i++) { + for (int p = 0; p < groups; p++) { + merged.addIntermediateRow(p, aggregators.getFirst(), p); + } + } + aggregators.set(0, merged); + } + }; } private static AggregatorFunctionSupplier supplier(String dataType) { @@ -331,12 +351,12 @@ private static Block groupingBlock(int groups) { @Benchmark public void run() { - run(groups, dataType, OP_COUNT); + run(groups, dataType, OP_COUNT, numOrdinalMerges); } - private static void run(int groups, String dataType, int opCount) { + private static void run(int groups, String dataType, int opCount, int numOrdinalMerges) { DriverContext driverContext = driverContext(); - try (Operator operator = operator(driverContext, groups, dataType)) { + try (Operator operator = operator(driverContext, groups, dataType, numOrdinalMerges)) { Page page = page(groups, dataType); for (int i = 0; i < opCount; i++) { operator.addInput(page.shallowCopy()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 51195578ac363..c7a21ef1a3075 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -24,6 +24,7 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; /** @@ -55,8 +56,8 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static GroupingAggregatorFunction.AddInput wrapAddInput( @@ -76,7 +77,7 @@ public static GroupingAggregatorFunction.AddInput wrapAddInput( } public static void combine(GroupingState state, int groupId, BytesRef v) { - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { @@ -84,17 +85,17 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getBytesRef(i, scratch)); + state.addValue(groupId, values.getBytesRef(i, scratch)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - BytesRef scratch = new BytesRef(); - for (int id = 0; id < state.values.size(); id++) { - if (state.values.getKey1(id) == statePosition) { - long value = state.values.getKey2(id); - combine(current, currentGroupId, state.bytes.get(value, scratch)); - } + var sorted = state.sortedForOrdinalMerging(current); + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValueOrdinal(currentGroupId, id); } } @@ -138,6 +139,18 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -146,15 +159,20 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { - final LongLongHash values; + private int maxGroupId = -1; + private final BlockFactory blockFactory; + private final LongLongHash values; BytesRefHash bytes; - private GroupingState(BigArrays bigArrays) { + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); LongLongHash _values = null; BytesRefHash _bytes = null; try { - _values = new LongLongHash(1, bigArrays); - _bytes = new BytesRefHash(1, bigArrays); + _values = new LongLongHash(1, driverContext.bigArrays()); + _bytes = new BytesRefHash(1, driverContext.bigArrays()); values = _values; bytes = _bytes; @@ -171,6 +189,16 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValueOrdinal(int groupId, long valueOrdinal) { + values.add(groupId, valueOrdinal); + maxGroupId = Math.max(maxGroupId, groupId); + } + + void addValue(int groupId, BytesRef v) { + values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -180,8 +208,19 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } else { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -256,14 +295,43 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { - return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids); - } else { - return buildOutputBlock(blockFactory, selected, selectedCounts, ids); - } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } + + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + // hash all the bytes to the destination to avoid hashing them multiple times + BytesRef scratch = new BytesRef(); + final int totalValue = Math.toIntExact(bytes.size()); + blockFactory.adjustBreaker((long) totalValue * Integer.BYTES); + try { + final int[] mappedIds = new int[totalValue]; + for (int i = 0; i < totalValue; i++) { + var v = bytes.get(i, scratch); + mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v))); + } + // no longer need the bytes + bytes.close(); + bytes = null; + int[] ids = sortedForOrdinalMerging.ids; + for (int i = 0; i < ids.length; i++) { + ids[i] = mappedIds[Math.toIntExact(values.getKey2(ids[i]))]; + } + } finally { + blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); + } + } } + return sortedForOrdinalMerging; } Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { @@ -279,11 +347,11 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele int count = end - start; switch (count) { case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start], scratch); + case 1 -> builder.appendBytesRef(getValue(ids[start], scratch)); default -> { builder.beginPositionEntry(); for (int i = start; i < end; i++) { - append(builder, ids[i], scratch); + builder.appendBytesRef(getValue(ids[i], scratch)); } builder.endPositionEntry(); } @@ -331,9 +399,8 @@ Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int } } - private void append(BytesRefBlock.Builder builder, int id, BytesRef scratch) { - BytesRef value = bytes.get(values.getKey2(id), scratch); - builder.appendBytesRef(value); + BytesRef getValue(int valueId, BytesRef scratch) { + return bytes.get(values.getKey2(valueId), scratch); } @Override @@ -343,7 +410,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index f5b0d519dd890..1404f93c27f8f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for double. @@ -48,28 +50,29 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, double v) { - state.values.add(groupId, Double.doubleToLongBits(v)); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getDouble(i)); + state.addValue(groupId, values.getDouble(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - if (state.values.getKey1(id) == statePosition) { - double value = Double.longBitsToDouble(state.values.getKey2(id)); - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -112,6 +115,18 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -120,10 +135,15 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongLongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongLongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongLongHash(1, driverContext.bigArrays()); } @Override @@ -131,6 +151,11 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, double v) { + values.add(groupId, Double.doubleToLongBits(v)); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -140,8 +165,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -216,12 +248,25 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } - return buildOutputBlock(blockFactory, selected, selectedCounts, ids); + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } } } + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -234,11 +279,11 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele int count = end - start; switch (count) { case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); + case 1 -> builder.appendDouble(getValue(ids[start])); default -> { builder.beginPositionEntry(); for (int i = start; i < end; i++) { - append(builder, ids[i]); + builder.appendDouble(getValue(ids[i])); } builder.endPositionEntry(); } @@ -249,9 +294,8 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele } } - private void append(DoubleBlock.Builder builder, int id) { - double value = Double.longBitsToDouble(values.getKey2(id)); - builder.appendDouble(value); + double getValue(int valueId) { + return Double.longBitsToDouble(values.getKey2(valueId)); } @Override @@ -261,7 +305,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index 4cfbf329a895d..bb43c75eb39ba 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for float. @@ -47,34 +49,29 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, float v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getFloat(i)); + state.addValue(groupId, values.getFloat(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - long both = state.values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group == statePosition) { - float value = Float.intBitsToFloat((int) both); - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -117,6 +114,18 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -125,10 +134,15 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongHash(1, driverContext.bigArrays()); } @Override @@ -136,6 +150,15 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, float v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -145,8 +168,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -223,10 +253,23 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } - return buildOutputBlock(blockFactory, selected, selectedCounts, ids); + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } + + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } } + return sortedForOrdinalMerging; } Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { @@ -241,11 +284,11 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele int count = end - start; switch (count) { case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); + case 1 -> builder.appendFloat(getValue(ids[start])); default -> { builder.beginPositionEntry(); for (int i = start; i < end; i++) { - append(builder, ids[i]); + builder.appendFloat(getValue(ids[i])); } builder.endPositionEntry(); } @@ -256,10 +299,9 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele } } - private void append(FloatBlock.Builder builder, int id) { - long both = values.get(id); - float value = Float.intBitsToFloat((int) both); - builder.appendFloat(value); + float getValue(int valueId) { + long both = values.get(valueId); + return Float.intBitsToFloat((int) both); } @Override @@ -269,7 +311,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 38e5ad99cf581..4bf7125051efa 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for int. @@ -47,34 +49,29 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, int v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getInt(i)); + state.addValue(groupId, values.getInt(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - long both = state.values.get(id); - int group = (int) (both >>> Integer.SIZE); - if (group == statePosition) { - int value = (int) both; - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -117,6 +114,18 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -125,10 +134,15 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongHash(1, driverContext.bigArrays()); } @Override @@ -136,6 +150,15 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, int v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -145,8 +168,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -223,10 +253,23 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } - return buildOutputBlock(blockFactory, selected, selectedCounts, ids); + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } + + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } } + return sortedForOrdinalMerging; } Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { @@ -241,11 +284,11 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele int count = end - start; switch (count) { case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); + case 1 -> builder.appendInt(getValue(ids[start])); default -> { builder.beginPositionEntry(); for (int i = start; i < end; i++) { - append(builder, ids[i]); + builder.appendInt(getValue(ids[i])); } builder.endPositionEntry(); } @@ -256,10 +299,9 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele } } - private void append(IntBlock.Builder builder, int id) { - long both = values.get(id); - int value = (int) both; - builder.appendInt(value); + int getValue(int valueId) { + long both = values.get(valueId); + return (int) both; } @Override @@ -269,7 +311,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 4bfc230d7e1f7..56c6912f58a76 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for long. @@ -48,28 +50,29 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, long v) { - state.values.add(groupId, v); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getLong(i)); + state.addValue(groupId, values.getLong(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - if (state.values.getKey1(id) == statePosition) { - long value = state.values.getKey2(id); - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -112,6 +115,18 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -120,10 +135,15 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongLongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongLongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongLongHash(1, driverContext.bigArrays()); } @Override @@ -131,6 +151,11 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, long v) { + values.add(groupId, v); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -140,8 +165,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -216,12 +248,25 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } - return buildOutputBlock(blockFactory, selected, selectedCounts, ids); + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } } } + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -234,11 +279,11 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele int count = end - start; switch (count) { case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); + case 1 -> builder.appendLong(getValue(ids[start])); default -> { builder.beginPositionEntry(); for (int i = start; i < end; i++) { - append(builder, ids[i]); + builder.appendLong(getValue(ids[i])); } builder.endPositionEntry(); } @@ -249,9 +294,8 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele } } - private void append(LongBlock.Builder builder, int id) { - long value = values.getKey2(id); - builder.appendLong(value); + long getValue(int valueId) { + return values.getKey2(valueId); } @Override @@ -261,7 +305,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index 28843942b73cb..ebbc4cd5eb8a3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -43,7 +43,7 @@ public ValuesBytesRefGroupingAggregatorFunction(List channels, public static ValuesBytesRefGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java index 76c865b33fd09..e61ffacb17274 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java @@ -42,7 +42,7 @@ public ValuesDoubleGroupingAggregatorFunction(List channels, public static ValuesDoubleGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java index bed9a884ccd10..d7eb4bc97bacb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java @@ -42,7 +42,7 @@ public ValuesFloatGroupingAggregatorFunction(List channels, public static ValuesFloatGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java index fb801eadcf5cd..bd34ac0d27098 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java @@ -41,7 +41,7 @@ public ValuesIntGroupingAggregatorFunction(List channels, public static ValuesIntGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java index 061af9fcc9213..39f485f3b174d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java @@ -42,7 +42,7 @@ public ValuesLongGroupingAggregatorFunction(List channels, public static ValuesLongGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java index 78a083b8daac7..4a2fa0923abe4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java @@ -55,7 +55,7 @@ public void add(int positionOffset, IntArrayBlock groupIds) { int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v))); } } } @@ -77,7 +77,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v))); } } } @@ -93,7 +93,7 @@ public void add(int positionOffset, IntVector groupIds) { int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v))); } } } @@ -135,7 +135,7 @@ public void add(int positionOffset, IntArrayBlock groupIds) { int groupEnd = groupStart + groupIds.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groupIds.getInt(g); - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); } } } @@ -150,7 +150,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { int groupEnd = groupStart + groupIds.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groupIds.getInt(g); - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); } } } @@ -159,7 +159,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { public void add(int positionOffset, IntVector groupIds) { for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { int groupId = groupIds.getInt(groupPosition); - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 67f32fc4a4d4e..70b2f487cef36 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -43,12 +43,9 @@ $if(BytesRef)$ import org.elasticsearch.compute.data.OrdinalBytesRefBlock; $endif$ import org.elasticsearch.compute.operator.DriverContext; -$if(BytesRef)$ +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -$else$ - -$endif$ /** * Aggregates field values for $type$. * This class is generated. Edit @{code X-ValuesAggregator.java.st} instead @@ -90,8 +87,8 @@ $endif$ return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } $if(BytesRef)$ @@ -113,25 +110,7 @@ $if(BytesRef)$ $endif$ public static void combine(GroupingState state, int groupId, $type$ v) { -$if(long)$ - state.values.add(groupId, v); -$elseif(double)$ - state.values.add(groupId, Double.doubleToLongBits(v)); -$elseif(BytesRef)$ - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); -$elseif(int)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); -$elseif(float)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); -$endif$ + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { @@ -142,37 +121,24 @@ $endif$ int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { $if(BytesRef)$ - combine(state, groupId, values.getBytesRef(i, scratch)); + state.addValue(groupId, values.getBytesRef(i, scratch)); $else$ - combine(state, groupId, values.get$Type$(i)); + state.addValue(groupId, values.get$Type$(i)); $endif$ } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { + var sorted = state.sortedForOrdinalMerging(current); + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; $if(BytesRef)$ - BytesRef scratch = new BytesRef(); -$endif$ - for (int id = 0; id < state.values.size(); id++) { -$if(long||BytesRef)$ - if (state.values.getKey1(id) == statePosition) { - long value = state.values.getKey2(id); -$elseif(double)$ - if (state.values.getKey1(id) == statePosition) { - double value = Double.longBitsToDouble(state.values.getKey2(id)); -$elseif(int)$ - long both = state.values.get(id); - int group = (int) (both >>> Integer.SIZE); - if (group == statePosition) { - int value = (int) both; -$elseif(float)$ - long both = state.values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group == statePosition) { - float value = Float.intBitsToFloat((int) both); + current.addValueOrdinal(currentGroupId, id); +$else$ + current.addValue(currentGroupId, state.getValue(id)); $endif$ - combine(current, currentGroupId, $if(BytesRef)$state.bytes.get(value, scratch)$else$value$endif$); - } } } @@ -247,6 +213,18 @@ $endif$ } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -255,26 +233,31 @@ $endif$ * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; $if(long||double)$ private final LongLongHash values; $elseif(BytesRef)$ - final LongLongHash values; + private final LongLongHash values; BytesRefHash bytes; $elseif(int||float)$ private final LongHash values; $endif$ - private GroupingState(BigArrays bigArrays) { + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); $if(long||double)$ - values = new LongLongHash(1, bigArrays); + values = new LongLongHash(1, driverContext.bigArrays()); $elseif(BytesRef)$ LongLongHash _values = null; BytesRefHash _bytes = null; try { - _values = new LongLongHash(1, bigArrays); - _bytes = new BytesRefHash(1, bigArrays); + _values = new LongLongHash(1, driverContext.bigArrays()); + _bytes = new BytesRefHash(1, driverContext.bigArrays()); values = _values; bytes = _bytes; @@ -285,7 +268,7 @@ $elseif(BytesRef)$ Releasables.closeExpectNoException(_values, _bytes); } $elseif(int||float)$ - values = new LongHash(1, bigArrays); + values = new LongHash(1, driverContext.bigArrays()); $endif$ } @@ -294,6 +277,36 @@ $endif$ blocks[offset] = toBlock(driverContext.blockFactory(), selected); } +$if(BytesRef)$ + void addValueOrdinal(int groupId, long valueOrdinal) { + values.add(groupId, valueOrdinal); + maxGroupId = Math.max(maxGroupId, groupId); + } + +$endif$ + void addValue(int groupId, $type$ v) { +$if(long)$ + values.add(groupId, v); +$elseif(double)$ + values.add(groupId, Double.doubleToLongBits(v)); +$elseif(BytesRef)$ + values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); +$elseif(int)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); +$elseif(float)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); +$endif$ + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -303,8 +316,23 @@ $endif$ return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { +$if(BytesRef)$ + if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } else { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } +$else$ + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); +$endif$ + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -319,12 +347,12 @@ $endif$ selectedCountsSize = adjust; int[] selectedCounts = new int[selectedCountsLen]; for (int id = 0; id < values.size(); id++) { -$if(long||BytesRef||double)$ + $if(long||BytesRef||double)$ int group = (int) values.getKey1(id); -$elseif(float||int)$ + $elseif(float||int)$ long both = values.get(id); int group = (int) (both >>> Float.SIZE); -$endif$ + $endif$ if (group < selectedCounts.length) { selectedCounts[group]--; } @@ -379,28 +407,55 @@ $endif$ idsSize = adjust; int[] ids = new int[total]; for (int id = 0; id < values.size(); id++) { -$if(long||BytesRef||double)$ + $if(long||BytesRef||double)$ int group = (int) values.getKey1(id); -$elseif(float||int)$ + $elseif(float||int)$ long both = values.get(id); int group = (int) (both >>> Float.SIZE); -$endif$ + $endif$ if (group < selectedCounts.length && selectedCounts[group] >= 0) { ids[selectedCounts[group]++] = id; } } -$if(BytesRef)$ - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { - return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids); - } else { - return buildOutputBlock(blockFactory, selected, selectedCounts, ids); + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); } -$else$ - return buildOutputBlock(blockFactory, selected, selectedCounts, ids); + } + } + + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); +$if(BytesRef)$ + // hash all the bytes to the destination to avoid hashing them multiple times + BytesRef scratch = new BytesRef(); + final int totalValue = Math.toIntExact(bytes.size()); + blockFactory.adjustBreaker((long) totalValue * Integer.BYTES); + try { + final int[] mappedIds = new int[totalValue]; + for (int i = 0; i < totalValue; i++) { + var v = bytes.get(i, scratch); + mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v))); + } + // no longer need the bytes + bytes.close(); + bytes = null; + int[] ids = sortedForOrdinalMerging.ids; + for (int i = 0; i < ids.length; i++) { + ids[i] = mappedIds[Math.toIntExact(values.getKey2(ids[i]))]; + } + } finally { + blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); + } $endif$ - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } } + return sortedForOrdinalMerging; } Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { @@ -418,11 +473,11 @@ $endif$ int count = end - start; switch (count) { case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]$if(BytesRef)$, scratch$endif$); + case 1 -> builder.append$Type$(getValue(ids[start]$if(BytesRef)$, scratch$endif$)); default -> { builder.beginPositionEntry(); for (int i = start; i < end; i++) { - append(builder, ids[i]$if(BytesRef)$, scratch$endif$); + builder.append$Type$(getValue(ids[i]$if(BytesRef)$, scratch$endif$)); } builder.endPositionEntry(); } @@ -470,29 +525,24 @@ $if(BytesRef)$ } } } +$endif$ - private void append($Type$Block.Builder builder, int id, BytesRef scratch) { - BytesRef value = bytes.get(values.getKey2(id), scratch); - builder.appendBytesRef(value); - } - -$else$ - private void append($Type$Block.Builder builder, int id) { -$if(long)$ - long value = values.getKey2(id); + $type$ getValue(int valueId$if(BytesRef)$, BytesRef scratch$endif$) { +$if(BytesRef)$ + return bytes.get(values.getKey2(valueId), scratch); +$elseif(long)$ + return values.getKey2(valueId); $elseif(double)$ - double value = Double.longBitsToDouble(values.getKey2(id)); + return Double.longBitsToDouble(values.getKey2(valueId)); $elseif(float)$ - long both = values.get(id); - float value = Float.intBitsToFloat((int) both); + long both = values.get(valueId); + return Float.intBitsToFloat((int) both); $elseif(int)$ - long both = values.get(id); - int value = (int) both; + long both = values.get(valueId); + return (int) both; $endif$ - builder.append$Type$(value); } -$endif$ @Override public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block @@ -501,9 +551,9 @@ $endif$ @Override public void close() { $if(BytesRef)$ - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging); $else$ - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); $endif$ } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 9c4b9dd360062..cbce712ed9cdb 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -85,7 +85,7 @@ public String describe() { private final BlockHash blockHash; - private final List aggregators; + protected final List aggregators; protected final DriverContext driverContext; From 814abc7cd0bfc27087e427458b2e0a9a1786fc30 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 4 Jul 2025 03:59:05 +0000 Subject: [PATCH 2/8] [CI] Auto commit changes from spotless --- .../benchmark/compute/operator/ValuesAggregatorBenchmark.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index c082079362cc4..ca3662203ca7b 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -114,7 +114,7 @@ static void selfTest() { @Param({ BYTES_REF, INT, LONG }) public String dataType; - @Param({ "0", "1"}) + @Param({ "0", "1" }) public int numOrdinalMerges; private static Operator operator(DriverContext driverContext, int groups, String dataType, int numOrdinalMerges) { @@ -135,6 +135,7 @@ public Page getOutput() { mergeOrdinal(); return super.getOutput(); } + // simulate OrdinalsGroupingOperator void mergeOrdinal() { var merged = supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1)).apply(driverContext); From 4c44fa1e91c7cde8b90a7b7d29ca1ccf1599be46 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 3 Jul 2025 21:01:03 -0700 Subject: [PATCH 3/8] Update docs/changelog/130576.yaml --- docs/changelog/130576.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/130576.yaml diff --git a/docs/changelog/130576.yaml b/docs/changelog/130576.yaml new file mode 100644 index 0000000000000..29947d9259c6b --- /dev/null +++ b/docs/changelog/130576.yaml @@ -0,0 +1,5 @@ +pr: 130576 +summary: Avoid O(N^2) in VALUES with ordinals grouping +area: ES|QL +type: bug +issues: [] From 3e363b222822770aea8cdfb585b327d1f77b2468 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 7 Jul 2025 07:59:47 -0700 Subject: [PATCH 4/8] revert --- .../benchmark/compute/operator/ValuesAggregatorBenchmark.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index ca3662203ca7b..879418e7f954c 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -65,7 +65,7 @@ @Fork(1) public class ValuesAggregatorBenchmark { static final int MIN_BLOCK_LENGTH = 8 * 1024; - private static final int OP_COUNT = 20; + private static final int OP_COUNT = 1024; private static final int UNIQUE_VALUES = 6; private static final BytesRef[] KEYWORDS = new BytesRef[] { new BytesRef("Tokyo"), From d8e6b20e120349cd7705ff3d4eb72f0bcb0efda9 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 7 Jul 2025 08:03:58 -0700 Subject: [PATCH 5/8] indent --- .../compute/aggregation/X-ValuesAggregator.java.st | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 70b2f487cef36..3721c214cfa31 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -347,12 +347,12 @@ $endif$ selectedCountsSize = adjust; int[] selectedCounts = new int[selectedCountsLen]; for (int id = 0; id < values.size(); id++) { - $if(long||BytesRef||double)$ +$if(long||BytesRef||double)$ int group = (int) values.getKey1(id); - $elseif(float||int)$ +$elseif(float||int)$ long both = values.get(id); int group = (int) (both >>> Float.SIZE); - $endif$ +$endif$ if (group < selectedCounts.length) { selectedCounts[group]--; } From f57c0c33df1f5c9cc0509f2e685559e4ff789288 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 7 Jul 2025 08:12:22 -0700 Subject: [PATCH 6/8] inline --- .../compute/aggregation/ValuesBytesRefAggregator.java | 5 ++--- .../compute/aggregation/X-ValuesAggregator.java.st | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index c7a21ef1a3075..198de38d4fbe0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -322,9 +322,8 @@ private Sorted sortedForOrdinalMerging(GroupingState other) { // no longer need the bytes bytes.close(); bytes = null; - int[] ids = sortedForOrdinalMerging.ids; - for (int i = 0; i < ids.length; i++) { - ids[i] = mappedIds[Math.toIntExact(values.getKey2(ids[i]))]; + for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) { + sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))]; } } finally { blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 3721c214cfa31..24df2b5f2c4c7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -445,9 +445,8 @@ $if(BytesRef)$ // no longer need the bytes bytes.close(); bytes = null; - int[] ids = sortedForOrdinalMerging.ids; - for (int i = 0; i < ids.length; i++) { - ids[i] = mappedIds[Math.toIntExact(values.getKey2(ids[i]))]; + for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) { + sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))]; } } finally { blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); From 6401294ff5dc919e4d00cf5287ecf4f7fc1d367c Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 7 Jul 2025 08:31:31 -0700 Subject: [PATCH 7/8] javadoc for ids --- .../compute/aggregation/ValuesBytesRefAggregator.java | 4 ++++ .../compute/aggregation/ValuesDoubleAggregator.java | 2 ++ .../compute/aggregation/ValuesFloatAggregator.java | 2 ++ .../compute/aggregation/ValuesIntAggregator.java | 2 ++ .../compute/aggregation/ValuesLongAggregator.java | 2 ++ .../compute/aggregation/X-ValuesAggregator.java.st | 6 ++++++ 6 files changed, 18 insertions(+) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 198de38d4fbe0..d7afb3ff130b1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -143,6 +143,10 @@ public void close() { * Values are collected in a hash. Iterating over them in order (row by row) to build the output, * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + * If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)}, + * these are ordinals referring to the {@link GroupingState#bytes} in the target state. */ private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index 1404f93c27f8f..c35f22af53fd6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -119,6 +119,8 @@ public void close() { * Values are collected in a hash. Iterating over them in order (row by row) to build the output, * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. */ private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index bb43c75eb39ba..edca3a5faf7b4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -118,6 +118,8 @@ public void close() { * Values are collected in a hash. Iterating over them in order (row by row) to build the output, * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. */ private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 4bf7125051efa..53cb057f3166e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -118,6 +118,8 @@ public void close() { * Values are collected in a hash. Iterating over them in order (row by row) to build the output, * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. */ private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 56c6912f58a76..a5df3f694fc1c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -119,6 +119,8 @@ public void close() { * Values are collected in a hash. Iterating over them in order (row by row) to build the output, * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. */ private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 24df2b5f2c4c7..da821935adbbe 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -217,6 +217,12 @@ $endif$ * Values are collected in a hash. Iterating over them in order (row by row) to build the output, * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. +$if(BytesRef)$ + * If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)}, + * these are ordinals referring to the {@link GroupingState#bytes} in the target state. +$endif$ */ private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { @Override From 46e6999ef1ba19c90b1772b434c70e2b088795ce Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 7 Jul 2025 12:07:57 -0700 Subject: [PATCH 8/8] fix positions --- .../compute/aggregation/ValuesBytesRefAggregator.java | 3 +++ .../compute/aggregation/ValuesDoubleAggregator.java | 3 +++ .../compute/aggregation/ValuesFloatAggregator.java | 3 +++ .../elasticsearch/compute/aggregation/ValuesIntAggregator.java | 3 +++ .../compute/aggregation/ValuesLongAggregator.java | 3 +++ .../compute/aggregation/X-ValuesAggregator.java.st | 3 +++ 6 files changed, 18 insertions(+) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index d7afb3ff130b1..992fff70f90a0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -91,6 +91,9 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; var end = sorted.counts[statePosition]; for (int i = start; i < end; i++) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index c35f22af53fd6..9aebe85c0cd89 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -68,6 +68,9 @@ public static void combineIntermediate(GroupingState state, int groupId, DoubleB public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; var end = sorted.counts[statePosition]; for (int i = start; i < end; i++) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index edca3a5faf7b4..28c736783122d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -67,6 +67,9 @@ public static void combineIntermediate(GroupingState state, int groupId, FloatBl public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; var end = sorted.counts[statePosition]; for (int i = start; i < end; i++) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 53cb057f3166e..39dbcd155954c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -67,6 +67,9 @@ public static void combineIntermediate(GroupingState state, int groupId, IntBloc public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; var end = sorted.counts[statePosition]; for (int i = start; i < end; i++) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index a5df3f694fc1c..a7d3a8fe539df 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -68,6 +68,9 @@ public static void combineIntermediate(GroupingState state, int groupId, LongBlo public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; var end = sorted.counts[statePosition]; for (int i = start; i < end; i++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index da821935adbbe..6a6adca143e94 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -130,6 +130,9 @@ $endif$ public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; var end = sorted.counts[statePosition]; for (int i = start; i < end; i++) {