Skip to content

Commit 01a0196

Browse files
committed
fix: update retry strategy for mutation calls to handle aborted transactions
1 parent ab31078 commit 01a0196

File tree

8 files changed

+189
-65
lines changed

8 files changed

+189
-65
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,7 @@ system_tests/local_test_setup
6262
# Make sure a generated file isn't accidentally committed.
6363
pylintrc
6464
pylintrc.test
65+
66+
67+
# Ignore coverage files
68+
.coverage*

google/cloud/spanner_dbapi/transaction_helper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
2222
from google.cloud.spanner_dbapi.exceptions import RetryAborted
23-
from google.cloud.spanner_v1.session import _get_retry_delay
23+
from google.cloud.spanner_v1._helpers import _get_retry_delay
2424

2525
if TYPE_CHECKING:
2626
from google.cloud.spanner_dbapi import Connection, Cursor

google/cloud/spanner_v1/_helpers.py

+75-2
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,15 @@
2727
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper
2828

2929
from google.api_core import datetime_helpers
30+
from google.api_core.exceptions import Aborted
3031
from google.cloud._helpers import _date_from_iso8601_date
3132
from google.cloud.spanner_v1 import TypeCode
3233
from google.cloud.spanner_v1 import ExecuteSqlRequest
3334
from google.cloud.spanner_v1 import JsonObject
3435
from google.cloud.spanner_v1.request_id_header import with_request_id
36+
from google.rpc.error_details_pb2 import RetryInfo
37+
38+
import random
3539

3640
# Validation error messages
3741
NUMERIC_MAX_SCALE_ERR_MSG = (
@@ -466,13 +470,19 @@ def _retry(
466470
delay=2,
467471
allowed_exceptions=None,
468472
beforeNextRetry=None,
473+
deadline=None,
469474
):
470475
"""
471-
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.
476+
Retry a specified function with different logic based on the type of exception raised.
477+
478+
If the exception is of type google.api_core.exceptions.Aborted,
479+
apply an alternate retry strategy that relies on the provided deadline value instead of a fixed number of retries.
480+
For all other exceptions, retry the function up to a specified number of times.
472481
473482
Args:
474483
func: The function to be retried.
475484
retry_count: The maximum number of times to retry the function.
485+
deadline: This will be used in case of Aborted transactions.
476486
delay: The delay in seconds between retries.
477487
allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
478488
Passing allowed_exceptions as None will lead to retrying for all exceptions.
@@ -481,13 +491,21 @@ def _retry(
481491
The result of the function if it is successful, or raises the last exception if all retries fail.
482492
"""
483493
retries = 0
484-
while retries <= retry_count:
494+
while True:
485495
if retries > 0 and beforeNextRetry:
486496
beforeNextRetry(retries, delay)
487497

488498
try:
489499
return func()
490500
except Exception as exc:
501+
if isinstance(exc, Aborted) and deadline is not None:
502+
if (
503+
allowed_exceptions is not None
504+
and allowed_exceptions.get(exc.__class__) is not None
505+
):
506+
retries += 1
507+
_delay_until_retry(exc, deadline=deadline, attempts=retries)
508+
continue
491509
if (
492510
allowed_exceptions is None or exc.__class__ in allowed_exceptions
493511
) and retries < retry_count:
@@ -529,6 +547,61 @@ def _metadata_with_leader_aware_routing(value, **kw):
529547
return ("x-goog-spanner-route-to-leader", str(value).lower())
530548

531549

550+
def _delay_until_retry(exc, deadline, attempts):
551+
"""Helper for :meth:`Session.run_in_transaction`.
552+
553+
Detect retryable abort, and impose server-supplied delay.
554+
555+
:type exc: :class:`google.api_core.exceptions.Aborted`
556+
:param exc: exception for aborted transaction
557+
558+
:type deadline: float
559+
:param deadline: maximum timestamp to continue retrying the transaction.
560+
561+
:type attempts: int
562+
:param attempts: number of call retries
563+
"""
564+
565+
cause = exc.errors[0]
566+
now = time.time()
567+
if now >= deadline:
568+
raise
569+
570+
delay = _get_retry_delay(cause, attempts)
571+
print(now, delay, deadline)
572+
if delay is not None:
573+
if now + delay > deadline:
574+
raise
575+
576+
time.sleep(delay)
577+
578+
579+
def _get_retry_delay(cause, attempts):
580+
"""Helper for :func:`_delay_until_retry`.
581+
582+
:type exc: :class:`grpc.Call`
583+
:param exc: exception for aborted transaction
584+
585+
:rtype: float
586+
:returns: seconds to wait before retrying the transaction.
587+
588+
:type attempts: int
589+
:param attempts: number of call retries
590+
"""
591+
if hasattr(cause, "trailing_metadata"):
592+
metadata = dict(cause.trailing_metadata())
593+
else:
594+
metadata = {}
595+
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
596+
if retry_info_pb is not None:
597+
retry_info = RetryInfo()
598+
retry_info.ParseFromString(retry_info_pb)
599+
nanos = retry_info.retry_delay.nanos
600+
return retry_info.retry_delay.seconds + nanos / 1.0e9
601+
602+
return 2**attempts + random.random()
603+
604+
532605
class AtomicCounter:
533606
def __init__(self, start_value=0):
534607
self.__lock = threading.Lock()

google/cloud/spanner_v1/batch.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
from google.cloud.spanner_v1._helpers import _retry
3232
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
3333
from google.api_core.exceptions import InternalServerError
34+
from google.api_core.exceptions import Aborted
35+
import time
36+
37+
DEFAULT_RETRY_TIMEOUT_SECS = 30
3438

3539

3640
class _BatchBase(_SessionWrapper):
@@ -162,6 +166,7 @@ def commit(
162166
request_options=None,
163167
max_commit_delay=None,
164168
exclude_txn_from_change_streams=False,
169+
**kwargs,
165170
):
166171
"""Commit mutations to the database.
167172
@@ -227,9 +232,16 @@ def commit(
227232
request=request,
228233
metadata=metadata,
229234
)
235+
deadline = time.time() + kwargs.get(
236+
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
237+
)
230238
response = _retry(
231239
method,
232-
allowed_exceptions={InternalServerError: _check_rst_stream_error},
240+
allowed_exceptions={
241+
InternalServerError: _check_rst_stream_error,
242+
Aborted: no_op_handler,
243+
},
244+
deadline=deadline,
233245
)
234246
self.committed = response.commit_timestamp
235247
self.commit_stats = response.commit_stats
@@ -293,7 +305,9 @@ def group(self):
293305
self._mutation_groups.append(mutation_group)
294306
return MutationGroup(self._session, mutation_group.mutations)
295307

296-
def batch_write(self, request_options=None, exclude_txn_from_change_streams=False):
308+
def batch_write(
309+
self, request_options=None, exclude_txn_from_change_streams=False, **kwargs
310+
):
297311
"""Executes batch_write.
298312
299313
:type request_options:
@@ -346,9 +360,16 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
346360
request=request,
347361
metadata=metadata,
348362
)
363+
deadline = time.time() + kwargs.get(
364+
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
365+
)
349366
response = _retry(
350367
method,
351-
allowed_exceptions={InternalServerError: _check_rst_stream_error},
368+
allowed_exceptions={
369+
InternalServerError: _check_rst_stream_error,
370+
Aborted: no_op_handler,
371+
},
372+
deadline=deadline,
352373
)
353374
self.committed = True
354375
return response
@@ -372,3 +393,8 @@ def _make_write_pb(table, columns, values):
372393
return Mutation.Write(
373394
table=table, columns=columns, values=_make_list_value_pbs(values)
374395
)
396+
397+
398+
def no_op_handler(exc):
399+
# No-op (does nothing)
400+
pass

google/cloud/spanner_v1/database.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ def batch(
775775
request_options=None,
776776
max_commit_delay=None,
777777
exclude_txn_from_change_streams=False,
778+
**kw,
778779
):
779780
"""Return an object which wraps a batch.
780781
@@ -805,7 +806,11 @@ def batch(
805806
:returns: new wrapper
806807
"""
807808
return BatchCheckout(
808-
self, request_options, max_commit_delay, exclude_txn_from_change_streams
809+
self,
810+
request_options,
811+
max_commit_delay,
812+
exclude_txn_from_change_streams,
813+
**kw,
809814
)
810815

811816
def mutation_groups(self):
@@ -1166,6 +1171,7 @@ def __init__(
11661171
request_options=None,
11671172
max_commit_delay=None,
11681173
exclude_txn_from_change_streams=False,
1174+
**kw,
11691175
):
11701176
self._database = database
11711177
self._session = self._batch = None
@@ -1177,6 +1183,7 @@ def __init__(
11771183
self._request_options = request_options
11781184
self._max_commit_delay = max_commit_delay
11791185
self._exclude_txn_from_change_streams = exclude_txn_from_change_streams
1186+
self._kw = kw
11801187

11811188
def __enter__(self):
11821189
"""Begin ``with`` block."""
@@ -1197,6 +1204,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
11971204
request_options=self._request_options,
11981205
max_commit_delay=self._max_commit_delay,
11991206
exclude_txn_from_change_streams=self._exclude_txn_from_change_streams,
1207+
**self._kw,
12001208
)
12011209
finally:
12021210
if self._database.log_commit_stats and self._batch.commit_stats:

google/cloud/spanner_v1/session.py

+2-56
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
"""Wrapper for Cloud Spanner Session objects."""
1616

1717
from functools import total_ordering
18-
import random
1918
import time
2019
from datetime import datetime
2120

2221
from google.api_core.exceptions import Aborted
2322
from google.api_core.exceptions import GoogleAPICallError
2423
from google.api_core.exceptions import NotFound
2524
from google.api_core.gapic_v1 import method
26-
from google.rpc.error_details_pb2 import RetryInfo
25+
from google.cloud.spanner_v1._helpers import _delay_until_retry
26+
from google.cloud.spanner_v1._helpers import _get_retry_delay
2727

2828
from google.cloud.spanner_v1 import ExecuteSqlRequest
2929
from google.cloud.spanner_v1 import CreateSessionRequest
@@ -554,57 +554,3 @@ def run_in_transaction(self, func, *args, **kw):
554554
extra={"commit_stats": txn.commit_stats},
555555
)
556556
return return_value
557-
558-
559-
# Rational: this function factors out complex shared deadline / retry
560-
# handling from two `except:` clauses.
561-
def _delay_until_retry(exc, deadline, attempts):
562-
"""Helper for :meth:`Session.run_in_transaction`.
563-
564-
Detect retryable abort, and impose server-supplied delay.
565-
566-
:type exc: :class:`google.api_core.exceptions.Aborted`
567-
:param exc: exception for aborted transaction
568-
569-
:type deadline: float
570-
:param deadline: maximum timestamp to continue retrying the transaction.
571-
572-
:type attempts: int
573-
:param attempts: number of call retries
574-
"""
575-
cause = exc.errors[0]
576-
577-
now = time.time()
578-
579-
if now >= deadline:
580-
raise
581-
582-
delay = _get_retry_delay(cause, attempts)
583-
if delay is not None:
584-
if now + delay > deadline:
585-
raise
586-
587-
time.sleep(delay)
588-
589-
590-
def _get_retry_delay(cause, attempts):
591-
"""Helper for :func:`_delay_until_retry`.
592-
593-
:type exc: :class:`grpc.Call`
594-
:param exc: exception for aborted transaction
595-
596-
:rtype: float
597-
:returns: seconds to wait before retrying the transaction.
598-
599-
:type attempts: int
600-
:param attempts: number of call retries
601-
"""
602-
metadata = dict(cause.trailing_metadata())
603-
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
604-
if retry_info_pb is not None:
605-
retry_info = RetryInfo()
606-
retry_info.ParseFromString(retry_info_pb)
607-
nanos = retry_info.retry_delay.nanos
608-
return retry_info.retry_delay.seconds + nanos / 1.0e9
609-
610-
return 2**attempts + random.random()

0 commit comments

Comments
 (0)