Skip to content

Commit a9293c1

Browse files
feat: Implement CPS non-asyncio subscriber (#25)
* feat: Implement publisher and subscriber factories. * Implement CPS subscriber that takes a message callback and returns a StreamingPullFuture. * docs: document why we have two phase init
1 parent 4890cae commit a9293c1

File tree

7 files changed

+247
-9
lines changed

7 files changed

+247
-9
lines changed

google/cloud/pubsublite/cloudpubsub/internal/assigning_subscriber.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class AssigningSubscriber(AsyncSubscriber, PermanentFailable):
2525
_messages: "Queue[Message]"
2626
_assign_poller: Future
2727

28-
def __init__(self, assigner: Assigner, subscriber_factory: _PartitionSubscriberFactory):
28+
def __init__(self, assigner: Assigner, subscriber_factory: PartitionSubscriberFactory):
2929
super().__init__()
3030
self._assigner = assigner
3131
self._subscriber_factory = subscriber_factory
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, Callable
3+
4+
from google.api_core.exceptions import GoogleAPICallError
5+
6+
7+
CloseCallback = Callable[["StreamingPullManager", Optional[GoogleAPICallError]], None]
8+
9+
10+
class StreamingPullManager(ABC):
11+
"""The API expected by StreamingPullFuture."""
12+
@abstractmethod
13+
def add_close_callback(self, close_callback: CloseCallback):
14+
pass
15+
16+
@abstractmethod
17+
def close(self):
18+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import concurrent.futures
2+
import threading
3+
from asyncio import CancelledError
4+
from concurrent.futures.thread import ThreadPoolExecutor
5+
from typing import ContextManager, Optional
6+
from google.api_core.exceptions import GoogleAPICallError
7+
from google.cloud.pubsublite.cloudpubsub.internal.managed_event_loop import ManagedEventLoop
8+
from google.cloud.pubsublite.cloudpubsub.internal.streaming_pull_manager import StreamingPullManager, CloseCallback
9+
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber, MessageCallback
10+
11+
12+
class SubscriberImpl(ContextManager, StreamingPullManager):
13+
_underlying: AsyncSubscriber
14+
_callback: MessageCallback
15+
_executor: ThreadPoolExecutor
16+
17+
_event_loop: ManagedEventLoop
18+
19+
_poller_future: concurrent.futures.Future
20+
_close_lock: threading.Lock
21+
_failure: Optional[GoogleAPICallError]
22+
_close_callback: Optional[CloseCallback]
23+
_closed: bool
24+
25+
def __init__(self, underlying: AsyncSubscriber, callback: MessageCallback, executor: ThreadPoolExecutor):
26+
self._underlying = underlying
27+
self._callback = callback
28+
self._executor = executor
29+
self._event_loop = ManagedEventLoop()
30+
self._close_lock = threading.Lock()
31+
self._failure = None
32+
self._close_callback = None
33+
self._closed = False
34+
35+
def add_close_callback(self, close_callback: CloseCallback):
36+
"""
37+
A close callback must be set exactly once by the StreamingPullFuture managing this subscriber.
38+
39+
This two-phase init model is made necessary by the requirements of StreamingPullFuture.
40+
"""
41+
with self._close_lock:
42+
assert self._close_callback is None
43+
self._close_callback = close_callback
44+
45+
def close(self):
46+
with self._close_lock:
47+
if not self._closed:
48+
self._closed = True
49+
self.__exit__(None, None, None)
50+
51+
def _fail(self, error: GoogleAPICallError):
52+
self._failure = error
53+
self.close()
54+
55+
async def _poller(self):
56+
try:
57+
while True:
58+
message = await self._underlying.read()
59+
self._executor.submit(self._callback, message)
60+
except GoogleAPICallError as e:
61+
self._executor.submit(lambda: self._fail(e))
62+
63+
def __enter__(self):
64+
assert self._close_callback is not None
65+
self._event_loop.__enter__()
66+
self._event_loop.submit(self._underlying.__aenter__()).result()
67+
self._poller_future = self._event_loop.submit(self._poller())
68+
return self
69+
70+
def __exit__(self, exc_type, exc_value, traceback):
71+
try:
72+
self._poller_future.cancel()
73+
self._poller_future.result()
74+
except CancelledError:
75+
pass
76+
self._event_loop.submit(self._underlying.__aexit__(exc_type, exc_value, traceback)).result()
77+
self._event_loop.__exit__(exc_type, exc_value, traceback)
78+
assert self._close_callback is not None
79+
self._executor.shutdown(wait=False) # __exit__ may be called from the executor.
80+
self._close_callback(self, self._failure)

google/cloud/pubsublite/cloudpubsub/make_subscriber.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
from concurrent.futures.thread import ThreadPoolExecutor
12
from typing import Optional, Mapping, Set, AsyncIterator
23
from uuid import uuid4
34

45
from google.api_core.client_options import ClientOptions
56
from google.auth.credentials import Credentials
6-
7+
from google.cloud.pubsub_v1.subscriber.futures import StreamingPullFuture
78
from google.cloud.pubsublite.cloudpubsub.flow_control_settings import FlowControlSettings
89
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker_impl import AckSetTrackerImpl
910
from google.cloud.pubsublite.cloudpubsub.internal.assigning_subscriber import PartitionSubscriberFactory, \
1011
AssigningSubscriber
1112
from google.cloud.pubsublite.cloudpubsub.internal.single_partition_subscriber import SinglePartitionSubscriber
13+
import google.cloud.pubsublite.cloudpubsub.internal.subscriber_impl as cps_subscriber
1214
from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer, DefaultMessageTransformer
1315
from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler, DefaultNackHandler
14-
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
16+
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber, MessageCallback
1517
from google.cloud.pubsublite.endpoints import regional_endpoint
1618
from google.cloud.pubsublite.internal.wire.assigner import Assigner
1719
from google.cloud.pubsublite.internal.wire.assigner_impl import AssignerImpl
@@ -20,7 +22,7 @@
2022
from google.cloud.pubsublite.internal.wire.gapic_connection import GapicConnectionFactory
2123
from google.cloud.pubsublite.internal.wire.merge_metadata import merge_metadata
2224
from google.cloud.pubsublite.internal.wire.pubsub_context import pubsub_context
23-
from google.cloud.pubsublite.internal.wire.subscriber_impl import SubscriberImpl
25+
import google.cloud.pubsublite.internal.wire.subscriber_impl as wire_subscriber
2426
from google.cloud.pubsublite.partition import Partition
2527
from google.cloud.pubsublite.paths import SubscriptionPath
2628
from google.cloud.pubsublite.routing_metadata import subscription_routing_metadata
@@ -63,14 +65,14 @@ def subscribe_connection_factory(requests: AsyncIterator[SubscribeRequest]):
6365
def cursor_connection_factory(requests: AsyncIterator[StreamingCommitCursorRequest]):
6466
return cursor_client.streaming_commit_cursor(requests, metadata=list(final_metadata.items()))
6567

66-
wire_subscriber = SubscriberImpl(
68+
subscriber = wire_subscriber.SubscriberImpl(
6769
InitialSubscribeRequest(subscription=str(subscription), partition=partition.value),
6870
_DEFAULT_FLUSH_SECONDS, GapicConnectionFactory(subscribe_connection_factory))
6971
committer = CommitterImpl(
7072
InitialCommitCursorRequest(subscription=str(subscription), partition=partition.value),
7173
_DEFAULT_FLUSH_SECONDS, GapicConnectionFactory(cursor_connection_factory))
7274
ack_set_tracker = AckSetTrackerImpl(committer)
73-
return SinglePartitionSubscriber(wire_subscriber, flow_control_settings, ack_set_tracker, nack_handler,
75+
return SinglePartitionSubscriber(subscriber, flow_control_settings, ack_set_tracker, nack_handler,
7476
message_transformer)
7577

7678
return factory
@@ -124,3 +126,46 @@ def make_async_subscriber(
124126
metadata, per_partition_flow_control_settings,
125127
nack_handler, message_transformer)
126128
return AssigningSubscriber(assigner, partition_subscriber_factory)
129+
130+
131+
def make_subscriber(
132+
subscription: SubscriptionPath,
133+
per_partition_flow_control_settings: FlowControlSettings,
134+
callback: MessageCallback,
135+
nack_handler: Optional[NackHandler] = None,
136+
message_transformer: Optional[MessageTransformer] = None,
137+
fixed_partitions: Optional[Set[Partition]] = None,
138+
executor: Optional[ThreadPoolExecutor] = None,
139+
credentials: Optional[Credentials] = None,
140+
client_options: Optional[ClientOptions] = None,
141+
metadata: Optional[Mapping[str, str]] = None) -> StreamingPullFuture:
142+
"""
143+
Make a Pub/Sub Lite Subscriber.
144+
145+
Args:
146+
subscription: The subscription to subscribe to.
147+
per_partition_flow_control_settings: The flow control settings for each partition subscribed to. Note that these
148+
settings apply to each partition individually, not in aggregate.
149+
callback: The callback to call with each message.
150+
nack_handler: An optional handler for when nack() is called on a Message. The default will fail the client.
151+
message_transformer: An optional transformer from Pub/Sub Lite messages to Cloud Pub/Sub messages.
152+
fixed_partitions: A fixed set of partitions to subscribe to. If not present, will instead use auto-assignment.
153+
executor: The executor to use for user callbacks. If not provided, will use the default constructed
154+
ThreadPoolExecutor. If provided a single threaded executor, messages will be ordered per-partition, but take care
155+
that the callback does not block for too long as it will impede forward progress on all partitions.
156+
credentials: The credentials to use to connect. GOOGLE_DEFAULT_CREDENTIALS is used if None.
157+
client_options: Other options to pass to the client. Note that if you pass any you must set api_endpoint.
158+
metadata: Additional metadata to send with the RPC.
159+
160+
Returns:
161+
A StreamingPullFuture, managing the subscriber's lifetime.
162+
"""
163+
underlying = make_async_subscriber(
164+
subscription, per_partition_flow_control_settings, nack_handler, message_transformer, fixed_partitions, credentials,
165+
client_options, metadata)
166+
if executor is None:
167+
executor = ThreadPoolExecutor()
168+
subscriber = cps_subscriber.SubscriberImpl(underlying, callback, executor)
169+
future = StreamingPullFuture(subscriber)
170+
subscriber.__enter__()
171+
return future

google/cloud/pubsublite/cloudpubsub/subscriber.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import abstractmethod
2-
from typing import AsyncContextManager
2+
from typing import AsyncContextManager, Callable
33

44
from google.cloud.pubsub_v1.subscriber.message import Message
55

@@ -23,3 +23,6 @@ async def read(self) -> Message:
2323
GoogleAPICallError: On a permanent error.
2424
"""
2525
raise NotImplementedError()
26+
27+
28+
MessageCallback = Callable[[Message], None]

tests/unit/pubsublite/cloudpubsub/internal/assigning_subscriber_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import asyncio
2-
from typing import Callable, Set
1+
from typing import Set
32

43
from asynctest.mock import MagicMock, call
54
import pytest
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import asyncio
2+
import concurrent
3+
from concurrent.futures.thread import ThreadPoolExecutor
4+
from queue import Queue
5+
6+
from asynctest.mock import MagicMock
7+
import pytest
8+
from google.api_core.exceptions import FailedPrecondition
9+
from google.cloud.pubsub_v1.subscriber.message import Message
10+
from google.pubsub_v1 import PubsubMessage
11+
12+
from google.cloud.pubsublite.cloudpubsub.internal.streaming_pull_manager import CloseCallback
13+
from google.cloud.pubsublite.cloudpubsub.internal.subscriber_impl import SubscriberImpl
14+
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber, MessageCallback
15+
from google.cloud.pubsublite.testing.test_utils import Box
16+
17+
18+
@pytest.fixture()
19+
def async_subscriber():
20+
subscriber = MagicMock(spec=AsyncSubscriber)
21+
subscriber.__aenter__.return_value = subscriber
22+
return subscriber
23+
24+
25+
@pytest.fixture()
26+
def message_callback():
27+
return MagicMock(spec=MessageCallback)
28+
29+
30+
@pytest.fixture()
31+
def close_callback():
32+
return MagicMock(spec=CloseCallback)
33+
34+
35+
@pytest.fixture()
36+
def subscriber(async_subscriber, message_callback, close_callback):
37+
return SubscriberImpl(async_subscriber, message_callback, ThreadPoolExecutor(max_workers=1))
38+
39+
40+
async def sleep_forever(*args, **kwargs):
41+
await asyncio.sleep(float("inf"))
42+
43+
44+
def test_init(subscriber: SubscriberImpl, async_subscriber, close_callback):
45+
async_subscriber.read.side_effect = sleep_forever
46+
subscriber.add_close_callback(close_callback)
47+
subscriber.__enter__()
48+
async_subscriber.__aenter__.assert_called_once()
49+
subscriber.close()
50+
async_subscriber.__aexit__.assert_called_once()
51+
close_callback.assert_called_once_with(subscriber, None)
52+
53+
54+
def test_failed(subscriber: SubscriberImpl, async_subscriber, close_callback):
55+
error = FailedPrecondition("bad read")
56+
async_subscriber.read.side_effect = error
57+
58+
close_called = concurrent.futures.Future()
59+
close_callback.side_effect = lambda manager, err: close_called.set_result(None)
60+
61+
subscriber.add_close_callback(close_callback)
62+
subscriber.__enter__()
63+
async_subscriber.__aenter__.assert_called_once()
64+
close_called.result()
65+
async_subscriber.__aexit__.assert_called_once()
66+
close_callback.assert_called_once_with(subscriber, error)
67+
68+
69+
def test_messages_received(subscriber: SubscriberImpl, async_subscriber, message_callback, close_callback):
70+
message1 = Message(PubsubMessage(message_id="1")._pb, "", 0, None)
71+
message2 = Message(PubsubMessage(message_id="2")._pb, "", 0, None)
72+
73+
counter = Box[int]()
74+
counter.val = 0
75+
76+
async def on_read() -> Message:
77+
counter.val += 1
78+
if counter.val == 1:
79+
return message1
80+
if counter.val == 2:
81+
return message2
82+
await sleep_forever()
83+
84+
async_subscriber.read.side_effect = on_read
85+
86+
results = Queue()
87+
message_callback.side_effect = lambda m: results.put(m.message_id)
88+
89+
subscriber.add_close_callback(close_callback)
90+
subscriber.__enter__()
91+
assert results.get() == "1"
92+
assert results.get() == "2"
93+
subscriber.close()

0 commit comments

Comments
 (0)