Skip to content

Commit 24c81cf

Browse files
authored
feat(medcat): CU-869bhknfm Refactor setting of filters for embedding linker (#268)
* CU-869bhknfm: Refactor setting of filters for embedding linker * CU-869bfagqw: Fix a small typing issue * CU-869bhknfm: Remove unused protected method
1 parent 6291baa commit 24c81cf

File tree

1 file changed

+151
-42
lines changed

1 file changed

+151
-42
lines changed

medcat-v2/medcat/components/linking/embedding_linker.py

Lines changed: 151 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections import defaultdict
1212
import logging
1313
import math
14+
import numpy as np
1415

1516
from medcat.utils.import_utils import ensure_optional_extras_installed
1617
import medcat
@@ -85,8 +86,9 @@ def __init__(self, cdb: CDB, config: Config) -> None:
8586
]
8687
for name in self._name_keys
8788
]
89+
self._initialize_filter_structures()
8890

89-
def create_embeddings(self,
91+
def create_embeddings(self,
9092
embedding_model_name: Optional[str] = None,
9193
max_length: Optional[int] = None,
9294
):
@@ -281,6 +283,149 @@ def _get_context_vectors(
281283
texts.append(text)
282284
return self._embed(texts, self.device)
283285

286+
def _initialize_filter_structures(self) -> None:
287+
"""Call once during initialization to create efficient lookup structures."""
288+
# Build an inverted index: cui_idx -> list of name indices that contain it
289+
# This is the KEY optimization - we flip the lookup direction
290+
if not hasattr(self, '_cui_idx_to_name_idxs'):
291+
cui2name_indices: defaultdict[
292+
int, list[int]] = defaultdict(list)
293+
294+
for name_idx, cui_idxs in enumerate(self._name_to_cui_idxs):
295+
for cui_idx in cui_idxs:
296+
cui2name_indices[cui_idx].append(name_idx)
297+
298+
# Convert lists to numpy arrays for faster indexing
299+
self._cui_idx_to_name_idxs = {
300+
cui_idx: np.array(name_idxs, dtype=np.int32)
301+
for cui_idx, name_idxs in cui2name_indices.items()
302+
}
303+
304+
# Cache _has_cuis_all
305+
if not hasattr(self, '_has_cuis_all_cached'):
306+
self._has_cuis_all_cached = torch.tensor(
307+
[bool(self.cdb.name2info[name]["per_cui_status"])
308+
for name in self._name_keys],
309+
device=self.device,
310+
dtype=torch.bool,
311+
)
312+
313+
def _get_include_filters_1cui(
314+
self, cui: str, n: int) -> torch.Tensor:
315+
"""Optimized single CUI include filter using inverted index."""
316+
if cui not in self._cui_to_idx:
317+
return torch.zeros(n, dtype=torch.bool, device=self.device)
318+
319+
cui_idx = self._cui_to_idx[cui]
320+
321+
# Use inverted index: get all name indices that contain this CUI
322+
if cui_idx in self._cui_idx_to_name_idxs:
323+
name_indices = self._cui_idx_to_name_idxs[cui_idx]
324+
325+
# Create mask by setting specific indices to True
326+
allowed_mask = torch.zeros(n, dtype=torch.bool, device=self.device)
327+
allowed_mask[torch.from_numpy(name_indices).to(self.device)] = True
328+
return allowed_mask
329+
else:
330+
return torch.zeros(n, dtype=torch.bool, device=self.device)
331+
332+
def _get_include_filters_multi_cui(
333+
self, include_set: Set[str], n: int) -> torch.Tensor:
334+
"""Optimized multi-CUI include filter using inverted index."""
335+
include_cui_idxs = [
336+
self._cui_to_idx[cui] for cui in include_set
337+
if cui in self._cui_to_idx
338+
]
339+
340+
if not include_cui_idxs:
341+
return torch.zeros(n, dtype=torch.bool, device=self.device)
342+
343+
# Collect all name indices from inverted index
344+
all_name_indices_list: list[np.ndarray] = []
345+
for cui_idx in include_cui_idxs:
346+
if cui_idx in self._cui_idx_to_name_idxs:
347+
all_name_indices_list.append(
348+
self._cui_idx_to_name_idxs[cui_idx])
349+
350+
if not all_name_indices_list:
351+
return torch.zeros(n, dtype=torch.bool, device=self.device)
352+
353+
# Concatenate and get unique indices
354+
all_name_indices = np.unique(
355+
np.concatenate(all_name_indices_list))
356+
357+
# Create mask
358+
allowed_mask = torch.zeros(n, dtype=torch.bool, device=self.device)
359+
allowed_mask[torch.from_numpy(all_name_indices).to(self.device)] = True
360+
return allowed_mask
361+
362+
def _get_include_filters(
363+
self, include_set: Set[str], n: int) -> torch.Tensor:
364+
"""Route to appropriate include filter method."""
365+
if len(include_set) == 1:
366+
cui = next(iter(include_set))
367+
return self._get_include_filters_1cui(cui, n)
368+
else:
369+
return self._get_include_filters_multi_cui(
370+
include_set, n)
371+
372+
def _get_exclude_filters_1cui(
373+
self, allowed_mask: torch.Tensor, cui: str) -> torch.Tensor:
374+
"""Optimized single CUI exclude filter using inverted index."""
375+
if cui not in self._cui_to_idx:
376+
return allowed_mask
377+
378+
cui_idx = self._cui_to_idx[cui]
379+
380+
if cui_idx in self._cui_idx_to_name_idxs:
381+
name_indices = self._cui_idx_to_name_idxs[cui_idx]
382+
# Set specific indices to False
383+
allowed_mask[
384+
torch.from_numpy(name_indices).to(self.device)] = False
385+
386+
return allowed_mask
387+
388+
def _get_exclude_filters_multi_cui(
389+
self, allowed_mask: torch.Tensor, exclude_set: Set[str],
390+
) -> torch.Tensor:
391+
"""Optimized multi-CUI exclude filter using inverted index."""
392+
exclude_cui_idxs = [
393+
self._cui_to_idx[cui] for cui in exclude_set
394+
if cui in self._cui_to_idx
395+
]
396+
397+
if not exclude_cui_idxs:
398+
return allowed_mask
399+
400+
# Collect all name indices to exclude
401+
_all_name_indices: list[np.ndarray] = []
402+
for cui_idx in exclude_cui_idxs:
403+
if cui_idx in self._cui_idx_to_name_idxs:
404+
_all_name_indices.append(self._cui_idx_to_name_idxs[cui_idx])
405+
406+
if _all_name_indices:
407+
all_name_indices = np.unique(np.concatenate(_all_name_indices))
408+
allowed_mask[torch.from_numpy(all_name_indices).to(self.device)] = False
409+
410+
return allowed_mask
411+
412+
def _get_exclude_filters(
413+
self, exclude_set: Set[str], n: int) -> torch.Tensor:
414+
"""Route to appropriate exclude filter method."""
415+
# Start with all allowed
416+
allowed_mask = torch.ones(n, dtype=torch.bool, device=self.device)
417+
418+
if not exclude_set:
419+
return allowed_mask
420+
421+
if len(exclude_set) == 1:
422+
cui = next(iter(exclude_set))
423+
return self._get_exclude_filters_1cui(
424+
allowed_mask, cui)
425+
else:
426+
return self._get_exclude_filters_multi_cui(
427+
allowed_mask, exclude_set)
428+
284429
def _set_filters(self) -> None:
285430
include_set = self.cnf_l.filters.cuis
286431
exclude_set = self.cnf_l.filters.cuis_exclude
@@ -295,54 +440,18 @@ def _set_filters(self) -> None:
295440
return
296441

297442
n = len(self._name_keys)
298-
allowed_mask = torch.empty(n, dtype=torch.bool, device=self.device)
299443

300444
if include_set:
301-
# if in include set, ignore exclude set.
302-
allowed_mask[:] = False
303-
include_cui_idxs = {
304-
self._cui_to_idx[cui] for cui in include_set if cui in self._cui_to_idx
305-
}
306-
include_idxs = [
307-
name_idx
308-
for name_idx, name_cui_idxs in enumerate(self._name_to_cui_idxs)
309-
if any(cui in include_cui_idxs for cui in name_cui_idxs)
310-
]
311-
allowed_mask[
312-
torch.tensor(include_idxs, dtype=torch.long, device=self.device)
313-
] = True
445+
allowed_mask = self._get_include_filters(
446+
include_set, n)
314447
else:
315-
# only look at exclude if there's no include set
316-
allowed_mask[:] = True
317-
if exclude_set:
318-
exclude_cui_idxs = {
319-
self._cui_to_idx[cui]
320-
for cui in exclude_set
321-
if cui in self._cui_to_idx
322-
}
323-
exclude_idxs = [
324-
i
325-
for i, name_cui_idxs in enumerate(self._name_to_cui_idxs)
326-
if any(ci in exclude_cui_idxs for ci in name_cui_idxs)
327-
]
328-
allowed_mask[
329-
torch.tensor(exclude_idxs, dtype=torch.long, device=self.device)
330-
] = False
448+
allowed_mask = self._get_exclude_filters(
449+
exclude_set, n)
331450

332-
# checking if a name has at least 1 cui related to it.
333-
_has_cuis_all = torch.tensor(
334-
[
335-
bool(self.cdb.name2info[name]["per_cui_status"])
336-
for name in self._name_keys
337-
],
338-
device=self.device,
339-
dtype=torch.bool,
340-
)
341-
self._valid_names = _has_cuis_all & allowed_mask
451+
self._valid_names = self._has_cuis_all_cached & allowed_mask
342452
self._last_include_set = set(include_set) if include_set is not None else None
343453
self._last_exclude_set = set(exclude_set) if exclude_set is not None else None
344454

345-
346455
def _disambiguate_by_cui(
347456
self, cui_candidates: list[str], scores: Tensor
348457
) -> tuple[str, float]:

0 commit comments

Comments
 (0)