66class EHRFoundationalModelMIMIC4 (BaseTask ):
77
88 task_name : str = "EHRFoundationalModelMIMIC4"
9+ TOKEN_REPRESENTING_MISSING_TEXT = "<missing>"
10+ TOKEN_REPRESENTING_MISSING_FLOAT = float ("nan" )
911
1012 def __init__ (self ):
1113 """Initialize the EHR Foundational Model task."""
@@ -23,47 +25,14 @@ def __init__(self):
2325 "tokenizer_name" : "bert-base-uncased" ,
2426 "type_tag" : "note" ,
2527 },
26- ),
27- "icd_codes" : (
28- "stagenet" ,
29- {"padding" : 0
30- }
31- ),
28+ )
3229 }
3330 self .output_schema : Dict [str , str ] = {"mortality" : "binary" }
3431
3532 def _clean_text (self , text : Optional [str ]) -> Optional [str ]:
3633 """Return text if non-empty, otherwise None."""
3734 return text if text else None
3835
39- def _compute_time_diffs (self , notes_with_timestamps , first_admission_time ):
40- """Compute hourly time offsets for notes relative to first admission.
41-
42- Sorts notes chronologically by timestamp, then computes each note's
43- offset (in hours) from the first admission time.
44-
45- Args:
46- notes_with_timestamps: List of (text, timestamp) tuples where
47- text is the clinical note string and timestamp is a datetime.
48- first_admission_time: datetime of the patient's first admission,
49- used as the anchor (t=0) for all time offsets.
50-
51- Returns:
52- Tuple of (texts, time_diffs) where:
53- - texts: List[str] of note contents, sorted chronologically
54- - time_diffs: List[float] of hours since first admission
55- Returns (["<missing>"], [0.0]) if no notes are available.
56- """
57- result = []
58-
59- if not notes_with_timestamps :
60- return (["<missing>" ], [0.0 ]) # TODO: Need to also figure out how to tokenize missing timestamps
61- notes_with_timestamps .sort (key = lambda x : x [1 ])
62- result = [(text , (ts - first_admission_time ).total_seconds () / 3600 ) for text , ts in notes_with_timestamps ]
63- texts , time_diffs = zip (* result )
64-
65- return (list (texts ), list (time_diffs ))
66-
6736 def __call__ (self , patient : Any ) -> List [Dict [str , Any ]]:
6837 # Get demographic info to filter by age
6938 demographics = patient .get_events (event_type = "patients" )
@@ -105,102 +74,62 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
10574 if len (admissions_to_process ) == 0 :
10675 return []
10776
108- # Get first admission time as reference for notes time offset
109- first_admission_time = admissions_to_process [0 ].timestamp
77+ # Aggregated notes and time offsets across all admissions (per hadm_id)
78+ all_discharge_texts : List [str ] = []
79+ all_discharge_times_from_admission : List [float ] = []
80+ all_radiology_texts : List [str ] = []
81+ all_radiology_times_from_admission : List [float ] = []
11082
111- # Aggregated data across all admissions
112- all_discharge_notes_timestamped = [] # List of (note_text, timestamp) tuples
113- all_radiology_notes_timestamped = [] # List of (note_text, timestamp) tuples
114- all_icd_codes = [] # ICD code lists per visit
115- all_icd_times = [] # Hours from first admission per visit
116-
117- # Process each admission and aggregate data
83+ # Process each admission independently (per hadm_id)
11884 for admission in admissions_to_process :
119- # Parse admission discharge time for lab events filtering
120- try :
121- admission_dischtime = datetime .strptime (
122- admission .dischtime , "%Y-%m-%d %H:%M:%S"
123- )
124- except (ValueError , AttributeError ):
125- # If we can't parse discharge time, skip this admission
126- continue
127-
128- # Skip if discharge is before admission (data quality issue)
129- if admission_dischtime < admission .timestamp :
130- continue
131-
132- # Get notes using hadm_id filtering
85+ admission_time = admission .timestamp
86+
87+ # Get notes for this hadm_id only
13388 discharge_notes = patient .get_events (
13489 event_type = "discharge" , filters = [("hadm_id" , "==" , admission .hadm_id )]
13590 )
13691 radiology_notes = patient .get_events (
13792 event_type = "radiology" , filters = [("hadm_id" , "==" , admission .hadm_id )]
13893 )
13994
140- # Extract and aggregate notes as individual items in lists
141- # Note: attribute is "text" (from mimic4_note.yaml), not "discharge"/"radiology"
142- for note in discharge_notes :
95+ for note in discharge_notes : #TODO: Maybe make this into a helper function?
14396 try :
14497 note_text = self ._clean_text (note .text )
14598 if note_text :
146- all_discharge_notes_timestamped .append ((note_text , note .timestamp ))
147- except AttributeError :
99+ time_from_admission = (
100+ note .timestamp - admission_time
101+ ).total_seconds () / 3600.0
102+ all_discharge_texts .append (note_text )
103+ all_discharge_times_from_admission .append (time_from_admission )
104+ except AttributeError : # note object is missing .text or .timestamp attribute (e.g. malformed note)
148105 pass
106+ if not discharge_notes : # If we get an empty list
107+ all_discharge_texts .append (self .TOKEN_REPRESENTING_MISSING_TEXT ) # Token representing missing text
108+ all_discharge_times_from_admission .append (self .TOKEN_REPRESENTING_MISSING_FLOAT ) # Token representing missing time(?)
149109
150- for note in radiology_notes :
110+ for note in radiology_notes : #TODO: Maybe make this into a helper function?
151111 try :
152112 note_text = self ._clean_text (note .text )
153113 if note_text :
154- all_radiology_notes_timestamped .append ((note_text , note .timestamp ))
155- except AttributeError :
114+ time_from_admission = (
115+ note .timestamp - admission_time
116+ ).total_seconds () / 3600.0
117+ all_radiology_texts .append (note_text )
118+ all_radiology_times_from_admission .append (time_from_admission )
119+ except AttributeError : # note object is missing .text or .timestamp attribute (e.g. malformed note)
156120 pass
121+ if not radiology_notes : # If we receive empty list
122+ all_radiology_texts .append (self .TOKEN_REPRESENTING_MISSING_TEXT ) # Token representing missing text
123+ all_radiology_times_from_admission .append (self .TOKEN_REPRESENTING_MISSING_FLOAT ) # Token representing missing time(?)
157124
158- # Get diagnosis codes for this admission using hadm_id
159- diagnoses_icd = patient .get_events (
160- event_type = "diagnoses_icd" ,
161- filters = [("hadm_id" , "==" , admission .hadm_id )],
162- )
163- visit_diagnoses = [
164- event .icd_code
165- for event in diagnoses_icd
166- if hasattr (event , "icd_code" ) and event .icd_code
167- ]
168-
169- # Get procedure codes for this admission using hadm_id
170- procedures_icd = patient .get_events (
171- event_type = "procedures_icd" ,
172- filters = [("hadm_id" , "==" , admission .hadm_id )],
173- )
174- visit_procedures = [
175- event .icd_code
176- for event in procedures_icd
177- if hasattr (event , "icd_code" ) and event .icd_code
178- ]
179-
180- # Combine diagnoses and procedures into single ICD code list
181- visit_icd_codes = visit_diagnoses + visit_procedures
182-
183- # Calculate time from admission start (hours)
184- if visit_icd_codes :
185- time_from_first = (
186- admission .timestamp - first_admission_time
187- ).total_seconds () / 3600.0
188- all_icd_codes .append (visit_icd_codes )
189- all_icd_times .append (time_from_first )
190-
191- # Convert (note_text, timestamp) tuples to (note_text, time_diff_hours) tuples
192- discharge_note_times = self ._compute_time_diffs (all_discharge_notes_timestamped , first_admission_time )
193- radiology_note_times = self ._compute_time_diffs (all_radiology_notes_timestamped , first_admission_time )
194-
195- # icd_codes: (List[List[str]], List[float]) — codes per visit, hours from first admission
196- icd_codes = (all_icd_codes , all_icd_times )
125+ discharge_note_times_from_admission = (all_discharge_texts , all_discharge_times_from_admission )
126+ radiology_note_times_from_admission = (all_radiology_texts , all_radiology_times_from_admission )
197127
198128 return [
199129 {
200130 "patient_id" : patient .patient_id ,
201- "discharge_note_times" : discharge_note_times ,
202- "radiology_note_times" : radiology_note_times ,
203- "icd_codes" : icd_codes ,
131+ "discharge_note_times" : discharge_note_times_from_admission ,
132+ "radiology_note_times" : radiology_note_times_from_admission ,
204133 "mortality" : mortality_label ,
205134 }
206135 ]
0 commit comments