Skip to content

Commit

Permalink
Add FlintJob to support queries in warmpool mode
Browse files Browse the repository at this point in the history
Signed-off-by: Shri Saran Raj N <[email protected]>
  • Loading branch information
Shri Saran Raj N committed Jan 26, 2025
1 parent 20ef890 commit e1db8de
Show file tree
Hide file tree
Showing 9 changed files with 472 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ public final class MetricConstants {
*/
public static final String RESULT_METADATA_WRITE_METRIC_PREFIX = "result.metadata.write";

/**
* Prefix for metrics related to interactive queries
*/
public static final String STATEMENT = "statement";

/**
* Metric name for counting the number of statements currently running.
*/
Expand Down Expand Up @@ -135,11 +140,31 @@ public final class MetricConstants {
*/
public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count";

/**
* Metric for tracking the count of jobs failed during query execution
*/
public static final String QUERY_EXECUTION_FAILED_METRIC = "execution.failed.count";

/**
* Metric for tracking the count of jobs failed during query result write
*/
public static final String RESULT_WRITER_FAILED_METRIC = "writer.failed.count";

/**
* Metric for tracking the latency of query execution (start to complete query execution) excluding result write.
*/
public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime";

/**
* Metric for tracking the latency of query result write only (excluding query execution)
*/
public static final String QUERY_RESULT_WRITER_TIME_METRIC = "result.writer.processingTime";

/**
* Metric for tracking the latency of query total execution including result write.
*/
public static final String QUERY_TOTAL_TIME_METRIC = "query.total.processingTime";

/**
* Metric for query count of each query type (DROP/VACUUM/ALTER/REFRESH/CREATE INDEX)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ object FlintSparkConf {
.doc("Enable external scheduler for index refresh")
.createWithDefault("false")

val WARMPOOL_ENABLED =
FlintConfig("spark.flint.job.warmpoolEnabled")
.createWithDefault("false")

val MAX_EXECUTORS_COUNT = FlintConfig("spark.dynamicAllocation.maxExecutors").createOptional()

val EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD =
FlintConfig("spark.flint.job.externalScheduler.interval")
.doc("Interval threshold in minutes for external scheduler to trigger index refresh")
Expand Down Expand Up @@ -246,6 +252,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.job.requestIndex")
.doc("Request index")
.createOptional()
val RESULT_INDEX =
FlintConfig(s"spark.flint.job.resultIndex")
.doc("Result index")
.createOptional()
val EXCLUDE_JOB_IDS =
FlintConfig(s"spark.flint.deployment.excludeJobs")
.doc("Exclude job ids")
Expand All @@ -271,6 +281,9 @@ object FlintSparkConf {
val CUSTOM_QUERY_RESULT_WRITER =
FlintConfig("spark.flint.job.customQueryResultWriter")
.createOptional()
val TERMINATE_JVM = FlintConfig("spark.flint.terminateJVM")
.doc("Indicates whether the JVM should be terminated after query execution")
.createWithDefault("true")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
val appId = "00feq82b752mbt0p"
val dataSourceName = "my_glue1"
val queryId = "testQueryId"
val requestIndex = "testRequestIndex"
var osClient: OSClient = _
val threadLocalFuture = new ThreadLocal[Future[Unit]]()

Expand Down Expand Up @@ -83,13 +84,15 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {

def createJobOperator(query: String, jobRunId: String): JobOperator = {
val streamingRunningCount = new AtomicInteger(0)
val statementRunningCount = new AtomicInteger(0)

/*
* Because we cannot test from FlintJob.main() for the reason below, we have to configure
* all Spark conf required by Flint code underlying manually.
*/
spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName)
spark.conf.set(JOB_TYPE.key, FlintJobType.STREAMING)
spark.conf.set(REQUEST_INDEX.key, requestIndex)

val job = JobOperator(
appId,
Expand All @@ -100,7 +103,8 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
dataSourceName,
resultIndex,
FlintJobType.STREAMING,
streamingRunningCount)
streamingRunningCount,
statementRunningCount)
job.terminateJVM = false
job
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}

import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.ThreadUtils

/**
* Spark SQL Application entrypoint
Expand All @@ -26,52 +29,71 @@ import org.apache.spark.sql.flint.config.FlintSparkConf
* write sql query result to given opensearch index
*/
object FlintJob extends Logging with FlintJobExecutor {
private val streamingRunningCount = new AtomicInteger(0)
private val statementRunningCount = new AtomicInteger(0)

def main(args: Array[String]): Unit = {
val (queryOption, resultIndexOption) = parseArgs(args)

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH)
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val sparkSession = createSparkSession(conf)
val applicationId =
environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")
val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown")
val isWarmpoolEnabled = conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false").toBoolean
logInfo(s"isWarmpoolEnabled: ${isWarmpoolEnabled}")

if (!isWarmpoolEnabled) {
val jobType = sparkSession.conf.get("spark.flint.job.type", FlintJobType.BATCH)
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
sparkSession.conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

val streamingRunningCount = new AtomicInteger(0)
val jobOperator =
JobOperator(
applicationId,
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val jobOperator =
JobOperator(
applicationId,
jobId,
sparkSession,
query,
queryId,
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount,
statementRunningCount,
Map.empty)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
} else {
// Fetch and execute queries in warm pool mode
val warmpoolJob =
WarmpoolJob(
applicationId,
jobId,
sparkSession,
streamingRunningCount,
statementRunningCount)
warmpoolJob.start()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.ThreadPoolExecutor

import com.amazonaws.services.glue.model.{AccessDeniedException, AWSGlueException}
import com.amazonaws.services.s3.model.AmazonS3Exception
Expand All @@ -20,6 +21,7 @@ import play.api.libs.json._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintREPL.instantiate
import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.exception.UnrecoverableException
Expand Down Expand Up @@ -566,4 +568,31 @@ trait FlintJobExecutor {
}
}
}

def instantiateQueryResultWriter(
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

def instantiateStatementExecutionManager(
commandContext: CommandContext): StatementExecutionManager = {
import commandContext._
instantiate(
new StatementExecutionManagerImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""),
spark,
sessionId)
}

def instantiateSessionManager(
spark: SparkSession,
resultIndexOption: Option[String]): SessionManager = {
instantiate(
new SessionManagerImpl(spark, resultIndexOption),
spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""),
resultIndexOption.getOrElse(""))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount)
streamingRunningCount,
statementRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
} else {
Expand Down Expand Up @@ -1021,33 +1022,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

private def instantiateSessionManager(
spark: SparkSession,
resultIndexOption: Option[String]): SessionManager = {
instantiate(
new SessionManagerImpl(spark, resultIndexOption),
spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""),
resultIndexOption.getOrElse(""))
}

private def instantiateStatementExecutionManager(
commandContext: CommandContext): StatementExecutionManager = {
import commandContext._
instantiate(
new StatementExecutionManagerImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""),
spark,
sessionId)
}

private def instantiateQueryResultWriter(
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = {
logInfo("Session Success")
stopTimer(sessionTimerContext)
Expand Down
Loading

0 comments on commit e1db8de

Please sign in to comment.