Skip to content

Commit dc1b82a

Browse files
matthew29tangcopybara-github
authored andcommitted
fix: Rollback BigQuery Datasource to use do_write() interface
PiperOrigin-RevId: 577245702
1 parent 1def3f6 commit dc1b82a

File tree

2 files changed

+72
-53
lines changed

2 files changed

+72
-53
lines changed

google/cloud/aiplatform/preview/vertex_ray/bigquery_datasource.py

+50-31
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
import time
2222
from typing import Any, Dict, List, Optional
2323
import uuid
24-
import pyarrow.parquet as pq
2524

2625
from google.api_core import client_info
2726
from google.api_core import exceptions
2827
from google.api_core.gapic_v1 import client_info as v1_client_info
2928
from google.cloud import bigquery
3029
from google.cloud import bigquery_storage
3130
from google.cloud.aiplatform import initializer
32-
33-
from ray.data._internal.execution.interfaces import TaskContext
31+
from google.cloud.bigquery_storage import types
32+
import pyarrow.parquet as pq
33+
from ray.data._internal.remote_fn import cached_remote_fn
3434
from ray.data.block import Block
3535
from ray.data.block import BlockAccessor
3636
from ray.data.block import BlockMetadata
@@ -50,9 +50,6 @@
5050
gapic_version=_BQS_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQS_GAPIC_VERSION}"
5151
)
5252

53-
MAX_RETRY_CNT = 10
54-
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
55-
5653

5754
class _BigQueryDatasourceReader(Reader):
5855
def __init__(
@@ -70,12 +67,12 @@ def __init__(
7067

7168
if query is not None and dataset is not None:
7269
raise ValueError(
73-
"[Ray on Vertex AI]: Query and dataset kwargs cannot both "
74-
+ "be provided (must be mutually exclusive)."
70+
"[Ray on Vertex AI]: Query and dataset kwargs cannot both be provided (must be mutually exclusive)."
7571
)
7672

7773
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
78-
def _read_single_partition(stream) -> Block:
74+
# Executed by a worker node
75+
def _read_single_partition(stream, kwargs) -> Block:
7976
client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info)
8077
reader = client.read_rows(stream.name)
8178
return reader.to_arrow()
@@ -99,9 +96,9 @@ def _read_single_partition(stream) -> Block:
9996

10097
if parallelism == -1:
10198
parallelism = None
102-
requested_session = bigquery_storage.types.ReadSession(
99+
requested_session = types.ReadSession(
103100
table=table,
104-
data_format=bigquery_storage.types.DataFormat.ARROW,
101+
data_format=types.DataFormat.ARROW,
105102
)
106103
read_session = bqs_client.create_read_session(
107104
parent=f"projects/{self._project_id}",
@@ -110,9 +107,9 @@ def _read_single_partition(stream) -> Block:
110107
)
111108

112109
read_tasks = []
113-
logging.info(f"Created streams: {len(read_session.streams)}")
110+
print("[Ray on Vertex AI]: Created streams:", len(read_session.streams))
114111
if len(read_session.streams) < parallelism:
115-
logging.info(
112+
print(
116113
"[Ray on Vertex AI]: The number of streams created by the "
117114
+ "BigQuery Storage Read API is less than the requested "
118115
+ "parallelism due to the size of the dataset."
@@ -128,11 +125,15 @@ def _read_single_partition(stream) -> Block:
128125
exec_stats=None,
129126
)
130127

131-
# Create the read task and pass the no-arg wrapper and metadata in
132-
read_task = ReadTask(
133-
lambda stream=stream: [_read_single_partition(stream)],
134-
metadata,
128+
# Create a no-arg wrapper read function which returns a block
129+
read_single_partition = (
130+
lambda stream=stream, kwargs=self._kwargs: [ # noqa: F731
131+
_read_single_partition(stream, kwargs)
132+
]
135133
)
134+
135+
# Create the read task and pass the wrapper and metadata in
136+
read_task = ReadTask(read_single_partition, metadata)
136137
read_tasks.append(read_task)
137138

138139
return read_tasks
@@ -167,14 +168,18 @@ class BigQueryDatasource(Datasource):
167168
def create_reader(self, **kwargs) -> Reader:
168169
return _BigQueryDatasourceReader(**kwargs)
169170

170-
def write(
171+
def do_write(
171172
self,
172173
blocks: List[ObjectRef[Block]],
173-
ctx: TaskContext,
174+
metadata: List[BlockMetadata],
175+
ray_remote_args: Optional[Dict[str, Any]],
174176
project_id: Optional[str] = None,
175177
dataset: Optional[str] = None,
176-
) -> WriteResult:
177-
def _write_single_block(block: Block, project_id: str, dataset: str):
178+
) -> List[ObjectRef[WriteResult]]:
179+
def _write_single_block(
180+
block: Block, metadata: BlockMetadata, project_id: str, dataset: str
181+
):
182+
print("[Ray on Vertex AI]: Starting to write", metadata.num_rows, "rows")
178183
block = BlockAccessor.for_block(block).to_arrow()
179184

180185
client = bigquery.Client(project=project_id, client_info=bq_info)
@@ -187,7 +192,7 @@ def _write_single_block(block: Block, project_id: str, dataset: str):
187192
pq.write_table(block, fp, compression="SNAPPY")
188193

189194
retry_cnt = 0
190-
while retry_cnt < MAX_RETRY_CNT:
195+
while retry_cnt < 10:
191196
with open(fp, "rb") as source_file:
192197
job = client.load_table_from_file(
193198
source_file, dataset, job_config=job_config
@@ -197,11 +202,12 @@ def _write_single_block(block: Block, project_id: str, dataset: str):
197202
logging.info(job.result())
198203
break
199204
except exceptions.Forbidden as e:
200-
logging.info(
205+
print(
201206
"[Ray on Vertex AI]: Rate limit exceeded... Sleeping to try again"
202207
)
203208
logging.debug(e)
204-
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
209+
time.sleep(11)
210+
print("[Ray on Vertex AI]: Finished writing", metadata.num_rows, "rows")
205211

206212
project_id = project_id or initializer.global_config.project
207213

@@ -210,21 +216,34 @@ def _write_single_block(block: Block, project_id: str, dataset: str):
210216
"[Ray on Vertex AI]: Dataset is required when writing to BigQuery."
211217
)
212218

219+
if ray_remote_args is None:
220+
ray_remote_args = {}
221+
222+
_write_single_block = cached_remote_fn(_write_single_block).options(
223+
**ray_remote_args
224+
)
225+
write_tasks = []
226+
213227
# Set up datasets to write
214228
client = bigquery.Client(project=project_id, client_info=bq_info)
215229
dataset_id = dataset.split(".", 1)[0]
216230
try:
217231
client.create_dataset(f"{project_id}.{dataset_id}", timeout=30)
218-
logging.info(f"[Ray on Vertex AI]: Created dataset {dataset_id}.")
232+
print("[Ray on Vertex AI]: Created dataset", dataset_id)
219233
except exceptions.Conflict:
220-
logging.info(
221-
f"[Ray on Vertex AI]: Dataset {dataset_id} already exists. "
222-
+ "The table will be overwritten if it already exists."
234+
print(
235+
"[Ray on Vertex AI]: Dataset",
236+
dataset_id,
237+
"already exists. The table will be overwritten if it already exists.",
223238
)
224239

225240
# Delete table if it already exists
226241
client.delete_table(f"{project_id}.{dataset}", not_found_ok=True)
227242

228-
for block in blocks:
229-
_write_single_block(block, project_id, dataset)
230-
return "ok"
243+
print("[Ray on Vertex AI]: Writing", len(blocks), "blocks")
244+
for i in range(len(blocks)):
245+
write_task = _write_single_block.remote(
246+
blocks[i], metadata[i], project_id, dataset
247+
)
248+
write_tasks.append(write_task)
249+
return write_tasks

tests/unit/vertex_ray/test_bigquery.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from google.cloud.bigquery import job
2828
from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream
2929
import mock
30-
import pyarrow as pa
3130
import pytest
3231
import ray
3332

@@ -90,6 +89,7 @@ def bq_query_mock(query):
9089
client_mock.query = bq_query_mock
9190

9291
monkeypatch.setattr(bigquery, "Client", client_mock)
92+
client_mock.reset_mock()
9393
return client_mock
9494

9595

@@ -108,6 +108,7 @@ def bqs_create_read_session(max_stream_count=0, **kwargs):
108108
client_mock.create_read_session = bqs_create_read_session
109109

110110
monkeypatch.setattr(bigquery_storage, "BigQueryReadClient", client_mock)
111+
client_mock.reset_mock()
111112
return client_mock
112113

113114

@@ -258,16 +259,16 @@ def setup_method(self):
258259
def teardown_method(self):
259260
aiplatform.initializer.global_pool.shutdown(wait=True)
260261

261-
def test_write(self):
262+
def test_do_write(self, ray_remote_function_mock):
262263
bq_ds = bigquery_datasource.BigQueryDatasource()
263-
arr = pa.array([2, 4, 5, 100])
264-
block = pa.Table.from_arrays([arr], names=["data"])
265-
status = bq_ds.write(
266-
blocks=[block],
267-
ctx=None,
264+
write_tasks_list = bq_ds.do_write(
265+
blocks=[1, 2, 3, 4],
266+
metadata=[1, 2, 3, 4],
267+
ray_remote_args={},
268+
project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID,
268269
dataset=_TEST_BQ_DATASET,
269270
)
270-
assert status == "ok"
271+
assert len(write_tasks_list) == 4
271272

272273
def test_do_write_initialized(self, ray_remote_function_mock):
273274
"""If initialized, do_write doesn't need to specify project_id."""
@@ -276,22 +277,21 @@ def test_do_write_initialized(self, ray_remote_function_mock):
276277
staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI,
277278
)
278279
bq_ds = bigquery_datasource.BigQueryDatasource()
279-
arr = pa.array([2, 4, 5, 100])
280-
block = pa.Table.from_arrays([arr], names=["data"])
281-
status = bq_ds.write(
282-
blocks=[block],
283-
ctx=None,
284-
dataset="existingdataset" + "." + _TEST_BQ_TABLE_ID,
280+
write_tasks_list = bq_ds.do_write(
281+
blocks=[1, 2, 3, 4],
282+
metadata=[1, 2, 3, 4],
283+
ray_remote_args={},
284+
dataset=_TEST_BQ_DATASET,
285285
)
286-
assert status == "ok"
286+
assert len(write_tasks_list) == 4
287287

288-
def test_write_dataset_exists(self, ray_remote_function_mock):
288+
def test_do_write_dataset_exists(self, ray_remote_function_mock):
289289
bq_ds = bigquery_datasource.BigQueryDatasource()
290-
arr = pa.array([2, 4, 5, 100])
291-
block = pa.Table.from_arrays([arr], names=["data"])
292-
status = bq_ds.write(
293-
blocks=[block],
294-
ctx=None,
290+
write_tasks_list = bq_ds.do_write(
291+
blocks=[1, 2, 3, 4],
292+
metadata=[1, 2, 3, 4],
293+
ray_remote_args={},
294+
project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID,
295295
dataset="existingdataset" + "." + _TEST_BQ_TABLE_ID,
296296
)
297-
assert status == "ok"
297+
assert len(write_tasks_list) == 4

0 commit comments

Comments
 (0)