Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified examples/figures/fig7b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/figures/genx_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 66 additions & 10 deletions src/stonerplots/context/multiple_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.ticker import FixedLocator

from ..counter import counter
from .base import PlotContextSequence, PreserveFigureMixin, RavelList
Expand Down Expand Up @@ -410,22 +411,77 @@ def _align_labels(self) -> None:
ax.yaxis.set_label_coords(label_pos, 0.5)

def _fix_limits(self, ix: int, ax: Axes) -> None:
"""Adjust the y-axis limits to ensure tick labels are inside the axes frame."""
"""Adjust the y-axis limits to ensure tick labels are inside the axes frame.

For joined (no-gap) stacked subplots the frame lines of adjacent panels share the
same y-coordinate. A tick label whose tick mark sits close to the shared edge will
have half its text rendered inside the neighbouring panel. This method expands the
y-axis limits just enough so that every visible tick mark sits at least *dy* (half
a label height in axes units) away from the top and bottom frame edges, then freezes
the tick positions with a :class:`~matplotlib.ticker.FixedLocator` so that the
enlarged limits cannot cause the auto-locator to add new edge ticks that would
recreate the problem.
"""
fig = self.figure
# Log-scale axes use unevenly spaced ticks; the formula below assumes a linear
# mapping from data to axes units, so skip the adjustment for non-linear scales.
if ax.get_yscale() != "linear":
return
ticklabels = ax.yaxis.get_ticklabels()
if not ticklabels:
return # No tick labels to adjust for
fnt_pts_val = ticklabels[0].get_fontsize()
fnt_pts = float(fnt_pts_val) if isinstance(fnt_pts_val, (int, float, str)) else 10.0
ax_height = ax.bbox.transformed(fig.transFigure.inverted()).height * fig.get_figheight() * 72
dy = 1.40 * fnt_pts / ax_height # Space needed in axes units for labels 7/5 font size.
if ax_height <= 0:
return
dy = 1.40 * fnt_pts / ax_height # Space needed in axes units for one full label height.

ylim = list(ax.get_ylim())
tr = ax.transData + ax.transAxes.inverted() # Transform data to axes units
yticks = [tr.transform((0, x))[1] for x in ax.get_yticks()] # Tick positions in axes units.

if len(yticks) > 1 and yticks[1] < dy and ix != len(self.axes) - 1: # Adjust range for non-bottom plots
ylim[0] = tr.inverted().transform((0, -dy))[1]
if len(yticks) > 2 and yticks[-2] < 1.0 - dy and ix != 0: # Adjust range for non-top plots
ylim[1] = tr.inverted().transform((0, 1 + dy))[1]
ax.set_ylim(ylim[0], ylim[1])
yticks_data = ax.get_yticks()
tr = ax.transData + ax.transAxes.inverted() # Transform: data → axes units (0–1)
yticks_axes = [tr.transform((0, x))[1] for x in yticks_data] # Tick positions in axes units.

# Identify ticks that lie within (or just outside) the visible frame so we know
# which are the extreme rendered ticks. A small tolerance of 0.01 in axes units
# admits ticks that the auto-locator may place right at the boundary.
visible = [(td, ta) for td, ta in zip(yticks_data, yticks_axes) if -0.01 <= ta <= 1.01]
if not visible:
return

t_bottom_data, t_bottom_ax = min(visible, key=lambda x: x[1])
t_top_data, t_top_ax = max(visible, key=lambda x: x[1])

adjust_lower = ix != len(self.axes) - 1 # All subplots except the bottom-most
adjust_upper = ix != 0 # All subplots except the top-most

needs_lower = adjust_lower and t_bottom_ax < dy
needs_upper = adjust_upper and t_top_ax > 1.0 - dy

new_ylim = list(ylim)

if needs_lower and needs_upper:
# Both edges need padding: solve the system so that t_bottom sits at dy and
# t_top sits at (1 - dy) in the new axes coordinate frame.
tick_range = t_top_data - t_bottom_data
if tick_range <= 0 or 1.0 - 2.0 * dy <= 0:
return
total_range = tick_range / (1.0 - 2.0 * dy) # 2.0: one dy margin on each edge
new_ylim[0] = t_bottom_data - dy * total_range
new_ylim[1] = t_top_data + dy * total_range
elif needs_lower:
# Only the bottom edge needs padding; keep the upper limit fixed.
# Solve: dy == (t_bottom_data - new_lower) / (ylim[1] - new_lower)
new_ylim[0] = (t_bottom_data - dy * ylim[1]) / (1.0 - dy)
elif needs_upper:
# Only the top edge needs padding; keep the lower limit fixed.
# Solve: (1 - dy) == (t_top_data - ylim[0]) / (new_upper - ylim[0])
new_ylim[1] = ylim[0] + (t_top_data - ylim[0]) / (1.0 - dy)

if new_ylim != ylim:
# Freeze tick positions so the enlarged limits do not cause the auto-locator
# to place new ticks near the frame edges, which would recreate the problem.
ax.yaxis.set_major_locator(FixedLocator([td for td, _ in visible]))

ax.set_ylim(new_ylim[0], new_ylim[1])
self.figure.canvas.draw()
185 changes: 185 additions & 0 deletions tests/stonerplots/test_stack_vertical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
"""Tests for StackVertical context manager, focusing on the _fix_limits robustness fix.

These tests verify that the `_fix_limits` method in StackVertical correctly adjusts
y-axis limits for joined subplots so that tick labels do not overflow into adjacent
panels, using the algebraic approach with a FixedLocator to prevent feedback ticks.
"""
import numpy as np
import pytest
from matplotlib import pyplot as plt
from matplotlib.ticker import FixedLocator

from stonerplots import StackVertical


def _axes_dy(ax) -> float:
"""Return the minimum required tick-to-edge clearance in axes units for *ax*."""
fig = ax.get_figure()
ticklabels = ax.yaxis.get_ticklabels()
if not ticklabels:
return 0.0
fnt_pts = float(ticklabels[0].get_fontsize())
ax_height = ax.bbox.transformed(fig.transFigure.inverted()).height * fig.get_figheight() * 72
return 1.40 * fnt_pts / ax_height if ax_height > 0 else 0.0


class TestFixLimitsBottomTickPadding:
"""Verify that the lowest visible tick is pushed away from the bottom frame edge."""

def test_bottom_tick_not_too_close_to_bottom_for_non_bottom_subplot(self):
"""For every non-bottom subplot the lowest visible tick must sit >= dy from the bottom."""
plt.figure()
with StackVertical(3, joined=True) as axes:
for ax in axes:
ax.plot([0, 1, 2], [0, 1, 0])

# Inspect each non-bottom subplot (indices 0 and 1 in a 3-panel stack)
for ix, ax in enumerate(axes[:-1]):
dy = _axes_dy(ax)
if dy <= 0:
continue
tr = ax.transData + ax.transAxes.inverted()
yticks_axes = [tr.transform((0, t))[1] for t in ax.get_yticks()]
visible = [ta for ta in yticks_axes if -0.01 <= ta <= 1.01]
if not visible:
continue
lowest = min(visible)
# Allow a small numerical tolerance of 1e-6 in the comparison.
assert lowest >= dy - 1e-6, (
f"Subplot {ix}: lowest visible tick at axes pos {lowest:.4f} "
f"is closer than dy={dy:.4f} to the bottom edge"
)

plt.close("all")

def test_top_tick_not_too_close_to_top_for_non_top_subplot(self):
"""For every non-top subplot the highest visible tick must sit >= dy from the top."""
plt.figure()
with StackVertical(3, joined=True) as axes:
for ax in axes:
ax.plot([0, 1, 2], [0, 1, 0])

# Inspect each non-top subplot (indices 1 and 2 in a 3-panel stack)
for ix, ax in enumerate(axes[1:], start=1):
dy = _axes_dy(ax)
if dy <= 0:
continue
tr = ax.transData + ax.transAxes.inverted()
yticks_axes = [tr.transform((0, t))[1] for t in ax.get_yticks()]
visible = [ta for ta in yticks_axes if -0.01 <= ta <= 1.01]
if not visible:
continue
highest = max(visible)
assert highest <= 1.0 - dy + 1e-6, (
f"Subplot {ix}: highest visible tick at axes pos {highest:.4f} "
f"is closer than dy={dy:.4f} to the top edge"
)

plt.close("all")


class TestFixLimitsLockerPreventsEdgeTicks:
"""Verify that a FixedLocator is installed after a limit adjustment."""

def test_fixed_locator_installed_when_limits_are_adjusted(self):
"""At least one non-bottom/non-top subplot should have a FixedLocator after exit."""
plt.figure()
with StackVertical(3, joined=True) as axes:
for ax in axes:
ax.plot([0, 1, 2], [0, 1, 0])

# The middle subplot is most likely to need both lower and upper adjustments.
# If the data happened to be perfectly spaced the locator may remain unchanged,
# but when an adjustment is made it must be a FixedLocator.
middle_ax = axes[1]
locator = middle_ax.yaxis.get_major_locator()
if isinstance(locator, FixedLocator):
ticks = locator.locs
assert len(ticks) > 0, "FixedLocator should contain at least one tick"
plt.close("all")


class TestFixLimitsNoJoined:
"""When joined=False, _fix_limits should not be called and limits remain default."""

def test_limits_unchanged_when_not_joined(self):
"""With joined=False the context manager should not adjust y limits at all."""
plt.figure()
with StackVertical(3, joined=False) as axes:
for ax in axes:
ax.plot([0, 1, 2], [0, 1, 0])

for ax in axes:
locator = ax.yaxis.get_major_locator()
assert not isinstance(locator, FixedLocator), (
"FixedLocator should not be applied when joined=False"
)
plt.close("all")


class TestFixLimitsTwoSubplots:
"""Sanity check for a 2-subplot stack (top and bottom only, no middle)."""

def test_two_subplot_stack_exits_cleanly(self):
"""A 2-panel StackVertical must exit without raising an exception."""
plt.figure()
with StackVertical(2, joined=True) as axes:
for ax in axes:
ax.plot([0, 1, 2], [0, 0.5, 0])
plt.close("all")

def test_two_subplot_bottom_tick_ok(self):
"""In a 2-panel stack the top subplot's bottom tick must be padded."""
plt.figure()
with StackVertical(2, joined=True) as axes:
for ax in axes:
ax.plot([0, 1, 2], [0, 0.5, 0])

top_ax = axes[0]
dy = _axes_dy(top_ax)
if dy > 0:
tr = top_ax.transData + top_ax.transAxes.inverted()
yticks_axes = [tr.transform((0, t))[1] for t in top_ax.get_yticks()]
visible = [ta for ta in yticks_axes if -0.01 <= ta <= 1.01]
if visible:
assert min(visible) >= dy - 1e-6
plt.close("all")


class TestFixLimitsEdgeCases:
"""Edge-case handling in _fix_limits."""

def test_single_tick_does_not_raise(self):
"""_fix_limits must not raise when an axis has only one visible tick."""
plt.figure()
with StackVertical(2, joined=True) as axes:
axes[0].plot([0, 1], [0.5, 0.5]) # Horizontal line → single tick at 0.5
plt.close("all")

def test_large_number_of_subplots(self):
"""A 5-panel stack should produce consistent padding on every panel."""
plt.figure()
with StackVertical(5, joined=True) as axes:
for ax in axes:
ax.plot([0, 1, 2], [0, 1, 0])

for ix, ax in enumerate(axes):
dy = _axes_dy(ax)
if dy <= 0:
continue
tr = ax.transData + ax.transAxes.inverted()
yticks_axes = [tr.transform((0, t))[1] for t in ax.get_yticks()]
visible = [ta for ta in yticks_axes if -0.01 <= ta <= 1.01]
if not visible:
continue
if ix != len(axes) - 1: # non-bottom
assert min(visible) >= dy - 1e-6, f"Panel {ix}: bottom tick too close to edge"
if ix != 0: # non-top
assert max(visible) <= 1.0 - dy + 1e-6, f"Panel {ix}: top tick too close to edge"

plt.close("all")


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading