From 18facdffabed8b4d3fe6b01a3f92cd6efcf9aaa4 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Sat, 14 Jun 2014 15:25:10 +0900 Subject: [PATCH] BUG/CLN: LinePlot uses incorrect xlim when secondary_y=True --- doc/source/v0.14.1.txt | 5 ++ pandas/tests/test_graphics.py | 35 +++++++++++++ pandas/tools/plotting.py | 92 ++++++++--------------------------- pandas/tseries/plotting.py | 14 +----- 4 files changed, 62 insertions(+), 84 deletions(-) diff --git a/doc/source/v0.14.1.txt b/doc/source/v0.14.1.txt index 9e992573f568d..4ef0110013925 100644 --- a/doc/source/v0.14.1.txt +++ b/doc/source/v0.14.1.txt @@ -148,6 +148,11 @@ Performance + +- Bug in line plot doesn't set correct ``xlim`` if ``secondary_y=True`` (:issue:`7459`) + + + Experimental ~~~~~~~~~~~~ diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index dfdd37c468a85..b195e668a731b 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -457,6 +457,20 @@ def test_plot_figsize_and_title(self): self._check_text_labels(ax.title, 'Test') self._check_axes_shape(ax, axes_num=1, layout=(1, 1), figsize=(16, 8)) + def test_ts_line_lim(self): + ax = self.ts.plot() + xmin, xmax = ax.get_xlim() + lines = ax.get_lines() + self.assertEqual(xmin, lines[0].get_data(orig=False)[0][0]) + self.assertEqual(xmax, lines[0].get_data(orig=False)[0][-1]) + tm.close() + + ax = self.ts.plot(secondary_y=True) + xmin, xmax = ax.get_xlim() + lines = ax.get_lines() + self.assertEqual(xmin, lines[0].get_data(orig=False)[0][0]) + self.assertEqual(xmax, lines[0].get_data(orig=False)[0][-1]) + def test_ts_area_lim(self): ax = self.ts.plot(kind='area', stacked=False) xmin, xmax = ax.get_xlim() @@ -1091,6 +1105,27 @@ def test_line_area_nan_df(self): self.assert_numpy_array_equal(ax.lines[0].get_ydata(), expected1) self.assert_numpy_array_equal(ax.lines[1].get_ydata(), expected2) + def test_line_lim(self): + df = DataFrame(rand(6, 3), columns=['x', 'y', 'z']) + ax = df.plot() + xmin, xmax = ax.get_xlim() + lines = ax.get_lines() + self.assertEqual(xmin, lines[0].get_data()[0][0]) + self.assertEqual(xmax, lines[0].get_data()[0][-1]) + + ax = df.plot(secondary_y=True) + xmin, xmax = ax.get_xlim() + lines = ax.get_lines() + self.assertEqual(xmin, lines[0].get_data()[0][0]) + self.assertEqual(xmax, lines[0].get_data()[0][-1]) + + axes = df.plot(secondary_y=True, subplots=True) + for ax in axes: + xmin, xmax = ax.get_xlim() + lines = ax.get_lines() + self.assertEqual(xmin, lines[0].get_data()[0][0]) + self.assertEqual(xmax, lines[0].get_data()[0][-1]) + def test_area_lim(self): df = DataFrame(rand(6, 4), columns=['x', 'y', 'z', 'four']) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 7f2f583c5e20e..b09896788a21d 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -909,13 +909,9 @@ def _args_adjust(self): pass def _maybe_right_yaxis(self, ax): - _types = (list, tuple, np.ndarray) - sec_true = isinstance(self.secondary_y, bool) and self.secondary_y - list_sec = isinstance(self.secondary_y, _types) - has_sec = list_sec and len(self.secondary_y) > 0 - all_sec = list_sec and len(self.secondary_y) == self.nseries - - if (sec_true or has_sec) and not hasattr(ax, 'right_ax'): + if hasattr(ax, 'right_ax'): + return ax.right_ax + else: orig_ax, new_ax = ax, ax.twinx() new_ax._get_lines.color_cycle = orig_ax._get_lines.color_cycle @@ -924,38 +920,25 @@ def _maybe_right_yaxis(self, ax): if len(orig_ax.get_lines()) == 0: # no data on left y orig_ax.get_yaxis().set_visible(False) - - if len(new_ax.get_lines()) == 0: - new_ax.get_yaxis().set_visible(False) - - if sec_true or all_sec: - ax = new_ax - else: - ax.get_yaxis().set_visible(True) - - return ax + return new_ax def _setup_subplots(self): if self.subplots: nrows, ncols = self._get_layout() fig, axes = _subplots(nrows=nrows, ncols=ncols, sharex=self.sharex, sharey=self.sharey, - figsize=self.figsize, ax=self.ax, - secondary_y=self.secondary_y, - data=self.data) + figsize=self.figsize, ax=self.ax) if not com.is_list_like(axes): axes = np.array([axes]) else: if self.ax is None: fig = self.plt.figure(figsize=self.figsize) ax = fig.add_subplot(111) - ax = self._maybe_right_yaxis(ax) else: fig = self.ax.get_figure() if self.figsize is not None: fig.set_size_inches(self.figsize) - ax = self._maybe_right_yaxis(self.ax) - + ax = self.ax axes = [ax] if self.logx or self.loglog: @@ -1182,14 +1165,21 @@ def _get_ax(self, i): # get the twinx ax if appropriate if self.subplots: ax = self.axes[i] + + if self.on_right(i): + ax = self._maybe_right_yaxis(ax) + self.axes[i] = ax else: ax = self.axes[0] - if self.on_right(i): - if hasattr(ax, 'right_ax'): - ax = ax.right_ax - elif hasattr(ax, 'left_ax'): - ax = ax.left_ax + if self.on_right(i): + ax = self._maybe_right_yaxis(ax) + + sec_true = isinstance(self.secondary_y, bool) and self.secondary_y + all_sec = (com.is_list_like(self.secondary_y) and + len(self.secondary_y) == self.nseries) + if sec_true or all_sec: + self.axes[0] = ax ax.get_yaxis().set_visible(True) return ax @@ -1550,8 +1540,6 @@ def _make_plot(self): data = self._maybe_convert_index(self.data) self._make_ts_plot(data) else: - from pandas.core.frame import DataFrame - lines = [] x = self._get_xticks(convert_period=True) plotf = self._get_plot_function() @@ -1563,8 +1551,6 @@ def _make_plot(self): kwds = self.kwds.copy() self._maybe_add_color(colors, kwds, style, i) - lines += _get_all_lines(ax) - errors = self._get_errorbars(label=label, index=i) kwds = dict(kwds, **errors) @@ -1588,15 +1574,13 @@ def _make_plot(self): newlines = plotf(*args, **kwds) self._add_legend_handle(newlines[0], label, index=i) - lines.append(newlines[0]) - if self.stacked and not self.subplots: if (y >= 0).all(): self._pos_prior += y elif (y <= 0).all(): self._neg_prior += y - if self._is_datetype(): + lines = _get_all_lines(ax) left, right = _get_xlim(lines) ax.set_xlim(left, right) @@ -2253,14 +2237,7 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None, import matplotlib.pyplot as plt if ax is None and len(plt.get_fignums()) > 0: ax = _gca() - if ax.get_yaxis().get_ticks_position().strip().lower() == 'right': - fig = _gcf() - axes = fig.get_axes() - for i in reversed(range(len(axes))): - ax = axes[i] - ypos = ax.get_yaxis().get_ticks_position().strip().lower() - if ypos == 'left': - break + ax = getattr(ax, 'left_ax', ax) # is there harm in this? if label is None: @@ -2890,8 +2867,7 @@ def _get_layout(nplots, layout=None): def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=True, - subplot_kw=None, ax=None, secondary_y=False, data=None, - **fig_kw): + subplot_kw=None, ax=None, **fig_kw): """Create a figure with a set of subplots already made. This utility wrapper makes it convenient to create common layouts of @@ -2932,12 +2908,6 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze= ax : Matplotlib axis object, optional - secondary_y : boolean or sequence of ints, default False - If True then y-axis will be on the right - - data : DataFrame, optional - If secondary_y is a sequence, data is used to select columns. - fig_kw : Other keyword arguments to be passed to the figure() call. Note that all keywords not recognized above will be automatically included here. @@ -2993,22 +2963,8 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze= axarr = np.empty(nplots, dtype=object) - def on_right(i): - if isinstance(secondary_y, bool): - return secondary_y - if isinstance(data, DataFrame): - return data.columns[i] in secondary_y - # Create first subplot separately, so we can share it if requested ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw) - if on_right(0): - orig_ax = ax0 - ax0 = ax0.twinx() - ax0._get_lines.color_cycle = orig_ax._get_lines.color_cycle - - orig_ax.get_yaxis().set_visible(False) - orig_ax.right_ax = ax0 - ax0.left_ax = orig_ax if sharex: subplot_kw['sharex'] = ax0 @@ -3020,12 +2976,6 @@ def on_right(i): # convention. for i in range(1, nplots): ax = fig.add_subplot(nrows, ncols, i + 1, **subplot_kw) - if on_right(i): - orig_ax = ax - ax = ax.twinx() - ax._get_lines.color_cycle = orig_ax._get_lines.color_cycle - - orig_ax.get_yaxis().set_visible(False) axarr[i] = ax if nplots > 1: diff --git a/pandas/tseries/plotting.py b/pandas/tseries/plotting.py index e390607a0e7e2..6031482fd9927 100644 --- a/pandas/tseries/plotting.py +++ b/pandas/tseries/plotting.py @@ -4,11 +4,7 @@ """ #!!! TODO: Use the fact that axis can have units to simplify the process -import datetime as pydt -from datetime import datetime - from matplotlib import pylab -import matplotlib.units as units import numpy as np @@ -22,7 +18,7 @@ from pandas.tseries.converter import (PeriodConverter, TimeSeries_DateLocator, TimeSeries_DateFormatter) -from pandas.tools.plotting import _get_all_lines +from pandas.tools.plotting import _get_all_lines, _get_xlim #---------------------------------------------------------------------- # Plotting functions and monkey patches @@ -222,14 +218,6 @@ def _get_freq(ax, series): return freq -def _get_xlim(lines): - left, right = np.inf, -np.inf - for l in lines: - x = l.get_xdata() - left = min(x[0].ordinal, left) - right = max(x[-1].ordinal, right) - return left, right - # Patch methods for subplot. Only format_dateaxis is currently used. # Do we need the rest for convenience?