Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## Changes in 1.1.1 (under development)

- Add support for configuring the number of concurrent workers in `preload_data`,
limiting parallel preload tasks. Defaults to 4 workers.


## Changes in 1.1.0

- Support for RAR-compressed datasets.
Expand Down
10 changes: 3 additions & 7 deletions test/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import shutil
import unittest
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -222,7 +221,6 @@ def test_open_data_compressed_format_not_preloaded(self):
def test_preload_data_tar_gz(self):
store = new_data_store(DATA_STORE_ID, root="6453099")
cache_store = store.preload_data(silent=True)
cache_store.preload_handle.close()

self.assertCountEqual(["diaz2016_inputs_raw.zarr"], cache_store.list_data_ids())
ds = cache_store.open_data("diaz2016_inputs_raw.zarr")
Expand All @@ -242,7 +240,7 @@ def test_preload_data_tar_gz(self):
},
ds.sizes,
)
shutil.rmtree(cache_store.root)
cache_store.preload_handle.close()

@pytest.mark.vcr()
def test_preload_data_zip(self):
Expand All @@ -251,7 +249,6 @@ def test_preload_data_zip(self):
"andorra.zip",
silent=True,
)
cache_store.preload_handle.close()

self.assertCountEqual(
[
Expand All @@ -273,7 +270,7 @@ def test_preload_data_zip(self):
self.assertIsInstance(ds, xr.Dataset)
self.assertCountEqual([f"band_{i}" for i in range(1, 40)], list(ds.data_vars))
self.assertEqual(ds["band_1"].shape, (971, 1149))
shutil.rmtree(cache_store.root)
cache_store.preload_handle.close()

@pytest.mark.vcr()
def test_preload_data_zip_preload_params(self):
Expand All @@ -298,7 +295,6 @@ def test_preload_data_zip_preload_params(self):
"The preload request is discarded."
)
self.assertEqual(msg, str(cm.output[-1]))
cache_store.preload_handle.close()

self.assertCountEqual(
[
Expand All @@ -322,7 +318,7 @@ def test_preload_data_zip_preload_params(self):
[f"band_{i}" for i in range(1, 40)] + ["spatial_ref"], list(ds.data_vars)
)
self.assertEqual(ds["band_1"].shape, (971, 1149))
shutil.rmtree(cache_store.root)
cache_store.preload_handle.close()

@pytest.mark.vcr()
def test_preload_data_download_fails(self):
Expand Down
11 changes: 11 additions & 0 deletions xcube_zenodo/preload.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tarfile
import zipfile
from collections.abc import Sequence
from concurrent.futures.thread import ThreadPoolExecutor

import fsspec
import rarfile
Expand Down Expand Up @@ -66,12 +67,22 @@ def __init__(
if not self._process_fs.isdir(self._process_root):
self._process_fs.makedirs(self._process_root)

# limit number of workers in the preload process
if "executor" not in preload_params:
max_workers = preload_params.pop("max_workers", 4)
preload_params["executor"] = ThreadPoolExecutor(max_workers=max_workers)

# trigger preload in parent class
self._data_ids = {data_id.split("/")[-1]: data_id for data_id in data_ids}
super().__init__(data_ids=tuple(self._data_ids.keys()), **preload_params)

# delete temp storage
self._clean_up()

def close(self) -> None:
self._clean_up()
if self._cache_fs.isdir(self._cache_root):
self._cache_fs.rm(self._cache_root, recursive=True)

def preload_data(self, data_id: str, **preload_params):
format_ext = identify_compressed_file_format(data_id)
Expand Down
8 changes: 7 additions & 1 deletion xcube_zenodo/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ def get_preload_data_params_schema(self) -> JsonObjectSchema:
),
default=True,
),
max_workers=JsonIntegerSchema(
title="Maximum number of concurrent workers.",
description="Limits the number of parallel preload tasks.",
minimum=1,
default=4,
),
target_format=JsonStringSchema(
title="Format of the preloaded dataset in the cache.",
description="If not given, native format is kept.",
Expand All @@ -219,7 +225,7 @@ def get_preload_data_params_schema(self) -> JsonObjectSchema:
return JsonObjectSchema(
properties=dict(**params),
required=[],
additional_properties=False,
additional_properties=True,
)

def search_data(self, data_type: DataTypeLike = None, **search_params):
Expand Down
2 changes: 1 addition & 1 deletion xcube_zenodo/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

version = "1.1.0"
version = "1.1.1.dev0"