Skip to content

Commit ec19dfc

Browse files
fix: remaining issues with subscriber client (#43)
* fix: Remaining issues with subscriber client. Fix make_subscriber to defer GRPC client creation. Fix retrying_connection to not let __aenter__ return until successful initialization or permanent failure. * chore: reformat
1 parent a037d0b commit ec19dfc

File tree

3 files changed

+30
-49
lines changed

3 files changed

+30
-49
lines changed

google/cloud/pubsublite/cloudpubsub/make_subscriber.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,18 @@ def assignment_connection_factory(
8888

8989
def _make_partition_subscriber_factory(
9090
subscription: SubscriptionPath,
91-
subscribe_client: SubscriberServiceAsyncClient,
92-
cursor_client: CursorServiceAsyncClient,
91+
client_options: ClientOptions,
92+
credentials: Optional[Credentials],
9393
base_metadata: Optional[Mapping[str, str]],
9494
flow_control_settings: FlowControlSettings,
9595
nack_handler: NackHandler,
9696
message_transformer: MessageTransformer,
9797
) -> PartitionSubscriberFactory:
9898
def factory(partition: Partition) -> AsyncSubscriber:
99+
subscribe_client = SubscriberServiceAsyncClient(
100+
credentials=credentials, client_options=client_options
101+
) # type: ignore
102+
cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options) # type: ignore
99103
final_metadata = merge_metadata(
100104
base_metadata, subscription_routing_metadata(subscription, partition)
101105
)
@@ -174,25 +178,22 @@ def make_async_subscriber(
174178
if fixed_partitions:
175179
assigner_factory = lambda: FixedSetAssigner(fixed_partitions) # noqa: E731
176180
else:
177-
assignment_client = PartitionAssignmentServiceAsyncClient(
178-
credentials=credentials, client_options=client_options
179-
) # type: ignore
180181
assigner_factory = lambda: _make_dynamic_assigner( # noqa: E731
181-
subscription, assignment_client, metadata
182+
subscription,
183+
PartitionAssignmentServiceAsyncClient(
184+
credentials=credentials, client_options=client_options
185+
),
186+
metadata,
182187
)
183188

184-
subscribe_client = SubscriberServiceAsyncClient(
185-
credentials=credentials, client_options=client_options
186-
) # type: ignore
187-
cursor_client = CursorServiceAsyncClient(credentials=credentials, client_options=client_options) # type: ignore
188189
if nack_handler is None:
189190
nack_handler = DefaultNackHandler()
190191
if message_transformer is None:
191192
message_transformer = DefaultMessageTransformer()
192193
partition_subscriber_factory = _make_partition_subscriber_factory(
193194
subscription,
194-
subscribe_client,
195-
cursor_client,
195+
client_options,
196+
credentials,
196197
metadata,
197198
per_partition_flow_control_settings,
198199
nack_handler,

google/cloud/pubsublite/internal/wire/retrying_connection.py

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class RetryingConnection(Connection[Request, Response], PermanentFailable):
2424

2525
_connection_factory: ConnectionFactory[Request, Response]
2626
_reinitializer: ConnectionReinitializer[Request, Response]
27+
_initialized_once: asyncio.Event
2728

2829
_loop_task: asyncio.Future
2930

@@ -38,11 +39,13 @@ def __init__(
3839
super().__init__()
3940
self._connection_factory = connection_factory
4041
self._reinitializer = reinitializer
42+
self._initialized_once = asyncio.Event()
4143
self._write_queue = asyncio.Queue(maxsize=1)
4244
self._read_queue = asyncio.Queue(maxsize=1)
4345

4446
async def __aenter__(self):
4547
self._loop_task = asyncio.ensure_future(self._run_loop())
48+
await self.await_unless_failed(self._initialized_once.wait())
4649
return self
4750

4851
async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -76,6 +79,7 @@ async def _run_loop(self):
7679
self._read_queue = asyncio.Queue(maxsize=1)
7780
self._write_queue = asyncio.Queue(maxsize=1)
7881
await self._reinitializer.reinitialize(connection)
82+
self._initialized_once.set()
7983
bad_retries = 0
8084
await self._loop_connection(connection)
8185
except GoogleAPICallError as e:

tests/unit/pubsublite/internal/wire/retrying_connection_test.py

+13-37
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
from typing import Union
32

43
from asynctest.mock import MagicMock, CoroutineMock
54
import pytest
@@ -15,6 +14,7 @@
1514
RetryingConnection,
1615
_MIN_BACKOFF_SECS,
1716
)
17+
from google.cloud.pubsublite.testing.test_utils import wire_queues
1818

1919
# All test coroutines will be treated as marked.
2020
pytestmark = pytest.mark.asyncio
@@ -41,7 +41,7 @@ def connection_factory(default_connection):
4141

4242
@pytest.fixture()
4343
def retrying_connection(connection_factory, reinitializer):
44-
return RetryingConnection[int, int](connection_factory, reinitializer)
44+
return RetryingConnection(connection_factory, reinitializer)
4545

4646

4747
@pytest.fixture
@@ -55,40 +55,27 @@ def asyncio_sleep(monkeypatch):
5555
async def test_permanent_error_on_reinitializer(
5656
retrying_connection: Connection[int, int], reinitializer, default_connection
5757
):
58-
fut = asyncio.Future()
59-
reinitialize_called = asyncio.Future()
60-
6158
async def reinit_action(conn):
6259
assert conn == default_connection
63-
reinitialize_called.set_result(None)
64-
return await fut
60+
raise InvalidArgument("abc")
6561

6662
reinitializer.reinitialize.side_effect = reinit_action
67-
async with retrying_connection as _:
68-
await reinitialize_called
69-
reinitializer.reinitialize.assert_called_once()
70-
fut.set_exception(InvalidArgument("abc"))
71-
with pytest.raises(InvalidArgument):
72-
await retrying_connection.read()
63+
with pytest.raises(InvalidArgument):
64+
async with retrying_connection as _:
65+
pass
7366

7467

7568
async def test_successful_reinitialize(
7669
retrying_connection: Connection[int, int], reinitializer, default_connection
7770
):
78-
fut = asyncio.Future()
79-
reinitialize_called = asyncio.Future()
80-
8171
async def reinit_action(conn):
8272
assert conn == default_connection
83-
reinitialize_called.set_result(None)
84-
return await fut
73+
return None
74+
75+
default_connection.read.return_value = 1
8576

8677
reinitializer.reinitialize.side_effect = reinit_action
8778
async with retrying_connection as _:
88-
await reinitialize_called
89-
reinitializer.reinitialize.assert_called_once()
90-
fut.set_result(None)
91-
default_connection.read.return_value = 1
9279
assert await retrying_connection.read() == 1
9380
assert (
9481
default_connection.read.call_count == 2
@@ -111,26 +98,15 @@ async def test_reinitialize_after_retryable(
11198
default_connection,
11299
asyncio_sleep,
113100
):
114-
reinit_called = asyncio.Queue()
115-
reinit_results: "asyncio.Queue[Union[None, Exception]]" = asyncio.Queue()
101+
reinit_queues = wire_queues(reinitializer.reinitialize)
116102

117-
async def reinit_action(conn):
118-
assert conn == default_connection
119-
await reinit_called.put(None)
120-
result = await reinit_results.get()
121-
if isinstance(result, Exception):
122-
raise result
103+
default_connection.read.return_value = 1
123104

124-
reinitializer.reinitialize.side_effect = reinit_action
105+
await reinit_queues.results.put(InternalServerError("abc"))
106+
await reinit_queues.results.put(None)
125107
async with retrying_connection as _:
126-
await reinit_called.get()
127-
reinitializer.reinitialize.assert_called_once()
128-
await reinit_results.put(InternalServerError("abc"))
129-
await reinit_called.get()
130108
asyncio_sleep.assert_called_once_with(_MIN_BACKOFF_SECS)
131109
assert reinitializer.reinitialize.call_count == 2
132-
await reinit_results.put(None)
133-
default_connection.read.return_value = 1
134110
assert await retrying_connection.read() == 1
135111
assert (
136112
default_connection.read.call_count == 2

0 commit comments

Comments
 (0)