Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 109 additions & 3 deletions src/wavespeed/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from wavespeed.config import api as api_config

# HTTP status codes that are safe to retry
_RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504}


class Client:
"""WaveSpeed API client.
Expand Down Expand Up @@ -57,6 +60,18 @@ def __init__(
retry_interval if retry_interval is not None else api_config.retry_interval
)

@staticmethod
def _is_retryable_status(status_code: int) -> bool:
"""Check if an HTTP status code is retryable.

Args:
status_code: HTTP response status code.

Returns:
True if the status code indicates a transient error worth retrying.
"""
return status_code in _RETRYABLE_STATUS_CODES

def _get_headers(self) -> dict[str, str]:
"""Get request headers with authentication."""
if not self.api_key:
Expand Down Expand Up @@ -106,13 +121,33 @@ def _submit(
)
timeouts = (connect_timeout, request_timeout)

last_error: Exception | None = None

for retry in range(self.max_connection_retries + 1):
try:
response = requests.post(
url, json=body, headers=self._get_headers(), timeout=timeouts
)

if response.status_code != 200:
# Retry on transient server errors (5xx) and rate limiting (429)
if self._is_retryable_status(response.status_code):
last_error = RuntimeError(
f"Failed to submit prediction: HTTP {response.status_code}: "
f"{response.text}"
)
if retry < self.max_connection_retries:
delay = self.retry_interval * (retry + 1)
print(
f"Server error (HTTP {response.status_code}) on attempt "
f"{retry + 1}/{self.max_connection_retries + 1}, "
f"retrying in {delay} seconds..."
)
time.sleep(delay)
continue
raise last_error

# Non-retryable HTTP errors (4xx etc.) fail immediately
raise RuntimeError(
f"Failed to submit prediction: HTTP {response.status_code}: "
f"{response.text}"
Expand All @@ -133,6 +168,7 @@ def _submit(
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
) as e:
last_error = e
print(
f"Connection error on attempt {retry + 1}/{self.max_connection_retries + 1}:"
)
Expand All @@ -147,6 +183,11 @@ def _submit(
f"Failed to submit prediction after {self.max_connection_retries + 1} attempts"
) from e

# Should not reach here, but guard against it
raise last_error or RuntimeError(
f"Failed to submit prediction after {self.max_connection_retries + 1} attempts"
)

def _get_result(
self, request_id: str, timeout: float | None = None
) -> dict[str, Any]:
Expand All @@ -171,13 +212,32 @@ def _get_result(
)
timeouts = (connect_timeout, request_timeout)

last_error: Exception | None = None

for retry in range(self.max_connection_retries + 1):
try:
response = requests.get(
url, headers=self._get_headers(), timeout=timeouts
)

if response.status_code != 200:
# Retry on transient server errors (5xx) and rate limiting (429)
if self._is_retryable_status(response.status_code):
last_error = RuntimeError(
f"Failed to get result for task {request_id}: "
f"HTTP {response.status_code}: {response.text}"
)
if retry < self.max_connection_retries:
delay = self.retry_interval * (retry + 1)
print(
f"Server error (HTTP {response.status_code}) getting result "
f"on attempt {retry + 1}/{self.max_connection_retries + 1}, "
f"retrying in {delay} seconds..."
)
time.sleep(delay)
continue
raise last_error

raise RuntimeError(
f"Failed to get result for task {request_id}: "
f"HTTP {response.status_code}: {response.text}"
Expand All @@ -189,6 +249,7 @@ def _get_result(
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
) as e:
last_error = e
print(
f"Connection error getting result on attempt {retry + 1}/{self.max_connection_retries + 1}:"
)
Expand All @@ -204,6 +265,12 @@ def _get_result(
f"after {self.max_connection_retries + 1} attempts"
) from e

# Should not reach here, but guard against it
raise last_error or RuntimeError(
f"Failed to get result for task {request_id} "
f"after {self.max_connection_retries + 1} attempts"
)

def _wait(
self,
request_id: str,
Expand Down Expand Up @@ -261,7 +328,12 @@ def _is_retryable_error(self, error: Exception) -> bool:
"""
# Always retry timeout and connection errors
if isinstance(
error, (requests.exceptions.Timeout, requests.exceptions.ConnectionError)
error,
(
requests.exceptions.Timeout,
requests.exceptions.ConnectionError,
TimeoutError,
),
):
return True

Expand Down Expand Up @@ -291,6 +363,8 @@ def run(
timeout: Maximum time to wait for completion (None = no timeout).
poll_interval: Interval between status checks in seconds.
enable_sync_mode: If True, use synchronous mode (single request).
If sync mode fails with a gateway timeout (HTTP 502/504),
the SDK automatically falls back to async mode (submit + poll).
max_retries: Maximum task-level retries (overrides client setting).

Returns:
Expand All @@ -303,14 +377,19 @@ def run(
"""
task_retries = max_retries if max_retries is not None else self.max_retries
last_error = None
# Track whether we should fall back from sync to async mode.
# This happens when sync mode hits a gateway timeout (502/504) after
# exhausting connection-level retries — the gateway cannot hold the
# connection long enough, but the backend may still be healthy.
use_sync = enable_sync_mode

for attempt in range(task_retries + 1):
try:
request_id, sync_result = self._submit(
model, input, enable_sync_mode=enable_sync_mode, timeout=timeout
model, input, enable_sync_mode=use_sync, timeout=timeout
)

if enable_sync_mode:
if use_sync:
# In sync mode, extract outputs from the result
status = sync_result.get("data", {}).get("status")
if status != "completed":
Expand All @@ -328,6 +407,18 @@ def run(

except Exception as e:
last_error = e

# Sync-to-async fallback: if sync mode got a gateway timeout
# (502/504) after all connection retries, switch to async mode
# and retry immediately without consuming a task-level retry.
if use_sync and self._is_gateway_timeout(e):
print(
"Sync mode hit gateway timeout, "
"falling back to async mode (submit + poll)..."
)
use_sync = False
continue

is_retryable = self._is_retryable_error(e)

if not is_retryable or attempt >= task_retries:
Expand All @@ -343,6 +434,21 @@ def run(
raise last_error
raise RuntimeError(f"All {task_retries + 1} attempts failed")

@staticmethod
def _is_gateway_timeout(error: Exception) -> bool:
"""Check if an error is a gateway timeout (HTTP 502 or 504).

Args:
error: The exception to check.

Returns:
True if the error indicates a gateway timeout.
"""
if isinstance(error, RuntimeError):
error_str = str(error)
return "HTTP 502" in error_str or "HTTP 504" in error_str
return False

def upload(self, file: str | BinaryIO, *, timeout: float | None = None) -> str:
"""Upload a file to WaveSpeed.

Expand Down
Loading