Skip to content

Commit d317d2e

Browse files
authored
fix: Catch rst stream error for all transactions (#934)
* fix: rst retry for txn * rst changes and tests * fix * rst stream comment changes * lint * lint
1 parent c53f273 commit d317d2e

File tree

8 files changed

+268
-14
lines changed

8 files changed

+268
-14
lines changed

google/cloud/spanner_v1/_helpers.py

+54
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import datetime
1818
import decimal
1919
import math
20+
import time
2021

2122
from google.protobuf.struct_pb2 import ListValue
2223
from google.protobuf.struct_pb2 import Value
@@ -294,6 +295,59 @@ def _metadata_with_prefix(prefix, **kw):
294295
return [("google-cloud-resource-prefix", prefix)]
295296

296297

298+
def _retry(
299+
func,
300+
retry_count=5,
301+
delay=2,
302+
allowed_exceptions=None,
303+
):
304+
"""
305+
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.
306+
307+
Args:
308+
func: The function to be retried.
309+
retry_count: The maximum number of times to retry the function.
310+
delay: The delay in seconds between retries.
311+
allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
312+
Passing allowed_exceptions as None will lead to retrying for all exceptions.
313+
314+
Returns:
315+
The result of the function if it is successful, or raises the last exception if all retries fail.
316+
"""
317+
retries = 0
318+
while retries <= retry_count:
319+
try:
320+
return func()
321+
except Exception as exc:
322+
if (
323+
allowed_exceptions is None or exc.__class__ in allowed_exceptions
324+
) and retries < retry_count:
325+
if (
326+
allowed_exceptions is not None
327+
and allowed_exceptions[exc.__class__] is not None
328+
):
329+
allowed_exceptions[exc.__class__](exc)
330+
time.sleep(delay)
331+
delay = delay * 2
332+
retries = retries + 1
333+
else:
334+
raise exc
335+
336+
337+
def _check_rst_stream_error(exc):
338+
resumable_error = (
339+
any(
340+
resumable_message in exc.message
341+
for resumable_message in (
342+
"RST_STREAM",
343+
"Received unexpected EOS on DATA frame from server",
344+
)
345+
),
346+
)
347+
if not resumable_error:
348+
raise
349+
350+
297351
def _metadata_with_leader_aware_routing(value, **kw):
298352
"""Create RPC metadata containing a leader aware routing header
299353

google/cloud/spanner_v1/batch.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Context manager for Cloud Spanner batched writes."""
16+
import functools
1617

1718
from google.cloud.spanner_v1 import CommitRequest
1819
from google.cloud.spanner_v1 import Mutation
@@ -26,6 +27,9 @@
2627
)
2728
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
2829
from google.cloud.spanner_v1 import RequestOptions
30+
from google.cloud.spanner_v1._helpers import _retry
31+
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
32+
from google.api_core.exceptions import InternalServerError
2933

3034

3135
class _BatchBase(_SessionWrapper):
@@ -186,10 +190,15 @@ def commit(self, return_commit_stats=False, request_options=None):
186190
request_options=request_options,
187191
)
188192
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
189-
response = api.commit(
193+
method = functools.partial(
194+
api.commit,
190195
request=request,
191196
metadata=metadata,
192197
)
198+
response = _retry(
199+
method,
200+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
201+
)
193202
self.committed = response.commit_timestamp
194203
self.commit_stats = response.commit_stats
195204
return self.committed

google/cloud/spanner_v1/snapshot.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
from google.api_core.exceptions import ServiceUnavailable
3030
from google.api_core.exceptions import InvalidArgument
3131
from google.api_core import gapic_v1
32-
from google.cloud.spanner_v1._helpers import _make_value_pb
33-
from google.cloud.spanner_v1._helpers import _merge_query_options
3432
from google.cloud.spanner_v1._helpers import (
33+
_make_value_pb,
34+
_merge_query_options,
3535
_metadata_with_prefix,
3636
_metadata_with_leader_aware_routing,
37+
_retry,
38+
_check_rst_stream_error,
39+
_SessionWrapper,
3740
)
38-
from google.cloud.spanner_v1._helpers import _SessionWrapper
3941
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
4042
from google.cloud.spanner_v1.streamed import StreamedResultSet
4143
from google.cloud.spanner_v1 import RequestOptions
@@ -560,12 +562,17 @@ def partition_read(
560562
with trace_call(
561563
"CloudSpanner.PartitionReadOnlyTransaction", self._session, trace_attributes
562564
):
563-
response = api.partition_read(
565+
method = functools.partial(
566+
api.partition_read,
564567
request=request,
565568
metadata=metadata,
566569
retry=retry,
567570
timeout=timeout,
568571
)
572+
response = _retry(
573+
method,
574+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
575+
)
569576

570577
return [partition.partition_token for partition in response.partitions]
571578

@@ -659,12 +666,17 @@ def partition_query(
659666
self._session,
660667
trace_attributes,
661668
):
662-
response = api.partition_query(
669+
method = functools.partial(
670+
api.partition_query,
663671
request=request,
664672
metadata=metadata,
665673
retry=retry,
666674
timeout=timeout,
667675
)
676+
response = _retry(
677+
method,
678+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
679+
)
668680

669681
return [partition.partition_token for partition in response.partitions]
670682

@@ -791,10 +803,15 @@ def begin(self):
791803
)
792804
txn_selector = self._make_txn_selector()
793805
with trace_call("CloudSpanner.BeginTransaction", self._session):
794-
response = api.begin_transaction(
806+
method = functools.partial(
807+
api.begin_transaction,
795808
session=self._session.name,
796809
options=txn_selector.begin,
797810
metadata=metadata,
798811
)
812+
response = _retry(
813+
method,
814+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
815+
)
799816
self._transaction_id = response.id
800817
return self._transaction_id

google/cloud/spanner_v1/transaction.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
_merge_query_options,
2323
_metadata_with_prefix,
2424
_metadata_with_leader_aware_routing,
25+
_retry,
26+
_check_rst_stream_error,
2527
)
2628
from google.cloud.spanner_v1 import CommitRequest
2729
from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
@@ -33,6 +35,7 @@
3335
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
3436
from google.cloud.spanner_v1 import RequestOptions
3537
from google.api_core import gapic_v1
38+
from google.api_core.exceptions import InternalServerError
3639

3740

3841
class Transaction(_SnapshotBase, _BatchBase):
@@ -102,7 +105,11 @@ def _execute_request(
102105
transaction = self._make_txn_selector()
103106
request.transaction = transaction
104107
with trace_call(trace_name, session, attributes):
105-
response = method(request=request)
108+
method = functools.partial(method, request=request)
109+
response = _retry(
110+
method,
111+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
112+
)
106113

107114
return response
108115

@@ -132,8 +139,15 @@ def begin(self):
132139
)
133140
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
134141
with trace_call("CloudSpanner.BeginTransaction", self._session):
135-
response = api.begin_transaction(
136-
session=self._session.name, options=txn_options, metadata=metadata
142+
method = functools.partial(
143+
api.begin_transaction,
144+
session=self._session.name,
145+
options=txn_options,
146+
metadata=metadata,
147+
)
148+
response = _retry(
149+
method,
150+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
137151
)
138152
self._transaction_id = response.id
139153
return self._transaction_id
@@ -153,11 +167,16 @@ def rollback(self):
153167
)
154168
)
155169
with trace_call("CloudSpanner.Rollback", self._session):
156-
api.rollback(
170+
method = functools.partial(
171+
api.rollback,
157172
session=self._session.name,
158173
transaction_id=self._transaction_id,
159174
metadata=metadata,
160175
)
176+
_retry(
177+
method,
178+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
179+
)
161180
self.rolled_back = True
162181
del self._session._transaction
163182

@@ -212,10 +231,15 @@ def commit(self, return_commit_stats=False, request_options=None):
212231
request_options=request_options,
213232
)
214233
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
215-
response = api.commit(
234+
method = functools.partial(
235+
api.commit,
216236
request=request,
217237
metadata=metadata,
218238
)
239+
response = _retry(
240+
method,
241+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
242+
)
219243
self.committed = response.commit_timestamp
220244
if return_commit_stats:
221245
self.commit_stats = response.commit_stats

tests/unit/spanner_dbapi/test_connection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test__session_checkout(self, mock_database):
170170
connection._session_checkout()
171171
self.assertEqual(connection._session, "db_session")
172172

173-
def test__session_checkout_database_error(self):
173+
def test_session_checkout_database_error(self):
174174
from google.cloud.spanner_dbapi import Connection
175175

176176
connection = Connection(INSTANCE)
@@ -191,7 +191,7 @@ def test__release_session(self, mock_database):
191191
pool.put.assert_called_once_with("session")
192192
self.assertIsNone(connection._session)
193193

194-
def test__release_session_database_error(self):
194+
def test_release_session_database_error(self):
195195
from google.cloud.spanner_dbapi import Connection
196196

197197
connection = Connection(INSTANCE)

tests/unit/test__helpers.py

+78
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import unittest
17+
import mock
1718

1819

1920
class Test_merge_query_options(unittest.TestCase):
@@ -671,6 +672,83 @@ def test(self):
671672
self.assertEqual(metadata, [("google-cloud-resource-prefix", prefix)])
672673

673674

675+
class Test_retry(unittest.TestCase):
676+
class test_class:
677+
def test_fxn(self):
678+
return True
679+
680+
def test_retry_on_error(self):
681+
from google.api_core.exceptions import InternalServerError, NotFound
682+
from google.cloud.spanner_v1._helpers import _retry
683+
import functools
684+
685+
test_api = mock.create_autospec(self.test_class)
686+
test_api.test_fxn.side_effect = [
687+
InternalServerError("testing"),
688+
NotFound("testing"),
689+
True,
690+
]
691+
692+
_retry(functools.partial(test_api.test_fxn))
693+
694+
self.assertEqual(test_api.test_fxn.call_count, 3)
695+
696+
def test_retry_allowed_exceptions(self):
697+
from google.api_core.exceptions import InternalServerError, NotFound
698+
from google.cloud.spanner_v1._helpers import _retry
699+
import functools
700+
701+
test_api = mock.create_autospec(self.test_class)
702+
test_api.test_fxn.side_effect = [
703+
NotFound("testing"),
704+
InternalServerError("testing"),
705+
True,
706+
]
707+
708+
with self.assertRaises(InternalServerError):
709+
_retry(
710+
functools.partial(test_api.test_fxn),
711+
allowed_exceptions={NotFound: None},
712+
)
713+
714+
self.assertEqual(test_api.test_fxn.call_count, 2)
715+
716+
def test_retry_count(self):
717+
from google.api_core.exceptions import InternalServerError
718+
from google.cloud.spanner_v1._helpers import _retry
719+
import functools
720+
721+
test_api = mock.create_autospec(self.test_class)
722+
test_api.test_fxn.side_effect = [
723+
InternalServerError("testing"),
724+
InternalServerError("testing"),
725+
]
726+
727+
with self.assertRaises(InternalServerError):
728+
_retry(functools.partial(test_api.test_fxn), retry_count=1)
729+
730+
self.assertEqual(test_api.test_fxn.call_count, 2)
731+
732+
def test_check_rst_stream_error(self):
733+
from google.api_core.exceptions import InternalServerError
734+
from google.cloud.spanner_v1._helpers import _retry, _check_rst_stream_error
735+
import functools
736+
737+
test_api = mock.create_autospec(self.test_class)
738+
test_api.test_fxn.side_effect = [
739+
InternalServerError("Received unexpected EOS on DATA frame from server"),
740+
InternalServerError("RST_STREAM"),
741+
True,
742+
]
743+
744+
_retry(
745+
functools.partial(test_api.test_fxn),
746+
allowed_exceptions={InternalServerError: _check_rst_stream_error},
747+
)
748+
749+
self.assertEqual(test_api.test_fxn.call_count, 3)
750+
751+
674752
class Test_metadata_with_leader_aware_routing(unittest.TestCase):
675753
def _call_fut(self, *args, **kw):
676754
from google.cloud.spanner_v1._helpers import _metadata_with_leader_aware_routing

0 commit comments

Comments
 (0)