Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CALCITE-6652] RelDecorrelator can't decorrelate query with limit 1 #4181

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 156 additions & 10 deletions core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.BiRel;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
Expand Down Expand Up @@ -73,6 +74,7 @@
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlFunction;
Expand All @@ -94,6 +96,7 @@
import org.apache.calcite.util.mapping.Mappings;
import org.apache.calcite.util.trace.CalciteTrace;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
Expand Down Expand Up @@ -523,6 +526,19 @@ protected RexNode removeCorrelationExpr(
return null;
}

if (isCorVarDefined && (rel.fetch != null || rel.offset != null)) {
if (rel.fetch != null
&& rel.offset == null
&& RexLiteral.intValue(rel.fetch) == 1) {
return decorrelateFetchOneSort(rel, frame);
}
// Can not decorrelate if the sort has per-correlate-key attributes like
// offset or fetch limit, because these attributes scope would change to
// global after decorrelation. They should take effect within the scope
// of the correlation key actually.
return null;
}

final RelNode newInput = frame.r;

Mappings.TargetMapping mapping =
Expand Down Expand Up @@ -767,16 +783,6 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
public @Nullable Frame getInvoke(RelNode r, boolean isCorVarDefined, @Nullable RelNode parent) {
final Frame frame = dispatcher.invoke(r, isCorVarDefined);
currentRel = parent;
if (frame != null && isCorVarDefined && r instanceof Sort) {
final Sort sort = (Sort) r;
// Can not decorrelate if the sort has per-correlate-key attributes like
// offset or fetch limit, because these attributes scope would change to
// global after decorrelation. They should take effect within the scope
// of the correlation key actually.
if (sort.offset != null || sort.fetch != null) {
return null;
}
}
if (frame != null) {
map.put(r, frame);
}
Expand All @@ -795,6 +801,146 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
return null;
}

protected @Nullable Frame decorrelateFetchOneSort(Sort sort, final Frame frame) {
Frame aggFrame = decorrelateSortAsAggregate(sort, frame);
if (aggFrame != null) {
return aggFrame;
}
//
// Rewrite logic:
//
// If sorted without offset and fetch = 1 (enforced by the caller), rewrite the sort to be
// Aggregate(group=(corVar.. , field..))
// project(first_value(field) over (partition by corVar order by (sort collation)))
// input
//
// 1. For the original sorted input, apply the FIRST_VALUE window function to produce
// the result of sorting with LIMIT 1, and the same as the decorrelate of aggregate,
// add correlated variables in partition list to maintain semantic consistency.
// 2. To ensure that there is at most one row of output for
// any combination of correlated variables, distinct for correlated variables.
// 3. Since we have partitioned by all correlated variables
// in the sorted output field window, so for any combination of correlated variables,
// all other field values are unique. So the following two are equivalent:
// - group by corVar1, covVar2, field1, field2
// - any_value(fields1), any_value(fields2) group by corVar1, covVar2
// Here we use the first.
final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();

final PairList<RexNode, String> corVarProjects = PairList.of();
List<RelDataTypeField> fieldList = frame.r.getRowType().getFieldList();
for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
corDefOutputs.put(entry.getKey(),
sort.getRowType().getFieldCount() + corVarProjects.size());
RexInputRef.add2(corVarProjects, entry.getValue(), fieldList);
}

final List<RexNode> sortExprs =
new ArrayList<>(sort.getCollation().getFieldCollations().size());
for (RelFieldCollation collation : sort.getCollation().getFieldCollations()) {
Integer newIdx = requireNonNull(frame.oldToNewOutputs.get(collation.getFieldIndex()));
RexNode node = RexInputRef.of(newIdx, fieldList);
if (collation.direction == RelFieldCollation.Direction.DESCENDING) {
node = relBuilder.desc(node);
}
if (collation.nullDirection == RelFieldCollation.NullDirection.FIRST) {
node = relBuilder.nullsFirst(node);
} else if (collation.nullDirection == RelFieldCollation.NullDirection.LAST) {
node = relBuilder.nullsLast(node);
}
sortExprs.add(node);
}

final PairList<RexNode, String> newProjExprs = PairList.of();
for (RelDataTypeField field : sort.getRowType().getFieldList()) {
final int newIdx =
requireNonNull(frame.oldToNewOutputs.get(field.getIndex()));

RelBuilder.AggCall aggCall =
relBuilder.aggregateCall(SqlStdOperatorTable.FIRST_VALUE,
RexInputRef.of(newIdx, fieldList));

// Convert each field from the sorted output to a window function that partitions by
// correlated variables, orders by the collation, and return the first_value.
RexNode winCall = aggCall.over()
.orderBy(sortExprs)
.partitionBy(corVarProjects.leftList())
.toRex();
mapOldToNewOutputs.put(newProjExprs.size(), newProjExprs.size());
newProjExprs.add(winCall, field.getName());
}
newProjExprs.addAll(corVarProjects);
RelNode result = relBuilder.push(frame.r)
.project(newProjExprs.leftList(), newProjExprs.rightList())
.distinct().build();

return register(sort, result, mapOldToNewOutputs, corDefOutputs);
}

protected @Nullable Frame decorrelateSortAsAggregate(Sort sort, final Frame frame) {
final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
if (sort.getCollation().getFieldCollations().size() == 1
&& sort.getRowType().getFieldCount() == 1
&& !frame.corDefOutputs.isEmpty()) {
//
// Rewrite logic:
//
// If sorted with no OFFSET and FETCH = 1, and only one collation field,
// rewrite the Sort as Aggregate using MIN/MAX function.
// Example:
// Sort(sort0=[$0], dir0=[ASC], fetch=[1])
// input
// Rewrite to:
// Aggregate(group=(corVar), agg=[min($0))
//
// Note: MIN/MAX is not strictly equivalent to LIMIT 1. When the input has 0 rows,
// MIN/MAX returns NULL, while LIMIT 1 returns 0 rows.
// However, in the decorrelate, we add correlated variables to the group list
// to ensure equivalence when Correlate is transformed to Join. When the group list
// is non-empty, MIN/MAX will also return 0 rows if the input has 0 rows.
// So in this case, the transformation is legal.
RelFieldCollation collation = Util.first(sort.getCollation().getFieldCollations());

if (collation.nullDirection != RelFieldCollation.NullDirection.LAST) {
return null;
}

SqlAggFunction aggFunction;
switch (collation.getDirection()) {
case ASCENDING:
case STRICTLY_ASCENDING:
aggFunction = SqlStdOperatorTable.MIN;
break;
case DESCENDING:
case STRICTLY_DESCENDING:
aggFunction = SqlStdOperatorTable.MAX;
break;
default:
return null;
}

final int newIdx = requireNonNull(frame.oldToNewOutputs.get(collation.getFieldIndex()));
RelBuilder.AggCall aggCall = relBuilder.push(frame.r)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering: sorting has collation, but min/max do not.
For example, sorting can specify what happens to nulls.

Copy link
Contributor

@rubenada rubenada Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was my exact same thought after a quick look at the PR... is the proposed conversion 100% equivalent in all cases, also when nulls are involved? What about when the original Sort is nulls-first and the relevant field has null values?

Copy link
Contributor Author

@suibianwanwank suibianwanwank Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as mentioned in JIRA, I missed this. This conversion is only valid if nulls-last.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have some unit tests with nulls-first, to see the expected result being the conversion not being applied?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think that MAX works for NULL LAST only, and MIN for NULL FIRST.
Am I wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For MIN/MAX, null values are ignored in the databases I know (I didn't find a corresponding description in the sql standard). So I think both asc and desc should be null last, i.e.return null only if all values are null.

.aggregateCall(aggFunction, relBuilder.fields(ImmutableList.of(newIdx)));

// As with the aggregate decorrelate, add correlated variables to the group list.
final List<RexInputRef> groupKey = new ArrayList<>();
for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
groupKey.add(RexInputRef.of(entry.getValue(), frame.r.getRowType()));
corDefOutputs.put(entry.getKey(), corDefOutputs.size());
}

RelNode aggregate = relBuilder.aggregate(relBuilder.groupKey(groupKey), aggCall).build();

// Add the mapping for the added aggregate fields.
mapOldToNewOutputs.put(0, groupKey.size());
return register(sort, aggregate, mapOldToNewOutputs, corDefOutputs);
}
return null;
}

public @Nullable Frame decorrelateRel(LogicalProject rel, boolean isCorVarDefined) {
return decorrelateRel((Project) rel, isCorVarDefined);
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/org/apache/calcite/util/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -2086,7 +2086,7 @@ public static <T> Iterable<T> orEmpty(@Nullable Iterable<T> v0) {
*
* @throws java.lang.IndexOutOfBoundsException if the list is empty
*/
public <E> E first(List<E> list) {
public static <E> E first(List<E> list) {
return list.get(0);
}

Expand Down
105 changes: 105 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8577,6 +8577,111 @@ private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) {
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateProjectWithFetchOne() {
final String query = "SELECT name, "
+ "(SELECT sal FROM emp where dept.deptno = emp.deptno order by sal limit 1) "
+ "FROM dept";
sql(query).withRule(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateProjectWithFetchOneDesc() {
final String query = "SELECT name, "
+ "(SELECT emp.sal FROM emp WHERE dept.deptno = emp.deptno "
+ "ORDER BY emp.sal desc nulls last LIMIT 1) "
+ "FROM dept";
sql(query).withRule(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateFilterWithFetchOne() {
final String query = "SELECT name FROM dept "
+ "WHERE 10 > (SELECT emp.sal FROM emp where dept.deptno = emp.deptno "
+ "ORDER BY emp.sal limit 1)";
sql(query).withRule(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateFilterWithFetchOneDesc() {
final String query = "SELECT name FROM dept "
rubenada marked this conversation as resolved.
Show resolved Hide resolved
+ "WHERE 10 > (SELECT emp.sal FROM emp where dept.deptno = emp.deptno "
+ "ORDER BY emp.sal desc nulls last limit 1)";
sql(query).withRule(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateFilterWithFetchOneDesc1() {
final String query = "SELECT name FROM dept "
+ "WHERE 10 > (SELECT emp.sal FROM emp where dept.deptno = emp.deptno "
+ "ORDER BY emp.sal desc limit 1)";
sql(query).withRule(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateProjectWithMultiKeyAndFetchOne() {
final String query = "SELECT name, "
+ "(SELECT sal FROM emp where dept.deptno = emp.deptno "
+ "order by year(hiredate), emp.sal limit 1) FROM dept";
sql(query).withRule(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateProjectWithMultiKeyAndFetchOne1() {
final String query = "SELECT name, "
+ "(SELECT sal FROM emp where dept.deptno = emp.deptno and dept.name = emp.ename "
+ "order by year(hiredate), emp.sal limit 1) FROM dept";
sql(query).withRule(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6652">[CALCITE-6652]
* RelDecorrelator can't decorrelate query with limit 1</a>.
*/
@Test void testDecorrelateFilterWithMultiKeyAndFetchOne() {
final String query = "SELECT name FROM dept "
+ "WHERE 10 > (SELECT emp.sal FROM emp where dept.deptno = emp.deptno "
+ "order by year(hiredate), emp.sal desc limit 1)";
sql(query).withRule(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)
.withLateDecorrelate(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-434">[CALCITE-434]
* Converting predicates on date dimension columns into date ranges</a>,
Expand Down
Loading