Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix the PPL Lookup command behavior when inputField is missing and REPLACE with existing fields #1035

Merged
merged 4 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ppl-lang/ppl-lookup-command.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ LOOKUP <lookupIndex> (<lookupMappingField> [AS <sourceMappingField>])...
**inputField**
- Optional
- Default: All fields of \<lookupIndex\> where matched values are applied to result output if no field is specified.
- Description: A field in \<lookupIndex\> where matched values are applied to result output. You can specify multiple \<inputField\> with comma-delimited. If you don't specify any \<inputField\>, all fields of \<lookupIndex\> where matched values are applied to result output.
- Description: A field in \<lookupIndex\> where matched values are applied to result output. You can specify multiple \<inputField\> with comma-delimited. If you don't specify any \<inputField\>, all fields expect \<lookupMappingField\> from \<lookupIndex\> where matched values are applied to result output.

**outputField**
- Optional
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext c
@Override
public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) {
CatalystPlanContext innerContext = new CatalystPlanContext();
innerContext.withSparkSession(outerContext.getSparkSession());
visitExpressionList(node.getChild(), innerContext);
Seq<Expression> values = innerContext.retainAllNamedParseExpressions(p -> p);
UnresolvedPlan outerPlan = node.getQuery();
Expand All @@ -387,6 +388,7 @@ public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerCont
@Override
public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) {
CatalystPlanContext innerContext = new CatalystPlanContext();
innerContext.withSparkSession(context.getSparkSession());
UnresolvedPlan outerPlan = node.getQuery();
LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext);
Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply(
Expand All @@ -402,6 +404,7 @@ public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext c
@Override
public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) {
CatalystPlanContext innerContext = new CatalystPlanContext();
innerContext.withSparkSession(context.getSparkSession());
UnresolvedPlan outerPlan = node.getQuery();
LogicalPlan subSearch = outerPlan.accept(planVisitor, innerContext);
Expression existsSubQuery = Exists$.MODULE$.apply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.ppl;

import lombok.Getter;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.Expression;
Expand Down Expand Up @@ -38,6 +39,8 @@
* The context used for Catalyst logical plan.
*/
public class CatalystPlanContext {

@Getter private SparkSession sparkSession;
/**
* Catalyst relations list
**/
Expand Down Expand Up @@ -283,4 +286,8 @@ public Expression resolveJoinCondition(
isResolvingJoinCondition = false;
return result;
}

public void withSparkSession(SparkSession sparkSession) {
this.sparkSession = sparkSession;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,41 +193,40 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) {
public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {
visitFirstChild(node, context);
return context.apply( searchSide -> {
LogicalPlan target;
LogicalPlan lookupTable = node.getLookupRelation().accept(this, context);
Expression lookupCondition = buildLookupMappingCondition(node, expressionAnalyzer, context);
// If no output field is specified, all fields from lookup table are applied to the output.
// If no output field is specified, all fields except mapping fields from lookup table are applied to the output.
if (node.allFieldsShouldAppliedToOutputList()) {
context.retainAllNamedParseExpressions(p -> p);
context.retainAllPlans(p -> p);
return join(searchSide, lookupTable, Join.JoinType.LEFT, Optional.of(lookupCondition), new Join.JoinHint());
}

// If the output fields are specified, build a project list for lookup table.
// The mapping fields of lookup table should be added in this project list, otherwise join will fail.
// So the mapping fields of lookup table should be dropped after join.
List<NamedExpression> lookupTableProjectList = buildLookupRelationProjectList(node, expressionAnalyzer, context);
LogicalPlan lookupTableWithProject = Project$.MODULE$.apply(seq(lookupTableProjectList), lookupTable);

LogicalPlan join = join(searchSide, lookupTableWithProject, Join.JoinType.LEFT, Optional.of(lookupCondition), new Join.JoinHint());
target = join(searchSide, lookupTable, Join.JoinType.LEFT, Optional.of(lookupCondition), new Join.JoinHint());
} else {
// If the output fields are specified, build a project list for lookup table.
// The mapping fields of lookup table should be added in this project list, otherwise join will fail.
// So the mapping fields of lookup table should be dropped after join.
List<NamedExpression> lookupTableProjectList = buildLookupRelationProjectList(node, expressionAnalyzer, context);
LogicalPlan lookupTableWithProject = Project$.MODULE$.apply(seq(lookupTableProjectList), lookupTable);

// Add all outputFields by __auto_generated_subquery_name_s.*
List<NamedExpression> outputFieldsWithNewAdded = new ArrayList<>();
outputFieldsWithNewAdded.add(UnresolvedStar$.MODULE$.apply(Option.apply(seq(node.getSourceSubqueryAliasName()))));
LogicalPlan join = join(searchSide, lookupTableWithProject, Join.JoinType.LEFT, Optional.of(lookupCondition), new Join.JoinHint());

// Add new columns based on different strategies:
// Append: coalesce($outputField, $"inputField").as(outputFieldName)
// Replace: $outputField.as(outputFieldName)
outputFieldsWithNewAdded.addAll(buildOutputProjectList(node, node.getOutputStrategy(), expressionAnalyzer, context));
// Add all outputFields by __auto_generated_subquery_name_s.*
List<NamedExpression> outputFieldsWithNewAdded = new ArrayList<>();
outputFieldsWithNewAdded.add(UnresolvedStar$.MODULE$.apply(Option.apply(seq(node.getSourceSubqueryAliasName()))));

org.apache.spark.sql.catalyst.plans.logical.Project outputWithNewAdded = Project$.MODULE$.apply(seq(outputFieldsWithNewAdded), join);
// Add new columns based on different strategies:
// Append: coalesce($outputField, $"inputField").as(outputFieldName)
// Replace: $outputField.as(outputFieldName)
outputFieldsWithNewAdded.addAll(buildOutputProjectList(node, node.getOutputStrategy(), expressionAnalyzer, context, searchSide));

target = Project$.MODULE$.apply(seq(outputFieldsWithNewAdded), join);
}
// Drop the mapping fields of lookup table in result:
// For example, in command "LOOKUP lookTbl Field1 AS Field2, Field3",
// the Field1 and Field3 are projection fields and join keys which will be dropped in result.
List<Field> mappingFieldsOfLookup = node.getLookupMappingMap().entrySet().stream()
.map(kv -> kv.getKey().getField() == kv.getValue().getField() ? buildFieldWithLookupSubqueryAlias(node, kv.getKey()) : kv.getKey())
.collect(Collectors.toList());
// List<Field> mappingFieldsOfLookup = new ArrayList<>(node.getLookupMappingMap().keySet());
List<Expression> dropListOfLookupMappingFields =
buildProjectListFromFields(mappingFieldsOfLookup, expressionAnalyzer, context).stream()
.map(Expression.class::cast).collect(Collectors.toList());
Expand All @@ -237,7 +236,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {
List<Expression> toDrop = new ArrayList<>(dropListOfLookupMappingFields);
toDrop.addAll(dropListOfSourceFields);

LogicalPlan outputWithDropped = DataFrameDropColumns$.MODULE$.apply(seq(toDrop), outputWithNewAdded);
LogicalPlan outputWithDropped = DataFrameDropColumns$.MODULE$.apply(seq(toDrop), target);

context.retainAllNamedParseExpressions(p -> p);
context.retainAllPlans(p -> p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import org.apache.spark.sql.catalyst.expressions.EqualTo$;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.tree.Lookup;
import org.opensearch.sql.ppl.CatalystExpressionVisitor;
import org.opensearch.sql.ppl.CatalystPlanContext;
import org.opensearch.sql.ppl.CatalystQueryPlanVisitor;
import scala.Option;

import java.util.ArrayList;
Expand Down Expand Up @@ -83,15 +83,23 @@ static List<NamedExpression> buildOutputProjectList(
Lookup node,
Lookup.OutputStrategy strategy,
CatalystExpressionVisitor expressionAnalyzer,
CatalystPlanContext context) {
CatalystPlanContext context,
LogicalPlan searchSide) {
List<NamedExpression> outputProjectList = new ArrayList<>();
for (Map.Entry<Alias, Field> entry : node.getOutputCandidateMap().entrySet()) {
Alias inputFieldWithAlias = entry.getKey();
Field inputField = (Field) inputFieldWithAlias.getDelegated();
Field outputField = entry.getValue();
Expression inputCol = expressionAnalyzer.visitField(inputField, context);
Expression outputCol = expressionAnalyzer.visitField(outputField, context);

// Always resolve the inputCol expression with alias: __auto_generated_subquery_name_l.<fieldName>
// If the outputField existed in source table, resolve the outputCol expression with alias: __auto_generated_subquery_name_s.<fieldName>
// If not, resolve the outputCol expression without alias: <fieldName> to avoid failure of unable to resolved attribute.
Expression inputCol = expressionAnalyzer.visitField(buildFieldWithLookupSubqueryAlias(node, inputField), context);
Expression outputCol;
if (RelationUtils.columnExistsInCatalogTable(context.getSparkSession(), outputField, searchSide)) {
outputCol = expressionAnalyzer.visitField(buildFieldWithSourceSubqueryAlias(node, outputField), context);
} else {
outputCol = expressionAnalyzer.visitField(outputField, context);
}
Expression child;
if (strategy == Lookup.OutputStrategy.APPEND) {
child = Coalesce$.MODULE$.apply(seq(outputCol, inputCol));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,26 @@

package org.opensearch.sql.ppl.utils;

import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.catalog.CatalogTable;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.opensearch.flint.spark.ppl.PPLSparkUtils;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.QualifiedName;
import scala.Option$;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;

public interface RelationUtils {
Logger LOG = Logger.getLogger(RelationUtils.class.getName());

/**
* attempt resolving if the field is relating to the given relation
* if name doesnt contain table prefix - add the current relation prefix to the fields name - returns true
Expand Down Expand Up @@ -65,4 +75,23 @@ static TableIdentifier getTableIdentifier(QualifiedName qualifiedName) {
}
return identifier;
}

static boolean columnExistsInCatalogTable(SparkSession spark, Field field, LogicalPlan plan) {
UnresolvedRelation relation = PPLSparkUtils.findLogicalRelations(plan).head();
QualifiedName tableQualifiedName = QualifiedName.of(Arrays.asList(relation.tableName().split("\\.")));
TableIdentifier sourceTableIdentifier = getTableIdentifier(tableQualifiedName);
boolean sourceTableExists = spark.sessionState().catalog().tableExists(sourceTableIdentifier);
if (sourceTableExists) {
try {
CatalogTable table = spark.sessionState().catalog().getTableMetadata(getTableIdentifier(tableQualifiedName));
Copy link
Member Author

@LantaoJin LantaoJin Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a just a metadata call via getting external catalog from the SparkSession passed in. No additional spark job will be submitted.

Copy link
Member

@YANG-DB YANG-DB Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense doing this metadata discovery more generic so that all relations have this info ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if we see more requirements later. Now I am still willing to leave the metadata binding to Spark internal if not necessary.

return Arrays.stream(table.dataSchema().fields()).anyMatch(f -> f.name().equalsIgnoreCase(field.getField().toString()));
} catch (NoSuchDatabaseException | NoSuchTableException e) {
LOG.warning("Source table or database " + sourceTableIdentifier + " not found");
return false;
}
} else {
LOG.warning("Source table " + sourceTableIdentifier + " not found");
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class FlintPPLSparkExtensions extends (SparkSessionExtensions => Unit) {

override def apply(extensions: SparkSessionExtensions): Unit = {
extensions.injectParser { (spark, parser) =>
new FlintSparkPPLParser(parser)
new FlintSparkPPLParser(parser, spark)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SparkSession in using is passed in from here.

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.common.antlr.SyntaxCheckException
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser._
Expand All @@ -44,7 +45,8 @@ import org.apache.spark.sql.types.{DataType, StructType}
* @param sparkParser
* Spark SQL parser
*/
class FlintSparkPPLParser(sparkParser: ParserInterface) extends ParserInterface {
class FlintSparkPPLParser(sparkParser: ParserInterface, val spark: SparkSession)
extends ParserInterface {

/** OpenSearch (PPL) AST builder. */
private val planTransformer = new CatalystQueryPlanVisitor()
Expand All @@ -55,6 +57,7 @@ class FlintSparkPPLParser(sparkParser: ParserInterface) extends ParserInterface
try {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
context.withSparkSession(spark)
planTransformer.visit(plan(pplParser, sqlText), context)
context.getPlan
} catch {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

object PPLSparkUtils {

def findLogicalRelations(plan: LogicalPlan): Seq[UnresolvedRelation] = {
plan
.transformDown { case relation: UnresolvedRelation =>
relation
}
.collect { case relation: UnresolvedRelation =>
relation
}
}
}
Loading