Skip to content

Commit 37914a4

Browse files
authored
feat: add series.sample (identical to existing dataframe.sample) (#187)
We're duplicating some arg-parsing logic here. Discussed briefly with Trevor. This is the case for other methods as well- we might want to add a sharing mechanism for dataframe/series (superclass like pandas?) in the future. The documentation already exists in third_party/core/generic.py, which is actually what prompted this feat/fix.
1 parent d49ae42 commit 37914a4

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

bigframes/series.py

+16
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,22 @@ def map(
14471447
result_df = self_df.join(map_df, on="series")
14481448
return result_df[self.name]
14491449

1450+
def sample(
1451+
self,
1452+
n: Optional[int] = None,
1453+
frac: Optional[float] = None,
1454+
*,
1455+
random_state: Optional[int] = None,
1456+
) -> Series:
1457+
if n is not None and frac is not None:
1458+
raise ValueError("Only one of 'n' or 'frac' parameter can be specified.")
1459+
1460+
ns = (n,) if n is not None else ()
1461+
fracs = (frac,) if frac is not None else ()
1462+
return Series(
1463+
self._block._split(ns=ns, fracs=fracs, random_state=random_state)[0]
1464+
)
1465+
14501466
def __array_ufunc__(
14511467
self, ufunc: numpy.ufunc, method: str, *inputs, **kwargs
14521468
) -> Series:

tests/system/small/test_series.py

+27
Original file line numberDiff line numberDiff line change
@@ -2922,3 +2922,30 @@ def test_map_series_input_duplicates_error(scalars_dfs):
29222922
scalars_pandas_df.int64_too.map(pd_map_series)
29232923
with pytest.raises(pd.errors.InvalidIndexError):
29242924
scalars_df.int64_too.map(bf_map_series, verify_integrity=True)
2925+
2926+
2927+
@pytest.mark.parametrize(
2928+
("frac", "n", "random_state"),
2929+
[
2930+
(None, 4, None),
2931+
(0.5, None, None),
2932+
(None, 4, 10),
2933+
(0.5, None, 10),
2934+
(None, None, None),
2935+
],
2936+
ids=[
2937+
"n_wo_random_state",
2938+
"frac_wo_random_state",
2939+
"n_w_random_state",
2940+
"frac_w_random_state",
2941+
"n_default",
2942+
],
2943+
)
2944+
def test_sample(scalars_dfs, frac, n, random_state):
2945+
scalars_df, _ = scalars_dfs
2946+
df = scalars_df.int64_col.sample(frac=frac, n=n, random_state=random_state)
2947+
bf_result = df.to_pandas()
2948+
2949+
n = 1 if n is None else n
2950+
expected_sample_size = round(frac * scalars_df.shape[0]) if frac is not None else n
2951+
assert bf_result.shape[0] == expected_sample_size

0 commit comments

Comments
 (0)