diff --git a/docs/changelog/130576.yaml b/docs/changelog/130576.yaml new file mode 100644 index 0000000000000..29947d9259c6b --- /dev/null +++ b/docs/changelog/130576.yaml @@ -0,0 +1,5 @@ +pr: 130576 +summary: Avoid O(N^2) in VALUES with ordinals grouping +area: ES|QL +type: bug +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 682a1e51ff3fc..72c840f090b83 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -53,12 +53,12 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, BytesRef v) { - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { @@ -66,17 +66,20 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getBytesRef(i, scratch)); + state.addValue(groupId, values.getBytesRef(i, scratch)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - BytesRef scratch = new BytesRef(); - for (int id = 0; id < state.values.size(); id++) { - if (state.values.getKey1(id) == statePosition) { - long value = state.values.getKey2(id); - combine(current, currentGroupId, state.bytes.get(value, scratch)); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValueOrdinal(currentGroupId, id); } } @@ -119,6 +122,22 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + * If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)}, + * these are ordinals referring to the {@link GroupingState#bytes} in the target state. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -126,19 +145,47 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongLongHash values; - private final BytesRefHash bytes; + private BytesRefHash bytes; - private GroupingState(BigArrays bigArrays) { - values = new LongLongHash(1, bigArrays); - bytes = new BytesRefHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + LongLongHash _values = null; + BytesRefHash _bytes = null; + try { + _values = new LongLongHash(1, driverContext.bigArrays()); + _bytes = new BytesRefHash(1, driverContext.bigArrays()); + + values = _values; + bytes = _bytes; + + _values = null; + _bytes = null; + } finally { + Releasables.closeExpectNoException(_values, _bytes); + } } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValueOrdinal(int groupId, long valueOrdinal) { + values.add(groupId, valueOrdinal); + maxGroupId = Math.max(maxGroupId, groupId); + } + + void addValue(int groupId, BytesRef v) { + values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -148,8 +195,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -224,40 +278,74 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - BytesRef scratch = new BytesRef(); - try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start], scratch); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i], scratch); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + // hash all the bytes to the destination to avoid hashing them multiple times + BytesRef scratch = new BytesRef(); + final int totalValue = Math.toIntExact(bytes.size()); + blockFactory.adjustBreaker((long) totalValue * Integer.BYTES); + try { + final int[] mappedIds = new int[totalValue]; + for (int i = 0; i < totalValue; i++) { + var v = bytes.get(i, scratch); + mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v))); + } + // no longer need the bytes + bytes.close(); + bytes = null; + for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) { + sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))]; + } + } finally { + blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); + } + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + BytesRef scratch = new BytesRef(); + try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendBytesRef(getValue(ids[start], scratch)); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendBytesRef(getValue(ids[i], scratch)); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(BytesRefBlock.Builder builder, int id, BytesRef scratch) { - BytesRef value = bytes.get(values.getKey2(id), scratch); - builder.appendBytesRef(value); + BytesRef getValue(int valueId, BytesRef scratch) { + return bytes.get(values.getKey2(valueId), scratch); } public void enableGroupIdTracking(SeenGroupIds seen) { @@ -266,7 +354,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index 505d3a91991ec..99919b3b05158 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -20,6 +20,7 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for double. @@ -49,28 +50,32 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, double v) { - state.values.add(groupId, Double.doubleToLongBits(v)); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getDouble(i)); + state.addValue(groupId, values.getDouble(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - if (state.values.getKey1(id) == statePosition) { - double value = Double.longBitsToDouble(state.values.getKey2(id)); - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -112,6 +117,20 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -119,17 +138,28 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongLongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongLongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongLongHash(1, driverContext.bigArrays()); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, double v) { + values.add(groupId, Double.doubleToLongBits(v)); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -139,8 +169,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -215,39 +252,54 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendDouble(getValue(ids[start])); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendDouble(getValue(ids[i])); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(DoubleBlock.Builder builder, int id) { - double value = Double.longBitsToDouble(values.getKey2(id)); - builder.appendDouble(value); + double getValue(int valueId) { + return Double.longBitsToDouble(values.getKey2(valueId)); } public void enableGroupIdTracking(SeenGroupIds seen) { @@ -256,7 +308,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index 9c50552110183..e544635de2ff5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -19,6 +19,7 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for float. @@ -48,34 +49,32 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, float v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getFloat(i)); + state.addValue(groupId, values.getFloat(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - long both = state.values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group == statePosition) { - float value = Float.intBitsToFloat((int) both); - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -117,6 +116,20 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -124,17 +137,32 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongHash(1, driverContext.bigArrays()); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, float v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -144,8 +172,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -222,40 +257,55 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - try (FloatBlock.Builder builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + try (FloatBlock.Builder builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendFloat(getValue(ids[start])); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendFloat(getValue(ids[i])); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(FloatBlock.Builder builder, int id) { - long both = values.get(id); - float value = Float.intBitsToFloat((int) both); - builder.appendFloat(value); + float getValue(int valueId) { + long both = values.get(valueId); + return Float.intBitsToFloat((int) both); } public void enableGroupIdTracking(SeenGroupIds seen) { @@ -264,7 +314,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 1e0ca72b8d1a6..0ef8923a812ff 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -19,6 +19,7 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for int. @@ -48,34 +49,32 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, int v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getInt(i)); + state.addValue(groupId, values.getInt(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - long both = state.values.get(id); - int group = (int) (both >>> Integer.SIZE); - if (group == statePosition) { - int value = (int) both; - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -117,6 +116,20 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -124,17 +137,32 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongHash(1, driverContext.bigArrays()); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, int v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -144,8 +172,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -222,40 +257,55 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendInt(getValue(ids[start])); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendInt(getValue(ids[i])); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(IntBlock.Builder builder, int id) { - long both = values.get(id); - int value = (int) both; - builder.appendInt(value); + int getValue(int valueId) { + long both = values.get(valueId); + return (int) both; } public void enableGroupIdTracking(SeenGroupIds seen) { @@ -264,7 +314,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index ba04f928b9fb9..7a0a8bed86fe9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -20,6 +20,7 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; /** * Aggregates field values for long. @@ -49,28 +50,32 @@ public static Block evaluateFinal(SingleState state, DriverContext driverContext return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, long v) { - state.values.add(groupId, v); + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { - combine(state, groupId, values.getLong(i)); + state.addValue(groupId, values.getLong(i)); } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - for (int id = 0; id < state.values.size(); id++) { - if (state.values.getKey1(id) == statePosition) { - long value = state.values.getKey2(id); - combine(current, currentGroupId, value); - } + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; + current.addValue(currentGroupId, state.getValue(id)); } } @@ -112,6 +117,20 @@ public void close() { } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -119,17 +138,28 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; private final LongLongHash values; - private GroupingState(BigArrays bigArrays) { - values = new LongLongHash(1, bigArrays); + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); + values = new LongLongHash(1, driverContext.bigArrays()); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } + void addValue(int groupId, long v) { + values.add(groupId, v); + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -139,8 +169,15 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -215,39 +252,54 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ - try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]); - } - builder.endPositionEntry(); + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); + } + } + return sortedForOrdinalMerging; + } + + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ + try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.appendLong(getValue(ids[start])); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.appendLong(getValue(ids[i])); } + builder.endPositionEntry(); } - start = end; } - return builder.build(); + start = end; } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + return builder.build(); } } - private void append(LongBlock.Builder builder, int id) { - long value = values.getKey2(id); - builder.appendLong(value); + long getValue(int valueId) { + return values.getKey2(valueId); } public void enableGroupIdTracking(SeenGroupIds seen) { @@ -256,7 +308,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index bdce606f92168..290409830be6a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -42,7 +42,7 @@ public ValuesBytesRefGroupingAggregatorFunction(List channels, public static ValuesBytesRefGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java index 5b8c2ac802663..001ae62f11d30 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java @@ -41,7 +41,7 @@ public ValuesDoubleGroupingAggregatorFunction(List channels, public static ValuesDoubleGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java index f50c5a67d15a5..ddbd15a089bf9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java @@ -41,7 +41,7 @@ public ValuesFloatGroupingAggregatorFunction(List channels, public static ValuesFloatGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java index c90fcedb291cf..8e7a10eb5d1e8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java @@ -39,7 +39,7 @@ public ValuesIntGroupingAggregatorFunction(List channels, public static ValuesIntGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java index 8a79cd7d942ee..fbb41bfbf0679 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java @@ -41,7 +41,7 @@ public ValuesLongGroupingAggregatorFunction(List channels, public static ValuesLongGroupingAggregatorFunction create(List channels, DriverContext driverContext) { - return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext.bigArrays()), driverContext); + return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext), driverContext); } public static List intermediateStateDesc() { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 69243332449f6..ecf8d8180e168 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -37,12 +37,8 @@ import org.elasticsearch.compute.data.LongBlock; $endif$ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; -$if(BytesRef)$ import org.elasticsearch.core.Releasables; -$else$ - -$endif$ /** * Aggregates field values for $type$. * This class is generated. Edit @{code X-ValuesAggregator.java.st} instead @@ -84,30 +80,12 @@ $endif$ return state.toBlock(driverContext.blockFactory()); } - public static GroupingState initGrouping(BigArrays bigArrays) { - return new GroupingState(bigArrays); + public static GroupingState initGrouping(DriverContext driverContext) { + return new GroupingState(driverContext); } public static void combine(GroupingState state, int groupId, $type$ v) { -$if(long)$ - state.values.add(groupId, v); -$elseif(double)$ - state.values.add(groupId, Double.doubleToLongBits(v)); -$elseif(BytesRef)$ - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); -$elseif(int)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); -$elseif(float)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); -$endif$ + state.addValue(groupId, v); } public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { @@ -118,37 +96,27 @@ $endif$ int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { $if(BytesRef)$ - combine(state, groupId, values.getBytesRef(i, scratch)); + state.addValue(groupId, values.getBytesRef(i, scratch)); $else$ - combine(state, groupId, values.get$Type$(i)); + state.addValue(groupId, values.get$Type$(i)); $endif$ } } public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { + var sorted = state.sortedForOrdinalMerging(current); + if (statePosition > state.maxGroupId) { + return; + } + var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; + var end = sorted.counts[statePosition]; + for (int i = start; i < end; i++) { + int id = sorted.ids[i]; $if(BytesRef)$ - BytesRef scratch = new BytesRef(); -$endif$ - for (int id = 0; id < state.values.size(); id++) { -$if(long||BytesRef)$ - if (state.values.getKey1(id) == statePosition) { - long value = state.values.getKey2(id); -$elseif(double)$ - if (state.values.getKey1(id) == statePosition) { - double value = Double.longBitsToDouble(state.values.getKey2(id)); -$elseif(int)$ - long both = state.values.get(id); - int group = (int) (both >>> Integer.SIZE); - if (group == statePosition) { - int value = (int) both; -$elseif(float)$ - long both = state.values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group == statePosition) { - float value = Float.intBitsToFloat((int) both); + current.addValueOrdinal(currentGroupId, id); +$else$ + current.addValue(currentGroupId, state.getValue(id)); $endif$ - combine(current, currentGroupId, $if(BytesRef)$state.bytes.get(value, scratch)$else$value$endif$); - } } } @@ -222,6 +190,24 @@ $endif$ } } + /** + * Values are collected in a hash. Iterating over them in order (row by row) to build the output, + * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, + * and then use it to iterate over the values in order. + * + * @param ids positions of the {@link GroupingState#values} to read. +$if(BytesRef)$ + * If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)}, + * these are ordinals referring to the {@link GroupingState#bytes} in the target state. +$endif$ + */ + private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + @Override + public void close() { + releasable.close(); + } + } + /** * State for a grouped {@code VALUES} aggregation. This implementation * emphasizes collect-time performance over the performance of rendering @@ -229,33 +215,81 @@ $endif$ * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { + private int maxGroupId = -1; + private final BlockFactory blockFactory; $if(long||double)$ private final LongLongHash values; $elseif(BytesRef)$ private final LongLongHash values; - private final BytesRefHash bytes; + private BytesRefHash bytes; $elseif(int||float)$ private final LongHash values; $endif$ - private GroupingState(BigArrays bigArrays) { + private Sorted sortedForOrdinalMerging = null; + + private GroupingState(DriverContext driverContext) { + this.blockFactory = driverContext.blockFactory(); $if(long||double)$ - values = new LongLongHash(1, bigArrays); + values = new LongLongHash(1, driverContext.bigArrays()); $elseif(BytesRef)$ - values = new LongLongHash(1, bigArrays); - bytes = new BytesRefHash(1, bigArrays); + LongLongHash _values = null; + BytesRefHash _bytes = null; + try { + _values = new LongLongHash(1, driverContext.bigArrays()); + _bytes = new BytesRefHash(1, driverContext.bigArrays()); + + values = _values; + bytes = _bytes; + + _values = null; + _bytes = null; + } finally { + Releasables.closeExpectNoException(_values, _bytes); + } $elseif(int||float)$ - values = new LongHash(1, bigArrays); + values = new LongHash(1, driverContext.bigArrays()); $endif$ } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } +$if(BytesRef)$ + void addValueOrdinal(int groupId, long valueOrdinal) { + values.add(groupId, valueOrdinal); + maxGroupId = Math.max(maxGroupId, groupId); + } + +$endif$ + void addValue(int groupId, $type$ v) { +$if(long)$ + values.add(groupId, v); +$elseif(double)$ + values.add(groupId, Double.doubleToLongBits(v)); +$elseif(BytesRef)$ + values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); +$elseif(int)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); +$elseif(float)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); +$endif$ + maxGroupId = Math.max(maxGroupId, groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -265,8 +299,15 @@ $endif$ return blockFactory.newConstantNullBlock(selected.getPositionCount()); } + try (var sorted = buildSorted(selected)) { + return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + } + } + + private Sorted buildSorted(IntVector selected) { long selectedCountsSize = 0; long idsSize = 0; + Sorted sorted = null; try { /* * Get a count of all groups less than the maximum selected group. Count @@ -341,72 +382,102 @@ $endif$ idsSize = adjust; int[] ids = new int[total]; for (int id = 0; id < values.size(); id++) { -$if(long||BytesRef||double)$ + $if(long||BytesRef||double)$ int group = (int) values.getKey1(id); -$elseif(float||int)$ + $elseif(float||int)$ long both = values.get(id); int group = (int) (both >>> Float.SIZE); -$endif$ + $endif$ if (group < selectedCounts.length && selectedCounts[group] >= 0) { ids[selectedCounts[group]++] = id; } } + final long totalMemoryUsed = selectedCountsSize + idsSize; + sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); + return sorted; + } finally { + if (sorted == null) { + blockFactory.adjustBreaker(-selectedCountsSize - idsSize); + } + } + } - /* - * Insert the ids in order. - */ + private Sorted sortedForOrdinalMerging(GroupingState other) { + if (sortedForOrdinalMerging == null) { + try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { + sortedForOrdinalMerging = buildSorted(selected); $if(BytesRef)$ - BytesRef scratch = new BytesRef(); -$endif$ - try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { - int start = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start]$if(BytesRef)$, scratch$endif$); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - append(builder, ids[i]$if(BytesRef)$, scratch$endif$); - } - builder.endPositionEntry(); - } + // hash all the bytes to the destination to avoid hashing them multiple times + BytesRef scratch = new BytesRef(); + final int totalValue = Math.toIntExact(bytes.size()); + blockFactory.adjustBreaker((long) totalValue * Integer.BYTES); + try { + final int[] mappedIds = new int[totalValue]; + for (int i = 0; i < totalValue; i++) { + var v = bytes.get(i, scratch); + mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v))); + } + // no longer need the bytes + bytes.close(); + bytes = null; + for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) { + sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))]; } - start = end; + } finally { + blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); } - return builder.build(); +$endif$ } - } finally { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); } + return sortedForOrdinalMerging; } + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + /* + * Insert the ids in order. + */ $if(BytesRef)$ - private void append($Type$Block.Builder builder, int id, BytesRef scratch) { - BytesRef value = bytes.get(values.getKey2(id), scratch); - builder.appendBytesRef(value); + BytesRef scratch = new BytesRef(); +$endif$ + try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int end = selectedCounts[group]; + int count = end - start; + switch (count) { + case 0 -> builder.appendNull(); + case 1 -> builder.append$Type$(getValue(ids[start]$if(BytesRef)$, scratch$endif$)); + default -> { + builder.beginPositionEntry(); + for (int i = start; i < end; i++) { + builder.append$Type$(getValue(ids[i]$if(BytesRef)$, scratch$endif$)); + } + builder.endPositionEntry(); + } + } + start = end; + } + return builder.build(); + } } -$else$ - private void append($Type$Block.Builder builder, int id) { -$if(long)$ - long value = values.getKey2(id); + $type$ getValue(int valueId$if(BytesRef)$, BytesRef scratch$endif$) { +$if(BytesRef)$ + return bytes.get(values.getKey2(valueId), scratch); +$elseif(long)$ + return values.getKey2(valueId); $elseif(double)$ - double value = Double.longBitsToDouble(values.getKey2(id)); + return Double.longBitsToDouble(values.getKey2(valueId)); $elseif(float)$ - long both = values.get(id); - float value = Float.intBitsToFloat((int) both); + long both = values.get(valueId); + return Float.intBitsToFloat((int) both); $elseif(int)$ - long both = values.get(id); - int value = (int) both; + long both = values.get(valueId); + return (int) both; $endif$ - builder.append$Type$(value); } -$endif$ public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } @@ -414,9 +485,9 @@ $endif$ @Override public void close() { $if(BytesRef)$ - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging); $else$ - values.close(); + Releasables.closeExpectNoException(values, sortedForOrdinalMerging); $endif$ } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 03a4ca2b0ad5e..f0f21c3530a9b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -66,7 +66,7 @@ public String describe() { private final BlockHash blockHash; - private final List aggregators; + protected final List aggregators; private final DriverContext driverContext;