Skip to content

Add "emitEmptyBuckets" parameter to the "Bucket" function. #131609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,12 @@ private static Operator operator(DriverContext driverContext, String grouping, S
new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
);
case TOP_N_LONGS -> List.of(
new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT))
new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT), null)
);
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
};
return new HashAggregationOperator(
groups,
List.of(supplier(op, dataType, filter).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(groups.size()))),
() -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false),
driverContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private static Operator operator(DriverContext driverContext, int groups, String
}
List<BlockHash.GroupSpec> groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG));
return new HashAggregationOperator(
groupSpec,
List.of(supplier(dataType).groupingAggregatorFactory(mode, List.of(1))),
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
driverContext
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_TOPN_TIMINGS = def(9_128_0_00);
public static final TransportVersion NODE_WEIGHTS_ADDED_TO_NODE_BALANCE_STATS = def(9_129_0_00);
public static final TransportVersion RERANK_SNIPPETS = def(9_130_0_00);
public static final TransportVersion ESQL_EMIT_EMPTY_BUCKETS = def(9_131_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ public class GroupingAggregator implements Releasable {

private final AggregatorMode mode;

public AggregatorMode getMode() {
return mode;
}

public interface Factory extends Function<DriverContext, GroupingAggregator>, Describable {}

public GroupingAggregator(GroupingAggregatorFunction aggregatorFunction, AggregatorMode mode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
*/
public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {}

public interface EmptyBucketGenerator {
int getEmptyBucketCount();

void generate(Block.Builder blockBuilder);
}

/**
* Configuration for a BlockHash group spec that is doing text categorization.
*/
Expand All @@ -137,13 +143,19 @@ public enum OutputFormat {
}
}

public record GroupSpec(int channel, ElementType elementType, @Nullable CategorizeDef categorizeDef, @Nullable TopNDef topNDef) {
public record GroupSpec(
int channel,
ElementType elementType,
@Nullable CategorizeDef categorizeDef,
@Nullable TopNDef topNDef,
@Nullable EmptyBucketGenerator emptyBucketGenerator
) {
public GroupSpec(int channel, ElementType elementType) {
this(channel, elementType, null, null);
this(channel, elementType, null, null, null);
}

public GroupSpec(int channel, ElementType elementType, CategorizeDef categorizeDef) {
this(channel, elementType, categorizeDef, null);
this(channel, elementType, categorizeDef, null, null);
}

public boolean isCategorize() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.IntArrayBlock;
import org.elasticsearch.compute.data.IntBigArrayBlock;
import org.elasticsearch.compute.data.IntVector;
Expand All @@ -34,6 +35,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;
Expand All @@ -52,6 +54,7 @@ public record HashAggregationOperatorFactory(
public Operator get(DriverContext driverContext) {
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
return new HashAggregationOperator(
groups,
aggregators,
() -> BlockHash.buildCategorizeBlockHash(
groups,
Expand All @@ -64,6 +67,7 @@ public Operator get(DriverContext driverContext) {
);
}
return new HashAggregationOperator(
groups,
aggregators,
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),
driverContext
Expand All @@ -83,6 +87,7 @@ public String describe() {
private boolean finished;
private Page output;

private final List<BlockHash.GroupSpec> groups;
private final BlockHash blockHash;

protected final List<GroupingAggregator> aggregators;
Expand Down Expand Up @@ -117,10 +122,12 @@ public String describe() {

@SuppressWarnings("this-escape")
public HashAggregationOperator(
List<BlockHash.GroupSpec> groups,
List<GroupingAggregator.Factory> aggregators,
Supplier<BlockHash> blockHash,
DriverContext driverContext
) {
this.groups = groups;
this.aggregators = new ArrayList<>(aggregators.size());
this.driverContext = driverContext;
boolean success = false;
Expand All @@ -142,8 +149,22 @@ public boolean needsInput() {
return finished == false;
}

private final AtomicBoolean isInitialPage = new AtomicBoolean(true);

@Override
public void addInput(Page page) {
if (isInitialPage.compareAndSet(true, false)
&& (aggregators.size() == 0 || AggregatorMode.INITIAL.equals(aggregators.get(0).getMode()))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For later: Would be nice if we could do this on output() on the coordinator, on FINAL aggs only.
I remember the reason to not do this was that Bucket wouldn't be on the coord physical plan, as it's moved to an early eval

Page initialPage = createInitialPage(page);
if (initialPage != null) {
addInputInternal(initialPage);
return;
}
}
addInputInternal(page);
}

private void addInputInternal(Page page) {
try {
GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()];
class AddInput implements GroupingAggregatorFunction.AddInput {
Expand Down Expand Up @@ -289,6 +310,42 @@ protected Page wrapPage(Page page) {
return page;
}

private Page createInitialPage(Page page) {
// If no groups are generating bucket keys, move on
if (groups.stream().allMatch(g -> g.emptyBucketGenerator() == null)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add an assert on the constructor (probably?) that either all have a generator, or none have. If there's a mix, I guess we'll have a problem

return page;
}
Block.Builder[] blockBuilders = new Block.Builder[page.getBlockCount()];
for (int channel = 0; channel < page.getBlockCount(); channel++) {
Block block = page.getBlock(channel);
blockBuilders[channel] = block.elementType().newBlockBuilder(block.getPositionCount(), driverContext.blockFactory());
blockBuilders[channel].copyFrom(block, 0, block.getPositionCount());
}
for (BlockHash.GroupSpec group : groups) {
BlockHash.EmptyBucketGenerator emptyBucketGenerator = group.emptyBucketGenerator();
if (emptyBucketGenerator != null) {
for (int channel = 0; channel < page.getBlockCount(); channel++) {
if (group.channel() == channel) {
emptyBucketGenerator.generate(blockBuilders[channel]);
} else {
for (int i = 0; i < emptyBucketGenerator.getEmptyBucketCount(); i++) {
if (page.getBlock(channel) instanceof DocBlock) {
// TODO: DocBlock doesn't allow appending nulls
((DocBlock.Builder) blockBuilders[channel]).appendShard(0).appendSegment(0).appendDoc(0);
Comment on lines +333 to +334
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels dangerous. Adding a comment to not forget it

} else {
blockBuilders[channel].appendNull();
}
}
}
}
}
}
Block[] blocks = Arrays.stream(blockBuilders).map(Block.Builder::build).toArray(Block[]::new);
Releasables.closeExpectNoException(blockBuilders);
page.releaseBlocks();
return new Page(blocks);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public record Factory(
@Override
public Operator get(DriverContext driverContext) {
// TODO: use TimeSeriesBlockHash when possible
return new TimeSeriesAggregationOperator(timeBucket, aggregators, () -> {
return new TimeSeriesAggregationOperator(timeBucket, groups, aggregators, () -> {
if (sortedInput && groups.size() == 2) {
return new TimeSeriesBlockHash(groups.get(0).channel(), groups.get(1).channel(), driverContext.blockFactory());
} else {
Expand Down Expand Up @@ -68,11 +68,12 @@ public String describe() {

public TimeSeriesAggregationOperator(
Rounding.Prepared timeBucket,
List<BlockHash.GroupSpec> groups,
List<GroupingAggregator.Factory> aggregators,
Supplier<BlockHash> blockHash,
DriverContext driverContext
) {
super(aggregators, blockHash, driverContext);
super(groups, aggregators, blockHash, driverContext);
this.timeBucket = timeBucket;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ public void close() {
};
};

return new HashAggregationOperator(aggregators, blockHashSupplier, driverContext);
return new HashAggregationOperator(groups, aggregators, blockHashSupplier, driverContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ private void hashBatchesCallbackOnLast(Consumer<OrdsAndKeys> callback, Block[]..
private BlockHash buildBlockHash(int emitBatchSize, Block... values) {
List<BlockHash.GroupSpec> specs = new ArrayList<>(values.length);
for (int c = 0; c < values.length; c++) {
specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), null, topNDef(c)));
specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), null, topNDef(c), null));
}
assert forcePackedHash == false : "Packed TopN hash not implemented yet";
/*return forcePackedHash
Expand Down
Loading
Loading