|
19 | 19 | from typing import cast, Literal, Optional, Union
|
20 | 20 | import warnings
|
21 | 21 |
|
| 22 | +from google.cloud import bigquery |
| 23 | + |
22 | 24 | import bigframes
|
23 | 25 | from bigframes import clients, constants
|
24 | 26 | from bigframes.core import blocks, log_adapter
|
@@ -113,6 +115,26 @@ def _create_bqml_model(self):
|
113 | 115 | session=self.session, connection_name=self.connection_name, options=options
|
114 | 116 | )
|
115 | 117 |
|
| 118 | + @classmethod |
| 119 | + def _from_bq( |
| 120 | + cls, session: bigframes.Session, model: bigquery.Model |
| 121 | + ) -> PaLM2TextGenerator: |
| 122 | + assert model.model_type == "MODEL_TYPE_UNSPECIFIED" |
| 123 | + assert "remoteModelInfo" in model._properties |
| 124 | + assert "endpoint" in model._properties["remoteModelInfo"] |
| 125 | + assert "connection" in model._properties["remoteModelInfo"] |
| 126 | + |
| 127 | + # Parse the remote model endpoint |
| 128 | + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] |
| 129 | + model_connection = model._properties["remoteModelInfo"]["connection"] |
| 130 | + model_endpoint = bqml_endpoint.split("/")[-1] |
| 131 | + |
| 132 | + text_generator_model = cls( |
| 133 | + session=session, model_name=model_endpoint, connection_name=model_connection |
| 134 | + ) |
| 135 | + text_generator_model._bqml_model = core.BqmlModel(session, model) |
| 136 | + return text_generator_model |
| 137 | + |
116 | 138 | def predict(
|
117 | 139 | self,
|
118 | 140 | X: Union[bpd.DataFrame, bpd.Series],
|
@@ -200,6 +222,21 @@ def predict(
|
200 | 222 |
|
201 | 223 | return df
|
202 | 224 |
|
| 225 | + def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator: |
| 226 | + """Save the model to BigQuery. |
| 227 | +
|
| 228 | + Args: |
| 229 | + model_name (str): |
| 230 | + the name of the model. |
| 231 | + replace (bool, default False): |
| 232 | + whether to replace if the model already exists. Default to False. |
| 233 | +
|
| 234 | + Returns: |
| 235 | + PaLM2TextGenerator: saved model.""" |
| 236 | + |
| 237 | + new_model = self._bqml_model.copy(model_name, replace) |
| 238 | + return new_model.session.read_gbq_model(model_name) |
| 239 | + |
203 | 240 |
|
204 | 241 | @log_adapter.class_logger
|
205 | 242 | class PaLM2TextEmbeddingGenerator(base.Predictor):
|
@@ -271,6 +308,26 @@ def _create_bqml_model(self):
|
271 | 308 | session=self.session, connection_name=self.connection_name, options=options
|
272 | 309 | )
|
273 | 310 |
|
| 311 | + @classmethod |
| 312 | + def _from_bq( |
| 313 | + cls, session: bigframes.Session, model: bigquery.Model |
| 314 | + ) -> PaLM2TextEmbeddingGenerator: |
| 315 | + assert model.model_type == "MODEL_TYPE_UNSPECIFIED" |
| 316 | + assert "remoteModelInfo" in model._properties |
| 317 | + assert "endpoint" in model._properties["remoteModelInfo"] |
| 318 | + assert "connection" in model._properties["remoteModelInfo"] |
| 319 | + |
| 320 | + # Parse the remote model endpoint |
| 321 | + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] |
| 322 | + model_connection = model._properties["remoteModelInfo"]["connection"] |
| 323 | + model_endpoint = bqml_endpoint.split("/")[-1] |
| 324 | + |
| 325 | + embedding_generator_model = cls( |
| 326 | + session=session, model_name=model_endpoint, connection_name=model_connection |
| 327 | + ) |
| 328 | + embedding_generator_model._bqml_model = core.BqmlModel(session, model) |
| 329 | + return embedding_generator_model |
| 330 | + |
274 | 331 | def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
|
275 | 332 | """Predict the result from input DataFrame.
|
276 | 333 |
|
@@ -307,3 +364,20 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
|
307 | 364 | )
|
308 | 365 |
|
309 | 366 | return df
|
| 367 | + |
| 368 | + def to_gbq( |
| 369 | + self, model_name: str, replace: bool = False |
| 370 | + ) -> PaLM2TextEmbeddingGenerator: |
| 371 | + """Save the model to BigQuery. |
| 372 | +
|
| 373 | + Args: |
| 374 | + model_name (str): |
| 375 | + the name of the model. |
| 376 | + replace (bool, default False): |
| 377 | + whether to replace if the model already exists. Default to False. |
| 378 | +
|
| 379 | + Returns: |
| 380 | + PaLM2TextEmbeddingGenerator: saved model.""" |
| 381 | + |
| 382 | + new_model = self._bqml_model.copy(model_name, replace) |
| 383 | + return new_model.session.read_gbq_model(model_name) |
0 commit comments