Skip to content

feat: add Series.case_when() #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dataclasses import dataclass
import functools
import io
import itertools
import typing
from typing import Iterable, Sequence

Expand Down Expand Up @@ -370,14 +371,16 @@ def unpivot(
for col_id, input_ids in unpivot_columns:
# row explode offset used to choose the input column
# we use offset instead of label as labels are not necessarily unique
cases = tuple(
(
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)),
ex.free_var(id_or_null)
if (id_or_null is not None)
else ex.const(None),
cases = itertools.chain(
*(
(
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)),
ex.free_var(id_or_null)
if (id_or_null is not None)
else ex.const(None),
)
for i, id_or_null in enumerate(input_ids)
)
for i, id_or_null in enumerate(input_ids)
)
col_expr = ops.case_when_op.as_expr(*cases)
unpivot_exprs.append((col_expr, col_id))
Expand Down
9 changes: 9 additions & 0 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,15 @@ def apply_ternary_op(
expr = op.as_expr(col_id_1, col_id_2, col_id_3)
return self.project_expr(expr, result_label)

def apply_nary_op(
self,
columns: Iterable[str],
op: ops.NaryOp,
result_label: Label = None,
) -> typing.Tuple[Block, str]:
expr = op.as_expr(*columns)
return self.project_expr(expr, result_label)

def multi_apply_window_op(
self,
columns: typing.Sequence[str],
Expand Down
4 changes: 4 additions & 0 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,10 @@ def _repr_html_(self) -> str:
html_string += f"[{row_count} rows x {column_count} columns in total]"
return html_string

def __delitem__(self, key: str):
df = self.drop(columns=[key])
self._set_block(df._get_block())

def __setitem__(self, key: str, value: SingleItemValue):
df = self._assign_single_item(key, value)
self._set_block(df._get_block())
Expand Down
51 changes: 25 additions & 26 deletions bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dataclasses
import functools
import typing
from typing import Tuple, Union
from typing import Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -46,7 +46,7 @@ def order_preserving(self) -> bool:


@dataclasses.dataclass(frozen=True)
class NaryOp:
class ScalarOp:
@property
def name(self) -> str:
raise NotImplementedError("RowOp abstract base class has no implementation")
Expand All @@ -60,10 +60,30 @@ def order_preserving(self) -> bool:
return False


@dataclasses.dataclass(frozen=True)
class NaryOp(ScalarOp):
def as_expr(
self,
*exprs: Union[str | bigframes.core.expression.Expression],
) -> bigframes.core.expression.Expression:
import bigframes.core.expression

# Keep this in sync with output_type and compilers
inputs: list[bigframes.core.expression.Expression] = []

for expr in exprs:
inputs.append(_convert_expr_input(expr))

return bigframes.core.expression.OpExpression(
self,
tuple(inputs),
)


# These classes can be used to create simple ops that don't take local parameters
# All is needed is a unique name, and to register an implementation in ibis_mappings.py
@dataclasses.dataclass(frozen=True)
class UnaryOp(NaryOp):
class UnaryOp(ScalarOp):
@property
def arguments(self) -> int:
return 1
Expand All @@ -79,7 +99,7 @@ def as_expr(


@dataclasses.dataclass(frozen=True)
class BinaryOp(NaryOp):
class BinaryOp(ScalarOp):
@property
def arguments(self) -> int:
return 2
Expand All @@ -101,7 +121,7 @@ def as_expr(


@dataclasses.dataclass(frozen=True)
class TernaryOp(NaryOp):
class TernaryOp(ScalarOp):
@property
def arguments(self) -> int:
return 3
Expand Down Expand Up @@ -655,27 +675,6 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
output_expr_types,
)

def as_expr(
self,
*case_output_pairs: Tuple[
Union[str | bigframes.core.expression.Expression],
Union[str | bigframes.core.expression.Expression],
],
) -> bigframes.core.expression.Expression:
import bigframes.core.expression

# Keep this in sync with output_type and compilers
inputs: list[bigframes.core.expression.Expression] = []

for case, output in case_output_pairs:
inputs.append(_convert_expr_input(case))
inputs.append(_convert_expr_input(output))

return bigframes.core.expression.OpExpression(
self,
tuple(inputs),
)


case_when_op = CaseWhenOp()

Expand Down
23 changes: 22 additions & 1 deletion bigframes/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import typing
from typing import List, Sequence

import bigframes_vendored.pandas.pandas._typing as vendored_pandas_typing
import numpy
Expand Down Expand Up @@ -205,6 +206,21 @@ def _apply_binary_op(
block, result_id = self._block.project_expr(expr, name)
return series.Series(block.select_column(result_id))

def _apply_nary_op(
self,
op: ops.NaryOp,
others: Sequence[typing.Union[series.Series, scalars.Scalar]],
ignore_self=False,
):
"""Applies an n-ary operator to the series and others."""
values, block = self._align_n(others, ignore_self=ignore_self)
block, result_id = block.apply_nary_op(
values,
op,
self._name,
)
return series.Series(block.select_column(result_id))

def _apply_binary_aggregation(
self, other: series.Series, stat: agg_ops.BinaryAggregateOp
) -> float:
Expand All @@ -226,8 +242,13 @@ def _align_n(
self,
others: typing.Sequence[typing.Union[series.Series, scalars.Scalar]],
how="outer",
ignore_self=False,
) -> tuple[typing.Sequence[str], blocks.Block]:
value_ids = [self._value_column]
if ignore_self:
value_ids: List[str] = []
else:
value_ids = [self._value_column]

block = self._block
for other in others:
if isinstance(other, series.Series):
Expand Down
19 changes: 19 additions & 0 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,25 @@ def between(self, left, right, inclusive="both"):
self._apply_binary_op(right, right_op)
)

def case_when(self, caselist) -> Series:
return self._apply_nary_op(
ops.case_when_op,
tuple(
itertools.chain(
itertools.chain(*caselist),
# Fallback to current value if no other matches.
(
# We make a Series with a constant value to avoid casts to
# types other than boolean.
Series(True, index=self.index, dtype=pandas.BooleanDtype()),
self,
),
),
),
# Self is already included in "others".
ignore_self=True,
)

def cumsum(self) -> Series:
return self._apply_window_op(
agg_ops.sum_op, bigframes.core.window_spec.WindowSpec(following=0)
Expand Down
137 changes: 137 additions & 0 deletions samples/snippets/logistic_regression_prediction_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BigQuery DataFrames code samples for
https://cloud.google.com/bigquery/docs/logistic-regression-prediction.
"""


def test_logistic_regression_prediction(random_model_id: str) -> None:
your_model_id = random_model_id

# [START bigquery_dataframes_logistic_regression_prediction_examine]
import bigframes.pandas as bpd

df = bpd.read_gbq(
"bigquery-public-data.ml_datasets.census_adult_income",
columns=(
"age",
"workclass",
"marital_status",
"education_num",
"occupation",
"hours_per_week",
"income_bracket",
"functional_weight",
),
max_results=100,
)
df.peek()
# Output:
# age workclass marital_status education_num occupation hours_per_week income_bracket functional_weight
# 47 Local-gov Married-civ-spouse 13 Prof-specialty 40 >50K 198660
# 56 Private Never-married 9 Adm-clerical 40 <=50K 85018
# 40 Private Married-civ-spouse 12 Tech-support 40 >50K 285787
# 34 Self-emp-inc Married-civ-spouse 9 Craft-repair 54 >50K 207668
# 23 Private Married-civ-spouse 10 Handlers-cleaners 40 <=50K 40060
# [END bigquery_dataframes_logistic_regression_prediction_examine]

# [START bigquery_dataframes_logistic_regression_prediction_prepare]
import bigframes.pandas as bpd

input_data = bpd.read_gbq(
"bigquery-public-data.ml_datasets.census_adult_income",
columns=(
"age",
"workclass",
"marital_status",
"education_num",
"occupation",
"hours_per_week",
"income_bracket",
"functional_weight",
),
)
input_data["dataframe"] = bpd.Series("training", index=input_data.index,).case_when(
[
(((input_data["functional_weight"] % 10) == 8), "evaluation"),
(((input_data["functional_weight"] % 10) == 9), "prediction"),
]
)
del input_data["functional_weight"]
# [END bigquery_dataframes_logistic_regression_prediction_prepare]

# [START bigquery_dataframes_logistic_regression_prediction_create_model]
import bigframes.ml.linear_model

# input_data is defined in an earlier step.
training_data = input_data[input_data["dataframe"] == "training"]
X = training_data.drop(columns=["income_bracket", "dataframe"])
y = training_data["income_bracket"]

census_model = bigframes.ml.linear_model.LogisticRegression()
census_model.fit(X, y)

census_model.to_gbq(
your_model_id, # For example: "your-project.census.census_model"
replace=True,
)
# [END bigquery_dataframes_logistic_regression_prediction_create_model]

# [START bigquery_dataframes_logistic_regression_prediction_evaluate_model]
# Select model you'll use for predictions. `read_gbq_model` loads model
# data from BigQuery, but you could also use the `census_model` object
# from previous steps.
census_model = bpd.read_gbq_model(
your_model_id, # For example: "your-project.census.census_model"
)

# input_data is defined in an earlier step.
evaluation_data = input_data[input_data["dataframe"] == "evaluation"]
X = evaluation_data.drop(columns=["income_bracket", "dataframe"])
y = evaluation_data["income_bracket"]

# The score() method evaluates how the model performs compared to the
# actual data. Output DataFrame matches that of ML.EVALUATE().
score = census_model.score(X, y)
score.peek()
# Output:
# precision recall accuracy f1_score log_loss roc_auc
# 0 0.685764 0.536685 0.83819 0.602134 0.350417 0.882953
# [END bigquery_dataframes_logistic_regression_prediction_evaluate_model]

# [START bigquery_dataframes_logistic_regression_prediction_predict_income_bracket]
# Select model you'll use for predictions. `read_gbq_model` loads model
# data from BigQuery, but you could also use the `census_model` object
# from previous steps.
census_model = bpd.read_gbq_model(
your_model_id, # For example: "your-project.census.census_model"
)

# input_data is defined in an earlier step.
prediction_data = input_data[input_data["dataframe"] == "prediction"]

predictions = census_model.predict(prediction_data)
predictions.peek()
# Output:
# predicted_income_bracket predicted_income_bracket_probs age workclass ... occupation hours_per_week income_bracket dataframe
# 18004 <=50K [{'label': ' >50K', 'prob': 0.0763305999358786... 75 ? ... ? 6 <=50K prediction
# 18886 <=50K [{'label': ' >50K', 'prob': 0.0448866871906495... 73 ? ... ? 22 >50K prediction
# 31024 <=50K [{'label': ' >50K', 'prob': 0.0362982319421936... 69 ? ... ? 1 <=50K prediction
# 31022 <=50K [{'label': ' >50K', 'prob': 0.0787836112058324... 75 ? ... ? 5 <=50K prediction
# 23295 <=50K [{'label': ' >50K', 'prob': 0.3385373037905673... 78 ? ... ? 32 <=50K prediction
# [END bigquery_dataframes_logistic_regression_prediction_predict_income_bracket]

# TODO(tswast): Implement ML.EXPLAIN_PREDICT() and corresponding sample.
# TODO(tswast): Implement ML.GLOBAL_EXPLAIN() and corresponding sample.
Loading