From 6f96c5b24577c5f72e7eb15f380a016253d5b7ba Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 7 Jul 2025 14:51:29 -0700 Subject: [PATCH] Avoid O(N^2) in VALUES with ordinals grouping (#130576) Using the VALUES aggregator with ordinals grouping led to accidental quadratic complexity. Queries like FROM .. | STATS ... VALUES(field) ... BY keyword-field are affected by this performance issue. This change caches a sorted structure - previously used to fix a similar O(N^2) problem when emitting the output block - during the merging phase of the OrdinalGroupingOperator. # Conflicts: # x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java # x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java # x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java # x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java # x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java # x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java # x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st --- .../operator/ValuesAggregatorBenchmark.java | 33 ++- docs/changelog/130576.yaml | 5 + .../aggregation/ValuesBytesRefAggregator.java | 158 ++++++++--- .../aggregation/ValuesDoubleAggregator.java | 126 ++++++--- .../aggregation/ValuesFloatAggregator.java | 138 +++++++--- .../aggregation/ValuesIntAggregator.java | 138 +++++++--- .../aggregation/ValuesLongAggregator.java | 126 ++++++--- ...uesBytesRefGroupingAggregatorFunction.java | 2 +- ...aluesDoubleGroupingAggregatorFunction.java | 2 +- ...ValuesFloatGroupingAggregatorFunction.java | 2 +- .../ValuesIntGroupingAggregatorFunction.java | 2 +- .../ValuesLongGroupingAggregatorFunction.java | 2 +- .../aggregation/X-ValuesAggregator.java.st | 255 +++++++++++------- .../operator/HashAggregationOperator.java | 2 +- 14 files changed, 678 insertions(+), 313 deletions(-) create mode 100644 docs/changelog/130576.yaml diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index 280e6274d84de..2522df371fc38 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -85,7 +85,8 @@ public class ValuesAggregatorBenchmark { try { for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) { for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) { - run(Integer.parseInt(groups), dataType, 10); + run(Integer.parseInt(groups), dataType, 10, 0); + run(Integer.parseInt(groups), dataType, 10, 1); } } } catch (NoSuchFieldException e) { @@ -103,7 +104,10 @@ public class ValuesAggregatorBenchmark { @Param({ BYTES_REF, INT, LONG }) public String dataType; - private static Operator operator(DriverContext driverContext, int groups, String dataType) { + @Param({ "0", "1" }) + public int numOrdinalMerges; + + private static Operator operator(DriverContext driverContext, int groups, String dataType, int numOrdinalMerges) { if (groups == 1) { return new AggregationOperator( List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)), @@ -115,7 +119,24 @@ private static Operator operator(DriverContext driverContext, int groups, String List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))), () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false), driverContext - ); + ) { + @Override + public Page getOutput() { + mergeOrdinal(); + return super.getOutput(); + } + + // simulate OrdinalsGroupingOperator + void mergeOrdinal() { + var merged = supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1)).apply(driverContext); + for (int i = 0; i < numOrdinalMerges; i++) { + for (int p = 0; p < groups; p++) { + merged.addIntermediateRow(p, aggregators.getFirst(), p); + } + } + aggregators.set(0, merged); + } + }; } private static AggregatorFunctionSupplier supplier(String dataType) { @@ -314,12 +335,12 @@ private static Block groupingBlock(int groups) { @Benchmark public void run() { - run(groups, dataType, OP_COUNT); + run(groups, dataType, OP_COUNT, numOrdinalMerges); } - private static void run(int groups, String dataType, int opCount) { + private static void run(int groups, String dataType, int opCount, int numOrdinalMerges) { DriverContext driverContext = driverContext(); - try (Operator operator = operator(driverContext, groups, dataType)) { + try (Operator operator = operator(driverContext, groups, dataType, numOrdinalMerges)) { Page page = page(groups, dataType); for (int i = 0; i < opCount; i++) { operator.addInput(page.shallowCopy()); diff --git a/docs/changelog/130576.yaml b/docs/changelog/130576.yaml new file mode 100644 index 0000000000000..29947d9259c6b --- /dev/null +++ b/docs/changelog/130576.yaml @@ -0,0 +1,5 @@ +pr: 130576 +summary: Avoid O(N^2) in VALUES with ordinals grouping +area: ES|QL +type: bug +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index f326492664fb8..e40e2c5db0cf0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -21,6 +21,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; /** @@ -52,12 +53,12 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, BytesRef v) { - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { @@ -65,17 +66,20 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getBytesRef(i, scratch)); + state.addValue(groupId, values.getBytesRef(i, scratch)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - BytesRef scratch = new BytesRef(); - for (int id = 0; id < state.values.size(); id++) { - if (state.values.getKey1(id) == statePosition) { - long value = state.values.getKey2(id); - combine(current, currentGroupId, state.bytes.get(value, scratch)); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValueOrdinal(currentGroupId, id); } } @@ -119,6 +123,22 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + * If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)}, + * these are ordinals referring to the {@link GroupingState#bytes} in the target state. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -127,15 +147,20 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongLongHash values; - private final BytesRefHash bytes; + private BytesRefHash bytes; + + private Sorted sortedForOrdinalMerging = null; - private GroupingState(BigArrays bigArrays) { + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); LongLongHash _values = null; BytesRefHash _bytes = null; try { - _values = new LongLongHash(1, bigArrays); - _bytes = new BytesRefHash(1, bigArrays); + _values = new LongLongHash(1, driverContext.bigArrays()); + _bytes = new BytesRefHash(1, driverContext.bigArrays()); values = _values; bytes = _bytes; @@ -152,6 +177,16 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValueOrdinal(int groupId, long valueOrdinal) { + values.add(groupId, valueOrdinal); + maxGroupId = Math.max(maxGroupId, groupId); + } + + void addValue(int groupId, BytesRef v) { + values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -161,8 +196,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -237,40 +279,74 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - BytesRef scratch = new BytesRef(); - try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start], scratch); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i], scratch); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + // hash all the bytes to the destination to avoid hashing them multiple times + BytesRef scratch = new BytesRef(); + final int totalValue = Math.toIntExact(bytes.size()); + blockFactory.adjustBreaker((long) totalValue * Integer.BYTES); + try { + final int[] mappedIds = new int[totalValue]; + for (int i = 0; i < totalValue; i++) { + BytesRef v = bytes.get(i, scratch); + mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v))); + } + // no longer need the bytes + bytes.close(); + bytes = null; + for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) { + sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))]; + } + } finally { + blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); + } + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + BytesRef scratch = new BytesRef(); + try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendBytesRef(getValue(ids[start], scratch)); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendBytesRef(getValue(ids[i], scratch)); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(BytesRefBlock.Builder builder, int id, BytesRef scratch) { - BytesRef value = bytes.get(values.getKey2(id), scratch); - builder.appendBytesRef(value); + BytesRef getValue(int valueId, BytesRef scratch) { + return bytes.get(values.getKey2(valueId), scratch); } @Override @@ -280,7 +356,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index 752cd53a140f7..9aebe85c0cd89 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for double. @@ -48,28 +50,32 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, double v) { - state.values.add(groupId, Double.doubleToLongBits(v)); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getDouble(i)); + state.addValue(groupId, values.getDouble(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - if (state.values.getKey1(id) == statePosition) { - double value = Double.longBitsToDouble(state.values.getKey2(id)); - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -112,6 +118,20 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -120,10 +140,15 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongLongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongLongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongLongHash(1, driverContext.bigArrays()); } @Override @@ -131,6 +156,11 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, double v) { + values.add(groupId, Double.doubleToLongBits(v)); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -140,8 +170,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -216,39 +253,54 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendDouble(getValue(ids[start])); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendDouble(getValue(ids[i])); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(DoubleBlock.Builder builder, int id) { - double value = Double.longBitsToDouble(values.getKey2(id)); - builder.appendDouble(value); + double getValue(int valueId) { + return Double.longBitsToDouble(values.getKey2(valueId)); } @Override @@ -258,7 +310,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index 91f1730ab3111..28c736783122d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for float. @@ -47,34 +49,32 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, float v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getFloat(i)); + state.addValue(groupId, values.getFloat(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - long both = state.values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group == statePosition) { - float value = Float.intBitsToFloat((int) both); - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -117,6 +117,20 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -125,10 +139,15 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongHash(1, driverContext.bigArrays()); } @Override @@ -136,6 +155,15 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, float v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -145,8 +173,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -223,40 +258,55 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - try (FloatBlock.Builder builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + try (FloatBlock.Builder builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendFloat(getValue(ids[start])); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendFloat(getValue(ids[i])); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(FloatBlock.Builder builder, int id) { - long both = values.get(id); - float value = Float.intBitsToFloat((int) both); - builder.appendFloat(value); + float getValue(int valueId) { + long both = values.get(valueId); + return Float.intBitsToFloat((int) both); } @Override @@ -266,7 +316,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index c4f595d938aa9..39dbcd155954c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for int. @@ -47,34 +49,32 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, int v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getInt(i)); + state.addValue(groupId, values.getInt(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - long both = state.values.get(id); - int group = (int) (both >>> Integer.SIZE); - if (group == statePosition) { - int value = (int) both; - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -117,6 +117,20 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -125,10 +139,15 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongHash(1, driverContext.bigArrays()); } @Override @@ -136,6 +155,15 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, int v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -145,8 +173,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -223,40 +258,55 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendInt(getValue(ids[start])); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendInt(getValue(ids[i])); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(IntBlock.Builder builder, int id) { - long both = values.get(id); - int value = (int) both; - builder.appendInt(value); + int getValue(int valueId) { + long both = values.get(valueId); + return (int) both; } @Override @@ -266,7 +316,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 8ae5da509151e..a7d3a8fe539df 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for long. @@ -48,28 +50,32 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, long v) { - state.values.add(groupId, v); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getLong(i)); + state.addValue(groupId, values.getLong(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - if (state.values.getKey1(id) == statePosition) { - long value = state.values.getKey2(id); - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -112,6 +118,20 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -120,10 +140,15 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongLongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongLongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongLongHash(1, driverContext.bigArrays()); } @Override @@ -131,6 +156,11 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, long v) { + values.add(groupId, v); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -140,8 +170,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -216,39 +253,54 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendLong(getValue(ids[start])); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendLong(getValue(ids[i])); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(LongBlock.Builder builder, int id) { - long value = values.getKey2(id); - builder.appendLong(value); + long getValue(int valueId) { + return values.getKey2(valueId); } @Override @@ -258,7 +310,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index 6db44ffce8faf..bc7c898d652ce 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -42,7 +42,7 @@ public ValuesBytesRefGroupingAggregatorFunction(List channels, public static ValuesBytesRefGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java index 893d8fcd2ea5d..c9d4fa9ba0569 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java @@ -41,7 +41,7 @@ public ValuesDoubleGroupingAggregatorFunction(List channels, public static ValuesDoubleGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java index 8afd75384aa87..b1b71e0e0f34e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java @@ -41,7 +41,7 @@ public ValuesFloatGroupingAggregatorFunction(List channels, public static ValuesFloatGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java index 468320a69fc98..bf4bbd01ba969 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java @@ -39,7 +39,7 @@ public ValuesIntGroupingAggregatorFunction(List channels, public static ValuesIntGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java index cc6e7121c5afb..0d57da6efc466 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java @@ -41,7 +41,7 @@ public ValuesLongGroupingAggregatorFunction(List channels, public static ValuesLongGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 68c6a8640cbd0..f25bfa1f1855e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -36,12 +36,9 @@ $if(long)$ import org.elasticsearch.compute.data.LongBlock; $endif$ import org.elasticsearch.compute.operator.DriverContext; -$if(BytesRef)$ +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -$else$ - -$endif$ /** * Aggregates field values for $type$. * This class is generated. Edit @{code X-ValuesAggregator.java.st} instead @@ -83,30 +80,12 @@ $endif$ return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, $type$ v) { -$if(long)$ - state.values.add(groupId, v); -$elseif(double)$ - state.values.add(groupId, Double.doubleToLongBits(v)); -$elseif(BytesRef)$ - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); -$elseif(int)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); -$elseif(float)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); -$endif$ + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { @@ -117,37 +96,27 @@ $endif$ int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { $if(BytesRef)$ - combine(state, groupId, values.getBytesRef(i, scratch)); + state.addValue(groupId, values.getBytesRef(i, scratch)); $else$ - combine(state, groupId, values.get$Type$(i)); + state.addValue(groupId, values.get$Type$(i)); $endif$ } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; $if(BytesRef)$ - BytesRef scratch = new BytesRef(); -$endif$ - for (int id = 0; id < state.values.size(); id++) { -$if(long||BytesRef)$ - if (state.values.getKey1(id) == statePosition) { - long value = state.values.getKey2(id); -$elseif(double)$ - if (state.values.getKey1(id) == statePosition) { - double value = Double.longBitsToDouble(state.values.getKey2(id)); -$elseif(int)$ - long both = state.values.get(id); - int group = (int) (both >>> Integer.SIZE); - if (group == statePosition) { - int value = (int) both; -$elseif(float)$ - long both = state.values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group == statePosition) { - float value = Float.intBitsToFloat((int) both); + current.addValueOrdinal(currentGroupId, id); +$else$ + current.addValue(currentGroupId, state.getValue(id)); $endif$ - combine(current, currentGroupId, $if(BytesRef)$state.bytes.get(value, scratch)$else$value$endif$); - } } } @@ -222,6 +191,24 @@ $endif$ } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. +$if(BytesRef)$ + * If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)}, + * these are ordinals referring to the {@link GroupingState#bytes} in the target state. +$endif$ + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -230,26 +217,31 @@ $endif$ * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; $if(long||double)$ private final LongLongHash values; $elseif(BytesRef)$ private final LongLongHash values; - private final BytesRefHash bytes; + private BytesRefHash bytes; $elseif(int||float)$ private final LongHash values; $endif$ - private GroupingState(BigArrays bigArrays) { + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); $if(long||double)$ - values = new LongLongHash(1, bigArrays); + values = new LongLongHash(1, driverContext.bigArrays()); $elseif(BytesRef)$ LongLongHash _values = null; BytesRefHash _bytes = null; try { - _values = new LongLongHash(1, bigArrays); - _bytes = new BytesRefHash(1, bigArrays); + _values = new LongLongHash(1, driverContext.bigArrays()); + _bytes = new BytesRefHash(1, driverContext.bigArrays()); values = _values; bytes = _bytes; @@ -260,7 +252,7 @@ $elseif(BytesRef)$ Releasables.closeExpectNoException(_values, _bytes); } $elseif(int||float)$ - values = new LongHash(1, bigArrays); + values = new LongHash(1, driverContext.bigArrays()); $endif$ } @@ -269,6 +261,36 @@ $endif$ blocks[offset] = toBlock(driverContext.blockFactory(), selected); } +$if(BytesRef)$ + void addValueOrdinal(int groupId, long valueOrdinal) { + values.add(groupId, valueOrdinal); + maxGroupId = Math.max(maxGroupId, groupId); + } + +$endif$ + void addValue(int groupId, $type$ v) { +$if(long)$ + values.add(groupId, v); +$elseif(double)$ + values.add(groupId, Double.doubleToLongBits(v)); +$elseif(BytesRef)$ + values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); +$elseif(int)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); +$elseif(float)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); +$endif$ + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -278,8 +300,15 @@ $endif$ return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -354,72 +383,102 @@ $endif$ idsSize = adjust; int[] ids = new int[total]; for (int id = 0; id < values.size(); id++) { -$if(long||BytesRef||double)$ + $if(long||BytesRef||double)$ int group = (int) values.getKey1(id); -$elseif(float||int)$ + $elseif(float||int)$ long both = values.get(id); int group = (int) (both >>> Float.SIZE); -$endif$ + $endif$ if (group < selectedCounts.length && selectedCounts[group] >= 0) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); $if(BytesRef)$ - BytesRef scratch = new BytesRef(); -$endif$ - try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]$if(BytesRef)$, scratch$endif$); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]$if(BytesRef)$, scratch$endif$); - } - builder.endPositionEntry(); - } + // hash all the bytes to the destination to avoid hashing them multiple times + BytesRef scratch = new BytesRef(); + final int totalValue = Math.toIntExact(bytes.size()); + blockFactory.adjustBreaker((long) totalValue * Integer.BYTES); + try { + final int[] mappedIds = new int[totalValue]; + for (int i = 0; i < totalValue; i++) { + BytesRef v = bytes.get(i, scratch); + mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v))); + } + // no longer need the bytes + bytes.close(); + bytes = null; + for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) { + sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))]; } - start = end; + } finally { + blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); } - return builder.build(); +$endif$ } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); } + return sortedForOrdinalMerging; } + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ $if(BytesRef)$ - private void append($Type$Block.Builder builder, int id, BytesRef scratch) { - BytesRef value = bytes.get(values.getKey2(id), scratch); - builder.appendBytesRef(value); + BytesRef scratch = new BytesRef(); +$endif$ + try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.append$Type$(getValue(ids[start]$if(BytesRef)$, scratch$endif$)); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.append$Type$(getValue(ids[i]$if(BytesRef)$, scratch$endif$)); + } + builder.endPositionEntry(); + } + } + start = end; + } + return builder.build(); + } } -$else$ - private void append($Type$Block.Builder builder, int id) { -$if(long)$ - long value = values.getKey2(id); + $type$ getValue(int valueId$if(BytesRef)$, BytesRef scratch$endif$) { +$if(BytesRef)$ + return bytes.get(values.getKey2(valueId), scratch); +$elseif(long)$ + return values.getKey2(valueId); $elseif(double)$ - double value = Double.longBitsToDouble(values.getKey2(id)); + return Double.longBitsToDouble(values.getKey2(valueId)); $elseif(float)$ - long both = values.get(id); - float value = Float.intBitsToFloat((int) both); + long both = values.get(valueId); + return Float.intBitsToFloat((int) both); $elseif(int)$ - long both = values.get(id); - int value = (int) both; + long both = values.get(valueId); + return (int) both; $endif$ - builder.append$Type$(value); } -$endif$ @Override public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block @@ -428,9 +487,9 @@ $endif$ @Override public void close() { $if(BytesRef)$ - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging); $else$ - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); $endif$ } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index c47b6cebdaddc..56ac7fb242472 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -83,7 +83,7 @@ public String describe() { private final BlockHash blockHash; - private final List aggregators; + protected final List aggregators; private final DriverContext driverContext;