Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 152 additions & 21 deletions tab_right/plotting/plot_segmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,51 @@
ColorMap = Union[str, list]


def normalize_scores(scores: np.ndarray, method: str = "minmax", k: float = 2.0) -> np.ndarray:
"""Normalize scores using different scaling methods.

Parameters
----------
scores : np.ndarray
Array of scores to normalize.
method : str, default="minmax"
Scaling method to use. Options:
- "minmax": Min-max normalization using actual min and max values
- "std": Standard deviation based scaling using mean ± k*std
k : float, default=2.0
Number of standard deviations to use for "std" method scaling range.

Returns
-------
np.ndarray
Normalized scores in the range [0, 1].

Raises
------
ValueError
If an unknown scaling method is provided.

"""
if method == "std":
mean = np.mean(scores)
std = np.std(scores)
vmin = mean - k * std
vmax = mean + k * std
elif method == "minmax":
vmin = np.min(scores)
vmax = np.max(scores)
else:
raise ValueError(f"Unknown method: {method}. Supported methods are 'minmax' and 'std'.")

# Avoid division by zero
if vmax - vmin == 0:
return np.zeros_like(scores)

# Clip values to the computed range and normalize to [0, 1]
clipped = np.clip(scores, vmin, vmax)
return (clipped - vmin) / (vmax - vmin)


def _prepare_data(df: pd.DataFrame) -> pd.DataFrame:
"""Prepare data for segmentation plotting by sorting.

Expand Down Expand Up @@ -86,7 +131,13 @@
return {"cmap": "RdYlGn"} # Red (low/bad) to Green (high/good)


def plot_single_segmentation(df: pd.DataFrame, lower_is_better: bool = True, backend: Backend = "plotly") -> Figure:
def plot_single_segmentation(
df: pd.DataFrame,
lower_is_better: bool = True,
backend: Backend = "plotly",
scaling_method: str = "minmax",
scaling_k: float = 2.0,
) -> Figure:
"""Plot the single segmentation of a given DataFrame as a bar chart.

This function can use either Plotly or Matplotlib as backend.
Expand All @@ -99,6 +150,10 @@
Whether lower values of the metric indicate better performance.
backend : str, default="plotly"
The plotting backend to use. Either "plotly" or "matplotlib".
scaling_method : str, default="minmax"
Method for scaling colors. Options: "minmax" or "std".
scaling_k : float, default=2.0
Number of standard deviations for "std" scaling method.

Returns
-------
Expand All @@ -107,12 +162,14 @@

"""
if backend == "plotly":
return _plot_single_segmentation_plotly(df, lower_is_better)
return _plot_single_segmentation_plotly(df, lower_is_better, scaling_method, scaling_k)
else:
return _plot_single_segmentation_matplotlib(df, lower_is_better)
return _plot_single_segmentation_matplotlib(df, lower_is_better, scaling_method, scaling_k)


def _plot_single_segmentation_plotly(df: pd.DataFrame, lower_is_better: bool = True) -> PlotlyFigure:
def _plot_single_segmentation_plotly(
df: pd.DataFrame, lower_is_better: bool = True, scaling_method: str = "minmax", scaling_k: float = 2.0
) -> PlotlyFigure:
"""Implement the single segmentation plot as a Plotly bar chart.

Parameters
Expand All @@ -121,6 +178,10 @@
See module docstring for format details.
lower_is_better : bool, default=True
Whether lower values of the metric indicate better performance.
scaling_method : str, default="minmax"
Method for scaling colors. Options: "minmax" or "std".
scaling_k : float, default=2.0
Number of standard deviations for "std" scaling method.

Returns
-------
Expand All @@ -134,14 +195,17 @@
# Get color scheme
color_scheme = _get_color_scheme(lower_is_better, "plotly")

# Normalize scores for color mapping
normalized_scores = normalize_scores(df_sorted["score"].values, method=scaling_method, k=scaling_k)

# Create a bar chart
fig = go.Figure(
data=[
go.Bar(
x=df_sorted["segment_name"].astype(str),
y=df_sorted["score"],
marker=dict(
color=df_sorted["score"],
color=normalized_scores,
colorscale=color_scheme["colorscale"],
colorbar=dict(title="Score"),
),
Expand All @@ -163,7 +227,9 @@
return fig


def _plot_single_segmentation_matplotlib(df: pd.DataFrame, lower_is_better: bool = True) -> MatplotlibFigure:
def _plot_single_segmentation_matplotlib(
df: pd.DataFrame, lower_is_better: bool = True, scaling_method: str = "minmax", scaling_k: float = 2.0
) -> MatplotlibFigure:
"""Implement the single segmentation plot as a Matplotlib bar chart.

Parameters
Expand All @@ -172,6 +238,10 @@
See module docstring for format details.
lower_is_better : bool, default=True
Whether lower values of the metric indicate better performance.
scaling_method : str, default="minmax"
Method for scaling colors. Options: "minmax" or "std".
scaling_k : float, default=2.0
Number of standard deviations for "std" scaling method.

Returns
-------
Expand All @@ -191,14 +261,22 @@
cmap_name = color_scheme["cmap"]
assert isinstance(cmap_name, str), "matplotlib cmap should be a string"

# Normalize the scores for colormapping
if len(df_sorted) > 1:
norm = plt.Normalize(float(df_sorted["score"].min()), float(df_sorted["score"].max()))
else:
# Normalize scores for color mapping
normalized_scores = normalize_scores(df_sorted["score"].values, method=scaling_method, k=scaling_k)

# Create normalization based on the scaling method
if scaling_method == "std":
# For std method, use [0, 1] range since normalized_scores are already in that range
norm = plt.Normalize(0, 1)
else:
# For minmax method, use the actual score range for the colorbar
if len(df_sorted) > 1:
norm = plt.Normalize(float(df_sorted["score"].min()), float(df_sorted["score"].max()))
else:
norm = plt.Normalize(0, 1)

Check warning on line 276 in tab_right/plotting/plot_segmentations.py

View check run for this annotation

Codecov / codecov/patch

tab_right/plotting/plot_segmentations.py#L276

Added line #L276 was not covered by tests

cmap = plt.get_cmap(cmap_name)
colors = cmap(norm(df_sorted["score"].values.astype(np.float64)))
colors = cmap(normalized_scores)

# Create bar chart
bars = ax.bar(df_sorted["segment_name"].astype(str), df_sorted["score"], color=colors)
Expand All @@ -210,8 +288,14 @@
bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.3f}", ha="center", va="bottom", fontsize=9
)

# Create colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
# Create colorbar - use the original score values for the colorbar scale
if scaling_method == "std":
# For std method, create a colorbar that shows the normalized range
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, 1))
else:
# For minmax method, use the actual score range
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

sm.set_array([])
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label("Score")
Expand All @@ -227,7 +311,9 @@


# For backward compatibility
def plot_single_segmentation_mp(df: pd.DataFrame, lower_is_better: bool = True) -> MatplotlibFigure:
def plot_single_segmentation_mp(
df: pd.DataFrame, lower_is_better: bool = True, scaling_method: str = "minmax", scaling_k: float = 2.0
) -> MatplotlibFigure:
"""Plot the single segmentation using matplotlib (compatibility function).

This is a wrapper around plot_single_segmentation with backend="matplotlib" for backwards compatibility.
Expand All @@ -238,18 +324,26 @@
See module docstring for format details.
lower_is_better : bool, default=True
Whether lower values indicate better performance.
scaling_method : str, default="minmax"
Method for scaling colors. Options: "minmax" or "std".
scaling_k : float, default=2.0
Number of standard deviations for "std" scaling method.

Returns
-------
MatplotlibFigure
A matplotlib bar chart showing each segment with its corresponding score.

"""
return plot_single_segmentation(df, lower_is_better, backend="matplotlib")
return plot_single_segmentation(
df, lower_is_better, backend="matplotlib", scaling_method=scaling_method, scaling_k=scaling_k
)


# For backward compatibility
def plot_single_segmentation_impl(df: pd.DataFrame, lower_is_better: bool = True) -> PlotlyFigure:
def plot_single_segmentation_impl(
df: pd.DataFrame, lower_is_better: bool = True, scaling_method: str = "minmax", scaling_k: float = 2.0
) -> PlotlyFigure:
"""Implement the single segmentation plot as a Plotly bar chart (compatibility function).

This is kept for backwards compatibility and wraps _plot_single_segmentation_plotly.
Expand All @@ -260,14 +354,18 @@
See module docstring for format details.
lower_is_better : bool, default=True
Whether lower values indicate better performance.
scaling_method : str, default="minmax"
Method for scaling colors. Options: "minmax" or "std".
scaling_k : float, default=2.0
Number of standard deviations for "std" scaling method.

Returns
-------
PlotlyFigure
A Plotly bar chart.

"""
return _plot_single_segmentation_plotly(df, lower_is_better)
return _plot_single_segmentation_plotly(df, lower_is_better, scaling_method, scaling_k)

Check warning on line 368 in tab_right/plotting/plot_segmentations.py

View check run for this annotation

Codecov / codecov/patch

tab_right/plotting/plot_segmentations.py#L368

Added line #L368 was not covered by tests


@dataclass
Expand All @@ -284,6 +382,8 @@
metric_name: str = "score"
lower_is_better: bool = True
backend: Backend = "plotly"
scaling_method: str = "minmax"
scaling_k: float = 2.0

def get_heatmap_df(self) -> pd.DataFrame:
"""Get the DataFrame for the heatmap from the double segmentation df.
Expand Down Expand Up @@ -315,10 +415,25 @@
# Get color scheme
color_scheme = _get_color_scheme(self.lower_is_better, "plotly")

# Normalize scores for color mapping
scores = heatmap_df.values.flatten()
# Remove NaN values for normalization
valid_scores = scores[~np.isnan(scores)]
if len(valid_scores) > 0:
normalized_scores = normalize_scores(valid_scores, method=self.scaling_method, k=self.scaling_k)
# Create a normalized version of the heatmap
normalized_heatmap = np.full_like(heatmap_df.values, np.nan)
valid_mask = ~np.isnan(heatmap_df.values)
if len(valid_scores) > 0:
# Map back the normalized scores to the heatmap structure
normalized_heatmap[valid_mask] = normalized_scores
else:
normalized_heatmap = heatmap_df.values

Check warning on line 431 in tab_right/plotting/plot_segmentations.py

View check run for this annotation

Codecov / codecov/patch

tab_right/plotting/plot_segmentations.py#L431

Added line #L431 was not covered by tests

# Create heatmap
fig = go.Figure(
data=go.Heatmap(
z=heatmap_df.values,
z=normalized_heatmap,
x=heatmap_df.columns,
y=heatmap_df.index,
colorscale=color_scheme["colorscale"],
Expand Down Expand Up @@ -362,13 +477,27 @@
cmap = color_scheme["cmap"]
assert isinstance(cmap, str), "matplotlib cmap should be a string"

# Normalize scores for color mapping
scores = heatmap_df.values.flatten()
# Remove NaN values for normalization
valid_scores = scores[~np.isnan(scores)]

if len(valid_scores) > 0:
normalized_scores = normalize_scores(valid_scores, method=self.scaling_method, k=self.scaling_k)
# Create a normalized version of the heatmap
normalized_heatmap = np.full_like(heatmap_df.values, np.nan)
valid_mask = ~np.isnan(heatmap_df.values)
normalized_heatmap[valid_mask] = normalized_scores
else:
normalized_heatmap = heatmap_df.values

Check warning on line 492 in tab_right/plotting/plot_segmentations.py

View check run for this annotation

Codecov / codecov/patch

tab_right/plotting/plot_segmentations.py#L492

Added line #L492 was not covered by tests

# Create heatmap using pcolormesh which creates a QuadMesh collection
# First create a meshgrid for the x and y coordinates
x = np.arange(len(heatmap_df.columns) + 1)
y = np.arange(len(heatmap_df.index) + 1)

# Create the heatmap using pcolormesh
mesh = ax.pcolormesh(x, y, heatmap_df.values, cmap=cmap)
mesh = ax.pcolormesh(x, y, normalized_heatmap, cmap=cmap, vmin=0, vmax=1)

# Set x and y labels
ax.set_xticks(np.arange(len(heatmap_df.columns)) + 0.5)
Expand All @@ -381,12 +510,14 @@
cbar = fig.colorbar(mesh, ax=ax)
cbar.set_label(self.metric_name)

# Add text annotations with the values
# Add text annotations with the values (use original values, not normalized)
for i in range(len(heatmap_df.index)):
for j in range(len(heatmap_df.columns)):
value = heatmap_df.values[i, j]
if not pd.isna(value):
text_color = "black" if 0.3 < value < 0.7 else "white"
# Determine text color based on normalized value for better contrast
normalized_value = normalized_heatmap[i, j] if not pd.isna(normalized_heatmap[i, j]) else 0.5
text_color = "black" if 0.3 < normalized_value < 0.7 else "white"
ax.text(j + 0.5, i + 0.5, f"{value:.3f}", ha="center", va="center", color=text_color)

# Set titles
Expand Down
Loading
Loading