Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions packages/smithy-core/src/smithy_core/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..auth import AuthParams
from ..deserializers import DeserializeableShape, ShapeDeserializer
from ..endpoints import EndpointResolverParams
from ..exceptions import ClientTimeoutError, RetryError, SmithyError
from ..exceptions import CallError, ClientTimeoutError, RetryError, SmithyError
from ..interceptors import (
InputContext,
Interceptor,
Expand All @@ -23,6 +23,7 @@
from ..interfaces import Endpoint, TypedProperties
from ..interfaces.auth import AuthOption, AuthSchemeResolver
from ..interfaces.retries import RetryStrategy
from ..retries import AdaptiveRetryStrategy
from ..schemas import APIOperation
from ..serializers import SerializeableShape
from ..shapes import ShapeID
Expand Down Expand Up @@ -338,16 +339,47 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
if retry_token.retry_delay:
await sleep(retry_token.retry_delay)

output_context = await self._handle_attempt(
call,
replace(
request_context,
transport_request=copy(request_context.transport_request),
),
request_future,
)
try:
# Rate limiting before request (adaptive only)
await self._handle_pre_request_rate_limiting(retry_strategy)

output_context = await self._handle_attempt(
call,
replace(
request_context,
transport_request=copy(request_context.transport_request),
),
request_future,
)
except TimeoutError as timeout_error:
error = CallError(
fault="client",
message=str(timeout_error),
is_retry_safe=True, # Make it retryable
)

# Token acquisition timeout will be treated as retryable error
try:
retry_token = retry_strategy.refresh_retry_token_for_retry(
token_to_renew=retry_token,
error=error,
)
except RetryError:
raise timeout_error

_LOGGER.debug(
"Token acquisition timeout. Attempting request #%s in %.4f seconds.",
retry_token.retry_count + 1,
retry_token.retry_delay,
)
continue # Skip to next retry iteration

if isinstance(output_context.response, Exception):
# Update rate limiter after failed response (adaptive only)
await self._handle_post_error_response_rate_limiting(
retry_strategy, output_context.response
)

try:
retry_token = retry_strategy.refresh_retry_token_for_retry(
token_to_renew=retry_token,
Expand All @@ -364,9 +396,35 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](

await seek(request_context.transport_request.body, 0)
else:
# Update rate limiter after successful response (adaptive only)
await self._handle_success_rate_limiting(retry_strategy)
retry_strategy.record_success(token=retry_token)
return output_context

async def _handle_pre_request_rate_limiting(
self, retry_strategy: RetryStrategy
) -> None:
"""Handle rate limiting before sending request."""
if isinstance(retry_strategy, AdaptiveRetryStrategy):
await retry_strategy.acquire_from_token_bucket()

async def _handle_post_error_response_rate_limiting(
self, retry_strategy: RetryStrategy, error: Exception
) -> None:
"""Handle rate limiting after failed response."""
if isinstance(retry_strategy, AdaptiveRetryStrategy):
is_throttling = retry_strategy.is_throttling_error(error)
await retry_strategy.rate_limiter.after_receiving_response(is_throttling)

async def _handle_success_rate_limiting(
self, retry_strategy: RetryStrategy
) -> None:
"""Handle rate limiting after successful response."""
if isinstance(retry_strategy, AdaptiveRetryStrategy):
await retry_strategy.rate_limiter.after_receiving_response(
throttling_error=False
)

async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](
self,
call: ClientCall[I, O],
Expand Down
62 changes: 59 additions & 3 deletions packages/smithy-core/src/smithy_core/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .interfaces import retries as retries_interface
from .interfaces.retries import RetryStrategy

RetryStrategyType = Literal["simple", "standard"]
RetryStrategyType = Literal["simple", "standard", "adaptive"]


@dataclass(kw_only=True, frozen=True)
Expand Down Expand Up @@ -69,6 +69,8 @@ def _create_retry_strategy(
return SimpleRetryStrategy(**filtered_kwargs)
case "standard":
return StandardRetryStrategy(**filtered_kwargs)
case "adaptive":
return AdaptiveRetryStrategy(**filtered_kwargs)
case _:
raise ValueError(f"Unknown retry mode: {retry_mode}")

Expand Down Expand Up @@ -665,8 +667,9 @@ def calculate_throttled_request_rate(
:param timestamp: Timestamp of the response.
:return: New calculated request rate based on CUBIC throttling.
"""
calculated_rate = rate_to_use * self._BETA
self._last_max_rate = rate_to_use
self.calculate_and_update_inflection_point()
calculated_rate = rate_to_use * self._BETA
self._last_throttle_time = timestamp
return calculated_rate

Expand Down Expand Up @@ -802,7 +805,6 @@ async def after_receiving_response(self, throttling_error: bool) -> None:
fill_rate = self._token_bucket.fill_rate
rate_to_use = min(measured_rate, fill_rate)

self._cubic_calculator.calculate_and_update_inflection_point()
cubic_calculated_rate = (
self._cubic_calculator.calculate_throttled_request_rate(
rate_to_use, timestamp
Expand All @@ -820,3 +822,57 @@ async def after_receiving_response(self, throttling_error: bool) -> None:
@property
def rate_limit_enabled(self) -> bool:
return self._rate_limiter_enabled


class AdaptiveRetryStrategy(StandardRetryStrategy):
"""Adaptive retry strategy with client-side rate limiting using CUBIC algorithm.

Builds on top of StandardRetryStrategy by adding token bucket rate limiting and
CUBIC congestion control. Rate limiting is enabled after the first throttling
response and dynamically adjusts sending rates based on the response type.
"""

STARTING_MAX_RATE = 0.5

def __init__(self, *, rate_limiter: ClientRateLimiter | None = None, **kwargs): # type: ignore
"""Initialize AdaptiveRetryStrategy.

:param rate_limiter: Optional pre-configured rate limiter. If None, creates
default components with rate limiting initially disabled.
"""
super().__init__(**kwargs) # type: ignore

if rate_limiter is None:
# Create default rate limiting components
token_bucket = TokenBucket()
cubic_calculator = CubicCalculator(
starting_max_rate=self.STARTING_MAX_RATE, start_time=time.monotonic()
)
rate_tracker = RequestRateTracker()
self._rate_limiter = ClientRateLimiter(
token_bucket=token_bucket,
cubic_calculator=cubic_calculator,
rate_tracker=rate_tracker,
rate_limiter_enabled=False, # Disabled until first throttle
)
else:
self._rate_limiter = rate_limiter

@property
def rate_limiter(self) -> ClientRateLimiter:
"""Get the rate limiter for integration with request pipeline."""
return self._rate_limiter

def is_throttling_error(self, error: Exception) -> bool:
"""Check if error is a throttling error using existing ErrorRetryInfo."""
if isinstance(error, retries_interface.ErrorRetryInfo):
return error.is_throttling_error
return False

async def acquire_from_token_bucket(self) -> None:
if self._rate_limiter.rate_limit_enabled:
await self._rate_limiter.before_sending_request()

def __deepcopy__(self, memo: Any) -> "AdaptiveRetryStrategy":
# Override return type from StandardRetryStrategy to AdaptiveRetryStrategy
return self
Loading
Loading