Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 300 additions & 10 deletions pycs/astro/wl/hos_peaks_l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,30 @@
from numpy import linalg as LA
from scipy.special import erf

from pycs.sparsity.sparse2d.starlet import *
from pycs.misc.cosmostat_init import *
from pycs.misc.mr_prog import *
from pycs.misc.utilHSS import *
from pycs.misc.im1d_tend import *
from pycs.misc.stats import *
from pycs.sparsity.sparse2d.dct import dct2d, idct2d
from pycs.sparsity.sparse2d.dct_inpainting import dct_inpainting
from pycs.misc.im_isospec import *
from pycs.astro.wl.mass_mapping import *
# Import spherical functionality - this is needed for the new functions
from pycs.sparsity.mrs.mrs_starlet import mrs_uwttrans, CMRStarlet

# Conditional imports to avoid dependency issues with 2D starlet
try:
from pycs.misc.cosmostat_init import *
from pycs.misc.mr_prog import *
from pycs.misc.utilHSS import *
from pycs.misc.im1d_tend import *
from pycs.misc.stats import *
except ImportError as e:
print(f"Warning: Some utility functionality may not be available: {e}")

try:
from pycs.sparsity.sparse2d.starlet import *
from pycs.sparsity.sparse2d.dct import dct2d, idct2d
from pycs.sparsity.sparse2d.dct_inpainting import dct_inpainting
from pycs.misc.im_isospec import *
from pycs.astro.wl.mass_mapping import *
except ImportError as e:
print(f"Warning: Some 2D starlet functionality may not be available: {e}")
# Define minimal replacement functions for compatibility
def conv(a, b):
return a # Placeholder
import healpy as hp # Added for Nside calculation


Expand Down Expand Up @@ -683,6 +696,283 @@ def get_wtl1_sphere(
return np.array(bins_coll), np.array(l1norm_coll)


def get_peaks_sphere(healpix_map, threshold=None, ordered=True, mask=None, nside=None):
"""Identify peaks in a HEALPix map above a given threshold.

A peak, or local maximum, is defined as a pixel with a value larger than
all of its neighbors on the sphere. A mask may be provided to exclude
certain regions from the search.

Parameters
----------
healpix_map : array_like
One-dimensional HEALPix map.
threshold : float, optional
Minimum pixel amplitude to be considered as a peak. If not provided,
the default value is set to the minimum of `healpix_map`.
ordered : bool, optional
If True, return peaks in decreasing order according to height.
mask : array_like (same shape as `healpix_map`), optional
Boolean array identifying which pixels of `healpix_map` to consider/exclude
in finding peaks. A numerical array will be converted to binary, where
only zero values are considered masked.
nside : int, optional
HEALPix nside parameter. If not provided, it will be inferred from map size.

Returns
-------
pixel_indices, heights : tuple of 1D numpy arrays
Pixel indices of peak positions and their associated heights.

Notes
-----
This is the spherical version of get_peaks, designed for HEALPix maps.
It uses healpy.get_all_neighbours to find the neighbors of each pixel.
"""
healpix_map = np.atleast_1d(healpix_map)

# Determine nside if not provided
if nside is None:
nside = hp.npix2nside(len(healpix_map))

npix = hp.nside2npix(nside)
if len(healpix_map) != npix:
raise ValueError(f"Map size ({len(healpix_map)}) doesn't match nside={nside} (npix={npix})")

# Deal with the mask first
if mask is not None:
mask = np.atleast_1d(mask)
if mask.shape != healpix_map.shape:
print("Warning: mask not compatible with map -> ignoring.")
mask = np.ones(healpix_map.shape)
else:
# Make sure mask is binary, i.e. turn nonzero values into ones
mask = mask.astype(bool).astype(float)
else:
mask = np.ones(healpix_map.shape)

# Determine threshold level
if threshold is None:
threshold = healpix_map[mask.astype('bool')].min()
else:
threshold = max(threshold, healpix_map.min())

# Find peaks by checking each pixel against its neighbors
peak_pixels = []
peak_heights = []

for ipix in range(npix):
# Skip if pixel is masked
if mask[ipix] == 0:
continue

pixel_value = healpix_map[ipix]

# Skip if below threshold
if pixel_value < threshold:
continue

# Get neighbors of this pixel
neighbors = hp.get_all_neighbours(nside, ipix)

# Remove invalid neighbors (-1 values)
valid_neighbors = neighbors[neighbors >= 0]

# Check if this pixel is higher than all its valid neighbors
is_peak = True
for neighbor_idx in valid_neighbors:
# Skip masked neighbors
if mask[neighbor_idx] == 0:
continue
if healpix_map[neighbor_idx] >= pixel_value:
is_peak = False
break

if is_peak:
peak_pixels.append(ipix)
peak_heights.append(pixel_value)

peak_pixels = np.array(peak_pixels)
peak_heights = np.array(peak_heights)

# Sort by height if requested
if ordered and len(peak_heights) > 0:
sort_indices = np.argsort(peak_heights)[::-1] # Descending order
peak_pixels = peak_pixels[sort_indices]
peak_heights = peak_heights[sort_indices]

return peak_pixels, peak_heights


def get_wtpeaks_sphere(
Map,
nscales,
nbins=None,
Mask=None,
min_snr=None,
max_snr=None,
noise_std=None,
peak_threshold=None,
):
"""
Calculate multi-scale peak counts for a HEALPix map using spherical wavelet transform.

This function performs a spherical wavelet transform using CMRStarlet, then identifies
peaks at each scale and returns histograms of peak counts binned by peak height/SNR.
This is the spherical analog of the get_wtpeaks method.

Parameters
----------
Map : array_like
HEALPix map to analyze.
nscales : int
Number of wavelet scales to use.
nbins : int, optional
Number of bins for the histogram. Default is 40.
Mask : array_like, optional
Mask indicating where we have observations. Only pixels where Mask != 0 are considered.
min_snr : float, optional
Minimum value for binning the normalized coefficients.
If None, uses the minimum value in the coefficients for the current scale.
max_snr : float, optional
Maximum value for binning the normalized coefficients.
If None, uses the maximum value in the coefficients for the current scale.
noise_std : float, optional
Noise standard deviation. If provided, coefficients are divided by this value
to compute an SNR before binning. Default is None.
peak_threshold : float, optional
Minimum peak height threshold for peak detection. If None, uses minimum value
of each scale.

Returns
-------
tuple of arrays
(bins_coll, peaks_count_coll, peaks_pixels_coll, peaks_heights_coll) where:
- bins_coll[i] are the bin centers for scale i
- peaks_count_coll[i] are the peak counts for each bin at scale i
- peaks_pixels_coll[i] are the pixel indices of peaks at scale i
- peaks_heights_coll[i] are the peak heights at scale i
"""

# Set default for nbins if not provided
if nbins is None:
nbins = 40

# Determine Nside from the input map
Nside = hp.npix2nside(Map.shape[0])

# Initialize and perform CMRStarlet transform
C = CMRStarlet()
C.init_starlet(Nside, nscale=nscales)
C.transform(Map)

bins_coll = []
peaks_count_coll = []
peaks_pixels_coll = []
peaks_heights_coll = []

# Loop through each scale of the wavelet transform
for i in range(nscales):
# Get normalized wavelet coefficients for the i-th scale
if C.TabNorm[i] == 0: # Avoid division by zero if TabNorm is zero
ScaleCoeffs = C.coef[i].copy()
else:
ScaleCoeffs = C.coef[i] / C.TabNorm[i]

# If noise_std is provided, convert to SNR
if noise_std is not None:
ScaleCoeffs = ScaleCoeffs / noise_std

# Find peaks in the current scale
peak_pixels, peak_heights = get_peaks_sphere(
ScaleCoeffs,
threshold=peak_threshold,
ordered=True,
mask=Mask,
nside=Nside
)

# Store peak information
peaks_pixels_coll.append(peak_pixels)
peaks_heights_coll.append(peak_heights)

# Create histogram of peak heights if we have peaks
if len(peak_heights) > 0:
# Determine binning range
if min_snr is not None:
current_min_val = min_snr
else:
current_min_val = np.min(peak_heights) if len(peak_heights) > 0 else 0

if max_snr is not None:
current_max_val = max_snr
else:
current_max_val = np.max(peak_heights) if len(peak_heights) > 0 else 1

# Define thresholds and bins
thresholds = np.linspace(current_min_val, current_max_val, nbins + 1)
bins = 0.5 * (thresholds[:-1] + thresholds[1:])

# Create histogram of peak heights
counts, _ = np.histogram(peak_heights, bins=thresholds)
else:
# No peaks found, create empty bins
if min_snr is not None and max_snr is not None:
thresholds = np.linspace(min_snr, max_snr, nbins + 1)
else:
thresholds = np.linspace(0, 1, nbins + 1)
bins = 0.5 * (thresholds[:-1] + thresholds[1:])
counts = np.zeros(nbins, dtype=int)

# Store the bins and counts for this scale
bins_coll.append(bins)
peaks_count_coll.append(counts)

return (
np.array(bins_coll),
np.array(peaks_count_coll),
peaks_pixels_coll,
peaks_heights_coll
)


def test_spherical_peaks():
"""Test function for the new spherical peak counting functionality."""
import numpy as np
import healpy as hp

print("Testing spherical peak counting functions...")

# Create test data
nside = 8
npix = hp.nside2npix(nside)
test_map = np.random.normal(0, 0.3, npix)

# Add some peaks
test_map[100] = 2.0
test_map[200] = 1.5
test_map[300] = 1.0

# Test get_peaks_sphere
peak_pixels, peak_heights = get_peaks_sphere(test_map, threshold=0.5)
print(f"get_peaks_sphere: Found {len(peak_pixels)} peaks")

# Test get_wtpeaks_sphere
bins, counts, pixels_list, heights_list = get_wtpeaks_sphere(
test_map, nscales=3, nbins=10
)
print(f"get_wtpeaks_sphere: Analysis complete!")
print(f"Peaks per scale: {[len(p) for p in pixels_list]}")

# Test with SNR
bins_snr, counts_snr, _, _ = get_wtpeaks_sphere(
test_map, nscales=3, nbins=10, noise_std=0.3
)
print("SNR analysis successful!")

return True


############# TESTS routine ############


Expand Down
Loading