Skip to content

Commit 49abab7

Browse files
authored
Merge pull request #330 from clement-th/subsampling
Subsampling
2 parents cf510dd + 2b57edf commit 49abab7

2 files changed

Lines changed: 15 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ dependencies = [
4242
"matplotlib",
4343
"scipy",
4444
"torch",
45-
"pytorch_wavelets",
45+
"ptwt",
4646
"torchvision",
47-
"PyWavelets",
4847
"sympy",
4948
"requests",
5049
"tqdm",

spyrit/misc/sampling.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from scipy.stats import rankdata
77
from scipy.ndimage import label
88
from spyrit.core.torch import walsh_matrix_2d
9-
from pytorch_wavelets import DWTForward
9+
import ptwt
1010

1111

1212
# from /misc/statistics.py
@@ -481,11 +481,11 @@ def define_order(n: int, order: str, pdf: bool = False):
481481
for i in range(N):
482482

483483
patt = np.asarray(h[i])
484-
binary = (patt > 0).astype(int)
485-
labeled, num = label(binary)
486-
if i == 0:
487-
num = 0
488-
CC_values[i] = num
484+
pos = (patt > 0).astype(int)
485+
neg = (patt < 0).astype(int)
486+
_, num_pos = label(pos)
487+
_, num_neg = label(neg)
488+
CC_values[i] = num_pos + num_neg
489489

490490
score_CC = 1 / (CC_values + 1e-8)
491491
score_CC[0] = 1
@@ -613,7 +613,7 @@ def sampling_map_multilevel_VDS(
613613
levels: int,
614614
J: int = 3,
615615
wave: str = "sym8",
616-
mode: str = "periodization",
616+
mode: str = "periodic",
617617
seed: int = 0,
618618
):
619619
"""
@@ -648,7 +648,7 @@ def sampling_map_multilevel_VDS(
648648
N = n**2
649649
H = walsh_matrix_2d(n)
650650

651-
dwt = DWTForward(J=J, wave=wave, mode=mode)
651+
# dwt = DWTForward(J=J, wave=wave, mode=mode)
652652

653653
lvl_sizes = torch.zeros(levels) # number of elements in each level
654654
lvl_maps = torch.zeros(levels, n, n)
@@ -680,7 +680,12 @@ def sampling_map_multilevel_VDS(
680680
) # Local coherences inside each level
681681

682682
for i in range(int(lvl_sizes[k])):
683-
coeffs = dwt(H_k[i].reshape(n, n).unsqueeze(0).unsqueeze(0))
683+
coeffs = ptwt.wavedec2(
684+
H_k[i].reshape(n, n).unsqueeze(0).unsqueeze(0),
685+
wavelet=wave,
686+
mode=mode,
687+
level=J,
688+
)
684689
mu_loc[i, 0] = torch.max(abs(coeffs[0]))
685690
for j in range(J):
686691
mu_loc[i, j + 1] = torch.max(abs(coeffs[1][2 - j]))

0 commit comments

Comments
 (0)