11from datetime import datetime
2- from typing import Any , Dict , List , Optional , Union , Tuple
2+ from typing import Any , ClassVar , Dict , List , Optional , Tuple , Union
3+ import polars as pl
34
45from pyhealth .tasks .base_task import BaseTask
56
67class EHRFoundationalModelMIMIC4 (BaseTask ):
7-
8+
89 task_name : str = "EHRFoundationalModelMIMIC4"
910 TOKEN_REPRESENTING_MISSING_TEXT = "<missing>"
1011 TOKEN_REPRESENTING_MISSING_FLOAT = float ("nan" )
11-
12+ PADDING : int = 0
13+
14+ LAB_CATEGORIES : ClassVar [Dict [str , List [str ]]] = {
15+ "Sodium" : ["50824" , "52455" , "50983" , "52623" ],
16+ "Potassium" : ["50822" , "52452" , "50971" , "52610" ],
17+ "Chloride" : ["50806" , "52434" , "50902" , "52535" ],
18+ "Bicarbonate" : ["50803" , "50804" ],
19+ "Glucose" : ["50809" , "52027" , "50931" , "52569" ],
20+ "Calcium" : ["50808" , "51624" ],
21+ "Magnesium" : ["50960" ],
22+ "Anion Gap" : ["50868" , "52500" ],
23+ "Osmolality" : ["52031" , "50964" , "51701" ],
24+ "Phosphate" : ["50970" ],
25+ }
26+
27+ LAB_CATEGORY_NAMES : ClassVar [List [str ]] = [
28+ "Sodium" , "Potassium" , "Chloride" , "Bicarbonate" , "Glucose" ,
29+ "Calcium" , "Magnesium" , "Anion Gap" , "Osmolality" , "Phosphate" ,
30+ ]
31+
32+ LABITEMS : ClassVar [List [str ]] = [
33+ item for itemids in LAB_CATEGORIES .values () for item in itemids
34+ ]
35+
1236 def __init__ (self ):
1337 """Initialize the EHR Foundational Model task."""
1438 self .input_schema : Dict [str , Union [str , Tuple [str , Dict ]]] = {
@@ -25,7 +49,9 @@ def __init__(self):
2549 "tokenizer_name" : "bert-base-uncased" ,
2650 "type_tag" : "note" ,
2751 },
28- )
52+ ),
53+ "icd_codes" : ("stagenet" , {"padding" : self .PADDING }),
54+ "labs" : ("stagenet_tensor" , {}),
2955 }
3056 self .output_schema : Dict [str , str ] = {"mortality" : "binary" }
3157
@@ -79,11 +105,26 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
79105 all_discharge_times_from_admission : List [float ] = []
80106 all_radiology_texts : List [str ] = []
81107 all_radiology_times_from_admission : List [float ] = []
108+ all_icd_codes : List [List [str ]] = []
109+ all_icd_times : List [float ] = []
110+ all_lab_values : List [List [Any ]] = []
111+ all_lab_times : List [float ] = []
112+ previous_admission_time = None
82113
83114 # Process each admission independently (per hadm_id)
84115 for admission in admissions_to_process :
85116 admission_time = admission .timestamp
86117
118+ try :
119+ admission_dischtime = datetime .strptime (
120+ admission .dischtime , "%Y-%m-%d %H:%M:%S"
121+ )
122+ except (ValueError , AttributeError ):
123+ continue
124+
125+ if admission_dischtime < admission_time :
126+ continue
127+
87128 # Get notes for this hadm_id only
88129 discharge_notes = patient .get_events (
89130 event_type = "discharge" , filters = [("hadm_id" , "==" , admission .hadm_id )]
@@ -122,6 +163,64 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
122163 all_radiology_texts .append (self .TOKEN_REPRESENTING_MISSING_TEXT ) # Token representing missing text
123164 all_radiology_times_from_admission .append (self .TOKEN_REPRESENTING_MISSING_FLOAT ) # Token representing missing time(?)
124165
166+ # ICD codes (diagnoses + procedures) with time relative to previous admission
167+ diagnoses_icd = patient .get_events (
168+ event_type = "diagnoses_icd" , filters = [("hadm_id" , "==" , admission .hadm_id )]
169+ )
170+ procedures_icd = patient .get_events (
171+ event_type = "procedures_icd" , filters = [("hadm_id" , "==" , admission .hadm_id )]
172+ )
173+ visit_icd_codes = (
174+ [e .icd_code for e in diagnoses_icd if hasattr (e , "icd_code" ) and e .icd_code ] +
175+ [e .icd_code for e in procedures_icd if hasattr (e , "icd_code" ) and e .icd_code ]
176+ )
177+ if visit_icd_codes :
178+ if previous_admission_time is None :
179+ time_from_previous = 0.0
180+ else :
181+ time_from_previous = (admission_time - previous_admission_time ).total_seconds () / 3600.0
182+ all_icd_codes .append (visit_icd_codes )
183+ all_icd_times .append (time_from_previous )
184+
185+ previous_admission_time = admission_time
186+
187+ # Lab events with time relative to this admission's start
188+ labevents_df = patient .get_events (
189+ event_type = "labevents" ,
190+ start = admission_time ,
191+ end = admission_dischtime ,
192+ return_df = True ,
193+ )
194+ labevents_df = labevents_df .filter (
195+ pl .col ("labevents/itemid" ).is_in (self .LABITEMS )
196+ )
197+ if labevents_df .height > 0 :
198+ labevents_df = labevents_df .with_columns (
199+ pl .col ("labevents/storetime" ).str .strptime (pl .Datetime , "%Y-%m-%d %H:%M:%S" )
200+ )
201+ labevents_df = labevents_df .filter (
202+ pl .col ("labevents/storetime" ) <= admission_dischtime
203+ )
204+ if labevents_df .height > 0 :
205+ labevents_df = labevents_df .select (
206+ pl .col ("timestamp" ),
207+ pl .col ("labevents/itemid" ),
208+ pl .col ("labevents/valuenum" ).cast (pl .Float64 ),
209+ )
210+ for lab_ts in sorted (labevents_df ["timestamp" ].unique ().to_list ()):
211+ ts_labs = labevents_df .filter (pl .col ("timestamp" ) == lab_ts )
212+ lab_vector : List [Any ] = []
213+ for category_name in self .LAB_CATEGORY_NAMES :
214+ category_value = None
215+ for itemid in self .LAB_CATEGORIES [category_name ]:
216+ matching = ts_labs .filter (pl .col ("labevents/itemid" ) == itemid )
217+ if matching .height > 0 :
218+ category_value = matching ["labevents/valuenum" ][0 ]
219+ break
220+ lab_vector .append (category_value )
221+ all_lab_values .append (lab_vector )
222+ all_lab_times .append ((lab_ts - admission_time ).total_seconds () / 3600.0 )
223+
125224 discharge_note_times_from_admission = (all_discharge_texts , all_discharge_times_from_admission )
126225 radiology_note_times_from_admission = (all_radiology_texts , all_radiology_times_from_admission )
127226
@@ -130,6 +229,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
130229 "patient_id" : patient .patient_id ,
131230 "discharge_note_times" : discharge_note_times_from_admission ,
132231 "radiology_note_times" : radiology_note_times_from_admission ,
232+ "icd_codes" : (all_icd_times , all_icd_codes ),
233+ "labs" : (all_lab_times , all_lab_values ),
133234 "mortality" : mortality_label ,
134235 }
135- ]
236+ ]
0 commit comments