1111from collections import defaultdict
1212import logging
1313import math
14+ import numpy as np
1415
1516from medcat .utils .import_utils import ensure_optional_extras_installed
1617import medcat
@@ -85,8 +86,9 @@ def __init__(self, cdb: CDB, config: Config) -> None:
8586 ]
8687 for name in self ._name_keys
8788 ]
89+ self ._initialize_filter_structures ()
8890
89- def create_embeddings (self ,
91+ def create_embeddings (self ,
9092 embedding_model_name : Optional [str ] = None ,
9193 max_length : Optional [int ] = None ,
9294 ):
@@ -281,6 +283,149 @@ def _get_context_vectors(
281283 texts .append (text )
282284 return self ._embed (texts , self .device )
283285
286+ def _initialize_filter_structures (self ) -> None :
287+ """Call once during initialization to create efficient lookup structures."""
288+ # Build an inverted index: cui_idx -> list of name indices that contain it
289+ # This is the KEY optimization - we flip the lookup direction
290+ if not hasattr (self , '_cui_idx_to_name_idxs' ):
291+ cui2name_indices : defaultdict [
292+ int , list [int ]] = defaultdict (list )
293+
294+ for name_idx , cui_idxs in enumerate (self ._name_to_cui_idxs ):
295+ for cui_idx in cui_idxs :
296+ cui2name_indices [cui_idx ].append (name_idx )
297+
298+ # Convert lists to numpy arrays for faster indexing
299+ self ._cui_idx_to_name_idxs = {
300+ cui_idx : np .array (name_idxs , dtype = np .int32 )
301+ for cui_idx , name_idxs in cui2name_indices .items ()
302+ }
303+
304+ # Cache _has_cuis_all
305+ if not hasattr (self , '_has_cuis_all_cached' ):
306+ self ._has_cuis_all_cached = torch .tensor (
307+ [bool (self .cdb .name2info [name ]["per_cui_status" ])
308+ for name in self ._name_keys ],
309+ device = self .device ,
310+ dtype = torch .bool ,
311+ )
312+
313+ def _get_include_filters_1cui (
314+ self , cui : str , n : int ) -> torch .Tensor :
315+ """Optimized single CUI include filter using inverted index."""
316+ if cui not in self ._cui_to_idx :
317+ return torch .zeros (n , dtype = torch .bool , device = self .device )
318+
319+ cui_idx = self ._cui_to_idx [cui ]
320+
321+ # Use inverted index: get all name indices that contain this CUI
322+ if cui_idx in self ._cui_idx_to_name_idxs :
323+ name_indices = self ._cui_idx_to_name_idxs [cui_idx ]
324+
325+ # Create mask by setting specific indices to True
326+ allowed_mask = torch .zeros (n , dtype = torch .bool , device = self .device )
327+ allowed_mask [torch .from_numpy (name_indices ).to (self .device )] = True
328+ return allowed_mask
329+ else :
330+ return torch .zeros (n , dtype = torch .bool , device = self .device )
331+
332+ def _get_include_filters_multi_cui (
333+ self , include_set : Set [str ], n : int ) -> torch .Tensor :
334+ """Optimized multi-CUI include filter using inverted index."""
335+ include_cui_idxs = [
336+ self ._cui_to_idx [cui ] for cui in include_set
337+ if cui in self ._cui_to_idx
338+ ]
339+
340+ if not include_cui_idxs :
341+ return torch .zeros (n , dtype = torch .bool , device = self .device )
342+
343+ # Collect all name indices from inverted index
344+ all_name_indices_list : list [np .ndarray ] = []
345+ for cui_idx in include_cui_idxs :
346+ if cui_idx in self ._cui_idx_to_name_idxs :
347+ all_name_indices_list .append (
348+ self ._cui_idx_to_name_idxs [cui_idx ])
349+
350+ if not all_name_indices_list :
351+ return torch .zeros (n , dtype = torch .bool , device = self .device )
352+
353+ # Concatenate and get unique indices
354+ all_name_indices = np .unique (
355+ np .concatenate (all_name_indices_list ))
356+
357+ # Create mask
358+ allowed_mask = torch .zeros (n , dtype = torch .bool , device = self .device )
359+ allowed_mask [torch .from_numpy (all_name_indices ).to (self .device )] = True
360+ return allowed_mask
361+
362+ def _get_include_filters (
363+ self , include_set : Set [str ], n : int ) -> torch .Tensor :
364+ """Route to appropriate include filter method."""
365+ if len (include_set ) == 1 :
366+ cui = next (iter (include_set ))
367+ return self ._get_include_filters_1cui (cui , n )
368+ else :
369+ return self ._get_include_filters_multi_cui (
370+ include_set , n )
371+
372+ def _get_exclude_filters_1cui (
373+ self , allowed_mask : torch .Tensor , cui : str ) -> torch .Tensor :
374+ """Optimized single CUI exclude filter using inverted index."""
375+ if cui not in self ._cui_to_idx :
376+ return allowed_mask
377+
378+ cui_idx = self ._cui_to_idx [cui ]
379+
380+ if cui_idx in self ._cui_idx_to_name_idxs :
381+ name_indices = self ._cui_idx_to_name_idxs [cui_idx ]
382+ # Set specific indices to False
383+ allowed_mask [
384+ torch .from_numpy (name_indices ).to (self .device )] = False
385+
386+ return allowed_mask
387+
388+ def _get_exclude_filters_multi_cui (
389+ self , allowed_mask : torch .Tensor , exclude_set : Set [str ],
390+ ) -> torch .Tensor :
391+ """Optimized multi-CUI exclude filter using inverted index."""
392+ exclude_cui_idxs = [
393+ self ._cui_to_idx [cui ] for cui in exclude_set
394+ if cui in self ._cui_to_idx
395+ ]
396+
397+ if not exclude_cui_idxs :
398+ return allowed_mask
399+
400+ # Collect all name indices to exclude
401+ _all_name_indices : list [np .ndarray ] = []
402+ for cui_idx in exclude_cui_idxs :
403+ if cui_idx in self ._cui_idx_to_name_idxs :
404+ _all_name_indices .append (self ._cui_idx_to_name_idxs [cui_idx ])
405+
406+ if _all_name_indices :
407+ all_name_indices = np .unique (np .concatenate (_all_name_indices ))
408+ allowed_mask [torch .from_numpy (all_name_indices ).to (self .device )] = False
409+
410+ return allowed_mask
411+
412+ def _get_exclude_filters (
413+ self , exclude_set : Set [str ], n : int ) -> torch .Tensor :
414+ """Route to appropriate exclude filter method."""
415+ # Start with all allowed
416+ allowed_mask = torch .ones (n , dtype = torch .bool , device = self .device )
417+
418+ if not exclude_set :
419+ return allowed_mask
420+
421+ if len (exclude_set ) == 1 :
422+ cui = next (iter (exclude_set ))
423+ return self ._get_exclude_filters_1cui (
424+ allowed_mask , cui )
425+ else :
426+ return self ._get_exclude_filters_multi_cui (
427+ allowed_mask , exclude_set )
428+
284429 def _set_filters (self ) -> None :
285430 include_set = self .cnf_l .filters .cuis
286431 exclude_set = self .cnf_l .filters .cuis_exclude
@@ -295,54 +440,18 @@ def _set_filters(self) -> None:
295440 return
296441
297442 n = len (self ._name_keys )
298- allowed_mask = torch .empty (n , dtype = torch .bool , device = self .device )
299443
300444 if include_set :
301- # if in include set, ignore exclude set.
302- allowed_mask [:] = False
303- include_cui_idxs = {
304- self ._cui_to_idx [cui ] for cui in include_set if cui in self ._cui_to_idx
305- }
306- include_idxs = [
307- name_idx
308- for name_idx , name_cui_idxs in enumerate (self ._name_to_cui_idxs )
309- if any (cui in include_cui_idxs for cui in name_cui_idxs )
310- ]
311- allowed_mask [
312- torch .tensor (include_idxs , dtype = torch .long , device = self .device )
313- ] = True
445+ allowed_mask = self ._get_include_filters (
446+ include_set , n )
314447 else :
315- # only look at exclude if there's no include set
316- allowed_mask [:] = True
317- if exclude_set :
318- exclude_cui_idxs = {
319- self ._cui_to_idx [cui ]
320- for cui in exclude_set
321- if cui in self ._cui_to_idx
322- }
323- exclude_idxs = [
324- i
325- for i , name_cui_idxs in enumerate (self ._name_to_cui_idxs )
326- if any (ci in exclude_cui_idxs for ci in name_cui_idxs )
327- ]
328- allowed_mask [
329- torch .tensor (exclude_idxs , dtype = torch .long , device = self .device )
330- ] = False
448+ allowed_mask = self ._get_exclude_filters (
449+ exclude_set , n )
331450
332- # checking if a name has at least 1 cui related to it.
333- _has_cuis_all = torch .tensor (
334- [
335- bool (self .cdb .name2info [name ]["per_cui_status" ])
336- for name in self ._name_keys
337- ],
338- device = self .device ,
339- dtype = torch .bool ,
340- )
341- self ._valid_names = _has_cuis_all & allowed_mask
451+ self ._valid_names = self ._has_cuis_all_cached & allowed_mask
342452 self ._last_include_set = set (include_set ) if include_set is not None else None
343453 self ._last_exclude_set = set (exclude_set ) if exclude_set is not None else None
344454
345-
346455 def _disambiguate_by_cui (
347456 self , cui_candidates : list [str ], scores : Tensor
348457 ) -> tuple [str , float ]:
0 commit comments