|
6 | 6 | # properties. |
7 | 7 | # Note that this set does not contain `xaxis`/`yaxis` because these behave a |
8 | 8 | # little differently. |
| 9 | +from __future__ import annotations |
9 | 10 | import collections |
10 | 11 |
|
11 | | -import plotly.graph_objects as go |
12 | | -from typing import Literal, Optional, Tuple, TypedDict, Iterable |
| 12 | +from typing import Literal, Optional, Tuple, TypedDict, TYPE_CHECKING |
| 13 | +if TYPE_CHECKING: |
| 14 | + from plotly.graph_objects import Layout, XAxis |
13 | 15 |
|
14 | 16 | _single_subplot_types = {"scene", "geo", "polar", "ternary", "map", "mapbox"} |
15 | 17 | _subplot_types = set.union(_single_subplot_types, {"xy", "domain"}) |
@@ -38,8 +40,9 @@ class SubplotSpec(TypedDict): |
38 | 40 | type : Literal['xy', 'scene', 'polar', 'ternary', 'map', 'mapbox', 'domain'] | str |
39 | 41 | secondary_y : bool |
40 | 42 | colspan : int |
41 | | - rowspan : int |
42 | | - l : float |
| 43 | + rowspan : int |
| 44 | + # NOTE: that this is the dictionary as defined by the documentation, so the ambiguous name 'l' can't be changed without changing the documentation |
| 45 | + l : float # noqa: E741 |
43 | 46 | r : float |
44 | 47 | t : float |
45 | 48 | b : float |
@@ -759,19 +762,10 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): |
759 | 762 | ) |
760 | 763 | grid_ref[r][c] = subplot_refs |
761 | 764 |
|
762 | | - _configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir, False) |
763 | | - _configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir, False) |
764 | 765 |
|
765 | | - any_secondary_y = any( |
766 | | - spec["secondary_y"] |
767 | | - for spec_row in specs |
768 | | - for spec in spec_row |
769 | | - if spec is not None |
770 | | - ) |
771 | | - if any_secondary_y: |
772 | | - _configure_shared_axes( |
773 | | - layout, grid_ref, specs, "y", shared_yaxes, row_dir, True |
774 | | - ) |
| 766 | + _configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir) |
| 767 | + _configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir) |
| 768 | + |
775 | 769 |
|
776 | 770 | # Build inset reference |
777 | 771 | # --------------------- |
@@ -903,172 +897,180 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): |
903 | 897 | return figure |
904 | 898 |
|
905 | 899 | def _configure_shared_axes( |
906 | | - layout : go.Layout, |
| 900 | + layout : Layout, |
907 | 901 | grid_ref : Tuple[Tuple[SubplotRef]], |
908 | 902 | specs : Tuple[Tuple[SubplotSpec]], |
909 | 903 | x_or_y : Literal['x', 'y'], |
910 | 904 | shared : bool | Literal['rows', 'columns', 'all'], |
911 | | - row_direction : Literal[1, -1], |
912 | | - secondary_y : bool |
| 905 | + row_direction : Literal[1, -1] |
913 | 906 | ) -> None: |
914 | 907 | ''' |
915 | | - Sets the axes to be shared, making them use the same axis |
| 908 | + Sets the axes to be shared, making them use the same axis |
916 | 909 |
|
917 | 910 | Parameters: |
918 | 911 | ----------- |
919 | 912 | layout (go.Layout) : The layout of the figure to be updating |
920 | 913 | grid_ref (Tuple[Tuple[SubplotRef]]) : The grid of subplots within the figure; grid_ref[row][column] = subplot at that coordinate |
921 | 914 | specs (Tuple[Tuple[SubplotSpec]]) : The specifications of each of the subplots within the figure; specs[row][column] = specs of the subplot at that coordinate |
922 | | - x_or_y ('x' | 'y') : The axis to make shared (x-axis or y-axis) |
923 | | - shared ('rows' | 'columns' | 'all' | bool) : Share the axis within the row, column, or across all of the subplots (True defaults to columns mode) |
924 | | - row_direction (1 | -1) : The directional that the rows go |
925 | | - secondary_y (bool) : Whether there are different or shared y-axis |
| 915 | + x_or_y ('x' | 'y') : The axis to configure |
| 916 | + shared ('rows' | 'columns' | 'all' | bool) : The sharing mode, (True is 'columns' mode, False means no sharing) ie share the axis with all subplots in the corresponding row, column, or entire figure |
| 917 | + row_direction (1 | -1) : The directional that the rows go |
926 | 918 | ''' |
927 | 919 |
|
928 | 920 | row_count : int = len(grid_ref) |
929 | 921 | column_count : int = len(grid_ref[0]) |
930 | 922 |
|
931 | | - rows : Iterable[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count)) |
932 | | - columns : Iterable[int] = tuple(range(column_count - 1, -1, -1)) if secondary_y else tuple(range(column_count)) |
933 | | - |
934 | | - axis_index : int = 1 if secondary_y else 0 |
935 | | - layout_axis_index : int = 0 if x_or_y == 'x' else 1 |
| 923 | + axis_index : int = 0 if x_or_y == 'x' else 1 |
936 | 924 |
|
937 | | - def find_label_and_index(row_order : int | Iterable[int], column_order : int | Iterable[int]) -> Optional[Tuple[str, Tuple[int, int]]]: |
| 925 | + def find_label_and_index(row_order : int | Tuple[int], column_order : int | Tuple[int], trace_layer : int) -> Optional[Tuple[str, Tuple[int, int]]]: |
938 | 926 | ''' |
939 | | - Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists |
| 927 | + Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS |
940 | 928 |
|
941 | 929 | Parameters: |
942 | 930 | ----------- |
943 | | - row_order (int | Iterable[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable |
944 | | - column_order (int | Iterable[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable |
945 | | -
|
| 931 | + row_order (int | Tuple[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable |
| 932 | + column_order (int | Tuple[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable |
| 933 | + trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1] |
946 | 934 | Return: |
947 | 935 | ------- |
948 | 936 | Returns (Label : str, (Row : int, Column : int)): returning the label found, and the row and column it was found at (uses x_or_y to determine which of the axes' labels to pull) |
949 | 937 | Return (None): No label was found |
950 | 938 | ''' |
951 | 939 |
|
952 | 940 | # Turn them into lists with one element, so that both row_order and column_order are iterables |
953 | | - row_order : Iterable[int] = [row_order] if isinstance(row_order, int) else row_order |
954 | | - column_order : Iterable[int] = [column_order] if isinstance(column_order, int) else column_order |
| 941 | + row_order : Tuple[int] = [row_order] if isinstance(row_order, int) else row_order |
| 942 | + column_order : Tuple[int] = [column_order] if isinstance(column_order, int) else column_order |
955 | 943 |
|
956 | 944 |
|
957 | 945 | # Iterate through the rows and columns |
958 | 946 | for row in row_order: |
959 | 947 | for column in column_order: |
960 | | - if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]): |
| 948 | + if not grid_ref[row][column]: |
961 | 949 | continue |
962 | | - |
963 | | - subplot_reference : SubplotRef = grid_ref[row][column][axis_index] |
964 | | - spec : SubplotSpec = specs[row][column] |
965 | 950 |
|
966 | | - if subplot_reference is None: |
967 | | - continue |
968 | | - |
969 | | - span = spec['colspan'] if x_or_y == 'x' else spec['rowspan'] |
970 | | - if subplot_reference.subplot_type != 'xy' or span != 1: |
| 951 | + subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column] |
| 952 | + subplot_spec : SubplotSpec = specs[row][column] |
| 953 | + |
| 954 | + span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan'] |
| 955 | + if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces): |
| 956 | + continue |
| 957 | + |
| 958 | + trace = subplot_traces[trace_layer] |
| 959 | + if trace is None or trace.subplot_type != 'xy': |
971 | 960 | continue |
972 | 961 |
|
973 | | - label_name : str = subplot_reference.layout_keys[layout_axis_index] |
| 962 | + label_name : str = trace.layout_keys[axis_index] |
974 | 963 | label : str = label_name.replace("axis", "") |
975 | 964 | return label, (row, column) |
976 | 965 | return None |
977 | 966 |
|
978 | 967 |
|
979 | | - def update_trace_axis(matched_label : str, row : int, column : int, can_remove_label : bool) -> None: |
| 968 | + def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : int, can_reassign_axis : bool, can_hide_ticks : bool, can_match_axis : bool) -> None: |
980 | 969 | ''' |
981 | | - Updates the trace at the given row and column with the given label, and removes the label visibility if necessary |
| 970 | + Updates the specific subplot trace at the given row and column with the given label, and removes the label visibility if necessary; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS |
982 | 971 |
|
983 | 972 | Parameters: |
984 | 973 | ----------- |
985 | | - matched_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location |
986 | | - row (int) : The row of the subplot within grid_ref to update |
987 | | - column (int) : The column of the subplot within grid_ref to update |
988 | | - can_remove_label (bool): Whether the label should be visible (only the bottom label should be visible) |
989 | | - can_change_trace_kwargs (bool): If True the label itself can be changed directly to be the exact same axis (ie use the exact same axis in the trace keyword arguments), or if False, can only mark as matching (ie don't change the trace keyword args) |
| 974 | + axis_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location |
| 975 | + row (int) : The row of the subplot within grid_ref to update |
| 976 | + column (int) : The column of the subplot within grid_ref to update |
| 977 | + trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1] |
| 978 | + can_reassign_axis (bool): If True, can change the unique axis for the shared axis in the trace keywords, otherwise, will keep using the axis name it already has |
| 979 | + can_hide_ticks (bool): If the function is allowed to hide the ticks (if True, it will hide the ticks, if False, it will leave the ticks as their current state) |
| 980 | + can_match_axis (bool): If the axis should be marked as a match to the axis label |
990 | 981 | ''' |
991 | | - if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]): |
992 | | - return |
993 | | - |
994 | | - subplot_reference : SubplotRef = grid_ref[row][column][axis_index] |
995 | | - spec : SubplotSpec = specs[row][column] |
996 | | - |
997 | | - if subplot_reference is None: |
998 | | - return |
999 | 982 |
|
1000 | | - span = spec['colspan'] if x_or_y == 'x' else spec['rowspan'] |
1001 | | - if subplot_reference.subplot_type != 'xy' or span != 1: |
1002 | | - return |
| 983 | + if not grid_ref[row][column] or specs[row][column] is None: |
| 984 | + return |
1003 | 985 |
|
1004 | | - axis_name : str = subplot_reference.layout_keys[layout_axis_index] |
1005 | | - axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' |
1006 | | - axis : go.XAxis = layout[axis_name] |
| 986 | + subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column] |
| 987 | + subplot_spec : SubplotSpec = specs[row][column] |
1007 | 988 |
|
1008 | | - axis.matches = matched_label |
1009 | | - subplot_reference.trace_kwargs[axis_dimension] = matched_label |
| 989 | + span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan'] |
| 990 | + if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces): |
| 991 | + return |
1010 | 992 |
|
1011 | | - if can_remove_label: |
1012 | | - axis.showticklabels = False |
| 993 | + trace : Optional[SubplotRef] = subplot_traces[trace_layer] |
| 994 | + |
| 995 | + if trace is None or trace.subplot_type != 'xy' or span != 1: |
| 996 | + return |
| 997 | + |
| 998 | + axis_name : str = trace.layout_keys[axis_index] |
| 999 | + axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' |
| 1000 | + axis : XAxis = layout[axis_name] |
1013 | 1001 |
|
1014 | | - def columns_mode(): |
| 1002 | + if can_match_axis: |
| 1003 | + axis.matches = axis_label |
| 1004 | + |
| 1005 | + if can_hide_ticks: |
| 1006 | + axis.showticklabels = False |
| 1007 | + |
| 1008 | + if can_reassign_axis: |
| 1009 | + # trace.trace_kwargs[axis_dimension] = axis_label |
| 1010 | + pass |
| 1011 | + |
| 1012 | + def columns_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): |
1015 | 1013 | for column in columns: |
1016 | 1014 | # Get the label used by all the rows in the column |
1017 | | - label_data = find_label_and_index(rows, column) |
| 1015 | + label_data = find_label_and_index(rows, column, trace_layer) |
1018 | 1016 | if label_data is None: |
1019 | 1017 | continue |
1020 | | - column_label, (label_row, _) = label_data |
1021 | | - # Set all of the values in the column |
| 1018 | + axis_label, (label_row, _) = label_data |
1022 | 1019 |
|
1023 | | - can_remove_label : bool = (x_or_y == 'x') |
1024 | | - |
| 1020 | + # Set all of the values in the column |
1025 | 1021 | for row in rows: |
1026 | | - if row == label_row: # Don't update the figure that the label we are matching comes from |
1027 | | - continue |
1028 | | - |
1029 | | - update_trace_axis(column_label, row, column, can_remove_label) |
| 1022 | + subplot_spec : SubplotSpec = specs[row][column] |
| 1023 | + can_reassign_axis : bool = (x_or_y != 'y' or not subplot_spec["secondary_y"]) # Every subplot in the same column should share the same axis if in columns mode |
| 1024 | + can_match_axis : bool = (row != label_row) |
| 1025 | + can_hide_ticks : bool = can_match_axis and x_or_y == 'x' # Sharing column wise can only hide x-axis; still need all of the different y-axis across plots in the same columns |
| 1026 | + |
| 1027 | + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) |
1030 | 1028 |
|
1031 | 1029 |
|
1032 | | - def rows_mode(): |
| 1030 | + def rows_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): |
1033 | 1031 | for row in rows: |
1034 | | - label_data = find_label_and_index(row, columns) |
| 1032 | + label_data = find_label_and_index(row, columns, trace_layer) |
1035 | 1033 | if label_data is None: |
1036 | 1034 | continue |
1037 | | - row_label, (_, label_column) = label_data |
1038 | | - |
1039 | | - can_remove_label : bool = (x_or_y == 'y') |
| 1035 | + axis_label, (_, label_column) = label_data |
1040 | 1036 |
|
1041 | 1037 | for column in columns: |
1042 | | - if column == label_column: # Don't update the figure that the label we are matching comes from |
1043 | | - continue |
| 1038 | + spec : SubplotSpec = specs[row][column] |
| 1039 | + can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y']) |
| 1040 | + can_match_axis : bool = (column != label_column) |
| 1041 | + can_hide_ticks : bool = can_match_axis and x_or_y == 'y' # Sharing row wise can only hide y-axis; still need all of the different x-axis across plots in the same row |
1044 | 1042 |
|
1045 | | - update_trace_axis(row_label, row, column, can_remove_label) |
| 1043 | + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) |
1046 | 1044 |
|
1047 | | - def all_mode(): |
1048 | | - label_data = find_label_and_index(rows, columns) |
| 1045 | + def all_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): |
| 1046 | + label_data = find_label_and_index(rows, columns, trace_layer) |
1049 | 1047 | if label_data is None: |
1050 | 1048 | return |
1051 | | - label, (label_row, label_column) = label_data |
| 1049 | + axis_label, (label_row, label_column) = label_data |
1052 | 1050 |
|
1053 | | - for row_index, row in enumerate(rows): |
| 1051 | + for row in rows: |
1054 | 1052 | for column in columns: |
1055 | | - if row == label_row and column == label_column: # Don't update the figure that the label we are matching comes from |
1056 | | - continue |
1057 | | - |
1058 | | - if x_or_y == 'y': |
1059 | | - can_remove_label : bool = (column < column_count - 1 if secondary_y else column > 0) |
1060 | | - else: |
1061 | | - can_remove_label : bool = (row_index > 0 if row_direction > 0 else row < row_count - 1) |
1062 | | - |
1063 | | - update_trace_axis(label, row, column, can_remove_label) |
| 1053 | + spec : SubplotSpec = specs[row][column] |
| 1054 | + can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y']) |
| 1055 | + can_match_axis : bool = (row != label_row or column != label_column) |
| 1056 | + can_hide_ticks : bool = not ((row == label_row and x_or_y == 'x') or (column == label_column and x_or_y == 'y')) # The x-axis is across the first row, and the y-axis is along the first column |
| 1057 | + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) |
1064 | 1058 |
|
1065 | | - match(shared, x_or_y, shared): |
1066 | | - case ('columns', _, _) | (_, 'x', True): # If columns mode, or shared and x |
1067 | | - columns_mode() |
1068 | | - case ('rows', _, _) | (_, 'y', True): # If rows mode, or shared and y |
1069 | | - rows_mode() |
1070 | | - case ('all', _, _): # If all mode |
1071 | | - all_mode() |
| 1059 | + |
| 1060 | + rows : Tuple[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count)) |
| 1061 | + columns : Tuple[int] = tuple(range(column_count)) |
| 1062 | + BASE_TRACE_LAYER = 0 |
| 1063 | + SECOND_Y_LAYER = 1 |
| 1064 | + match(shared, x_or_y): |
| 1065 | + case ('columns', _) | (True, 'x'): # If columns mode, or shared and x |
| 1066 | + columns_mode(rows, columns, BASE_TRACE_LAYER) |
| 1067 | + columns_mode(tuple(reversed(rows)), columns, SECOND_Y_LAYER) |
| 1068 | + case ('rows', _) | (True, 'y'): # If rows mode, or shared and y |
| 1069 | + rows_mode(rows, columns, BASE_TRACE_LAYER) |
| 1070 | + rows_mode(rows, tuple(reversed(columns)), SECOND_Y_LAYER) |
| 1071 | + case ('all', _): # If all mode |
| 1072 | + all_mode(rows, columns, BASE_TRACE_LAYER) |
| 1073 | + all_mode(tuple(reversed(rows)), tuple(reversed(columns)), SECOND_Y_LAYER) |
1072 | 1074 | case _: # If reached the other case |
1073 | 1075 | return |
1074 | 1076 |
|
|
0 commit comments