Skip to content

Commit ed26545

Browse files
Added logic to support recursive searching of upstream visits (#725)
* Added logic to support recursive searching of upstream visits * Add logic to recursive search function to create dictionary, update it recursively, and return it instead of using a variable from outside the function scope * Use 'None' as a default for the 'result' parameter instead * Safer registration of data collections (#726) Ensure that dcg, dc and pjid exist before inserting into murfey db. Add a sleep for the case where they cannot be registered to allow the database to settle. --------- Co-authored-by: Stephen Riggs <122790971+stephen-riggs@users.noreply.github.com>
1 parent 80a68b5 commit ed26545

File tree

2 files changed

+52
-9
lines changed

2 files changed

+52
-9
lines changed

src/murfey/server/api/session_shared.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
from pathlib import Path
34
from typing import Dict, List
45

@@ -136,11 +137,48 @@ def get_foil_hole(session_id: int, fh_name: int, db) -> Dict[str, int]:
136137
return {f[1].tag: f[0].id for f in foil_holes}
137138

138139

139-
def find_upstream_visits(session_id: int, db: SQLModelSession):
140+
def find_upstream_visits(session_id: int, db: SQLModelSession, max_depth: int = 2):
140141
"""
141142
Returns a nested dictionary, in which visits and the full paths to their directories
142143
are further grouped by instrument name.
143144
"""
145+
146+
def _recursive_search(
147+
dirpath: str | Path,
148+
search_string: str,
149+
partial_match: bool = True,
150+
max_depth: int = 1,
151+
result: dict[str, Path] | None = None,
152+
):
153+
# If no dictionary was passed in, create a new dictionary
154+
if result is None:
155+
result = {}
156+
# Stop recursing for this route once max depth hits 0
157+
if max_depth == 0:
158+
return result
159+
160+
# Walk through the directories
161+
for entry in os.scandir(dirpath):
162+
if entry.is_dir():
163+
# Update dictionary with match and stop recursing for this route
164+
if (
165+
search_string in entry.name
166+
if partial_match
167+
else search_string == entry.name
168+
):
169+
if result is not None: # MyPy needs this 'is not None' check
170+
result[entry.name] = Path(entry.path)
171+
else:
172+
# Continue searching down this route until max depth is reached
173+
result = _recursive_search(
174+
dirpath=entry.path,
175+
search_string=search_string,
176+
partial_match=partial_match,
177+
max_depth=max_depth - 1,
178+
result=result,
179+
)
180+
return result
181+
144182
murfey_session = db.exec(
145183
select(MurfeySession).where(MurfeySession.id == session_id)
146184
).one()
@@ -155,12 +193,13 @@ def find_upstream_visits(session_id: int, db: SQLModelSession):
155193
upstream_instrument,
156194
upstream_data_dir,
157195
) in machine_config.upstream_data_directories.items():
158-
# Looks for visit name in file path
159-
current_upstream_visits = {}
160-
for visit_path in Path(upstream_data_dir).glob(f"{visit_name.split('-')[0]}-*"):
161-
if visit_path.is_dir():
162-
current_upstream_visits[visit_path.name] = visit_path
163-
upstream_visits[upstream_instrument] = current_upstream_visits
196+
# Recursively look for matching visit names under current directory
197+
upstream_visits[upstream_instrument] = _recursive_search(
198+
dirpath=upstream_data_dir,
199+
search_string=f"{visit_name.split('-')[0]}-",
200+
partial_match=True,
201+
max_depth=max_depth,
202+
)
164203
return upstream_visits
165204

166205

tests/server/api/test_session_shared.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from tests.conftest import ExampleVisit
1010

1111

12+
@pytest.mark.parametrize("recurse", (True, False))
1213
def test_find_upstream_visits(
1314
mocker: MockerFixture,
1415
tmp_path: Path,
15-
# murfey_db_session,
16+
recurse: bool,
1617
):
1718
# Get the visit, instrument name, and session ID
1819
visit_name_root = f"{ExampleVisit.proposal_code}{ExampleVisit.proposal_number}"
@@ -40,7 +41,10 @@ def test_find_upstream_visits(
4041
# Only directories should be picked up
4142
upstream_visit.mkdir(parents=True, exist_ok=True)
4243
upstream_visits[upstream_instrument] = {upstream_visit.stem: upstream_visit}
43-
upstream_data_dirs[upstream_instrument] = upstream_visit.parent
44+
# Check that the function can cope with recursive searching
45+
upstream_data_dirs[upstream_instrument] = (
46+
upstream_visit.parent.parent if recurse else upstream_visit.parent
47+
)
4448
else:
4549
upstream_visit.parent.mkdir(parents=True, exist_ok=True)
4650
upstream_visit.touch(exist_ok=True)

0 commit comments

Comments
 (0)