Skip to content

Commit a435afe

Browse files
committed
feat: expose the Reference class and allow it to be passed to Dataset.open to avoid data duplication. feat: begin work for returning variant info
1 parent a192755 commit a435afe

File tree

10 files changed

+89
-1444
lines changed

10 files changed

+89
-1444
lines changed

docs/source/api.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
3232
.. autofunction:: get_dummy_dataset
3333
34+
.. autoclass:: Reference
35+
:members:
36+
:exclude-members: __new__, __init__
37+
3438
.. autoclass:: RaggedDataset
3539
:exclude-members: __new__, __init__
3640

python/genvarloader/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._bigwig import BigWigs
44
from ._dataset._impl import ArrayDataset, Dataset, RaggedDataset
5+
from ._dataset._reconstruct import Reference
56
from ._dataset._write import write
67
from ._dummy import get_dummy_dataset
78
from ._ragged import Ragged
@@ -19,4 +20,5 @@
1920
"get_dummy_dataset",
2021
"ArrayDataset",
2122
"RaggedDataset",
23+
"Reference",
2224
]

python/genvarloader/_dataset/_impl.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def open(
107107
@staticmethod
108108
def open(
109109
path: str | Path,
110-
reference: str | Path,
110+
reference: str | Path | Reference,
111111
jitter: int = 0,
112112
rng: int | np.random.Generator | None = False,
113113
deterministic: bool = True,
@@ -116,7 +116,7 @@ def open(
116116
@staticmethod
117117
def open(
118118
path: str | Path,
119-
reference: str | Path | None = None,
119+
reference: str | Path | Reference | None = None,
120120
jitter: int = 0,
121121
rng: int | np.random.Generator | None = False,
122122
deterministic: bool = True,
@@ -202,7 +202,10 @@ def open(
202202
logger.info(
203203
"Loading reference genome into memory. This typically has a modest memory footprint (a few GB) and greatly improves performance."
204204
)
205-
_reference = Reference.from_path_and_contigs(reference, contigs)
205+
if isinstance(reference, Reference):
206+
_reference = reference
207+
else:
208+
_reference = Reference.from_path(reference, contigs)
206209
seqs = Seqs(reference=_reference)
207210
tracks = Tracks.from_path(path, regions, len(samples))
208211
tracks = tracks.with_tracks(list(tracks.intervals))
@@ -211,7 +214,10 @@ def open(
211214
logger.info(
212215
"Loading reference genome into memory. This typically has a modest memory footprint (a few GB) and greatly improves performance."
213216
)
214-
_reference = Reference.from_path_and_contigs(reference, contigs)
217+
if isinstance(reference, Reference):
218+
_reference = reference
219+
else:
220+
_reference = Reference.from_path(reference, contigs)
215221
assert phased is not None
216222
assert ploidy is not None
217223
seqs = Haps.from_path(
@@ -228,7 +234,10 @@ def open(
228234
logger.info(
229235
"Loading reference genome into memory. This typically has a modest memory footprint (a few GB) and greatly improves performance."
230236
)
231-
_reference = Reference.from_path_and_contigs(reference, contigs)
237+
if isinstance(reference, Reference):
238+
_reference = reference
239+
else:
240+
_reference = Reference.from_path(reference, contigs)
232241
assert phased is not None
233242
assert ploidy is not None
234243
seqs = Haps.from_path(
@@ -253,7 +262,8 @@ def open(
253262
)
254263
out_of_bounds = bed.select(
255264
(
256-
pl.col("chromStart") >= pl.col("chrom").replace_strict(contig_lengths)
265+
pl.col("chromStart")
266+
>= pl.col("chrom").replace_strict(contig_lengths)
257267
).any()
258268
).item()
259269
if out_of_bounds:

python/genvarloader/_dataset/_reconstruct.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
RaggedIntervals,
3535
)
3636
from .._utils import _lengths_to_offsets, _normalize_contig_name
37-
from .._variants._records import VLenAlleles
37+
from .._variants._records import RaggedAlleles
3838
from ._genotypes import (
3939
SparseGenotypes,
4040
SparseSomaticGenotypes,
@@ -52,23 +52,49 @@
5252

5353
@define
5454
class Reference:
55+
"""A reference genome kept in-memory. Typically this is only instantiated to be
56+
passed to :meth:`Dataset.open <genvarloader.Dataset.open>` and avoid data duplication.
57+
58+
.. note::
59+
Do not instantiate this class directly. Use :meth:`Reference.from_path` instead.
60+
"""
61+
5562
reference: NDArray[np.uint8]
5663
contigs: List[str]
5764
offsets: NDArray[np.uint64]
5865
pad_char: int
5966

6067
@classmethod
61-
def from_path_and_contigs(cls, fasta: Union[str, Path], contigs: List[str]):
68+
def from_path(cls, fasta: Union[str, Path], contigs: List[str] | None = None):
69+
"""Load a reference genome from a FASTA file.
70+
71+
Parameters
72+
----------
73+
fasta
74+
Path to the FASTA file.
75+
contigs
76+
List of contig names to load. If None, all contigs in the FASTA file are loaded.
77+
Can be either UCSC or Ensembl style (i.e. with or without the "chr" prefix) and
78+
will be handled appropriately to match the underlying FASTA.
79+
"""
6280
_fasta = Fasta("ref", fasta, "N")
6381

6482
if not _fasta.cache_path.exists():
6583
logger.info("Memory-mapping FASTA file for faster access.")
6684
_fasta._write_to_cache()
6785

68-
contigs = cast(
69-
List[str],
70-
[_normalize_contig_name(c, _fasta.contigs) for c in contigs],
71-
)
86+
if contigs is None:
87+
contigs = list(_fasta.contigs)
88+
89+
_contigs = [_normalize_contig_name(c, _fasta.contigs) for c in contigs]
90+
if unmapped := [
91+
source for source, mapped in zip(contigs, _contigs) if mapped is None
92+
]:
93+
raise ValueError(
94+
f"Some of the given contig names are not present in reference file: {unmapped}"
95+
)
96+
contigs = cast(list[str], _contigs)
97+
7298
_fasta.sequences = _fasta._get_sequences(contigs)
7399
if TYPE_CHECKING:
74100
assert _fasta.sequences is not None
@@ -95,7 +121,7 @@ def from_path_and_contigs(cls, fasta: Union[str, Path], contigs: List[str]):
95121
class _Variants:
96122
positions: NDArray[np.int32]
97123
sizes: NDArray[np.int32]
98-
alts: VLenAlleles
124+
alts: RaggedAlleles
99125

100126
@classmethod
101127
def from_table(cls, variants: Union[str, Path, pl.DataFrame]):
@@ -104,7 +130,7 @@ def from_table(cls, variants: Union[str, Path, pl.DataFrame]):
104130
return cls(
105131
variants["POS"].to_numpy(),
106132
variants["ILEN"].to_numpy(),
107-
VLenAlleles.from_polars(variants["ALT"]),
133+
RaggedAlleles.from_polars(variants["ALT"]),
108134
)
109135

110136

@@ -268,7 +294,7 @@ def from_path(
268294
variants = _Variants(
269295
svar_index["POS"].to_numpy() - 1,
270296
svar_index["ILEN"].to_numpy(),
271-
VLenAlleles.from_polars(svar_index["ALT"]),
297+
RaggedAlleles.from_polars(svar_index["ALT"]),
272298
)
273299
return cls(
274300
reference=reference,
@@ -547,7 +573,7 @@ def _get_haplotypes(
547573
geno_v_idxs=self.genotypes.data,
548574
positions=self.variants.positions,
549575
sizes=self.variants.sizes,
550-
alt_alleles=self.variants.alts.alleles.view(np.uint8),
576+
alt_alleles=self.variants.alts.data.view(np.uint8),
551577
alt_offsets=self.variants.alts.offsets,
552578
ref=self.reference.reference,
553579
ref_offsets=self.reference.offsets,

python/genvarloader/_dataset/_reference.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

python/genvarloader/_dummy.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ._dataset._utils import bed_to_regions
1515
from ._ragged import Ragged, RaggedIntervals
1616
from ._utils import _lengths_to_offsets
17-
from ._variants._records import VLenAlleles
17+
from ._variants._records import RaggedAlleles
1818

1919

2020
def get_dummy_dataset():
@@ -25,6 +25,7 @@ def get_dummy_dataset():
2525
max_jitter = 2
2626

2727
dummy_samples = ["Aang", "Katara", "Sokka", "Toph"]
28+
n_samples = len(dummy_samples)
2829

2930
dummy_contigs = [str(i) for i in range(1, 23)] + ["X", "Y", "MT"]
3031
dummy_bed = pl.DataFrame(
@@ -35,6 +36,7 @@ def get_dummy_dataset():
3536
"strand": ["+", "-", "+", "+"],
3637
}
3738
)
39+
n_regions = len(dummy_bed)
3840

3941
with pl.StringCache():
4042
pl.Series(natsorted(dummy_contigs), dtype=pl.Categorical())
@@ -61,12 +63,13 @@ def get_dummy_dataset():
6163
)
6264

6365
dummy_vars = _Variants(
64-
positions=repeat(dummy_regions[:, 1], "r -> (r s)", s=4),
65-
sizes=repeat(np.array([-2, -1, 0, 1], np.int32), "s -> (r s)", r=4),
66-
alts=VLenAlleles(
67-
alleles=repeat(sp.cast_seqs("ACGTT"), "a -> (r a)", r=4),
66+
positions=repeat(dummy_regions[:, 1], "r -> (r s)", s=n_samples),
67+
sizes=repeat(np.array([-2, -1, 0, 1], np.int32), "s -> (r s)", r=n_regions),
68+
alts=RaggedAlleles.from_offsets(
69+
data=repeat(sp.cast_seqs("ACGTT"), "a -> (r a)", r=n_regions),
70+
shape=n_regions*n_samples,
6871
offsets=_lengths_to_offsets(
69-
repeat(np.array([1, 1, 1, 2]), "s -> (r s)", r=4)
72+
repeat(np.array([1, 1, 1, 2]), "s -> (r s)", r=n_regions)
7073
),
7174
),
7275
)

python/genvarloader/_ragged.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any, Optional, Tuple, TypeGuard, TypeVar, Union
44

5+
import awkward as ak
56
import numba as nb
67
import numpy as np
78
from attrs import define
@@ -54,6 +55,16 @@ def to_fixed_shape(self, shape: tuple[int, ...]) -> AnnotatedHaps:
5455
return AnnotatedHaps(haps, var_idxs, ref_coords)
5556

5657

58+
@define
59+
class RaggedVariants:
60+
"""Typically contains ragged arrays with shape (batch, ploidy, ~variants)"""
61+
62+
alts: ak.Array # (batch, ploidy, ~variants, ~length)
63+
pos: Ragged[np.int32]
64+
ilens: Ragged[np.int32]
65+
ccfs: Ragged[np.float32]
66+
67+
5768
def is_rag_dtype(rag: Ragged, dtype: type[DTYPE]) -> TypeGuard[Ragged[DTYPE]]:
5869
return np.issubdtype(rag.data.dtype, dtype)
5970

0 commit comments

Comments
 (0)