Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
_extract_bidi_writes_redirect_proto,
)


_MAX_CHUNK_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB
_DEFAULT_FLUSH_INTERVAL_BYTES = 16 * 1024 * 1024 # 16 MiB
_BIDI_WRITE_REDIRECTED_TYPE_URL = (
Expand Down Expand Up @@ -289,8 +288,7 @@ async def _do_open():
await self.write_obj_stream.close()
except Exception as e:
logger.warning(
"Error closing previous write stream during open retry. Got exception: ",
{e},
f"Error closing previous write stream during open retry. Got exception: {e}"
)
self.write_obj_stream = None
self._is_stream_open = False
Expand Down Expand Up @@ -383,8 +381,6 @@ async def generator():
logger.info(
f"Re-opening the stream with attempt_count: {attempt_count}"
)
if self.write_obj_stream and self.write_obj_stream.is_stream_open:
await self.write_obj_stream.close()

current_metadata = list(metadata) if metadata else []
if write_state.routing_token:
Expand Down
23 changes: 23 additions & 0 deletions google/cloud/storage/asyncio/async_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
DEFAULT_CLIENT_INFO,
)
from google.cloud.storage import __version__
import grpc
from google.auth import credentials as auth_credentials


class AsyncGrpcClient:
Expand Down Expand Up @@ -52,6 +54,12 @@ def __init__(
*,
attempt_direct_path=True,
):
if isinstance(credentials, auth_credentials.AnonymousCredentials):
self._grpc_client = self._create_anonymous_client(
client_options, credentials
)
return

if client_info is None:
client_info = DEFAULT_CLIENT_INFO
client_info.client_library_version = __version__
Expand All @@ -68,6 +76,21 @@ def __init__(
attempt_direct_path=attempt_direct_path,
)

def _create_anonymous_client(self, client_options, credentials):
channel = grpc.aio.insecure_channel(client_options.api_endpoint)
transport = storage_v2.services.storage.transports.StorageGrpcAsyncIOTransport(
channel=channel, credentials=credentials
)
return storage_v2.StorageAsyncClient(transport=transport)

@classmethod
def _create_insecure_grpc_client(cls, client_options):
return cls(
credentials=auth_credentials.AnonymousCredentials(),
client_options=client_options,
attempt_direct_path=False,
)

def _create_async_grpc_client(
self,
credentials=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from google.cloud import _storage_v2
from google.cloud.storage._helpers import generate_random_56_bit_integer


_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100
_BIDI_READ_REDIRECTED_TYPE_URL = (
"type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError"
Expand Down Expand Up @@ -432,7 +431,7 @@ async def generator():

if attempt_count > 1:
logger.info(
f"Resuming download (attempt {attempt_count - 1}) for {len(requests)} ranges."
f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges."
)

async with lock:
Expand All @@ -453,11 +452,7 @@ async def generator():
logger.info(
f"Re-opening stream with routing token: {current_token}"
)
# Close existing stream if any
if self.read_obj_str and self.read_obj_str.is_stream_open:
await self.read_obj_str.close()

# Re-initialize stream
self.read_obj_str = _AsyncReadObjectStream(
client=self.client.grpc_client,
bucket_name=self.bucket_name,
Expand Down
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def conftest_retry(session):
session.install(
"pytest",
"pytest-xdist",
"pytest-asyncio",
"grpcio",
"grpcio-status",
"grpc-google-iam-v1",
Expand Down
30 changes: 30 additions & 0 deletions tests/conformance/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import time
import requests

def start_grpc_server(grpc_endpoint, http_endpoint):
"""Starts the testbench gRPC server if it's not already running.

this essentially makes -

`curl -s --retry 5 --retry-max-time 40 "http://localhost:9000/start_grpc?port=8888"`
"""
start_time = time.time()
max_time = 40
retries = 5
port = grpc_endpoint.split(":")[-1]
url = f"{http_endpoint}/start_grpc?port={port}"

for i in range(retries):
try:
response = requests.get(url, timeout=10)
if response.status_code == 200:
return
except requests.exceptions.RequestException:
pass

elapsed_time = time.time() - start_time
if elapsed_time >= max_time:
raise RuntimeError("Failed to start gRPC server within the time limit.")

# backoff
time.sleep(1)
79 changes: 38 additions & 41 deletions tests/conformance/test_bidi_reads.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import asyncio
import io
import uuid
import grpc
import requests

from google.api_core import exceptions
from google.api_core import exceptions, client_options
from google.auth import credentials as auth_credentials
from google.cloud import _storage_v2 as storage_v2

from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import (
from google.cloud.storage.asyncio.async_multi_range_downloader import (
AsyncMultiRangeDownloader,
)

from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient
import pytest

from tests.conformance._utils import start_grpc_server

# --- Configuration ---
PROJECT_NUMBER = "12345" # A dummy project number is fine for the testbench.
GRPC_ENDPOINT = "localhost:8888"
Expand Down Expand Up @@ -50,8 +54,11 @@ async def run_test_scenario(
retry_test_id = resp.json()["id"]

# 2. Set up downloader and metadata for fault injection.
grpc_client = AsyncGrpcClient._create_insecure_grpc_client(
client_options=client_options.ClientOptions(api_endpoint=GRPC_ENDPOINT),
)
downloader = await AsyncMultiRangeDownloader.create_mrd(
gapic_client, bucket_name, object_name
grpc_client, bucket_name, object_name
)
fault_injection_metadata = (("x-retry-test-id", retry_test_id),)

Expand Down Expand Up @@ -82,8 +89,12 @@ async def run_test_scenario(
http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}")


async def main():
@pytest.mark.asyncio
async def test_bidi_reads():
"""Main function to set up resources and run all test scenarios."""
start_grpc_server(
GRPC_ENDPOINT, HTTP_ENDPOINT
) # Ensure the testbench gRPC server is running before this test executes.
channel = grpc.aio.insecure_channel(GRPC_ENDPOINT)
creds = auth_credentials.AnonymousCredentials()
transport = storage_v2.services.storage.transports.StorageGrpcAsyncIOTransport(
Expand All @@ -97,42 +108,12 @@ async def main():

# Define all test scenarios
test_scenarios = [
{
"name": "Retry on Service Unavailable (503)",
"method": "storage.objects.get",
"instruction": "return-503",
"expected_error": None,
},
{
"name": "Retry on 500",
"method": "storage.objects.get",
"instruction": "return-500",
"expected_error": None,
},
{
"name": "Retry on 504",
"method": "storage.objects.get",
"instruction": "return-504",
"expected_error": None,
},
{
"name": "Retry on 429",
"method": "storage.objects.get",
"instruction": "return-429",
"expected_error": None,
},
{
"name": "Smarter Resumption: Retry 503 after partial data",
"method": "storage.objects.get",
"instruction": "return-broken-stream-after-2K",
"expected_error": None,
},
{
"name": "Retry on BidiReadObjectRedirectedError",
"method": "storage.objects.get",
"instruction": "redirect-send-handle-and-token-tokenval", # Testbench instruction for redirect
"expected_error": None,
},
]

try:
Expand Down Expand Up @@ -185,6 +166,24 @@ async def write_req_gen():
"instruction": "return-401",
"expected_error": exceptions.Unauthorized,
},
{
"name": "Retry on 500",
"method": "storage.objects.get",
"instruction": "return-500",
"expected_error": None,
},
{
"name": "Retry on 504",
"method": "storage.objects.get",
"instruction": "return-504",
"expected_error": None,
},
{
"name": "Retry on 429",
"method": "storage.objects.get",
"instruction": "return-429",
"expected_error": None,
},
]
for scenario in open_test_scenarios:
await run_open_test_scenario(
Expand Down Expand Up @@ -227,15 +226,17 @@ async def run_open_test_scenario(
resp = http_client.post(f"{HTTP_ENDPOINT}/retry_test", json=retry_test_config)
resp.raise_for_status()
retry_test_id = resp.json()["id"]
print(f"Retry Test created with ID: {retry_test_id}")

# 2. Set up metadata for fault injection.
fault_injection_metadata = (("x-retry-test-id", retry_test_id),)

# 3. Execute the open (via create_mrd) and assert the outcome.
try:
grpc_client = AsyncGrpcClient._create_insecure_grpc_client(
client_options=client_options.ClientOptions(api_endpoint=GRPC_ENDPOINT),
)
downloader = await AsyncMultiRangeDownloader.create_mrd(
gapic_client,
grpc_client,
bucket_name,
object_name,
metadata=fault_injection_metadata,
Expand All @@ -260,7 +261,3 @@ async def run_open_test_scenario(
# 4. Clean up the Retry Test resource.
if retry_test_id:
http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}")


if __name__ == "__main__":
asyncio.run(main())
Loading