Skip to content

Commit

Permalink
Stabilize adaptive rate limit by considering current rate (#1027)
Browse files Browse the repository at this point in the history
* Stabilize adaptive rate limit by considering current rate

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix bulk retry condition

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix FlintSparkConfSuite

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix FlintSparkConfSuite

Signed-off-by: Tomoyuki Morita <[email protected]>

* Reformat

Signed-off-by: Tomoyuki Morita <[email protected]>

* Use Queue interface instead of List

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix doc for BULK_INITIAL_BACKOFF

Signed-off-by: Tomoyuki Morita <[email protected]>

* Add retryable http status code

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix test

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix RequestRateMeter.addDataPoint

Signed-off-by: Tomoyuki Morita <[email protected]>

---------

Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 authored Feb 6, 2025
1 parent 4d6ba7d commit 77da4a7
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 20 deletions.
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
- `spark.datasource.flint.read.scroll_size`: default value is 100.
- `spark.datasource.flint.read.scroll_duration`: default value is 5 minutes. scroll context keep alive duration.
- `spark.datasource.flint.retry.max_retries`: max retries on failed HTTP request. default value is 3. Use 0 to disable retry.
- `spark.datasource.flint.retry.bulk.max_retries`: max retries on failed bulk request. default value is 10. Use 0 to disable retry.
- `spark.datasource.flint.retry.bulk.initial_backoff`: initial backoff in seconds for bulk request retry, default is 4.
- `spark.datasource.flint.retry.http_status_codes`: retryable HTTP response status code list. default value is "429,502" (429 Too Many Request and 502 Bad Gateway).
- `spark.datasource.flint.retry.exception_class_names`: retryable exception class name list. by default no retry on any exception thrown.
- `spark.datasource.flint.read.support_shard`: default is true. set to false if index does not support shard (AWS OpenSearch Serverless collection). Do not use in production, this setting will be removed in later version.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import dev.failsafe.event.ExecutionAttemptedEvent;
import dev.failsafe.function.CheckedPredicate;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.flint.core.http.handler.ExceptionClassNameFailurePredicate;
import org.opensearch.flint.core.http.handler.HttpAOSSResultPredicate;
Expand All @@ -41,8 +44,12 @@ public class FlintRetryOptions implements Serializable {
*/
public static final int DEFAULT_MAX_RETRIES = 3;
public static final String MAX_RETRIES = "retry.max_retries";
public static final int DEFAULT_BULK_MAX_RETRIES = 10;
public static final String BULK_MAX_RETRIES = "retry.bulk.max_retries";
public static final int DEFAULT_BULK_INITIAL_BACKOFF = 4;
public static final String BULK_INITIAL_BACKOFF = "retry.bulk.initial_backoff";

public static final String DEFAULT_RETRYABLE_HTTP_STATUS_CODES = "429,502";
public static final String DEFAULT_RETRYABLE_HTTP_STATUS_CODES = "429,500,502";
public static final String RETRYABLE_HTTP_STATUS_CODES = "retry.http_status_codes";

/**
Expand Down Expand Up @@ -90,9 +97,9 @@ public <T> RetryPolicy<T> getRetryPolicy() {
public RetryPolicy<BulkResponse> getBulkRetryPolicy(CheckedPredicate<BulkResponse> resultPredicate) {
return RetryPolicy.<BulkResponse>builder()
// Using higher initial backoff to mitigate throttling quickly
.withBackoff(4, 30, SECONDS)
.withBackoff(getBulkInitialBackoff(), 30, SECONDS)
.withJitter(Duration.ofMillis(100))
.withMaxRetries(getMaxRetries())
.withMaxRetries(getBulkMaxRetries())
// Do not retry on exception (will be handled by the other retry policy
.handleIf((ex) -> false)
.handleResultIf(resultPredicate)
Expand Down Expand Up @@ -122,10 +129,27 @@ public int getMaxRetries() {
}

/**
* @return retryable HTTP status code list
* @return bulk maximum retry option value
*/
public int getBulkMaxRetries() {
return Integer.parseInt(
options.getOrDefault(BULK_MAX_RETRIES, String.valueOf(DEFAULT_BULK_MAX_RETRIES)));
}

/**
* @return maximum retry option value
*/
public String getRetryableHttpStatusCodes() {
return options.getOrDefault(RETRYABLE_HTTP_STATUS_CODES, DEFAULT_RETRYABLE_HTTP_STATUS_CODES);
public int getBulkInitialBackoff() {
return Integer.parseInt(
options.getOrDefault(BULK_INITIAL_BACKOFF, String.valueOf(DEFAULT_BULK_INITIAL_BACKOFF)));
}

public Set<Integer> getRetryableHttpStatusCodes() {
String statusCodes = options.getOrDefault(RETRYABLE_HTTP_STATUS_CODES, DEFAULT_RETRYABLE_HTTP_STATUS_CODES);
return Arrays.stream(statusCodes.split(","))
.map(String::trim)
.map(Integer::valueOf)
.collect(Collectors.toSet());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@ public class HttpStatusCodeResultPredicate<T> implements CheckedPredicate<T> {
*/
private final Set<Integer> retryableStatusCodes;

public HttpStatusCodeResultPredicate(String httpStatusCodes) {
this.retryableStatusCodes =
Arrays.stream(httpStatusCodes.split(","))
.map(String::trim)
.map(Integer::valueOf)
.collect(Collectors.toSet());
public HttpStatusCodeResultPredicate(Set<Integer> httpStatusCodes) {
this.retryableStatusCodes = httpStatusCodes;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ public class BulkRequestRateLimiterImpl implements BulkRequestRateLimiter {
private final long maxRate;
private final long increaseStep;
private final double decreaseRatio;
private final RequestRateMeter requestRateMeter;

public BulkRequestRateLimiterImpl(FlintOptions flintOptions) {
minRate = flintOptions.getBulkRequestMinRateLimitPerNode();
maxRate = flintOptions.getBulkRequestMaxRateLimitPerNode();
increaseStep = flintOptions.getBulkRequestRateLimitPerNodeIncreaseStep();
decreaseRatio = flintOptions.getBulkRequestRateLimitPerNodeDecreaseRatio();
requestRateMeter = new RequestRateMeter();

LOG.info("Setting rate limit for bulk request to " + minRate + " documents/sec");
this.rateLimiter = RateLimiter.create(minRate);
Expand All @@ -42,14 +44,26 @@ public void acquirePermit() {
public void acquirePermit(int permits) {
this.rateLimiter.acquire(permits);
LOG.info("Acquired " + permits + " permits");
requestRateMeter.addDataPoint(System.currentTimeMillis(), permits);
}

/**
* Increase rate limit additively.
*/
@Override
public void increaseRate() {
setRate(getRate() + increaseStep);
if (isEstimatedCurrentRateCloseToLimit()) {
setRate(getRate() + increaseStep);
} else {
LOG.info("Rate increase was blocked.");
}
LOG.info("Current rate limit for bulk request is " + getRate() + " documents/sec");
}

private boolean isEstimatedCurrentRateCloseToLimit() {
long currentEstimatedRate = requestRateMeter.getCurrentEstimatedRate();
LOG.info("Current estimated rate is " + currentEstimatedRate + " documents/sec");
return getRate() * 0.8 < currentEstimatedRate;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dev.failsafe.function.CheckedPredicate;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Logger;
Expand All @@ -34,10 +35,12 @@ public class OpenSearchBulkWrapper {

private final RetryPolicy<BulkResponse> retryPolicy;
private final BulkRequestRateLimiter rateLimiter;
private final Set<Integer> retryableStatusCodes;

public OpenSearchBulkWrapper(FlintRetryOptions retryOptions, BulkRequestRateLimiter rateLimiter) {
this.retryPolicy = retryOptions.getBulkRetryPolicy(bulkItemRetryableResultPredicate);
this.rateLimiter = rateLimiter;
this.retryableStatusCodes = retryOptions.getRetryableHttpStatusCodes();
}

/**
Expand All @@ -50,7 +53,6 @@ public OpenSearchBulkWrapper(FlintRetryOptions retryOptions, BulkRequestRateLimi
* @return Last result
*/
public BulkResponse bulk(RestHighLevelClient client, BulkRequest bulkRequest, RequestOptions options) {
rateLimiter.acquirePermit(bulkRequest.requests().size());
return bulkWithPartialRetry(client, bulkRequest, options);
}

Expand All @@ -69,11 +71,13 @@ private BulkResponse bulkWithPartialRetry(RestHighLevelClient client, BulkReques
})
.get(() -> {
requestCount.incrementAndGet();
rateLimiter.acquirePermit(nextRequest.get().requests().size());
BulkResponse response = client.bulk(nextRequest.get(), options);

if (!bulkItemRetryableResultPredicate.test(response)) {
rateLimiter.increaseRate();
} else {
LOG.info("Bulk request failed. attempt = " + (requestCount.get() - 1));
rateLimiter.decreaseRate();
if (retryPolicy.getConfig().allowsRetries()) {
nextRequest.set(getRetryableRequest(nextRequest.get(), response));
Expand Down Expand Up @@ -118,10 +122,10 @@ private static void verifyIdMatch(DocWriteRequest<?> request, BulkItemResponse r
/**
* A predicate to decide if a BulkResponse is retryable or not.
*/
private static final CheckedPredicate<BulkResponse> bulkItemRetryableResultPredicate = bulkResponse ->
private final CheckedPredicate<BulkResponse> bulkItemRetryableResultPredicate = bulkResponse ->
bulkResponse.hasFailures() && isRetryable(bulkResponse);

private static boolean isRetryable(BulkResponse bulkResponse) {
private boolean isRetryable(BulkResponse bulkResponse) {
if (Arrays.stream(bulkResponse.getItems())
.anyMatch(itemResp -> isItemRetryable(itemResp))) {
LOG.info("Found retryable failure in the bulk response");
Expand All @@ -130,12 +134,23 @@ private static boolean isRetryable(BulkResponse bulkResponse) {
return false;
}

private static boolean isItemRetryable(BulkItemResponse itemResponse) {
return itemResponse.isFailed() && !isCreateConflict(itemResponse);
private boolean isItemRetryable(BulkItemResponse itemResponse) {
return itemResponse.isFailed() && !isCreateConflict(itemResponse)
&& isFailureStatusRetryable(itemResponse);
}

private static boolean isCreateConflict(BulkItemResponse itemResp) {
return itemResp.getOpType() == DocWriteRequest.OpType.CREATE &&
itemResp.getFailure().getStatus() == RestStatus.CONFLICT;
}

private boolean isFailureStatusRetryable(BulkItemResponse itemResp) {
if (retryableStatusCodes.contains(itemResp.getFailure().getStatus().getStatus())) {
return true;
} else {
LOG.info("Found non-retryable failure in bulk response: " + itemResp.getFailure().getStatus()
+ ", " + itemResp.getFailure().toString());
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.storage;

import java.util.LinkedList;
import java.util.List;
import java.util.Queue;

/**
* Track the current request rate based on the past requests within ESTIMATE_RANGE_DURATION_MSEC
* milliseconds period.
*/
public class RequestRateMeter {
private static final long ESTIMATE_RANGE_DURATION_MSEC = 3000;

private static class DataPoint {
long timestamp;
long requestCount;
public DataPoint(long timestamp, long requestCount) {
this.timestamp = timestamp;
this.requestCount = requestCount;
}
}

private Queue<DataPoint> dataPoints = new LinkedList<>();
private long currentSum = 0;

public synchronized void addDataPoint(long timestamp, long requestCount) {
dataPoints.add(new DataPoint(timestamp, requestCount));
currentSum += requestCount;
removeOldDataPoints();
}

public synchronized long getCurrentEstimatedRate() {
removeOldDataPoints();
return currentSum * 1000 / ESTIMATE_RANGE_DURATION_MSEC;
}

private synchronized void removeOldDataPoints() {
long curr = System.currentTimeMillis();
while (!dataPoints.isEmpty() && dataPoints.peek().timestamp < curr - ESTIMATE_RANGE_DURATION_MSEC) {
currentSum -= dataPoints.remove().requestCount;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.storage;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;

class RequestRateMeterTest {

private RequestRateMeter requestRateMeter;

@BeforeEach
void setUp() {
requestRateMeter = new RequestRateMeter();
}

@Test
void testAddDataPoint() {
long timestamp = System.currentTimeMillis();
requestRateMeter.addDataPoint(timestamp, 30);
assertEquals(10, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testAddDataPointRemoveOldDataPoint() {
long timestamp = System.currentTimeMillis();
requestRateMeter.addDataPoint(timestamp - 4000, 30);
requestRateMeter.addDataPoint(timestamp, 90);
assertEquals(90 / 3, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testRemoveOldDataPoints() {
long currentTime = System.currentTimeMillis();
requestRateMeter.addDataPoint(currentTime - 4000, 30);
requestRateMeter.addDataPoint(currentTime - 2000, 60);
requestRateMeter.addDataPoint(currentTime, 90);

assertEquals((60 + 90)/3, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testGetCurrentEstimatedRate() {
long currentTime = System.currentTimeMillis();
requestRateMeter.addDataPoint(currentTime - 2500, 30);
requestRateMeter.addDataPoint(currentTime - 1500, 60);
requestRateMeter.addDataPoint(currentTime - 500, 90);

assertEquals((30 + 60 + 90)/3, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testEmptyRateMeter() {
assertEquals(0, requestRateMeter.getCurrentEstimatedRate());
}

@Test
void testSingleDataPoint() {
requestRateMeter.addDataPoint(System.currentTimeMillis(), 30);
assertEquals(30 / 3, requestRateMeter.getCurrentEstimatedRate());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ object FlintSparkConf {
.doc("max retries on failed HTTP request, 0 means retry is disabled, default is 3")
.createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_MAX_RETRIES))

val BULK_MAX_RETRIES =
FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.BULK_MAX_RETRIES}")
.datasourceOption()
.doc("max retries on failed HTTP request, 0 means retry is disabled, default is 10")
.createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_BULK_MAX_RETRIES))

val BULK_INITIAL_BACKOFF =
FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.BULK_INITIAL_BACKOFF}")
.datasourceOption()
.doc("initial backoff in seconds for bulk request retry, default is 4s")
.createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_BULK_INITIAL_BACKOFF))

val BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED =
FlintConfig(
s"spark.datasource.flint.${FlintOptions.BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED}")
Expand Down Expand Up @@ -368,6 +380,8 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
SCHEME,
AUTH,
MAX_RETRIES,
BULK_MAX_RETRIES,
BULK_INITIAL_BACKOFF,
RETRYABLE_HTTP_STATUS_CODES,
BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED,
BULK_REQUEST_MIN_RATE_LIMIT_PER_NODE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import scala.collection.JavaConverters._

import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.must.Matchers.contain
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
Expand Down Expand Up @@ -46,7 +47,7 @@ class FlintSparkConfSuite extends FlintSuite {
test("test retry options default values") {
val retryOptions = FlintSparkConf().flintOptions().getRetryOptions
retryOptions.getMaxRetries shouldBe DEFAULT_MAX_RETRIES
retryOptions.getRetryableHttpStatusCodes shouldBe DEFAULT_RETRYABLE_HTTP_STATUS_CODES
retryOptions.getRetryableHttpStatusCodes should contain theSameElementsAs Set(429, 500, 502)
retryOptions.getRetryableExceptionClassNames shouldBe Optional.empty
}

Expand All @@ -60,7 +61,11 @@ class FlintSparkConfSuite extends FlintSuite {
.getRetryOptions

retryOptions.getMaxRetries shouldBe 5
retryOptions.getRetryableHttpStatusCodes shouldBe "429,502,503,504"
retryOptions.getRetryableHttpStatusCodes should contain theSameElementsAs Set(
429,
502,
503,
504)
retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException"
}

Expand Down

0 comments on commit 77da4a7

Please sign in to comment.