Skip to content

Commit 538a818

Browse files
committed
fix: plot.scatter s parameter cannot accept float-like column
1 parent 1caac27 commit 538a818

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

bigframes/operations/_matplotlib/core.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,18 @@ def _compute_plot_data(self):
115115
if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE:
116116
sample[c] = sample[c].astype("object")
117117

118+
# To avoid Matplotlib's automatic conversion of `Float64` or `Int64` columns
119+
# to `object` types (which breaks float-like behavior), this code proactively
120+
# converts the column to a compatible format."
121+
s = self.kwargs.get("s", None)
122+
if pd.core.dtypes.common.is_integer(s):
123+
s = self.data.columns[s]
124+
if self._is_column_name(s, sample):
125+
if sample[s].dtype == dtypes.INT_DTYPE:
126+
sample[s] = sample[s].astype("int64")
127+
elif sample[s].dtype == dtypes.FLOAT_DTYPE:
128+
sample[s] = sample[s].astype("float64")
129+
118130
return sample
119131

120132
def _is_sequence_arg(self, arg):
@@ -130,9 +142,3 @@ def _is_column_name(self, arg, data):
130142
and pd.core.dtypes.common.is_hashable(arg)
131143
and arg in data.columns
132144
)
133-
134-
def _generate_new_column_name(self, data):
135-
col_name = None
136-
while col_name is None or col_name in data.columns:
137-
col_name = f"plot_temp_{str(uuid.uuid4())[:8]}"
138-
return col_name

tests/system/small/operations/test_plotting.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,32 @@ def test_scatter_args_c(c):
240240
)
241241

242242

243+
@pytest.mark.parametrize(
244+
("s"),
245+
[
246+
pytest.param([10, 34, 50], id="int"),
247+
pytest.param([1.0, 3.4, 5.0], id="float"),
248+
pytest.param(
249+
[True, True, False], id="bool", marks=pytest.mark.xfail(raises=ValueError)
250+
),
251+
],
252+
)
253+
def test_scatter_args_s(s):
254+
data = {
255+
"a": [1, 2, 3],
256+
"b": [1, 2, 3],
257+
}
258+
data["s"] = s
259+
df = bpd.DataFrame(data)
260+
pd_df = pd.DataFrame(data)
261+
262+
ax = df.plot.scatter(x="a", y="b", s="s")
263+
pd_ax = pd_df.plot.scatter(x="a", y="b", s="s")
264+
tm.assert_numpy_array_equal(
265+
ax.collections[0].get_sizes(), pd_ax.collections[0].get_sizes()
266+
)
267+
268+
243269
@pytest.mark.parametrize(
244270
("arg_name"),
245271
[
@@ -255,7 +281,6 @@ def test_scatter_sequence_arg(arg_name):
255281
arg_value = [3, 3, 1]
256282
bpd.DataFrame(data).plot.scatter(x="a", y="b", **{arg_name: arg_value})
257283

258-
259284
def test_sampling_plot_args_n():
260285
df = bpd.DataFrame(np.arange(bf_mpl.DEFAULT_SAMPLING_N * 10), columns=["one"])
261286
ax = df.plot.line()

0 commit comments

Comments
 (0)