Skip to content

Commit f3b23b2

Browse files
authored
feat: Implementation of run partition query (#1080)
* feat: Implementation of run partition query * Comments incorporated * Comments incorporated * Comments incorporated
1 parent ec87c08 commit f3b23b2

10 files changed

+388
-28
lines changed

google/cloud/spanner_dbapi/client_side_statement_executor.py

+2
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
103103
return connection.run_partition(
104104
parsed_statement.client_side_statement_params[0]
105105
)
106+
if statement_type == ClientSideStatementType.RUN_PARTITIONED_QUERY:
107+
return connection.run_partitioned_query(parsed_statement)
106108

107109

108110
def _get_streamed_result_set(column_name, type_code, column_values):

google/cloud/spanner_dbapi/client_side_statement_parser.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)
3636
RE_PARTITION_QUERY = re.compile(r"^\s*(PARTITION)\s+(.+)", re.IGNORECASE)
3737
RE_RUN_PARTITION = re.compile(r"^\s*(RUN)\s+(PARTITION)\s+(.+)", re.IGNORECASE)
38+
RE_RUN_PARTITIONED_QUERY = re.compile(
39+
r"^\s*(RUN)\s+(PARTITIONED)\s+(QUERY)\s+(.+)", re.IGNORECASE
40+
)
3841

3942

4043
def parse_stmt(query):
@@ -53,25 +56,29 @@ def parse_stmt(query):
5356
client_side_statement_params = []
5457
if RE_COMMIT.match(query):
5558
client_side_statement_type = ClientSideStatementType.COMMIT
56-
if RE_BEGIN.match(query):
57-
client_side_statement_type = ClientSideStatementType.BEGIN
58-
if RE_ROLLBACK.match(query):
59+
elif RE_ROLLBACK.match(query):
5960
client_side_statement_type = ClientSideStatementType.ROLLBACK
60-
if RE_SHOW_COMMIT_TIMESTAMP.match(query):
61+
elif RE_SHOW_COMMIT_TIMESTAMP.match(query):
6162
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
62-
if RE_SHOW_READ_TIMESTAMP.match(query):
63+
elif RE_SHOW_READ_TIMESTAMP.match(query):
6364
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
64-
if RE_START_BATCH_DML.match(query):
65+
elif RE_START_BATCH_DML.match(query):
6566
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
66-
if RE_RUN_BATCH.match(query):
67+
elif RE_BEGIN.match(query):
68+
client_side_statement_type = ClientSideStatementType.BEGIN
69+
elif RE_RUN_BATCH.match(query):
6770
client_side_statement_type = ClientSideStatementType.RUN_BATCH
68-
if RE_ABORT_BATCH.match(query):
71+
elif RE_ABORT_BATCH.match(query):
6972
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
70-
if RE_PARTITION_QUERY.match(query):
73+
elif RE_RUN_PARTITIONED_QUERY.match(query):
74+
match = re.search(RE_RUN_PARTITIONED_QUERY, query)
75+
client_side_statement_params.append(match.group(4))
76+
client_side_statement_type = ClientSideStatementType.RUN_PARTITIONED_QUERY
77+
elif RE_PARTITION_QUERY.match(query):
7178
match = re.search(RE_PARTITION_QUERY, query)
7279
client_side_statement_params.append(match.group(2))
7380
client_side_statement_type = ClientSideStatementType.PARTITION_QUERY
74-
if RE_RUN_PARTITION.match(query):
81+
elif RE_RUN_PARTITION.match(query):
7582
match = re.search(RE_RUN_PARTITION, query)
7683
client_side_statement_params.append(match.group(3))
7784
client_side_statement_type = ClientSideStatementType.RUN_PARTITION

google/cloud/spanner_dbapi/connection.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -511,15 +511,7 @@ def partition_query(
511511
):
512512
statement = parsed_statement.statement
513513
partitioned_query = parsed_statement.client_side_statement_params[0]
514-
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
515-
raise ProgrammingError(
516-
"Only queries can be partitioned. Invalid statement: " + statement.sql
517-
)
518-
if self.read_only is not True and self._client_transaction_started is True:
519-
raise ProgrammingError(
520-
"Partitioned query not supported as the connection is not in "
521-
"read only mode or ReadWrite transaction started"
522-
)
514+
self._partitioned_query_validation(partitioned_query, statement)
523515

524516
batch_snapshot = self._database.batch_snapshot()
525517
partition_ids = []
@@ -531,17 +523,18 @@ def partition_query(
531523
query_options=query_options,
532524
)
533525
)
526+
527+
batch_transaction_id = batch_snapshot.get_batch_transaction_id()
534528
for partition in partitions:
535-
batch_transaction_id = batch_snapshot.get_batch_transaction_id()
536529
partition_ids.append(
537530
partition_helper.encode_to_string(batch_transaction_id, partition)
538531
)
539532
return partition_ids
540533

541534
@check_not_closed
542-
def run_partition(self, batch_transaction_id):
535+
def run_partition(self, encoded_partition_id):
543536
partition_id: PartitionId = partition_helper.decode_from_string(
544-
batch_transaction_id
537+
encoded_partition_id
545538
)
546539
batch_transaction_id = partition_id.batch_transaction_id
547540
batch_snapshot = self._database.batch_snapshot(
@@ -551,6 +544,29 @@ def run_partition(self, batch_transaction_id):
551544
)
552545
return batch_snapshot.process(partition_id.partition_result)
553546

547+
@check_not_closed
548+
def run_partitioned_query(
549+
self,
550+
parsed_statement: ParsedStatement,
551+
):
552+
statement = parsed_statement.statement
553+
partitioned_query = parsed_statement.client_side_statement_params[0]
554+
self._partitioned_query_validation(partitioned_query, statement)
555+
batch_snapshot = self._database.batch_snapshot()
556+
return batch_snapshot.run_partitioned_query(
557+
partitioned_query, statement.params, statement.param_types
558+
)
559+
560+
def _partitioned_query_validation(self, partitioned_query, statement):
561+
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
562+
raise ProgrammingError(
563+
"Only queries can be partitioned. Invalid statement: " + statement.sql
564+
)
565+
if self.read_only is not True and self._client_transaction_started is True:
566+
raise ProgrammingError(
567+
"Partitioned query is not supported, because the connection is in a read/write transaction."
568+
)
569+
554570
def __enter__(self):
555571
return self
556572

google/cloud/spanner_dbapi/cursor.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType
5050
from google.cloud.spanner_dbapi.utils import PeekIterator
5151
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
52+
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
5253

5354
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
5455

@@ -248,7 +249,9 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
248249
self, self._parsed_statement
249250
)
250251
if self._result_set is not None:
251-
if isinstance(self._result_set, StreamedManyResultSets):
252+
if isinstance(
253+
self._result_set, StreamedManyResultSets
254+
) or isinstance(self._result_set, MergedResultSet):
252255
self._itr = self._result_set
253256
else:
254257
self._itr = PeekIterator(self._result_set)

google/cloud/spanner_dbapi/parsed_statement.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class ClientSideStatementType(Enum):
3535
ABORT_BATCH = 8
3636
PARTITION_QUERY = 9
3737
RUN_PARTITION = 10
38+
RUN_PARTITIONED_QUERY = 11
3839

3940

4041
@dataclass

google/cloud/spanner_v1/database.py

+67-5
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from google.cloud.spanner_v1.batch import Batch
5555
from google.cloud.spanner_v1.batch import MutationGroups
5656
from google.cloud.spanner_v1.keyset import KeySet
57+
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
5758
from google.cloud.spanner_v1.pool import BurstyPool
5859
from google.cloud.spanner_v1.pool import SessionCheckout
5960
from google.cloud.spanner_v1.session import Session
@@ -1416,11 +1417,6 @@ def generate_query_batches(
14161417
(Optional) desired size for each partition generated. The service
14171418
uses this as a hint, the actual partition size may differ.
14181419
1419-
:type partition_size_bytes: int
1420-
:param partition_size_bytes:
1421-
(Optional) desired size for each partition generated. The service
1422-
uses this as a hint, the actual partition size may differ.
1423-
14241420
:type max_partitions: int
14251421
:param max_partitions:
14261422
(Optional) desired maximum number of partitions generated. The
@@ -1513,6 +1509,72 @@ def process_query_batch(
15131509
partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout
15141510
)
15151511

1512+
def run_partitioned_query(
1513+
self,
1514+
sql,
1515+
params=None,
1516+
param_types=None,
1517+
partition_size_bytes=None,
1518+
max_partitions=None,
1519+
query_options=None,
1520+
data_boost_enabled=False,
1521+
):
1522+
"""Start a partitioned query operation to get list of partitions and
1523+
then executes each partition on a separate thread
1524+
1525+
:type sql: str
1526+
:param sql: SQL query statement
1527+
1528+
:type params: dict, {str -> column value}
1529+
:param params: values for parameter replacement. Keys must match
1530+
the names used in ``sql``.
1531+
1532+
:type param_types: dict[str -> Union[dict, .types.Type]]
1533+
:param param_types:
1534+
(Optional) maps explicit types for one or more param values;
1535+
required if parameters are passed.
1536+
1537+
:type partition_size_bytes: int
1538+
:param partition_size_bytes:
1539+
(Optional) desired size for each partition generated. The service
1540+
uses this as a hint, the actual partition size may differ.
1541+
1542+
:type max_partitions: int
1543+
:param max_partitions:
1544+
(Optional) desired maximum number of partitions generated. The
1545+
service uses this as a hint, the actual number of partitions may
1546+
differ.
1547+
1548+
:type query_options:
1549+
:class:`~google.cloud.spanner_v1.types.ExecuteSqlRequest.QueryOptions`
1550+
or :class:`dict`
1551+
:param query_options:
1552+
(Optional) Query optimizer configuration to use for the given query.
1553+
If a dict is provided, it must be of the same form as the protobuf
1554+
message :class:`~google.cloud.spanner_v1.types.QueryOptions`
1555+
1556+
:type data_boost_enabled:
1557+
:param data_boost_enabled:
1558+
(Optional) If this is for a partitioned query and this field is
1559+
set ``true``, the request will be executed using data boost.
1560+
Please see https://cloud.google.com/spanner/docs/databoost/databoost-overview
1561+
1562+
:rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet`
1563+
:returns: a result set instance which can be used to consume rows.
1564+
"""
1565+
partitions = list(
1566+
self.generate_query_batches(
1567+
sql,
1568+
params,
1569+
param_types,
1570+
partition_size_bytes,
1571+
max_partitions,
1572+
query_options,
1573+
data_boost_enabled,
1574+
)
1575+
)
1576+
return MergedResultSet(self, partitions, 0)
1577+
15161578
def process(self, batch):
15171579
"""Process a single, partitioned query or read.
15181580
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from concurrent.futures import ThreadPoolExecutor
15+
from dataclasses import dataclass
16+
from queue import Queue
17+
from typing import Any, TYPE_CHECKING
18+
from threading import Lock, Event
19+
20+
if TYPE_CHECKING:
21+
from google.cloud.spanner_v1.database import BatchSnapshot
22+
23+
QUEUE_SIZE_PER_WORKER = 32
24+
MAX_PARALLELISM = 16
25+
26+
27+
class PartitionExecutor:
28+
"""
29+
Executor that executes single partition on a separate thread and inserts
30+
rows in the queue
31+
"""
32+
33+
def __init__(self, batch_snapshot, partition_id, merged_result_set):
34+
self._batch_snapshot: BatchSnapshot = batch_snapshot
35+
self._partition_id = partition_id
36+
self._merged_result_set: MergedResultSet = merged_result_set
37+
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue
38+
39+
def run(self):
40+
results = None
41+
try:
42+
results = self._batch_snapshot.process_query_batch(self._partition_id)
43+
for row in results:
44+
if self._merged_result_set._metadata is None:
45+
self._set_metadata(results)
46+
self._queue.put(PartitionExecutorResult(data=row))
47+
# Special case: The result set did not return any rows.
48+
# Push the metadata to the merged result set.
49+
if self._merged_result_set._metadata is None:
50+
self._set_metadata(results)
51+
except Exception as ex:
52+
if self._merged_result_set._metadata is None:
53+
self._set_metadata(results, True)
54+
self._queue.put(PartitionExecutorResult(exception=ex))
55+
finally:
56+
# Emit a special 'is_last' result to ensure that the MergedResultSet
57+
# is not blocked on a queue that never receives any more results.
58+
self._queue.put(PartitionExecutorResult(is_last=True))
59+
60+
def _set_metadata(self, results, is_exception=False):
61+
self._merged_result_set.metadata_lock.acquire()
62+
try:
63+
if not is_exception:
64+
self._merged_result_set._metadata = results.metadata
65+
finally:
66+
self._merged_result_set.metadata_lock.release()
67+
self._merged_result_set.metadata_event.set()
68+
69+
70+
@dataclass
71+
class PartitionExecutorResult:
72+
data: Any = None
73+
exception: Exception = None
74+
is_last: bool = False
75+
76+
77+
class MergedResultSet:
78+
"""
79+
Executes multiple partitions on different threads and then combines the
80+
results from multiple queries using a synchronized queue. The order of the
81+
records in the MergedResultSet is not guaranteed.
82+
"""
83+
84+
def __init__(self, batch_snapshot, partition_ids, max_parallelism):
85+
self._exception = None
86+
self._metadata = None
87+
self.metadata_event = Event()
88+
self.metadata_lock = Lock()
89+
90+
partition_ids_count = len(partition_ids)
91+
self._finished_count_down_latch = partition_ids_count
92+
parallelism = min(MAX_PARALLELISM, partition_ids_count)
93+
if max_parallelism != 0:
94+
parallelism = min(partition_ids_count, max_parallelism)
95+
self._queue = Queue(maxsize=QUEUE_SIZE_PER_WORKER * parallelism)
96+
97+
partition_executors = []
98+
for partition_id in partition_ids:
99+
partition_executors.append(
100+
PartitionExecutor(batch_snapshot, partition_id, self)
101+
)
102+
executor = ThreadPoolExecutor(max_workers=parallelism)
103+
for partition_executor in partition_executors:
104+
executor.submit(partition_executor.run)
105+
executor.shutdown(False)
106+
107+
def __iter__(self):
108+
return self
109+
110+
def __next__(self):
111+
if self._exception is not None:
112+
raise self._exception
113+
while True:
114+
partition_result = self._queue.get()
115+
if partition_result.is_last:
116+
self._finished_count_down_latch -= 1
117+
if self._finished_count_down_latch == 0:
118+
raise StopIteration
119+
elif partition_result.exception is not None:
120+
self._exception = partition_result.exception
121+
raise self._exception
122+
else:
123+
return partition_result.data
124+
125+
@property
126+
def metadata(self):
127+
self.metadata_event.wait()
128+
return self._metadata
129+
130+
@property
131+
def stats(self):
132+
# TODO: Implement
133+
return None

0 commit comments

Comments
 (0)