Skip to content

Commit 5331462

Browse files
Undoing the fix, due to breaking 'all' mode
- I ended up refactoring the code to try to find the problem, but it turns out my solution breaks when all mode is selected, so I have commented that part out - I am including this, because I thought the refactor is nice, but this ended up not being a good fix
1 parent 4b5cb84 commit 5331462

2 files changed

Lines changed: 117 additions & 113 deletions

File tree

plotly/_subplots.py

Lines changed: 111 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
# properties.
77
# Note that this set does not contain `xaxis`/`yaxis` because these behave a
88
# little differently.
9+
from __future__ import annotations
910
import collections
1011

11-
import plotly.graph_objects as go
12-
from typing import Literal, Optional, Tuple, TypedDict, Iterable
12+
from typing import Literal, Optional, Tuple, TypedDict, TYPE_CHECKING
13+
if TYPE_CHECKING:
14+
from plotly.graph_objects import Layout, XAxis
1315

1416
_single_subplot_types = {"scene", "geo", "polar", "ternary", "map", "mapbox"}
1517
_subplot_types = set.union(_single_subplot_types, {"xy", "domain"})
@@ -38,8 +40,9 @@ class SubplotSpec(TypedDict):
3840
type : Literal['xy', 'scene', 'polar', 'ternary', 'map', 'mapbox', 'domain'] | str
3941
secondary_y : bool
4042
colspan : int
41-
rowspan : int
42-
l : float
43+
rowspan : int
44+
# NOTE: that this is the dictionary as defined by the documentation, so the ambiguous name 'l' can't be changed without changing the documentation
45+
l : float # noqa: E741
4346
r : float
4447
t : float
4548
b : float
@@ -759,19 +762,10 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname):
759762
)
760763
grid_ref[r][c] = subplot_refs
761764

762-
_configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir, False)
763-
_configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir, False)
764765

765-
any_secondary_y = any(
766-
spec["secondary_y"]
767-
for spec_row in specs
768-
for spec in spec_row
769-
if spec is not None
770-
)
771-
if any_secondary_y:
772-
_configure_shared_axes(
773-
layout, grid_ref, specs, "y", shared_yaxes, row_dir, True
774-
)
766+
_configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir)
767+
_configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir)
768+
775769

776770
# Build inset reference
777771
# ---------------------
@@ -903,172 +897,180 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname):
903897
return figure
904898

905899
def _configure_shared_axes(
906-
layout : go.Layout,
900+
layout : Layout,
907901
grid_ref : Tuple[Tuple[SubplotRef]],
908902
specs : Tuple[Tuple[SubplotSpec]],
909903
x_or_y : Literal['x', 'y'],
910904
shared : bool | Literal['rows', 'columns', 'all'],
911-
row_direction : Literal[1, -1],
912-
secondary_y : bool
905+
row_direction : Literal[1, -1]
913906
) -> None:
914907
'''
915-
Sets the axes to be shared, making them use the same axis
908+
Sets the axes to be shared, making them use the same axis
916909
917910
Parameters:
918911
-----------
919912
layout (go.Layout) : The layout of the figure to be updating
920913
grid_ref (Tuple[Tuple[SubplotRef]]) : The grid of subplots within the figure; grid_ref[row][column] = subplot at that coordinate
921914
specs (Tuple[Tuple[SubplotSpec]]) : The specifications of each of the subplots within the figure; specs[row][column] = specs of the subplot at that coordinate
922-
x_or_y ('x' | 'y') : The axis to make shared (x-axis or y-axis)
923-
shared ('rows' | 'columns' | 'all' | bool) : Share the axis within the row, column, or across all of the subplots (True defaults to columns mode)
924-
row_direction (1 | -1) : The directional that the rows go
925-
secondary_y (bool) : Whether there are different or shared y-axis
915+
x_or_y ('x' | 'y') : The axis to configure
916+
shared ('rows' | 'columns' | 'all' | bool) : The sharing mode, (True is 'columns' mode, False means no sharing) ie share the axis with all subplots in the corresponding row, column, or entire figure
917+
row_direction (1 | -1) : The directional that the rows go
926918
'''
927919

928920
row_count : int = len(grid_ref)
929921
column_count : int = len(grid_ref[0])
930922

931-
rows : Iterable[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count))
932-
columns : Iterable[int] = tuple(range(column_count - 1, -1, -1)) if secondary_y else tuple(range(column_count))
933-
934-
axis_index : int = 1 if secondary_y else 0
935-
layout_axis_index : int = 0 if x_or_y == 'x' else 1
923+
axis_index : int = 0 if x_or_y == 'x' else 1
936924

937-
def find_label_and_index(row_order : int | Iterable[int], column_order : int | Iterable[int]) -> Optional[Tuple[str, Tuple[int, int]]]:
925+
def find_label_and_index(row_order : int | Tuple[int], column_order : int | Tuple[int], trace_layer : int) -> Optional[Tuple[str, Tuple[int, int]]]:
938926
'''
939-
Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists
927+
Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS
940928
941929
Parameters:
942930
-----------
943-
row_order (int | Iterable[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable
944-
column_order (int | Iterable[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable
945-
931+
row_order (int | Tuple[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable
932+
column_order (int | Tuple[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable
933+
trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1]
946934
Return:
947935
-------
948936
Returns (Label : str, (Row : int, Column : int)): returning the label found, and the row and column it was found at (uses x_or_y to determine which of the axes' labels to pull)
949937
Return (None): No label was found
950938
'''
951939

952940
# Turn them into lists with one element, so that both row_order and column_order are iterables
953-
row_order : Iterable[int] = [row_order] if isinstance(row_order, int) else row_order
954-
column_order : Iterable[int] = [column_order] if isinstance(column_order, int) else column_order
941+
row_order : Tuple[int] = [row_order] if isinstance(row_order, int) else row_order
942+
column_order : Tuple[int] = [column_order] if isinstance(column_order, int) else column_order
955943

956944

957945
# Iterate through the rows and columns
958946
for row in row_order:
959947
for column in column_order:
960-
if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]):
948+
if not grid_ref[row][column]:
961949
continue
962-
963-
subplot_reference : SubplotRef = grid_ref[row][column][axis_index]
964-
spec : SubplotSpec = specs[row][column]
965950

966-
if subplot_reference is None:
967-
continue
968-
969-
span = spec['colspan'] if x_or_y == 'x' else spec['rowspan']
970-
if subplot_reference.subplot_type != 'xy' or span != 1:
951+
subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column]
952+
subplot_spec : SubplotSpec = specs[row][column]
953+
954+
span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan']
955+
if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces):
956+
continue
957+
958+
trace = subplot_traces[trace_layer]
959+
if trace is None or trace.subplot_type != 'xy':
971960
continue
972961

973-
label_name : str = subplot_reference.layout_keys[layout_axis_index]
962+
label_name : str = trace.layout_keys[axis_index]
974963
label : str = label_name.replace("axis", "")
975964
return label, (row, column)
976965
return None
977966

978967

979-
def update_trace_axis(matched_label : str, row : int, column : int, can_remove_label : bool) -> None:
968+
def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : int, can_reassign_axis : bool, can_hide_ticks : bool, can_match_axis : bool) -> None:
980969
'''
981-
Updates the trace at the given row and column with the given label, and removes the label visibility if necessary
970+
Updates the specific subplot trace at the given row and column with the given label, and removes the label visibility if necessary; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS
982971
983972
Parameters:
984973
-----------
985-
matched_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location
986-
row (int) : The row of the subplot within grid_ref to update
987-
column (int) : The column of the subplot within grid_ref to update
988-
can_remove_label (bool): Whether the label should be visible (only the bottom label should be visible)
989-
can_change_trace_kwargs (bool): If True the label itself can be changed directly to be the exact same axis (ie use the exact same axis in the trace keyword arguments), or if False, can only mark as matching (ie don't change the trace keyword args)
974+
axis_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location
975+
row (int) : The row of the subplot within grid_ref to update
976+
column (int) : The column of the subplot within grid_ref to update
977+
trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1]
978+
can_reassign_axis (bool): If True, can change the unique axis for the shared axis in the trace keywords, otherwise, will keep using the axis name it already has
979+
can_hide_ticks (bool): If the function is allowed to hide the ticks (if True, it will hide the ticks, if False, it will leave the ticks as their current state)
980+
can_match_axis (bool): If the axis should be marked as a match to the axis label
990981
'''
991-
if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]):
992-
return
993-
994-
subplot_reference : SubplotRef = grid_ref[row][column][axis_index]
995-
spec : SubplotSpec = specs[row][column]
996-
997-
if subplot_reference is None:
998-
return
999982

1000-
span = spec['colspan'] if x_or_y == 'x' else spec['rowspan']
1001-
if subplot_reference.subplot_type != 'xy' or span != 1:
1002-
return
983+
if not grid_ref[row][column] or specs[row][column] is None:
984+
return
1003985

1004-
axis_name : str = subplot_reference.layout_keys[layout_axis_index]
1005-
axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis'
1006-
axis : go.XAxis = layout[axis_name]
986+
subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column]
987+
subplot_spec : SubplotSpec = specs[row][column]
1007988

1008-
axis.matches = matched_label
1009-
subplot_reference.trace_kwargs[axis_dimension] = matched_label
989+
span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan']
990+
if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces):
991+
return
1010992

1011-
if can_remove_label:
1012-
axis.showticklabels = False
993+
trace : Optional[SubplotRef] = subplot_traces[trace_layer]
994+
995+
if trace is None or trace.subplot_type != 'xy' or span != 1:
996+
return
997+
998+
axis_name : str = trace.layout_keys[axis_index]
999+
axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis'
1000+
axis : XAxis = layout[axis_name]
10131001

1014-
def columns_mode():
1002+
if can_match_axis:
1003+
axis.matches = axis_label
1004+
1005+
if can_hide_ticks:
1006+
axis.showticklabels = False
1007+
1008+
if can_reassign_axis:
1009+
# trace.trace_kwargs[axis_dimension] = axis_label
1010+
pass
1011+
1012+
def columns_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int):
10151013
for column in columns:
10161014
# Get the label used by all the rows in the column
1017-
label_data = find_label_and_index(rows, column)
1015+
label_data = find_label_and_index(rows, column, trace_layer)
10181016
if label_data is None:
10191017
continue
1020-
column_label, (label_row, _) = label_data
1021-
# Set all of the values in the column
1018+
axis_label, (label_row, _) = label_data
10221019

1023-
can_remove_label : bool = (x_or_y == 'x')
1024-
1020+
# Set all of the values in the column
10251021
for row in rows:
1026-
if row == label_row: # Don't update the figure that the label we are matching comes from
1027-
continue
1028-
1029-
update_trace_axis(column_label, row, column, can_remove_label)
1022+
subplot_spec : SubplotSpec = specs[row][column]
1023+
can_reassign_axis : bool = (x_or_y != 'y' or not subplot_spec["secondary_y"]) # Every subplot in the same column should share the same axis if in columns mode
1024+
can_match_axis : bool = (row != label_row)
1025+
can_hide_ticks : bool = can_match_axis and x_or_y == 'x' # Sharing column wise can only hide x-axis; still need all of the different y-axis across plots in the same columns
1026+
1027+
update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis)
10301028

10311029

1032-
def rows_mode():
1030+
def rows_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int):
10331031
for row in rows:
1034-
label_data = find_label_and_index(row, columns)
1032+
label_data = find_label_and_index(row, columns, trace_layer)
10351033
if label_data is None:
10361034
continue
1037-
row_label, (_, label_column) = label_data
1038-
1039-
can_remove_label : bool = (x_or_y == 'y')
1035+
axis_label, (_, label_column) = label_data
10401036

10411037
for column in columns:
1042-
if column == label_column: # Don't update the figure that the label we are matching comes from
1043-
continue
1038+
spec : SubplotSpec = specs[row][column]
1039+
can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y'])
1040+
can_match_axis : bool = (column != label_column)
1041+
can_hide_ticks : bool = can_match_axis and x_or_y == 'y' # Sharing row wise can only hide y-axis; still need all of the different x-axis across plots in the same row
10441042

1045-
update_trace_axis(row_label, row, column, can_remove_label)
1043+
update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis)
10461044

1047-
def all_mode():
1048-
label_data = find_label_and_index(rows, columns)
1045+
def all_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int):
1046+
label_data = find_label_and_index(rows, columns, trace_layer)
10491047
if label_data is None:
10501048
return
1051-
label, (label_row, label_column) = label_data
1049+
axis_label, (label_row, label_column) = label_data
10521050

1053-
for row_index, row in enumerate(rows):
1051+
for row in rows:
10541052
for column in columns:
1055-
if row == label_row and column == label_column: # Don't update the figure that the label we are matching comes from
1056-
continue
1057-
1058-
if x_or_y == 'y':
1059-
can_remove_label : bool = (column < column_count - 1 if secondary_y else column > 0)
1060-
else:
1061-
can_remove_label : bool = (row_index > 0 if row_direction > 0 else row < row_count - 1)
1062-
1063-
update_trace_axis(label, row, column, can_remove_label)
1053+
spec : SubplotSpec = specs[row][column]
1054+
can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y'])
1055+
can_match_axis : bool = (row != label_row or column != label_column)
1056+
can_hide_ticks : bool = not ((row == label_row and x_or_y == 'x') or (column == label_column and x_or_y == 'y')) # The x-axis is across the first row, and the y-axis is along the first column
1057+
update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis)
10641058

1065-
match(shared, x_or_y, shared):
1066-
case ('columns', _, _) | (_, 'x', True): # If columns mode, or shared and x
1067-
columns_mode()
1068-
case ('rows', _, _) | (_, 'y', True): # If rows mode, or shared and y
1069-
rows_mode()
1070-
case ('all', _, _): # If all mode
1071-
all_mode()
1059+
1060+
rows : Tuple[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count))
1061+
columns : Tuple[int] = tuple(range(column_count))
1062+
BASE_TRACE_LAYER = 0
1063+
SECOND_Y_LAYER = 1
1064+
match(shared, x_or_y):
1065+
case ('columns', _) | (True, 'x'): # If columns mode, or shared and x
1066+
columns_mode(rows, columns, BASE_TRACE_LAYER)
1067+
columns_mode(tuple(reversed(rows)), columns, SECOND_Y_LAYER)
1068+
case ('rows', _) | (True, 'y'): # If rows mode, or shared and y
1069+
rows_mode(rows, columns, BASE_TRACE_LAYER)
1070+
rows_mode(rows, tuple(reversed(columns)), SECOND_Y_LAYER)
1071+
case ('all', _): # If all mode
1072+
all_mode(rows, columns, BASE_TRACE_LAYER)
1073+
all_mode(tuple(reversed(rows)), tuple(reversed(columns)), SECOND_Y_LAYER)
10721074
case _: # If reached the other case
10731075
return
10741076

tests/test_core/test_subplots/test_make_subplots.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,8 @@ def test_subplot_titles_shared_axes_rows_columns(self):
14651465
shared_xaxes="rows",
14661466
shared_yaxes="columns",
14671467
)
1468+
print(f'Expected {expected}')
1469+
print(f'Actual: {fig}')
14681470
self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json())
14691471

14701472
def test_subplot_titles_irregular_layout(self):
@@ -1848,8 +1850,8 @@ def test_secondary_y_subplots(self):
18481850
fig.add_scatter(y=[0, 2, 4], name="Fifth", row=2, col=1)
18491851
fig.add_scatter(y=[2, 1, 3], name="Sixth", row=2, col=1, secondary_y=True)
18501852

1851-
fig.add_scatter(y=[2, 4, 0], name="Fifth", row=2, col=2)
1852-
fig.add_scatter(y=[2, 3, 6], name="Sixth", row=2, col=2, secondary_y=True)
1853+
fig.add_scatter(y=[2, 4, 0], name="Seventh", row=2, col=2)
1854+
fig.add_scatter(y=[2, 3, 6], name="Eighth", row=2, col=2, secondary_y=True)
18531855

18541856
fig.update_traces(uid=None)
18551857

@@ -1899,14 +1901,14 @@ def test_secondary_y_subplots(self):
18991901
"yaxis": "y6",
19001902
},
19011903
{
1902-
"name": "Fifth",
1904+
"name": "Seventh",
19031905
"type": "scatter",
19041906
"xaxis": "x4",
19051907
"y": [2, 4, 0],
19061908
"yaxis": "y7",
19071909
},
19081910
{
1909-
"name": "Sixth",
1911+
"name": "Eighth",
19101912
"type": "scatter",
19111913
"xaxis": "x4",
19121914
"y": [2, 3, 6],

0 commit comments

Comments
 (0)