|
6 | 6 | from scipy.stats import rankdata |
7 | 7 | from scipy.ndimage import label |
8 | 8 | from spyrit.core.torch import walsh_matrix_2d |
9 | | -from pytorch_wavelets import DWTForward |
| 9 | +import ptwt |
10 | 10 |
|
11 | 11 |
|
12 | 12 | # from /misc/statistics.py |
@@ -481,11 +481,11 @@ def define_order(n: int, order: str, pdf: bool = False): |
481 | 481 | for i in range(N): |
482 | 482 |
|
483 | 483 | patt = np.asarray(h[i]) |
484 | | - binary = (patt > 0).astype(int) |
485 | | - labeled, num = label(binary) |
486 | | - if i == 0: |
487 | | - num = 0 |
488 | | - CC_values[i] = num |
| 484 | + pos = (patt > 0).astype(int) |
| 485 | + neg = (patt < 0).astype(int) |
| 486 | + _, num_pos = label(pos) |
| 487 | + _, num_neg = label(neg) |
| 488 | + CC_values[i] = num_pos + num_neg |
489 | 489 |
|
490 | 490 | score_CC = 1 / (CC_values + 1e-8) |
491 | 491 | score_CC[0] = 1 |
@@ -613,7 +613,7 @@ def sampling_map_multilevel_VDS( |
613 | 613 | levels: int, |
614 | 614 | J: int = 3, |
615 | 615 | wave: str = "sym8", |
616 | | - mode: str = "periodization", |
| 616 | + mode: str = "periodic", |
617 | 617 | seed: int = 0, |
618 | 618 | ): |
619 | 619 | """ |
@@ -648,7 +648,7 @@ def sampling_map_multilevel_VDS( |
648 | 648 | N = n**2 |
649 | 649 | H = walsh_matrix_2d(n) |
650 | 650 |
|
651 | | - dwt = DWTForward(J=J, wave=wave, mode=mode) |
| 651 | + # dwt = DWTForward(J=J, wave=wave, mode=mode) |
652 | 652 |
|
653 | 653 | lvl_sizes = torch.zeros(levels) # number of elements in each level |
654 | 654 | lvl_maps = torch.zeros(levels, n, n) |
@@ -680,7 +680,12 @@ def sampling_map_multilevel_VDS( |
680 | 680 | ) # Local coherences inside each level |
681 | 681 |
|
682 | 682 | for i in range(int(lvl_sizes[k])): |
683 | | - coeffs = dwt(H_k[i].reshape(n, n).unsqueeze(0).unsqueeze(0)) |
| 683 | + coeffs = ptwt.wavedec2( |
| 684 | + H_k[i].reshape(n, n).unsqueeze(0).unsqueeze(0), |
| 685 | + wavelet=wave, |
| 686 | + mode=mode, |
| 687 | + level=J, |
| 688 | + ) |
684 | 689 | mu_loc[i, 0] = torch.max(abs(coeffs[0])) |
685 | 690 | for j in range(J): |
686 | 691 | mu_loc[i, j + 1] = torch.max(abs(coeffs[1][2 - j])) |
|
0 commit comments