diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/storage/FlintQueryCompiler.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/storage/FlintQueryCompiler.scala index 38418ba6a..2be0523ac 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/storage/FlintQueryCompiler.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/storage/FlintQueryCompiler.scala @@ -13,7 +13,6 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{And, Predicate} import org.apache.spark.sql.flint.datatype.FlintDataType.STRICT_DATE_OPTIONAL_TIME_FORMATTER_WITH_NANOS -import org.apache.spark.sql.flint.datatype.FlintMetadataExtensions import org.apache.spark.sql.flint.datatype.FlintMetadataExtensions.MetadataExtension import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -34,6 +33,14 @@ case class FlintQueryCompiler(schema: StructType) { compile(predicates.reduce(new And(_, _))) } + /** + * Compile an expression to a query string. Returns empty string if any part of the expression + * is unsupported. + */ + def compile(expr: Expression, quoteString: Boolean = true): String = { + compileOpt(expr, quoteString).getOrElse("") + } + /** * Compile Expression to Flint query string. * @@ -42,13 +49,13 @@ case class FlintQueryCompiler(schema: StructType) { * @return * empty if does not support. */ - def compile(expr: Expression, quoteString: Boolean = true): String = { + def compileOpt(expr: Expression, quoteString: Boolean = true): Option[String] = { expr match { case LiteralValue(value, dataType) => - quote(extract, quoteString)(value, dataType) + Some(quote(extract, quoteString)(value, dataType)) case p: Predicate => visitPredicate(p) - case f: FieldReference => f.toString() - case _ => "" + case f: FieldReference => Some(f.toString()) + case _ => None } } @@ -77,56 +84,101 @@ case class FlintQueryCompiler(schema: StructType) { * 1. currently, we map spark contains to OpenSearch match query. Can we leverage more full * text queries for text field. 2. configuration of expensive query. */ - def visitPredicate(p: Predicate): String = { - val name = p.name() - name match { - case "IS_NULL" => - s"""{"bool":{"must_not":{"exists":{"field":"${compile(p.children()(0))}"}}}}""" - case "IS_NOT_NULL" => - s"""{"exists":{"field":"${compile(p.children()(0))}"}}""" - case "AND" => - s"""{"bool":{"filter":[${compile(p.children()(0))},${compile(p.children()(1))}]}}""" - case "OR" => - s"""{"bool":{"should":[{"bool":{"filter":${compile( - p.children()(0))}}},{"bool":{"filter":${compile(p.children()(1))}}}]}}""" - case "NOT" => - s"""{"bool":{"must_not":${compile(p.children()(0))}}}""" - case "=" => - val fieldName = compile(p.children()(0)) - if (isTextField(fieldName)) { - getKeywordSubfield(fieldName) match { - case Some(keywordField) => - s"""{"term":{"$keywordField":{"value":${compile(p.children()(1))}}}}""" - case None => "" + def visitPredicate(p: Predicate): Option[String] = p.name() match { + case "IS_NULL" => + compileOpt(p.children()(0)).map { field => + s"""{"bool":{"must_not":{"exists":{"field":"$field"}}}}""" + } + case "IS_NOT_NULL" => + compileOpt(p.children()(0)).map { field => + s"""{"exists":{"field":"$field"}}""" + } + case "AND" => + for { + left <- compileOpt(p.children()(0)) + right <- compileOpt(p.children()(1)) + } yield s"""{"bool":{"filter":[$left,$right]}}""" + case "OR" => + for { + left <- compileOpt(p.children()(0)) + right <- compileOpt(p.children()(1)) + } yield s"""{"bool":{"should":[{"bool":{"filter":$left}},{"bool":{"filter":$right}}]}}""" + case "NOT" => + compileOpt(p.children()(0)).map { child => + s"""{"bool":{"must_not":$child}}""" + } + case "=" => + for { + field <- compileOpt(p.children()(0)) + value <- compileOpt(p.children()(1)) + result <- + if (isTextField(field)) { + getKeywordSubfield(field) match { + case Some(keywordField) => + Some(s"""{"term":{"$keywordField":{"value":$value}}}""") + case None => None // Return None for unsupported text fields + } + } else { + Some(s"""{"term":{"$field":{"value":$value}}}""") } + } yield result + case ">" => + for { + field <- compileOpt(p.children()(0)) + value <- compileOpt(p.children()(1)) + } yield s"""{"range":{"$field":{"gt":$value}}}""" + case ">=" => + for { + field <- compileOpt(p.children()(0)) + value <- compileOpt(p.children()(1)) + } yield s"""{"range":{"$field":{"gte":$value}}}""" + case "<" => + for { + field <- compileOpt(p.children()(0)) + value <- compileOpt(p.children()(1)) + } yield s"""{"range":{"$field":{"lt":$value}}}""" + case "<=" => + for { + field <- compileOpt(p.children()(0)) + value <- compileOpt(p.children()(1)) + } yield s"""{"range":{"$field":{"lte":$value}}}""" + case "IN" => + for { + field <- compileOpt(p.children()(0)) + valuesList = p.children().tail.flatMap(expr => compileOpt(expr)) + // Only proceed if we have values + if valuesList.nonEmpty + } yield { + val values = valuesList.mkString("[", ",", "]") + s"""{"terms":{"$field":$values}}""" + } + case "STARTS_WITH" => + for { + field <- compileOpt(p.children()(0)) + value <- compileOpt(p.children()(1)) + } yield s"""{"prefix":{"$field":{"value":$value}}}""" + case "CONTAINS" => + for { + field <- compileOpt(p.children()(0)) + quoteValue <- compileOpt(p.children()(1)) + unQuoteValue <- compileOpt(p.children()(1), false) + } yield { + if (isTextField(field)) { + s"""{"match":{"$field":{"query":$quoteValue}}}""" } else { - s"""{"term":{"$fieldName":{"value":${compile(p.children()(1))}}}}""" + s"""{"wildcard":{"$field":{"value":"*$unQuoteValue*"}}}""" } - case ">" => - s"""{"range":{"${compile(p.children()(0))}":{"gt":${compile(p.children()(1))}}}}""" - case ">=" => - s"""{"range":{"${compile(p.children()(0))}":{"gte":${compile(p.children()(1))}}}}""" - case "<" => - s"""{"range":{"${compile(p.children()(0))}":{"lt":${compile(p.children()(1))}}}}""" - case "<=" => - s"""{"range":{"${compile(p.children()(0))}":{"lte":${compile(p.children()(1))}}}}""" - case "IN" => - val values = p.children().tail.map(expr => compile(expr)).mkString("[", ",", "]") - s"""{"terms":{"${compile(p.children()(0))}":$values}}""" - case "STARTS_WITH" => - s"""{"prefix":{"${compile(p.children()(0))}":{"value":${compile(p.children()(1))}}}}""" - case "CONTAINS" => - val fieldName = compile(p.children()(0)) - if (isTextField(fieldName)) { - s"""{"match":{"$fieldName":{"query":${compile(p.children()(1))}}}}""" - } else { - s"""{"wildcard":{"$fieldName":{"value":"*${compile(p.children()(1), false)}*"}}}""" - } - case "ENDS_WITH" => - s"""{"wildcard":{"${compile(p.children()(0))}":{"value":"*${compile( - p.children()(1), - false)}"}}}""" - case "BLOOM_FILTER_MIGHT_CONTAIN" => + } + case "ENDS_WITH" => + for { + field <- compileOpt(p.children()(0)) + value <- compileOpt(p.children()(1), false) + } yield s"""{"wildcard":{"$field":{"value":"*$value"}}}""" + case "BLOOM_FILTER_MIGHT_CONTAIN" => + for { + field <- compileOpt(p.children()(0)) + value <- compileOpt(p.children()(1)) + } yield { val code = Source.fromResource("bloom_filter_query.script").getLines().mkString(" ") s""" |{ @@ -137,8 +189,8 @@ case class FlintQueryCompiler(schema: StructType) { | "lang": "painless", | "source": "$code", | "params": { - | "fieldName": "${compile(p.children()(0))}", - | "value": ${compile(p.children()(1))} + | "fieldName": "$field", + | "value": $value | } | } | } @@ -146,8 +198,8 @@ case class FlintQueryCompiler(schema: StructType) { | } |} |""".stripMargin - case _ => "" - } + } + case _ => None } /** diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/storage/FlintQueryCompilerSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/storage/FlintQueryCompilerSuite.scala index 403585864..aff0a8be7 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/storage/FlintQueryCompilerSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/storage/FlintQueryCompilerSuite.scala @@ -196,6 +196,20 @@ class FlintQueryCompilerSuite extends FlintSuite { assertResult("")(query) } + test( + "Bug fix, https://github.com/opensearch-project/opensearch-spark/actions/runs/13338348402/job/37258321066?pr=1052 ") { + val schema = StructType( + Seq( + StructField( + "aText", + StringType, + nullable = true, + new MetadataBuilder().withTextField.build()))) + + val query = FlintQueryCompiler(schema).compile(Not(EqualTo("aText", "text")).toV2) + assertResult("")(query) + } + protected def schema(): StructType = { StructType( Seq(