21
21
import time
22
22
from typing import Any , Dict , List , Optional
23
23
import uuid
24
- import pyarrow .parquet as pq
25
24
26
25
from google .api_core import client_info
27
26
from google .api_core import exceptions
28
27
from google .api_core .gapic_v1 import client_info as v1_client_info
29
28
from google .cloud import bigquery
30
29
from google .cloud import bigquery_storage
31
30
from google .cloud .aiplatform import initializer
32
-
33
- from ray .data ._internal .execution .interfaces import TaskContext
31
+ from google .cloud .bigquery_storage import types
32
+ import pyarrow .parquet as pq
33
+ from ray .data ._internal .remote_fn import cached_remote_fn
34
34
from ray .data .block import Block
35
35
from ray .data .block import BlockAccessor
36
36
from ray .data .block import BlockMetadata
50
50
gapic_version = _BQS_GAPIC_VERSION , user_agent = f"ray-on-vertex/{ _BQS_GAPIC_VERSION } "
51
51
)
52
52
53
- MAX_RETRY_CNT = 10
54
- RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
55
-
56
53
57
54
class _BigQueryDatasourceReader (Reader ):
58
55
def __init__ (
@@ -70,12 +67,12 @@ def __init__(
70
67
71
68
if query is not None and dataset is not None :
72
69
raise ValueError (
73
- "[Ray on Vertex AI]: Query and dataset kwargs cannot both "
74
- + "be provided (must be mutually exclusive)."
70
+ "[Ray on Vertex AI]: Query and dataset kwargs cannot both be provided (must be mutually exclusive)."
75
71
)
76
72
77
73
def get_read_tasks (self , parallelism : int ) -> List [ReadTask ]:
78
- def _read_single_partition (stream ) -> Block :
74
+ # Executed by a worker node
75
+ def _read_single_partition (stream , kwargs ) -> Block :
79
76
client = bigquery_storage .BigQueryReadClient (client_info = bqstorage_info )
80
77
reader = client .read_rows (stream .name )
81
78
return reader .to_arrow ()
@@ -99,9 +96,9 @@ def _read_single_partition(stream) -> Block:
99
96
100
97
if parallelism == - 1 :
101
98
parallelism = None
102
- requested_session = bigquery_storage . types .ReadSession (
99
+ requested_session = types .ReadSession (
103
100
table = table ,
104
- data_format = bigquery_storage . types .DataFormat .ARROW ,
101
+ data_format = types .DataFormat .ARROW ,
105
102
)
106
103
read_session = bqs_client .create_read_session (
107
104
parent = f"projects/{ self ._project_id } " ,
@@ -110,9 +107,9 @@ def _read_single_partition(stream) -> Block:
110
107
)
111
108
112
109
read_tasks = []
113
- logging . info ( f" Created streams: { len (read_session .streams )} " )
110
+ print ( "[Ray on Vertex AI]: Created streams:" , len (read_session .streams ))
114
111
if len (read_session .streams ) < parallelism :
115
- logging . info (
112
+ print (
116
113
"[Ray on Vertex AI]: The number of streams created by the "
117
114
+ "BigQuery Storage Read API is less than the requested "
118
115
+ "parallelism due to the size of the dataset."
@@ -128,11 +125,15 @@ def _read_single_partition(stream) -> Block:
128
125
exec_stats = None ,
129
126
)
130
127
131
- # Create the read task and pass the no-arg wrapper and metadata in
132
- read_task = ReadTask (
133
- lambda stream = stream : [_read_single_partition (stream )],
134
- metadata ,
128
+ # Create a no-arg wrapper read function which returns a block
129
+ read_single_partition = (
130
+ lambda stream = stream , kwargs = self ._kwargs : [ # noqa: F731
131
+ _read_single_partition (stream , kwargs )
132
+ ]
135
133
)
134
+
135
+ # Create the read task and pass the wrapper and metadata in
136
+ read_task = ReadTask (read_single_partition , metadata )
136
137
read_tasks .append (read_task )
137
138
138
139
return read_tasks
@@ -167,14 +168,18 @@ class BigQueryDatasource(Datasource):
167
168
def create_reader (self , ** kwargs ) -> Reader :
168
169
return _BigQueryDatasourceReader (** kwargs )
169
170
170
- def write (
171
+ def do_write (
171
172
self ,
172
173
blocks : List [ObjectRef [Block ]],
173
- ctx : TaskContext ,
174
+ metadata : List [BlockMetadata ],
175
+ ray_remote_args : Optional [Dict [str , Any ]],
174
176
project_id : Optional [str ] = None ,
175
177
dataset : Optional [str ] = None ,
176
- ) -> WriteResult :
177
- def _write_single_block (block : Block , project_id : str , dataset : str ):
178
+ ) -> List [ObjectRef [WriteResult ]]:
179
+ def _write_single_block (
180
+ block : Block , metadata : BlockMetadata , project_id : str , dataset : str
181
+ ):
182
+ print ("[Ray on Vertex AI]: Starting to write" , metadata .num_rows , "rows" )
178
183
block = BlockAccessor .for_block (block ).to_arrow ()
179
184
180
185
client = bigquery .Client (project = project_id , client_info = bq_info )
@@ -187,7 +192,7 @@ def _write_single_block(block: Block, project_id: str, dataset: str):
187
192
pq .write_table (block , fp , compression = "SNAPPY" )
188
193
189
194
retry_cnt = 0
190
- while retry_cnt < MAX_RETRY_CNT :
195
+ while retry_cnt < 10 :
191
196
with open (fp , "rb" ) as source_file :
192
197
job = client .load_table_from_file (
193
198
source_file , dataset , job_config = job_config
@@ -197,11 +202,12 @@ def _write_single_block(block: Block, project_id: str, dataset: str):
197
202
logging .info (job .result ())
198
203
break
199
204
except exceptions .Forbidden as e :
200
- logging . info (
205
+ print (
201
206
"[Ray on Vertex AI]: Rate limit exceeded... Sleeping to try again"
202
207
)
203
208
logging .debug (e )
204
- time .sleep (RATE_LIMIT_EXCEEDED_SLEEP_TIME )
209
+ time .sleep (11 )
210
+ print ("[Ray on Vertex AI]: Finished writing" , metadata .num_rows , "rows" )
205
211
206
212
project_id = project_id or initializer .global_config .project
207
213
@@ -210,21 +216,34 @@ def _write_single_block(block: Block, project_id: str, dataset: str):
210
216
"[Ray on Vertex AI]: Dataset is required when writing to BigQuery."
211
217
)
212
218
219
+ if ray_remote_args is None :
220
+ ray_remote_args = {}
221
+
222
+ _write_single_block = cached_remote_fn (_write_single_block ).options (
223
+ ** ray_remote_args
224
+ )
225
+ write_tasks = []
226
+
213
227
# Set up datasets to write
214
228
client = bigquery .Client (project = project_id , client_info = bq_info )
215
229
dataset_id = dataset .split ("." , 1 )[0 ]
216
230
try :
217
231
client .create_dataset (f"{ project_id } .{ dataset_id } " , timeout = 30 )
218
- logging . info ( f "[Ray on Vertex AI]: Created dataset { dataset_id } ." )
232
+ print ( "[Ray on Vertex AI]: Created dataset" , dataset_id )
219
233
except exceptions .Conflict :
220
- logging .info (
221
- f"[Ray on Vertex AI]: Dataset { dataset_id } already exists. "
222
- + "The table will be overwritten if it already exists."
234
+ print (
235
+ "[Ray on Vertex AI]: Dataset" ,
236
+ dataset_id ,
237
+ "already exists. The table will be overwritten if it already exists." ,
223
238
)
224
239
225
240
# Delete table if it already exists
226
241
client .delete_table (f"{ project_id } .{ dataset } " , not_found_ok = True )
227
242
228
- for block in blocks :
229
- _write_single_block (block , project_id , dataset )
230
- return "ok"
243
+ print ("[Ray on Vertex AI]: Writing" , len (blocks ), "blocks" )
244
+ for i in range (len (blocks )):
245
+ write_task = _write_single_block .remote (
246
+ blocks [i ], metadata [i ], project_id , dataset
247
+ )
248
+ write_tasks .append (write_task )
249
+ return write_tasks
0 commit comments