Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stabilize adaptive rate limit by considering current rate #1027

Merged
merged 10 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
- `spark.datasource.flint.read.scroll_size`: default value is 100.
- `spark.datasource.flint.read.scroll_duration`: default value is 5 minutes. scroll context keep alive duration.
- `spark.datasource.flint.retry.max_retries`: max retries on failed HTTP request. default value is 3. Use 0 to disable retry.
- `spark.datasource.flint.retry.bulk.max_retries`: max retries on failed bulk request. default value is 10. Use 0 to disable retry.
- `spark.datasource.flint.retry.bulk.initial_backoff`: initial backoff in seconds for bulk request retry, default is 4.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason choose 4s as default value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was fixed value and I made it a configuration. The original intention is having higher initial backoff to quickly reduce the rate.

- `spark.datasource.flint.retry.http_status_codes`: retryable HTTP response status code list. default value is "429,502" (429 Too Many Request and 502 Bad Gateway).
- `spark.datasource.flint.retry.exception_class_names`: retryable exception class name list. by default no retry on any exception thrown.
- `spark.datasource.flint.read.support_shard`: default is true. set to false if index does not support shard (AWS OpenSearch Serverless collection). Do not use in production, this setting will be removed in later version.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import dev.failsafe.event.ExecutionAttemptedEvent;
import dev.failsafe.function.CheckedPredicate;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.flint.core.http.handler.ExceptionClassNameFailurePredicate;
import org.opensearch.flint.core.http.handler.HttpAOSSResultPredicate;
Expand All @@ -41,8 +44,12 @@ public class FlintRetryOptions implements Serializable {
*/
public static final int DEFAULT_MAX_RETRIES = 3;
public static final String MAX_RETRIES = "retry.max_retries";
public static final int DEFAULT_BULK_MAX_RETRIES = 10;
public static final String BULK_MAX_RETRIES = "retry.bulk.max_retries";
public static final int DEFAULT_BULK_INITIAL_BACKOFF = 4;
public static final String BULK_INITIAL_BACKOFF = "retry.bulk.initial_backoff";

public static final String DEFAULT_RETRYABLE_HTTP_STATUS_CODES = "429,502";
public static final String DEFAULT_RETRYABLE_HTTP_STATUS_CODES = "429,500,502";
public static final String RETRYABLE_HTTP_STATUS_CODES = "retry.http_status_codes";

/**
Expand Down Expand Up @@ -90,9 +97,9 @@ public <T> RetryPolicy<T> getRetryPolicy() {
public RetryPolicy<BulkResponse> getBulkRetryPolicy(CheckedPredicate<BulkResponse> resultPredicate) {
return RetryPolicy.<BulkResponse>builder()
// Using higher initial backoff to mitigate throttling quickly
.withBackoff(4, 30, SECONDS)
.withBackoff(getBulkInitialBackoff(), 30, SECONDS)
.withJitter(Duration.ofMillis(100))
.withMaxRetries(getMaxRetries())
.withMaxRetries(getBulkMaxRetries())
// Do not retry on exception (will be handled by the other retry policy
.handleIf((ex) -> false)
.handleResultIf(resultPredicate)
Expand Down Expand Up @@ -122,10 +129,27 @@ public int getMaxRetries() {
}

/**
* @return retryable HTTP status code list
* @return bulk maximum retry option value
*/
public int getBulkMaxRetries() {
return Integer.parseInt(
options.getOrDefault(BULK_MAX_RETRIES, String.valueOf(DEFAULT_BULK_MAX_RETRIES)));
}

/**
* @return maximum retry option value
*/
public String getRetryableHttpStatusCodes() {
return options.getOrDefault(RETRYABLE_HTTP_STATUS_CODES, DEFAULT_RETRYABLE_HTTP_STATUS_CODES);
public int getBulkInitialBackoff() {
return Integer.parseInt(
options.getOrDefault(BULK_INITIAL_BACKOFF, String.valueOf(DEFAULT_BULK_INITIAL_BACKOFF)));
}

public Set<Integer> getRetryableHttpStatusCodes() {
String statusCodes = options.getOrDefault(RETRYABLE_HTTP_STATUS_CODES, DEFAULT_RETRYABLE_HTTP_STATUS_CODES);
return Arrays.stream(statusCodes.split(","))
.map(String::trim)
.map(Integer::valueOf)
.collect(Collectors.toSet());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@ public class HttpStatusCodeResultPredicate<T> implements CheckedPredicate<T> {
*/
private final Set<Integer> retryableStatusCodes;

public HttpStatusCodeResultPredicate(String httpStatusCodes) {
this.retryableStatusCodes =
Arrays.stream(httpStatusCodes.split(","))
.map(String::trim)
.map(Integer::valueOf)
.collect(Collectors.toSet());
public HttpStatusCodeResultPredicate(Set<Integer> httpStatusCodes) {
this.retryableStatusCodes = httpStatusCodes;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ public class BulkRequestRateLimiterImpl implements BulkRequestRateLimiter {
private final long maxRate;
private final long increaseStep;
private final double decreaseRatio;
private final RequestRateMeter requestRateMeter;

public BulkRequestRateLimiterImpl(FlintOptions flintOptions) {
minRate = flintOptions.getBulkRequestMinRateLimitPerNode();
maxRate = flintOptions.getBulkRequestMaxRateLimitPerNode();
increaseStep = flintOptions.getBulkRequestRateLimitPerNodeIncreaseStep();
decreaseRatio = flintOptions.getBulkRequestRateLimitPerNodeDecreaseRatio();
requestRateMeter = new RequestRateMeter();

LOG.info("Setting rate limit for bulk request to " + minRate + " documents/sec");
this.rateLimiter = RateLimiter.create(minRate);
Expand All @@ -42,14 +44,26 @@ public void acquirePermit() {
public void acquirePermit(int permits) {
this.rateLimiter.acquire(permits);
LOG.info("Acquired " + permits + " permits");
requestRateMeter.addDataPoint(System.currentTimeMillis(), permits);
}

/**
* Increase rate limit additively.
*/
@Override
public void increaseRate() {
setRate(getRate() + increaseStep);
if (isEstimatedCurrentRateCloseToLimit()) {
setRate(getRate() + increaseStep);
} else {
LOG.info("Rate increase was blocked.");
}
LOG.info("Current rate limit for bulk request is " + getRate() + " documents/sec");
}

private boolean isEstimatedCurrentRateCloseToLimit() {
long currentEstimatedRate = requestRateMeter.getCurrentEstimatedRate();
LOG.info("Current estimated rate is " + currentEstimatedRate + " documents/sec");
return getRate() * 0.8 < currentEstimatedRate;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dev.failsafe.function.CheckedPredicate;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Logger;
Expand All @@ -34,10 +35,12 @@ public class OpenSearchBulkWrapper {

private final RetryPolicy<BulkResponse> retryPolicy;
private final BulkRequestRateLimiter rateLimiter;
private final Set<Integer> retryableStatusCodes;

public OpenSearchBulkWrapper(FlintRetryOptions retryOptions, BulkRequestRateLimiter rateLimiter) {
this.retryPolicy = retryOptions.getBulkRetryPolicy(bulkItemRetryableResultPredicate);
this.rateLimiter = rateLimiter;
this.retryableStatusCodes = retryOptions.getRetryableHttpStatusCodes();
}

/**
Expand All @@ -50,7 +53,6 @@ public OpenSearchBulkWrapper(FlintRetryOptions retryOptions, BulkRequestRateLimi
* @return Last result
*/
public BulkResponse bulk(RestHighLevelClient client, BulkRequest bulkRequest, RequestOptions options) {
rateLimiter.acquirePermit(bulkRequest.requests().size());
return bulkWithPartialRetry(client, bulkRequest, options);
}

Expand All @@ -69,11 +71,13 @@ private BulkResponse bulkWithPartialRetry(RestHighLevelClient client, BulkReques
})
.get(() -> {
requestCount.incrementAndGet();
rateLimiter.acquirePermit(nextRequest.get().requests().size());
BulkResponse response = client.bulk(nextRequest.get(), options);

if (!bulkItemRetryableResultPredicate.test(response)) {
rateLimiter.increaseRate();
} else {
LOG.info("Bulk request failed. attempt = " + (requestCount.get() - 1));
rateLimiter.decreaseRate();
if (retryPolicy.getConfig().allowsRetries()) {
nextRequest.set(getRetryableRequest(nextRequest.get(), response));
Expand Down Expand Up @@ -118,10 +122,10 @@ private static void verifyIdMatch(DocWriteRequest<?> request, BulkItemResponse r
/**
* A predicate to decide if a BulkResponse is retryable or not.
*/
private static final CheckedPredicate<BulkResponse> bulkItemRetryableResultPredicate = bulkResponse ->
private final CheckedPredicate<BulkResponse> bulkItemRetryableResultPredicate = bulkResponse ->
bulkResponse.hasFailures() && isRetryable(bulkResponse);

private static boolean isRetryable(BulkResponse bulkResponse) {
private boolean isRetryable(BulkResponse bulkResponse) {
if (Arrays.stream(bulkResponse.getItems())
.anyMatch(itemResp -> isItemRetryable(itemResp))) {
LOG.info("Found retryable failure in the bulk response");
Expand All @@ -130,12 +134,23 @@ private static boolean isRetryable(BulkResponse bulkResponse) {
return false;
}

private static boolean isItemRetryable(BulkItemResponse itemResponse) {
return itemResponse.isFailed() && !isCreateConflict(itemResponse);
private boolean isItemRetryable(BulkItemResponse itemResponse) {
return itemResponse.isFailed() && !isCreateConflict(itemResponse)
&& isFailureStatusRetryable(itemResponse);
}

private static boolean isCreateConflict(BulkItemResponse itemResp) {
return itemResp.getOpType() == DocWriteRequest.OpType.CREATE &&
itemResp.getFailure().getStatus() == RestStatus.CONFLICT;
}

private boolean isFailureStatusRetryable(BulkItemResponse itemResp) {
if (retryableStatusCodes.contains(itemResp.getFailure().getStatus().getStatus())) {
return true;
} else {
LOG.info("Found non-retryable failure in bulk response: " + itemResp.getFailure().getStatus()
+ ", " + itemResp.getFailure().toString());
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.storage;

import java.util.LinkedList;
import java.util.List;
import java.util.Queue;

/**
* Track the current request rate based on the past requests within ESTIMATE_RANGE_DURATION_MSEC
* milliseconds period.
*/
public class RequestRateMeter {
private static final long ESTIMATE_RANGE_DURATION_MSEC = 3000;

private static class DataPoint {
long timestamp;
long requestCount;
public DataPoint(long timestamp, long requestCount) {
this.timestamp = timestamp;
this.requestCount = requestCount;
}
}

private Queue<DataPoint> dataPoints = new LinkedList<>();
private long currentSum = 0;

public synchronized void addDataPoint(long timestamp, long requestCount) {
dataPoints.add(new DataPoint(timestamp, requestCount));
currentSum += requestCount;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could removeOldDataPoint during add to reduce memory usage.

Copy link
Collaborator Author

@ykmr1224 ykmr1224 Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. Memory usage is almost ignorable.

removeOldDataPoints();
}

public synchronized long getCurrentEstimatedRate() {
removeOldDataPoints();
return currentSum * 1000 / ESTIMATE_RANGE_DURATION_MSEC;
}

private synchronized void removeOldDataPoints() {
long curr = System.currentTimeMillis();
while (!dataPoints.isEmpty() && dataPoints.peek().timestamp < curr - ESTIMATE_RANGE_DURATION_MSEC) {
currentSum -= dataPoints.remove().requestCount;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.storage;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;

class RequestRateMeterTest {

private RequestRateMeter requestRateMeter;

@BeforeEach
void setUp() {
requestRateMeter = new RequestRateMeter();
}

@Test
void testAddDataPoint() {
long timestamp = System.currentTimeMillis();
requestRateMeter.addDataPoint(timestamp, 30);
assertEquals(10, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testAddDataPointRemoveOldDataPoint() {
long timestamp = System.currentTimeMillis();
requestRateMeter.addDataPoint(timestamp - 4000, 30);
requestRateMeter.addDataPoint(timestamp, 90);
assertEquals(90 / 3, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testRemoveOldDataPoints() {
long currentTime = System.currentTimeMillis();
requestRateMeter.addDataPoint(currentTime - 4000, 30);
requestRateMeter.addDataPoint(currentTime - 2000, 60);
requestRateMeter.addDataPoint(currentTime, 90);

assertEquals((60 + 90)/3, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testGetCurrentEstimatedRate() {
long currentTime = System.currentTimeMillis();
requestRateMeter.addDataPoint(currentTime - 2500, 30);
requestRateMeter.addDataPoint(currentTime - 1500, 60);
requestRateMeter.addDataPoint(currentTime - 500, 90);

assertEquals((30 + 60 + 90)/3, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testEmptyRateMeter() {
assertEquals(0, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testSingleDataPoint() {
requestRateMeter.addDataPoint(System.currentTimeMillis(), 30);
assertEquals(30 / 3, requestRateMeter.getCurrentEstimatedRate());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ object FlintSparkConf {
.doc("max retries on failed HTTP request, 0 means retry is disabled, default is 3")
.createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_MAX_RETRIES))

val BULK_MAX_RETRIES =
FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.BULK_MAX_RETRIES}")
.datasourceOption()
.doc("max retries on failed HTTP request, 0 means retry is disabled, default is 10")
.createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_BULK_MAX_RETRIES))

val BULK_INITIAL_BACKOFF =
FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.BULK_INITIAL_BACKOFF}")
.datasourceOption()
.doc("initial backoff in seconds for bulk request retry, default is 4s")
.createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_BULK_INITIAL_BACKOFF))

val BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED =
FlintConfig(
s"spark.datasource.flint.${FlintOptions.BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED}")
Expand Down Expand Up @@ -368,6 +380,8 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
SCHEME,
AUTH,
MAX_RETRIES,
BULK_MAX_RETRIES,
BULK_INITIAL_BACKOFF,
RETRYABLE_HTTP_STATUS_CODES,
BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED,
BULK_REQUEST_MIN_RATE_LIMIT_PER_NODE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import scala.collection.JavaConverters._

import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.must.Matchers.contain
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
Expand Down Expand Up @@ -46,7 +47,7 @@ class FlintSparkConfSuite extends FlintSuite {
test("test retry options default values") {
val retryOptions = FlintSparkConf().flintOptions().getRetryOptions
retryOptions.getMaxRetries shouldBe DEFAULT_MAX_RETRIES
retryOptions.getRetryableHttpStatusCodes shouldBe DEFAULT_RETRYABLE_HTTP_STATUS_CODES
retryOptions.getRetryableHttpStatusCodes should contain theSameElementsAs Set(429, 500, 502)
retryOptions.getRetryableExceptionClassNames shouldBe Optional.empty
}

Expand All @@ -60,7 +61,11 @@ class FlintSparkConfSuite extends FlintSuite {
.getRetryOptions

retryOptions.getMaxRetries shouldBe 5
retryOptions.getRetryableHttpStatusCodes shouldBe "429,502,503,504"
retryOptions.getRetryableHttpStatusCodes should contain theSameElementsAs Set(
429,
502,
503,
504)
retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException"
}

Expand Down
Loading