diff --git a/src/wavespeed/api/client.py b/src/wavespeed/api/client.py index c7d7ff4..1668198 100644 --- a/src/wavespeed/api/client.py +++ b/src/wavespeed/api/client.py @@ -9,6 +9,9 @@ from wavespeed.config import api as api_config +# HTTP status codes that are safe to retry +_RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504} + class Client: """WaveSpeed API client. @@ -57,6 +60,18 @@ def __init__( retry_interval if retry_interval is not None else api_config.retry_interval ) + @staticmethod + def _is_retryable_status(status_code: int) -> bool: + """Check if an HTTP status code is retryable. + + Args: + status_code: HTTP response status code. + + Returns: + True if the status code indicates a transient error worth retrying. + """ + return status_code in _RETRYABLE_STATUS_CODES + def _get_headers(self) -> dict[str, str]: """Get request headers with authentication.""" if not self.api_key: @@ -106,6 +121,8 @@ def _submit( ) timeouts = (connect_timeout, request_timeout) + last_error: Exception | None = None + for retry in range(self.max_connection_retries + 1): try: response = requests.post( @@ -113,6 +130,24 @@ def _submit( ) if response.status_code != 200: + # Retry on transient server errors (5xx) and rate limiting (429) + if self._is_retryable_status(response.status_code): + last_error = RuntimeError( + f"Failed to submit prediction: HTTP {response.status_code}: " + f"{response.text}" + ) + if retry < self.max_connection_retries: + delay = self.retry_interval * (retry + 1) + print( + f"Server error (HTTP {response.status_code}) on attempt " + f"{retry + 1}/{self.max_connection_retries + 1}, " + f"retrying in {delay} seconds..." + ) + time.sleep(delay) + continue + raise last_error + + # Non-retryable HTTP errors (4xx etc.) fail immediately raise RuntimeError( f"Failed to submit prediction: HTTP {response.status_code}: " f"{response.text}" @@ -133,6 +168,7 @@ def _submit( requests.exceptions.ConnectionError, requests.exceptions.Timeout, ) as e: + last_error = e print( f"Connection error on attempt {retry + 1}/{self.max_connection_retries + 1}:" ) @@ -147,6 +183,11 @@ def _submit( f"Failed to submit prediction after {self.max_connection_retries + 1} attempts" ) from e + # Should not reach here, but guard against it + raise last_error or RuntimeError( + f"Failed to submit prediction after {self.max_connection_retries + 1} attempts" + ) + def _get_result( self, request_id: str, timeout: float | None = None ) -> dict[str, Any]: @@ -171,6 +212,8 @@ def _get_result( ) timeouts = (connect_timeout, request_timeout) + last_error: Exception | None = None + for retry in range(self.max_connection_retries + 1): try: response = requests.get( @@ -178,6 +221,23 @@ def _get_result( ) if response.status_code != 200: + # Retry on transient server errors (5xx) and rate limiting (429) + if self._is_retryable_status(response.status_code): + last_error = RuntimeError( + f"Failed to get result for task {request_id}: " + f"HTTP {response.status_code}: {response.text}" + ) + if retry < self.max_connection_retries: + delay = self.retry_interval * (retry + 1) + print( + f"Server error (HTTP {response.status_code}) getting result " + f"on attempt {retry + 1}/{self.max_connection_retries + 1}, " + f"retrying in {delay} seconds..." + ) + time.sleep(delay) + continue + raise last_error + raise RuntimeError( f"Failed to get result for task {request_id}: " f"HTTP {response.status_code}: {response.text}" @@ -189,6 +249,7 @@ def _get_result( requests.exceptions.ConnectionError, requests.exceptions.Timeout, ) as e: + last_error = e print( f"Connection error getting result on attempt {retry + 1}/{self.max_connection_retries + 1}:" ) @@ -204,6 +265,12 @@ def _get_result( f"after {self.max_connection_retries + 1} attempts" ) from e + # Should not reach here, but guard against it + raise last_error or RuntimeError( + f"Failed to get result for task {request_id} " + f"after {self.max_connection_retries + 1} attempts" + ) + def _wait( self, request_id: str, @@ -261,7 +328,12 @@ def _is_retryable_error(self, error: Exception) -> bool: """ # Always retry timeout and connection errors if isinstance( - error, (requests.exceptions.Timeout, requests.exceptions.ConnectionError) + error, + ( + requests.exceptions.Timeout, + requests.exceptions.ConnectionError, + TimeoutError, + ), ): return True @@ -291,6 +363,8 @@ def run( timeout: Maximum time to wait for completion (None = no timeout). poll_interval: Interval between status checks in seconds. enable_sync_mode: If True, use synchronous mode (single request). + If sync mode fails with a gateway timeout (HTTP 502/504), + the SDK automatically falls back to async mode (submit + poll). max_retries: Maximum task-level retries (overrides client setting). Returns: @@ -303,14 +377,19 @@ def run( """ task_retries = max_retries if max_retries is not None else self.max_retries last_error = None + # Track whether we should fall back from sync to async mode. + # This happens when sync mode hits a gateway timeout (502/504) after + # exhausting connection-level retries — the gateway cannot hold the + # connection long enough, but the backend may still be healthy. + use_sync = enable_sync_mode for attempt in range(task_retries + 1): try: request_id, sync_result = self._submit( - model, input, enable_sync_mode=enable_sync_mode, timeout=timeout + model, input, enable_sync_mode=use_sync, timeout=timeout ) - if enable_sync_mode: + if use_sync: # In sync mode, extract outputs from the result status = sync_result.get("data", {}).get("status") if status != "completed": @@ -328,6 +407,18 @@ def run( except Exception as e: last_error = e + + # Sync-to-async fallback: if sync mode got a gateway timeout + # (502/504) after all connection retries, switch to async mode + # and retry immediately without consuming a task-level retry. + if use_sync and self._is_gateway_timeout(e): + print( + "Sync mode hit gateway timeout, " + "falling back to async mode (submit + poll)..." + ) + use_sync = False + continue + is_retryable = self._is_retryable_error(e) if not is_retryable or attempt >= task_retries: @@ -343,6 +434,21 @@ def run( raise last_error raise RuntimeError(f"All {task_retries + 1} attempts failed") + @staticmethod + def _is_gateway_timeout(error: Exception) -> bool: + """Check if an error is a gateway timeout (HTTP 502 or 504). + + Args: + error: The exception to check. + + Returns: + True if the error indicates a gateway timeout. + """ + if isinstance(error, RuntimeError): + error_str = str(error) + return "HTTP 502" in error_str or "HTTP 504" in error_str + return False + def upload(self, file: str | BinaryIO, *, timeout: float | None = None) -> str: """Upload a file to WaveSpeed.