Skip to content

Commit 6640888

Browse files
authored
feat: Fixing and refactoring transaction retry logic in dbapi. Also adding interceptors support for testing (#1056)
* feat: Fixing and refactoring transaction retry logic in dbapi. Also adding interceptors support for testing * Comments incorporated and changes for also storing Cursor object with the statements details added for retry * Some refactoring of transaction_helper.py and maintaining state of rows update count for batch dml in cursor * Small fix * Maintaining a map from cursor to last statement added in transaction_helper.py * Rolling back the transaction when Aborted exception is thrown from interceptor * Small change * Disabling a test for emulator run * Reformatting
1 parent 7ada21c commit 6640888

16 files changed

+1812
-988
lines changed

google/cloud/spanner_dbapi/batch_dml_executor.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from enum import Enum
1818
from typing import TYPE_CHECKING, List
19-
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
2019
from google.cloud.spanner_dbapi.parsed_statement import (
2120
ParsedStatement,
2221
StatementType,
@@ -80,8 +79,10 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
8079
"""
8180
from google.cloud.spanner_dbapi import OperationalError
8281

83-
connection = cursor.connection
8482
many_result_set = StreamedManyResultSets()
83+
if not statements:
84+
return many_result_set
85+
connection = cursor.connection
8586
statements_tuple = []
8687
for statement in statements:
8788
statements_tuple.append(statement.get_tuple())
@@ -90,28 +91,26 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
9091
many_result_set.add_iter(res)
9192
cursor._row_count = sum([max(val, 0) for val in res])
9293
else:
93-
retried = False
9494
while True:
9595
try:
9696
transaction = connection.transaction_checkout()
9797
status, res = transaction.batch_update(statements_tuple)
98-
many_result_set.add_iter(res)
99-
res_checksum = ResultsChecksum()
100-
res_checksum.consume_result(res)
101-
res_checksum.consume_result(status.code)
102-
if not retried:
103-
connection._statements.append((statements, res_checksum))
104-
cursor._row_count = sum([max(val, 0) for val in res])
105-
10698
if status.code == ABORTED:
10799
connection._transaction = None
108100
raise Aborted(status.message)
109101
elif status.code != OK:
110102
raise OperationalError(status.message)
103+
104+
cursor._batch_dml_rows_count = res
105+
many_result_set.add_iter(res)
106+
cursor._row_count = sum([max(val, 0) for val in res])
111107
return many_result_set
112108
except Aborted:
113-
connection.retry_transaction()
114-
retried = True
109+
# We are raising it so it could be handled in transaction_helper.py and is retried
110+
if cursor._in_retry_mode:
111+
raise
112+
else:
113+
connection._transaction_helper.retry_transaction()
115114

116115

117116
def _do_batch_update(transaction, statements):

google/cloud/spanner_dbapi/checksum.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def consume_result(self, result):
6262

6363

6464
def _compare_checksums(original, retried):
65+
from google.cloud.spanner_dbapi.transaction_helper import RETRY_ABORTED_ERROR
66+
6567
"""Compare the given checksums.
6668
6769
Raise an error if the given checksums are not equal.
@@ -75,6 +77,4 @@ def _compare_checksums(original, retried):
7577
:raises: :exc:`google.cloud.spanner_dbapi.exceptions.RetryAborted` in case if checksums are not equal.
7678
"""
7779
if retried != original:
78-
raise RetryAborted(
79-
"The transaction was aborted and could not be retried due to a concurrent modification."
80-
)
80+
raise RetryAborted(RETRY_ABORTED_ERROR)

google/cloud/spanner_dbapi/connection.py

+19-108
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
"""DB-API Connection for the Google Cloud Spanner."""
16-
import time
1716
import warnings
1817

1918
from google.api_core.exceptions import Aborted
@@ -23,19 +22,16 @@
2322
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
2423
from google.cloud.spanner_dbapi.parse_utils import _get_statement_type
2524
from google.cloud.spanner_dbapi.parsed_statement import (
26-
ParsedStatement,
27-
Statement,
2825
StatementType,
2926
)
3027
from google.cloud.spanner_dbapi.partition_helper import PartitionId
28+
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
29+
from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper
30+
from google.cloud.spanner_dbapi.cursor import Cursor
3131
from google.cloud.spanner_v1 import RequestOptions
32-
from google.cloud.spanner_v1.session import _get_retry_delay
3332
from google.cloud.spanner_v1.snapshot import Snapshot
3433
from deprecated import deprecated
3534

36-
from google.cloud.spanner_dbapi.checksum import _compare_checksums
37-
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
38-
from google.cloud.spanner_dbapi.cursor import Cursor
3935
from google.cloud.spanner_dbapi.exceptions import (
4036
InterfaceError,
4137
OperationalError,
@@ -44,13 +40,10 @@
4440
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
4541
from google.cloud.spanner_dbapi.version import PY_VERSION
4642

47-
from google.rpc.code_pb2 import ABORTED
48-
4943

5044
CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
5145
"This method is non-operational as a transaction has not been started."
5246
)
53-
MAX_INTERNAL_RETRIES = 50
5447

5548

5649
def check_not_closed(function):
@@ -106,9 +99,6 @@ def __init__(self, instance, database=None, read_only=False):
10699
self._transaction = None
107100
self._session = None
108101
self._snapshot = None
109-
# SQL statements, which were executed
110-
# within the current transaction
111-
self._statements = []
112102

113103
self.is_closed = False
114104
self._autocommit = False
@@ -125,6 +115,7 @@ def __init__(self, instance, database=None, read_only=False):
125115
self._spanner_transaction_started = False
126116
self._batch_mode = BatchMode.NONE
127117
self._batch_dml_executor: BatchDmlExecutor = None
118+
self._transaction_helper = TransactionRetryHelper(self)
128119

129120
@property
130121
def autocommit(self):
@@ -288,76 +279,6 @@ def _release_session(self):
288279
self.database._pool.put(self._session)
289280
self._session = None
290281

291-
def retry_transaction(self):
292-
"""Retry the aborted transaction.
293-
294-
All the statements executed in the original transaction
295-
will be re-executed in new one. Results checksums of the
296-
original statements and the retried ones will be compared.
297-
298-
:raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
299-
If results checksum of the retried statement is
300-
not equal to the checksum of the original one.
301-
"""
302-
attempt = 0
303-
while True:
304-
self._spanner_transaction_started = False
305-
attempt += 1
306-
if attempt > MAX_INTERNAL_RETRIES:
307-
raise
308-
309-
try:
310-
self._rerun_previous_statements()
311-
break
312-
except Aborted as exc:
313-
delay = _get_retry_delay(exc.errors[0], attempt)
314-
if delay:
315-
time.sleep(delay)
316-
317-
def _rerun_previous_statements(self):
318-
"""
319-
Helper to run all the remembered statements
320-
from the last transaction.
321-
"""
322-
for statement in self._statements:
323-
if isinstance(statement, list):
324-
statements, checksum = statement
325-
326-
transaction = self.transaction_checkout()
327-
statements_tuple = []
328-
for single_statement in statements:
329-
statements_tuple.append(single_statement.get_tuple())
330-
status, res = transaction.batch_update(statements_tuple)
331-
332-
if status.code == ABORTED:
333-
raise Aborted(status.details)
334-
335-
retried_checksum = ResultsChecksum()
336-
retried_checksum.consume_result(res)
337-
retried_checksum.consume_result(status.code)
338-
339-
_compare_checksums(checksum, retried_checksum)
340-
else:
341-
res_iter, retried_checksum = self.run_statement(statement, retried=True)
342-
# executing all the completed statements
343-
if statement != self._statements[-1]:
344-
for res in res_iter:
345-
retried_checksum.consume_result(res)
346-
347-
_compare_checksums(statement.checksum, retried_checksum)
348-
# executing the failed statement
349-
else:
350-
# streaming up to the failed result or
351-
# to the end of the streaming iterator
352-
while len(retried_checksum) < len(statement.checksum):
353-
try:
354-
res = next(iter(res_iter))
355-
retried_checksum.consume_result(res)
356-
except StopIteration:
357-
break
358-
359-
_compare_checksums(statement.checksum, retried_checksum)
360-
361282
def transaction_checkout(self):
362283
"""Get a Cloud Spanner transaction.
363284
@@ -433,12 +354,10 @@ def begin(self):
433354

434355
def commit(self):
435356
"""Commits any pending transaction to the database.
436-
437357
This is a no-op if there is no active client transaction.
438358
"""
439359
if self.database is None:
440360
raise ValueError("Database needs to be passed for this operation")
441-
442361
if not self._client_transaction_started:
443362
warnings.warn(
444363
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
@@ -450,33 +369,31 @@ def commit(self):
450369
if self._spanner_transaction_started and not self._read_only:
451370
self._transaction.commit()
452371
except Aborted:
453-
self.retry_transaction()
372+
self._transaction_helper.retry_transaction()
454373
self.commit()
455374
finally:
456-
self._release_session()
457-
self._statements = []
458-
self._transaction_begin_marked = False
459-
self._spanner_transaction_started = False
375+
self._reset_post_commit_or_rollback()
460376

461377
def rollback(self):
462378
"""Rolls back any pending transaction.
463-
464379
This is a no-op if there is no active client transaction.
465380
"""
466381
if not self._client_transaction_started:
467382
warnings.warn(
468383
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
469384
)
470385
return
471-
472386
try:
473387
if self._spanner_transaction_started and not self._read_only:
474388
self._transaction.rollback()
475389
finally:
476-
self._release_session()
477-
self._statements = []
478-
self._transaction_begin_marked = False
479-
self._spanner_transaction_started = False
390+
self._reset_post_commit_or_rollback()
391+
392+
def _reset_post_commit_or_rollback(self):
393+
self._release_session()
394+
self._transaction_helper.reset()
395+
self._transaction_begin_marked = False
396+
self._spanner_transaction_started = False
480397

481398
@check_not_closed
482399
def cursor(self):
@@ -493,7 +410,7 @@ def run_prior_DDL_statements(self):
493410

494411
return self.database.update_ddl(ddl_statements).result()
495412

496-
def run_statement(self, statement: Statement, retried=False):
413+
def run_statement(self, statement: Statement):
497414
"""Run single SQL statement in begun transaction.
498415
499416
This method is never used in autocommit mode. In
@@ -513,17 +430,11 @@ def run_statement(self, statement: Statement, retried=False):
513430
checksum of this statement results.
514431
"""
515432
transaction = self.transaction_checkout()
516-
if not retried:
517-
self._statements.append(statement)
518-
519-
return (
520-
transaction.execute_sql(
521-
statement.sql,
522-
statement.params,
523-
param_types=statement.param_types,
524-
request_options=self.request_options,
525-
),
526-
ResultsChecksum() if retried else statement.checksum,
433+
return transaction.execute_sql(
434+
statement.sql,
435+
statement.params,
436+
param_types=statement.param_types,
437+
request_options=self.request_options,
527438
)
528439

529440
@check_not_closed

0 commit comments

Comments
 (0)