Skip to content

Commit 3d2f115

Browse files
gfyoungjreback
authored andcommitted
BUG: Prevent abuse of kwargs in stat functions
Addresses issue #12301 by filtering `kwargs` argument in stat functions to prevent the passage of clearly invalid arguments while at the same time maintaining compatibility with analogous `numpy` functions. Author: gfyoung <[email protected]> Closes #12318 from gfyoung/kwarg_remover and squashes the following commits: f9de80f [gfyoung] BUG: Prevent abuse of kwargs in stat functions
1 parent faedd11 commit 3d2f115

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

doc/source/whatsnew/v0.18.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,8 @@ Other API Changes
824824

825825
- As part of the new API for :ref:`window functions <whatsnew_0180.enhancements.moments>` and :ref:`resampling <whatsnew_0180.breaking.resample>`, aggregation functions have been clarified, raising more informative error messages on invalid aggregations. (:issue:`9052`). A full set of examples are presented in :ref:`groupby <groupby.aggregation>`.
826826

827+
- Statistical functions for ``NDFrame`` objects will now raise if non-numpy-compatible arguments are passed in for ``**kwargs`` (:issue:`12301`)
828+
827829
.. _whatsnew_0180.deprecations:
828830

829831
Deprecations

pandas/core/generic.py

+20
Original file line numberDiff line numberDiff line change
@@ -5207,12 +5207,29 @@ def _doc_parms(cls):
52075207
%(outname)s : %(name1)s\n"""
52085208

52095209

5210+
def _validate_kwargs(fname, kwargs, *compat_args):
5211+
"""
5212+
Checks whether parameters passed to the
5213+
**kwargs argument in a 'stat' function 'fname'
5214+
are valid parameters as specified in *compat_args
5215+
5216+
"""
5217+
list(map(kwargs.__delitem__, filter(
5218+
kwargs.__contains__, compat_args)))
5219+
if kwargs:
5220+
bad_arg = list(kwargs)[0] # first 'key' element
5221+
raise TypeError(("{fname}() got an unexpected "
5222+
"keyword argument '{arg}'".
5223+
format(fname=fname, arg=bad_arg)))
5224+
5225+
52105226
def _make_stat_function(name, name1, name2, axis_descr, desc, f):
52115227
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
52125228
axis_descr=axis_descr)
52135229
@Appender(_num_doc)
52145230
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
52155231
**kwargs):
5232+
_validate_kwargs(name, kwargs, 'out', 'dtype')
52165233
if skipna is None:
52175234
skipna = True
52185235
if axis is None:
@@ -5233,6 +5250,7 @@ def _make_stat_function_ddof(name, name1, name2, axis_descr, desc, f):
52335250
@Appender(_num_ddof_doc)
52345251
def stat_func(self, axis=None, skipna=None, level=None, ddof=1,
52355252
numeric_only=None, **kwargs):
5253+
_validate_kwargs(name, kwargs, 'out', 'dtype')
52365254
if skipna is None:
52375255
skipna = True
52385256
if axis is None:
@@ -5254,6 +5272,7 @@ def _make_cum_function(name, name1, name2, axis_descr, desc, accum_func,
52545272
@Appender("Return cumulative {0} over requested axis.".format(name) +
52555273
_cnum_doc)
52565274
def func(self, axis=None, dtype=None, out=None, skipna=True, **kwargs):
5275+
_validate_kwargs(name, kwargs, 'out', 'dtype')
52575276
if axis is None:
52585277
axis = self._stat_axis_number
52595278
else:
@@ -5288,6 +5307,7 @@ def _make_logical_function(name, name1, name2, axis_descr, desc, f):
52885307
@Appender(_bool_doc)
52895308
def logical_func(self, axis=None, bool_only=None, skipna=None, level=None,
52905309
**kwargs):
5310+
_validate_kwargs(name, kwargs, 'out', 'dtype')
52915311
if skipna is None:
52925312
skipna = True
52935313
if axis is None:

pandas/tests/test_generic.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
from pandas.compat import range, zip
1818
from pandas import compat
19-
from pandas.util.testing import (assert_series_equal,
19+
from pandas.util.testing import (assertRaisesRegexp,
20+
assert_series_equal,
2021
assert_frame_equal,
2122
assert_panel_equal,
2223
assert_panel4d_equal,
2324
assert_almost_equal,
2425
assert_equal)
26+
2527
import pandas.util.testing as tm
2628

2729

@@ -483,8 +485,6 @@ def test_split_compat(self):
483485
self.assertTrue(len(np.array_split(o, 2)) == 2)
484486

485487
def test_unexpected_keyword(self): # GH8597
486-
from pandas.util.testing import assertRaisesRegexp
487-
488488
df = DataFrame(np.random.randn(5, 2), columns=['jim', 'joe'])
489489
ca = pd.Categorical([0, 0, 2, 2, 3, np.nan])
490490
ts = df['joe'].copy()
@@ -502,6 +502,20 @@ def test_unexpected_keyword(self): # GH8597
502502
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
503503
ts.fillna(0, in_place=True)
504504

505+
# See gh-12301
506+
def test_stat_unexpected_keyword(self):
507+
obj = self._construct(5)
508+
starwars = 'Star Wars'
509+
510+
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
511+
obj.max(epic=starwars) # stat_function
512+
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
513+
obj.var(epic=starwars) # stat_function_ddof
514+
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
515+
obj.sum(epic=starwars) # cum_function
516+
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
517+
obj.any(epic=starwars) # logical_function
518+
505519

506520
class TestSeries(tm.TestCase, Generic):
507521
_typ = Series

0 commit comments

Comments
 (0)