diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/QueryOptimizer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/QueryOptimizer.java index da92a1a160ee2..8bc9c3e70745c 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/QueryOptimizer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/QueryOptimizer.java @@ -45,6 +45,7 @@ import com.starrocks.sql.optimizer.rule.transformation.CTEProduceAddProjectionRule; import com.starrocks.sql.optimizer.rule.transformation.ConvertToEqualForNullRule; import com.starrocks.sql.optimizer.rule.transformation.DeriveRangeJoinPredicateRule; +import com.starrocks.sql.optimizer.rule.transformation.DrivingTableSelection; import com.starrocks.sql.optimizer.rule.transformation.EliminateAggRule; import com.starrocks.sql.optimizer.rule.transformation.EliminateConstantCTERule; import com.starrocks.sql.optimizer.rule.transformation.EliminateSortColumnWithEqualityPredicateRule; @@ -544,6 +545,7 @@ private OptExpression logicalRuleRewrite( // we need to compute the stats of child project(like subfield). skewJoinOptimize(tree, rootTaskContext); scheduler.rewriteOnce(tree, rootTaskContext, new IcebergEqualityDeleteRewriteRule()); + scheduler.rewriteOnce(tree, rootTaskContext, new DrivingTableSelection()); tree = pruneSubfield(tree, rootTaskContext, requiredColumns); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/CompoundPredicateOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/CompoundPredicateOperator.java index 7630ed2031760..1128b30ad8578 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/CompoundPredicateOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/CompoundPredicateOperator.java @@ -116,7 +116,7 @@ public String debugString() { } } - private List normalizeChildren() { + public List normalizeChildren() { List sortedChildren; switch (type) { case AND: diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java index 80175c21f6b2c..bfd55db58e873 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java @@ -16,6 +16,7 @@ public enum RuleType { TRANSFORMATION_RULES, + TF_DRIVING_TABLE_SELECTION, TF_JOIN_ASSOCIATIVITY_INNER, TF_JOIN_ASSOCIATIVITY_OUTER, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/DrivingTableSelection.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/DrivingTableSelection.java new file mode 100755 index 0000000000000..549753d2d8756 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/DrivingTableSelection.java @@ -0,0 +1,364 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer.rule.transformation; + +import com.starrocks.analysis.JoinOperator; +import com.starrocks.catalog.Table; +import com.starrocks.common.Pair; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.OptimizerContext; +import com.starrocks.sql.optimizer.base.ColumnRefFactory; +import com.starrocks.sql.optimizer.operator.Operator; +import com.starrocks.sql.optimizer.operator.OperatorType; +import com.starrocks.sql.optimizer.operator.logical.LogicalAssertOneRowOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalUnionOperator; +import com.starrocks.sql.optimizer.operator.pattern.Pattern; +import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator; +import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.optimizer.rewrite.BaseScalarOperatorShuttle; +import com.starrocks.sql.optimizer.rule.RuleType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; + +public class DrivingTableSelection extends TransformationRule { + + public DrivingTableSelection() { + super(RuleType.TF_DRIVING_TABLE_SELECTION, Pattern.create(OperatorType.LOGICAL_JOIN)); + } + + private class Node { + OptExpression parent; + OptExpression child; + Integer childIndex; + + public Node(OptExpression parent, OptExpression child, Integer childIndex) { + this.parent = parent; + this.child = child; + this.childIndex = childIndex; + } + } + + private Optional getSourceTableId(OptExpression parent, int childIdx, OptExpression node, List projections) { + Operator operator = node.getOp(); + if (operator instanceof LogicalScanOperator) { + return Optional.of(((LogicalScanOperator) operator).getTable().getId()); + } + if (operator instanceof LogicalJoinOperator) { + return Optional.empty(); + } + for (int i = 0; i < node.getInputs().size(); ++i) { + OptExpression child = node.inputAt(i); + Optional tableId = getSourceTableId(node, i, child, projections); + if (tableId.isPresent()) { + return tableId; + } + } + if (operator instanceof LogicalProjectOperator) { + projections.add(new Node(parent, node, childIdx)); + } + return Optional.empty(); + } + + private void extractJoins(OptExpression parent, int childIdx, OptExpression root, List> joinTables, + List> joinWithTableIdx, List projections) { + if (root.getOp() instanceof LogicalJoinOperator joinOperator && + (joinOperator.getJoinType().isCrossJoin() || joinOperator.getJoinType().isInnerJoin() && childIdx == -1)) { + Optional tableIdx = Optional.empty(); + for (int i = 0; i < root.getInputs().size(); i++) { + OptExpression child = root.inputAt(i); + Optional sourceTableId = getSourceTableId(root, i, child, projections); + if (sourceTableId.isPresent()) { + OptExpression tableChild = root.inputAt(i); + joinTables.add(new Pair<>(sourceTableId.get(), tableChild)); + joinWithTableIdx.add(new Pair<>(new Node(parent, root, childIdx), i)); + + if (tableIdx.isPresent()) { + tableIdx = Optional.empty(); + } else { + tableIdx = Optional.of(i); + } + } + } + if (tableIdx.isPresent()) { + int joinIdx = tableIdx.get() == 1 ? 0 : 1; + extractJoins(root, joinIdx, root.inputAt(joinIdx), joinTables, joinWithTableIdx, projections); + } + } else if (root.getOp() instanceof LogicalProjectOperator) { + extractJoins(root, 0, root.inputAt(0), joinTables, joinWithTableIdx, projections); + } + } + + boolean isCrossJoin(OptExpression root) { + if (root.getOp() instanceof LogicalJoinOperator joinOp && joinOp.getJoinType().isCrossJoin()) { + return true; + } else if (root.getOp() instanceof LogicalProjectOperator) { + return isCrossJoin(root.inputAt(0)); + } else { + return false; + } + } + + private void extractJoinOutputColumnMapping(OptExpression root, Map columnMapping, boolean isRoot) { + Operator operator = root.getOp(); + if (operator instanceof LogicalJoinOperator joinOperator) { + JoinOperator joinType = joinOperator.getJoinType(); + + if (!joinType.isCrossJoin() && !isRoot) { + return; + } + + } else if (operator instanceof LogicalProjectOperator projectOperator) { + projectOperator.getRowOutputInfo(root.getInputs()).getColumnRefMap().forEach((columnOp, scalarOp) -> { + for (int columnId : scalarOp.getUsedColumns().getColumnIds()) { + if (columnOp.getId() != columnId) { + columnMapping.put(columnOp.getId(), columnId); + } + } + }); + } else if (!(operator instanceof LogicalAssertOneRowOperator)) { + return; + } + + for (int i = 0; i < root.getInputs().size(); ++i) { + extractJoinOutputColumnMapping(root.inputAt(i), columnMapping, false); + } + } + + @Override + public boolean check(OptExpression input, OptimizerContext context) { + Operator inputOp = input.getOp(); + + boolean isInnerJoin = + inputOp instanceof LogicalJoinOperator && ((LogicalJoinOperator) inputOp).getJoinType().isInnerJoin(); + if (!isInnerJoin) { + return false; + } + if (input.getInputs().stream().noneMatch(this::isCrossJoin)) { + return false; + } + return super.check(input, context); + } + + @Override + public List transform(OptExpression input, OptimizerContext context) { + ColumnRefFactory columnRefFactory = context.getColumnRefFactory(); + LogicalJoinOperator rootJoinOp = (LogicalJoinOperator) input.getOp(); + ScalarOperator innerOnPredicate = rootJoinOp.getOnPredicate(); + + Map outputColumnMapping = new HashMap<>(); + extractJoinOutputColumnMapping(input, outputColumnMapping, true); + + Map> tableRelations = new HashMap<>(); + List> compoundTypes = new ArrayList<>(); + if (rootJoinOp.getJoinType().isInnerJoin() && rootJoinOp.getOnPredicate() != null) { + if (binaryRelation(rootJoinOp.getOnPredicate(), columnRefFactory, tableRelations, outputColumnMapping, compoundTypes, + 0, null)) { + return Collections.emptyList(); + } + } + if (tableRelations.size() <= 1) { + return Collections.emptyList(); + } + // all tables have a relationship with only the same table. + // e.g. select * from t1, t2, t3 inner join t4 on t4.c1 = t1.c1 or t4.c1 = t2.c1 or t4.c1 = t3.c1 + Optional> tableDepthMap = Optional.empty(); + for (Map.Entry> entry : tableRelations.entrySet()) { + HashMap map = entry.getValue(); + if (map.size() > 1) { + if (tableDepthMap.isPresent()) { + return Collections.emptyList(); + } + // driving table first + map.put(entry.getKey(), 0); + tableDepthMap = Optional.of(map); + } + } + if (tableDepthMap.isPresent()) { + List> joinTables = new ArrayList<>(); + List> joinWithTableIdx = new ArrayList<>(); + List projections = new ArrayList<>(); + extractJoins(null, -1, input, joinTables, joinWithTableIdx, projections); + + // JoinReorder + Map finalTableDepthMap = tableDepthMap.get(); + joinTables.sort(Comparator.comparingInt((Pair pair) -> finalTableDepthMap.get(pair.first)) + .thenComparingLong(pair -> pair.first)); + + for (int i = 0; i < joinTables.size(); i++) { + Pair joinPair = joinWithTableIdx.get(i); + Node join = joinPair.first; + OptExpression joinRoot = join.child; + int tableChildId = joinPair.second; + Pair tablePair = joinTables.get(i); + OptExpression tableChild = tablePair.second; + + joinRoot.setChild(tableChildId, tableChild); + } + + // Projection + Collections.reverse(projections); + for (Node node : projections) { + OptExpression child = node.child; + List childInputs = node.parent.inputAt(node.childIndex).getInputs(); + + Map newMap = new HashMap<>(); + for (OptExpression projectionChild : childInputs) { + newMap.putAll(projectionChild.getRowOutputInfo().getColumnRefMap()); + } + + LogicalProjectOperator projectionOp = (LogicalProjectOperator) child.getOp(); + node.parent.setChild(node.childIndex, OptExpression.create( + LogicalProjectOperator.builder().withOperator(projectionOp).setColumnRefMap(newMap).build(), + childInputs)); + } + ScalarOperator newOnPredicate = + innerOnPredicate.accept(new ColumnMappingRewriter(outputColumnMapping, columnRefFactory), null); + + compoundTypes.sort(Comparator.comparingInt((pair -> pair.first))); + for (int i = 0; i < joinWithTableIdx.size(); i++) { + Pair joinPair = joinWithTableIdx.get(i); + Node join = joinPair.first; + OptExpression joinRoot = join.child; + + if (join.parent != null) { + Pair compoundTypePair = compoundTypes.get(i - 1); + if (!compoundTypePair.second.equals(CompoundPredicateOperator.CompoundType.OR)) { + continue; + } + // rewrite to union all + List result = new ArrayList<>(); + List> childOutputColumns = List.of(new ArrayList<>(), new ArrayList<>()); + + Map leftMap = new HashMap<>(); + Map rightMap = new HashMap<>(); + OptExpression leftInput = joinRoot.inputAt(0); + OptExpression rightInput = joinRoot.inputAt(1); + + extractUnionInput(result, childOutputColumns, leftMap, rightMap, leftInput); + extractUnionInput(result, childOutputColumns, rightMap, leftMap, rightInput); + List newInputs = new ArrayList<>(); + newInputs.add( + OptExpression.create(LogicalProjectOperator.builder().setColumnRefMap(leftMap).build(), leftInput)); + newInputs.add( + OptExpression.create(LogicalProjectOperator.builder().setColumnRefMap(rightMap).build(), rightInput)); + + join.parent.setChild(join.childIndex, + OptExpression.create(new LogicalUnionOperator(result, childOutputColumns, true), newInputs)); + } + } + + return List.of(OptExpression.create( + LogicalJoinOperator.builder().withOperator(rootJoinOp).setOnPredicate(newOnPredicate).build(), + input.getInputs())); + } + return Collections.emptyList(); + } + + private void extractUnionInput(List result, List> childOutputColumns, + Map leftMap, + Map rightMap, OptExpression input) { + input.getRowOutputInfo().getColumnRefMap().forEach((columnOp, scalarOp) -> { + ColumnRefOperator nullableColumnOp = + new ColumnRefOperator(columnOp.getId(), columnOp.getType(), columnOp.getName(), true); + result.add(nullableColumnOp); + childOutputColumns.get(0).add(nullableColumnOp); + childOutputColumns.get(1).add(nullableColumnOp); + leftMap.put(nullableColumnOp, scalarOp); + rightMap.put(nullableColumnOp, ConstantOperator.createNull(scalarOp.getType())); + }); + } + + private Long getTableIdByColumnId(int columnId, ColumnRefFactory columnRefFactory, Map columnMapping) { + Table table = columnRefFactory.getTableForColumn(columnId); + + if (null == table) { + if (columnMapping.containsKey(columnId)) { + return getTableIdByColumnId(columnMapping.get(columnId), columnRefFactory, columnMapping); + } else { + return null; + } + } else { + return table.getId(); + } + } + + private boolean binaryRelation(ScalarOperator onPredicate, ColumnRefFactory columnRefFactory, + Map> tableRelations, Map columnMapping, + List> compoundTypes, + int depth, CompoundPredicateOperator.CompoundType compoundType) { + if (onPredicate instanceof CompoundPredicateOperator compoundPredicate) { + for (ScalarOperator scalarOperator : compoundPredicate.normalizeChildren()) { + if (binaryRelation(scalarOperator, columnRefFactory, tableRelations, columnMapping, compoundTypes, depth + 1, + compoundPredicate.getCompoundType())) { + return true; + } + } + } else if (onPredicate instanceof BinaryPredicateOperator && depth != 0) { + int[] columnIds = onPredicate.getUsedColumns().getColumnIds(); + if (columnIds.length == 2) { + int leftIdx = columnIds[0]; + int rightIdx = columnIds[1]; + + Long leftTableId = getTableIdByColumnId(leftIdx, columnRefFactory, columnMapping); + Long rightTableId = getTableIdByColumnId(rightIdx, columnRefFactory, columnMapping); + + if (leftTableId != null && rightTableId != null) { + BiFunction function = + (tableId, tableDepth) -> tableDepth == null ? depth : Math.min(depth, tableDepth); + tableRelations.computeIfAbsent(leftTableId, k -> new HashMap<>()).compute(rightTableId, function); + tableRelations.computeIfAbsent(rightTableId, k -> new HashMap<>()).compute(leftTableId, function); + compoundTypes.add(new Pair<>(depth, compoundType)); + } + } else { + // e,g, `t1.c1 = t2.c1 + t3.c1` + return columnIds.length > 2; + } + } + return false; + } + + private class ColumnMappingRewriter extends BaseScalarOperatorShuttle { + Map columnMapping; + ColumnRefFactory columnRefFactory; + + public ColumnMappingRewriter(Map columnMapping, ColumnRefFactory columnRefFactory) { + this.columnMapping = columnMapping; + this.columnRefFactory = columnRefFactory; + } + + @Override + public ScalarOperator visitVariableReference(ColumnRefOperator variable, Void context) { + if (columnMapping.containsKey(variable.getId())) { + Integer mappingColumnId = columnMapping.get(variable.getId()); + return columnRefFactory.getColumnRef(mappingColumnId); + } else { + return super.visitVariableReference(variable, context); + } + } + } +} diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/rule/transformation/DrivingTableSelectionRuleTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/rule/transformation/DrivingTableSelectionRuleTest.java new file mode 100755 index 0000000000000..4c77dfb8293c4 --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/rule/transformation/DrivingTableSelectionRuleTest.java @@ -0,0 +1,160 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer.rule.transformation; + +import com.google.common.collect.Maps; +import com.starrocks.analysis.BinaryType; +import com.starrocks.analysis.JoinOperator; +import com.starrocks.catalog.Column; +import com.starrocks.catalog.OlapTable; +import com.starrocks.catalog.ScalarType; +import com.starrocks.catalog.Table; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.OptimizerContext; +import com.starrocks.sql.optimizer.OptimizerFactory; +import com.starrocks.sql.optimizer.base.ColumnRefFactory; +import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator; +import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator; +import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.utframe.UtFrameUtils; +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class DrivingTableSelectionRuleTest { + + @Test + public void transform() { + DrivingTableSelection rule = new DrivingTableSelection(); + ColumnRefFactory columnRefFactory = new ColumnRefFactory(); + + Column c1 = new Column("c1", ScalarType.INT, true); + Column c2 = new Column("c2", ScalarType.INT, true); + Column c3 = new Column("c3", ScalarType.INT, true); + Column c4 = new Column("c4", ScalarType.INT, true); + Column c5 = new Column("c4", ScalarType.INT, true); + Column c6 = new Column("c5", ScalarType.INT, true); + ColumnRefOperator c1Operator = columnRefFactory.create("c1", ScalarType.INT, true); + ColumnRefOperator c2Operator = columnRefFactory.create("c2", ScalarType.INT, true); + ColumnRefOperator c3Operator = columnRefFactory.create("c3", ScalarType.INT, true); + ColumnRefOperator c4Operator = columnRefFactory.create("c4", ScalarType.INT, true); + ColumnRefOperator c5Operator = columnRefFactory.create("c5", ScalarType.INT, true); + ColumnRefOperator c6Operator = columnRefFactory.create("c6", ScalarType.INT, true); + + Table t = new OlapTable(0, "t", List.of(c1, c2, c3), null, null, null); + Table t1 = new OlapTable(1, "t1", List.of(c4), null, null, null); + Table t2 = new OlapTable(2, "t2", List.of(c5), null, null, null); + Table t3 = new OlapTable(3, "t3", List.of(c6), null, null, null); + + columnRefFactory.updateColumnRefToColumns(c1Operator, c1, t); + columnRefFactory.updateColumnRefToColumns(c2Operator, c2, t); + columnRefFactory.updateColumnRefToColumns(c3Operator, c3, t); + columnRefFactory.updateColumnRefToColumns(c4Operator, c4, t1); + columnRefFactory.updateColumnRefToColumns(c5Operator, c5, t2); + columnRefFactory.updateColumnRefToColumns(c6Operator, c6, t3); + + Map tColumnMap = Maps.newHashMap(); + Map t1ColumnMap = Maps.newHashMap(); + Map t2ColumnMap = Maps.newHashMap(); + Map t3ColumnMap = Maps.newHashMap(); + tColumnMap.put(c1Operator, c1); + tColumnMap.put(c2Operator, c2); + tColumnMap.put(c3Operator, c3); + t1ColumnMap.put(c4Operator, c4); + t2ColumnMap.put(c5Operator, c5); + t3ColumnMap.put(c6Operator, c6); + + OptExpression tScan = new OptExpression(new LogicalOlapScanOperator(t, tColumnMap, Maps.newHashMap(), null, -1, null)); + OptExpression t1Scan = new OptExpression(new LogicalOlapScanOperator(t1, t1ColumnMap, Maps.newHashMap(), null, -1, null)); + OptExpression t2Scan = new OptExpression(new LogicalOlapScanOperator(t2, t2ColumnMap, Maps.newHashMap(), null, -1, null)); + OptExpression t3Scan = new OptExpression(new LogicalOlapScanOperator(t3, t3ColumnMap, Maps.newHashMap(), null, -1, null)); + + Map projection1ColumnRefMap = Maps.newHashMap(); + Map projection2ColumnRefMap = Maps.newHashMap(); + Map projection3ColumnRefMap = Maps.newHashMap(); + projection1ColumnRefMap.putAll(tScan.getRowOutputInfo().getColumnRefMap()); + projection1ColumnRefMap.putAll(t1Scan.getRowOutputInfo().getColumnRefMap()); + projection2ColumnRefMap.putAll(projection1ColumnRefMap); + projection2ColumnRefMap.putAll(t2Scan.getRowOutputInfo().getColumnRefMap()); + projection3ColumnRefMap.putAll(t3Scan.getRowOutputInfo().getColumnRefMap()); + + ScalarOperator onKeys = new CompoundPredicateOperator(CompoundPredicateOperator.CompoundType.AND, + new BinaryPredicateOperator(BinaryType.EQ, c1Operator, c4Operator), + new CompoundPredicateOperator(CompoundPredicateOperator.CompoundType.OR, + new BinaryPredicateOperator(BinaryType.EQ, c2Operator, c5Operator), + new BinaryPredicateOperator(BinaryType.EQ, c3Operator, c6Operator))); + OptExpression crossJoin1 = OptExpression.create(new LogicalJoinOperator(JoinOperator.CROSS_JOIN, null), tScan, t1Scan); + OptExpression projection1 = + OptExpression.create(LogicalProjectOperator.builder().setColumnRefMap(projection1ColumnRefMap).build(), + crossJoin1); + OptExpression crossJoin2 = + OptExpression.create(new LogicalJoinOperator(JoinOperator.CROSS_JOIN, null), projection1, t2Scan); + OptExpression projection2 = + OptExpression.create(LogicalProjectOperator.builder().setColumnRefMap(projection2ColumnRefMap).build(), + crossJoin2); + OptExpression projection3 = + OptExpression.create(LogicalProjectOperator.builder().setColumnRefMap(projection3ColumnRefMap).build(), t3Scan); + OptExpression innerJoin = + OptExpression.create(new LogicalJoinOperator(JoinOperator.INNER_JOIN, onKeys), projection2, projection3); + + OptimizerContext context = OptimizerFactory.mockContext(UtFrameUtils.createDefaultCtx(), columnRefFactory); + + System.out.println("Before: " + innerJoin.debugString()); + assertTrue(rule.check(innerJoin, context)); + List transform = rule.transform(innerJoin, context); + + OptExpression newJoinTree = transform.get(0); + System.out.println("After: " + newJoinTree.debugString()); + + // JOIN + { + assertEquals(0, ((LogicalOlapScanOperator) newJoinTree.inputAt(1).getOp()).getTable().getId()); + assertEquals(2, + ((LogicalOlapScanOperator) newJoinTree.inputAt(0).inputAt(0).inputAt(0).inputAt(0).inputAt(0).inputAt(0) + .getOp()).getTable().getId()); + } + // Projection + { + Map map0 = new HashMap<>(); + map0.putAll(t1Scan.getRowOutputInfo().getColumnRefMap()); + map0.putAll(t2Scan.getRowOutputInfo().getColumnRefMap()); + map0.putAll(t3Scan.getRowOutputInfo().getColumnRefMap()); + + assertEquals(map0, ((LogicalProjectOperator) newJoinTree.inputAt(0).getOp()).getColumnRefMap()); + + Map map1 = new HashMap<>(); + map1.putAll(t2Scan.getRowOutputInfo().getColumnRefMap()); + map1.putAll(t3Scan.getRowOutputInfo().getColumnRefMap()); + + assertEquals(map1, ((LogicalProjectOperator) newJoinTree.inputAt(0).inputAt(0).inputAt(0).getOp()).getColumnRefMap()); + + Map map2 = new HashMap<>(t3Scan.getRowOutputInfo().getColumnRefMap()); + t2Scan.getRowOutputInfo().getColumnRefMap().forEach((k, v) -> map2.put(k, ConstantOperator.createNull(v.getType()))); + + assertEquals(map2, ((LogicalProjectOperator) newJoinTree.inputAt(0).inputAt(0).inputAt(0).inputAt(0).inputAt(1) + .getOp()).getColumnRefMap()); + } + } +} diff --git a/test/sql/test_join/R/test_driving_table_selection b/test/sql/test_join/R/test_driving_table_selection new file mode 100644 index 0000000000000..824afd0ec5aa5 --- /dev/null +++ b/test/sql/test_join/R/test_driving_table_selection @@ -0,0 +1,24 @@ +-- name: test_driving_table_selection +create table T (a int, b int, c int) properties("replication_num"="1"); +-- result: +-- !result +insert into T values(1,1,1); +-- result: +-- !result +create table T0 properties("replication_num"="1") as select * from T; +-- result: +-- !result +create table T1 properties("replication_num"="1") as select * from T; +-- result: +-- !result +create table T2 properties("replication_num"="1") as select * from T; +-- result: +-- !result +select * from T where T.a = (select a from T0) or T.b = (select b from T1) or T.c = (select c from T2); +-- result: +1 1 1 +-- !result +select * from T where T.a = (select a from T0) and T.b = (select b from T1) or T.c = (select c from T2); +-- result: +1 1 1 +-- !result \ No newline at end of file diff --git a/test/sql/test_join/T/test_driving_table_selection b/test/sql/test_join/T/test_driving_table_selection new file mode 100644 index 0000000000000..fc210a0006b35 --- /dev/null +++ b/test/sql/test_join/T/test_driving_table_selection @@ -0,0 +1,12 @@ +-- name: test_driving_table_selection +create table T (a int, b int, c int) properties("replication_num"="1"); + +insert into T values(1,1,1); + +create table T0 properties("replication_num"="1") as select * from T; +create table T1 properties("replication_num"="1") as select * from T; +create table T2 properties("replication_num"="1") as select * from T; + +select * from T where T.a = (select a from T0) or T.b = (select b from T1) or T.c = (select c from T2); + +select * from T where T.a = (select a from T0) and T.b = (select b from T1) or T.c = (select c from T2); \ No newline at end of file