Skip to content

Commit bdbc8b5

Browse files
[FLINK-38576][table] Align commonJoinKey in MultiJoin for logical and physical rules
1 parent 1c2d953 commit bdbc8b5

File tree

7 files changed

+3284
-813
lines changed

7 files changed

+3284
-813
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.utils;
20+
21+
import org.apache.flink.annotation.Internal;
22+
import org.apache.flink.util.FlinkRuntimeException;
23+
24+
/** Thrown when a MultiJoin node has no common join key. */
25+
@Internal
26+
public class NoCommonJoinKeyException extends FlinkRuntimeException {
27+
private static final long serialVersionUID = 1L;
28+
29+
public NoCommonJoinKeyException(String message) {
30+
super(message);
31+
}
32+
33+
public NoCommonJoinKeyException(String message, Throwable cause) {
34+
super(message, cause);
35+
}
36+
37+
public NoCommonJoinKeyException(Throwable cause) {
38+
super(cause);
39+
}
40+
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/JoinToMultiJoinRule.java

Lines changed: 38 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@
1818

1919
package org.apache.flink.table.planner.plan.rules.logical;
2020

21-
import org.apache.flink.api.java.tuple.Tuple2;
2221
import org.apache.flink.table.api.TableException;
22+
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
2323
import org.apache.flink.table.planner.hint.FlinkHints;
2424
import org.apache.flink.table.planner.hint.StateTtlHint;
2525
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMultiJoin;
2626
import org.apache.flink.table.planner.plan.utils.IntervalJoinUtil;
27+
import org.apache.flink.table.runtime.operators.join.stream.keyselector.AttributeBasedJoinKeyExtractor;
28+
import org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor;
29+
import org.apache.flink.table.types.logical.RowType;
30+
import org.apache.flink.table.utils.NoCommonJoinKeyException;
2731

2832
import org.apache.calcite.plan.RelOptRuleCall;
29-
import org.apache.calcite.plan.RelOptTable;
3033
import org.apache.calcite.plan.RelOptUtil;
3134
import org.apache.calcite.plan.RelRule;
3235
import org.apache.calcite.plan.hep.HepRelVertex;
@@ -36,21 +39,15 @@
3639
import org.apache.calcite.rel.core.Join;
3740
import org.apache.calcite.rel.core.JoinInfo;
3841
import org.apache.calcite.rel.core.JoinRelType;
39-
import org.apache.calcite.rel.core.TableFunctionScan;
40-
import org.apache.calcite.rel.core.TableScan;
41-
import org.apache.calcite.rel.core.Values;
4242
import org.apache.calcite.rel.hint.RelHint;
4343
import org.apache.calcite.rel.logical.LogicalJoin;
4444
import org.apache.calcite.rel.logical.LogicalSnapshot;
45-
import org.apache.calcite.rel.metadata.RelColumnOrigin;
46-
import org.apache.calcite.rel.metadata.RelMetadataQuery;
4745
import org.apache.calcite.rel.rules.CoreRules;
4846
import org.apache.calcite.rel.rules.FilterMultiJoinMergeRule;
4947
import org.apache.calcite.rel.rules.MultiJoin;
5048
import org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule;
5149
import org.apache.calcite.rel.rules.TransformationRule;
5250
import org.apache.calcite.rex.RexBuilder;
53-
import org.apache.calcite.rex.RexCall;
5451
import org.apache.calcite.rex.RexInputRef;
5552
import org.apache.calcite.rex.RexNode;
5653
import org.apache.calcite.rex.RexUtil;
@@ -65,14 +62,14 @@
6562
import java.util.Arrays;
6663
import java.util.Collections;
6764
import java.util.HashMap;
68-
import java.util.HashSet;
6965
import java.util.List;
7066
import java.util.Map;
7167
import java.util.Objects;
72-
import java.util.Set;
7368
import java.util.stream.Collectors;
69+
import java.util.stream.Stream;
7470

7571
import static org.apache.flink.table.planner.hint.StateTtlHint.STATE_TTL;
72+
import static org.apache.flink.table.planner.plan.utils.MultiJoinUtil.createJoinAttributeMap;
7673

7774
/**
7875
* Flink Planner rule to flatten a tree of {@link Join}s into a single {@link MultiJoin} with N
@@ -442,134 +439,45 @@ private boolean canCombine(RelNode input, Join origJoin) {
442439

443440
/**
444441
* Checks if original join and child multi-join have common join keys to decide if we can merge
445-
* them into a single MultiJoin with one more input.
442+
* them into a single MultiJoin with one more input. The method uses {@link
443+
* AttributeBasedJoinKeyExtractor} to try to create valid common join key extractors.
446444
*
447445
* @param origJoin original Join
448446
* @param otherJoin child MultiJoin
449447
* @return true if original Join and child multi-join have at least one common JoinKey
450448
*/
451449
private boolean haveCommonJoinKey(Join origJoin, MultiJoin otherJoin) {
452-
Set<String> origJoinKeys = getJoinKeys(origJoin);
453-
Set<String> otherJoinKeys = getJoinKeys(otherJoin);
454-
455-
origJoinKeys.retainAll(otherJoinKeys);
456-
457-
return !origJoinKeys.isEmpty();
458-
}
459-
460-
/**
461-
* Returns a set of join keys as strings following this format [table_name.field_name].
462-
*
463-
* @param join Join or MultiJoin node
464-
* @return set of all the join keys (keys from join conditions)
465-
*/
466-
public Set<String> getJoinKeys(RelNode join) {
467-
Set<String> joinKeys = new HashSet<>();
468-
List<RexCall> conditions = Collections.emptyList();
469-
List<RelNode> inputs = join.getInputs();
470-
471-
if (join instanceof Join) {
472-
conditions = collectConjunctions(((Join) join).getCondition());
473-
} else if (join instanceof MultiJoin) {
474-
conditions =
475-
((MultiJoin) join)
476-
.getOuterJoinConditions().stream()
477-
.flatMap(cond -> collectConjunctions(cond).stream())
478-
.collect(Collectors.toList());
450+
final List<RelNode> combinedJoinInputs =
451+
Stream.concat(otherJoin.getInputs().stream(), Stream.of(origJoin.getRight()))
452+
.collect(Collectors.toUnmodifiableList());
453+
454+
final List<RowType> combinedInputTypes =
455+
combinedJoinInputs.stream()
456+
.map(i -> FlinkTypeFactory.toLogicalRowType(i.getRowType()))
457+
.collect(Collectors.toUnmodifiableList());
458+
459+
final List<RexNode> combinedJoinConditions =
460+
Stream.concat(
461+
otherJoin.getOuterJoinConditions().stream(),
462+
List.of(origJoin.getCondition()).stream())
463+
.collect(Collectors.toUnmodifiableList());
464+
465+
final Map<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>>
466+
joinAttributeMap =
467+
createJoinAttributeMap(combinedJoinInputs, combinedJoinConditions);
468+
469+
boolean haveCommonJoinKey = false;
470+
try {
471+
// we probe to instantiate AttributeBasedJoinKeyExtractor's constructor to check whether
472+
// it's possible to initialize common join key structures
473+
final JoinKeyExtractor keyExtractor =
474+
new AttributeBasedJoinKeyExtractor(joinAttributeMap, combinedInputTypes);
475+
haveCommonJoinKey = keyExtractor.getCommonJoinKeyIndices(0).length > 0;
476+
} catch (NoCommonJoinKeyException ignored) {
477+
// failed to instantiate common join key structures => no common join key
479478
}
480479

481-
RelMetadataQuery mq = join.getCluster().getMetadataQuery();
482-
483-
for (RexCall condition : conditions) {
484-
for (RexNode operand : condition.getOperands()) {
485-
if (operand instanceof RexInputRef) {
486-
addJoinKeysByOperand((RexInputRef) operand, inputs, mq, joinKeys);
487-
}
488-
}
489-
}
490-
491-
return joinKeys;
492-
}
493-
494-
/**
495-
* Retrieves conjunctions from joinCondition.
496-
*
497-
* @param joinCondition join condition
498-
* @return List of RexCalls representing conditions
499-
*/
500-
private List<RexCall> collectConjunctions(RexNode joinCondition) {
501-
return RelOptUtil.conjunctions(joinCondition).stream()
502-
.map(rexNode -> (RexCall) rexNode)
503-
.collect(Collectors.toList());
504-
}
505-
506-
/**
507-
* Appends join key's string representation to the set of join keys.
508-
*
509-
* @param ref input ref to the operand
510-
* @param inputs List of node's inputs
511-
* @param mq RelMetadataQuery needed to retrieve column origins
512-
* @param joinKeys Set of join keys to be added
513-
*/
514-
private void addJoinKeysByOperand(
515-
RexInputRef ref, List<RelNode> inputs, RelMetadataQuery mq, Set<String> joinKeys) {
516-
int inputRefIndex = ref.getIndex();
517-
Tuple2<RelNode, Integer> targetInputAndIdx = getTargetInputAndIdx(inputRefIndex, inputs);
518-
RelNode targetInput = targetInputAndIdx.f0;
519-
int idxInTargetInput = targetInputAndIdx.f1;
520-
521-
Set<RelColumnOrigin> origins = mq.getColumnOrigins(targetInput, idxInTargetInput);
522-
if (origins != null) {
523-
for (RelColumnOrigin origin : origins) {
524-
RelOptTable originTable = origin.getOriginTable();
525-
List<String> qualifiedName = originTable.getQualifiedName();
526-
String fieldName =
527-
originTable
528-
.getRowType()
529-
.getFieldList()
530-
.get(origin.getOriginColumnOrdinal())
531-
.getName();
532-
joinKeys.add(qualifiedName.get(qualifiedName.size() - 1) + "." + fieldName);
533-
}
534-
}
535-
}
536-
537-
/**
538-
* Get real table that contains needed input ref (join key).
539-
*
540-
* @param inputRefIndex index of the required field
541-
* @param inputs inputs of the node
542-
* @return target input + idx of the required field as target input's
543-
*/
544-
private Tuple2<RelNode, Integer> getTargetInputAndIdx(int inputRefIndex, List<RelNode> inputs) {
545-
RelNode targetInput = null;
546-
int idxInTargetInput = 0;
547-
int inputFieldEnd = 0;
548-
for (RelNode input : inputs) {
549-
inputFieldEnd += input.getRowType().getFieldCount();
550-
if (inputRefIndex < inputFieldEnd) {
551-
targetInput = input;
552-
int targetInputStartIdx = inputFieldEnd - input.getRowType().getFieldCount();
553-
idxInTargetInput = inputRefIndex - targetInputStartIdx;
554-
break;
555-
}
556-
}
557-
558-
targetInput =
559-
(targetInput instanceof HepRelVertex)
560-
? ((HepRelVertex) targetInput).getCurrentRel()
561-
: targetInput;
562-
563-
assert targetInput != null;
564-
565-
if (targetInput instanceof TableScan
566-
|| targetInput instanceof Values
567-
|| targetInput instanceof TableFunctionScan
568-
|| targetInput.getInputs().isEmpty()) {
569-
return new Tuple2<>(targetInput, idxInTargetInput);
570-
} else {
571-
return getTargetInputAndIdx(idxInTargetInput, targetInput.getInputs());
572-
}
480+
return haveCommonJoinKey;
573481
}
574482

575483
/**

0 commit comments

Comments
 (0)