Skip to content

Commit a6dc15f

Browse files
feat: Implement SerialBatcher which helps with transforming single writes into batch writes. (#7)
1 parent f72a2f0 commit a6dc15f

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Generic, List, Iterable
3+
import asyncio
4+
5+
from google.cloud.pubsublite.internal.wire.connection import Request, Response
6+
from google.cloud.pubsublite.internal.wire.work_item import WorkItem
7+
8+
9+
class BatchTester(Generic[Request], ABC):
10+
"""A BatchTester determines whether a given batch of messages must be sent."""
11+
@abstractmethod
12+
def test(self, requests: Iterable[Request]) -> bool:
13+
"""
14+
Args:
15+
requests: The current outstanding batch.
16+
17+
Returns: Whether that batch must be sent.
18+
"""
19+
raise NotImplementedError()
20+
21+
22+
class SerialBatcher(Generic[Request, Response]):
23+
_tester: BatchTester[Request]
24+
_requests: List[WorkItem[Request]] # A list of outstanding requests
25+
26+
def __init__(self, tester: BatchTester[Request]):
27+
self._tester = tester
28+
self._requests = []
29+
30+
def add(self, request: Request) -> 'asyncio.Future[Response]':
31+
"""Add a new request to this batcher. Callers must always call should_flush() after add, and flush() if that returns
32+
true.
33+
34+
Args:
35+
request: The request to send.
36+
37+
Returns:
38+
A future that will resolve to the response or a GoogleAPICallError.
39+
"""
40+
item = WorkItem[Request](request)
41+
self._requests.append(item)
42+
return item.response_future
43+
44+
def should_flush(self) -> bool:
45+
return self._tester.test(item.request for item in self._requests)
46+
47+
def flush(self) -> Iterable[WorkItem[Request]]:
48+
requests = self._requests
49+
self._requests = []
50+
return requests

0 commit comments

Comments
 (0)