Skip to content

Commit 03db702

Browse files
fix: ensure ack() doesn't wait on stream messages (#234)
* fix: ensure ack() doesn't wait on stream messages also fix error propagation to streaming pull future * fix: ensure ack() doesn't wait on stream messages also fix error propagation to streaming pull future * fix: ensure ack() doesn't wait on stream messages also fix error propagation to streaming pull future * fix: remove debug log
1 parent 435ad27 commit 03db702

15 files changed

+146
-178
lines changed

google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def track(self, offset: int):
3535
"""
3636

3737
@abstractmethod
38-
async def ack(self, offset: int):
38+
def ack(self, offset: int):
3939
"""
4040
Acknowledge the message with the provided offset. The offset must have previously been tracked.
4141

google/cloud/pubsublite/cloudpubsub/internal/ack_set_tracker_impl.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Optional
1818

1919
from google.api_core.exceptions import FailedPrecondition
20+
2021
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
2122
from google.cloud.pubsublite.internal.wire.committer import Committer
2223
from google.cloud.pubsublite_v1 import Cursor
@@ -43,9 +44,7 @@ def track(self, offset: int):
4344
)
4445
self._receipts.append(offset)
4546

46-
async def ack(self, offset: int):
47-
# Note: put_nowait is used here and below to ensure that the below logic is executed without yielding
48-
# to another coroutine in the event loop. The queue is unbounded so it will never throw.
47+
def ack(self, offset: int):
4948
self._acks.put_nowait(offset)
5049
prefix_acked_offset: Optional[int] = None
5150
while len(self._receipts) != 0 and not self._acks.empty():
@@ -60,7 +59,7 @@ async def ack(self, offset: int):
6059
if prefix_acked_offset is None:
6160
return
6261
# Convert from last acked to first unacked.
63-
await self._committer.commit(Cursor(offset=prefix_acked_offset + 1))
62+
self._committer.commit(Cursor(offset=prefix_acked_offset + 1))
6463

6564
async def clear_and_commit(self):
6665
self._receipts.clear()

google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import (
2121
AsyncSingleSubscriber,
2222
)
23-
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled
23+
from google.cloud.pubsublite.internal.wait_ignore_cancelled import (
24+
wait_ignore_cancelled,
25+
wait_ignore_errors,
26+
)
2427
from google.cloud.pubsublite.internal.wire.assigner import Assigner
2528
from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable
2629
from google.cloud.pubsublite.types import Partition
@@ -100,8 +103,10 @@ async def __aenter__(self):
100103

101104
async def __aexit__(self, exc_type, exc_value, traceback):
102105
self._assign_poller.cancel()
103-
await wait_ignore_cancelled(self._assign_poller)
104-
await self._assigner.__aexit__(exc_type, exc_value, traceback)
106+
await wait_ignore_errors(self._assign_poller)
107+
await wait_ignore_errors(
108+
self._assigner.__aexit__(exc_type, exc_value, traceback)
109+
)
105110
for running in self._subscribers.values():
106-
await self._stop_subscriber(running)
111+
await wait_ignore_errors(self._stop_subscriber(running))
107112
pass

google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ async def read(self) -> List[Message]:
126126
self.fail(e)
127127
raise e
128128

129-
async def _handle_ack(self, message: requests.AckRequest):
130-
await self._underlying.allow_flow(
129+
def _handle_ack(self, message: requests.AckRequest):
130+
self._underlying.allow_flow(
131131
FlowControlRequest(
132132
allowed_messages=1,
133133
allowed_bytes=self._messages_by_ack_id[message.ack_id].size_bytes,
@@ -138,7 +138,7 @@ async def _handle_ack(self, message: requests.AckRequest):
138138
ack_id = _AckId.parse(message.ack_id)
139139
if ack_id.generation == self._ack_generation_id:
140140
try:
141-
await self._ack_set_tracker.ack(ack_id.offset)
141+
self._ack_set_tracker.ack(ack_id.offset)
142142
except GoogleAPICallError as e:
143143
self.fail(e)
144144

@@ -179,7 +179,7 @@ async def _handle_queue_message(
179179
)
180180
)
181181
elif isinstance(message, requests.AckRequest):
182-
await self._handle_ack(message)
182+
self._handle_ack(message)
183183
else:
184184
self._handle_nack(message)
185185

@@ -198,7 +198,7 @@ async def __aenter__(self):
198198
await self._ack_set_tracker.__aenter__()
199199
await self._underlying.__aenter__()
200200
self._looper_future = asyncio.ensure_future(self._looper())
201-
await self._underlying.allow_flow(
201+
self._underlying.allow_flow(
202202
FlowControlRequest(
203203
allowed_messages=self._flow_control_settings.messages_outstanding,
204204
allowed_bytes=self._flow_control_settings.bytes_outstanding,

google/cloud/pubsublite/cloudpubsub/internal/subscriber_impl.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from concurrent.futures.thread import ThreadPoolExecutor
1818
from typing import ContextManager, Optional
1919
from google.api_core.exceptions import GoogleAPICallError
20+
from functools import partial
21+
22+
from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_errors
2023
from google.cloud.pubsublite.cloudpubsub.internal.managed_event_loop import (
2124
ManagedEventLoop,
2225
)
@@ -86,8 +89,8 @@ async def _poller(self):
8689
while True:
8790
batch = await self._underlying.read()
8891
self._unowned_executor.map(self._callback, batch)
89-
except GoogleAPICallError as e: # noqa: F841 Flake8 thinks e is unused
90-
self._unowned_executor.submit(lambda: self._fail(e)) # noqa: F821
92+
except GoogleAPICallError as e:
93+
self._unowned_executor.submit(partial(self._fail, e))
9194

9295
def __enter__(self):
9396
assert self._close_callback is not None
@@ -97,13 +100,15 @@ def __enter__(self):
97100
return self
98101

99102
def __exit__(self, exc_type, exc_value, traceback):
103+
self._poller_future.cancel()
100104
try:
101-
self._poller_future.cancel()
102-
self._poller_future.result()
103-
except concurrent.futures.CancelledError:
105+
self._poller_future.result() # Ignore error.
106+
except: # noqa: E722
104107
pass
105108
self._event_loop.submit(
106-
self._underlying.__aexit__(exc_type, exc_value, traceback)
109+
wait_ignore_errors(
110+
self._underlying.__aexit__(exc_type, exc_value, traceback)
111+
)
107112
).result()
108113
self._event_loop.__exit__(exc_type, exc_value, traceback)
109114
assert self._close_callback is not None

google/cloud/pubsublite/internal/wire/committer.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ class Committer(AsyncContextManager, metaclass=ABCMeta):
2424
"""
2525

2626
@abstractmethod
27-
async def commit(self, cursor: Cursor) -> None:
27+
def commit(self, cursor: Cursor) -> None:
28+
"""
29+
Start the commit for a cursor.
30+
31+
Raises:
32+
GoogleAPICallError: When the committer terminates in failure.
33+
"""
2834
pass
2935

3036
@abstractmethod

google/cloud/pubsublite/internal/wire/committer_impl.py

+18-37
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,12 @@
2828
ConnectionReinitializer,
2929
)
3030
from google.cloud.pubsublite.internal.wire.connection import Connection
31-
from google.cloud.pubsublite.internal.wire.serial_batcher import SerialBatcher
3231
from google.cloud.pubsublite_v1 import Cursor
3332
from google.cloud.pubsublite_v1.types import (
3433
StreamingCommitCursorRequest,
3534
StreamingCommitCursorResponse,
3635
InitialCommitCursorRequest,
3736
)
38-
from google.cloud.pubsublite.internal.wire.work_item import WorkItem
3937

4038

4139
_LOGGER = logging.getLogger(__name__)
@@ -53,9 +51,8 @@ class CommitterImpl(
5351
StreamingCommitCursorRequest, StreamingCommitCursorResponse
5452
]
5553

56-
_batcher: SerialBatcher[Cursor, None]
57-
58-
_outstanding_commits: List[List[WorkItem[Cursor, None]]]
54+
_next_to_commit: Optional[Cursor]
55+
_outstanding_commits: List[Cursor]
5956

6057
_receiver: Optional[asyncio.Future]
6158
_flusher: Optional[asyncio.Future]
@@ -72,7 +69,7 @@ def __init__(
7269
self._initial = initial
7370
self._flush_seconds = flush_seconds
7471
self._connection = RetryingConnection(factory, self)
75-
self._batcher = SerialBatcher()
72+
self._next_to_commit = None
7673
self._outstanding_commits = []
7774
self._receiver = None
7875
self._flusher = None
@@ -113,9 +110,7 @@ def _handle_response(self, response: StreamingCommitCursorResponse):
113110
)
114111
)
115112
for _ in range(response.commit.acknowledged_commits):
116-
batch = self._outstanding_commits.pop(0)
117-
for item in batch:
118-
item.response_future.set_result(None)
113+
self._outstanding_commits.pop(0)
119114
if len(self._outstanding_commits) == 0:
120115
self._empty.set()
121116

@@ -131,39 +126,31 @@ async def _flush_loop(self):
131126

132127
async def __aexit__(self, exc_type, exc_val, exc_tb):
133128
await self._stop_loopers()
134-
if self._connection.error():
135-
self._fail_if_retrying_failed()
136-
else:
129+
if not self._connection.error():
137130
await self._flush()
138131
await self._connection.__aexit__(exc_type, exc_val, exc_tb)
139132

140-
def _fail_if_retrying_failed(self):
141-
if self._connection.error():
142-
for batch in self._outstanding_commits:
143-
for item in batch:
144-
item.response_future.set_exception(self._connection.error())
145-
146133
async def _flush(self):
147-
batch = self._batcher.flush()
148-
if not batch:
134+
if self._next_to_commit is None:
149135
return
150-
self._outstanding_commits.append(batch)
151-
self._empty.clear()
152136
req = StreamingCommitCursorRequest()
153-
req.commit.cursor = batch[-1].request
137+
req.commit.cursor = self._next_to_commit
138+
self._outstanding_commits.append(self._next_to_commit)
139+
self._next_to_commit = None
140+
self._empty.clear()
154141
try:
155142
await self._connection.write(req)
156143
except GoogleAPICallError as e:
157144
_LOGGER.debug(f"Failed commit on stream: {e}")
158-
self._fail_if_retrying_failed()
159145

160146
async def wait_until_empty(self):
161147
await self._flush()
162148
await self._connection.await_unless_failed(self._empty.wait())
163149

164-
async def commit(self, cursor: Cursor) -> None:
165-
future = self._batcher.add(cursor)
166-
await future
150+
def commit(self, cursor: Cursor) -> None:
151+
if self._connection.error():
152+
raise self._connection.error()
153+
self._next_to_commit = cursor
167154

168155
async def reinitialize(
169156
self,
@@ -181,14 +168,8 @@ async def reinitialize(
181168
"Received an invalid initial response on the publish stream."
182169
)
183170
)
184-
if self._outstanding_commits:
185-
# Roll up outstanding commits
186-
rollup: List[WorkItem[Cursor, None]] = []
187-
for batch in self._outstanding_commits:
188-
for item in batch:
189-
rollup.append(item)
190-
self._outstanding_commits = [rollup]
191-
req = StreamingCommitCursorRequest()
192-
req.commit.cursor = rollup[-1].request
193-
await connection.write(req)
171+
if self._next_to_commit is None:
172+
if self._outstanding_commits:
173+
self._next_to_commit = self._outstanding_commits[-1]
174+
self._outstanding_commits = []
194175
self._start_loopers()

google/cloud/pubsublite/internal/wire/flow_control_batcher.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@ class _AggregateRequest:
2626
def __init__(self):
2727
self._request = FlowControlRequest.meta.pb()
2828

29-
def __add__(self, other: FlowControlRequest.meta.pb):
30-
self._request.allowed_bytes = self._request.allowed_bytes + other.allowed_bytes
29+
def __add__(self, other: FlowControlRequest):
30+
other_pb = other._pb
31+
self._request.allowed_bytes = (
32+
self._request.allowed_bytes + other_pb.allowed_bytes
33+
)
3134
self._request.allowed_bytes = min(self._request.allowed_bytes, _MAX_INT64)
3235
self._request.allowed_messages = (
33-
self._request.allowed_messages + other.allowed_messages
36+
self._request.allowed_messages + other_pb.allowed_messages
3437
)
3538
self._request.allowed_messages = min(self._request.allowed_messages, _MAX_INT64)
3639
return self
@@ -77,16 +80,3 @@ def release_pending_request(self) -> Optional[FlowControlRequest]:
7780
request = self._pending_tokens
7881
self._pending_tokens = _AggregateRequest()
7982
return request.to_optional()
80-
81-
def should_expedite(self):
82-
pending_request = self._pending_tokens._request
83-
client_request = self._client_tokens._request
84-
if _exceeds_expedite_ratio(
85-
pending_request.allowed_bytes, client_request.allowed_bytes
86-
):
87-
return True
88-
if _exceeds_expedite_ratio(
89-
pending_request.allowed_messages, client_request.allowed_messages
90-
):
91-
return True
92-
return False

google/cloud/pubsublite/internal/wire/subscriber.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def read(self) -> List[SequencedMessage.meta.pb]:
3636
raise NotImplementedError()
3737

3838
@abstractmethod
39-
async def allow_flow(self, request: FlowControlRequest):
39+
def allow_flow(self, request: FlowControlRequest):
4040
"""
4141
Allow an additional amount of messages and bytes to be sent to this client.
4242
"""

google/cloud/pubsublite/internal/wire/subscriber_impl.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,5 @@ async def reinitialize(
201201
async def read(self) -> List[SequencedMessage.meta.pb]:
202202
return await self._connection.await_unless_failed(self._message_queue.get())
203203

204-
async def allow_flow(self, request: FlowControlRequest):
204+
def allow_flow(self, request: FlowControlRequest):
205205
self._outstanding_flow_control.add(request)
206-
if (
207-
not self._reinitializing
208-
and self._outstanding_flow_control.should_expedite()
209-
):
210-
await self._try_send_tokens()

tests/unit/pubsublite/cloudpubsub/internal/ack_set_tracker_impl_test.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ async def test_track_and_aggregate_acks(committer, tracker: AckSetTracker):
4949
tracker.track(offset=7)
5050

5151
committer.commit.assert_has_calls([])
52-
await tracker.ack(offset=3)
52+
tracker.ack(offset=3)
5353
committer.commit.assert_has_calls([])
54-
await tracker.ack(offset=1)
54+
tracker.ack(offset=1)
5555
committer.commit.assert_has_calls([call(Cursor(offset=4))])
56-
await tracker.ack(offset=5)
56+
tracker.ack(offset=5)
5757
committer.commit.assert_has_calls(
5858
[call(Cursor(offset=4)), call(Cursor(offset=6))]
5959
)
6060

6161
tracker.track(offset=8)
62-
await tracker.ack(offset=7)
62+
tracker.ack(offset=7)
6363
committer.commit.assert_has_calls(
6464
[call(Cursor(offset=4)), call(Cursor(offset=6)), call(Cursor(offset=8))]
6565
)
@@ -74,14 +74,14 @@ async def test_clear_and_commit(committer, tracker: AckSetTracker):
7474

7575
with pytest.raises(FailedPrecondition):
7676
tracker.track(offset=1)
77-
await tracker.ack(offset=5)
77+
tracker.ack(offset=5)
7878
committer.commit.assert_has_calls([])
7979

8080
await tracker.clear_and_commit()
8181
committer.wait_until_empty.assert_called_once()
8282

8383
# After clearing, it should be possible to track earlier offsets.
8484
tracker.track(offset=1)
85-
await tracker.ack(offset=1)
85+
tracker.ack(offset=1)
8686
committer.commit.assert_has_calls([call(Cursor(offset=2))])
8787
committer.__aexit__.assert_called_once()

0 commit comments

Comments
 (0)