Skip to content

Commit 95a8bf4

Browse files
committed
update: integrate sanity noise model into bootstrap
1 parent 158a87d commit 95a8bf4

5 files changed

Lines changed: 85 additions & 26 deletions

File tree

src/scloop/analyzing/bootstrap.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def run_single_bootstrap(
8282
decay_random_walk: PositiveFloat = 1.0,
8383
noise_random_walk: PositiveFloat = 1.0,
8484
seed_random_walk: int = 1,
85+
do_force_deviate_random_walk: bool = False,
8586
k_neighbors_check_equivalence: int = DEFAULT_K_NEIGHBORS_CHECK_EQUIVALENCE,
8687
method_geometric_equivalence: LoopDistMethod = DEFAULT_LOOP_DIST_METHOD,
8788
n_pairs_check_equivalence: int = DEFAULT_N_PAIRS_CHECK_EQUIVALENCE,
@@ -145,6 +146,7 @@ def run_single_bootstrap(
145146
decay_random_walk=decay_random_walk,
146147
noise_random_walk=noise_random_walk,
147148
seed_random_walk=seed_random_walk,
149+
do_force_deviate_random_walk=do_force_deviate_random_walk,
148150
bootstrap=False,
149151
do_clean_cocycle_region=True,
150152
)
@@ -169,6 +171,7 @@ def run_single_bootstrap(
169171
decay_random_walk=decay_random_walk,
170172
noise_random_walk=noise_random_walk,
171173
seed_random_walk=seed_random_walk,
174+
do_force_deviate_random_walk=do_force_deviate_random_walk,
172175
bootstrap=True,
173176
)
174177

@@ -309,6 +312,7 @@ def run_bootstrap_pipeline(
309312
with_relaxation_equivalence: bool = DEFAULT_WITH_RELAXATION_EQUIVALENCE,
310313
n_hubs_relaxation_equivalence: int = DEFAULT_N_HUBS_RELAXATION_EQUIVALENCE,
311314
max_n_edges_relaxation_equivalence: int = DEFAULT_MAX_N_EDGES_RELAXATION_EQUIVALENCE,
315+
do_force_deviate_random_walk: bool = False,
312316
**kwargs,
313317
) -> list[BootstrapResult]:
314318
results: list[BootstrapResult] = []
@@ -344,6 +348,7 @@ def run_bootstrap_pipeline(
344348
with_relaxation_equivalence=with_relaxation_equivalence,
345349
n_hubs_relaxation_equivalence=n_hubs_relaxation_equivalence,
346350
max_n_edges_relaxation_equivalence=max_n_edges_relaxation_equivalence,
351+
do_force_deviate_random_walk=do_force_deviate_random_walk,
347352
**kwargs,
348353
)
349354
tasks[task] = i

src/scloop/computing/homology.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
sample_farthest_points,
3030
sample_farthest_points_randomized,
3131
)
32+
from ..utils.denoise.Sanity_py import sample_posterior_predictive_counts
3233
from ..utils.distance_metrics.frechet_py import compute_pairwise_loop_frechet
3334
from ..utils.linear_algebra_gf2 import ( # type: ignore
3435
solve_multiple_gf2_m4ri, # type: ignore[import-not-found]
@@ -38,11 +39,42 @@
3839
from ..data.containers import BoundaryMatrixD1
3940

4041

42+
def _sample_bootstrap_embedding(
43+
adata: AnnData,
44+
meta: ScloopMeta,
45+
selected_indices: list[int],
46+
sample_idx: np.ndarray,
47+
bootstrap_noise_model: str,
48+
noise_scale: float,
49+
) -> tuple[np.ndarray, list[int]]:
50+
assert meta.preprocess is not None
51+
assert meta.preprocess.embedding_method is not None
52+
53+
boot_idx = [selected_indices[int(i)] for i in sample_idx.tolist()]
54+
emb = np.asarray(adata.obsm[f"X_{meta.preprocess.embedding_method}"])
55+
56+
if bootstrap_noise_model == "sanity" and meta.preprocess.embedding_method == "pca":
57+
X = sample_posterior_predictive_counts(
58+
adata=adata,
59+
cell_idx=np.asarray(boot_idx, dtype=np.int64),
60+
scale_before_pca=meta.preprocess.scale_before_pca,
61+
n_pca_comps=meta.preprocess.n_pca_comps,
62+
)
63+
else:
64+
X_ref = emb[selected_indices]
65+
X = emb[boot_idx]
66+
std_X = np.std(X_ref, axis=0)
67+
X = X + np.random.normal(scale=std_X * noise_scale, size=X.shape)
68+
69+
return X, boot_idx
70+
71+
4172
def compute_sparse_pairwise_distance(
4273
adata: AnnData,
4374
meta: ScloopMeta,
4475
bootstrap: bool = False,
4576
noise_scale: float = 1e-3,
77+
bootstrap_noise_model: str = "gaussian",
4678
thresh: Diameter_t | None = None,
4779
bootstrap_sampling: str = "resample",
4880
bootstrap_downsample_fraction: Percent_t = 2 / 3,
@@ -68,22 +100,28 @@ def compute_sparse_pairwise_distance(
68100
if bootstrap_sampling == "resample":
69101
sample_idx = np.random.choice(
70102
len(selected_indices), size=len(selected_indices), replace=True
71-
).tolist()
72-
boot_idx = [selected_indices[i] for i in sample_idx]
73-
std_X = np.std(X, axis=0)
74-
X = X[sample_idx] + np.random.normal(
75-
scale=std_X * noise_scale, size=X.shape
103+
)
104+
X, boot_idx = _sample_bootstrap_embedding(
105+
adata=adata,
106+
meta=meta,
107+
selected_indices=selected_indices,
108+
sample_idx=np.asarray(sample_idx, dtype=np.int64),
109+
bootstrap_noise_model=bootstrap_noise_model,
110+
noise_scale=noise_scale,
76111
)
77112
elif bootstrap_sampling == "fps":
78113
n_keep = max(
79114
2, int(round(len(selected_indices) * bootstrap_downsample_fraction))
80115
)
81116
n_keep = min(n_keep, len(selected_indices))
82117
sample_idx = sample_farthest_points(X, n_keep)
83-
boot_idx = [selected_indices[int(i)] for i in sample_idx.tolist()]
84-
std_X = np.std(X, axis=0)
85-
X = X[sample_idx] + np.random.normal(
86-
scale=std_X * noise_scale, size=(n_keep, X.shape[1])
118+
X, boot_idx = _sample_bootstrap_embedding(
119+
adata=adata,
120+
meta=meta,
121+
selected_indices=selected_indices,
122+
sample_idx=np.asarray(sample_idx, dtype=np.int64),
123+
bootstrap_noise_model=bootstrap_noise_model,
124+
noise_scale=noise_scale,
87125
)
88126
elif bootstrap_sampling == "fps_random":
89127
if bootstrap_fps_top_k <= 0:
@@ -97,10 +135,13 @@ def compute_sparse_pairwise_distance(
97135
sample_idx = sample_farthest_points_randomized(
98136
X, n_keep, top_k=bootstrap_fps_top_k, alpha=bootstrap_fps_alpha
99137
)
100-
boot_idx = [selected_indices[int(i)] for i in sample_idx.tolist()]
101-
std_X = np.std(X, axis=0)
102-
X = X[sample_idx] + np.random.normal(
103-
scale=std_X * noise_scale, size=(n_keep, X.shape[1])
138+
X, boot_idx = _sample_bootstrap_embedding(
139+
adata=adata,
140+
meta=meta,
141+
selected_indices=selected_indices,
142+
sample_idx=np.asarray(sample_idx, dtype=np.int64),
143+
bootstrap_noise_model=bootstrap_noise_model,
144+
noise_scale=noise_scale,
104145
)
105146
elif (
106147
bootstrap_sampling == "herding"
@@ -118,10 +159,13 @@ def compute_sparse_pairwise_distance(
118159
frequency_seed=bootstrap_herding_seed,
119160
n_features=int(bootstrap_herding_n_features),
120161
)
121-
boot_idx = [selected_indices[int(i)] for i in sample_idx.tolist()]
122-
std_X = np.std(X, axis=0)
123-
X = X[sample_idx] + np.random.normal(
124-
scale=std_X * noise_scale, size=(n_keep, X.shape[1])
162+
X, boot_idx = _sample_bootstrap_embedding(
163+
adata=adata,
164+
meta=meta,
165+
selected_indices=selected_indices,
166+
sample_idx=np.asarray(sample_idx, dtype=np.int64),
167+
bootstrap_noise_model=bootstrap_noise_model,
168+
noise_scale=noise_scale,
125169
)
126170
else:
127171
boot_idx = selected_indices

src/scloop/computing/loops.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def compute_loop_representatives(
128128
decay_random_walk: PositiveFloat = 1.0,
129129
noise_random_walk: PositiveFloat = 1.0,
130130
seed_random_walk: int = 1,
131+
do_force_deviate_random_walk: bool = False,
131132
bootstrap: bool = False,
132133
rank_offset: int = 0,
133134
do_clean_cocycle_region: bool = False,
@@ -231,6 +232,7 @@ def compute_loop_representatives(
231232
decay_random_walk=decay_random_walk,
232233
noise_random_walk=noise_random_walk,
233234
seed_random_walk=seed_random_walk,
235+
do_force_deviate_random_walk=do_force_deviate_random_walk,
234236
do_clean_cocycle_region=do_clean_cocycle_region,
235237
)
236238

@@ -262,11 +264,12 @@ def reconstruct_n_loop_representatives(
262264
loop_lower_pct: float = 5,
263265
loop_upper_pct: float = 95,
264266
n_cocycles_used: Count_t = DEFAULT_N_COCYCLES_USED,
265-
do_random_walk: bool = False,
267+
do_random_walk: bool = False, # random walk works but still less robust than the force deviate branch
266268
n_random_graphs: Count_t = 10,
267269
decay_random_walk: PositiveFloat = 1.0,
268270
noise_random_walk: PositiveFloat = 1.0,
269271
seed_random_walk: int = 1,
272+
do_force_deviate_random_walk: bool = False,
270273
*,
271274
do_clean_cocycle_region: bool = False,
272275
) -> Tuple[List[List[int]], List[float]]:
@@ -363,10 +366,11 @@ def reconstruct_n_loop_representatives(
363366
paths_this_round.append(path)
364367
cycles_dist.append(dist)
365368

366-
for path in paths_this_round:
367-
for u, v in zip(path[:-1], path[1:]):
368-
key = (min(u, v), max(u, v))
369-
edge_weight_dict[key] = math.inf
369+
if (not do_random_walk) or do_force_deviate_random_walk:
370+
for path in paths_this_round:
371+
for u, v in zip(path[:-1], path[1:]):
372+
key = (min(u, v), max(u, v))
373+
edge_weight_dict[key] = math.inf
370374

371375
return _select_diverse_loops(
372376
cycles=cycles_pool,

src/scloop/data/containers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def _compute_loop_representatives(
484484
decay_random_walk: PositiveFloat = 1.0,
485485
noise_random_walk: PositiveFloat = 1.0,
486486
seed_random_walk: int = 1,
487+
do_force_deviate_random_walk: bool = False,
487488
):
488489
assert pairwise_distance_matrix.shape is not None
489490
assert self.meta.preprocess is not None
@@ -529,6 +530,7 @@ def _compute_loop_representatives(
529530
decay_random_walk=decay_random_walk,
530531
noise_random_walk=noise_random_walk,
531532
seed_random_walk=seed_random_walk,
533+
do_force_deviate_random_walk=do_force_deviate_random_walk,
532534
bootstrap=bootstrap,
533535
rank_offset=0,
534536
)
@@ -809,6 +811,7 @@ def _bootstrap(
809811
decay_random_walk: PositiveFloat = 1.0,
810812
noise_random_walk: PositiveFloat = 1.0,
811813
seed_random_walk: int = 1,
814+
do_force_deviate_random_walk: bool = False,
812815
n_pairs_check_equivalence: Count_t = DEFAULT_N_PAIRS_CHECK_EQUIVALENCE,
813816
with_relaxation_equivalence: bool = DEFAULT_WITH_RELAXATION_EQUIVALENCE,
814817
n_hubs_relaxation_equivalence: Count_t = DEFAULT_N_HUBS_RELAXATION_EQUIVALENCE,
@@ -853,6 +856,7 @@ def _bootstrap(
853856
decay_random_walk=decay_random_walk,
854857
noise_random_walk=noise_random_walk,
855858
seed_random_walk=seed_random_walk,
859+
do_force_deviate_random_walk=do_force_deviate_random_walk,
856860
n_pairs_check_equivalence=n_pairs_check_equivalence,
857861
with_relaxation_equivalence=with_relaxation_equivalence,
858862
n_hubs_relaxation_equivalence=n_hubs_relaxation_equivalence,
@@ -931,6 +935,7 @@ def _bootstrap(
931935
decay_random_walk=decay_random_walk,
932936
noise_random_walk=noise_random_walk,
933937
seed_random_walk=seed_random_walk,
938+
do_force_deviate_random_walk=do_force_deviate_random_walk,
934939
)
935940
if verbose:
936941
logger.info("Matching bootstrapped loops to the original loops")
@@ -1072,6 +1077,7 @@ def _bootstrap_parallel(
10721077
decay_random_walk: PositiveFloat = 1.0,
10731078
noise_random_walk: PositiveFloat = 1.0,
10741079
seed_random_walk: int = 1,
1080+
do_force_deviate_random_walk: bool = False,
10751081
n_pairs_check_equivalence: int = DEFAULT_N_PAIRS_CHECK_EQUIVALENCE,
10761082
with_relaxation_equivalence: bool = DEFAULT_WITH_RELAXATION_EQUIVALENCE,
10771083
n_hubs_relaxation_equivalence: int = DEFAULT_N_HUBS_RELAXATION_EQUIVALENCE,
@@ -1113,6 +1119,7 @@ def _bootstrap_parallel(
11131119
decay_random_walk=decay_random_walk,
11141120
noise_random_walk=noise_random_walk,
11151121
seed_random_walk=seed_random_walk,
1122+
do_force_deviate_random_walk=do_force_deviate_random_walk,
11161123
k_neighbors_check_equivalence=k_neighbors_check_equivalence,
11171124
method_geometric_equivalence=method_geometric_equivalence,
11181125
n_pairs_check_equivalence=n_pairs_check_equivalence,

src/scloop/tools/_loops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def find_loops(
5656
tightness_loops: Percent_t = 0,
5757
n_candidates: NonZeroCount_t = 1,
5858
n_bootstrap: Size_t = DEFAULT_N_BOOTSTRAP,
59-
bootstrap_sampling: str = "fps", # this does help quite a bit
59+
bootstrap_sampling: str = "fps_random",
6060
bootstrap_downsample_fraction: Percent_t = 2 / 3,
6161
bootstrap_fps_top_k: int = 5,
6262
bootstrap_fps_alpha: float = 1.0,
@@ -67,8 +67,7 @@ def find_loops(
6767
auto_shrink_boundary_matrix: bool = True,
6868
auto_shrink_factor: Percent_t = 0.9,
6969
n_max_workers: NonZeroCount_t = DEFAULT_N_MAX_WORKERS,
70-
use_parallel: bool = False,
71-
reconstruct_bootstrap_on_full_data: bool = False, # this migth help but not sure how to fully disconnect loop at this point
70+
use_parallel: bool = True,
7271
verbose: bool = False,
7372
max_log_messages: int | None = None,
7473
kwargs_bootstrap: dict[str, Any] | None = None,
@@ -179,7 +178,7 @@ def find_loops(
179178
k_neighbors_check_equivalence=DEFAULT_K_NEIGHBORS_CHECK_EQUIVALENCE,
180179
n_max_workers=n_max_workers,
181180
life_pct=tightness_loops,
182-
reconstruct_on_full_data=reconstruct_bootstrap_on_full_data,
181+
reconstruct_on_full_data=False, # this is not working, revisit in the future
183182
bootstrap_sampling=bootstrap_sampling,
184183
bootstrap_downsample_fraction=bootstrap_downsample_fraction,
185184
bootstrap_fps_top_k=bootstrap_fps_top_k,

0 commit comments

Comments
 (0)