Skip to content

Commit bb5fa1f

Browse files
authored
feat: Implementation of client side statements that return (#1046)
* Implementation of client side statements that return * Small fix * Incorporated comments * Added tests for exception in commit and rollback * Fix in tests * Skipping few tests from running in emulator * Few fixes * Refactoring * Incorporated comments * Incorporating comments
1 parent 95b8e74 commit bb5fa1f

11 files changed

+581
-259
lines changed

google/cloud/spanner_dbapi/client_side_statement_executor.py

+60-6
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,27 @@
1515

1616
if TYPE_CHECKING:
1717
from google.cloud.spanner_dbapi import Connection
18+
from google.cloud.spanner_dbapi import ProgrammingError
19+
1820
from google.cloud.spanner_dbapi.parsed_statement import (
1921
ParsedStatement,
2022
ClientSideStatementType,
2123
)
24+
from google.cloud.spanner_v1 import (
25+
Type,
26+
StructType,
27+
TypeCode,
28+
ResultSetMetadata,
29+
PartialResultSet,
30+
)
31+
32+
from google.cloud.spanner_v1._helpers import _make_value_pb
33+
from google.cloud.spanner_v1.streamed import StreamedResultSet
34+
35+
CONNECTION_CLOSED_ERROR = "This connection is closed"
36+
TRANSACTION_NOT_STARTED_WARNING = (
37+
"This method is non-operational as a transaction has not been started."
38+
)
2239

2340

2441
def execute(connection: "Connection", parsed_statement: ParsedStatement):
@@ -32,9 +49,46 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
3249
:type parsed_statement: ParsedStatement
3350
:param parsed_statement: parsed_statement based on the sql query
3451
"""
35-
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
36-
return connection.commit()
37-
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
38-
return connection.begin()
39-
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
40-
return connection.rollback()
52+
if connection.is_closed:
53+
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
54+
statement_type = parsed_statement.client_side_statement_type
55+
if statement_type == ClientSideStatementType.COMMIT:
56+
connection.commit()
57+
return None
58+
if statement_type == ClientSideStatementType.BEGIN:
59+
connection.begin()
60+
return None
61+
if statement_type == ClientSideStatementType.ROLLBACK:
62+
connection.rollback()
63+
return None
64+
if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP:
65+
if connection._transaction is None:
66+
committed_timestamp = None
67+
else:
68+
committed_timestamp = connection._transaction.committed
69+
return _get_streamed_result_set(
70+
ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name,
71+
TypeCode.TIMESTAMP,
72+
committed_timestamp,
73+
)
74+
if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP:
75+
if connection._snapshot is None:
76+
read_timestamp = None
77+
else:
78+
read_timestamp = connection._snapshot._transaction_read_timestamp
79+
return _get_streamed_result_set(
80+
ClientSideStatementType.SHOW_READ_TIMESTAMP.name,
81+
TypeCode.TIMESTAMP,
82+
read_timestamp,
83+
)
84+
85+
86+
def _get_streamed_result_set(column_name, type_code, column_value):
87+
struct_type_pb = StructType(
88+
fields=[StructType.Field(name=column_name, type_=Type(code=type_code))]
89+
)
90+
91+
result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb))
92+
if column_value is not None:
93+
result_set.values.extend([_make_value_pb(column_value)])
94+
return StreamedResultSet(iter([result_set]))

google/cloud/spanner_dbapi/client_side_statement_parser.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
2424
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
2525
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)
26+
RE_SHOW_COMMIT_TIMESTAMP = re.compile(
27+
r"^\s*(SHOW)\s+(VARIABLE)\s+(COMMIT_TIMESTAMP)", re.IGNORECASE
28+
)
29+
RE_SHOW_READ_TIMESTAMP = re.compile(
30+
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
31+
)
2632

2733

2834
def parse_stmt(query):
@@ -37,16 +43,19 @@ def parse_stmt(query):
3743
:rtype: ParsedStatement
3844
:returns: ParsedStatement object.
3945
"""
46+
client_side_statement_type = None
4047
if RE_COMMIT.match(query):
41-
return ParsedStatement(
42-
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
43-
)
48+
client_side_statement_type = ClientSideStatementType.COMMIT
4449
if RE_BEGIN.match(query):
45-
return ParsedStatement(
46-
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
47-
)
50+
client_side_statement_type = ClientSideStatementType.BEGIN
4851
if RE_ROLLBACK.match(query):
52+
client_side_statement_type = ClientSideStatementType.ROLLBACK
53+
if RE_SHOW_COMMIT_TIMESTAMP.match(query):
54+
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
55+
if RE_SHOW_READ_TIMESTAMP.match(query):
56+
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
57+
if client_side_statement_type is not None:
4958
return ParsedStatement(
50-
StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK
59+
StatementType.CLIENT_SIDE, query, client_side_statement_type
5160
)
5261
return None

google/cloud/spanner_dbapi/connection.py

+38-43
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from google.cloud.spanner_v1 import RequestOptions
2424
from google.cloud.spanner_v1.session import _get_retry_delay
2525
from google.cloud.spanner_v1.snapshot import Snapshot
26+
from deprecated import deprecated
2627

2728
from google.cloud.spanner_dbapi.checksum import _compare_checksums
2829
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
@@ -35,7 +36,7 @@
3536

3637

3738
CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
38-
"This method is non-operational as transaction has not started"
39+
"This method is non-operational as a transaction has not been started."
3940
)
4041
MAX_INTERNAL_RETRIES = 50
4142

@@ -107,6 +108,9 @@ def __init__(self, instance, database=None, read_only=False):
107108
self._staleness = None
108109
self.request_priority = None
109110
self._transaction_begin_marked = False
111+
# whether transaction started at Spanner. This means that we had
112+
# made atleast one call to Spanner.
113+
self._spanner_transaction_started = False
110114

111115
@property
112116
def autocommit(self):
@@ -140,26 +144,15 @@ def database(self):
140144
return self._database
141145

142146
@property
143-
def _spanner_transaction_started(self):
144-
"""Flag: whether transaction started at Spanner. This means that we had
145-
made atleast one call to Spanner. Property client_transaction_started
146-
would always be true if this is true as transaction has to start first
147-
at clientside than at Spanner
148-
149-
Returns:
150-
bool: True if Spanner transaction started, False otherwise.
151-
"""
147+
@deprecated(
148+
reason="This method is deprecated. Use _spanner_transaction_started field"
149+
)
150+
def inside_transaction(self):
152151
return (
153152
self._transaction
154153
and not self._transaction.committed
155154
and not self._transaction.rolled_back
156-
) or (self._snapshot is not None)
157-
158-
@property
159-
def inside_transaction(self):
160-
"""Deprecated property which won't be supported in future versions.
161-
Please use spanner_transaction_started property instead."""
162-
return self._spanner_transaction_started
155+
)
163156

164157
@property
165158
def _client_transaction_started(self):
@@ -277,7 +270,8 @@ def _release_session(self):
277270
"""
278271
if self.database is None:
279272
raise ValueError("Database needs to be passed for this operation")
280-
self.database._pool.put(self._session)
273+
if self._session is not None:
274+
self.database._pool.put(self._session)
281275
self._session = None
282276

283277
def retry_transaction(self):
@@ -293,7 +287,7 @@ def retry_transaction(self):
293287
"""
294288
attempt = 0
295289
while True:
296-
self._transaction = None
290+
self._spanner_transaction_started = False
297291
attempt += 1
298292
if attempt > MAX_INTERNAL_RETRIES:
299293
raise
@@ -319,7 +313,6 @@ def _rerun_previous_statements(self):
319313
status, res = transaction.batch_update(statements)
320314

321315
if status.code == ABORTED:
322-
self.connection._transaction = None
323316
raise Aborted(status.details)
324317

325318
retried_checksum = ResultsChecksum()
@@ -363,6 +356,8 @@ def transaction_checkout(self):
363356
if not self.read_only and self._client_transaction_started:
364357
if not self._spanner_transaction_started:
365358
self._transaction = self._session_checkout().transaction()
359+
self._snapshot = None
360+
self._spanner_transaction_started = True
366361
self._transaction.begin()
367362

368363
return self._transaction
@@ -377,11 +372,13 @@ def snapshot_checkout(self):
377372
:returns: A Cloud Spanner snapshot object, ready to use.
378373
"""
379374
if self.read_only and self._client_transaction_started:
380-
if not self._snapshot:
375+
if not self._spanner_transaction_started:
381376
self._snapshot = Snapshot(
382377
self._session_checkout(), multi_use=True, **self.staleness
383378
)
379+
self._transaction = None
384380
self._snapshot.begin()
381+
self._spanner_transaction_started = True
385382

386383
return self._snapshot
387384

@@ -391,7 +388,7 @@ def close(self):
391388
The connection will be unusable from this point forward. If the
392389
connection has an active transaction, it will be rolled back.
393390
"""
394-
if self._spanner_transaction_started and not self.read_only:
391+
if self._spanner_transaction_started and not self._read_only:
395392
self._transaction.rollback()
396393

397394
if self._own_pool and self.database:
@@ -405,13 +402,15 @@ def begin(self):
405402
Marks the transaction as started.
406403
407404
:raises: :class:`InterfaceError`: if this connection is closed.
408-
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
405+
:raises: :class:`OperationalError`: if there is an existing transaction
406+
that has been started
409407
"""
410408
if self._transaction_begin_marked:
411409
raise OperationalError("A transaction has already started")
412410
if self._spanner_transaction_started:
413411
raise OperationalError(
414-
"Beginning a new transaction is not allowed when a transaction is already running"
412+
"Beginning a new transaction is not allowed when a transaction "
413+
"is already running"
415414
)
416415
self._transaction_begin_marked = True
417416

@@ -430,41 +429,37 @@ def commit(self):
430429
return
431430

432431
self.run_prior_DDL_statements()
433-
if self._spanner_transaction_started:
434-
try:
435-
if self.read_only:
436-
self._snapshot = None
437-
else:
438-
self._transaction.commit()
439-
440-
self._release_session()
441-
self._statements = []
442-
self._transaction_begin_marked = False
443-
except Aborted:
444-
self.retry_transaction()
445-
self.commit()
432+
try:
433+
if self._spanner_transaction_started and not self._read_only:
434+
self._transaction.commit()
435+
except Aborted:
436+
self.retry_transaction()
437+
self.commit()
438+
finally:
439+
self._release_session()
440+
self._statements = []
441+
self._transaction_begin_marked = False
442+
self._spanner_transaction_started = False
446443

447444
def rollback(self):
448445
"""Rolls back any pending transaction.
449446
450447
This is a no-op if there is no active client transaction.
451448
"""
452-
453449
if not self._client_transaction_started:
454450
warnings.warn(
455451
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
456452
)
457453
return
458454

459-
if self._spanner_transaction_started:
460-
if self.read_only:
461-
self._snapshot = None
462-
else:
455+
try:
456+
if self._spanner_transaction_started and not self._read_only:
463457
self._transaction.rollback()
464-
458+
finally:
465459
self._release_session()
466460
self._statements = []
467461
self._transaction_begin_marked = False
462+
self._spanner_transaction_started = False
468463

469464
@check_not_closed
470465
def cursor(self):

0 commit comments

Comments
 (0)