Skip to content

Commit 13848fa

Browse files
committed
update: fix diffusion map mapping
1 parent 21b801f commit 13848fa

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

src/scloop/utils/denoise/Sanity_py.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,20 @@ def _sample_posterior_predictive_counts(
3434
log_var: np.ndarray,
3535
library_size: np.ndarray,
3636
cell_idx: np.ndarray,
37+
n_posterior: int = 1000,
38+
ltq_var_scale: float = 0.1,
3739
) -> np.ndarray:
3840
n_cells = cell_idx.shape[0]
3941
n_genes = log_mean.shape[1]
4042
counts = np.empty((n_cells, n_genes), dtype=np.float64)
4143
for i in range(n_cells):
4244
c = cell_idx[i]
4345
for g in range(n_genes):
44-
ltq = np.random.normal(log_mean[c, g], np.sqrt(log_var[c, g]))
46+
ltq = np.random.normal(
47+
log_mean[c, g], ltq_var_scale * np.sqrt(log_var[c, g])
48+
)
4549
rate = library_size[c] * np.exp(ltq)
46-
counts[i, g] = np.random.poisson(rate)
50+
counts[i, g] = np.mean(np.random.poisson(rate, size=n_posterior))
4751
return counts
4852

4953

@@ -55,17 +59,15 @@ def sample_posterior_predictive_counts(
5559
) -> np.ndarray:
5660
log_mean = np.ascontiguousarray(adata.layers["sanity_log_mean"])
5761
log_var = np.ascontiguousarray(adata.layers["sanity_log_var"])
58-
library_size = np.asarray(
59-
adata.layers["counts"].sum(axis=1), dtype=np.float64
60-
).ravel()
62+
library_size = np.array(adata.obs["library_size_sanity"])
6163
X = _sample_posterior_predictive_counts(log_mean, log_var, library_size, cell_idx)
6264
# map back to Sanity's log-fraction scale
6365
N_c = library_size[cell_idx]
6466
X = np.log(X / N_c[:, np.newaxis] + NUMERIC_EPSILON)
6567
if scale_before_pca:
66-
X = (X - adata.var["mean"].values) / adata.var["std"].values
68+
X = (X - np.mean(log_mean, axis=0)) / np.std(log_mean, axis=0)
6769
elif not scale_before_pca and n_pca_comps is not None:
68-
X = X - adata.uns["pca"]["params"]["mean"]
70+
X = X - np.mean(log_mean, axis=0)
6971
if n_pca_comps is not None:
7072
X = X @ adata.varm["PCs"]
7173
return X

0 commit comments

Comments
 (0)