-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
PERF: Allow jitting of groupby agg loop #35759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c7ba0eb
7f1159b
2d79984
ddfe6d8
d349816
6292d75
66edc21
608c955
cd0ed3f
3d2f955
b4d8dab
dfad4f5
8f5e9db
09c4309
0009fc4
7234f7e
5282f16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,19 +69,16 @@ | |
GroupBy, | ||
_agg_template, | ||
_apply_docs, | ||
_group_selection_context, | ||
_transform_template, | ||
get_groupby, | ||
) | ||
from pandas.core.groupby.numba_ import generate_numba_func, split_for_numba | ||
from pandas.core.indexes.api import Index, MultiIndex, all_indexes_same | ||
import pandas.core.indexes.base as ibase | ||
from pandas.core.internals import BlockManager, make_block | ||
from pandas.core.series import Series | ||
from pandas.core.util.numba_ import ( | ||
NUMBA_FUNC_CACHE, | ||
generate_numba_func, | ||
maybe_use_numba, | ||
split_for_numba, | ||
) | ||
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE, maybe_use_numba | ||
|
||
from pandas.plotting import boxplot_frame_groupby | ||
|
||
|
@@ -229,6 +226,18 @@ def apply(self, func, *args, **kwargs): | |
) | ||
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): | ||
|
||
if maybe_use_numba(engine): | ||
if not callable(func): | ||
raise NotImplementedError( | ||
"Numba engine can only be used with a single function." | ||
) | ||
with _group_selection_context(self): | ||
data = self._selected_obj | ||
result, index = self._aggregate_with_numba( | ||
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
return self.obj._constructor(result.ravel(), index=index, name=data.name) | ||
|
||
relabeling = func is None | ||
columns = None | ||
if relabeling: | ||
|
@@ -251,16 +260,11 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) | |
return getattr(self, cyfunc)() | ||
|
||
if self.grouper.nkeys > 1: | ||
return self._python_agg_general( | ||
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
return self._python_agg_general(func, *args, **kwargs) | ||
|
||
try: | ||
return self._python_agg_general( | ||
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
return self._python_agg_general(func, *args, **kwargs) | ||
except (ValueError, KeyError): | ||
# Do not catch Numba errors here, we want to raise and not fall back. | ||
# TODO: KeyError is raised in _python_agg_general, | ||
# see see test_groupby.test_basic | ||
result = self._aggregate_named(func, *args, **kwargs) | ||
|
@@ -936,12 +940,19 @@ class DataFrameGroupBy(GroupBy[DataFrame]): | |
) | ||
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): | ||
|
||
relabeling, func, columns, order = reconstruct_func(func, **kwargs) | ||
|
||
if maybe_use_numba(engine): | ||
return self._python_agg_general( | ||
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
if not callable(func): | ||
raise NotImplementedError( | ||
"Numba engine can only be used with a single function." | ||
) | ||
with _group_selection_context(self): | ||
data = self._selected_obj | ||
result, index = self._aggregate_with_numba( | ||
data, func, *args, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
return self.obj._constructor(result, index=index, columns=data.columns) | ||
|
||
relabeling, func, columns, order = reconstruct_func(func, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment here |
||
|
||
result, how = self._aggregate(func, *args, **kwargs) | ||
if how is None: | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -34,7 +34,7 @@ class providing the base-class of operations. | |||
|
||||
from pandas._config.config import option_context | ||||
|
||||
from pandas._libs import Timestamp | ||||
from pandas._libs import Timestamp, lib | ||||
import pandas._libs.groupby as libgroupby | ||||
from pandas._typing import F, FrameOrSeries, FrameOrSeriesUnion, Scalar | ||||
from pandas.compat.numpy import function as nv | ||||
|
@@ -61,11 +61,11 @@ class providing the base-class of operations. | |||
import pandas.core.common as com | ||||
from pandas.core.frame import DataFrame | ||||
from pandas.core.generic import NDFrame | ||||
from pandas.core.groupby import base, ops | ||||
from pandas.core.groupby import base, numba_, ops | ||||
from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex | ||||
from pandas.core.series import Series | ||||
from pandas.core.sorting import get_group_index_sorter | ||||
from pandas.core.util.numba_ import maybe_use_numba | ||||
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE | ||||
|
||||
_common_see_also = """ | ||||
See Also | ||||
|
@@ -384,7 +384,8 @@ class providing the base-class of operations. | |||
- dict of axis labels -> functions, function names or list of such. | ||||
|
||||
Can also accept a Numba JIT function with | ||||
``engine='numba'`` specified. | ||||
``engine='numba'`` specified. Only passing a single function is supported | ||||
with this engine. | ||||
|
||||
If the ``'numba'`` engine is chosen, the function must be | ||||
a user defined function with ``values`` and ``index`` as the | ||||
|
@@ -1053,12 +1054,43 @@ def _cython_agg_general( | |||
|
||||
return self._wrap_aggregated_output(output, index=self.grouper.result_index) | ||||
|
||||
def _python_agg_general( | ||||
self, func, *args, engine="cython", engine_kwargs=None, **kwargs | ||||
): | ||||
def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs): | ||||
""" | ||||
Perform groupby aggregation routine with the numba engine. | ||||
|
||||
This routine mimics the data splitting routine of the DataSplitter class | ||||
to generate the indices of each group in the sorted data and then passes the | ||||
data and indices into a Numba jitted function. | ||||
""" | ||||
group_keys = self.grouper._get_group_keys() | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a doc-string a and type as much as possible |
||||
labels, _, n_groups = self.grouper.group_info | ||||
sorted_index = get_group_index_sorter(labels, n_groups) | ||||
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False) | ||||
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy() | ||||
starts, ends = lib.generate_slices(sorted_labels, n_groups) | ||||
cache_key = (func, "groupby_agg") | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this consistent with other functions, e.g. transform and rolling and such (the cache keys)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah these keys are all formatted similarly (function, "string of the operation") |
||||
if cache_key in NUMBA_FUNC_CACHE: | ||||
# Return an already compiled version of roll_apply if available | ||||
numba_agg_func = NUMBA_FUNC_CACHE[cache_key] | ||||
else: | ||||
numba_agg_func = numba_.generate_numba_agg_func( | ||||
tuple(args), kwargs, func, engine_kwargs | ||||
) | ||||
result = numba_agg_func( | ||||
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns), | ||||
) | ||||
if cache_key not in NUMBA_FUNC_CACHE: | ||||
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't this be moved into the else? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should you check that the cache is being used property lin a test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking to evaluate the function first with all arguments first before putting the function in the cache so we're not caching a function that may fail. I have existing tests that check for the presence of the function in the cache here:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah would want to move this all to groupby/numba_.py rather than here (you can certainly cache afer, but ideally all of the caching is not exposed here; I think we did this elsewhere IIRC) |
||||
|
||||
if self.grouper.nkeys > 1: | ||||
index = MultiIndex.from_tuples(group_keys, names=self.grouper.names) | ||||
else: | ||||
index = Index(group_keys, name=self.grouper.names[0]) | ||||
return result, index | ||||
|
||||
def _python_agg_general(self, func, *args, **kwargs): | ||||
func = self._is_builtin_func(func) | ||||
if engine != "numba": | ||||
f = lambda x: func(x, *args, **kwargs) | ||||
f = lambda x: func(x, *args, **kwargs) | ||||
|
||||
# iterate through "columns" ex exclusions to populate output dict | ||||
output: Dict[base.OutputKey, np.ndarray] = {} | ||||
|
@@ -1069,21 +1101,11 @@ def _python_agg_general( | |||
# agg_series below assumes ngroups > 0 | ||||
continue | ||||
|
||||
if maybe_use_numba(engine): | ||||
result, counts = self.grouper.agg_series( | ||||
obj, | ||||
func, | ||||
*args, | ||||
engine=engine, | ||||
engine_kwargs=engine_kwargs, | ||||
**kwargs, | ||||
) | ||||
else: | ||||
try: | ||||
# if this function is invalid for this dtype, we will ignore it. | ||||
result, counts = self.grouper.agg_series(obj, f) | ||||
except TypeError: | ||||
continue | ||||
try: | ||||
# if this function is invalid for this dtype, we will ignore it. | ||||
result, counts = self.grouper.agg_series(obj, f) | ||||
except TypeError: | ||||
continue | ||||
|
||||
assert result is not None | ||||
key = base.OutputKey(label=name, position=idx) | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
"""Common utilities for Numba operations with groupby ops""" | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import inspect | ||
from typing import Any, Callable, Dict, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
from pandas._typing import FrameOrSeries, Scalar | ||
from pandas.compat._optional import import_optional_dependency | ||
|
||
from pandas.core.util.numba_ import ( | ||
NUMBA_FUNC_CACHE, | ||
NumbaUtilError, | ||
check_kwargs_and_nopython, | ||
get_jit_arguments, | ||
jit_user_function, | ||
) | ||
|
||
|
||
def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Split pandas object into its components as numpy arrays for numba functions. | ||
|
||
Parameters | ||
---------- | ||
arg : Series or DataFrame | ||
|
||
Returns | ||
------- | ||
(ndarray, ndarray) | ||
values, index | ||
""" | ||
return arg.to_numpy(), arg.index.to_numpy() | ||
|
||
|
||
def validate_udf(func: Callable) -> None: | ||
""" | ||
Validate user defined function for ops when using Numba with groupby ops. | ||
|
||
The first signature arguments should include: | ||
|
||
def f(values, index, ...): | ||
... | ||
|
||
Parameters | ||
---------- | ||
func : function, default False | ||
user defined function | ||
|
||
Returns | ||
------- | ||
None | ||
|
||
Raises | ||
------ | ||
NumbaUtilError | ||
""" | ||
udf_signature = list(inspect.signature(func).parameters.keys()) | ||
expected_args = ["values", "index"] | ||
min_number_args = len(expected_args) | ||
if ( | ||
len(udf_signature) < min_number_args | ||
or udf_signature[:min_number_args] != expected_args | ||
): | ||
raise NumbaUtilError( | ||
f"The first {min_number_args} arguments to {func.__name__} must be " | ||
f"{expected_args}" | ||
) | ||
|
||
|
||
def generate_numba_func( | ||
func: Callable, | ||
engine_kwargs: Optional[Dict[str, bool]], | ||
kwargs: dict, | ||
cache_key_str: str, | ||
) -> Tuple[Callable, Tuple[Callable, str]]: | ||
""" | ||
Return a JITed function and cache key for the NUMBA_FUNC_CACHE | ||
|
||
This _may_ be specific to groupby (as it's only used there currently). | ||
|
||
Parameters | ||
---------- | ||
func : function | ||
user defined function | ||
engine_kwargs : dict or None | ||
numba.jit arguments | ||
kwargs : dict | ||
kwargs for func | ||
cache_key_str : str | ||
string representing the second part of the cache key tuple | ||
|
||
Returns | ||
------- | ||
(JITed function, cache key) | ||
|
||
Raises | ||
------ | ||
NumbaUtilError | ||
""" | ||
nopython, nogil, parallel = get_jit_arguments(engine_kwargs) | ||
check_kwargs_and_nopython(kwargs, nopython) | ||
validate_udf(func) | ||
cache_key = (func, cache_key_str) | ||
numba_func = NUMBA_FUNC_CACHE.get( | ||
cache_key, jit_user_function(func, nopython, nogil, parallel) | ||
) | ||
return numba_func, cache_key | ||
|
||
|
||
def generate_numba_agg_func( | ||
args: Tuple, | ||
kwargs: Dict[str, Any], | ||
func: Callable[..., Scalar], | ||
engine_kwargs: Optional[Dict[str, bool]], | ||
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]: | ||
""" | ||
Generate a numba jitted agg function specified by values from engine_kwargs. | ||
|
||
1. jit the user's function | ||
2. Return a groupby agg function with the jitted function inline | ||
|
||
Configurations specified in engine_kwargs apply to both the user's | ||
function _AND_ the rolling apply function. | ||
|
||
Parameters | ||
---------- | ||
args : tuple | ||
*args to be passed into the function | ||
kwargs : dict | ||
**kwargs to be passed into the function | ||
func : function | ||
function to be applied to each window and will be JITed | ||
engine_kwargs : dict | ||
dictionary of arguments to be passed into numba.jit | ||
|
||
Returns | ||
------- | ||
Numba function | ||
""" | ||
nopython, nogil, parallel = get_jit_arguments(engine_kwargs) | ||
|
||
check_kwargs_and_nopython(kwargs, nopython) | ||
|
||
validate_udf(func) | ||
|
||
numba_func = jit_user_function(func, nopython, nogil, parallel) | ||
|
||
numba = import_optional_dependency("numba") | ||
|
||
if parallel: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could make a helper function for this (as we likley need this elsewhere?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only mimicked one other place currently for rolling. I can consolidate when the pattern grows |
||
loop_range = numba.prange | ||
else: | ||
loop_range = range | ||
|
||
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) | ||
def group_apply( | ||
values: np.ndarray, | ||
index: np.ndarray, | ||
begin: np.ndarray, | ||
end: np.ndarray, | ||
num_groups: int, | ||
num_columns: int, | ||
) -> np.ndarray: | ||
result = np.empty((num_groups, num_columns)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to type this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking float (the default type) would be the safest here?
If there's a desire to infer a more appropriate type (int) I could include inference logic |
||
for i in loop_range(num_groups): | ||
group_index = index[begin[i] : end[i]] | ||
for j in loop_range(num_columns): | ||
group = values[begin[i] : end[i], j] | ||
result[i, j] = numba_func(group, group_index, *args) | ||
return result | ||
|
||
return group_apply |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would not object to making a _aggregate_with_python_cython (where you put everything L242 and on down.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could do this in a follow up refactor PR.
I guess I would need to make a Series and DataFrame version of this function since it looks like both are different.